Implement custom parser support for operations, enhance dim/addf to use it, and add a new load op.

This regresses parser error recovery in some cases (in invalid.mlir) which I'll
consider in a follow-up patch.  The important thing in this patch is that the
parse methods in StandardOps.cpp are nice and simple.

PiperOrigin-RevId: 206023308
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3c655d1..63ebeff 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/Types.h"
@@ -52,7 +53,8 @@
               SMDiagnosticHandlerTy errorReporter)
       : context(module->getContext()), module(module),
         lex(sourceMgr, errorReporter), curToken(lex.lexToken()),
-        errorReporter(errorReporter) {}
+        errorReporter(errorReporter), operationSet(OperationSet::get(context)) {
+  }
 
   // A map from affine map identifier to AffineMap.
   llvm::StringMap<AffineMap *> affineMapDefinitions;
@@ -77,6 +79,9 @@
 
   // The diagnostic error reporter.
   SMDiagnosticHandlerTy const errorReporter;
+
+  // The active OperationSet we're parsing with.
+  OperationSet &operationSet;
 };
 } // end anonymous namespace
 
@@ -99,6 +104,7 @@
   ParserState &getState() const { return state; }
   MLIRContext *getContext() const { return state.context; }
   Module *getModule() { return state.module; }
+  OperationSet &getOperationSet() const { return state.operationSet; }
 
   /// Return the current token the parser is inspecting.
   const Token &getToken() const { return state.curToken; }
@@ -611,8 +617,8 @@
     return builder.getStringAttr(val);
   }
 
-  case Token::l_bracket: {
-    consumeToken(Token::l_bracket);
+  case Token::l_square: {
+    consumeToken(Token::l_square);
     SmallVector<Attribute *, 4> elements;
 
     auto parseElt = [&]() -> ParseResult {
@@ -620,7 +626,7 @@
       return elements.back() ? ParseSuccess : ParseFailure;
     };
 
-    if (parseCommaSeparatedListUntil(Token::r_bracket, parseElt))
+    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
       return nullptr;
     return builder.getArrayAttr(elements);
   }
@@ -1101,11 +1107,11 @@
 
 /// Parse the list of symbolic identifiers to an affine map.
 ParseResult AffineMapParser::parseSymbolIdList() {
-  if (parseToken(Token::l_bracket, "expected '['"))
+  if (parseToken(Token::l_square, "expected '['"))
     return ParseFailure;
 
   auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
-  return parseCommaSeparatedListUntil(Token::r_bracket, parseElt);
+  return parseCommaSeparatedListUntil(Token::r_square, parseElt);
 }
 
 /// Parse the list of dimensional identifiers to an affine map.
@@ -1131,7 +1137,7 @@
     return nullptr;
 
   // Symbols are optional.
-  if (getToken().is(Token::l_bracket)) {
+  if (getToken().is(Token::l_square)) {
     if (parseSymbolIdList())
       return nullptr;
   }
@@ -1260,6 +1266,8 @@
 
   // Operations
   ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
+  Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc);
+  Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
 
 private:
   /// This keeps track of all of the SSA values we are tracking, indexed by
@@ -1416,7 +1424,7 @@
 ///
 ParseResult
 FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
-  if (!getToken().is(Token::percent_identifier))
+  if (getToken().isNot(Token::percent_identifier))
     return ParseSuccess;
   return parseCommaSeparatedList([&]() -> ParseResult {
     SSAUseInfo result;
@@ -1514,60 +1522,15 @@
       return ParseFailure;
   }
 
-  if (getToken().isNot(Token::string))
+  Operation *op;
+  if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
+    op = parseCustomOperation(createOpFunc);
+  else if (getToken().is(Token::string))
+    op = parseVerboseOperation(createOpFunc);
+  else
     return emitError("expected operation name in quotes");
 
-  auto name = getToken().getStringValue();
-  if (name.empty())
-    return emitError("empty operation name is invalid");
-
-  consumeToken(Token::string);
-
-  // Parse the operand list.
-  SmallVector<SSAUseInfo, 8> operandInfos;
-
-  if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
-      parseOptionalSSAUseList(operandInfos) ||
-      parseToken(Token::r_paren, "expected ')' to end operand list")) {
-    return ParseFailure;
-  }
-
-  SmallVector<NamedAttribute, 4> attributes;
-  if (getToken().is(Token::l_brace)) {
-    if (parseAttributeDict(attributes))
-      return ParseFailure;
-  }
-
-  if (parseToken(Token::colon, "expected ':' followed by instruction type"))
-    return ParseFailure;
-
-  auto typeLoc = getToken().getLoc();
-  auto type = parseType();
-  if (!type)
-    return ParseFailure;
-  auto fnType = dyn_cast<FunctionType>(type);
-  if (!fnType)
-    return emitError(typeLoc, "expected function type");
-
-  // Check that we have the right number of types for the operands.
-  auto operandTypes = fnType->getInputs();
-  if (operandTypes.size() != operandInfos.size()) {
-    auto plural = "s"[operandInfos.size() == 1];
-    return emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
-                                  " operand type" + plural + " but had " +
-                                  llvm::utostr(operandTypes.size()));
-  }
-
-  // Resolve all of the operands.
-  SmallVector<SSAValue *, 8> operands;
-  for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
-    operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
-    if (!operands.back())
-      return ParseFailure;
-  }
-
-  auto nameId = builder.getIdentifier(name);
-  auto op = createOpFunc(nameId, operands, fnType->getResults(), attributes);
+  // If parsing of the basic operation failed, then this whole thing fails.
   if (!op)
     return ParseFailure;
 
