blob: 4e2a7b03f8fb2e663b1b3cd8e72466ddea60d190 [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070021#include "mlir/IR/Types.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070022using namespace mlir;
23
24//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070025// StmtResult
26//===------------------------------------------------------------------===//
27
28/// Return the result number of this result.
29unsigned StmtResult::getResultNumber() const {
30 // Results are always stored consecutively, so use pointer subtraction to
31 // figure out what number this is.
32 return this - &getOwner()->getStmtResults()[0];
33}
34
35//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070036// Statement
37//===------------------------------------------------------------------===//
38
39// Statements are deleted through the destroy() member because we don't have
40// a virtual destructor.
41Statement::~Statement() {
42 assert(block == nullptr && "statement destroyed but still in a block");
43}
44
45/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070046void Statement::destroy() {
47 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070048 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070049 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070050 break;
51 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070052 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070053 break;
54 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070055 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070056 break;
57 }
58}
59
Tatiana Shpeismand880b352018-07-31 23:14:16 -070060Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
61
62MLFunction *Statement::findFunction() const {
63 return this->getBlock()->findFunction();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070064}
65
Uday Bondhugula081d9e72018-07-27 10:58:14 -070066bool Statement::isInnermost() const {
67 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070068 unsigned numNestedLoops;
69 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070070 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070071 };
72
73 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070074 nlc.walk(const_cast<Statement *>(this));
75 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070076}
77
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070078//===----------------------------------------------------------------------===//
79// ilist_traits for Statement
80//===----------------------------------------------------------------------===//
81
82StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
83 size_t Offset(
84 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
85 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
86 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
87 Offset);
88}
89
90/// This is a trait method invoked when a statement is added to a block. We
91/// keep the block pointer up to date.
92void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
93 assert(!stmt->getBlock() && "already in a statement block!");
94 stmt->block = getContainingBlock();
95}
96
97/// This is a trait method invoked when a statement is removed from a block.
98/// We keep the block pointer up to date.
99void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
100 Statement *stmt) {
101 assert(stmt->block && "not already in a statement block!");
102 stmt->block = nullptr;
103}
104
105/// This is a trait method invoked when a statement is moved from one block
106/// to another. We keep the block pointer up to date.
107void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
108 ilist_traits<Statement> &otherList, stmt_iterator first,
109 stmt_iterator last) {
110 // If we are transferring statements within the same block, the block
111 // pointer doesn't need to be updated.
112 StmtBlock *curParent = getContainingBlock();
113 if (curParent == otherList.getContainingBlock())
114 return;
115
116 // Update the 'block' member of each statement.
117 for (; first != last; ++first)
118 first->block = curParent;
119}
120
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700121/// Remove this statement (and its descendants) from its StmtBlock and delete
122/// all of them.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700123void Statement::eraseFromBlock() {
124 assert(getBlock() && "Statement has no block");
125 getBlock()->getStatements().erase(this);
126}
127
128//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700129// OperationStmt
130//===----------------------------------------------------------------------===//
131
132/// Create a new OperationStmt with the specific fields.
133OperationStmt *OperationStmt::create(Identifier name,
134 ArrayRef<MLValue *> operands,
135 ArrayRef<Type *> resultTypes,
136 ArrayRef<NamedAttribute> attributes,
137 MLIRContext *context) {
138 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
139 resultTypes.size());
140 void *rawMem = malloc(byteSize);
141
142 // Initialize the OperationStmt part of the statement.
143 auto stmt = ::new (rawMem) OperationStmt(
144 name, operands.size(), resultTypes.size(), attributes, context);
145
146 // Initialize the operands and results.
147 auto stmtOperands = stmt->getStmtOperands();
148 for (unsigned i = 0, e = operands.size(); i != e; ++i)
149 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
150
151 auto stmtResults = stmt->getStmtResults();
152 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
153 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
154 return stmt;
155}
156
Uday Bondhugula15984952018-08-01 22:36:12 -0700157/// Clone an existing OperationStmt.
158OperationStmt *OperationStmt::clone() const {
159 SmallVector<MLValue *, 8> operands;
160 SmallVector<Type *, 8> resultTypes;
161
162 // TODO(clattner): switch this to iterator logic.
163 // Put together operands and results.
164 for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
165 operands.push_back(getStmtOperand(i).get());
166
167 for (unsigned i = 0, e = getNumResults(); i != e; ++i)
168 resultTypes.push_back(getStmtResult(i).getType());
169
170 return create(getName(), operands, resultTypes, getAttrs(), getContext());
171}
172
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700173OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
174 unsigned numResults,
175 ArrayRef<NamedAttribute> attributes,
176 MLIRContext *context)
177 : Operation(name, /*isInstruction=*/false, attributes, context),
178 Statement(Kind::Operation), numOperands(numOperands),
179 numResults(numResults) {}
180
181OperationStmt::~OperationStmt() {
182 // Explicitly run the destructors for the operands and results.
183 for (auto &operand : getStmtOperands())
184 operand.~StmtOperand();
185
186 for (auto &result : getStmtResults())
187 result.~StmtResult();
188}
189
190void OperationStmt::destroy() {
191 this->~OperationStmt();
192 free(this);
193}
194
Chris Lattner95865062018-08-01 10:18:59 -0700195/// Return the context this operation is associated with.
196MLIRContext *OperationStmt::getContext() const {
197 // If we have a result or operand type, that is a constant time way to get
198 // to the context.
199 if (getNumResults())
200 return getResult(0)->getType()->getContext();
201 if (getNumOperands())
202 return getOperand(0)->getType()->getContext();
203
204 // In the very odd case where we have no operands or results, fall back to
205 // doing a find.
206 return findFunction()->getContext();
207}
208
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700209/// This drops all operand uses from this statement, which is an essential
210/// step in breaking cyclic dependences between references when they are to
211/// be deleted.
212void OperationStmt::dropAllReferences() {
213 for (auto &op : getStmtOperands())
214 op.drop();
215}
216
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700217//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700218// ForStmt
219//===----------------------------------------------------------------------===//
220
221ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
222 AffineConstantExpr *step, MLIRContext *context)
Uday Bondhugula15984952018-08-01 22:36:12 -0700223 : Statement(Kind::For),
Tatiana Shpeismanc9c4b342018-07-31 07:40:14 -0700224 MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
Uday Bondhugula15984952018-08-01 22:36:12 -0700225 StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
226 upperBound(upperBound), step(step) {}
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700227
228//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700229// IfStmt
230//===----------------------------------------------------------------------===//
231
232IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700233 delete thenClause;
Uday Bondhugula15984952018-08-01 22:36:12 -0700234 if (elseClause)
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700235 delete elseClause;
236}