blob: ffc210afd7d5c099603ac14dbcb26fce3d0b85c7 [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070021#include "mlir/IR/Types.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070022using namespace mlir;
23
24//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070025// StmtResult
26//===------------------------------------------------------------------===//
27
28/// Return the result number of this result.
29unsigned StmtResult::getResultNumber() const {
30 // Results are always stored consecutively, so use pointer subtraction to
31 // figure out what number this is.
32 return this - &getOwner()->getStmtResults()[0];
33}
34
35//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070036// Statement
37//===------------------------------------------------------------------===//
38
39// Statements are deleted through the destroy() member because we don't have
40// a virtual destructor.
41Statement::~Statement() {
42 assert(block == nullptr && "statement destroyed but still in a block");
43}
44
45/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070046void Statement::destroy() {
47 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070048 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070049 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070050 break;
51 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070052 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070053 break;
54 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070055 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070056 break;
57 }
58}
59
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070060Statement *Statement::getParentStmt() const {
61 return block ? block->getParentStmt() : nullptr;
62}
Tatiana Shpeismand880b352018-07-31 23:14:16 -070063
64MLFunction *Statement::findFunction() const {
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070065 return block ? block->findFunction() : nullptr;
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070066}
67
Uday Bondhugula081d9e72018-07-27 10:58:14 -070068bool Statement::isInnermost() const {
69 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070070 unsigned numNestedLoops;
71 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070072 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070073 };
74
75 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070076 nlc.walk(const_cast<Statement *>(this));
77 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070078}
79
Uday Bondhugula84b80952018-08-03 13:22:26 -070080Statement *Statement::clone() const {
81 switch (kind) {
82 case Kind::Operation:
83 return cast<OperationStmt>(this)->clone();
84 case Kind::If:
85 llvm_unreachable("cloning for if's not implemented yet");
86 return cast<IfStmt>(this)->clone();
87 case Kind::For:
Uday Bondhugula84b80952018-08-03 13:22:26 -070088 return cast<ForStmt>(this)->clone();
89 }
90}
91
Uday Bondhugula134154e2018-08-06 18:40:34 -070092/// Replaces all uses of oldVal with newVal.
93// TODO(bondhugula,clattner): do this more efficiently by walking those uses of
94// oldVal that fall within this statement.
95void Statement::replaceUses(MLValue *oldVal, MLValue *newVal) {
96 struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
97 // Value to be replaced.
98 MLValue *oldVal;
99 // Value to be replaced with.
100 MLValue *newVal;
101
102 ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
103 : oldVal(oldVal), newVal(newVal){};
104
105 void visitOperationStmt(OperationStmt *os) {
106 for (auto &operand : os->getStmtOperands()) {
107 if (operand.get() == oldVal)
108 operand.set(newVal);
109 }
110 }
111 };
112
113 ReplaceUseWalker ri(oldVal, newVal);
114 ri.walk(this);
115}
116
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700117//===----------------------------------------------------------------------===//
118// ilist_traits for Statement
119//===----------------------------------------------------------------------===//
120
121StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
122 size_t Offset(
123 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
124 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
125 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
126 Offset);
127}
128
129/// This is a trait method invoked when a statement is added to a block. We
130/// keep the block pointer up to date.
131void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
132 assert(!stmt->getBlock() && "already in a statement block!");
133 stmt->block = getContainingBlock();
134}
135
136/// This is a trait method invoked when a statement is removed from a block.
137/// We keep the block pointer up to date.
138void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
139 Statement *stmt) {
140 assert(stmt->block && "not already in a statement block!");
141 stmt->block = nullptr;
142}
143
144/// This is a trait method invoked when a statement is moved from one block
145/// to another. We keep the block pointer up to date.
146void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
147 ilist_traits<Statement> &otherList, stmt_iterator first,
148 stmt_iterator last) {
149 // If we are transferring statements within the same block, the block
150 // pointer doesn't need to be updated.
151 StmtBlock *curParent = getContainingBlock();
152 if (curParent == otherList.getContainingBlock())
153 return;
154
155 // Update the 'block' member of each statement.
156 for (; first != last; ++first)
157 first->block = curParent;
158}
159
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700160/// Remove this statement (and its descendants) from its StmtBlock and delete
161/// all of them.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700162void Statement::eraseFromBlock() {
163 assert(getBlock() && "Statement has no block");
164 getBlock()->getStatements().erase(this);
165}
166
167//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700168// OperationStmt
169//===----------------------------------------------------------------------===//
170
171/// Create a new OperationStmt with the specific fields.
172OperationStmt *OperationStmt::create(Identifier name,
173 ArrayRef<MLValue *> operands,
174 ArrayRef<Type *> resultTypes,
175 ArrayRef<NamedAttribute> attributes,
176 MLIRContext *context) {
177 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
178 resultTypes.size());
179 void *rawMem = malloc(byteSize);
180
181 // Initialize the OperationStmt part of the statement.
182 auto stmt = ::new (rawMem) OperationStmt(
183 name, operands.size(), resultTypes.size(), attributes, context);
184
185 // Initialize the operands and results.
186 auto stmtOperands = stmt->getStmtOperands();
187 for (unsigned i = 0, e = operands.size(); i != e; ++i)
188 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
189
190 auto stmtResults = stmt->getStmtResults();
191 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
192 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
193 return stmt;
194}
195
Uday Bondhugula15984952018-08-01 22:36:12 -0700196/// Clone an existing OperationStmt.
197OperationStmt *OperationStmt::clone() const {
198 SmallVector<MLValue *, 8> operands;
199 SmallVector<Type *, 8> resultTypes;
200
201 // TODO(clattner): switch this to iterator logic.
202 // Put together operands and results.
203 for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
204 operands.push_back(getStmtOperand(i).get());
205
206 for (unsigned i = 0, e = getNumResults(); i != e; ++i)
207 resultTypes.push_back(getStmtResult(i).getType());
208
209 return create(getName(), operands, resultTypes, getAttrs(), getContext());
210}
211
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700212OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
213 unsigned numResults,
214 ArrayRef<NamedAttribute> attributes,
215 MLIRContext *context)
216 : Operation(name, /*isInstruction=*/false, attributes, context),
217 Statement(Kind::Operation), numOperands(numOperands),
218 numResults(numResults) {}
219
220OperationStmt::~OperationStmt() {
221 // Explicitly run the destructors for the operands and results.
222 for (auto &operand : getStmtOperands())
223 operand.~StmtOperand();
224
225 for (auto &result : getStmtResults())
226 result.~StmtResult();
227}
228
229void OperationStmt::destroy() {
230 this->~OperationStmt();
231 free(this);
232}
233
Chris Lattner95865062018-08-01 10:18:59 -0700234/// Return the context this operation is associated with.
235MLIRContext *OperationStmt::getContext() const {
236 // If we have a result or operand type, that is a constant time way to get
237 // to the context.
238 if (getNumResults())
239 return getResult(0)->getType()->getContext();
240 if (getNumOperands())
241 return getOperand(0)->getType()->getContext();
242
243 // In the very odd case where we have no operands or results, fall back to
244 // doing a find.
245 return findFunction()->getContext();
246}
247
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700248//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700249// ForStmt
250//===----------------------------------------------------------------------===//
251
252ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
253 AffineConstantExpr *step, MLIRContext *context)
Uday Bondhugula15984952018-08-01 22:36:12 -0700254 : Statement(Kind::For),
Tatiana Shpeismanc9c4b342018-07-31 07:40:14 -0700255 MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
Uday Bondhugula15984952018-08-01 22:36:12 -0700256 StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
257 upperBound(upperBound), step(step) {}
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700258
Uday Bondhugula84b80952018-08-03 13:22:26 -0700259ForStmt *ForStmt::clone() const {
Uday Bondhugula134154e2018-08-06 18:40:34 -0700260 auto *forStmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
261 Statement::findFunction()->getContext());
262
263 // Pairs of <old op stmt result whose uses need to be replaced,
264 // new result generated by the corresponding cloned op stmt>.
265 SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
Uday Bondhugula84b80952018-08-03 13:22:26 -0700266 for (auto &s : getStatements()) {
Uday Bondhugula134154e2018-08-06 18:40:34 -0700267 auto *cloneStmt = s.clone();
268 forStmt->getStatements().push_back(cloneStmt);
269 if (auto *opStmt = dyn_cast<OperationStmt>(&s)) {
270 auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
271 for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
272 oldNewResultPairs.push_back(
273 std::make_pair(const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
274 &cloneOpStmt->getStmtResult(i)));
275 }
276 }
Uday Bondhugula84b80952018-08-03 13:22:26 -0700277 }
Uday Bondhugula134154e2018-08-06 18:40:34 -0700278 // Replace uses of old op results' with the newly created ones.
279 for (unsigned i = 0, e = oldNewResultPairs.size(); i < e; i++) {
280 for (auto &stmt : *forStmt) {
281 stmt.replaceUses(oldNewResultPairs[i].first, oldNewResultPairs[i].second);
282 }
283 }
284
285 // Replace uses of old loop IV with the new one.
286 forStmt->Statement::replaceUses(const_cast<ForStmt *>(this), forStmt);
287 return forStmt;
Uday Bondhugula84b80952018-08-03 13:22:26 -0700288}
289
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700290//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700291// IfStmt
292//===----------------------------------------------------------------------===//
293
294IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700295 delete thenClause;
Uday Bondhugula15984952018-08-01 22:36:12 -0700296 if (elseClause)
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700297 delete elseClause;
Uday Bondhugulabc535622018-08-07 14:24:38 -0700298 // An IfStmt's IntegerSet 'condition' should not be deleted since it is
299 // allocated through MLIRContext's bump pointer allocator.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700300}
Uday Bondhugula84b80952018-08-03 13:22:26 -0700301
302IfStmt *IfStmt::clone() const {
303 llvm_unreachable("cloning for if's not implemented yet");
304 return nullptr;
305}