Implement custom parser support for operations, enhance dim/addf to use it, and add a new load op.

This regresses parser error recovery in some cases (in invalid.mlir) which I'll
consider in a follow-up patch.  The important thing in this patch is that the
parse methods in StandardOps.cpp are nice and simple.

PiperOrigin-RevId: 206023308
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 52ad4ad..95d95dd 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -31,6 +31,7 @@
 
 namespace mlir {
 class Type;
+class OpAsmParser;
 class OpAsmPrinter;
 
 /// This pointer represents a notional "Operation*" but where the actual
@@ -68,6 +69,22 @@
   const OpType value;
 };
 
+/// This is the result type of parsing a custom operation.  If an error is
+/// emitted, it is fine to return this in a partially mutated state.
+struct OpAsmParserResult {
+  SmallVector<SSAValue *, 4> operands;
+  SmallVector<Type *, 4> types;
+  SmallVector<NamedAttribute, 4> attributes;
+
+  /*implicit*/ OpAsmParserResult() {}
+
+  OpAsmParserResult(ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
+                    ArrayRef<NamedAttribute> attributes = {})
+      : operands(operands.begin(), operands.end()),
+        types(types.begin(), types.end()),
+        attributes(attributes.begin(), attributes.end()) {}
+};
+
 //===----------------------------------------------------------------------===//
 // OpImpl Types
 //===----------------------------------------------------------------------===//
@@ -108,6 +125,9 @@
   /// back to this one which accepts everything.
   const char *verify() const { return nullptr; }
 
+  // Unless overridden, the short form of an op is always rejected.
+  static OpAsmParserResult parse(OpAsmParser *parser);
+
   // The fallback for the printer is to print it the longhand form.
   void print(OpAsmPrinter *p) const;
 
@@ -138,6 +158,12 @@
     return op->getName().is(ConcreteType::getOperationName());
   }
 
