blob: 978137bee2e12b8e317d2a769b0dc148d966ddde [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070021#include "mlir/IR/Types.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070022using namespace mlir;
23
24//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070025// StmtResult
26//===------------------------------------------------------------------===//
27
28/// Return the result number of this result.
29unsigned StmtResult::getResultNumber() const {
30 // Results are always stored consecutively, so use pointer subtraction to
31 // figure out what number this is.
32 return this - &getOwner()->getStmtResults()[0];
33}
34
35//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070036// Statement
37//===------------------------------------------------------------------===//
38
39// Statements are deleted through the destroy() member because we don't have
40// a virtual destructor.
41Statement::~Statement() {
42 assert(block == nullptr && "statement destroyed but still in a block");
43}
44
45/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070046void Statement::destroy() {
47 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070048 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070049 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070050 break;
51 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070052 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070053 break;
54 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070055 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070056 break;
57 }
58}
59
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070060Statement *Statement::getParentStmt() const {
61 return block ? block->getParentStmt() : nullptr;
62}
Tatiana Shpeismand880b352018-07-31 23:14:16 -070063
64MLFunction *Statement::findFunction() const {
Tatiana Shpeismanc335d182018-08-03 11:12:34 -070065 return block ? block->findFunction() : nullptr;
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070066}
67
Uday Bondhugula081d9e72018-07-27 10:58:14 -070068bool Statement::isInnermost() const {
69 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070070 unsigned numNestedLoops;
71 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070072 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070073 };
74
75 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070076 nlc.walk(const_cast<Statement *>(this));
77 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070078}
79
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070080//===----------------------------------------------------------------------===//
81// ilist_traits for Statement
82//===----------------------------------------------------------------------===//
83
84StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
85 size_t Offset(
86 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
87 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
88 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
89 Offset);
90}
91
92/// This is a trait method invoked when a statement is added to a block. We
93/// keep the block pointer up to date.
94void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
95 assert(!stmt->getBlock() && "already in a statement block!");
96 stmt->block = getContainingBlock();
97}
98
99/// This is a trait method invoked when a statement is removed from a block.
100/// We keep the block pointer up to date.
101void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
102 Statement *stmt) {
103 assert(stmt->block && "not already in a statement block!");
104 stmt->block = nullptr;
105}
106
107/// This is a trait method invoked when a statement is moved from one block
108/// to another. We keep the block pointer up to date.
109void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
110 ilist_traits<Statement> &otherList, stmt_iterator first,
111 stmt_iterator last) {
112 // If we are transferring statements within the same block, the block
113 // pointer doesn't need to be updated.
114 StmtBlock *curParent = getContainingBlock();
115 if (curParent == otherList.getContainingBlock())
116 return;
117
118 // Update the 'block' member of each statement.
119 for (; first != last; ++first)
120 first->block = curParent;
121}
122
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700123/// Remove this statement (and its descendants) from its StmtBlock and delete
124/// all of them.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700125void Statement::eraseFromBlock() {
126 assert(getBlock() && "Statement has no block");
127 getBlock()->getStatements().erase(this);
128}
129
130//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700131// OperationStmt
132//===----------------------------------------------------------------------===//
133
134/// Create a new OperationStmt with the specific fields.
135OperationStmt *OperationStmt::create(Identifier name,
136 ArrayRef<MLValue *> operands,
137 ArrayRef<Type *> resultTypes,
138 ArrayRef<NamedAttribute> attributes,
139 MLIRContext *context) {
140 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
141 resultTypes.size());
142 void *rawMem = malloc(byteSize);
143
144 // Initialize the OperationStmt part of the statement.
145 auto stmt = ::new (rawMem) OperationStmt(
146 name, operands.size(), resultTypes.size(), attributes, context);
147
148 // Initialize the operands and results.
149 auto stmtOperands = stmt->getStmtOperands();
150 for (unsigned i = 0, e = operands.size(); i != e; ++i)
151 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
152
153 auto stmtResults = stmt->getStmtResults();
154 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
155 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
156 return stmt;
157}
158
Uday Bondhugula15984952018-08-01 22:36:12 -0700159/// Clone an existing OperationStmt.
160OperationStmt *OperationStmt::clone() const {
161 SmallVector<MLValue *, 8> operands;
162 SmallVector<Type *, 8> resultTypes;
163
164 // TODO(clattner): switch this to iterator logic.
165 // Put together operands and results.
166 for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
167 operands.push_back(getStmtOperand(i).get());
168
169 for (unsigned i = 0, e = getNumResults(); i != e; ++i)
170 resultTypes.push_back(getStmtResult(i).getType());
171
172 return create(getName(), operands, resultTypes, getAttrs(), getContext());
173}
174
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700175OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
176 unsigned numResults,
177 ArrayRef<NamedAttribute> attributes,
178 MLIRContext *context)
179 : Operation(name, /*isInstruction=*/false, attributes, context),
180 Statement(Kind::Operation), numOperands(numOperands),
181 numResults(numResults) {}
182
183OperationStmt::~OperationStmt() {
184 // Explicitly run the destructors for the operands and results.
185 for (auto &operand : getStmtOperands())
186 operand.~StmtOperand();
187
188 for (auto &result : getStmtResults())
189 result.~StmtResult();
190}
191
192void OperationStmt::destroy() {
193 this->~OperationStmt();
194 free(this);
195}
196
Chris Lattner95865062018-08-01 10:18:59 -0700197/// Return the context this operation is associated with.
198MLIRContext *OperationStmt::getContext() const {
199 // If we have a result or operand type, that is a constant time way to get
200 // to the context.
201 if (getNumResults())
202 return getResult(0)->getType()->getContext();
203 if (getNumOperands())
204 return getOperand(0)->getType()->getContext();
205
206 // In the very odd case where we have no operands or results, fall back to
207 // doing a find.
208 return findFunction()->getContext();
209}
210
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700211/// This drops all operand uses from this statement, which is an essential
212/// step in breaking cyclic dependences between references when they are to
213/// be deleted.
214void OperationStmt::dropAllReferences() {
215 for (auto &op : getStmtOperands())
216 op.drop();
217}
218
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700219//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700220// ForStmt
221//===----------------------------------------------------------------------===//
222
223ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
224 AffineConstantExpr *step, MLIRContext *context)
Uday Bondhugula15984952018-08-01 22:36:12 -0700225 : Statement(Kind::For),
Tatiana Shpeismanc9c4b342018-07-31 07:40:14 -0700226 MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
Uday Bondhugula15984952018-08-01 22:36:12 -0700227 StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
228 upperBound(upperBound), step(step) {}
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700229
230//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700231// IfStmt
232//===----------------------------------------------------------------------===//
233
234IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700235 delete thenClause;
Uday Bondhugula15984952018-08-01 22:36:12 -0700236 if (elseClause)
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700237 delete elseClause;
238}