Introduce the start of IR builder APIs, which makes it easier and less error
prone to create things.

PiperOrigin-RevId: 203703229
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 94d1468..b2fb4d6 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -24,7 +24,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSet.h"
@@ -70,15 +70,15 @@
 public:
   Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
          SMDiagnosticHandlerTy errorReporter)
-      : context(context), lex(sourceMgr, errorReporter),
+      : builder(context), lex(sourceMgr, errorReporter),
         curToken(lex.lexToken()), errorReporter(std::move(errorReporter)) {
-    module.reset(new Module());
+    module.reset(new Module(context));
   }
 
   Module *parseModule();
 private:
   // State.
-  MLIRContext *const context;
+  Builder builder;
 
   // The lexer for the source file we're parsing.
   Lexer lex;
@@ -170,14 +170,10 @@
   AffineExpr *parseIntegerExpr(const AffineMapParserState &state);
   AffineExpr *parseBareIdExpr(const AffineMapParserState &state);
 
-  static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
-                                                   AffineExpr *lhs,
-                                                   AffineExpr *rhs,
-                                                   MLIRContext *context);
-  static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op,
-                                                   AffineExpr *lhs,
-                                                   AffineExpr *rhs,
-                                                   MLIRContext *context);
+  AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
+                                            AffineExpr *lhs, AffineExpr *rhs);
+  AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExpr *lhs,
+                                            AffineExpr *rhs);
   ParseResult parseAffineOperandExpr(const AffineMapParserState &state,
                                      AffineExpr *&result);
   ParseResult parseAffineLowPrecOpExpr(AffineExpr *llhs, AffineLowPrecOp llhsOp,
@@ -278,25 +274,25 @@
     return (emitError("expected type"), nullptr);
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
-    return Type::getBF16(context);
+    return builder.getBF16Type();
   case Token::kw_f16:
     consumeToken(Token::kw_f16);
-    return Type::getF16(context);
+    return builder.getF16Type();
   case Token::kw_f32:
     consumeToken(Token::kw_f32);
-    return Type::getF32(context);
+    return builder.getF32Type();
   case Token::kw_f64:
     consumeToken(Token::kw_f64);
-    return Type::getF64(context);
+    return builder.getF64Type();
   case Token::kw_affineint:
     consumeToken(Token::kw_affineint);
-    return Type::getAffineInt(context);
+    return builder.getAffineIntType();
   case Token::inttype: {
     auto width = curToken.getIntTypeBitwidth();
     if (!width.hasValue())
       return (emitError("invalid integer width"), nullptr);
     consumeToken(Token::inttype);
-    return Type::getInt(width.getValue(), context);
+    return builder.getIntegerType(width.getValue());
   }
   }
 }
@@ -426,8 +422,8 @@
     return (emitError("expected '>' in tensor type"), nullptr);
 
   if (isUnranked)
-    return UnrankedTensorType::get(elementType);
-  return RankedTensorType::get(dimensions, elementType);
+    return builder.getTensorType(elementType);
+  return builder.getTensorType(dimensions, elementType);
 }
 
 /// Parse a memref type.
@@ -460,7 +456,7 @@
     return (emitError("expected '>' in memref type"), nullptr);
 
   // FIXME: Add an IR representation for memref types.
-  return Type::getInt(1, context);
+  return builder.getIntegerType(1);
 }
 
 /// Parse a function type.
@@ -481,7 +477,7 @@
   if (parseTypeList(results))
     return nullptr;
 
-  return FunctionType::get(arguments, results, context);
+  return builder.getFunctionType(arguments, results);
 }
 
 /// Parse an arbitrary type.