+  /// This is the hook used by the AsmParser to parse the custom form of this
+  /// op from an .mlir file.  Op implementations should provide a parse method.
+  static OpAsmParserResult parseAssembly(OpAsmParser *parser) {
+    return ConcreteType::parse(parser);
+  }
+
   /// This is the hook used by the AsmPrinter to emit this to the .mlir file.
   /// Op implementations should provide a print method.
   static void printAssembly(const Operation *op, OpAsmPrinter *p) {
diff --git a/include/mlir/IR/OpImplementation.h b/include/mlir/IR/OpImplementation.h
index 8979d9e..1ef4846 100644
--- a/include/mlir/IR/OpImplementation.h
+++ b/include/mlir/IR/OpImplementation.h
@@ -23,10 +23,17 @@
 #define MLIR_IR_OPIMPLEMENTATION_H
 
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SMLoc.h"
 
 namespace mlir {
 class AffineMap;
 class AffineExpr;
+class Builder;
+
+//===----------------------------------------------------------------------===//
+// OpAsmPrinter
+//===----------------------------------------------------------------------===//
 
 /// This is a pure-virtual base class that exposes the asmprinter hooks
 /// necessary to implement a custom print() method.
@@ -38,6 +45,19 @@
 
   /// Print implementations for various things an operation contains.
   virtual void printOperand(const SSAValue *value) = 0;
+
+  /// Print a comma separated list of operands.
+  template <typename ContainerType>
+  void printOperands(const ContainerType &container) {
+    auto it = container.begin(), end = container.end();
+    if (it == end)
+      return;
+    printOperand(*it);
+    for (++it; it != end; ++it) {
+      getStream() << ", ";
+      printOperand(*it);
+    }
+  }
   virtual void printType(const Type *type) = 0;
   virtual void printAttribute(const Attribute *attr) = 0;
   virtual void printAffineMap(const AffineMap *map) = 0;
@@ -78,6 +98,151 @@
   return p;
 }
 
+//===----------------------------------------------------------------------===//
+// OpAsmParser
+//===----------------------------------------------------------------------===//
+
+/// The OpAsmParser has methods for interacting with the asm parser: parsing
+/// things from it, emitting errors etc.  It has an intentionally high-level API
+/// that is designed to reduce/constrain syntax innovation in individual
+/// operations.
+///
+/// For example, consider an op like this:
+///
+///    %x = load %p[%1, %2] : memref<...>
+///
+/// The "%x = load" tokens are already parsed and therefore invisible to the
+/// custom op parser.  This can be supported by calling `parseOperandList` to
+/// parse the %p, then calling `parseOperandList` with a `SquareDelimeter` to
+/// parse the indices, then calling `parseColonTypeList` to parse the result
+/// type.
+///
+class OpAsmParser {
+public:
+  virtual ~OpAsmParser();
+
+  //===--------------------------------------------------------------------===//
+  // High level parsing methods.
+  //===--------------------------------------------------------------------===//
+
+  // These return void if they always succeed.  If they can fail, they emit an
+  // error and return "true".  On success, they can optionally provide location
+  // information for clients who want it.
+
+  /// This parses... a comma!
+  virtual bool parseComma(llvm::SMLoc *loc = nullptr) = 0;
+
+  /// Parse a colon followed by a type.
+  virtual bool parseColonType(Type *&result, llvm::SMLoc *loc = nullptr) = 0;
+
+  /// Parse a type of a specific kind, e.g. a FunctionType.
+  template <typename TypeType>
+  bool parseColonType(TypeType *&result, llvm::SMLoc *loc = nullptr) {
+    // Parse any kind of type.
+    Type *type;
+    llvm::SMLoc tmpLoc;
+    if (parseColonType(type, &tmpLoc))
+      return true;
+    if (loc)
+      *loc = tmpLoc;
+
+    // Check for the right kind of attribute.
+    result = dyn_cast<TypeType>(type);
+    if (!result) {
+      emitError(tmpLoc, "invalid kind of type specified");
+      return true;
+    }
+
+    return false;
+  }
+
+  /// Parse a colon followed by a type list, which must have at least one type.
+  virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result,
+                                  llvm::SMLoc *loc = nullptr) = 0;
+
+  /// Parse an attribute.
+  virtual bool parseAttribute(Attribute *&result,
+                              llvm::SMLoc *loc = nullptr) = 0;
+
+  /// Parse an attribute of a specific kind.
+  template <typename AttrType>
+  bool parseAttribute(AttrType *&result, llvm::SMLoc *loc = nullptr) {
+    // Parse any kind of attribute.
+    Attribute *attr;
+    llvm::SMLoc tmpLoc;
+    if (parseAttribute(attr, &tmpLoc))
+      return true;
+    if (loc)
+      *loc = tmpLoc;
+
+    // Check for the right kind of attribute.
+    result = dyn_cast<AttrType>(attr);
+    if (!result) {
+      emitError(tmpLoc, "invalid kind of constant specified");
+      return true;
+    }
+
+    return false;
+  }
+
+  /// This is the representation of an operand reference.
+  struct OperandType {
+    llvm::SMLoc location; // Location of the token.
+    StringRef name;       // Value name, e.g. %42 or %abc
+    unsigned number;      // Number, e.g. 12 for an operand like %xyz#12
+  };
+
+  /// Parse a single operand.
+  virtual bool parseOperand(OperandType &result) = 0;
+
+  /// These are the supported delimeters around operand lists, used by
+  /// parseOperandList.
+  enum Delimeter {
+    NoDelimeter,
+    ParenDelimeter,
+    SquareDelimeter,
+  };
+
+  /// Parse zero or more SSA comma-separated operand references with a specified
+  /// surrounding delimeter, and an optional required operand count.
+  virtual bool
+  parseOperandList(SmallVectorImpl<OperandType> &result,
+                   int requiredOperandCount = -1,
+                   Delimeter delimeter = Delimeter::NoDelimeter) = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Methods for interacting with the parser
+  //===--------------------------------------------------------------------===//
+
+  /// Return a builder which provides useful access to MLIRContext, global
+  /// objects like types and attributes.
+  virtual Builder &getBuilder() const = 0;
+
+  /// Return the location of the original name token.
+  virtual llvm::SMLoc getNameLoc() const = 0;
+
+  /// Resolve an operand to an SSA value, emitting an error and returning true
+  /// on failure.
+  virtual bool resolveOperand(OperandType operand, Type *type,
+                              SSAValue *&result) = 0;
+
+  /// Resolve a list of operands to SSA values, emitting an error and returning
+  /// true on failure, or appending the results to the list on success.
+  virtual bool resolveOperands(ArrayRef<OperandType> operand, Type *type,
+                               SmallVectorImpl<SSAValue *> &result) {
+    for (auto elt : operand) {
+      SSAValue *value;
+      if (resolveOperand(elt, type, value))
+        return true;
+      result.push_back(value);
+    }
+    return false;
+  }
+
+  /// Emit a diagnostic at the specified location.
+  virtual void emitError(llvm::SMLoc loc, const Twine &message) = 0;
+};
+
 } // end namespace mlir
 
 #endif
