blob: 7da08c2de6c64cbc4431762475edaf12ae021cbc [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070021#include "mlir/IR/Types.h"
Chris Lattnere787b322018-08-08 11:14:57 -070022#include "llvm/ADT/DenseMap.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070023using namespace mlir;
24
25//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070026// StmtResult
27//===------------------------------------------------------------------===//
28
29/// Return the result number of this result.
30unsigned StmtResult::getResultNumber() const {
31 // Results are always stored consecutively, so use pointer subtraction to
32 // figure out what number this is.
33 return this - &getOwner()->getStmtResults()[0];
34}
35
36//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070037// Statement
38//===------------------------------------------------------------------===//
39
40// Statements are deleted through the destroy() member because we don't have
41// a virtual destructor.
42Statement::~Statement() {
43 assert(block == nullptr && "statement destroyed but still in a block");
44}
45
46/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070047void Statement::destroy() {
48 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070049 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070050 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070051 break;
52 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070053 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070054 break;
55 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070056 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070057 break;
58 }
59}
60
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070061Statement *Statement::getParentStmt() const {
62 return block ? block->getParentStmt() : nullptr;
63}
Tatiana Shpeismand880b352018-07-31 23:14:16 -070064
65MLFunction *Statement::findFunction() const {
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070066 return block ? block->findFunction() : nullptr;
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070067}
68
Uday Bondhugula081d9e72018-07-27 10:58:14 -070069bool Statement::isInnermost() const {
70 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070071 unsigned numNestedLoops;
72 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070073 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070074 };
75
76 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070077 nlc.walk(const_cast<Statement *>(this));
78 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070079}
80
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070081//===----------------------------------------------------------------------===//
82// ilist_traits for Statement
83//===----------------------------------------------------------------------===//
84
85StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
86 size_t Offset(
87 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
88 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
89 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
90 Offset);
91}
92
93/// This is a trait method invoked when a statement is added to a block. We
94/// keep the block pointer up to date.
95void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
96 assert(!stmt->getBlock() && "already in a statement block!");
97 stmt->block = getContainingBlock();
98}
99
100/// This is a trait method invoked when a statement is removed from a block.
101/// We keep the block pointer up to date.
102void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
103 Statement *stmt) {
104 assert(stmt->block && "not already in a statement block!");
105 stmt->block = nullptr;
106}
107
108/// This is a trait method invoked when a statement is moved from one block
109/// to another. We keep the block pointer up to date.
110void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
111 ilist_traits<Statement> &otherList, stmt_iterator first,
112 stmt_iterator last) {
113 // If we are transferring statements within the same block, the block
114 // pointer doesn't need to be updated.
115 StmtBlock *curParent = getContainingBlock();
116 if (curParent == otherList.getContainingBlock())
117 return;
118
119 // Update the 'block' member of each statement.
120 for (; first != last; ++first)
121 first->block = curParent;
122}
123
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700124/// Remove this statement (and its descendants) from its StmtBlock and delete
125/// all of them.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700126void Statement::eraseFromBlock() {
127 assert(getBlock() && "Statement has no block");
128 getBlock()->getStatements().erase(this);
129}
130
131//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700132// OperationStmt
133//===----------------------------------------------------------------------===//
134
135/// Create a new OperationStmt with the specific fields.
136OperationStmt *OperationStmt::create(Identifier name,
137 ArrayRef<MLValue *> operands,
138 ArrayRef<Type *> resultTypes,
139 ArrayRef<NamedAttribute> attributes,
140 MLIRContext *context) {
141 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
142 resultTypes.size());
143 void *rawMem = malloc(byteSize);
144
145 // Initialize the OperationStmt part of the statement.
146 auto stmt = ::new (rawMem) OperationStmt(
147 name, operands.size(), resultTypes.size(), attributes, context);
148
149 // Initialize the operands and results.
150 auto stmtOperands = stmt->getStmtOperands();
151 for (unsigned i = 0, e = operands.size(); i != e; ++i)
152 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
153
154 auto stmtResults = stmt->getStmtResults();
155 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
156 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
157 return stmt;
158}
159
160OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
161 unsigned numResults,
162 ArrayRef<NamedAttribute> attributes,
163 MLIRContext *context)
164 : Operation(name, /*isInstruction=*/false, attributes, context),
165 Statement(Kind::Operation), numOperands(numOperands),
166 numResults(numResults) {}
167
168OperationStmt::~OperationStmt() {
169 // Explicitly run the destructors for the operands and results.
170 for (auto &operand : getStmtOperands())
171 operand.~StmtOperand();
172
173 for (auto &result : getStmtResults())
174 result.~StmtResult();
175}
176
177void OperationStmt::destroy() {
178 this->~OperationStmt();
179 free(this);
180}
181
Chris Lattner95865062018-08-01 10:18:59 -0700182/// Return the context this operation is associated with.
183MLIRContext *OperationStmt::getContext() const {
184 // If we have a result or operand type, that is a constant time way to get
185 // to the context.
186 if (getNumResults())
187 return getResult(0)->getType()->getContext();
188 if (getNumOperands())
189 return getOperand(0)->getType()->getContext();
190
191 // In the very odd case where we have no operands or results, fall back to
192 // doing a find.
193 return findFunction()->getContext();
194}
195
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700196//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700197// ForStmt
198//===----------------------------------------------------------------------===//
199
200ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
Uday Bondhugula67701712018-08-21 16:01:23 -0700201 int64_t step, MLIRContext *context)
Uday Bondhugula15984952018-08-01 22:36:12 -0700202 : Statement(Kind::For),
Tatiana Shpeismanc9c4b342018-07-31 07:40:14 -0700203 MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
Uday Bondhugula15984952018-08-01 22:36:12 -0700204 StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
205 upperBound(upperBound), step(step) {}
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700206
207//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700208// IfStmt
209//===----------------------------------------------------------------------===//
210
211IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700212 delete thenClause;
Uday Bondhugula15984952018-08-01 22:36:12 -0700213 if (elseClause)
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700214 delete elseClause;
Uday Bondhugulabc535622018-08-07 14:24:38 -0700215 // An IfStmt's IntegerSet 'condition' should not be deleted since it is
216 // allocated through MLIRContext's bump pointer allocator.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700217}
Uday Bondhugula84b80952018-08-03 13:22:26 -0700218
Chris Lattnere787b322018-08-08 11:14:57 -0700219//===----------------------------------------------------------------------===//
220// Statement Cloning
221//===----------------------------------------------------------------------===//
222
223/// Create a deep copy of this statement, remapping any operands that use
224/// values outside of the statement using the map that is provided (leaving
225/// them alone if no entry is present). Replaces references to cloned
226/// sub-statements to the corresponding statement that is copied, and adds
227/// those mappings to the map.
228Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
229 MLIRContext *context) const {
230 // If the specified value is in operandMap, return the remapped value.
231 // Otherwise return the value itself.
232 auto remapOperand = [&](const MLValue *value) -> MLValue * {
233 auto it = operandMap.find(value);
234 return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
235 };
236
237 if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
238 SmallVector<MLValue *, 8> operands;
239 operands.reserve(opStmt->getNumOperands());
240 for (auto *opValue : opStmt->getOperands())
241 operands.push_back(remapOperand(opValue));
242
243 SmallVector<Type *, 8> resultTypes;
244 resultTypes.reserve(opStmt->getNumResults());
245 for (auto *result : opStmt->getResults())
246 resultTypes.push_back(result->getType());
247 auto *newOp = OperationStmt::create(
248 opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), context);
249 // Remember the mapping of any results.
250 for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
251 operandMap[opStmt->getResult(i)] = newOp->getResult(i);
252 return newOp;
253 }
254
255 if (auto *forStmt = dyn_cast<ForStmt>(this)) {
256 auto *newFor =
257 new ForStmt(forStmt->getLowerBound(), forStmt->getUpperBound(),
258 forStmt->getStep(), context);
259 // Remember the induction variable mapping.
260 operandMap[forStmt] = newFor;
261
262 // TODO: remap operands in loop bounds when they are added.
263 // Recursively clone the body of the for loop.
264 for (auto &subStmt : *forStmt)
265 newFor->push_back(subStmt.clone(operandMap, context));
266
267 return newFor;
268 }
269
270 // Otherwise, we must have an If statement.
271 auto *ifStmt = cast<IfStmt>(this);
272 auto *newIf = new IfStmt(ifStmt->getCondition());
273
274 // TODO: remap operands with remapOperand when if statements have them.
275
276 auto *resultThen = newIf->getThen();
277 for (auto &childStmt : *ifStmt->getThen())
278 resultThen->push_back(childStmt.clone(operandMap, context));
279
280 if (ifStmt->hasElse()) {
281 auto *resultElse = newIf->createElse();
282 for (auto &childStmt : *ifStmt->getElse())
283 resultElse->push_back(childStmt.clone(operandMap, context));
284 }
285
286 return newIf;
Uday Bondhugula84b80952018-08-03 13:22:26 -0700287}