Introduce the start of IR builder APIs, which makes it easier and less error
prone to create things.

PiperOrigin-RevId: 203703229
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
new file mode 100644
index 0000000..644c812
--- /dev/null
+++ b/include/mlir/IR/Builders.h
@@ -0,0 +1,122 @@
+//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_BUILDERS_H
+#define MLIR_IR_BUILDERS_H
+
+#include "mlir/IR/CFGFunction.h"
+
+namespace mlir {
+class MLIRContext;
+class Module;
+class Type;
+class PrimitiveType;
+class IntegerType;
+class FunctionType;
+class VectorType;
+class RankedTensorType;
+class UnrankedTensorType;
+
+/// This class is a general helper class for creating context-global objects
+/// like types, attributes, and affine expressions.
+class Builder {
+public:
+  explicit Builder(MLIRContext *context) : context(context) {}
+  explicit Builder(Module *module);
+
+  MLIRContext *getContext() const { return context; }
+
+  // Types.
+  PrimitiveType *getAffineIntType();
+  PrimitiveType *getBF16Type();
+  PrimitiveType *getF16Type();
+  PrimitiveType *getF32Type();
+  PrimitiveType *getF64Type();
+  IntegerType *getIntegerType(unsigned width);
+  FunctionType *getFunctionType(ArrayRef<Type *> inputs,
+                                ArrayRef<Type *> results);
+  VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
+  RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
+  UnrankedTensorType *getTensorType(Type *elementType);
+
+  // TODO: Helpers for affine map/exprs, etc.
+  // TODO: Helpers for attributes.
+  // TODO: Identifier
+  // TODO: createModule()
+protected:
+  MLIRContext *context;
+};
+
+/// This class helps build a CFGFunction.  Instructions that are created are
+/// automatically inserted at an insertion point or added to the current basic
+/// block.
+class CFGFuncBuilder : public Builder {
+public:
+  CFGFuncBuilder(BasicBlock *block)
+      : Builder(block->getFunction()->getContext()),
+        function(block->getFunction()) {
+    setInsertionPoint(block);
+  }
+  CFGFuncBuilder(CFGFunction *function)
+      : Builder(function->getContext()), function(function) {}
+
+  /// Reset the insertion point to no location.  Creating an operation without a
+  /// set insertion point is an error, but this can still be useful when the
+  /// current insertion point a builder refers to is being removed.
+  void clearInsertionPoint() {
+    this->block = nullptr;
+    insertPoint = BasicBlock::iterator();
+  }
+
+  /// Set the insertion point to the end of the specified block.
+  void setInsertionPoint(BasicBlock *block) {
+    this->block = block;
+    insertPoint = block->end();
+  }
+
+  OperationInst *createOperation(Identifier name,
+                                 ArrayRef<NamedAttribute> attributes) {
+    auto op = new OperationInst(name, attributes, context);
+    block->getOperations().push_back(op);
+    return op;
+  }
+
+  // Terminators.
+
+  ReturnInst *createReturnInst() { return insertTerminator(new ReturnInst()); }
+
+  BranchInst *createBranchInst(BasicBlock *dest) {
+    return insertTerminator(new BranchInst(dest));
+  }
+
+private:
+  template <typename T>
+  T *insertTerminator(T *term) {
+    block->setTerminator(term);
+    return term;
+  }
+
+  CFGFunction *function;
+  BasicBlock *block = nullptr;
+  BasicBlock::iterator insertPoint;
+};
+
+// TODO: MLFuncBuilder
+
+} // namespace mlir
+
+#endif
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index ec277d6..a12b387 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -31,7 +31,9 @@
 
 class Module {
 public:
-  explicit Module();
+  explicit Module(MLIRContext *context);
+
+  MLIRContext *getContext() const { return context; }
 
   // FIXME: wrong representation and API.
   // TODO(someone): This should switch to llvm::iplist<Function>.
@@ -47,6 +49,9 @@
 
   void print(raw_ostream &os) const;
   void dump() const;
+
+private:
+  MLIRContext *context;
 };
 } // end namespace mlir
 
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index e4c8c8c..e499e56 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -69,7 +69,7 @@
   void dump() const;
 
   // Convenience factories.
