blob: 7cd0dd52d10da5802496226e67ecbc9b01b27846 [file] [log] [blame]
Chris Lattner158e0a3e2018-07-08 20:51:38 -07001//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#ifndef MLIR_IR_BUILDERS_H
19#define MLIR_IR_BUILDERS_H
20
21#include "mlir/IR/CFGFunction.h"
Tatiana Shpeisman565b9642018-07-16 11:47:09 -070022#include "mlir/IR/MLFunction.h"
23#include "mlir/IR/Statements.h"
Chris Lattner158e0a3e2018-07-08 20:51:38 -070024
25namespace mlir {
26class MLIRContext;
27class Module;
Chris Lattnerfc647d52018-08-27 21:05:16 -070028class UnknownLoc;
29class UniquedFilename;
30class FileLineColLoc;
Chris Lattner158e0a3e2018-07-08 20:51:38 -070031class Type;
32class PrimitiveType;
33class IntegerType;
34class FunctionType;
35class VectorType;
36class RankedTensorType;
37class UnrankedTensorType;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070038class BoolAttr;
39class IntegerAttr;
40class FloatAttr;
41class StringAttr;
James Molloyf0d2f442018-08-03 01:54:46 -070042class TypeAttr;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070043class ArrayAttr;
Chris Lattner4613d9e2018-08-19 21:17:22 -070044class FunctionAttr;
MLIR Teamb61885d2018-07-18 16:29:21 -070045class AffineMapAttr;
Chris Lattner1ac20cb2018-07-10 10:59:53 -070046class AffineMap;
47class AffineExpr;
48class AffineConstantExpr;
49class AffineDimExpr;
50class AffineSymbolExpr;
Chris Lattner158e0a3e2018-07-08 20:51:38 -070051
52/// This class is a general helper class for creating context-global objects
53/// like types, attributes, and affine expressions.
54class Builder {
55public:
56 explicit Builder(MLIRContext *context) : context(context) {}
57 explicit Builder(Module *module);
58
59 MLIRContext *getContext() const { return context; }
60
Chris Lattner1ac20cb2018-07-10 10:59:53 -070061 Identifier getIdentifier(StringRef str);
62 Module *createModule();
63
Chris Lattnerfc647d52018-08-27 21:05:16 -070064 // Locations.
65 UnknownLoc *getUnknownLoc();
66 UniquedFilename getUniquedFilename(StringRef filename);
67 FileLineColLoc *getFileLineColLoc(UniquedFilename filename, unsigned line,
68 unsigned column);
69
Chris Lattner158e0a3e2018-07-08 20:51:38 -070070 // Types.
Chris Lattnerc3251192018-07-27 13:09:58 -070071 FloatType *getBF16Type();
72 FloatType *getF16Type();
73 FloatType *getF32Type();
74 FloatType *getF64Type();
75
76 OtherType *getAffineIntType();
77 OtherType *getTFControlType();
James Molloy72b0cbe2018-08-01 12:55:27 -070078 OtherType *getTFStringType();
Chris Lattner158e0a3e2018-07-08 20:51:38 -070079 IntegerType *getIntegerType(unsigned width);
80 FunctionType *getFunctionType(ArrayRef<Type *> inputs,
81 ArrayRef<Type *> results);
Jacques Pienaarc03c6952018-08-10 11:56:47 -070082 MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
83 ArrayRef<AffineMap *> affineMapComposition = {},
84 unsigned memorySpace = 0);
Chris Lattner158e0a3e2018-07-08 20:51:38 -070085 VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
86 RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
87 UnrankedTensorType *getTensorType(Type *elementType);
88
Chris Lattner1ac20cb2018-07-10 10:59:53 -070089 // Attributes.
90 BoolAttr *getBoolAttr(bool value);
91 IntegerAttr *getIntegerAttr(int64_t value);
92 FloatAttr *getFloatAttr(double value);
93 StringAttr *getStringAttr(StringRef bytes);
94 ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
MLIR Teamb61885d2018-07-18 16:29:21 -070095 AffineMapAttr *getAffineMapAttr(AffineMap *value);
James Molloyf0d2f442018-08-03 01:54:46 -070096 TypeAttr *getTypeAttr(Type *type);
Chris Lattner1aa46322018-08-21 17:55:22 -070097 FunctionAttr *getFunctionAttr(const Function *value);
Chris Lattner1ac20cb2018-07-10 10:59:53 -070098
99 // Affine Expressions and Affine Map.
100 AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
Uday Bondhugula0115dbb2018-07-11 21:31:07 -0700101 ArrayRef<AffineExpr *> results,
102 ArrayRef<AffineExpr *> rangeSizes);
Chris Lattner1ac20cb2018-07-10 10:59:53 -0700103 AffineDimExpr *getDimExpr(unsigned position);
104 AffineSymbolExpr *getSymbolExpr(unsigned position);
105 AffineConstantExpr *getConstantExpr(int64_t constant);
106 AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
107 AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
108 AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700109 AffineExpr *getMulExpr(AffineExpr *lhs, int64_t rhs);
Chris Lattner1ac20cb2018-07-10 10:59:53 -0700110 AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700111 AffineExpr *getModExpr(AffineExpr *lhs, uint64_t rhs);
Chris Lattner1ac20cb2018-07-10 10:59:53 -0700112 AffineExpr *getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700113 AffineExpr *getFloorDivExpr(AffineExpr *lhs, uint64_t rhs);
Chris Lattner1ac20cb2018-07-10 10:59:53 -0700114 AffineExpr *getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700115 AffineExpr *getCeilDivExpr(AffineExpr *lhs, uint64_t rhs);
Chris Lattner1ac20cb2018-07-10 10:59:53 -0700116
Uday Bondhugulabc535622018-08-07 14:24:38 -0700117 // Integer set.
118 IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
119 ArrayRef<AffineExpr *> constraints,
120 ArrayRef<bool> isEq);
121
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700122 // Special cases of affine maps and integer sets
123 // One constant result: () -> (val).
124 AffineMap *getConstantMap(int64_t val);
125 // One dimension id identity map: (i) -> (i).
126 AffineMap *getDimIdentityMap();
127 // One symbol identity map: ()[s] -> (s).
128 AffineMap *getSymbolIdentityMap();
129
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700130 // TODO: Helpers for affine map/exprs, etc.
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700131protected:
132 MLIRContext *context;
133};
134
135/// This class helps build a CFGFunction. Instructions that are created are
136/// automatically inserted at an insertion point or added to the current basic
137/// block.
138class CFGFuncBuilder : public Builder {
139public:
Chris Lattner8174f3a2018-07-29 16:45:23 -0700140 CFGFuncBuilder(BasicBlock *block, BasicBlock::iterator insertPoint)
141 : Builder(block->getFunction()->getContext()),
142 function(block->getFunction()) {
143 setInsertionPoint(block, insertPoint);
144 }
145
146 CFGFuncBuilder(OperationInst *insertBefore)
147 : CFGFuncBuilder(insertBefore->getBlock(),
148 BasicBlock::iterator(insertBefore)) {}
149
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700150 CFGFuncBuilder(BasicBlock *block)
151 : Builder(block->getFunction()->getContext()),
152 function(block->getFunction()) {
153 setInsertionPoint(block);
154 }
Chris Lattner8174f3a2018-07-29 16:45:23 -0700155
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700156 CFGFuncBuilder(CFGFunction *function)
157 : Builder(function->getContext()), function(function) {}
158
159 /// Reset the insertion point to no location. Creating an operation without a
160 /// set insertion point is an error, but this can still be useful when the
161 /// current insertion point a builder refers to is being removed.
162 void clearInsertionPoint() {
163 this->block = nullptr;
164 insertPoint = BasicBlock::iterator();
165 }
166
Chris Lattner8174f3a2018-07-29 16:45:23 -0700167 /// Set the insertion point to the specified location.
168 void setInsertionPoint(BasicBlock *block, BasicBlock::iterator insertPoint) {
169 assert(block->getFunction() == function &&
170 "can't move to a different function");
171 this->block = block;
172 this->insertPoint = insertPoint;
173 }
174
175 /// Set the insertion point to the specified operation.
176 void setInsertionPoint(OperationInst *inst) {
177 setInsertionPoint(inst->getBlock(), BasicBlock::iterator(inst));
178 }
179
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700180 /// Set the insertion point to the end of the specified block.
181 void setInsertionPoint(BasicBlock *block) {
Chris Lattner8174f3a2018-07-29 16:45:23 -0700182 setInsertionPoint(block, block->end());
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700183 }
184
Chris Lattner992a1272018-08-07 12:02:37 -0700185 void insert(OperationInst *opInst) {
186 block->getOperations().insert(insertPoint, opInst);
187 }
188
Chris Lattner8a9310a2018-08-24 21:13:19 -0700189 /// Add new basic block and set the insertion point to the end of it. If an
190 /// 'insertBefore' basic block is passed, the block will be placed before the
191 /// specified block. If not, the block will be appended to the end of the
192 /// current function.
193 BasicBlock *createBlock(BasicBlock *insertBefore = nullptr);
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700194
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700195 /// Create an operation given the fields represented as an OperationState.
196 OperationInst *createOperation(const OperationState &state);
197
Chris Lattner7879f842018-09-02 22:01:45 -0700198 /// Create operation of specific op type at the current insertion point
199 /// without verifying to see if it is valid.
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700200 template <typename OpTy, typename... Args>
Chris Lattnerfc647d52018-08-27 21:05:16 -0700201 OpPointer<OpTy> create(Location *location, Args... args) {
Chris Lattner1628fa02018-08-23 14:32:25 -0700202 OperationState state(getContext(), location, OpTy::getOperationName());
Chris Lattner1eb77482018-08-22 19:25:49 -0700203 OpTy::build(this, &state, args...);
204 auto *inst = createOperation(state);
Chris Lattner992a1272018-08-07 12:02:37 -0700205 auto result = inst->template getAs<OpTy>();
206 assert(result && "Builder didn't return the right type");
207 return result;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700208 }
209
Chris Lattner7879f842018-09-02 22:01:45 -0700210 /// Create operation of specific op type at the current insertion point. If
211 /// the result is an invalid op (the verifier hook fails), emit a the
212 /// specified error message and return null.
213 template <typename OpTy, typename... Args>
Chris Lattner015100b2018-09-09 20:40:23 -0700214 OpPointer<OpTy> createChecked(Location *location, Args... args) {
Chris Lattner7879f842018-09-02 22:01:45 -0700215 OperationState state(getContext(), location, OpTy::getOperationName());
216 OpTy::build(this, &state, args...);
217 auto *inst = createOperation(state);
Chris Lattner7879f842018-09-02 22:01:45 -0700218
Chris Lattner015100b2018-09-09 20:40:23 -0700219 // If the OperationInst we produce is valid, return it.
220 if (!OpTy::verifyInvariants(inst)) {
221 auto result = inst->template getAs<OpTy>();
222 assert(result && "Builder didn't return the right type");
Chris Lattner7879f842018-09-02 22:01:45 -0700223 return result;
Chris Lattner015100b2018-09-09 20:40:23 -0700224 }
225
226 // Otherwise, the error message got emitted. Just remove the instruction
227 // we made.
Chris Lattner7879f842018-09-02 22:01:45 -0700228 inst->eraseFromBlock();
229 return OpPointer<OpTy>();
230 }
231
Uday Bondhugula15984952018-08-01 22:36:12 -0700232 OperationInst *cloneOperation(const OperationInst &srcOpInst) {
233 auto *op = srcOpInst.clone();
Chris Lattner992a1272018-08-07 12:02:37 -0700234 insert(op);
Uday Bondhugula15984952018-08-01 22:36:12 -0700235 return op;
236 }
237
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700238 // Terminators.
239
Chris Lattnerfc647d52018-08-27 21:05:16 -0700240 ReturnInst *createReturn(Location *location, ArrayRef<CFGValue *> operands) {
Chris Lattner1628fa02018-08-23 14:32:25 -0700241 return insertTerminator(ReturnInst::create(location, operands));
Chris Lattner40746442018-07-21 14:32:09 -0700242 }
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700243
Chris Lattnerfc647d52018-08-27 21:05:16 -0700244 BranchInst *createBranch(Location *location, BasicBlock *dest,
Chris Lattner8a9310a2018-08-24 21:13:19 -0700245 ArrayRef<CFGValue *> operands = {}) {
246 return insertTerminator(BranchInst::create(location, dest, operands));
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700247 }
248
Chris Lattnerfc647d52018-08-27 21:05:16 -0700249 CondBranchInst *createCondBranch(Location *location, CFGValue *condition,
Chris Lattner091a6b52018-08-23 14:58:27 -0700250 BasicBlock *trueDest,
251 BasicBlock *falseDest) {
James Molloy4f788372018-07-24 15:01:27 -0700252 return insertTerminator(
Chris Lattner1628fa02018-08-23 14:32:25 -0700253 CondBranchInst::create(location, condition, trueDest, falseDest));
James Molloy4f788372018-07-24 15:01:27 -0700254 }
255
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700256private:
257 template <typename T>
258 T *insertTerminator(T *term) {
259 block->setTerminator(term);
260 return term;
261 }
262
263 CFGFunction *function;
264 BasicBlock *block = nullptr;
265 BasicBlock::iterator insertPoint;
266};
267
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700268/// This class helps build an MLFunction. Statements that are created are
269/// automatically inserted at an insertion point or added to the current
270/// statement block.
271class MLFuncBuilder : public Builder {
272public:
Chris Lattnere787b322018-08-08 11:14:57 -0700273 /// Create ML function builder and set insertion point to the given statement,
274 /// which will cause subsequent insertions to go right before it.
275 MLFuncBuilder(Statement *stmt)
276 // TODO: Eliminate findFunction from this.
277 : Builder(stmt->findFunction()->getContext()) {
278 setInsertionPoint(stmt);
279 }
280
281 MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
282 // TODO: Eliminate findFunction from this.
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700283 : Builder(block->findFunction()->getContext()) {
Chris Lattnere787b322018-08-08 11:14:57 -0700284 setInsertionPoint(block, insertPoint);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700285 }
286
287 /// Reset the insertion point to no location. Creating an operation without a
288 /// set insertion point is an error, but this can still be useful when the
289 /// current insertion point a builder refers to is being removed.
290 void clearInsertionPoint() {
291 this->block = nullptr;
292 insertPoint = StmtBlock::iterator();
293 }
294
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700295 /// Set the insertion point to the specified location.
296 /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
297 /// point to a different function.
298 void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
299 // TODO: check that insertPoint is in this rather than some other block.
300 this->block = block;
301 this->insertPoint = insertPoint;
302 }
303
Uday Bondhugula67701712018-08-21 16:01:23 -0700304 /// Set the insertion point to the specified operation, which will cause
305 /// subsequent insertions to go right before it.
Chris Lattnere787b322018-08-08 11:14:57 -0700306 void setInsertionPoint(Statement *stmt) {
Tatiana Shpeismand880b352018-07-31 23:14:16 -0700307 setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
308 }
309
Chris Lattnere787b322018-08-08 11:14:57 -0700310 /// Set the insertion point to the start of the specified block.
311 void setInsertionPointToStart(StmtBlock *block) {
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700312 setInsertionPoint(block, block->begin());
Uday Bondhugula15984952018-08-01 22:36:12 -0700313 }
314
Chris Lattnere787b322018-08-08 11:14:57 -0700315 /// Set the insertion point to the end of the specified block.
316 void setInsertionPointToEnd(StmtBlock *block) {
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700317 setInsertionPoint(block, block->end());
Chris Lattnere787b322018-08-08 11:14:57 -0700318 }
319
Uday Bondhugula84b80952018-08-03 13:22:26 -0700320 /// Get the current insertion point of the builder.
321 StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
322
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700323 /// Create an operation given the fields represented as an OperationState.
324 OperationStmt *createOperation(const OperationState &state);
325
Chris Lattner992a1272018-08-07 12:02:37 -0700326 /// Create operation of specific op type at the current insertion point.
Jacques Pienaarac86d102018-08-03 08:16:37 -0700327 template <typename OpTy, typename... Args>
Chris Lattnerfc647d52018-08-27 21:05:16 -0700328 OpPointer<OpTy> create(Location *location, Args... args) {
Chris Lattner1628fa02018-08-23 14:32:25 -0700329 OperationState state(getContext(), location, OpTy::getOperationName());
Chris Lattner1eb77482018-08-22 19:25:49 -0700330 OpTy::build(this, &state, args...);
331 auto *stmt = createOperation(state);
Chris Lattner992a1272018-08-07 12:02:37 -0700332 auto result = stmt->template getAs<OpTy>();
333 assert(result && "Builder didn't return the right type");
334 return result;
Jacques Pienaarac86d102018-08-03 08:16:37 -0700335 }
336
Chris Lattner7879f842018-09-02 22:01:45 -0700337 /// Create operation of specific op type at the current insertion point. If
338 /// the result is an invalid op (the verifier hook fails), emit an error and
339 /// return null.
340 template <typename OpTy, typename... Args>
Chris Lattner015100b2018-09-09 20:40:23 -0700341 OpPointer<OpTy> createChecked(Location *location, Args... args) {
Chris Lattner7879f842018-09-02 22:01:45 -0700342 OperationState state(getContext(), location, OpTy::getOperationName());
343 OpTy::build(this, &state, args...);
344 auto *stmt = createOperation(state);
Chris Lattner7879f842018-09-02 22:01:45 -0700345
Chris Lattner015100b2018-09-09 20:40:23 -0700346 // If the OperationStmt we produce is valid, return it.
347 if (!OpTy::verifyInvariants(stmt)) {
348 auto result = stmt->template getAs<OpTy>();
349 assert(result && "Builder didn't return the right type");
Chris Lattner7879f842018-09-02 22:01:45 -0700350 return result;
Chris Lattner015100b2018-09-09 20:40:23 -0700351 }
352
353 // Otherwise, the error message got emitted. Just remove the statement
354 // we made.
Chris Lattner7879f842018-09-02 22:01:45 -0700355 stmt->eraseFromBlock();
356 return OpPointer<OpTy>();
357 }
358
Chris Lattnere787b322018-08-08 11:14:57 -0700359 /// Create a deep copy of the specified statement, remapping any operands that
360 /// use values outside of the statement using the map that is provided (
361 /// leaving them alone if no entry is present). Replaces references to cloned
362 /// sub-statements to the corresponding statement that is copied, and adds
363 /// those mappings to the map.
364 Statement *clone(const Statement &stmt,
365 OperationStmt::OperandMapTy &operandMapping) {
366 Statement *cloneStmt = stmt.clone(operandMapping, getContext());
Uday Bondhugula134154e2018-08-06 18:40:34 -0700367 block->getStatements().insert(insertPoint, cloneStmt);
368 return cloneStmt;
Uday Bondhugula84b80952018-08-03 13:22:26 -0700369 }
370
Tatiana Shpeismanc6aa35b2018-08-28 15:26:20 -0700371 // Create for statement. When step is not specified, it is set to 1.
Chris Lattnerfc647d52018-08-27 21:05:16 -0700372 ForStmt *createFor(Location *location, ArrayRef<MLValue *> lbOperands,
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700373 AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
374 AffineMap *ubMap, int64_t step = 1);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700375
Tatiana Shpeismanc6aa35b2018-08-28 15:26:20 -0700376 /// Create if statement.
377 IfStmt *createIf(Location *location, ArrayRef<MLValue *> operands,
378 IntegerSet *set);
Tatiana Shpeisman565b9642018-07-16 11:47:09 -0700379
380private:
381 StmtBlock *block = nullptr;
382 StmtBlock::iterator insertPoint;
383};
Chris Lattner158e0a3e2018-07-08 20:51:38 -0700384
385} // namespace mlir
386
387#endif