diff --git a/include/mlir/IR/OperationSet.h b/include/mlir/IR/OperationSet.h
index 7f55517..59dbfb9 100644
--- a/include/mlir/IR/OperationSet.h
+++ b/include/mlir/IR/OperationSet.h
@@ -27,6 +27,8 @@
 
 namespace mlir {
 class Operation;
+class OpAsmParser;
+class OpAsmParserResult;
 class OpAsmPrinter;
 class MLIRContextImpl;
 class MLIRContext;
@@ -40,7 +42,8 @@
   template <typename T>
   static AbstractOperation get() {
     return AbstractOperation(T::getOperationName(), T::isClassFor,
-                             T::printAssembly, T::verifyInvariants);
+                             T::parseAssembly, T::printAssembly,
+                             T::verifyInvariants);
   }
 
   /// This is the name of the operation.
@@ -49,6 +52,9 @@
   /// Return true if this "op class" can match against the specified operation.
   bool (&isClassFor)(const Operation *op);
 
+  /// Use the specified object to parse this ops custom assembly format.
+  OpAsmParserResult (&parseAssembly)(OpAsmParser *parser);
+
   /// This hook implements the AsmPrinter for this operation.
   void (&printAssembly)(const Operation *op, OpAsmPrinter *p);
 
@@ -60,10 +66,11 @@
 
 private:
   AbstractOperation(StringRef name, bool (&isClassFor)(const Operation *op),
+                    OpAsmParserResult (&parseAssembly)(OpAsmParser *parser),
                     void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
                     const char *(&verifyInvariants)(const Operation *op))
-      : name(name), isClassFor(isClassFor), printAssembly(printAssembly),
-        verifyInvariants(verifyInvariants) {}
+      : name(name), isClassFor(isClassFor), parseAssembly(parseAssembly),
+        printAssembly(printAssembly), verifyInvariants(verifyInvariants) {}
 };
 
 /// An instance of OperationSet is owned and maintained by MLIRContext.  It
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index 4b2f4fa..a3b4fcc 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -43,6 +43,7 @@
   static StringRef getOperationName() { return "addf"; }
 
   const char *verify() const;
+  static OpAsmParserResult parse(OpAsmParser *parser);
   void print(OpAsmPrinter *p) const;
 
 private:
@@ -109,6 +110,7 @@
 
   // Hooks to customize behavior of this op.
   const char *verify() const;
+  static OpAsmParserResult parse(OpAsmParser *parser);
   void print(OpAsmPrinter *p) const;
 
 private:
@@ -116,15 +118,49 @@
   explicit DimOp(const Operation *state) : Base(state) {}
 };
 
