Introduce the start of IR builder APIs, which makes it easier and less error
prone to create things.
PiperOrigin-RevId: 203703229
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 94d1468..b2fb4d6 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -24,7 +24,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
@@ -70,15 +70,15 @@
public:
Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
SMDiagnosticHandlerTy errorReporter)
- : context(context), lex(sourceMgr, errorReporter),
+ : builder(context), lex(sourceMgr, errorReporter),
curToken(lex.lexToken()), errorReporter(std::move(errorReporter)) {
- module.reset(new Module());
+ module.reset(new Module(context));
}
Module *parseModule();
private:
// State.
- MLIRContext *const context;
+ Builder builder;
// The lexer for the source file we're parsing.
Lexer lex;
@@ -170,14 +170,10 @@
AffineExpr *parseIntegerExpr(const AffineMapParserState &state);
AffineExpr *parseBareIdExpr(const AffineMapParserState &state);
- static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
- AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context);
- static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op,
- AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context);
+ AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
+ AffineExpr *lhs, AffineExpr *rhs);
+ AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExpr *lhs,
+ AffineExpr *rhs);
ParseResult parseAffineOperandExpr(const AffineMapParserState &state,
AffineExpr *&result);
ParseResult parseAffineLowPrecOpExpr(AffineExpr *llhs, AffineLowPrecOp llhsOp,
@@ -278,25 +274,25 @@
return (emitError("expected type"), nullptr);
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
- return Type::getBF16(context);
+ return builder.getBF16Type();
case Token::kw_f16:
consumeToken(Token::kw_f16);
- return Type::getF16(context);
+ return builder.getF16Type();
case Token::kw_f32:
consumeToken(Token::kw_f32);
- return Type::getF32(context);
+ return builder.getF32Type();
case Token::kw_f64:
consumeToken(Token::kw_f64);
- return Type::getF64(context);
+ return builder.getF64Type();
case Token::kw_affineint:
consumeToken(Token::kw_affineint);
- return Type::getAffineInt(context);
+ return builder.getAffineIntType();
case Token::inttype: {
auto width = curToken.getIntTypeBitwidth();
if (!width.hasValue())
return (emitError("invalid integer width"), nullptr);
consumeToken(Token::inttype);
- return Type::getInt(width.getValue(), context);
+ return builder.getIntegerType(width.getValue());
}
}
}
@@ -426,8 +422,8 @@
return (emitError("expected '>' in tensor type"), nullptr);
if (isUnranked)
- return UnrankedTensorType::get(elementType);
- return RankedTensorType::get(dimensions, elementType);
+ return builder.getTensorType(elementType);
+ return builder.getTensorType(dimensions, elementType);
}
/// Parse a memref type.
@@ -460,7 +456,7 @@
return (emitError("expected '>' in memref type"), nullptr);
// FIXME: Add an IR representation for memref types.
- return Type::getInt(1, context);
+ return builder.getIntegerType(1);
}
/// Parse a function type.
@@ -481,7 +477,7 @@
if (parseTypeList(results))
return nullptr;
- return FunctionType::get(arguments, results, context);
+ return builder.getFunctionType(arguments, results);
}
/// Parse an arbitrary type.
@@ -568,17 +564,17 @@
switch (curToken.getKind()) {
case Token::kw_true:
consumeToken(Token::kw_true);
- return BoolAttr::get(true, context);
+ return BoolAttr::get(true, builder.getContext());
case Token::kw_false:
consumeToken(Token::kw_false);
- return BoolAttr::get(false, context);
+ return BoolAttr::get(false, builder.getContext());
case Token::integer: {
auto val = curToken.getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
- return IntegerAttr::get((int64_t)val.getValue(), context);
+ return IntegerAttr::get((int64_t)val.getValue(), builder.getContext());
}
case Token::minus: {
@@ -588,7 +584,7 @@
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
- return IntegerAttr::get((int64_t)-val.getValue(), context);
+ return IntegerAttr::get((int64_t)-val.getValue(), builder.getContext());
}
return (emitError("expected constant integer or floating point value"),
@@ -598,7 +594,7 @@
case Token::string: {
auto val = curToken.getStringValue();
consumeToken(Token::string);
- return StringAttr::get(val, context);
+ return StringAttr::get(val, builder.getContext());
}
case Token::l_bracket: {
@@ -612,7 +608,7 @@
if (parseCommaSeparatedList(Token::r_bracket, parseElt))
return nullptr;
- return ArrayAttr::get(elements, context);
+ return ArrayAttr::get(elements, builder.getContext());
}
default:
// TODO: Handle floating point.
@@ -636,7 +632,7 @@
if (curToken.isNot(Token::bare_identifier, Token::inttype) &&
!curToken.isKeyword())
return emitError("expected attribute name");
- auto nameId = Identifier::get(curToken.getSpelling(), context);
+ auto nameId = Identifier::get(curToken.getSpelling(), builder.getContext());
consumeToken();
if (!consumeIf(Token::colon))
@@ -690,17 +686,16 @@
/// Create an affine op expression
AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineHighPrecOp op,
AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context) {
+ AffineExpr *rhs) {
switch (op) {
case Mul:
- return AffineMulExpr::get(lhs, rhs, context);
+ return AffineMulExpr::get(lhs, rhs, builder.getContext());
case FloorDiv:
- return AffineFloorDivExpr::get(lhs, rhs, context);
+ return AffineFloorDivExpr::get(lhs, rhs, builder.getContext());
case CeilDiv:
- return AffineCeilDivExpr::get(lhs, rhs, context);
+ return AffineCeilDivExpr::get(lhs, rhs, builder.getContext());
case Mod:
- return AffineModExpr::get(lhs, rhs, context);
+ return AffineModExpr::get(lhs, rhs, builder.getContext());
case HNoOp:
llvm_unreachable("can't create affine expression for null high prec op");
return nullptr;
@@ -709,13 +704,12 @@
AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineLowPrecOp op,
AffineExpr *lhs,
- AffineExpr *rhs,
- MLIRContext *context) {
+ AffineExpr *rhs) {
switch (op) {
case AffineLowPrecOp::Add:
- return AffineAddExpr::get(lhs, rhs, context);
+ return AffineAddExpr::get(lhs, rhs, builder.getContext());
case AffineLowPrecOp::Sub:
- return AffineSubExpr::get(lhs, rhs, context);
+ return AffineSubExpr::get(lhs, rhs, builder.getContext());
case AffineLowPrecOp::LNoOp:
llvm_unreachable("can't create affine expression for null low prec op");
return nullptr;
@@ -762,8 +756,7 @@
if (llhs) {
// TODO(bondhugula): check whether 'lhs' here is a constant (for affine
// maps); semi-affine maps allow symbols.
- AffineExpr *expr =
- Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ AffineExpr *expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
AffineExpr *subRes = nullptr;
if (parseAffineHighPrecOpExpr(expr, op, state, subRes)) {
if (!subRes)
@@ -794,7 +787,7 @@
if (llhs) {
// TODO(bondhugula): check whether lhs here is a constant (for affine
// maps); semi-affine maps allow symbols.
- result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
return ParseSuccess;
}
@@ -874,8 +867,7 @@
AffineHighPrecOp rOp;
if ((lOp = consumeIfLowPrecOp())) {
if (llhs) {
- AffineExpr *sum =
- Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ AffineExpr *sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
AffineExpr *recSum = nullptr;
parseAffineLowPrecOpExpr(sum, lOp, state, recSum);
result = recSum ? recSum : sum;
@@ -903,7 +895,7 @@
// found expression. If non-null, assume for now that the op to associate
// with llhs is add.
AffineExpr *expr =
- llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes, context) : highRes;
+ llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes) : highRes;
// Recurse for subsequent add's after the affine mul expression
AffineLowPrecOp nextOp = consumeIfLowPrecOp();
if (nextOp) {
@@ -917,7 +909,7 @@
} else {
// Last operand in the expression list.
if (llhs) {
- result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
return ParseSuccess;
}
// No llhs, 'lhs' itself is the expression.
@@ -951,11 +943,11 @@
const auto &symbols = state.getSymbols();
if (dims.count(sRef)) {
consumeToken(Token::bare_identifier);
- return AffineDimExpr::get(dims.lookup(sRef), context);
+ return AffineDimExpr::get(dims.lookup(sRef), builder.getContext());
}
if (symbols.count(sRef)) {
consumeToken(Token::bare_identifier);
- return AffineSymbolExpr::get(symbols.lookup(sRef), context);
+ return AffineSymbolExpr::get(symbols.lookup(sRef), builder.getContext());
}
return emitError("identifier is neither dimensional nor symbolic"), nullptr;
}
@@ -968,7 +960,7 @@
AffineExpr *Parser::parseIntegerExpr(const AffineMapParserState &state) {
if (curToken.is(Token::integer)) {
auto *expr = AffineConstantExpr::get(
- curToken.getUnsignedIntegerValue().getValue(), context);
+ curToken.getUnsignedIntegerValue().getValue(), builder.getContext());
consumeToken(Token::integer);
return expr;
}
@@ -1094,7 +1086,7 @@
// Parsed a valid affine map.
return AffineMap::get(state.getNumDims(), state.getNumSymbols(), exprs,
- context);
+ builder.getContext());
}
//===----------------------------------------------------------------------===//
@@ -1187,7 +1179,7 @@
if (parseTypeList(results))
return ParseFailure;
}
- type = FunctionType::get(arguments, results, context);
+ type = builder.getFunctionType(arguments, results);
return ParseSuccess;
}
@@ -1214,11 +1206,13 @@
/// function as we are parsing it, e.g. the names for basic blocks. It handles
/// forward references.
class CFGFunctionParserState {
- public:
+public:
CFGFunction *function;
llvm::StringMap<std::pair<BasicBlock*, SMLoc>> blocksByName;
+ CFGFuncBuilder builder;
- CFGFunctionParserState(CFGFunction *function) : function(function) {}
+ CFGFunctionParserState(CFGFunction *function)
+ : function(function), builder(function) {}
/// Get the basic block with the specified name, creating it if it doesn't
/// already exist. The location specified is the point of use, which allows
@@ -1312,6 +1306,9 @@
if (!consumeIf(Token::colon))
return emitError("expected ':' after basic block name");
+ // Set the insertion point to the block we want to insert new operations into.
+ functionState.builder.setInsertionPoint(block);
+
// Parse the list of operations that make up the body of the block.
while (curToken.isNot(Token::kw_return, Token::kw_br)) {
auto loc = curToken.getLoc();
@@ -1322,17 +1319,14 @@
// We just parsed an operation. If it is a recognized one, verify that it
// is structurally as we expect. If not, produce an error with a reasonable
// source location.
- if (auto *opInfo = inst->getAbstractOperation(context))
+ if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
if (auto error = opInfo->verifyInvariants(inst))
return emitError(loc, error);
-
- block->getOperations().push_back(inst);
}
auto *term = parseTerminator(functionState);
if (!term)
return ParseFailure;
- block->setTerminator(term);
return ParseSuccess;
}
@@ -1380,8 +1374,8 @@
}
// TODO: Don't drop result name and operand names on the floor.
- auto nameId = Identifier::get(name, context);
- return new OperationInst(nameId, attributes, context);
+ auto nameId = Identifier::get(name, builder.getContext());
+ return functionState.builder.createOperation(nameId, attributes);
}
@@ -1400,7 +1394,7 @@
case Token::kw_return:
consumeToken(Token::kw_return);
- return new ReturnInst();
+ return functionState.builder.createReturnInst();
case Token::kw_br: {
consumeToken(Token::kw_br);
@@ -1408,7 +1402,7 @@
curToken.getLoc());
if (!consumeIf(Token::bare_identifier))
return (emitError("expected basic block name"), nullptr);
- return new BranchInst(destBB);
+ return functionState.builder.createBranchInst(destBB);
}
// TODO: cond_br.
}