Introduce the start of IR builder APIs, which makes it easier and less error
prone to create things.
PiperOrigin-RevId: 203703229
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
new file mode 100644
index 0000000..644c812
--- /dev/null
+++ b/include/mlir/IR/Builders.h
@@ -0,0 +1,122 @@
+//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_BUILDERS_H
+#define MLIR_IR_BUILDERS_H
+
+#include "mlir/IR/CFGFunction.h"
+
+namespace mlir {
+class MLIRContext;
+class Module;
+class Type;
+class PrimitiveType;
+class IntegerType;
+class FunctionType;
+class VectorType;
+class RankedTensorType;
+class UnrankedTensorType;
+
+/// This class is a general helper class for creating context-global objects
+/// like types, attributes, and affine expressions.
+class Builder {
+public:
+ explicit Builder(MLIRContext *context) : context(context) {}
+ explicit Builder(Module *module);
+
+ MLIRContext *getContext() const { return context; }
+
+ // Types.
+ PrimitiveType *getAffineIntType();
+ PrimitiveType *getBF16Type();
+ PrimitiveType *getF16Type();
+ PrimitiveType *getF32Type();
+ PrimitiveType *getF64Type();
+ IntegerType *getIntegerType(unsigned width);
+ FunctionType *getFunctionType(ArrayRef<Type *> inputs,
+ ArrayRef<Type *> results);
+ VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
+ RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
+ UnrankedTensorType *getTensorType(Type *elementType);
+
+ // TODO: Helpers for affine map/exprs, etc.
+ // TODO: Helpers for attributes.
+ // TODO: Identifier
+ // TODO: createModule()
+protected:
+ MLIRContext *context;
+};
+
+/// This class helps build a CFGFunction. Instructions that are created are
+/// automatically inserted at an insertion point or added to the current basic
+/// block.
+class CFGFuncBuilder : public Builder {
+public:
+ CFGFuncBuilder(BasicBlock *block)
+ : Builder(block->getFunction()->getContext()),
+ function(block->getFunction()) {
+ setInsertionPoint(block);
+ }
+ CFGFuncBuilder(CFGFunction *function)
+ : Builder(function->getContext()), function(function) {}
+
+ /// Reset the insertion point to no location. Creating an operation without a
+ /// set insertion point is an error, but this can still be useful when the
+ /// current insertion point a builder refers to is being removed.
+ void clearInsertionPoint() {
+ this->block = nullptr;
+ insertPoint = BasicBlock::iterator();
+ }
+
+ /// Set the insertion point to the end of the specified block.
+ void setInsertionPoint(BasicBlock *block) {
+ this->block = block;
+ insertPoint = block->end();
+ }
+
+ OperationInst *createOperation(Identifier name,
+ ArrayRef<NamedAttribute> attributes) {
+ auto op = new OperationInst(name, attributes, context);
+ block->getOperations().push_back(op);
+ return op;
+ }
+
+ // Terminators.
+
+ ReturnInst *createReturnInst() { return insertTerminator(new ReturnInst()); }
+
+ BranchInst *createBranchInst(BasicBlock *dest) {
+ return insertTerminator(new BranchInst(dest));
+ }
+
+private:
+ template <typename T>
+ T *insertTerminator(T *term) {
+ block->setTerminator(term);
+ return term;
+ }
+
+ CFGFunction *function;
+ BasicBlock *block = nullptr;
+ BasicBlock::iterator insertPoint;
+};
+
+// TODO: MLFuncBuilder
+
+} // namespace mlir
+
+#endif
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index ec277d6..a12b387 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -31,7 +31,9 @@
class Module {
public:
- explicit Module();
+ explicit Module(MLIRContext *context);
+
+ MLIRContext *getContext() const { return context; }
// FIXME: wrong representation and API.
// TODO(someone): This should switch to llvm::iplist<Function>.
@@ -47,6 +49,9 @@
void print(raw_ostream &os) const;
void dump() const;
+
+private:
+ MLIRContext *context;
};
} // end namespace mlir
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index e4c8c8c..e499e56 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -69,7 +69,7 @@
void dump() const;
// Convenience factories.
- static IntegerType *getInt(unsigned width, MLIRContext *ctx);
+ static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
static PrimitiveType *getAffineInt(MLIRContext *ctx);
static PrimitiveType *getBF16(MLIRContext *ctx);
static PrimitiveType *getF16(MLIRContext *ctx);
@@ -162,12 +162,10 @@
IntegerType(unsigned width, MLIRContext *context);
};
-inline IntegerType *Type::getInt(unsigned width, MLIRContext *ctx) {
+inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) {
return IntegerType::get(width, ctx);
}
-
-
/// Function types map from a list of inputs to a list of results.
class FunctionType : public Type {
public:
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
new file mode 100644
index 0000000..a03f4b8
--- /dev/null
+++ b/lib/IR/Builders.cpp
@@ -0,0 +1,59 @@
+//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Types.h"
+using namespace mlir;
+
+Builder::Builder(Module *module) : context(module->getContext()) {}
+
+// Types.
+PrimitiveType *Builder::getAffineIntType() {
+ return Type::getAffineInt(context);
+}
+
+PrimitiveType *Builder::getBF16Type() { return Type::getBF16(context); }
+
+PrimitiveType *Builder::getF16Type() { return Type::getF16(context); }
+
+PrimitiveType *Builder::getF32Type() { return Type::getF32(context); }
+
+PrimitiveType *Builder::getF64Type() { return Type::getF64(context); }
+
+IntegerType *Builder::getIntegerType(unsigned width) {
+ return Type::getInteger(width, context);
+}
+
+FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
+ ArrayRef<Type *> results) {
+ return FunctionType::get(inputs, results, context);
+}
+
+VectorType *Builder::getVectorType(ArrayRef<unsigned> shape,
+ Type *elementType) {
+ return VectorType::get(shape, elementType);
+}
+
+RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
+ Type *elementType) {
+ return RankedTensorType::get(shape, elementType);
+}
+
+UnrankedTensorType *Builder::getTensorType(Type *elementType) {
+ return UnrankedTensorType::get(elementType);
+}
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index b41ad2b..99e5e32 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -18,6 +18,4 @@
#include "mlir/IR/Module.h"
using namespace mlir;
-Module::Module() {
-}
-
+Module::Module(MLIRContext *context) : context(context) {}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 94d1468..b2fb4d6 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -24,7 +24,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
@@ -70,15 +70,15 @@
public:
Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
SMDiagnosticHandlerTy errorReporter)
- : context(context), lex(sourceMgr, errorReporter),
+ : builder(context), lex(sourceMgr, errorReporter),
curToken(lex.lexToken()), errorReporter(std::move(errorReporter)) {
- module.reset(new Module());
+ module.reset(new Module(context));
}
Module *parseModule();
private:
// State.
- MLIRContext *const context;
+ Builder builder;
// The lexer for the source file we're parsing.
Lexer lex;
@@ -170,14 +170,10 @@
AffineExpr *parseIntegerExpr(const AffineMapParserState &state);
AffineExpr *parseBareIdExpr(const AffineMapParserState &state);
- static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
- AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context);
- static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op,
- AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context);
+ AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
+ AffineExpr *lhs, AffineExpr *rhs);
+ AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExpr *lhs,
+ AffineExpr *rhs);
ParseResult parseAffineOperandExpr(const AffineMapParserState &state,
AffineExpr *&result);
ParseResult parseAffineLowPrecOpExpr(AffineExpr *llhs, AffineLowPrecOp llhsOp,
@@ -278,25 +274,25 @@
return (emitError("expected type"), nullptr);
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
- return Type::getBF16(context);
+ return builder.getBF16Type();
case Token::kw_f16:
consumeToken(Token::kw_f16);
- return Type::getF16(context);
+ return builder.getF16Type();
case Token::kw_f32:
consumeToken(Token::kw_f32);
- return Type::getF32(context);
+ return builder.getF32Type();
case Token::kw_f64:
consumeToken(Token::kw_f64);
- return Type::getF64(context);
+ return builder.getF64Type();
case Token::kw_affineint:
consumeToken(Token::kw_affineint);
- return Type::getAffineInt(context);
+ return builder.getAffineIntType();
case Token::inttype: {
auto width = curToken.getIntTypeBitwidth();
if (!width.hasValue())
return (emitError("invalid integer width"), nullptr);
consumeToken(Token::inttype);
- return Type::getInt(width.getValue(), context);
+ return builder.getIntegerType(width.getValue());
}
}
}
@@ -426,8 +422,8 @@
return (emitError("expected '>' in tensor type"), nullptr);
if (isUnranked)
- return UnrankedTensorType::get(elementType);
- return RankedTensorType::get(dimensions, elementType);
+ return builder.getTensorType(elementType);
+ return builder.getTensorType(dimensions, elementType);
}
/// Parse a memref type.
@@ -460,7 +456,7 @@
return (emitError("expected '>' in memref type"), nullptr);
// FIXME: Add an IR representation for memref types.
- return Type::getInt(1, context);
+ return builder.getIntegerType(1);
}
/// Parse a function type.
@@ -481,7 +477,7 @@
if (parseTypeList(results))
return nullptr;
- return FunctionType::get(arguments, results, context);
+ return builder.getFunctionType(arguments, results);
}
/// Parse an arbitrary type.
@@ -568,17 +564,17 @@
switch (curToken.getKind()) {
case Token::kw_true:
consumeToken(Token::kw_true);
- return BoolAttr::get(true, context);
+ return BoolAttr::get(true, builder.getContext());
case Token::kw_false:
consumeToken(Token::kw_false);
- return BoolAttr::get(false, context);
+ return BoolAttr::get(false, builder.getContext());
case Token::integer: {
auto val = curToken.getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
- return IntegerAttr::get((int64_t)val.getValue(), context);
+ return IntegerAttr::get((int64_t)val.getValue(), builder.getContext());
}
case Token::minus: {
@@ -588,7 +584,7 @@
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
- return IntegerAttr::get((int64_t)-val.getValue(), context);
+ return IntegerAttr::get((int64_t)-val.getValue(), builder.getContext());
}
return (emitError("expected constant integer or floating point value"),
@@ -598,7 +594,7 @@
case Token::string: {
auto val = curToken.getStringValue();
consumeToken(Token::string);
- return StringAttr::get(val, context);
+ return StringAttr::get(val, builder.getContext());
}
case Token::l_bracket: {
@@ -612,7 +608,7 @@
if (parseCommaSeparatedList(Token::r_bracket, parseElt))
return nullptr;
- return ArrayAttr::get(elements, context);
+ return ArrayAttr::get(elements, builder.getContext());
}
default:
// TODO: Handle floating point.
@@ -636,7 +632,7 @@
if (curToken.isNot(Token::bare_identifier, Token::inttype) &&
!curToken.isKeyword())
return emitError("expected attribute name");
- auto nameId = Identifier::get(curToken.getSpelling(), context);
+ auto nameId = Identifier::get(curToken.getSpelling(), builder.getContext());
consumeToken();
if (!consumeIf(Token::colon))
@@ -690,17 +686,16 @@
/// Create an affine op expression
AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineHighPrecOp op,
AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context) {
+ AffineExpr *rhs) {
switch (op) {
case Mul:
- return AffineMulExpr::get(lhs, rhs, context);
+ return AffineMulExpr::get(lhs, rhs, builder.getContext());
case FloorDiv:
- return AffineFloorDivExpr::get(lhs, rhs, context);
+ return AffineFloorDivExpr::get(lhs, rhs, builder.getContext());
case CeilDiv:
- return AffineCeilDivExpr::get(lhs, rhs, context);
+ return AffineCeilDivExpr::get(lhs, rhs, builder.getContext());
case Mod:
- return AffineModExpr::get(lhs, rhs, context);
+ return AffineModExpr::get(lhs, rhs, builder.getContext());
case HNoOp:
llvm_unreachable("can't create affine expression for null high prec op");
return nullptr;
@@ -709,13 +704,12 @@
AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineLowPrecOp op,
AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context) {
+ AffineExpr *rhs) {
switch (op) {
case AffineLowPrecOp::Add:
- return AffineAddExpr::get(lhs, rhs, context);
+ return AffineAddExpr::get(lhs, rhs, builder.getContext());
case AffineLowPrecOp::Sub:
- return AffineSubExpr::get(lhs, rhs, context);
+ return AffineSubExpr::get(lhs, rhs, builder.getContext());
case AffineLowPrecOp::LNoOp:
llvm_unreachable("can't create affine expression for null low prec op");
return nullptr;
@@ -762,8 +756,7 @@
if (llhs) {
// TODO(bondhugula): check whether 'lhs' here is a constant (for affine
// maps); semi-affine maps allow symbols.
- AffineExpr *expr =
- Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ AffineExpr *expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
AffineExpr *subRes = nullptr;
if (parseAffineHighPrecOpExpr(expr, op, state, subRes)) {
if (!subRes)
@@ -794,7 +787,7 @@
if (llhs) {
// TODO(bondhugula): check whether lhs here is a constant (for affine
// maps); semi-affine maps allow symbols.
- result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
return ParseSuccess;
}
@@ -874,8 +867,7 @@
AffineHighPrecOp rOp;
if ((lOp = consumeIfLowPrecOp())) {
if (llhs) {
- AffineExpr *sum =
- Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ AffineExpr *sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
AffineExpr *recSum = nullptr;
parseAffineLowPrecOpExpr(sum, lOp, state, recSum);
result = recSum ? recSum : sum;
@@ -903,7 +895,7 @@
// found expression. If non-null, assume for now that the op to associate
// with llhs is add.
AffineExpr *expr =
- llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes, context) : highRes;
+ llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes) : highRes;
// Recurse for subsequent add's after the affine mul expression
AffineLowPrecOp nextOp = consumeIfLowPrecOp();
if (nextOp) {
@@ -917,7 +909,7 @@
} else {
// Last operand in the expression list.
if (llhs) {
- result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
return ParseSuccess;
}
// No llhs, 'lhs' itself is the expression.
@@ -951,11 +943,11 @@
const auto &symbols = state.getSymbols();
if (dims.count(sRef)) {
consumeToken(Token::bare_identifier);
- return AffineDimExpr::get(dims.lookup(sRef), context);
+ return AffineDimExpr::get(dims.lookup(sRef), builder.getContext());
}
if (symbols.count(sRef)) {
consumeToken(Token::bare_identifier);
- return AffineSymbolExpr::get(symbols.lookup(sRef), context);
+ return AffineSymbolExpr::get(symbols.lookup(sRef), builder.getContext());
}
return emitError("identifier is neither dimensional nor symbolic"), nullptr;
}
@@ -968,7 +960,7 @@
AffineExpr *Parser::parseIntegerExpr(const AffineMapParserState &state) {
if (curToken.is(Token::integer)) {
auto *expr = AffineConstantExpr::get(
- curToken.getUnsignedIntegerValue().getValue(), context);
+ curToken.getUnsignedIntegerValue().getValue(), builder.getContext());
consumeToken(Token::integer);
return expr;
}
@@ -1094,7 +1086,7 @@
// Parsed a valid affine map.
return AffineMap::get(state.getNumDims(), state.getNumSymbols(), exprs,
- context);
+ builder.getContext());
}
//===----------------------------------------------------------------------===//
@@ -1187,7 +1179,7 @@
if (parseTypeList(results))
return ParseFailure;
}
- type = FunctionType::get(arguments, results, context);
+ type = builder.getFunctionType(arguments, results);
return ParseSuccess;
}
@@ -1214,11 +1206,13 @@
/// function as we are parsing it, e.g. the names for basic blocks. It handles
/// forward references.
class CFGFunctionParserState {
- public:
+public:
CFGFunction *function;
llvm::StringMap<std::pair<BasicBlock*, SMLoc>> blocksByName;
+ CFGFuncBuilder builder;
- CFGFunctionParserState(CFGFunction *function) : function(function) {}
+ CFGFunctionParserState(CFGFunction *function)
+ : function(function), builder(function) {}
/// Get the basic block with the specified name, creating it if it doesn't
/// already exist. The location specified is the point of use, which allows
@@ -1312,6 +1306,9 @@
if (!consumeIf(Token::colon))
return emitError("expected ':' after basic block name");
+ // Set the insertion point to the block we want to insert new operations into.
+ functionState.builder.setInsertionPoint(block);
+
// Parse the list of operations that make up the body of the block.
while (curToken.isNot(Token::kw_return, Token::kw_br)) {
auto loc = curToken.getLoc();
@@ -1322,17 +1319,14 @@
// We just parsed an operation. If it is a recognized one, verify that it
// is structurally as we expect. If not, produce an error with a reasonable
// source location.
- if (auto *opInfo = inst->getAbstractOperation(context))
+ if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
if (auto error = opInfo->verifyInvariants(inst))
return emitError(loc, error);
-
- block->getOperations().push_back(inst);
}
auto *term = parseTerminator(functionState);
if (!term)
return ParseFailure;
- block->setTerminator(term);
return ParseSuccess;
}
@@ -1380,8 +1374,8 @@
}
// TODO: Don't drop result name and operand names on the floor.
- auto nameId = Identifier::get(name, context);
- return new OperationInst(nameId, attributes, context);
+ auto nameId = Identifier::get(name, builder.getContext());
+ return functionState.builder.createOperation(nameId, attributes);
}
@@ -1400,7 +1394,7 @@
case Token::kw_return:
consumeToken(Token::kw_return);
- return new ReturnInst();
+ return functionState.builder.createReturnInst();
case Token::kw_br: {
consumeToken(Token::kw_br);
@@ -1408,7 +1402,7 @@
curToken.getLoc());
if (!consumeIf(Token::bare_identifier))
return (emitError("expected basic block name"), nullptr);
- return new BranchInst(destBB);
+ return functionState.builder.createBranchInst(destBB);
}
// TODO: cond_br.
}