-// The "affine_apply" operation applies an affine map to a list of operands,
-// yielding a list of results. The operand and result list sizes must be the
-// same. All operands and results are of type 'AffineInt'. This operation
-// requires a single affine map attribute named "map".
-// For example:
-//
-//   %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
-//          (affineint) -> (affineint)
-//
+/// The "load" op reads an element from a memref specified by an index list. The
+/// output of load is a new value with the same type as the elements of the
+/// memref. The arity of indices is the rank of the memref (i.e., if the memref
+/// loaded from is of rank 3, then 3 indices are required for the load following
+/// the memref identifier).  For example:
+///
+///   %3 = load %0[%1, %1] : memref<4x4xi32>
+///
+class LoadOp
+    : public OpImpl::Base<LoadOp, OpImpl::VariadicOperands, OpImpl::OneResult> {
+public:
+  SSAValue *getMemRef() { return getOperand(0); }
+  const SSAValue *getMemRef() const { return getOperand(0); }
+
+  llvm::iterator_range<Operation::operand_iterator> getIndices() {
+    return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+  }
+
+  llvm::iterator_range<Operation::const_operand_iterator> getIndices() const {
+    return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+  }
+
+  static StringRef getOperationName() { return "load"; }
+
+  // Hooks to customize behavior of this op.
+  const char *verify() const;
+  static OpAsmParserResult parse(OpAsmParser *parser);
+  void print(OpAsmPrinter *p) const;
+
+private:
+  friend class Operation;
+  explicit LoadOp(const Operation *state) : Base(state) {}
+};
+
+/// The "affine_apply" operation applies an affine map to a list of operands,
+/// yielding a list of results. The operand and result list sizes must be the
+/// same. All operands and results are of type 'AffineInt'. This operation
+/// requires a single affine map attribute named "map".
+/// For example:
+///
+///   %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
+///          (affineint) -> (affineint)
+///
 class AffineApplyOp
     : public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
                           OpImpl::VariadicResults> {
diff --git a/lib/IR/OperationSet.cpp b/lib/IR/OperationSet.cpp
index c8e1e9b..43a65aa 100644
--- a/lib/IR/OperationSet.cpp
+++ b/lib/IR/OperationSet.cpp
@@ -19,10 +19,19 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/Twine.h"
 #include "llvm/Support/raw_ostream.h"
 using namespace mlir;
 using llvm::StringMap;
 
+OpAsmParser::~OpAsmParser() {}
+
+// The fallback for the printer is to reject the short form.
+OpAsmParserResult OpImpl::BaseState::parse(OpAsmParser *parser) {
+  parser->emitError(parser->getNameLoc(), "has no concise form");
+  return {};
+}
+
 // The fallback for the printer is to print it the longhand form.
 void OpImpl::BaseState::print(OpAsmPrinter *p) const {
   p->printDefaultOp(getOperation());
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 4da0d60..069f3ed 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -17,6 +17,7 @@
 
 #include "mlir/IR/StandardOps.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/SSAValue.h"
@@ -26,6 +27,18 @@
 
 // TODO: Have verify functions return std::string to enable more descriptive
 // error messages.
+OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
+  SmallVector<OpAsmParser::OperandType, 2> ops;
+  Type *type;
+  SSAValue *lhs, *rhs;
+  if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
+      parser->resolveOperand(ops[0], type, lhs) ||
+      parser->resolveOperand(ops[1], type, rhs))
+    return {};
+
+  return OpAsmParserResult({lhs, rhs}, type);
+}
+
 void AddFOp::print(OpAsmPrinter *p) const {
   *p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
      << *getType();
@@ -71,6 +84,22 @@
      << *getOperand()->getType();
 }
 
+OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
+  OpAsmParser::OperandType operandInfo;
+  IntegerAttr *indexAttr;
+  Type *type;
+  SSAValue *operand;
+  if (parser->parseOperand(operandInfo) || parser->parseComma() ||
+      parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
+      parser->resolveOperand(operandInfo, type, operand))
+    return {};
+
+  auto &builder = parser->getBuilder();
+  return OpAsmParserResult(
+      operand, builder.getAffineIntType(),
+      NamedAttribute(builder.getIdentifier("index"), indexAttr));
+}
+
 const char *DimOp::verify() const {
   // Check that we have an integer index operand.
   auto indexAttr = getAttrOfType<IntegerAttr>("index");
@@ -95,6 +124,35 @@
   return nullptr;
 }
 
+void LoadOp::print(OpAsmPrinter *p) const {
+  *p << "load " << *getMemRef() << '[';
+  p->printOperands(getIndices());
+  *p << "] : " << *getMemRef()->getType();
+}
+
+OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
+  OpAsmParser::OperandType memrefInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  MemRefType *type;
+  SmallVector<SSAValue *, 4> operands;
+
+  auto affineIntTy = parser->getBuilder().getAffineIntType();
+  if (parser->parseOperand(memrefInfo) ||
+      parser->parseOperandList(indexInfo, -1,
+                               OpAsmParser::Delimeter::SquareDelimeter) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperands(memrefInfo, type, operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, operands))
+    return {};
+
+  return OpAsmParserResult(operands, type->getElementType());
+}
+
+const char *LoadOp::verify() const {
+  // TODO: Check load
+  return nullptr;
+}
+
 void AffineApplyOp::print(OpAsmPrinter *p) const {
   // TODO: Print operands etc.
   *p << "affine_apply map: " << *getAffineMap();
@@ -122,5 +180,6 @@
 
 /// Install the standard operations in the specified operation set.
 void mlir::registerStandardOperations(OperationSet &opSet) {
-  opSet.addOperations<AddFOp, ConstantOp, DimOp, AffineApplyOp>(/*prefix=*/"");
+  opSet.addOperations<AddFOp, ConstantOp, DimOp, LoadOp, AffineApplyOp>(
+      /*prefix=*/"");
 }
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 91596d3..011dfcb 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -77,8 +77,10 @@
   case ')': return formToken(Token::r_paren, tokStart);
   case '{': return formToken(Token::l_brace, tokStart);
   case '}': return formToken(Token::r_brace, tokStart);
-  case '[': return formToken(Token::l_bracket, tokStart);
-  case ']': return formToken(Token::r_bracket, tokStart);
+  case '[':
+    return formToken(Token::l_square, tokStart);
+  case ']':
+    return formToken(Token::r_square, tokStart);
   case '<': return formToken(Token::less, tokStart);
   case '>': return formToken(Token::greater, tokStart);
   case '=': return formToken(Token::equal, tokStart);
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3c655d1..63ebeff 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/Types.h"
@@ -52,7 +53,8 @@
               SMDiagnosticHandlerTy errorReporter)
       : context(module->getContext()), module(module),
         lex(sourceMgr, errorReporter), curToken(lex.lexToken()),
