blob: 3cc8227c44decfeebe2f0c58c20da4028ffbd4a4 [file] [log] [blame]
Chris Lattner158e0a3e2018-07-08 20:51:38 -07001//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#ifndef MLIR_IR_BUILDERS_H
19#define MLIR_IR_BUILDERS_H
20
Uday Bondhugula15984952018-08-01 22:36:12 -070021#include "mlir/IR/Attributes.h"
Chris Lattner158e0a3e2018-07-08 20:51:38 -070022#include "mlir/IR/CFGFunction.h"
Tatiana Shpeisman565b9642018-07-16 11:47:09 -070023#include "mlir/IR/MLFunction.h"
24#include "mlir/IR/Statements.h"
Chris Lattner158e0a3e2018-07-08 20:51:38 -070025
26namespace mlir {
27class MLIRContext;
28class Module;
29class Type;
30class PrimitiveType;
31class IntegerType;
32class FunctionType;
33class VectorType;
34class RankedTensorType;
35class UnrankedTensorType;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070036class BoolAttr;
37class IntegerAttr;
38class FloatAttr;
39class StringAttr;
James Molloyf0d2f442018-08-03 01:54:46 -070040class TypeAttr;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070041class ArrayAttr;
Chris Lattner4613d9e2018-08-19 21:17:22 -070042class FunctionAttr;
MLIR Teamb61885d2018-07-18 16:29:21 -070043class AffineMapAttr;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070044class AffineMap;
45class AffineExpr;
46class AffineConstantExpr;
47class AffineDimExpr;
48class AffineSymbolExpr;
Chris Lattner158e0a3e2018-07-08 20:51:38 -070049
50/// This class is a general helper class for creating context-global objects
51/// like types, attributes, and affine expressions.
52class Builder {
53public:
54 explicit Builder(MLIRContext *context) : context(context) {}
55 explicit Builder(Module *module);
56
57 MLIRContext *getContext() const { return context; }
58
Chris Lattner1ac20cb2018-07-10 10:59:53 -070059 Identifier getIdentifier(StringRef str);
60 Module *createModule();
61
Chris Lattner158e0a3e2018-07-08 20:51:38 -070062 // Types.
Chris Lattnerc3251192018-07-27 13:09:58 -070063 FloatType *getBF16Type();
64 FloatType *getF16Type();
65 FloatType *getF32Type();
66 FloatType *getF64Type();
67
68 OtherType *getAffineIntType();
69 OtherType *getTFControlType();
James Molloy72b0cbe2018-08-01 12:55:27 -070070 OtherType *getTFStringType();
Chris Lattner158e0a3e2018-07-08 20:51:38 -070071 IntegerType *getIntegerType(unsigned width);
72 FunctionType *getFunctionType(ArrayRef<Type *> inputs,
73 ArrayRef<Type *> results);
Jacques Pienaarc03c6952018-08-10 11:56:47 -070074 MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
75 ArrayRef<AffineMap *> affineMapComposition = {},
76 unsigned memorySpace = 0);
Chris Lattner158e0a3e2018-07-08 20:51:38 -070077 VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
78 RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
79 UnrankedTensorType *getTensorType(Type *elementType);
80
Chris Lattner1ac20cb2018-07-10 10:59:53 -070081 // Attributes.
82 BoolAttr *getBoolAttr(bool value);
83 IntegerAttr *getIntegerAttr(int64_t value);
84 FloatAttr *getFloatAttr(double value);
85 StringAttr *getStringAttr(StringRef bytes);
86 ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
MLIR Teamb61885d2018-07-18 16:29:21 -070087 AffineMapAttr *getAffineMapAttr(AffineMap *value);
James Molloyf0d2f442018-08-03 01:54:46 -070088 TypeAttr *getTypeAttr(Type *type);
Chris Lattner1aa46322018-08-21 17:55:22 -070089 FunctionAttr *getFunctionAttr(const Function *value);
Chris Lattner1ac20cb2018-07-10 10:59:53 -070090
91 // Affine Expressions and Affine Map.
92 AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
Uday Bondhugula0115dbb2018-07-11 21:31:07 -070093 ArrayRef<AffineExpr *> results,
94 ArrayRef<AffineExpr *> rangeSizes);
Chris Lattner1ac20cb2018-07-10 10:59:53 -070095 AffineDimExpr *getDimExpr(unsigned position);
96 AffineSymbolExpr *getSymbolExpr(unsigned position);
97 AffineConstantExpr *getConstantExpr(int64_t constant);
98 AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
99 AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
100 AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
101 AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);
102 AffineExpr *getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs);
103 AffineExpr *getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs);
104
Uday Bondhugulabc535622018-08-07 14:24:38 -0700105 // Integer set.
106 IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
107 ArrayRef<AffineExpr *> constraints,
108 ArrayRef<bool> isEq);
109
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700110 // Special cases of affine maps and integer sets
111 // One constant result: () -> (val).
112 AffineMap *getConstantMap(int64_t val);
113 // One dimension id identity map: (i) -> (i).
114 AffineMap *getDimIdentityMap();
115 // One symbol identity map: ()[s] -> (s).
116 AffineMap *getSymbolIdentityMap();
117
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700118 // TODO: Helpers for affine map/exprs, etc.
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700119protected:
120 MLIRContext *context;
121};
122
123/// This class helps build a CFGFunction. Instructions that are created are
124/// automatically inserted at an insertion point or added to the current basic
125/// block.
126class CFGFuncBuilder : public Builder {
127public:
Chris Lattner8174f3a2018-07-29 16:45:23 -0700128 CFGFuncBuilder(BasicBlock *block, BasicBlock::iterator insertPoint)
129 : Builder(block->getFunction()->getContext()),
130 function(block->getFunction()) {
131 setInsertionPoint(block, insertPoint);
132 }
133
134 CFGFuncBuilder(OperationInst *insertBefore)
135 : CFGFuncBuilder(insertBefore->getBlock(),
136 BasicBlock::iterator(insertBefore)) {}
137
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700138 CFGFuncBuilder(BasicBlock *block)
139 : Builder(block->getFunction()->getContext()),
140 function(block->getFunction()) {
141 setInsertionPoint(block);
142 }
Chris Lattner8174f3a2018-07-29 16:45:23 -0700143
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700144 CFGFuncBuilder(CFGFunction *function)
145 : Builder(function->getContext()), function(function) {}
146
147 /// Reset the insertion point to no location. Creating an operation without a
148 /// set insertion point is an error, but this can still be useful when the
149 /// current insertion point a builder refers to is being removed.
150 void clearInsertionPoint() {
151 this->block = nullptr;
152 insertPoint = BasicBlock::iterator();
153 }
154
Chris Lattner8174f3a2018-07-29 16:45:23 -0700155 /// Set the insertion point to the specified location.
156 void setInsertionPoint(BasicBlock *block, BasicBlock::iterator insertPoint) {
157 assert(block->getFunction() == function &&
158 "can't move to a different function");
159 this->block = block;
160 this->insertPoint = insertPoint;
161 }
162
163 /// Set the insertion point to the specified operation.
164 void setInsertionPoint(OperationInst *inst) {
165 setInsertionPoint(inst->getBlock(), BasicBlock::iterator(inst));
166 }
167
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700168 /// Set the insertion point to the end of the specified block.
169 void setInsertionPoint(BasicBlock *block) {
Chris Lattner8174f3a2018-07-29 16:45:23 -0700170 setInsertionPoint(block, block->end());
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700171 }
172
Chris Lattner992a1272018-08-07 12:02:37 -0700173 void insert(OperationInst *opInst) {
174 block->getOperations().insert(insertPoint, opInst);
175 }
176
Chris Lattner8a9310a2018-08-24 21:13:19 -0700177 /// Add new basic block and set the insertion point to the end of it. If an
178 /// 'insertBefore' basic block is passed, the block will be placed before the
179 /// specified block. If not, the block will be appended to the end of the
180 /// current function.
181 BasicBlock *createBlock(BasicBlock *insertBefore = nullptr);
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700182
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700183 /// Create an operation given the fields represented as an OperationState.
184 OperationInst *createOperation(const OperationState &state);
185
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700186 /// Create operation of specific op type at the current insertion point.
187 template <typename OpTy, typename... Args>
Chris Lattner1628fa02018-08-23 14:32:25 -0700188 OpPointer<OpTy> create(Attribute *location, Args... args) {
189 OperationState state(getContext(), location, OpTy::getOperationName());
Chris Lattner1eb77482018-08-22 19:25:49 -0700190 OpTy::build(this, &state, args...);
191 auto *inst = createOperation(state);
Chris Lattner992a1272018-08-07 12:02:37 -0700192 auto result = inst->template getAs<OpTy>();
193 assert(result && "Builder didn't return the right type");
194 return result;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700195 }
196
Uday Bondhugula15984952018-08-01 22:36:12 -0700197 OperationInst *cloneOperation(const OperationInst &srcOpInst) {
198 auto *op = srcOpInst.clone();
Chris Lattner992a1272018-08-07 12:02:37 -0700199 insert(op);
Uday Bondhugula15984952018-08-01 22:36:12 -0700200 return op;
201 }
202
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700203 // Terminators.
204
Chris Lattner091a6b52018-08-23 14:58:27 -0700205 ReturnInst *createReturn(Attribute *location, ArrayRef<CFGValue *> operands) {
Chris Lattner1628fa02018-08-23 14:32:25 -0700206 return insertTerminator(ReturnInst::create(location, operands));
Chris Lattner40746442018-07-21 14:32:09 -0700207 }
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700208
Chris Lattner8a9310a2018-08-24 21:13:19 -0700209 BranchInst *createBranch(Attribute *location, BasicBlock *dest,
210 ArrayRef<CFGValue *> operands = {}) {
211 return insertTerminator(BranchInst::create(location, dest, operands));
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700212 }
213
Chris Lattner091a6b52018-08-23 14:58:27 -0700214 CondBranchInst *createCondBranch(Attribute *location, CFGValue *condition,
215 BasicBlock *trueDest,
216 BasicBlock *falseDest) {
James Molloy4f788372018-07-24 15:01:27 -0700217 return insertTerminator(
Chris Lattner1628fa02018-08-23 14:32:25 -0700218 CondBranchInst::create(location, condition, trueDest, falseDest));
James Molloy4f788372018-07-24 15:01:27 -0700219 }
220
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700221private:
222 template <typename T>
223 T *insertTerminator(T *term) {
224 block->setTerminator(term);
225 return term;
226 }
227
228 CFGFunction *function;
229 BasicBlock *block = nullptr;
230 BasicBlock::iterator insertPoint;
231};
232
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700233/// This class helps build an MLFunction. Statements that are created are
234/// automatically inserted at an insertion point or added to the current
235/// statement block.
236class MLFuncBuilder : public Builder {
237public:
Chris Lattnere787b322018-08-08 11:14:57 -0700238 /// Create ML function builder and set insertion point to the given statement,
239 /// which will cause subsequent insertions to go right before it.
240 MLFuncBuilder(Statement *stmt)
241 // TODO: Eliminate findFunction from this.
242 : Builder(stmt->findFunction()->getContext()) {
243 setInsertionPoint(stmt);
244 }
245
246 MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
247 // TODO: Eliminate findFunction from this.
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700248 : Builder(block->findFunction()->getContext()) {
Chris Lattnere787b322018-08-08 11:14:57 -0700249 setInsertionPoint(block, insertPoint);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700250 }
251
252 /// Reset the insertion point to no location. Creating an operation without a
253 /// set insertion point is an error, but this can still be useful when the
254 /// current insertion point a builder refers to is being removed.
255 void clearInsertionPoint() {
256 this->block = nullptr;
257 insertPoint = StmtBlock::iterator();
258 }
259
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700260 /// Set the insertion point to the specified location.
261 /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
262 /// point to a different function.
263 void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
264 // TODO: check that insertPoint is in this rather than some other block.
265 this->block = block;
266 this->insertPoint = insertPoint;
267 }
268
Uday Bondhugula67701712018-08-21 16:01:23 -0700269 /// Set the insertion point to the specified operation, which will cause
270 /// subsequent insertions to go right before it.
Chris Lattnere787b322018-08-08 11:14:57 -0700271 void setInsertionPoint(Statement *stmt) {
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700272 setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
273 }
274
Chris Lattnere787b322018-08-08 11:14:57 -0700275 /// Set the insertion point to the start of the specified block.
276 void setInsertionPointToStart(StmtBlock *block) {
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700277 setInsertionPoint(block, block->begin());
Uday Bondhugula15984952018-08-01 22:36:12 -0700278 }
279
Chris Lattnere787b322018-08-08 11:14:57 -0700280 /// Set the insertion point to the end of the specified block.
281 void setInsertionPointToEnd(StmtBlock *block) {
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700282 setInsertionPoint(block, block->end());
Chris Lattnere787b322018-08-08 11:14:57 -0700283 }
284
Uday Bondhugula84b80952018-08-03 13:22:26 -0700285 /// Get the current insertion point of the builder.
286 StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
287
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700288 /// Create an operation given the fields represented as an OperationState.
289 OperationStmt *createOperation(const OperationState &state);
290
Chris Lattner992a1272018-08-07 12:02:37 -0700291 /// Create operation of specific op type at the current insertion point.
Jacques Pienaarac86d102018-08-03 08:16:37 -0700292 template <typename OpTy, typename... Args>
Chris Lattner1628fa02018-08-23 14:32:25 -0700293 OpPointer<OpTy> create(Attribute *location, Args... args) {
294 OperationState state(getContext(), location, OpTy::getOperationName());
Chris Lattner1eb77482018-08-22 19:25:49 -0700295 OpTy::build(this, &state, args...);
296 auto *stmt = createOperation(state);
Chris Lattner992a1272018-08-07 12:02:37 -0700297 auto result = stmt->template getAs<OpTy>();
298 assert(result && "Builder didn't return the right type");
299 return result;
Jacques Pienaarac86d102018-08-03 08:16:37 -0700300 }
301
Chris Lattnere787b322018-08-08 11:14:57 -0700302 /// Create a deep copy of the specified statement, remapping any operands that
303 /// use values outside of the statement using the map that is provided (
304 /// leaving them alone if no entry is present). Replaces references to cloned
305 /// sub-statements to the corresponding statement that is copied, and adds
306 /// those mappings to the map.
307 Statement *clone(const Statement &stmt,
308 OperationStmt::OperandMapTy &operandMapping) {
309 Statement *cloneStmt = stmt.clone(operandMapping, getContext());
Uday Bondhugula134154e2018-08-06 18:40:34 -0700310 block->getStatements().insert(insertPoint, cloneStmt);
311 return cloneStmt;
Uday Bondhugula84b80952018-08-03 13:22:26 -0700312 }
313
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700314 /// Create a 'for' statement with bounds that may involve MLValue operands.
315 /// When step is not specified, it is set to 1.
316 ForStmt *createFor(Attribute *location, ArrayRef<MLValue *> lbOperands,
317 AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
318 AffineMap *ubMap, int64_t step = 1);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700319
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700320 /// Create if statement.
321 /// TODO: pass operands.
Chris Lattner1628fa02018-08-23 14:32:25 -0700322 IfStmt *createIf(Attribute *location, IntegerSet *condition);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700323
324private:
325 StmtBlock *block = nullptr;
326 StmtBlock::iterator insertPoint;
327};
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700328
329} // namespace mlir
330
331#endif