-  static IntegerType *getInt(unsigned width, MLIRContext *ctx);
+  static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
   static PrimitiveType *getAffineInt(MLIRContext *ctx);
   static PrimitiveType *getBF16(MLIRContext *ctx);
   static PrimitiveType *getF16(MLIRContext *ctx);
@@ -162,12 +162,10 @@
   IntegerType(unsigned width, MLIRContext *context);
 };
 
-inline IntegerType *Type::getInt(unsigned width, MLIRContext *ctx) {
+inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) {
   return IntegerType::get(width, ctx);
 }
 
-
-
 /// Function types map from a list of inputs to a list of results.
 class FunctionType : public Type {
 public:
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
new file mode 100644
index 0000000..a03f4b8
--- /dev/null
+++ b/lib/IR/Builders.cpp
@@ -0,0 +1,59 @@
+//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Types.h"
+using namespace mlir;
+
+Builder::Builder(Module *module) : context(module->getContext()) {}
+
+// Types.
+PrimitiveType *Builder::getAffineIntType() {
+  return Type::getAffineInt(context);
+}
+
+PrimitiveType *Builder::getBF16Type() { return Type::getBF16(context); }
+
+PrimitiveType *Builder::getF16Type() { return Type::getF16(context); }
+
+PrimitiveType *Builder::getF32Type() { return Type::getF32(context); }
+
+PrimitiveType *Builder::getF64Type() { return Type::getF64(context); }
+
+IntegerType *Builder::getIntegerType(unsigned width) {
+  return Type::getInteger(width, context);
+}
+
+FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
+                                       ArrayRef<Type *> results) {
+  return FunctionType::get(inputs, results, context);
+}
+
+VectorType *Builder::getVectorType(ArrayRef<unsigned> shape,
+                                   Type *elementType) {
+  return VectorType::get(shape, elementType);
+}
+
+RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
+                                         Type *elementType) {
+  return RankedTensorType::get(shape, elementType);
+}
+
+UnrankedTensorType *Builder::getTensorType(Type *elementType) {
+  return UnrankedTensorType::get(elementType);
+}
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index b41ad2b..99e5e32 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -18,6 +18,4 @@
 #include "mlir/IR/Module.h"
 using namespace mlir;
 
-Module::Module() {
-}
-
+Module::Module(MLIRContext *context) : context(context) {}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 94d1468..b2fb4d6 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -24,7 +24,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSet.h"
@@ -70,15 +70,15 @@
 public:
   Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
          SMDiagnosticHandlerTy errorReporter)
-      : context(context), lex(sourceMgr, errorReporter),
+      : builder(context), lex(sourceMgr, errorReporter),
         curToken(lex.lexToken()), errorReporter(std::move(errorReporter)) {
-    module.reset(new Module());
+    module.reset(new Module(context));
   }
 
   Module *parseModule();
 private:
   // State.
-  MLIRContext *const context;
+  Builder builder;
 
   // The lexer for the source file we're parsing.
   Lexer lex;
@@ -170,14 +170,10 @@
   AffineExpr *parseIntegerExpr(const AffineMapParserState &state);
   AffineExpr *parseBareIdExpr(const AffineMapParserState &state);
 