-        errorReporter(errorReporter) {}
+        errorReporter(errorReporter), operationSet(OperationSet::get(context)) {
+  }
 
   // A map from affine map identifier to AffineMap.
   llvm::StringMap<AffineMap *> affineMapDefinitions;
@@ -77,6 +79,9 @@
 
   // The diagnostic error reporter.
   SMDiagnosticHandlerTy const errorReporter;
+
+  // The active OperationSet we're parsing with.
+  OperationSet &operationSet;
 };
 } // end anonymous namespace
 
@@ -99,6 +104,7 @@
   ParserState &getState() const { return state; }
   MLIRContext *getContext() const { return state.context; }
   Module *getModule() { return state.module; }
+  OperationSet &getOperationSet() const { return state.operationSet; }
 
   /// Return the current token the parser is inspecting.
   const Token &getToken() const { return state.curToken; }
@@ -611,8 +617,8 @@
     return builder.getStringAttr(val);
   }
 
-  case Token::l_bracket: {
-    consumeToken(Token::l_bracket);
+  case Token::l_square: {
+    consumeToken(Token::l_square);
     SmallVector<Attribute *, 4> elements;
 
     auto parseElt = [&]() -> ParseResult {
@@ -620,7 +626,7 @@
       return elements.back() ? ParseSuccess : ParseFailure;
     };
 
-    if (parseCommaSeparatedListUntil(Token::r_bracket, parseElt))
+    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
       return nullptr;
     return builder.getArrayAttr(elements);
   }
@@ -1101,11 +1107,11 @@
 
 /// Parse the list of symbolic identifiers to an affine map.
 ParseResult AffineMapParser::parseSymbolIdList() {
-  if (parseToken(Token::l_bracket, "expected '['"))
+  if (parseToken(Token::l_square, "expected '['"))
     return ParseFailure;
 
   auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
-  return parseCommaSeparatedListUntil(Token::r_bracket, parseElt);
+  return parseCommaSeparatedListUntil(Token::r_square, parseElt);
 }
 
 /// Parse the list of dimensional identifiers to an affine map.
@@ -1131,7 +1137,7 @@
     return nullptr;
 
   // Symbols are optional.
-  if (getToken().is(Token::l_bracket)) {
+  if (getToken().is(Token::l_square)) {
     if (parseSymbolIdList())
       return nullptr;
   }
@@ -1260,6 +1266,8 @@
 
   // Operations
   ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
+  Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc);
+  Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
 
 private:
   /// This keeps track of all of the SSA values we are tracking, indexed by
@@ -1416,7 +1424,7 @@
 ///
 ParseResult
 FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
-  if (!getToken().is(Token::percent_identifier))
+  if (getToken().isNot(Token::percent_identifier))
     return ParseSuccess;
   return parseCommaSeparatedList([&]() -> ParseResult {
     SSAUseInfo result;
@@ -1514,60 +1522,15 @@
       return ParseFailure;
   }
 
-  if (getToken().isNot(Token::string))
+  Operation *op;
+  if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
+    op = parseCustomOperation(createOpFunc);
+  else if (getToken().is(Token::string))
+    op = parseVerboseOperation(createOpFunc);
+  else
     return emitError("expected operation name in quotes");
 
-  auto name = getToken().getStringValue();
-  if (name.empty())
-    return emitError("empty operation name is invalid");
-
-  consumeToken(Token::string);
-
-  // Parse the operand list.
-  SmallVector<SSAUseInfo, 8> operandInfos;
-
-  if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
-      parseOptionalSSAUseList(operandInfos) ||
-      parseToken(Token::r_paren, "expected ')' to end operand list")) {
-    return ParseFailure;
-  }
-
-  SmallVector<NamedAttribute, 4> attributes;
-  if (getToken().is(Token::l_brace)) {
-    if (parseAttributeDict(attributes))
-      return ParseFailure;
-  }
-
-  if (parseToken(Token::colon, "expected ':' followed by instruction type"))
-    return ParseFailure;
-
-  auto typeLoc = getToken().getLoc();
-  auto type = parseType();
-  if (!type)
-    return ParseFailure;
-  auto fnType = dyn_cast<FunctionType>(type);
-  if (!fnType)
-    return emitError(typeLoc, "expected function type");
-
-  // Check that we have the right number of types for the operands.
-  auto operandTypes = fnType->getInputs();
-  if (operandTypes.size() != operandInfos.size()) {
-    auto plural = "s"[operandInfos.size() == 1];
-    return emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
-                                  " operand type" + plural + " but had " +
-                                  llvm::utostr(operandTypes.size()));
-  }
-
-  // Resolve all of the operands.
-  SmallVector<SSAValue *, 8> operands;
-  for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
-    operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
-    if (!operands.back())
-      return ParseFailure;
-  }
-
-  auto nameId = builder.getIdentifier(name);
-  auto op = createOpFunc(nameId, operands, fnType->getResults(), attributes);
+  // If parsing of the basic operation failed, then this whole thing fails.
   if (!op)
     return ParseFailure;
 