@@ -1595,6 +1558,228 @@
   return ParseSuccess;
 }
 
+Operation *FunctionParser::parseVerboseOperation(
+    const CreateOperationFunction &createOpFunc) {
+  auto name = getToken().getStringValue();
+  if (name.empty())
+    return (emitError("empty operation name is invalid"), nullptr);
+
+  consumeToken(Token::string);
+
+  // Parse the operand list.
+  SmallVector<SSAUseInfo, 8> operandInfos;
+
+  if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+      parseOptionalSSAUseList(operandInfos) ||
+      parseToken(Token::r_paren, "expected ')' to end operand list")) {
+    return nullptr;
+  }
+
+  SmallVector<NamedAttribute, 4> attributes;
+  if (getToken().is(Token::l_brace)) {
+    if (parseAttributeDict(attributes))
+      return nullptr;
+  }
+
+  if (parseToken(Token::colon, "expected ':' followed by instruction type"))
+    return nullptr;
+
+  auto typeLoc = getToken().getLoc();
+  auto type = parseType();
+  if (!type)
+    return nullptr;
+  auto fnType = dyn_cast<FunctionType>(type);
+  if (!fnType)
+    return (emitError(typeLoc, "expected function type"), nullptr);
+
+  // Check that we have the right number of types for the operands.
+  auto operandTypes = fnType->getInputs();
+  if (operandTypes.size() != operandInfos.size()) {
+    auto plural = "s"[operandInfos.size() == 1];
+    return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
+                                   " operand type" + plural + " but had " +
+                                   llvm::utostr(operandTypes.size())),
+            nullptr);
+  }
+
+  // Resolve all of the operands.
+  SmallVector<SSAValue *, 8> operands;
+  for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
+    operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+    if (!operands.back())
+      return nullptr;
+  }
+
+  auto nameId = builder.getIdentifier(name);
+  return createOpFunc(nameId, operands, fnType->getResults(), attributes);
+}
+
+namespace {
+class CustomOpAsmParser : public OpAsmParser {
+public:
+  CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
+      : nameLoc(nameLoc), opName(opName), parser(parser) {}
+
+  /// This is an internal helper to parser a colon, we don't want to expose
+  /// this to clients.
+  bool internalParseColon(llvm::SMLoc *loc) {
+    if (loc)
+      *loc = parser.getToken().getLoc();
+    return parser.parseToken(Token::colon, "expected ':'");
+  }
+
+  //===--------------------------------------------------------------------===//
+  // High level parsing methods.
+  //===--------------------------------------------------------------------===//
+
+  bool parseComma(llvm::SMLoc *loc = nullptr) override {
+    if (loc)
+      *loc = parser.getToken().getLoc();
+    return parser.parseToken(Token::comma, "expected ','");
+  }
+
+  bool parseColonType(Type *&result, llvm::SMLoc *loc = nullptr) override {
+    return internalParseColon(loc) || !(result = parser.parseType());
+  }
+
+  bool parseColonTypeList(SmallVectorImpl<Type *> &result,
+                          llvm::SMLoc *loc = nullptr) override {
+    if (internalParseColon(loc))
+      return true;
+
+    do {
+      if (auto *type = parser.parseType())
+        result.push_back(type);
+      else
+        return true;
+
+    } while (parser.consumeIf(Token::comma));
+    return false;
+  }
+
+  bool parseAttribute(Attribute *&result, llvm::SMLoc *loc = nullptr) override {
+    if (loc)
+      *loc = parser.getToken().getLoc();
+    result = parser.parseAttribute();
+    return result == nullptr;
+  }
+
+  bool parseOperand(OperandType &result) override {
+    FunctionParser::SSAUseInfo useInfo;
+    if (parser.parseSSAUse(useInfo))
+      return true;
+
+    result = {useInfo.loc, useInfo.name, useInfo.number};
+    return false;
+  }
+
+  bool parseOperandList(SmallVectorImpl<OperandType> &result,
+                        int requiredOperandCount = -1,
+                        Delimeter delimeter = Delimeter::NoDelimeter) override {
+    auto startLoc = parser.getToken().getLoc();
+
+    // Handle delimeters.
+    switch (delimeter) {
+    case Delimeter::NoDelimeter:
+      break;
+    case Delimeter::ParenDelimeter:
+      if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
+        return true;
+      break;
+    case Delimeter::SquareDelimeter:
+      if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
+        return true;
+      break;
+    }
+
+    // Check for zero operands.
+    if (parser.getToken().is(Token::percent_identifier)) {
+      do {
+        OperandType operand;
+        if (parseOperand(operand))
+          return true;
+        result.push_back(operand);
+      } while (parser.consumeIf(Token::comma));
+    }
+
+    // Handle delimeters.
+    switch (delimeter) {
+    case Delimeter::NoDelimeter:
+      break;
+    case Delimeter::ParenDelimeter:
+      if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
+        return true;
+      break;
+    case Delimeter::SquareDelimeter:
+      if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+        return true;
+      break;
+    }
+
+    if (requiredOperandCount != -1 && result.size() != requiredOperandCount)
+      emitError(startLoc,
+                "expected " + Twine(requiredOperandCount) + " operands");
+    return false;
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Methods for interacting with the parser
+  //===--------------------------------------------------------------------===//
+
+  Builder &getBuilder() const override { return parser.builder; }
+
+  llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+  bool resolveOperand(OperandType operand, Type *type,
+                      SSAValue *&result) override {
+    FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+                                              operand.location};
+    result = parser.resolveSSAUse(operandInfo, type);
+    return result == nullptr;
+  }
+
+  /// Emit a diagnostic at the specified location.
+  void emitError(llvm::SMLoc loc, const Twine &message) override {
+    parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
+    emittedError = true;
+  }
+
+  bool didEmitError() const { return emittedError; }
+
+private:
+  SMLoc nameLoc;
+  StringRef opName;
+  FunctionParser &parser;
+  bool emittedError = false;
+};
+} // end anonymous namespace.
+
+Operation *FunctionParser::parseCustomOperation(
+    const CreateOperationFunction &createOpFunc) {
+  auto opLoc = getToken().getLoc();
+  auto opName = getTokenSpelling();
+  CustomOpAsmParser opAsmParser(opLoc, opName, *this);
+
+  auto *opDefinition = getOperationSet().lookup(opName);
+  if (!opDefinition) {
+    opAsmParser.emitError(opLoc, "is unknown");
+    return nullptr;
+  }
+
+  consumeToken();
+
+  // Have the op implementation take a crack and parsing this.
+  auto result = opDefinition->parseAssembly(&opAsmParser);
+
+  // If it emitted an error, we failed.
+  if (opAsmParser.didEmitError())
+    return nullptr;
+
+  // Otherwise, we succeeded.  Use the state it parsed as our op information.
+  auto nameId = builder.getIdentifier(opName);
+  return createOpFunc(nameId, result.operands, result.types, result.attributes);
+}
+
 //===----------------------------------------------------------------------===//
 // CFG Functions
 //===----------------------------------------------------------------------===//