-  static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
-                                                   AffineExpr *lhs,
-                                                   AffineExpr *rhs,
-                                                   MLIRContext *context);
-  static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op,
-                                                   AffineExpr *lhs,
-                                                   AffineExpr *rhs,
-                                                   MLIRContext *context);
+  AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
+                                            AffineExpr *lhs, AffineExpr *rhs);
+  AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExpr *lhs,
+                                            AffineExpr *rhs);
   ParseResult parseAffineOperandExpr(const AffineMapParserState &state,
                                      AffineExpr *&result);
   ParseResult parseAffineLowPrecOpExpr(AffineExpr *llhs, AffineLowPrecOp llhsOp,
@@ -278,25 +274,25 @@
     return (emitError("expected type"), nullptr);
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
-    return Type::getBF16(context);
+    return builder.getBF16Type();
   case Token::kw_f16:
     consumeToken(Token::kw_f16);
-    return Type::getF16(context);
+    return builder.getF16Type();
   case Token::kw_f32:
     consumeToken(Token::kw_f32);
-    return Type::getF32(context);
+    return builder.getF32Type();
   case Token::kw_f64:
     consumeToken(Token::kw_f64);
-    return Type::getF64(context);
+    return builder.getF64Type();
   case Token::kw_affineint:
     consumeToken(Token::kw_affineint);
-    return Type::getAffineInt(context);
+    return builder.getAffineIntType();
   case Token::inttype: {
     auto width = curToken.getIntTypeBitwidth();
     if (!width.hasValue())
       return (emitError("invalid integer width"), nullptr);
     consumeToken(Token::inttype);
-    return Type::getInt(width.getValue(), context);
+    return builder.getIntegerType(width.getValue());
   }
   }
 }
@@ -426,8 +422,8 @@
     return (emitError("expected '>' in tensor type"), nullptr);
 
   if (isUnranked)
-    return UnrankedTensorType::get(elementType);
-  return RankedTensorType::get(dimensions, elementType);
+    return builder.getTensorType(elementType);
+  return builder.getTensorType(dimensions, elementType);
 }
 
 /// Parse a memref type.
@@ -460,7 +456,7 @@
     return (emitError("expected '>' in memref type"), nullptr);
 
   // FIXME: Add an IR representation for memref types.
-  return Type::getInt(1, context);
+  return builder.getIntegerType(1);
 }
 
 /// Parse a function type.
@@ -481,7 +477,7 @@
   if (parseTypeList(results))
     return nullptr;
 
-  return FunctionType::get(arguments, results, context);
+  return builder.getFunctionType(arguments, results);
 }
 
 /// Parse an arbitrary type.