@@ -1595,6 +1558,228 @@
   return ParseSuccess;
 }
 
+Operation *FunctionParser::parseVerboseOperation(
+    const CreateOperationFunction &createOpFunc) {
+  auto name = getToken().getStringValue();
+  if (name.empty())
+    return (emitError("empty operation name is invalid"), nullptr);
+
+  consumeToken(Token::string);
+
+  // Parse the operand list.
+  SmallVector<SSAUseInfo, 8> operandInfos;
+
+  if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+      parseOptionalSSAUseList(operandInfos) ||
+      parseToken(Token::r_paren, "expected ')' to end operand list")) {
+    return nullptr;
+  }
+
+  SmallVector<NamedAttribute, 4> attributes;
+  if (getToken().is(Token::l_brace)) {
+    if (parseAttributeDict(attributes))
+      return nullptr;
+  }
+
+  if (parseToken(Token::colon, "expected ':' followed by instruction type"))
+    return nullptr;
+
+  auto typeLoc = getToken().getLoc();
+  auto type = parseType();
+  if (!type)
+    return nullptr;
+  auto fnType = dyn_cast<FunctionType>(type);
+  if (!fnType)
+    return (emitError(typeLoc, "expected function type"), nullptr);
+
+  // Check that we have the right number of types for the operands.
+  auto operandTypes = fnType->getInputs();
+  if (operandTypes.size() != operandInfos.size()) {
+    auto plural = "s"[operandInfos.size() == 1];
+    return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
+                                   " operand type" + plural + " but had " +
+                                   llvm::utostr(operandTypes.size())),
+            nullptr);
+  }
+
+  // Resolve all of the operands.
+  SmallVector<SSAValue *, 8> operands;
+  for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
+    operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+    if (!operands.back())
+      return nullptr;
+  }
+
+  auto nameId = builder.getIdentifier(name);
+  return createOpFunc(nameId, operands, fnType->getResults(), attributes);
+}
+
+namespace {
+class CustomOpAsmParser : public OpAsmParser {
+public:
+  CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
+      : nameLoc(nameLoc), opName(opName), parser(parser) {}
+
+  /// This is an internal helper to parser a colon, we don't want to expose
+  /// this to clients.
+  bool internalParseColon(llvm::SMLoc *loc) {
+    if (loc)
+      *loc = parser.getToken().getLoc();
+    return parser.parseToken(Token::colon, "expected ':'");
+  }
+
+  //===--------------------------------------------------------------------===//
+  // High level parsing methods.
+  //===--------------------------------------------------------------------===//
+
+  bool parseComma(llvm::SMLoc *loc = nullptr) override {
+    if (loc)
+      *loc = parser.getToken().getLoc();
+    return parser.parseToken(Token::comma, "expected ','");
+  }
+
+  bool parseColonType(Type *&result, llvm::SMLoc *loc = nullptr) override {
+    return internalParseColon(loc) || !(result = parser.parseType());
+  }
+
+  bool parseColonTypeList(SmallVectorImpl<Type *> &result,
+                          llvm::SMLoc *loc = nullptr) override {
+    if (internalParseColon(loc))
+      return true;
+
+    do {
+      if (auto *type = parser.parseType())
+        result.push_back(type);
+      else
+        return true;
+
+    } while (parser.consumeIf(Token::comma));
+    return false;
+  }
+
+  bool parseAttribute(Attribute *&result, llvm::SMLoc *loc = nullptr) override {
+    if (loc)
+      *loc = parser.getToken().getLoc();
+    result = parser.parseAttribute();
+    return result == nullptr;
+  }
+
+  bool parseOperand(OperandType &result) override {
+    FunctionParser::SSAUseInfo useInfo;
+    if (parser.parseSSAUse(useInfo))
+      return true;
+
+    result = {useInfo.loc, useInfo.name, useInfo.number};
+    return false;
+  }
+
+  bool parseOperandList(SmallVectorImpl<OperandType> &result,
+                        int requiredOperandCount = -1,
+                        Delimeter delimeter = Delimeter::NoDelimeter) override {
+    auto startLoc = parser.getToken().getLoc();
+
+    // Handle delimeters.
+    switch (delimeter) {
+    case Delimeter::NoDelimeter:
+      break;
+    case Delimeter::ParenDelimeter:
+      if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
+        return true;
+      break;
+    case Delimeter::SquareDelimeter:
+      if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
+        return true;
+      break;
+    }
+
+    // Check for zero operands.
+    if (parser.getToken().is(Token::percent_identifier)) {
+      do {
+        OperandType operand;
+        if (parseOperand(operand))
+          return true;
+        result.push_back(operand);
+      } while (parser.consumeIf(Token::comma));
+    }
+
+    // Handle delimeters.
+    switch (delimeter) {
+    case Delimeter::NoDelimeter:
+      break;
+    case Delimeter::ParenDelimeter:
+      if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
+        return true;
+      break;
+    case Delimeter::SquareDelimeter:
+      if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+        return true;
+      break;
+    }
+
+    if (requiredOperandCount != -1 && result.size() != requiredOperandCount)
+      emitError(startLoc,
+                "expected " + Twine(requiredOperandCount) + " operands");
+    return false;
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Methods for interacting with the parser
+  //===--------------------------------------------------------------------===//
+
+  Builder &getBuilder() const override { return parser.builder; }
+
+  llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+  bool resolveOperand(OperandType operand, Type *type,
+                      SSAValue *&result) override {
+    FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+                                              operand.location};
+    result = parser.resolveSSAUse(operandInfo, type);
+    return result == nullptr;
+  }
+
+  /// Emit a diagnostic at the specified location.
+  void emitError(llvm::SMLoc loc, const Twine &message) override {
+    parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
+    emittedError = true;
+  }
+
+  bool didEmitError() const { return emittedError; }
+
+private:
+  SMLoc nameLoc;
+  StringRef opName;
+  FunctionParser &parser;
+  bool emittedError = false;
+};
+} // end anonymous namespace.
+
+Operation *FunctionParser::parseCustomOperation(
+    const CreateOperationFunction &createOpFunc) {
+  auto opLoc = getToken().getLoc();
+  auto opName = getTokenSpelling();
+  CustomOpAsmParser opAsmParser(opLoc, opName, *this);
+
+  auto *opDefinition = getOperationSet().lookup(opName);
+  if (!opDefinition) {
+    opAsmParser.emitError(opLoc, "is unknown");
+    return nullptr;
+  }
+
+  consumeToken();
+
+  // Have the op implementation take a crack and parsing this.
+  auto result = opDefinition->parseAssembly(&opAsmParser);
+
+  // If it emitted an error, we failed.
+  if (opAsmParser.didEmitError())
+    return nullptr;
+
+  // Otherwise, we succeeded.  Use the state it parsed as our op information.
+  auto nameId = builder.getIdentifier(opName);
+  return createOpFunc(nameId, result.operands, result.types, result.attributes);
+}
+
 //===----------------------------------------------------------------------===//
 // CFG Functions
 //===----------------------------------------------------------------------===//
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index b9ef9b05..44a40ee 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -71,8 +71,8 @@
 TOK_PUNCTUATION(r_paren,          ")")
 TOK_PUNCTUATION(l_brace,          "{")
 TOK_PUNCTUATION(r_brace,          "}")
