blob: 4f7161ba46d22e2b3a0e7f879372de08350925b1 [file] [log] [blame]
Chris Lattner158e0a3e2018-07-08 20:51:38 -07001//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#ifndef MLIR_IR_BUILDERS_H
19#define MLIR_IR_BUILDERS_H
20
Uday Bondhugula15984952018-08-01 22:36:12 -070021#include "mlir/IR/Attributes.h"
Chris Lattner158e0a3e2018-07-08 20:51:38 -070022#include "mlir/IR/CFGFunction.h"
Tatiana Shpeisman565b9642018-07-16 11:47:09 -070023#include "mlir/IR/MLFunction.h"
24#include "mlir/IR/Statements.h"
Chris Lattner158e0a3e2018-07-08 20:51:38 -070025
26namespace mlir {
27class MLIRContext;
28class Module;
29class Type;
30class PrimitiveType;
31class IntegerType;
32class FunctionType;
33class VectorType;
34class RankedTensorType;
35class UnrankedTensorType;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070036class BoolAttr;
37class IntegerAttr;
38class FloatAttr;
39class StringAttr;
40class ArrayAttr;
MLIR Teamb61885d2018-07-18 16:29:21 -070041class AffineMapAttr;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070042class AffineMap;
43class AffineExpr;
44class AffineConstantExpr;
45class AffineDimExpr;
46class AffineSymbolExpr;
Chris Lattner158e0a3e2018-07-08 20:51:38 -070047
48/// This class is a general helper class for creating context-global objects
49/// like types, attributes, and affine expressions.
50class Builder {
51public:
52 explicit Builder(MLIRContext *context) : context(context) {}
53 explicit Builder(Module *module);
54
55 MLIRContext *getContext() const { return context; }
56
Chris Lattner1ac20cb2018-07-10 10:59:53 -070057 Identifier getIdentifier(StringRef str);
58 Module *createModule();
59
Chris Lattner158e0a3e2018-07-08 20:51:38 -070060 // Types.
Chris Lattnerc3251192018-07-27 13:09:58 -070061 FloatType *getBF16Type();
62 FloatType *getF16Type();
63 FloatType *getF32Type();
64 FloatType *getF64Type();
65
66 OtherType *getAffineIntType();
67 OtherType *getTFControlType();
James Molloy72b0cbe2018-08-01 12:55:27 -070068 OtherType *getTFStringType();
Chris Lattner158e0a3e2018-07-08 20:51:38 -070069 IntegerType *getIntegerType(unsigned width);
70 FunctionType *getFunctionType(ArrayRef<Type *> inputs,
71 ArrayRef<Type *> results);
72 VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
73 RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
74 UnrankedTensorType *getTensorType(Type *elementType);
75
Chris Lattner1ac20cb2018-07-10 10:59:53 -070076 // Attributes.
77 BoolAttr *getBoolAttr(bool value);
78 IntegerAttr *getIntegerAttr(int64_t value);
79 FloatAttr *getFloatAttr(double value);
80 StringAttr *getStringAttr(StringRef bytes);
81 ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
MLIR Teamb61885d2018-07-18 16:29:21 -070082 AffineMapAttr *getAffineMapAttr(AffineMap *value);
Chris Lattner1ac20cb2018-07-10 10:59:53 -070083
84 // Affine Expressions and Affine Map.
85 AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
Uday Bondhugula0115dbb2018-07-11 21:31:07 -070086 ArrayRef<AffineExpr *> results,
87 ArrayRef<AffineExpr *> rangeSizes);
Chris Lattner1ac20cb2018-07-10 10:59:53 -070088 AffineDimExpr *getDimExpr(unsigned position);
89 AffineSymbolExpr *getSymbolExpr(unsigned position);
90 AffineConstantExpr *getConstantExpr(int64_t constant);
91 AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
92 AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
93 AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
94 AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);
95 AffineExpr *getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs);
96 AffineExpr *getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs);
97
Chris Lattner158e0a3e2018-07-08 20:51:38 -070098 // TODO: Helpers for affine map/exprs, etc.
Chris Lattner158e0a3e2018-07-08 20:51:38 -070099protected:
100 MLIRContext *context;
101};
102
103/// This class helps build a CFGFunction. Instructions that are created are
104/// automatically inserted at an insertion point or added to the current basic
105/// block.
106class CFGFuncBuilder : public Builder {
107public:
Chris Lattner8174f3a2018-07-29 16:45:23 -0700108 CFGFuncBuilder(BasicBlock *block, BasicBlock::iterator insertPoint)
109 : Builder(block->getFunction()->getContext()),
110 function(block->getFunction()) {
111 setInsertionPoint(block, insertPoint);
112 }
113
114 CFGFuncBuilder(OperationInst *insertBefore)
115 : CFGFuncBuilder(insertBefore->getBlock(),
116 BasicBlock::iterator(insertBefore)) {}
117
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700118 CFGFuncBuilder(BasicBlock *block)
119 : Builder(block->getFunction()->getContext()),
120 function(block->getFunction()) {
121 setInsertionPoint(block);
122 }
Chris Lattner8174f3a2018-07-29 16:45:23 -0700123
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700124 CFGFuncBuilder(CFGFunction *function)
125 : Builder(function->getContext()), function(function) {}
126
127 /// Reset the insertion point to no location. Creating an operation without a
128 /// set insertion point is an error, but this can still be useful when the
129 /// current insertion point a builder refers to is being removed.
130 void clearInsertionPoint() {
131 this->block = nullptr;
132 insertPoint = BasicBlock::iterator();
133 }
134
Chris Lattner8174f3a2018-07-29 16:45:23 -0700135 /// Set the insertion point to the specified location.
136 void setInsertionPoint(BasicBlock *block, BasicBlock::iterator insertPoint) {
137 assert(block->getFunction() == function &&
138 "can't move to a different function");
139 this->block = block;
140 this->insertPoint = insertPoint;
141 }
142
143 /// Set the insertion point to the specified operation.
144 void setInsertionPoint(OperationInst *inst) {
145 setInsertionPoint(inst->getBlock(), BasicBlock::iterator(inst));
146 }
147
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700148 /// Set the insertion point to the end of the specified block.
149 void setInsertionPoint(BasicBlock *block) {
Chris Lattner8174f3a2018-07-29 16:45:23 -0700150 setInsertionPoint(block, block->end());
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700151 }
152
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700153 // Add new basic block and set the insertion point to the end of it.
154 BasicBlock *createBlock();
155
Chris Lattner8174f3a2018-07-29 16:45:23 -0700156 // Create an operation at the current insertion point.
Chris Lattner3b2ef762018-07-18 15:31:25 -0700157 OperationInst *createOperation(Identifier name, ArrayRef<CFGValue *> operands,
158 ArrayRef<Type *> resultTypes,
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700159 ArrayRef<NamedAttribute> attributes) {
Chris Lattner3b2ef762018-07-18 15:31:25 -0700160 auto op =
161 OperationInst::create(name, operands, resultTypes, attributes, context);
Chris Lattner8174f3a2018-07-29 16:45:23 -0700162 block->getOperations().insert(insertPoint, op);
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700163 return op;
164 }
165
Uday Bondhugula15984952018-08-01 22:36:12 -0700166 OperationInst *cloneOperation(const OperationInst &srcOpInst) {
167 auto *op = srcOpInst.clone();
168 block->getOperations().insert(insertPoint, op);
169 return op;
170 }
171
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700172 // Terminators.
173
Chris Lattner40746442018-07-21 14:32:09 -0700174 ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
175 return insertTerminator(ReturnInst::create(operands));
176 }
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700177
178 BranchInst *createBranchInst(BasicBlock *dest) {
Chris Lattner40746442018-07-21 14:32:09 -0700179 return insertTerminator(BranchInst::create(dest));
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700180 }
181
James Molloy4f788372018-07-24 15:01:27 -0700182 CondBranchInst *createCondBranchInst(CFGValue *condition,
183 BasicBlock *trueDest,
184 BasicBlock *falseDest) {
185 return insertTerminator(
186 CondBranchInst::create(condition, trueDest, falseDest));
187 }
188
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700189private:
190 template <typename T>
191 T *insertTerminator(T *term) {
192 block->setTerminator(term);
193 return term;
194 }
195
196 CFGFunction *function;
197 BasicBlock *block = nullptr;
198 BasicBlock::iterator insertPoint;
199};
200
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700201/// This class helps build an MLFunction. Statements that are created are
202/// automatically inserted at an insertion point or added to the current
203/// statement block.
204class MLFuncBuilder : public Builder {
205public:
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700206 /// Create ML function builder and set insertion point to the given
207 /// statement block, that is, given ML function, for statement or if statement
208 /// clause.
209 MLFuncBuilder(StmtBlock *block)
210 : Builder(block->findFunction()->getContext()) {
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700211 setInsertionPoint(block);
212 }
213
214 /// Reset the insertion point to no location. Creating an operation without a
215 /// set insertion point is an error, but this can still be useful when the
216 /// current insertion point a builder refers to is being removed.
217 void clearInsertionPoint() {
218 this->block = nullptr;
219 insertPoint = StmtBlock::iterator();
220 }
221
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700222 /// Set the insertion point to the specified location.
223 /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
224 /// point to a different function.
225 void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
226 // TODO: check that insertPoint is in this rather than some other block.
227 this->block = block;
228 this->insertPoint = insertPoint;
229 }
230
231 /// Set the insertion point to the specified operation.
232 void setInsertionPoint(OperationStmt *stmt) {
233 setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
234 }
235
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700236 /// Set the insertion point to the end of the specified block.
237 void setInsertionPoint(StmtBlock *block) {
238 this->block = block;
239 insertPoint = block->end();
240 }
241
Uday Bondhugula15984952018-08-01 22:36:12 -0700242 /// Set the insertion point at the beginning of the specified block.
243 void setInsertionPointAtStart(StmtBlock *block) {
244 this->block = block;
245 insertPoint = block->begin();
246 }
247
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700248 OperationStmt *createOperation(Identifier name, ArrayRef<MLValue *> operands,
249 ArrayRef<Type *> resultTypes,
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700250 ArrayRef<NamedAttribute> attributes) {
Uday Bondhugula9f7754e2018-07-31 14:26:07 -0700251 auto *op =
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700252 OperationStmt::create(name, operands, resultTypes, attributes, context);
Uday Bondhugula9f7754e2018-07-31 14:26:07 -0700253 block->getStatements().insert(insertPoint, op);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700254 return op;
255 }
256
Uday Bondhugula15984952018-08-01 22:36:12 -0700257 OperationStmt *cloneOperation(const OperationStmt &srcOpStmt) {
258 auto *op = srcOpStmt.clone();
259 block->getStatements().insert(insertPoint, op);
260 return op;
261 }
262
Chris Lattner1604e472018-07-23 08:42:19 -0700263 // Creates for statement. When step is not specified, it is set to 1.
Tatiana Shpeisman1da50c42018-07-19 09:52:39 -0700264 ForStmt *createFor(AffineConstantExpr *lowerBound,
265 AffineConstantExpr *upperBound,
266 AffineConstantExpr *step = nullptr);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700267
268 IfStmt *createIf() {
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700269 auto *stmt = new IfStmt();
270 block->getStatements().insert(insertPoint, stmt);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700271 return stmt;
272 }
273
Uday Bondhugula15984952018-08-01 22:36:12 -0700274 // TODO: subsume with a generate create<ConstantInt>() method.
275 OperationStmt *createConstInt32Op(int value) {
276 std::pair<Identifier, Attribute *> namedAttr(
277 Identifier::get("value", context), getIntegerAttr(value));
278 auto *mlconst = createOperation(Identifier::get("constant", context), {},
279 {getIntegerType(32)}, {namedAttr});
280 return mlconst;
281 }
282
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700283private:
284 StmtBlock *block = nullptr;
285 StmtBlock::iterator insertPoint;
286};
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700287
288} // namespace mlir
289
290#endif