@@ -568,17 +564,17 @@
   switch (curToken.getKind()) {
   case Token::kw_true:
     consumeToken(Token::kw_true);
-    return BoolAttr::get(true, context);
+    return BoolAttr::get(true, builder.getContext());
   case Token::kw_false:
     consumeToken(Token::kw_false);
-    return BoolAttr::get(false, context);
+    return BoolAttr::get(false, builder.getContext());
 
   case Token::integer: {
     auto val = curToken.getUInt64IntegerValue();
     if (!val.hasValue() || (int64_t)val.getValue() < 0)
       return (emitError("integer too large for attribute"), nullptr);
     consumeToken(Token::integer);
-    return IntegerAttr::get((int64_t)val.getValue(), context);
+    return IntegerAttr::get((int64_t)val.getValue(), builder.getContext());
   }
 
   case Token::minus: {
@@ -588,7 +584,7 @@
       if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
         return (emitError("integer too large for attribute"), nullptr);
       consumeToken(Token::integer);
-      return IntegerAttr::get((int64_t)-val.getValue(), context);
+      return IntegerAttr::get((int64_t)-val.getValue(), builder.getContext());
     }
 
     return (emitError("expected constant integer or floating point value"),
@@ -598,7 +594,7 @@
   case Token::string: {
     auto val = curToken.getStringValue();
     consumeToken(Token::string);
-    return StringAttr::get(val, context);
+    return StringAttr::get(val, builder.getContext());
   }
 
   case Token::l_bracket: {
@@ -612,7 +608,7 @@
 
     if (parseCommaSeparatedList(Token::r_bracket, parseElt))
       return nullptr;
-    return ArrayAttr::get(elements, context);
+    return ArrayAttr::get(elements, builder.getContext());
   }
   default:
     // TODO: Handle floating point.
@@ -636,7 +632,7 @@
     if (curToken.isNot(Token::bare_identifier, Token::inttype) &&
         !curToken.isKeyword())
       return emitError("expected attribute name");
-    auto nameId = Identifier::get(curToken.getSpelling(), context);
+    auto nameId = Identifier::get(curToken.getSpelling(), builder.getContext());
     consumeToken();
 
     if (!consumeIf(Token::colon))
@@ -690,17 +686,16 @@
 /// Create an affine op expression
 AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineHighPrecOp op,
                                                   AffineExpr *lhs,
-                                                  AffineExpr *rhs,
-                                                  MLIRContext *context) {
+                                                  AffineExpr *rhs) {
   switch (op) {
   case Mul:
-    return AffineMulExpr::get(lhs, rhs, context);
+    return AffineMulExpr::get(lhs, rhs, builder.getContext());
   case FloorDiv:
-    return AffineFloorDivExpr::get(lhs, rhs, context);
+    return AffineFloorDivExpr::get(lhs, rhs, builder.getContext());
   case CeilDiv:
-    return AffineCeilDivExpr::get(lhs, rhs, context);
+    return AffineCeilDivExpr::get(lhs, rhs, builder.getContext());
   case Mod:
-    return AffineModExpr::get(lhs, rhs, context);
+    return AffineModExpr::get(lhs, rhs, builder.getContext());
   case HNoOp:
     llvm_unreachable("can't create affine expression for null high prec op");
     return nullptr;
@@ -709,13 +704,12 @@
 
 AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineLowPrecOp op,
                                                   AffineExpr *lhs,
-                                                  AffineExpr *rhs,
-                                                  MLIRContext *context) {
+                                                  AffineExpr *rhs) {
   switch (op) {
   case AffineLowPrecOp::Add:
-    return AffineAddExpr::get(lhs, rhs, context);
+    return AffineAddExpr::get(lhs, rhs, builder.getContext());
   case AffineLowPrecOp::Sub:
-    return AffineSubExpr::get(lhs, rhs, context);
+    return AffineSubExpr::get(lhs, rhs, builder.getContext());
   case AffineLowPrecOp::LNoOp:
     llvm_unreachable("can't create affine expression for null low prec op");
     return nullptr;
@@ -762,8 +756,7 @@
     if (llhs) {
       // TODO(bondhugula): check whether 'lhs' here is a constant (for affine
       // maps); semi-affine maps allow symbols.
-      AffineExpr *expr =
-          Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+      AffineExpr *expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
       AffineExpr *subRes = nullptr;
       if (parseAffineHighPrecOpExpr(expr, op, state, subRes)) {
         if (!subRes)
@@ -794,7 +787,7 @@
   if (llhs) {
     // TODO(bondhugula): check whether lhs here is a constant (for affine
     // maps); semi-affine maps allow symbols.
-    result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+    result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
     return ParseSuccess;
   }
 
@@ -874,8 +867,7 @@
   AffineHighPrecOp rOp;
   if ((lOp = consumeIfLowPrecOp())) {
     if (llhs) {
-      AffineExpr *sum =
-          Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+      AffineExpr *sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
       AffineExpr *recSum = nullptr;
       parseAffineLowPrecOpExpr(sum, lOp, state, recSum);
       result = recSum ? recSum : sum;
@@ -903,7 +895,7 @@
     // found expression. If non-null, assume for now that the op to associate
     // with llhs is add.
     AffineExpr *expr =
-        llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes, context) : highRes;
+        llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes) : highRes;
     // Recurse for subsequent add's after the affine mul expression
     AffineLowPrecOp nextOp = consumeIfLowPrecOp();
     if (nextOp) {
@@ -917,7 +909,7 @@
   } else {
     // Last operand in the expression list.
     if (llhs) {
-      result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+      result = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
       return ParseSuccess;
     }
     // No llhs, 'lhs' itself is the expression.
@@ -951,11 +943,11 @@
     const auto &symbols = state.getSymbols();
     if (dims.count(sRef)) {
       consumeToken(Token::bare_identifier);
-      return AffineDimExpr::get(dims.lookup(sRef), context);
+      return AffineDimExpr::get(dims.lookup(sRef), builder.getContext());
     }
     if (symbols.count(sRef)) {
       consumeToken(Token::bare_identifier);
-      return AffineSymbolExpr::get(symbols.lookup(sRef), context);
+      return AffineSymbolExpr::get(symbols.lookup(sRef), builder.getContext());
     }
     return emitError("identifier is neither dimensional nor symbolic"), nullptr;
   }
@@ -968,7 +960,7 @@
 AffineExpr *Parser::parseIntegerExpr(const AffineMapParserState &state) {
   if (curToken.is(Token::integer)) {
     auto *expr = AffineConstantExpr::get(
-        curToken.getUnsignedIntegerValue().getValue(), context);
+        curToken.getUnsignedIntegerValue().getValue(), builder.getContext());
     consumeToken(Token::integer);
     return expr;
   }
@@ -1094,7 +1086,7 @@
 
   // Parsed a valid affine map.
   return AffineMap::get(state.getNumDims(), state.getNumSymbols(), exprs,
-                        context);
+                        builder.getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1187,7 +1179,7 @@
     if (parseTypeList(results))
       return ParseFailure;
   }
-  type = FunctionType::get(arguments, results, context);
+  type = builder.getFunctionType(arguments, results);
   return ParseSuccess;
 }
 
@@ -1214,11 +1206,13 @@
 /// function as we are parsing it, e.g. the names for basic blocks.  It handles
 /// forward references.
 class CFGFunctionParserState {
- public:
+public:
   CFGFunction *function;
   llvm::StringMap<std::pair<BasicBlock*, SMLoc>> blocksByName;
+  CFGFuncBuilder builder;
 
-  CFGFunctionParserState(CFGFunction *function) : function(function) {}
+  CFGFunctionParserState(CFGFunction *function)
+      : function(function), builder(function) {}
 
   /// Get the basic block with the specified name, creating it if it doesn't
   /// already exist.  The location specified is the point of use, which allows
@@ -1312,6 +1306,9 @@
   if (!consumeIf(Token::colon))
     return emitError("expected ':' after basic block name");
 
+  // Set the insertion point to the block we want to insert new operations into.
+  functionState.builder.setInsertionPoint(block);
+
   // Parse the list of operations that make up the body of the block.
   while (curToken.isNot(Token::kw_return, Token::kw_br)) {
     auto loc = curToken.getLoc();
@@ -1322,17 +1319,14 @@
     // We just parsed an operation.  If it is a recognized one, verify that it
     // is structurally as we expect.  If not, produce an error with a reasonable
     // source location.
-    if (auto *opInfo = inst->getAbstractOperation(context))
+    if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
       if (auto error = opInfo->verifyInvariants(inst))
         return emitError(loc, error);
-
-    block->getOperations().push_back(inst);
   }
 
   auto *term = parseTerminator(functionState);
   if (!term)
     return ParseFailure;
-  block->setTerminator(term);
 
   return ParseSuccess;
 }
@@ -1380,8 +1374,8 @@
   }
 
   // TODO: Don't drop result name and operand names on the floor.
-  auto nameId = Identifier::get(name, context);
-  return new OperationInst(nameId, attributes, context);
+  auto nameId = Identifier::get(name, builder.getContext());
+  return functionState.builder.createOperation(nameId, attributes);
 }
 
 
@@ -1400,7 +1394,7 @@
 
   case Token::kw_return:
     consumeToken(Token::kw_return);
-    return new ReturnInst();
+    return functionState.builder.createReturnInst();
 
   case Token::kw_br: {
     consumeToken(Token::kw_br);
@@ -1408,7 +1402,7 @@
                                               curToken.getLoc());
     if (!consumeIf(Token::bare_identifier))
       return (emitError("expected basic block name"), nullptr);
-    return new BranchInst(destBB);
+    return functionState.builder.createBranchInst(destBB);
   }
     // TODO: cond_br.
   }