-TOK_PUNCTUATION(l_bracket,        "[")
-TOK_PUNCTUATION(r_bracket,        "]")
+TOK_PUNCTUATION(l_square,         "[")
+TOK_PUNCTUATION(r_square,         "]")
 TOK_PUNCTUATION(less,             "<")
 TOK_PUNCTUATION(greater,          ">")
 TOK_PUNCTUATION(equal,            "=")
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index ab82e4b..5618239 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -2,6 +2,7 @@
 
 // CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), (d1 + 2))
 #map5 = (d0, d1) -> (d0 + 1, d1 + 2)
+#id2 = (i,j)->(i,j)
 
 // CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
 cfgfunc @cfgfunc_with_ops(f32) {
@@ -19,22 +20,23 @@
   return
 }
 
-// CHECK-LABEL: cfgfunc @standard_instrs() {
-cfgfunc @standard_instrs() {
-bb42:       // CHECK: bb0:
-  // CHECK: %0 = "getTensor"() : () -> tensor<4x4x?xf32>
-  %42 = "getTensor"() : () -> tensor<4x4x?xf32>
+// CHECK-LABEL: cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32) {
+cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32) {
+// CHECK: bb0(%0: tensor<4x4x?xf32>, %1: f32):
+bb42(%t: tensor<4x4x?xf32>, %f: f32):
+  // CHECK: %2 = dim %0, 2 : tensor<4x4x?xf32>
+  %a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
 
-  // CHECK: dim %0, 2 : tensor<4x4x?xf32>
-  %a = "dim"(%42){index: 2} : (tensor<4x4x?xf32>) -> affineint
+  // CHECK: %3 = dim %0, 2 : tensor<4x4x?xf32>
+  %a2 = dim %t, 2 : tensor<4x4x?xf32>
 
-  // FIXME: Add support for fp attributes so this can use 'constant'.
-  %f = "FIXMEConst"(){value: 1} : () -> f32
+  // CHECK: %4 = addf %1, %1 : f32
+  %f2 = "addf"(%f, %f) : (f32,f32) -> f32
 
-  // CHECK: %3 = addf %2, %2 : f32
-  "addf"(%f, %f) : (f32,f32) -> f32
+  // CHECK: %5 = addf %4, %4 : f32
+  %f3 = addf %f2, %f2 : f32
 
-  // CHECK: %4 = "constant"(){value: 42} : () -> i32
+  // CHECK: %6 = "constant"(){value: 42} : () -> i32
   %x = "constant"(){value: 42} : () -> i32
   return
 }
