blob: 4b6ddc77a0013933dfee586feeabb708d428c8ce [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070021using namespace mlir;
22
23//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070024// StmtResult
25//===------------------------------------------------------------------===//
26
27/// Return the result number of this result.
28unsigned StmtResult::getResultNumber() const {
29 // Results are always stored consecutively, so use pointer subtraction to
30 // figure out what number this is.
31 return this - &getOwner()->getStmtResults()[0];
32}
33
34//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070035// Statement
36//===------------------------------------------------------------------===//
37
38// Statements are deleted through the destroy() member because we don't have
39// a virtual destructor.
40Statement::~Statement() {
41 assert(block == nullptr && "statement destroyed but still in a block");
42}
43
44/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070045void Statement::destroy() {
46 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070047 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070048 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070049 break;
50 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070051 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070052 break;
53 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070054 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070055 break;
56 }
57}
58
59MLFunction *Statement::getFunction() const {
60 return this->getBlock()->getFunction();
61}
62
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070063unsigned Statement::getNumNestedLoops() const {
64 struct NestedLoopCounter : public StmtVisitor<NestedLoopCounter> {
65 unsigned numNestedLoops;
66 NestedLoopCounter() : numNestedLoops(0) {}
67 void visitForStmt(const ForStmt *fs) { numNestedLoops++; }
68 };
69
70 NestedLoopCounter nlc;
71 nlc.visit(const_cast<Statement *>(this));
72 return nlc.numNestedLoops;
73}
74
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070075//===----------------------------------------------------------------------===//
76// ilist_traits for Statement
77//===----------------------------------------------------------------------===//
78
79StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
80 size_t Offset(
81 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
82 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
83 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
84 Offset);
85}
86
87/// This is a trait method invoked when a statement is added to a block. We
88/// keep the block pointer up to date.
89void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
90 assert(!stmt->getBlock() && "already in a statement block!");
91 stmt->block = getContainingBlock();
92}
93
94/// This is a trait method invoked when a statement is removed from a block.
95/// We keep the block pointer up to date.
96void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
97 Statement *stmt) {
98 assert(stmt->block && "not already in a statement block!");
99 stmt->block = nullptr;
100}
101
102/// This is a trait method invoked when a statement is moved from one block
103/// to another. We keep the block pointer up to date.
104void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
105 ilist_traits<Statement> &otherList, stmt_iterator first,
106 stmt_iterator last) {
107 // If we are transferring statements within the same block, the block
108 // pointer doesn't need to be updated.
109 StmtBlock *curParent = getContainingBlock();
110 if (curParent == otherList.getContainingBlock())
111 return;
112
113 // Update the 'block' member of each statement.
114 for (; first != last; ++first)
115 first->block = curParent;
116}
117
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700118/// Remove this statement (and its descendants) from its StmtBlock and delete
119/// all of them.
120/// TODO: erase all descendents for ForStmt/IfStmt.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700121void Statement::eraseFromBlock() {
122 assert(getBlock() && "Statement has no block");
123 getBlock()->getStatements().erase(this);
124}
125
126//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700127// OperationStmt
128//===----------------------------------------------------------------------===//
129
130/// Create a new OperationStmt with the specific fields.
131OperationStmt *OperationStmt::create(Identifier name,
132 ArrayRef<MLValue *> operands,
133 ArrayRef<Type *> resultTypes,
134 ArrayRef<NamedAttribute> attributes,
135 MLIRContext *context) {
136 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
137 resultTypes.size());
138 void *rawMem = malloc(byteSize);
139
140 // Initialize the OperationStmt part of the statement.
141 auto stmt = ::new (rawMem) OperationStmt(
142 name, operands.size(), resultTypes.size(), attributes, context);
143
144 // Initialize the operands and results.
145 auto stmtOperands = stmt->getStmtOperands();
146 for (unsigned i = 0, e = operands.size(); i != e; ++i)
147 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
148
149 auto stmtResults = stmt->getStmtResults();
150 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
151 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
152 return stmt;
153}
154
155OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
156 unsigned numResults,
157 ArrayRef<NamedAttribute> attributes,
158 MLIRContext *context)
159 : Operation(name, /*isInstruction=*/false, attributes, context),
160 Statement(Kind::Operation), numOperands(numOperands),
161 numResults(numResults) {}
162
163OperationStmt::~OperationStmt() {
164 // Explicitly run the destructors for the operands and results.
165 for (auto &operand : getStmtOperands())
166 operand.~StmtOperand();
167
168 for (auto &result : getStmtResults())
169 result.~StmtResult();
170}
171
172void OperationStmt::destroy() {
173 this->~OperationStmt();
174 free(this);
175}
176
177/// This drops all operand uses from this statement, which is an essential
178/// step in breaking cyclic dependences between references when they are to
179/// be deleted.
180void OperationStmt::dropAllReferences() {
181 for (auto &op : getStmtOperands())
182 op.drop();
183}
184
185/// If this value is the result of an OperationStmt, return the statement
186/// that defines it.
187OperationStmt *SSAValue::getDefiningStmt() {
188 if (auto *result = dyn_cast<StmtResult>(this))
189 return result->getOwner();
190 return nullptr;
191}
192
193//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700194// IfStmt
195//===----------------------------------------------------------------------===//
196
197IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700198 delete thenClause;
199 if (elseClause != nullptr)
200 delete elseClause;
201}