blob: 3324d4b566786ce36015c3ed20c2284d5a580259 [file] [log] [blame]
Chris Lattner158e0a3e2018-07-08 20:51:38 -07001//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#ifndef MLIR_IR_BUILDERS_H
19#define MLIR_IR_BUILDERS_H
20
Uday Bondhugula15984952018-08-01 22:36:12 -070021#include "mlir/IR/Attributes.h"
Chris Lattner158e0a3e2018-07-08 20:51:38 -070022#include "mlir/IR/CFGFunction.h"
Tatiana Shpeisman565b9642018-07-16 11:47:09 -070023#include "mlir/IR/MLFunction.h"
24#include "mlir/IR/Statements.h"
Chris Lattner158e0a3e2018-07-08 20:51:38 -070025
26namespace mlir {
27class MLIRContext;
28class Module;
29class Type;
30class PrimitiveType;
31class IntegerType;
32class FunctionType;
33class VectorType;
34class RankedTensorType;
35class UnrankedTensorType;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070036class BoolAttr;
37class IntegerAttr;
38class FloatAttr;
39class StringAttr;
James Molloyf0d2f442018-08-03 01:54:46 -070040class TypeAttr;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070041class ArrayAttr;
MLIR Teamb61885d2018-07-18 16:29:21 -070042class AffineMapAttr;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070043class AffineMap;
44class AffineExpr;
45class AffineConstantExpr;
46class AffineDimExpr;
47class AffineSymbolExpr;
Chris Lattner158e0a3e2018-07-08 20:51:38 -070048
49/// This class is a general helper class for creating context-global objects
50/// like types, attributes, and affine expressions.
51class Builder {
52public:
53 explicit Builder(MLIRContext *context) : context(context) {}
54 explicit Builder(Module *module);
55
56 MLIRContext *getContext() const { return context; }
57
Chris Lattner1ac20cb2018-07-10 10:59:53 -070058 Identifier getIdentifier(StringRef str);
59 Module *createModule();
60
Chris Lattner158e0a3e2018-07-08 20:51:38 -070061 // Types.
Chris Lattnerc3251192018-07-27 13:09:58 -070062 FloatType *getBF16Type();
63 FloatType *getF16Type();
64 FloatType *getF32Type();
65 FloatType *getF64Type();
66
67 OtherType *getAffineIntType();
68 OtherType *getTFControlType();
James Molloy72b0cbe2018-08-01 12:55:27 -070069 OtherType *getTFStringType();
Chris Lattner158e0a3e2018-07-08 20:51:38 -070070 IntegerType *getIntegerType(unsigned width);
71 FunctionType *getFunctionType(ArrayRef<Type *> inputs,
72 ArrayRef<Type *> results);
Jacques Pienaarc03c6952018-08-10 11:56:47 -070073 MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
74 ArrayRef<AffineMap *> affineMapComposition = {},
75 unsigned memorySpace = 0);
Chris Lattner158e0a3e2018-07-08 20:51:38 -070076 VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
77 RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
78 UnrankedTensorType *getTensorType(Type *elementType);
79
Chris Lattner1ac20cb2018-07-10 10:59:53 -070080 // Attributes.
81 BoolAttr *getBoolAttr(bool value);
82 IntegerAttr *getIntegerAttr(int64_t value);
83 FloatAttr *getFloatAttr(double value);
84 StringAttr *getStringAttr(StringRef bytes);
85 ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
MLIR Teamb61885d2018-07-18 16:29:21 -070086 AffineMapAttr *getAffineMapAttr(AffineMap *value);
James Molloyf0d2f442018-08-03 01:54:46 -070087 TypeAttr *getTypeAttr(Type *type);
Chris Lattner1ac20cb2018-07-10 10:59:53 -070088
89 // Affine Expressions and Affine Map.
90 AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
Uday Bondhugula0115dbb2018-07-11 21:31:07 -070091 ArrayRef<AffineExpr *> results,
92 ArrayRef<AffineExpr *> rangeSizes);
Chris Lattner1ac20cb2018-07-10 10:59:53 -070093 AffineDimExpr *getDimExpr(unsigned position);
94 AffineSymbolExpr *getSymbolExpr(unsigned position);
95 AffineConstantExpr *getConstantExpr(int64_t constant);
96 AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
97 AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
98 AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
99 AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);
100 AffineExpr *getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs);
101 AffineExpr *getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs);
102
Uday Bondhugulabc535622018-08-07 14:24:38 -0700103 // Integer set.
104 IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
105 ArrayRef<AffineExpr *> constraints,
106 ArrayRef<bool> isEq);
107
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700108 // TODO: Helpers for affine map/exprs, etc.
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700109protected:
110 MLIRContext *context;
111};
112
113/// This class helps build a CFGFunction. Instructions that are created are
114/// automatically inserted at an insertion point or added to the current basic
115/// block.
116class CFGFuncBuilder : public Builder {
117public:
Chris Lattner8174f3a2018-07-29 16:45:23 -0700118 CFGFuncBuilder(BasicBlock *block, BasicBlock::iterator insertPoint)
119 : Builder(block->getFunction()->getContext()),
120 function(block->getFunction()) {
121 setInsertionPoint(block, insertPoint);
122 }
123
124 CFGFuncBuilder(OperationInst *insertBefore)
125 : CFGFuncBuilder(insertBefore->getBlock(),
126 BasicBlock::iterator(insertBefore)) {}
127
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700128 CFGFuncBuilder(BasicBlock *block)
129 : Builder(block->getFunction()->getContext()),
130 function(block->getFunction()) {
131 setInsertionPoint(block);
132 }
Chris Lattner8174f3a2018-07-29 16:45:23 -0700133
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700134 CFGFuncBuilder(CFGFunction *function)
135 : Builder(function->getContext()), function(function) {}
136
137 /// Reset the insertion point to no location. Creating an operation without a
138 /// set insertion point is an error, but this can still be useful when the
139 /// current insertion point a builder refers to is being removed.
140 void clearInsertionPoint() {
141 this->block = nullptr;
142 insertPoint = BasicBlock::iterator();
143 }
144
Chris Lattner8174f3a2018-07-29 16:45:23 -0700145 /// Set the insertion point to the specified location.
146 void setInsertionPoint(BasicBlock *block, BasicBlock::iterator insertPoint) {
147 assert(block->getFunction() == function &&
148 "can't move to a different function");
149 this->block = block;
150 this->insertPoint = insertPoint;
151 }
152
153 /// Set the insertion point to the specified operation.
154 void setInsertionPoint(OperationInst *inst) {
155 setInsertionPoint(inst->getBlock(), BasicBlock::iterator(inst));
156 }
157
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700158 /// Set the insertion point to the end of the specified block.
159 void setInsertionPoint(BasicBlock *block) {
Chris Lattner8174f3a2018-07-29 16:45:23 -0700160 setInsertionPoint(block, block->end());
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700161 }
162
Chris Lattner992a1272018-08-07 12:02:37 -0700163 void insert(OperationInst *opInst) {
164 block->getOperations().insert(insertPoint, opInst);
165 }
166
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700167 // Add new basic block and set the insertion point to the end of it.
168 BasicBlock *createBlock();
169
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700170 /// Create an operation given the fields represented as an OperationState.
171 OperationInst *createOperation(const OperationState &state);
172
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700173 /// Create operation of specific op type at the current insertion point.
174 template <typename OpTy, typename... Args>
175 OpPointer<OpTy> create(Args... args) {
Chris Lattner992a1272018-08-07 12:02:37 -0700176 auto *inst = createOperation(OpTy::build(this, args...));
177 auto result = inst->template getAs<OpTy>();
178 assert(result && "Builder didn't return the right type");
179 return result;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700180 }
181
Uday Bondhugula15984952018-08-01 22:36:12 -0700182 OperationInst *cloneOperation(const OperationInst &srcOpInst) {
183 auto *op = srcOpInst.clone();
Chris Lattner992a1272018-08-07 12:02:37 -0700184 insert(op);
Uday Bondhugula15984952018-08-01 22:36:12 -0700185 return op;
186 }
187
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700188 // Terminators.
189
Chris Lattner40746442018-07-21 14:32:09 -0700190 ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
191 return insertTerminator(ReturnInst::create(operands));
192 }
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700193
194 BranchInst *createBranchInst(BasicBlock *dest) {
Chris Lattner40746442018-07-21 14:32:09 -0700195 return insertTerminator(BranchInst::create(dest));
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700196 }
197
James Molloy4f788372018-07-24 15:01:27 -0700198 CondBranchInst *createCondBranchInst(CFGValue *condition,
199 BasicBlock *trueDest,
200 BasicBlock *falseDest) {
201 return insertTerminator(
202 CondBranchInst::create(condition, trueDest, falseDest));
203 }
204
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700205private:
206 template <typename T>
207 T *insertTerminator(T *term) {
208 block->setTerminator(term);
209 return term;
210 }
211
212 CFGFunction *function;
213 BasicBlock *block = nullptr;
214 BasicBlock::iterator insertPoint;
215};
216
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700217/// This class helps build an MLFunction. Statements that are created are
218/// automatically inserted at an insertion point or added to the current
219/// statement block.
220class MLFuncBuilder : public Builder {
221public:
Chris Lattnere787b322018-08-08 11:14:57 -0700222 /// Create ML function builder and set insertion point to the given statement,
223 /// which will cause subsequent insertions to go right before it.
224 MLFuncBuilder(Statement *stmt)
225 // TODO: Eliminate findFunction from this.
226 : Builder(stmt->findFunction()->getContext()) {
227 setInsertionPoint(stmt);
228 }
229
230 MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
231 // TODO: Eliminate findFunction from this.
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700232 : Builder(block->findFunction()->getContext()) {
Chris Lattnere787b322018-08-08 11:14:57 -0700233 setInsertionPoint(block, insertPoint);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700234 }
235
236 /// Reset the insertion point to no location. Creating an operation without a
237 /// set insertion point is an error, but this can still be useful when the
238 /// current insertion point a builder refers to is being removed.
239 void clearInsertionPoint() {
240 this->block = nullptr;
241 insertPoint = StmtBlock::iterator();
242 }
243
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700244 /// Set the insertion point to the specified location.
245 /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
246 /// point to a different function.
247 void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
248 // TODO: check that insertPoint is in this rather than some other block.
249 this->block = block;
250 this->insertPoint = insertPoint;
251 }
252
253 /// Set the insertion point to the specified operation.
Chris Lattnere787b322018-08-08 11:14:57 -0700254 void setInsertionPoint(Statement *stmt) {
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700255 setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
256 }
257
Chris Lattnere787b322018-08-08 11:14:57 -0700258 /// Set the insertion point to the start of the specified block.
259 void setInsertionPointToStart(StmtBlock *block) {
Uday Bondhugula15984952018-08-01 22:36:12 -0700260 this->block = block;
261 insertPoint = block->begin();
262 }
263
Chris Lattnere787b322018-08-08 11:14:57 -0700264 /// Set the insertion point to the end of the specified block.
265 void setInsertionPointToEnd(StmtBlock *block) {
266 this->block = block;
267 insertPoint = block->end();
268 }
269
Uday Bondhugula84b80952018-08-03 13:22:26 -0700270 /// Get the current insertion point of the builder.
271 StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
272
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700273 /// Create an operation given the fields represented as an OperationState.
274 OperationStmt *createOperation(const OperationState &state);
275
Chris Lattner992a1272018-08-07 12:02:37 -0700276 /// Create operation of specific op type at the current insertion point.
Jacques Pienaarac86d102018-08-03 08:16:37 -0700277 template <typename OpTy, typename... Args>
278 OpPointer<OpTy> create(Args... args) {
Chris Lattner992a1272018-08-07 12:02:37 -0700279 auto stmt = createOperation(OpTy::build(this, args...));
280 auto result = stmt->template getAs<OpTy>();
281 assert(result && "Builder didn't return the right type");
282 return result;
Jacques Pienaarac86d102018-08-03 08:16:37 -0700283 }
284
Chris Lattnere787b322018-08-08 11:14:57 -0700285 /// Create a deep copy of the specified statement, remapping any operands that
286 /// use values outside of the statement using the map that is provided (
287 /// leaving them alone if no entry is present). Replaces references to cloned
288 /// sub-statements to the corresponding statement that is copied, and adds
289 /// those mappings to the map.
290 Statement *clone(const Statement &stmt,
291 OperationStmt::OperandMapTy &operandMapping) {
292 Statement *cloneStmt = stmt.clone(operandMapping, getContext());
Uday Bondhugula134154e2018-08-06 18:40:34 -0700293 block->getStatements().insert(insertPoint, cloneStmt);
294 return cloneStmt;
Uday Bondhugula84b80952018-08-03 13:22:26 -0700295 }
296
Chris Lattner1604e472018-07-23 08:42:19 -0700297 // Creates for statement. When step is not specified, it is set to 1.
Tatiana Shpeisman1da50c42018-07-19 09:52:39 -0700298 ForStmt *createFor(AffineConstantExpr *lowerBound,
299 AffineConstantExpr *upperBound,
300 AffineConstantExpr *step = nullptr);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700301
Uday Bondhugulabc535622018-08-07 14:24:38 -0700302 IfStmt *createIf(IntegerSet *condition) {
303 auto *stmt = new IfStmt(condition);
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700304 block->getStatements().insert(insertPoint, stmt);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700305 return stmt;
306 }
307
308private:
309 StmtBlock *block = nullptr;
310 StmtBlock::iterator insertPoint;
311};
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700312
313} // namespace mlir
314
315#endif