@@ -53,4 +55,18 @@
   %y = "affine_apply" (%i, %j) { map: #map5 } :
     (affineint, affineint) -> (affineint, affineint)
   return
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: cfgfunc @load_store
+cfgfunc @load_store(memref<4x4xi32, #id2, 0>, affineint) {
+bb0(%0: memref<4x4xi32, #id2, 0>, %1: affineint):
+
+  // CHECK: %2 = load %0[%1, %1] : memref<4x4xi32, #map2, 0>
+  %2 = "load"(%0, %1, %1) : (memref<4x4xi32, #id2, 0>, affineint, affineint)->i32
+
+  // CHECK: %3 = load %0[%1, %1] : memref<4x4xi32, #map2, 0>
+  %3 = load %0[%1, %1] : memref<4x4xi32, #id2, 0>
+
+  return
+}
+
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index 4aafb21..2991d63 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -60,3 +60,11 @@
   %x = "affine_apply" (%i, %j) {map: (d0, d1) -> ((d0 + 1), (d1 + 2))} : (affineint,affineint) -> (affineint) //  expected-error {{'affine_apply' op result count and affine map result count must match}}
   return
 }
+
+// -----
+
+cfgfunc @unknown_custom_op() {
+bb0:
+  %i = crazyThing() {value: 0} : () -> affineint  // expected-error {{custom op 'crazyThing' is unknown}}
+  return
+}
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index c82f0ec..2e87540 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -85,7 +85,7 @@
 bb40:
   return
 bb41:
-bb42:        // expected-error {{expected operation name}}
+bb42:        // expected-error {{custom op 'bb42' is unknown}}
   return
 }
 
@@ -170,7 +170,7 @@
 // -----
 
 mlfunc @non_statement() {
-  asd   // expected-error {{expected operation name in quotes}}
+  asd   // expected-error {{custom op 'asd' is unknown}}
 }
 
 // -----