blob: 6ace8ffcda2946f9a256c90fbbd02dcb5f9ae965 [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070021#include "mlir/IR/Types.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070022using namespace mlir;
23
24//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070025// StmtResult
26//===------------------------------------------------------------------===//
27
28/// Return the result number of this result.
29unsigned StmtResult::getResultNumber() const {
30 // Results are always stored consecutively, so use pointer subtraction to
31 // figure out what number this is.
32 return this - &getOwner()->getStmtResults()[0];
33}
34
35//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070036// Statement
37//===------------------------------------------------------------------===//
38
39// Statements are deleted through the destroy() member because we don't have
40// a virtual destructor.
41Statement::~Statement() {
42 assert(block == nullptr && "statement destroyed but still in a block");
43}
44
45/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070046void Statement::destroy() {
47 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070048 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070049 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070050 break;
51 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070052 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070053 break;
54 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070055 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070056 break;
57 }
58}
59
60MLFunction *Statement::getFunction() const {
61 return this->getBlock()->getFunction();
62}
63
Uday Bondhugula081d9e72018-07-27 10:58:14 -070064bool Statement::isInnermost() const {
65 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070066 unsigned numNestedLoops;
67 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070068 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070069 };
70
71 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070072 nlc.walk(const_cast<Statement *>(this));
73 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070074}
75
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070076//===----------------------------------------------------------------------===//
77// ilist_traits for Statement
78//===----------------------------------------------------------------------===//
79
80StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
81 size_t Offset(
82 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
83 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
84 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
85 Offset);
86}
87
88/// This is a trait method invoked when a statement is added to a block. We
89/// keep the block pointer up to date.
90void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
91 assert(!stmt->getBlock() && "already in a statement block!");
92 stmt->block = getContainingBlock();
93}
94
95/// This is a trait method invoked when a statement is removed from a block.
96/// We keep the block pointer up to date.
97void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
98 Statement *stmt) {
99 assert(stmt->block && "not already in a statement block!");
100 stmt->block = nullptr;
101}
102
103/// This is a trait method invoked when a statement is moved from one block
104/// to another. We keep the block pointer up to date.
105void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
106 ilist_traits<Statement> &otherList, stmt_iterator first,
107 stmt_iterator last) {
108 // If we are transferring statements within the same block, the block
109 // pointer doesn't need to be updated.
110 StmtBlock *curParent = getContainingBlock();
111 if (curParent == otherList.getContainingBlock())
112 return;
113
114 // Update the 'block' member of each statement.
115 for (; first != last; ++first)
116 first->block = curParent;
117}
118
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700119/// Remove this statement (and its descendants) from its StmtBlock and delete
120/// all of them.
121/// TODO: erase all descendents for ForStmt/IfStmt.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700122void Statement::eraseFromBlock() {
123 assert(getBlock() && "Statement has no block");
124 getBlock()->getStatements().erase(this);
125}
126
127//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700128// OperationStmt
129//===----------------------------------------------------------------------===//
130
131/// Create a new OperationStmt with the specific fields.
132OperationStmt *OperationStmt::create(Identifier name,
133 ArrayRef<MLValue *> operands,
134 ArrayRef<Type *> resultTypes,
135 ArrayRef<NamedAttribute> attributes,
136 MLIRContext *context) {
137 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
138 resultTypes.size());
139 void *rawMem = malloc(byteSize);
140
141 // Initialize the OperationStmt part of the statement.
142 auto stmt = ::new (rawMem) OperationStmt(
143 name, operands.size(), resultTypes.size(), attributes, context);
144
145 // Initialize the operands and results.
146 auto stmtOperands = stmt->getStmtOperands();
147 for (unsigned i = 0, e = operands.size(); i != e; ++i)
148 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
149
150 auto stmtResults = stmt->getStmtResults();
151 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
152 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
153 return stmt;
154}
155
156OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
157 unsigned numResults,
158 ArrayRef<NamedAttribute> attributes,
159 MLIRContext *context)
160 : Operation(name, /*isInstruction=*/false, attributes, context),
161 Statement(Kind::Operation), numOperands(numOperands),
162 numResults(numResults) {}
163
164OperationStmt::~OperationStmt() {
165 // Explicitly run the destructors for the operands and results.
166 for (auto &operand : getStmtOperands())
167 operand.~StmtOperand();
168
169 for (auto &result : getStmtResults())
170 result.~StmtResult();
171}
172
173void OperationStmt::destroy() {
174 this->~OperationStmt();
175 free(this);
176}
177
178/// This drops all operand uses from this statement, which is an essential
179/// step in breaking cyclic dependences between references when they are to
180/// be deleted.
181void OperationStmt::dropAllReferences() {
182 for (auto &op : getStmtOperands())
183 op.drop();
184}
185
186/// If this value is the result of an OperationStmt, return the statement
187/// that defines it.
188OperationStmt *SSAValue::getDefiningStmt() {
189 if (auto *result = dyn_cast<StmtResult>(this))
190 return result->getOwner();
191 return nullptr;
192}
193
194//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700195// ForStmt
196//===----------------------------------------------------------------------===//
197
198ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
199 AffineConstantExpr *step, MLIRContext *context)
200 : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
201 MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)),
202 lowerBound(lowerBound), upperBound(upperBound), step(step) {}
203
204//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700205// IfStmt
206//===----------------------------------------------------------------------===//
207
208IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700209 delete thenClause;
210 if (elseClause != nullptr)
211 delete elseClause;
212}