blob: 44e44c821f8d6ca7566fc8919673f05ef89ecd5f [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070021#include "mlir/IR/Types.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070022using namespace mlir;
23
24//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070025// StmtResult
26//===------------------------------------------------------------------===//
27
28/// Return the result number of this result.
29unsigned StmtResult::getResultNumber() const {
30 // Results are always stored consecutively, so use pointer subtraction to
31 // figure out what number this is.
32 return this - &getOwner()->getStmtResults()[0];
33}
34
35//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070036// Statement
37//===------------------------------------------------------------------===//
38
39// Statements are deleted through the destroy() member because we don't have
40// a virtual destructor.
41Statement::~Statement() {
42 assert(block == nullptr && "statement destroyed but still in a block");
43}
44
45/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070046void Statement::destroy() {
47 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070048 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070049 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070050 break;
51 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070052 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070053 break;
54 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070055 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070056 break;
57 }
58}
59
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070060Statement *Statement::getParentStmt() const {
61 return block ? block->getParentStmt() : nullptr;
62}
Tatiana Shpeismand880b352018-07-31 23:14:16 -070063
64MLFunction *Statement::findFunction() const {
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070065 return block ? block->findFunction() : nullptr;
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070066}
67
Uday Bondhugula081d9e72018-07-27 10:58:14 -070068bool Statement::isInnermost() const {
69 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070070 unsigned numNestedLoops;
71 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070072 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070073 };
74
75 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070076 nlc.walk(const_cast<Statement *>(this));
77 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070078}
79
Uday Bondhugula84b80952018-08-03 13:22:26 -070080Statement *Statement::clone() const {
81 switch (kind) {
82 case Kind::Operation:
83 return cast<OperationStmt>(this)->clone();
84 case Kind::If:
85 llvm_unreachable("cloning for if's not implemented yet");
86 return cast<IfStmt>(this)->clone();
87 case Kind::For:
88 llvm_unreachable("cloning for loops not implemented yet");
89 return cast<ForStmt>(this)->clone();
90 }
91}
92
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070093//===----------------------------------------------------------------------===//
94// ilist_traits for Statement
95//===----------------------------------------------------------------------===//
96
97StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
98 size_t Offset(
99 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
100 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
101 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
102 Offset);
103}
104
105/// This is a trait method invoked when a statement is added to a block. We
106/// keep the block pointer up to date.
107void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
108 assert(!stmt->getBlock() && "already in a statement block!");
109 stmt->block = getContainingBlock();
110}
111
112/// This is a trait method invoked when a statement is removed from a block.
113/// We keep the block pointer up to date.
114void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
115 Statement *stmt) {
116 assert(stmt->block && "not already in a statement block!");
117 stmt->block = nullptr;
118}
119
120/// This is a trait method invoked when a statement is moved from one block
121/// to another. We keep the block pointer up to date.
122void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
123 ilist_traits<Statement> &otherList, stmt_iterator first,
124 stmt_iterator last) {
125 // If we are transferring statements within the same block, the block
126 // pointer doesn't need to be updated.
127 StmtBlock *curParent = getContainingBlock();
128 if (curParent == otherList.getContainingBlock())
129 return;
130
131 // Update the 'block' member of each statement.
132 for (; first != last; ++first)
133 first->block = curParent;
134}
135
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700136/// Remove this statement (and its descendants) from its StmtBlock and delete
137/// all of them.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700138void Statement::eraseFromBlock() {
139 assert(getBlock() && "Statement has no block");
140 getBlock()->getStatements().erase(this);
141}
142
143//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700144// OperationStmt
145//===----------------------------------------------------------------------===//
146
147/// Create a new OperationStmt with the specific fields.
148OperationStmt *OperationStmt::create(Identifier name,
149 ArrayRef<MLValue *> operands,
150 ArrayRef<Type *> resultTypes,
151 ArrayRef<NamedAttribute> attributes,
152 MLIRContext *context) {
153 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
154 resultTypes.size());
155 void *rawMem = malloc(byteSize);
156
157 // Initialize the OperationStmt part of the statement.
158 auto stmt = ::new (rawMem) OperationStmt(
159 name, operands.size(), resultTypes.size(), attributes, context);
160
161 // Initialize the operands and results.
162 auto stmtOperands = stmt->getStmtOperands();
163 for (unsigned i = 0, e = operands.size(); i != e; ++i)
164 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
165
166 auto stmtResults = stmt->getStmtResults();
167 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
168 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
169 return stmt;
170}
171
Uday Bondhugula15984952018-08-01 22:36:12 -0700172/// Clone an existing OperationStmt.
173OperationStmt *OperationStmt::clone() const {
174 SmallVector<MLValue *, 8> operands;
175 SmallVector<Type *, 8> resultTypes;
176
177 // TODO(clattner): switch this to iterator logic.
178 // Put together operands and results.
179 for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
180 operands.push_back(getStmtOperand(i).get());
181
182 for (unsigned i = 0, e = getNumResults(); i != e; ++i)
183 resultTypes.push_back(getStmtResult(i).getType());
184
185 return create(getName(), operands, resultTypes, getAttrs(), getContext());
186}
187
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700188OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
189 unsigned numResults,
190 ArrayRef<NamedAttribute> attributes,
191 MLIRContext *context)
192 : Operation(name, /*isInstruction=*/false, attributes, context),
193 Statement(Kind::Operation), numOperands(numOperands),
194 numResults(numResults) {}
195
196OperationStmt::~OperationStmt() {
197 // Explicitly run the destructors for the operands and results.
198 for (auto &operand : getStmtOperands())
199 operand.~StmtOperand();
200
201 for (auto &result : getStmtResults())
202 result.~StmtResult();
203}
204
205void OperationStmt::destroy() {
206 this->~OperationStmt();
207 free(this);
208}
209
Chris Lattner95865062018-08-01 10:18:59 -0700210/// Return the context this operation is associated with.
211MLIRContext *OperationStmt::getContext() const {
212 // If we have a result or operand type, that is a constant time way to get
213 // to the context.
214 if (getNumResults())
215 return getResult(0)->getType()->getContext();
216 if (getNumOperands())
217 return getOperand(0)->getType()->getContext();
218
219 // In the very odd case where we have no operands or results, fall back to
220 // doing a find.
221 return findFunction()->getContext();
222}
223
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700224/// This drops all operand uses from this statement, which is an essential
225/// step in breaking cyclic dependences between references when they are to
226/// be deleted.
227void OperationStmt::dropAllReferences() {
228 for (auto &op : getStmtOperands())
229 op.drop();
230}
231
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700232//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700233// ForStmt
234//===----------------------------------------------------------------------===//
235
236ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
237 AffineConstantExpr *step, MLIRContext *context)
Uday Bondhugula15984952018-08-01 22:36:12 -0700238 : Statement(Kind::For),
Tatiana Shpeismanc9c4b342018-07-31 07:40:14 -0700239 MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
Uday Bondhugula15984952018-08-01 22:36:12 -0700240 StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
241 upperBound(upperBound), step(step) {}
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700242
Uday Bondhugula84b80952018-08-03 13:22:26 -0700243ForStmt *ForStmt::clone() const {
244 auto *stmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
245 Statement::findFunction()->getContext());
246 for (auto &s : getStatements()) {
247 stmt->getStatements().push_back(s.clone());
248 }
249 return stmt;
250}
251
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700252//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700253// IfStmt
254//===----------------------------------------------------------------------===//
255
256IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700257 delete thenClause;
Uday Bondhugula15984952018-08-01 22:36:12 -0700258 if (elseClause)
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700259 delete elseClause;
260}
Uday Bondhugula84b80952018-08-03 13:22:26 -0700261
262IfStmt *IfStmt::clone() const {
263 llvm_unreachable("cloning for if's not implemented yet");
264 return nullptr;
265}