@@ -568,17 +564,17 @@
   switch (curToken.getKind()) {
   case Token::kw_true:
     consumeToken(Token::kw_true);
-    return BoolAttr::get(true, context);
+    return BoolAttr::get(true, builder.getContext());
   case Token::kw_false:
     consumeToken(Token::kw_false);
-    return BoolAttr::get(false, context);
+    return BoolAttr::get(false, builder.getContext());
 
   case Token::integer: {
     auto val = curToken.getUInt64IntegerValue();
     if (!val.hasValue() || (int64_t)val.getValue() < 0)
       return (emitError("integer too large for attribute"), nullptr);
     consumeToken(Token::integer);
-    return IntegerAttr::get((int64_t)val.getValue(), context);
+    return IntegerAttr::get((int64_t)val.getValue(), builder.getContext());
   }
 
   case Token::minus: {
@@ -588,7 +584,7 @@
       if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
         return (emitError("integer too large for attribute"), nullptr);
       consumeToken(Token::integer);
-      return IntegerAttr::get((int64_t)-val.getValue(), context);
+      return IntegerAttr::get((int64_t)-val.getValue(), builder.getContext());
     }
 
     return (emitError("expected constant integer or floating point value"),
@@ -598,7 +594,7 @@
   case Token::string: {
     auto val = curToken.getStringValue();
     consumeToken(Token::string);
-    return StringAttr::get(val, context);
+    return StringAttr::get(val, builder.getContext());
   }
 
   case Token::l_bracket: {
@@ -612,7 +608,7 @@
 
     if (parseCommaSeparatedList(Token::r_bracket, parseElt))
       return nullptr;
-    return ArrayAttr::get(elements, context);
+    return ArrayAttr::get(elements, builder.getContext());
   }
   default:
     // TODO: Handle floating point.
@@ -636,7 +632,7 @@
     if (curToken.isNot(Token::bare_identifier, Token::inttype) &&
         !curToken.isKeyword())
       return emitError("expected attribute name");
-    auto nameId = Identifier::get(curToken.getSpelling(), context);
+    auto nameId = Identifier::get(curToken.getSpelling(), builder.getContext());
     consumeToken();
 
     if (!consumeIf(Token::colon))
@@ -690,17 +686,16 @@
 /// Create an affine op expression
 AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineHighPrecOp op,
                                                   AffineExpr *lhs,
-                                                  AffineExpr *rhs,
-                                                  MLIRContext *context) {
+                                                  AffineExpr *rhs) {
   switch (op) {
   case Mul:
-    return AffineMulExpr::get(lhs, rhs, context);
+    return AffineMulExpr::get(lhs, rhs, builder.getContext());
   case FloorDiv:
-    return AffineFloorDivExpr::get(lhs, rhs, context);
+    return AffineFloorDivExpr::get(lhs, rhs, builder.getContext());
   case CeilDiv:
-    return AffineCeilDivExpr::get(lhs, rhs, context);
+    return AffineCeilDivExpr::get(lhs, rhs, builder.getContext());
   case Mod:
-    return AffineModExpr::get(lhs, rhs, context);
+    return AffineModExpr::get(lhs, rhs, builder.getContext());
   case HNoOp:
     llvm_unreachable("can't create affine expression for null high prec op");
     return nullptr;
@@ -709,13 +704,12 @@
 
 AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineLowPrecOp op,
                                                   AffineExpr *lhs,
-                                                  AffineExpr *rhs,
-                                                  MLIRContext *context) {
+                                                  AffineExpr *rhs) {
   switch (op) {
   case AffineLowPrecOp::Add:
-    return AffineAddExpr::get(lhs, rhs, context);
+    return AffineAddExpr::get(lhs, rhs, builder.getContext());
   case AffineLowPrecOp::Sub:
-    return AffineSubExpr::get(lhs, rhs, context);
+    return AffineSubExpr::get(lhs, rhs, builder.getContext());
   case AffineLowPrecOp::LNoOp:
     llvm_unreachable("can't create affine expression for null low prec op");
     return nullptr;
@@ -762,8 +756,7 @@
     if (llhs) {
       // TODO(bondhugula): check whether 'lhs' here is a constant (for affine
       // maps); semi-affine maps allow symbols.
-      AffineExpr *expr =
-          Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+      AffineExpr *expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
       AffineExpr *subRes = nullptr;
       if (parseAffineHighPrecOpExpr(expr, op, state, subRes)) {
         if (!subRes)
@@ -794,7 +787,7 @@
   if (llhs) {
     // TODO(bondhugula): check whether lhs here is a constant (for affine
     // maps); semi-affine maps allow symbols.
-    result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+    result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
     return ParseSuccess;
   }
 
@@ -874,8 +867,7 @@
   AffineHighPrecOp rOp;
   if ((lOp = consumeIfLowPrecOp())) {
     if (llhs) {
-      AffineExpr *sum =
-          Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+      AffineExpr *sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
       AffineExpr *recSum = nullptr;
       parseAffineLowPrecOpExpr(sum, lOp, state, recSum);
       result = recSum ? recSum : sum;
@@ -903,7 +895,7 @@
     // found expression. If non-null, assume for now that the op to associate
     // with llhs is add.
     AffineExpr *expr =
-        llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes, context) : highRes;
+        llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes) : highRes;
     // Recurse for subsequent add's after the affine mul expression
     AffineLowPrecOp nextOp = consumeIfLowPrecOp();
     if (nextOp) {
@@ -917,7 +909,7 @@
   } else {
     // Last operand in the expression list.
     if (llhs) {
-      result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+      result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
       return ParseSuccess;
     }
     // No llhs, 'lhs' itself is the expression.
@@ -951,11 +943,11 @@
     const auto &symbols = state.getSymbols();
     if (dims.count(sRef)) {
       consumeToken(Token::bare_identifier);
-      return AffineDimExpr::get(dims.lookup(sRef), context);
+      return AffineDimExpr::get(dims.lookup(sRef), builder.getContext());
     }
     if (symbols.count(sRef)) {
       consumeToken(Token::bare_identifier);
-      return AffineSymbolExpr::get(symbols.lookup(sRef), context);
+      return AffineSymbolExpr::get(symbols.lookup(sRef), builder.getContext());
     }
     return emitError("identifier is neither dimensional nor symbolic"), nullptr;
   }
@@ -968,7 +960,7 @@
 AffineExpr *Parser::parseIntegerExpr(const AffineMapParserState &state) {
   if (curToken.is(Token::integer)) {
     auto *expr = AffineConstantExpr::get(
-        curToken.getUnsignedIntegerValue().getValue(), context);
+        curToken.getUnsignedIntegerValue().getValue(), builder.getContext());
     consumeToken(Token::integer);
     return expr;
   }
@@ -1094,7 +1086,7 @@
 
   // Parsed a valid affine map.
   return AffineMap::get(state.getNumDims(), state.getNumSymbols(), exprs,
-                        context);
+                        builder.getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1187,7 +1179,7 @@
     if (parseTypeList(results))
       return ParseFailure;
   }
-  type = FunctionType::get(arguments, results, context);
+  type = builder.getFunctionType(arguments, results);
   return ParseSuccess;
 }
 
@@ -1214,11 +1206,13 @@
 /// function as we are parsing it, e.g. the names for basic blocks.  It handles
 /// forward references.
 class CFGFunctionParserState {
- public:
+public:
   CFGFunction *function;
   llvm::StringMap<std::pair<BasicBlock*, SMLoc>> blocksByName;
+  CFGFuncBuilder builder;
 
-  CFGFunctionParserState(CFGFunction *function) : function(function) {}
+  CFGFunctionParserState(CFGFunction *function)
+      : function(function), builder(function) {}
 
   /// Get the basic block with the specified name, creating it if it doesn't
   /// already exist.  The location specified is the point of use, which allows
@@ -1312,6 +1306,9 @@
   if (!consumeIf(Token::colon))
     return emitError("expected ':' after basic block name");
 
+  // Set the insertion point to the block we want to insert new operations into.
+  functionState.builder.setInsertionPoint(block);
+
   // Parse the list of operations that make up the body of the block.
   while (curToken.isNot(Token::kw_return, Token::kw_br)) {
     auto loc = curToken.getLoc();
@@ -1322,17 +1319,14 @@
     // We just parsed an operation.  If it is a recognized one, verify that it
     // is structurally as we expect.  If not, produce an error with a reasonable
     // source location.
-    if (auto *opInfo = inst->getAbstractOperation(context))
+    if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
       if (auto error = opInfo->verifyInvariants(inst))
         return emitError(loc, error);
-
-    block->getOperations().push_back(inst);
   }
 
   auto *term = parseTerminator(functionState);
   if (!term)
     return ParseFailure;
-  block->setTerminator(term);
 
   return ParseSuccess;
 }
@@ -1380,8 +1374,8 @@
   }
 
   // TODO: Don't drop result name and operand names on the floor.
-  auto nameId = Identifier::get(name, context);
-  return new OperationInst(nameId, attributes, context);
+  auto nameId = Identifier::get(name, builder.getContext());
+  return functionState.builder.createOperation(nameId, attributes);
 }
 
 
@@ -1400,7 +1394,7 @@
 
   case Token::kw_return:
     consumeToken(Token::kw_return);
-    return new ReturnInst();
+    return functionState.builder.createReturnInst();
 
   case Token::kw_br: {
     consumeToken(Token::kw_br);
@@ -1408,7 +1402,7 @@
                                               curToken.getLoc());
     if (!consumeIf(Token::bare_identifier))
       return (emitError("expected basic block name"), nullptr);
-    return new BranchInst(destBB);
+    return functionState.builder.createBranchInst(destBB);
   }
     // TODO: cond_br.
   }