blob: 8a6248ea34a6b20a82c016cee3458ae9cfd72956 [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
Chris Lattner1628fa02018-08-23 14:32:25 -070019#include "mlir/IR/MLIRContext.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070020#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070021#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070022#include "mlir/IR/Types.h"
Chris Lattnere787b322018-08-08 11:14:57 -070023#include "llvm/ADT/DenseMap.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070024using namespace mlir;
25
26//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070027// StmtResult
28//===------------------------------------------------------------------===//
29
30/// Return the result number of this result.
31unsigned StmtResult::getResultNumber() const {
32 // Results are always stored consecutively, so use pointer subtraction to
33 // figure out what number this is.
34 return this - &getOwner()->getStmtResults()[0];
35}
36
37//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070038// Statement
39//===------------------------------------------------------------------===//
40
41// Statements are deleted through the destroy() member because we don't have
42// a virtual destructor.
43Statement::~Statement() {
44 assert(block == nullptr && "statement destroyed but still in a block");
45}
46
47/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070048void Statement::destroy() {
49 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070050 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070051 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070052 break;
53 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070054 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070055 break;
56 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070057 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070058 break;
59 }
60}
61
Chris Lattner1628fa02018-08-23 14:32:25 -070062/// Return the context this operation is associated with.
63MLIRContext *Statement::getContext() const {
64 // Work a bit to avoid calling findFunction() and getting its context.
65 switch (getKind()) {
66 case Kind::Operation:
67 return cast<OperationStmt>(this)->getContext();
68 case Kind::For:
69 return cast<ForStmt>(this)->getType()->getContext();
70 case Kind::If:
71 // TODO(shpeisman): When if statement has value operands, we can get a
72 // context from their type.
73 return findFunction()->getContext();
74 }
75}
76
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070077Statement *Statement::getParentStmt() const {
78 return block ? block->getParentStmt() : nullptr;
79}
Tatiana Shpeismand880b352018-07-31 23:14:16 -070080
81MLFunction *Statement::findFunction() const {
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070082 return block ? block->findFunction() : nullptr;
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070083}
84
Uday Bondhugula081d9e72018-07-27 10:58:14 -070085bool Statement::isInnermost() const {
86 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070087 unsigned numNestedLoops;
88 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070089 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070090 };
91
92 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070093 nlc.walk(const_cast<Statement *>(this));
94 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070095}
96
Chris Lattner1628fa02018-08-23 14:32:25 -070097/// Emit a note about this statement, reporting up to any diagnostic
98/// handlers that may be listening.
99void Statement::emitNote(const Twine &message) const {
100 getContext()->emitDiagnostic(getLoc(), message,
101 MLIRContext::DiagnosticKind::Note);
102}
103
104/// Emit a warning about this statement, reporting up to any diagnostic
105/// handlers that may be listening.
106void Statement::emitWarning(const Twine &message) const {
107 getContext()->emitDiagnostic(getLoc(), message,
108 MLIRContext::DiagnosticKind::Warning);
109}
110
111/// Emit an error about fatal conditions with this statement, reporting up to
112/// any diagnostic handlers that may be listening. NOTE: This may terminate
113/// the containing application, only use when the IR is in an inconsistent
114/// state.
115void Statement::emitError(const Twine &message) const {
116 getContext()->emitDiagnostic(getLoc(), message,
117 MLIRContext::DiagnosticKind::Error);
118}
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700119//===----------------------------------------------------------------------===//
120// ilist_traits for Statement
121//===----------------------------------------------------------------------===//
122
123StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
124 size_t Offset(
125 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
126 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
127 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
128 Offset);
129}
130
131/// This is a trait method invoked when a statement is added to a block. We
132/// keep the block pointer up to date.
133void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
134 assert(!stmt->getBlock() && "already in a statement block!");
135 stmt->block = getContainingBlock();
136}
137
138/// This is a trait method invoked when a statement is removed from a block.
139/// We keep the block pointer up to date.
140void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
141 Statement *stmt) {
142 assert(stmt->block && "not already in a statement block!");
143 stmt->block = nullptr;
144}
145
146/// This is a trait method invoked when a statement is moved from one block
147/// to another. We keep the block pointer up to date.
148void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
149 ilist_traits<Statement> &otherList, stmt_iterator first,
150 stmt_iterator last) {
151 // If we are transferring statements within the same block, the block
152 // pointer doesn't need to be updated.
153 StmtBlock *curParent = getContainingBlock();
154 if (curParent == otherList.getContainingBlock())
155 return;
156
157 // Update the 'block' member of each statement.
158 for (; first != last; ++first)
159 first->block = curParent;
160}
161
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700162/// Remove this statement (and its descendants) from its StmtBlock and delete
163/// all of them.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700164void Statement::eraseFromBlock() {
165 assert(getBlock() && "Statement has no block");
166 getBlock()->getStatements().erase(this);
167}
168
169//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700170// OperationStmt
171//===----------------------------------------------------------------------===//
172
173/// Create a new OperationStmt with the specific fields.
Chris Lattner1628fa02018-08-23 14:32:25 -0700174OperationStmt *OperationStmt::create(Attribute *location, Identifier name,
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700175 ArrayRef<MLValue *> operands,
176 ArrayRef<Type *> resultTypes,
177 ArrayRef<NamedAttribute> attributes,
178 MLIRContext *context) {
179 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
180 resultTypes.size());
181 void *rawMem = malloc(byteSize);
182
183 // Initialize the OperationStmt part of the statement.
184 auto stmt = ::new (rawMem) OperationStmt(
Chris Lattner1628fa02018-08-23 14:32:25 -0700185 location, name, operands.size(), resultTypes.size(), attributes, context);
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700186
187 // Initialize the operands and results.
188 auto stmtOperands = stmt->getStmtOperands();
189 for (unsigned i = 0, e = operands.size(); i != e; ++i)
190 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
191
192 auto stmtResults = stmt->getStmtResults();
193 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
194 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
195 return stmt;
196}
197
Chris Lattner1628fa02018-08-23 14:32:25 -0700198OperationStmt::OperationStmt(Attribute *location, Identifier name,
199 unsigned numOperands, unsigned numResults,
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700200 ArrayRef<NamedAttribute> attributes,
201 MLIRContext *context)
Chris Lattner1628fa02018-08-23 14:32:25 -0700202 : Operation(/*isInstruction=*/false, name, attributes, context),
203 Statement(Kind::Operation, location), numOperands(numOperands),
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700204 numResults(numResults) {}
205
206OperationStmt::~OperationStmt() {
207 // Explicitly run the destructors for the operands and results.
208 for (auto &operand : getStmtOperands())
209 operand.~StmtOperand();
210
211 for (auto &result : getStmtResults())
212 result.~StmtResult();
213}
214
215void OperationStmt::destroy() {
216 this->~OperationStmt();
217 free(this);
218}
219
Chris Lattner95865062018-08-01 10:18:59 -0700220/// Return the context this operation is associated with.
221MLIRContext *OperationStmt::getContext() const {
222 // If we have a result or operand type, that is a constant time way to get
223 // to the context.
224 if (getNumResults())
225 return getResult(0)->getType()->getContext();
226 if (getNumOperands())
227 return getOperand(0)->getType()->getContext();
228
229 // In the very odd case where we have no operands or results, fall back to
230 // doing a find.
231 return findFunction()->getContext();
232}
233
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700234//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700235// ForStmt
236//===----------------------------------------------------------------------===//
237
Chris Lattner1628fa02018-08-23 14:32:25 -0700238ForStmt::ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
239 AffineConstantExpr *upperBound, int64_t step,
240 MLIRContext *context)
241 : Statement(Kind::For, location),
Tatiana Shpeismanc9c4b342018-07-31 07:40:14 -0700242 MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
Uday Bondhugula15984952018-08-01 22:36:12 -0700243 StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
244 upperBound(upperBound), step(step) {}
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700245
246//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700247// IfStmt
248//===----------------------------------------------------------------------===//
249
Chris Lattner1628fa02018-08-23 14:32:25 -0700250IfStmt::IfStmt(Attribute *location, IntegerSet *condition)
251 : Statement(Kind::If, location), thenClause(new IfClause(this)),
252 elseClause(nullptr), condition(condition) {}
253
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700254IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700255 delete thenClause;
Uday Bondhugula15984952018-08-01 22:36:12 -0700256 if (elseClause)
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700257 delete elseClause;
Uday Bondhugulabc535622018-08-07 14:24:38 -0700258 // An IfStmt's IntegerSet 'condition' should not be deleted since it is
259 // allocated through MLIRContext's bump pointer allocator.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700260}
Uday Bondhugula84b80952018-08-03 13:22:26 -0700261
Chris Lattnere787b322018-08-08 11:14:57 -0700262//===----------------------------------------------------------------------===//
263// Statement Cloning
264//===----------------------------------------------------------------------===//
265
266/// Create a deep copy of this statement, remapping any operands that use
267/// values outside of the statement using the map that is provided (leaving
268/// them alone if no entry is present). Replaces references to cloned
269/// sub-statements to the corresponding statement that is copied, and adds
270/// those mappings to the map.
271Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
272 MLIRContext *context) const {
273 // If the specified value is in operandMap, return the remapped value.
274 // Otherwise return the value itself.
275 auto remapOperand = [&](const MLValue *value) -> MLValue * {
276 auto it = operandMap.find(value);
277 return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
278 };
279
280 if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
281 SmallVector<MLValue *, 8> operands;
282 operands.reserve(opStmt->getNumOperands());
283 for (auto *opValue : opStmt->getOperands())
284 operands.push_back(remapOperand(opValue));
285
286 SmallVector<Type *, 8> resultTypes;
287 resultTypes.reserve(opStmt->getNumResults());
288 for (auto *result : opStmt->getResults())
289 resultTypes.push_back(result->getType());
Chris Lattner1628fa02018-08-23 14:32:25 -0700290 auto *newOp =
291 OperationStmt::create(getLoc(), opStmt->getName(), operands,
292 resultTypes, opStmt->getAttrs(), context);
Chris Lattnere787b322018-08-08 11:14:57 -0700293 // Remember the mapping of any results.
294 for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
295 operandMap[opStmt->getResult(i)] = newOp->getResult(i);
296 return newOp;
297 }
298
299 if (auto *forStmt = dyn_cast<ForStmt>(this)) {
300 auto *newFor =
Chris Lattner1628fa02018-08-23 14:32:25 -0700301 new ForStmt(getLoc(), forStmt->getLowerBound(),
302 forStmt->getUpperBound(), forStmt->getStep(), context);
Chris Lattnere787b322018-08-08 11:14:57 -0700303 // Remember the induction variable mapping.
304 operandMap[forStmt] = newFor;
305
306 // TODO: remap operands in loop bounds when they are added.
307 // Recursively clone the body of the for loop.
308 for (auto &subStmt : *forStmt)
309 newFor->push_back(subStmt.clone(operandMap, context));
310
311 return newFor;
312 }
313
314 // Otherwise, we must have an If statement.
315 auto *ifStmt = cast<IfStmt>(this);
Chris Lattner1628fa02018-08-23 14:32:25 -0700316 auto *newIf = new IfStmt(getLoc(), ifStmt->getCondition());
Chris Lattnere787b322018-08-08 11:14:57 -0700317
318 // TODO: remap operands with remapOperand when if statements have them.
319
320 auto *resultThen = newIf->getThen();
321 for (auto &childStmt : *ifStmt->getThen())
322 resultThen->push_back(childStmt.clone(operandMap, context));
323
324 if (ifStmt->hasElse()) {
325 auto *resultElse = newIf->createElse();
326 for (auto &childStmt : *ifStmt->getElse())
327 resultElse->push_back(childStmt.clone(operandMap, context));
328 }
329
330 return newIf;
Uday Bondhugula84b80952018-08-03 13:22:26 -0700331}