Implement custom parser support for operations, enhance dim/addf to use it, and add a new load op.
This regresses parser error recovery in some cases (in invalid.mlir) which I'll
consider in a follow-up patch. The important thing in this patch is that the
parse methods in StandardOps.cpp are nice and simple.
PiperOrigin-RevId: 206023308
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3c655d1..63ebeff 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
@@ -52,7 +53,8 @@
SMDiagnosticHandlerTy errorReporter)
: context(module->getContext()), module(module),
lex(sourceMgr, errorReporter), curToken(lex.lexToken()),
- errorReporter(errorReporter) {}
+ errorReporter(errorReporter), operationSet(OperationSet::get(context)) {
+ }
// A map from affine map identifier to AffineMap.
llvm::StringMap<AffineMap *> affineMapDefinitions;
@@ -77,6 +79,9 @@
// The diagnostic error reporter.
SMDiagnosticHandlerTy const errorReporter;
+
+ // The active OperationSet we're parsing with.
+ OperationSet &operationSet;
};
} // end anonymous namespace
@@ -99,6 +104,7 @@
ParserState &getState() const { return state; }
MLIRContext *getContext() const { return state.context; }
Module *getModule() { return state.module; }
+ OperationSet &getOperationSet() const { return state.operationSet; }
/// Return the current token the parser is inspecting.
const Token &getToken() const { return state.curToken; }
@@ -611,8 +617,8 @@
return builder.getStringAttr(val);
}
- case Token::l_bracket: {
- consumeToken(Token::l_bracket);
+ case Token::l_square: {
+ consumeToken(Token::l_square);
SmallVector<Attribute *, 4> elements;
auto parseElt = [&]() -> ParseResult {
@@ -620,7 +626,7 @@
return elements.back() ? ParseSuccess : ParseFailure;
};
- if (parseCommaSeparatedListUntil(Token::r_bracket, parseElt))
+ if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
@@ -1101,11 +1107,11 @@
/// Parse the list of symbolic identifiers to an affine map.
ParseResult AffineMapParser::parseSymbolIdList() {
- if (parseToken(Token::l_bracket, "expected '['"))
+ if (parseToken(Token::l_square, "expected '['"))
return ParseFailure;
auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
- return parseCommaSeparatedListUntil(Token::r_bracket, parseElt);
+ return parseCommaSeparatedListUntil(Token::r_square, parseElt);
}
/// Parse the list of dimensional identifiers to an affine map.
@@ -1131,7 +1137,7 @@
return nullptr;
// Symbols are optional.
- if (getToken().is(Token::l_bracket)) {
+ if (getToken().is(Token::l_square)) {
if (parseSymbolIdList())
return nullptr;
}
@@ -1260,6 +1266,8 @@
// Operations
ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
+ Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc);
+ Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
private:
/// This keeps track of all of the SSA values we are tracking, indexed by
@@ -1416,7 +1424,7 @@
///
ParseResult
FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
- if (!getToken().is(Token::percent_identifier))
+ if (getToken().isNot(Token::percent_identifier))
return ParseSuccess;
return parseCommaSeparatedList([&]() -> ParseResult {
SSAUseInfo result;
@@ -1514,60 +1522,15 @@
return ParseFailure;
}
- if (getToken().isNot(Token::string))
+ Operation *op;
+ if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
+ op = parseCustomOperation(createOpFunc);
+ else if (getToken().is(Token::string))
+ op = parseVerboseOperation(createOpFunc);
+ else
return emitError("expected operation name in quotes");
- auto name = getToken().getStringValue();
- if (name.empty())
- return emitError("empty operation name is invalid");
-
- consumeToken(Token::string);
-
- // Parse the operand list.
- SmallVector<SSAUseInfo, 8> operandInfos;
-
- if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
- parseOptionalSSAUseList(operandInfos) ||
- parseToken(Token::r_paren, "expected ')' to end operand list")) {
- return ParseFailure;
- }
-
- SmallVector<NamedAttribute, 4> attributes;
- if (getToken().is(Token::l_brace)) {
- if (parseAttributeDict(attributes))
- return ParseFailure;
- }
-
- if (parseToken(Token::colon, "expected ':' followed by instruction type"))
- return ParseFailure;
-
- auto typeLoc = getToken().getLoc();
- auto type = parseType();
- if (!type)
- return ParseFailure;
- auto fnType = dyn_cast<FunctionType>(type);
- if (!fnType)
- return emitError(typeLoc, "expected function type");
-
- // Check that we have the right number of types for the operands.
- auto operandTypes = fnType->getInputs();
- if (operandTypes.size() != operandInfos.size()) {
- auto plural = "s"[operandInfos.size() == 1];
- return emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
- " operand type" + plural + " but had " +
- llvm::utostr(operandTypes.size()));
- }
-
- // Resolve all of the operands.
- SmallVector<SSAValue *, 8> operands;
- for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
- operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
- if (!operands.back())
- return ParseFailure;
- }
-
- auto nameId = builder.getIdentifier(name);
- auto op = createOpFunc(nameId, operands, fnType->getResults(), attributes);
+ // If parsing of the basic operation failed, then this whole thing fails.
if (!op)
return ParseFailure;
@@ -1595,6 +1558,228 @@
return ParseSuccess;
}
+Operation *FunctionParser::parseVerboseOperation(
+ const CreateOperationFunction &createOpFunc) {
+ auto name = getToken().getStringValue();
+ if (name.empty())
+ return (emitError("empty operation name is invalid"), nullptr);
+
+ consumeToken(Token::string);
+
+ // Parse the operand list.
+ SmallVector<SSAUseInfo, 8> operandInfos;
+
+ if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+ parseOptionalSSAUseList(operandInfos) ||
+ parseToken(Token::r_paren, "expected ')' to end operand list")) {
+ return nullptr;
+ }
+
+ SmallVector<NamedAttribute, 4> attributes;
+ if (getToken().is(Token::l_brace)) {
+ if (parseAttributeDict(attributes))
+ return nullptr;
+ }
+
+ if (parseToken(Token::colon, "expected ':' followed by instruction type"))
+ return nullptr;
+
+ auto typeLoc = getToken().getLoc();
+ auto type = parseType();
+ if (!type)
+ return nullptr;
+ auto fnType = dyn_cast<FunctionType>(type);
+ if (!fnType)
+ return (emitError(typeLoc, "expected function type"), nullptr);
+
+ // Check that we have the right number of types for the operands.
+ auto operandTypes = fnType->getInputs();
+ if (operandTypes.size() != operandInfos.size()) {
+ auto plural = "s"[operandInfos.size() == 1];
+ return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
+ " operand type" + plural + " but had " +
+ llvm::utostr(operandTypes.size())),
+ nullptr);
+ }
+
+ // Resolve all of the operands.
+ SmallVector<SSAValue *, 8> operands;
+ for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
+ operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+ if (!operands.back())
+ return nullptr;
+ }
+
+ auto nameId = builder.getIdentifier(name);
+ return createOpFunc(nameId, operands, fnType->getResults(), attributes);
+}
+
+namespace {
+class CustomOpAsmParser : public OpAsmParser {
+public:
+ CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
+ : nameLoc(nameLoc), opName(opName), parser(parser) {}
+
+ /// This is an internal helper to parser a colon, we don't want to expose
+ /// this to clients.
+ bool internalParseColon(llvm::SMLoc *loc) {
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ return parser.parseToken(Token::colon, "expected ':'");
+ }
+
+ //===--------------------------------------------------------------------===//
+ // High level parsing methods.
+ //===--------------------------------------------------------------------===//
+
+ bool parseComma(llvm::SMLoc *loc = nullptr) override {
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ return parser.parseToken(Token::comma, "expected ','");
+ }
+
+ bool parseColonType(Type *&result, llvm::SMLoc *loc = nullptr) override {
+ return internalParseColon(loc) || !(result = parser.parseType());
+ }
+
+ bool parseColonTypeList(SmallVectorImpl<Type *> &result,
+ llvm::SMLoc *loc = nullptr) override {
+ if (internalParseColon(loc))
+ return true;
+
+ do {
+ if (auto *type = parser.parseType())
+ result.push_back(type);
+ else
+ return true;
+
+ } while (parser.consumeIf(Token::comma));
+ return false;
+ }
+
+ bool parseAttribute(Attribute *&result, llvm::SMLoc *loc = nullptr) override {
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ result = parser.parseAttribute();
+ return result == nullptr;
+ }
+
+ bool parseOperand(OperandType &result) override {
+ FunctionParser::SSAUseInfo useInfo;
+ if (parser.parseSSAUse(useInfo))
+ return true;
+
+ result = {useInfo.loc, useInfo.name, useInfo.number};
+ return false;
+ }
+
+ bool parseOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount = -1,
+ Delimeter delimeter = Delimeter::NoDelimeter) override {
+ auto startLoc = parser.getToken().getLoc();
+
+ // Handle delimeters.
+ switch (delimeter) {
+ case Delimeter::NoDelimeter:
+ break;
+ case Delimeter::ParenDelimeter:
+ if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
+ return true;
+ break;
+ case Delimeter::SquareDelimeter:
+ if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
+ return true;
+ break;
+ }
+
+ // Check for zero operands.
+ if (parser.getToken().is(Token::percent_identifier)) {
+ do {
+ OperandType operand;
+ if (parseOperand(operand))
+ return true;
+ result.push_back(operand);
+ } while (parser.consumeIf(Token::comma));
+ }
+
+ // Handle delimeters.
+ switch (delimeter) {
+ case Delimeter::NoDelimeter:
+ break;
+ case Delimeter::ParenDelimeter:
+ if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
+ return true;
+ break;
+ case Delimeter::SquareDelimeter:
+ if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+ return true;
+ break;
+ }
+
+ if (requiredOperandCount != -1 && result.size() != requiredOperandCount)
+ emitError(startLoc,
+ "expected " + Twine(requiredOperandCount) + " operands");
+ return false;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Methods for interacting with the parser
+ //===--------------------------------------------------------------------===//
+
+ Builder &getBuilder() const override { return parser.builder; }
+
+ llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+ bool resolveOperand(OperandType operand, Type *type,
+ SSAValue *&result) override {
+ FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+ operand.location};
+ result = parser.resolveSSAUse(operandInfo, type);
+ return result == nullptr;
+ }
+
+ /// Emit a diagnostic at the specified location.
+ void emitError(llvm::SMLoc loc, const Twine &message) override {
+ parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
+ emittedError = true;
+ }
+
+ bool didEmitError() const { return emittedError; }
+
+private:
+ SMLoc nameLoc;
+ StringRef opName;
+ FunctionParser &parser;
+ bool emittedError = false;
+};
+} // end anonymous namespace.
+
+Operation *FunctionParser::parseCustomOperation(
+ const CreateOperationFunction &createOpFunc) {
+ auto opLoc = getToken().getLoc();
+ auto opName = getTokenSpelling();
+ CustomOpAsmParser opAsmParser(opLoc, opName, *this);
+
+ auto *opDefinition = getOperationSet().lookup(opName);
+ if (!opDefinition) {
+ opAsmParser.emitError(opLoc, "is unknown");
+ return nullptr;
+ }
+
+ consumeToken();
+
+ // Have the op implementation take a crack and parsing this.
+ auto result = opDefinition->parseAssembly(&opAsmParser);
+
+ // If it emitted an error, we failed.
+ if (opAsmParser.didEmitError())
+ return nullptr;
+
+ // Otherwise, we succeeded. Use the state it parsed as our op information.
+ auto nameId = builder.getIdentifier(opName);
+ return createOpFunc(nameId, result.operands, result.types, result.attributes);
+}
+
//===----------------------------------------------------------------------===//
// CFG Functions
//===----------------------------------------------------------------------===//