blob: b55b5976572e11434ce407bc4bbe091af2ca1ddb [file] [log] [blame]
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -07001//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLFunction.h"
19#include "mlir/IR/Statements.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070020#include "mlir/IR/StmtVisitor.h"
Tatiana Shpeisman3838db72018-07-30 15:18:10 -070021#include "mlir/IR/Types.h"
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070022using namespace mlir;
23
24//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070025// StmtResult
26//===------------------------------------------------------------------===//
27
28/// Return the result number of this result.
29unsigned StmtResult::getResultNumber() const {
30 // Results are always stored consecutively, so use pointer subtraction to
31 // figure out what number this is.
32 return this - &getOwner()->getStmtResults()[0];
33}
34
35//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070036// Statement
37//===------------------------------------------------------------------===//
38
39// Statements are deleted through the destroy() member because we don't have
40// a virtual destructor.
41Statement::~Statement() {
42 assert(block == nullptr && "statement destroyed but still in a block");
43}
44
45/// Destroy this statement or one of its subclasses.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070046void Statement::destroy() {
47 switch (this->getKind()) {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070048 case Kind::Operation:
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -070049 cast<OperationStmt>(this)->destroy();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070050 break;
51 case Kind::For:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070052 delete cast<ForStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070053 break;
54 case Kind::If:
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070055 delete cast<IfStmt>(this);
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070056 break;
57 }
58}
59
Tatiana Shpeismand880b352018-07-31 23:14:16 -070060Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
61
62MLFunction *Statement::findFunction() const {
63 return this->getBlock()->findFunction();
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070064}
65
Uday Bondhugula081d9e72018-07-27 10:58:14 -070066bool Statement::isInnermost() const {
67 struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070068 unsigned numNestedLoops;
69 NestedLoopCounter() : numNestedLoops(0) {}
Uday Bondhugula081d9e72018-07-27 10:58:14 -070070 void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070071 };
72
73 NestedLoopCounter nlc;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070074 nlc.walk(const_cast<Statement *>(this));
75 return nlc.numNestedLoops == 1;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070076}
77
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -070078//===----------------------------------------------------------------------===//
79// ilist_traits for Statement
80//===----------------------------------------------------------------------===//
81
82StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
83 size_t Offset(
84 size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
85 iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
86 return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
87 Offset);
88}
89
90/// This is a trait method invoked when a statement is added to a block. We
91/// keep the block pointer up to date.
92void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
93 assert(!stmt->getBlock() && "already in a statement block!");
94 stmt->block = getContainingBlock();
95}
96
97/// This is a trait method invoked when a statement is removed from a block.
98/// We keep the block pointer up to date.
99void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
100 Statement *stmt) {
101 assert(stmt->block && "not already in a statement block!");
102 stmt->block = nullptr;
103}
104
105/// This is a trait method invoked when a statement is moved from one block
106/// to another. We keep the block pointer up to date.
107void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
108 ilist_traits<Statement> &otherList, stmt_iterator first,
109 stmt_iterator last) {
110 // If we are transferring statements within the same block, the block
111 // pointer doesn't need to be updated.
112 StmtBlock *curParent = getContainingBlock();
113 if (curParent == otherList.getContainingBlock())
114 return;
115
116 // Update the 'block' member of each statement.
117 for (; first != last; ++first)
118 first->block = curParent;
119}
120
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700121/// Remove this statement (and its descendants) from its StmtBlock and delete
122/// all of them.
123/// TODO: erase all descendents for ForStmt/IfStmt.
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700124void Statement::eraseFromBlock() {
125 assert(getBlock() && "Statement has no block");
126 getBlock()->getStatements().erase(this);
127}
128
129//===----------------------------------------------------------------------===//
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700130// OperationStmt
131//===----------------------------------------------------------------------===//
132
133/// Create a new OperationStmt with the specific fields.
134OperationStmt *OperationStmt::create(Identifier name,
135 ArrayRef<MLValue *> operands,
136 ArrayRef<Type *> resultTypes,
137 ArrayRef<NamedAttribute> attributes,
138 MLIRContext *context) {
139 auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
140 resultTypes.size());
141 void *rawMem = malloc(byteSize);
142
143 // Initialize the OperationStmt part of the statement.
144 auto stmt = ::new (rawMem) OperationStmt(
145 name, operands.size(), resultTypes.size(), attributes, context);
146
147 // Initialize the operands and results.
148 auto stmtOperands = stmt->getStmtOperands();
149 for (unsigned i = 0, e = operands.size(); i != e; ++i)
150 new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
151
152 auto stmtResults = stmt->getStmtResults();
153 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
154 new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
155 return stmt;
156}
157
158OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
159 unsigned numResults,
160 ArrayRef<NamedAttribute> attributes,
161 MLIRContext *context)
162 : Operation(name, /*isInstruction=*/false, attributes, context),
163 Statement(Kind::Operation), numOperands(numOperands),
164 numResults(numResults) {}
165
166OperationStmt::~OperationStmt() {
167 // Explicitly run the destructors for the operands and results.
168 for (auto &operand : getStmtOperands())
169 operand.~StmtOperand();
170
171 for (auto &result : getStmtResults())
172 result.~StmtResult();
173}
174
175void OperationStmt::destroy() {
176 this->~OperationStmt();
177 free(this);
178}
179
180/// This drops all operand uses from this statement, which is an essential
181/// step in breaking cyclic dependences between references when they are to
182/// be deleted.
183void OperationStmt::dropAllReferences() {
184 for (auto &op : getStmtOperands())
185 op.drop();
186}
187
188/// If this value is the result of an OperationStmt, return the statement
189/// that defines it.
190OperationStmt *SSAValue::getDefiningStmt() {
191 if (auto *result = dyn_cast<StmtResult>(this))
192 return result->getOwner();
193 return nullptr;
194}
195
196//===----------------------------------------------------------------------===//
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700197// ForStmt
198//===----------------------------------------------------------------------===//
199
200ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
201 AffineConstantExpr *step, MLIRContext *context)
202 : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
Tatiana Shpeismanc9c4b342018-07-31 07:40:14 -0700203 MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
Tatiana Shpeisman3838db72018-07-30 15:18:10 -0700204 lowerBound(lowerBound), upperBound(upperBound), step(step) {}
205
206//===----------------------------------------------------------------------===//
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700207// IfStmt
208//===----------------------------------------------------------------------===//
209
210IfStmt::~IfStmt() {
Tatiana Shpeisman1bcfe982018-07-13 13:03:13 -0700211 delete thenClause;
212 if (elseClause != nullptr)
213 delete elseClause;
214}