Implement custom parser support for operations, enhance dim/addf to use it, and add a new load op.
This regresses parser error recovery in some cases (in invalid.mlir) which I'll
consider in a follow-up patch. The important thing in this patch is that the
parse methods in StandardOps.cpp are nice and simple.
PiperOrigin-RevId: 206023308
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 52ad4ad..95d95dd 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -31,6 +31,7 @@
namespace mlir {
class Type;
+class OpAsmParser;
class OpAsmPrinter;
/// This pointer represents a notional "Operation*" but where the actual
@@ -68,6 +69,22 @@
const OpType value;
};
+/// This is the result type of parsing a custom operation. If an error is
+/// emitted, it is fine to return this in a partially mutated state.
+struct OpAsmParserResult {
+ SmallVector<SSAValue *, 4> operands;
+ SmallVector<Type *, 4> types;
+ SmallVector<NamedAttribute, 4> attributes;
+
+ /*implicit*/ OpAsmParserResult() {}
+
+ OpAsmParserResult(ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
+ ArrayRef<NamedAttribute> attributes = {})
+ : operands(operands.begin(), operands.end()),
+ types(types.begin(), types.end()),
+ attributes(attributes.begin(), attributes.end()) {}
+};
+
//===----------------------------------------------------------------------===//
// OpImpl Types
//===----------------------------------------------------------------------===//
@@ -108,6 +125,9 @@
/// back to this one which accepts everything.
const char *verify() const { return nullptr; }
+ // Unless overridden, the short form of an op is always rejected.
+ static OpAsmParserResult parse(OpAsmParser *parser);
+
// The fallback for the printer is to print it the longhand form.
void print(OpAsmPrinter *p) const;
@@ -138,6 +158,12 @@
return op->getName().is(ConcreteType::getOperationName());
}
+ /// This is the hook used by the AsmParser to parse the custom form of this
+ /// op from an .mlir file. Op implementations should provide a parse method.
+ static OpAsmParserResult parseAssembly(OpAsmParser *parser) {
+ return ConcreteType::parse(parser);
+ }
+
/// This is the hook used by the AsmPrinter to emit this to the .mlir file.
/// Op implementations should provide a print method.
static void printAssembly(const Operation *op, OpAsmPrinter *p) {
diff --git a/include/mlir/IR/OpImplementation.h b/include/mlir/IR/OpImplementation.h
index 8979d9e..1ef4846 100644
--- a/include/mlir/IR/OpImplementation.h
+++ b/include/mlir/IR/OpImplementation.h
@@ -23,10 +23,17 @@
#define MLIR_IR_OPIMPLEMENTATION_H
#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SMLoc.h"
namespace mlir {
class AffineMap;
class AffineExpr;
+class Builder;
+
+//===----------------------------------------------------------------------===//
+// OpAsmPrinter
+//===----------------------------------------------------------------------===//
/// This is a pure-virtual base class that exposes the asmprinter hooks
/// necessary to implement a custom print() method.
@@ -38,6 +45,19 @@
/// Print implementations for various things an operation contains.
virtual void printOperand(const SSAValue *value) = 0;
+
+ /// Print a comma separated list of operands.
+ template <typename ContainerType>
+ void printOperands(const ContainerType &container) {
+ auto it = container.begin(), end = container.end();
+ if (it == end)
+ return;
+ printOperand(*it);
+ for (++it; it != end; ++it) {
+ getStream() << ", ";
+ printOperand(*it);
+ }
+ }
virtual void printType(const Type *type) = 0;
virtual void printAttribute(const Attribute *attr) = 0;
virtual void printAffineMap(const AffineMap *map) = 0;
@@ -78,6 +98,151 @@
return p;
}
+//===----------------------------------------------------------------------===//
+// OpAsmParser
+//===----------------------------------------------------------------------===//
+
+/// The OpAsmParser has methods for interacting with the asm parser: parsing
+/// things from it, emitting errors etc. It has an intentionally high-level API
+/// that is designed to reduce/constrain syntax innovation in individual
+/// operations.
+///
+/// For example, consider an op like this:
+///
+/// %x = load %p[%1, %2] : memref<...>
+///
+/// The "%x = load" tokens are already parsed and therefore invisible to the
+/// custom op parser. This can be supported by calling `parseOperandList` to
+/// parse the %p, then calling `parseOperandList` with a `SquareDelimeter` to
+/// parse the indices, then calling `parseColonTypeList` to parse the result
+/// type.
+///
+class OpAsmParser {
+public:
+ virtual ~OpAsmParser();
+
+ //===--------------------------------------------------------------------===//
+ // High level parsing methods.
+ //===--------------------------------------------------------------------===//
+
+ // These return void if they always succeed. If they can fail, they emit an
+ // error and return "true". On success, they can optionally provide location
+ // information for clients who want it.
+
+ /// This parses... a comma!
+ virtual bool parseComma(llvm::SMLoc *loc = nullptr) = 0;
+
+ /// Parse a colon followed by a type.
+ virtual bool parseColonType(Type *&result, llvm::SMLoc *loc = nullptr) = 0;
+
+ /// Parse a type of a specific kind, e.g. a FunctionType.
+ template <typename TypeType>
+ bool parseColonType(TypeType *&result, llvm::SMLoc *loc = nullptr) {
+ // Parse any kind of type.
+ Type *type;
+ llvm::SMLoc tmpLoc;
+ if (parseColonType(type, &tmpLoc))
+ return true;
+ if (loc)
+ *loc = tmpLoc;
+
+ // Check for the right kind of attribute.
+ result = dyn_cast<TypeType>(type);
+ if (!result) {
+ emitError(tmpLoc, "invalid kind of type specified");
+ return true;
+ }
+
+ return false;
+ }
+
+ /// Parse a colon followed by a type list, which must have at least one type.
+ virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result,
+ llvm::SMLoc *loc = nullptr) = 0;
+
+ /// Parse an attribute.
+ virtual bool parseAttribute(Attribute *&result,
+ llvm::SMLoc *loc = nullptr) = 0;
+
+ /// Parse an attribute of a specific kind.
+ template <typename AttrType>
+ bool parseAttribute(AttrType *&result, llvm::SMLoc *loc = nullptr) {
+ // Parse any kind of attribute.
+ Attribute *attr;
+ llvm::SMLoc tmpLoc;
+ if (parseAttribute(attr, &tmpLoc))
+ return true;
+ if (loc)
+ *loc = tmpLoc;
+
+ // Check for the right kind of attribute.
+ result = dyn_cast<AttrType>(attr);
+ if (!result) {
+ emitError(tmpLoc, "invalid kind of constant specified");
+ return true;
+ }
+
+ return false;
+ }
+
+ /// This is the representation of an operand reference.
+ struct OperandType {
+ llvm::SMLoc location; // Location of the token.
+ StringRef name; // Value name, e.g. %42 or %abc
+ unsigned number; // Number, e.g. 12 for an operand like %xyz#12
+ };
+
+ /// Parse a single operand.
+ virtual bool parseOperand(OperandType &result) = 0;
+
+ /// These are the supported delimeters around operand lists, used by
+ /// parseOperandList.
+ enum Delimeter {
+ NoDelimeter,
+ ParenDelimeter,
+ SquareDelimeter,
+ };
+
+ /// Parse zero or more SSA comma-separated operand references with a specified
+ /// surrounding delimeter, and an optional required operand count.
+ virtual bool
+ parseOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount = -1,
+ Delimeter delimeter = Delimeter::NoDelimeter) = 0;
+
+ //===--------------------------------------------------------------------===//
+ // Methods for interacting with the parser
+ //===--------------------------------------------------------------------===//
+
+ /// Return a builder which provides useful access to MLIRContext, global
+ /// objects like types and attributes.
+ virtual Builder &getBuilder() const = 0;
+
+ /// Return the location of the original name token.
+ virtual llvm::SMLoc getNameLoc() const = 0;
+
+ /// Resolve an operand to an SSA value, emitting an error and returning true
+ /// on failure.
+ virtual bool resolveOperand(OperandType operand, Type *type,
+ SSAValue *&result) = 0;
+
+ /// Resolve a list of operands to SSA values, emitting an error and returning
+ /// true on failure, or appending the results to the list on success.
+ virtual bool resolveOperands(ArrayRef<OperandType> operand, Type *type,
+ SmallVectorImpl<SSAValue *> &result) {
+ for (auto elt : operand) {
+ SSAValue *value;
+ if (resolveOperand(elt, type, value))
+ return true;
+ result.push_back(value);
+ }
+ return false;
+ }
+
+ /// Emit a diagnostic at the specified location.
+ virtual void emitError(llvm::SMLoc loc, const Twine &message) = 0;
+};
+
} // end namespace mlir
#endif
diff --git a/include/mlir/IR/OperationSet.h b/include/mlir/IR/OperationSet.h
index 7f55517..59dbfb9 100644
--- a/include/mlir/IR/OperationSet.h
+++ b/include/mlir/IR/OperationSet.h
@@ -27,6 +27,8 @@
namespace mlir {
class Operation;
+class OpAsmParser;
+class OpAsmParserResult;
class OpAsmPrinter;
class MLIRContextImpl;
class MLIRContext;
@@ -40,7 +42,8 @@
template <typename T>
static AbstractOperation get() {
return AbstractOperation(T::getOperationName(), T::isClassFor,
- T::printAssembly, T::verifyInvariants);
+ T::parseAssembly, T::printAssembly,
+ T::verifyInvariants);
}
/// This is the name of the operation.
@@ -49,6 +52,9 @@
/// Return true if this "op class" can match against the specified operation.
bool (&isClassFor)(const Operation *op);
+ /// Use the specified object to parse this ops custom assembly format.
+ OpAsmParserResult (&parseAssembly)(OpAsmParser *parser);
+
/// This hook implements the AsmPrinter for this operation.
void (&printAssembly)(const Operation *op, OpAsmPrinter *p);
@@ -60,10 +66,11 @@
private:
AbstractOperation(StringRef name, bool (&isClassFor)(const Operation *op),
+ OpAsmParserResult (&parseAssembly)(OpAsmParser *parser),
void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
const char *(&verifyInvariants)(const Operation *op))
- : name(name), isClassFor(isClassFor), printAssembly(printAssembly),
- verifyInvariants(verifyInvariants) {}
+ : name(name), isClassFor(isClassFor), parseAssembly(parseAssembly),
+ printAssembly(printAssembly), verifyInvariants(verifyInvariants) {}
};
/// An instance of OperationSet is owned and maintained by MLIRContext. It
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index 4b2f4fa..a3b4fcc 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -43,6 +43,7 @@
static StringRef getOperationName() { return "addf"; }
const char *verify() const;
+ static OpAsmParserResult parse(OpAsmParser *parser);
void print(OpAsmPrinter *p) const;
private:
@@ -109,6 +110,7 @@
// Hooks to customize behavior of this op.
const char *verify() const;
+ static OpAsmParserResult parse(OpAsmParser *parser);
void print(OpAsmPrinter *p) const;
private:
@@ -116,15 +118,49 @@
explicit DimOp(const Operation *state) : Base(state) {}
};
-// The "affine_apply" operation applies an affine map to a list of operands,
-// yielding a list of results. The operand and result list sizes must be the
-// same. All operands and results are of type 'AffineInt'. This operation
-// requires a single affine map attribute named "map".
-// For example:
-//
-// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
-// (affineint) -> (affineint)
-//
+/// The "load" op reads an element from a memref specified by an index list. The
+/// output of load is a new value with the same type as the elements of the
+/// memref. The arity of indices is the rank of the memref (i.e., if the memref
+/// loaded from is of rank 3, then 3 indices are required for the load following
+/// the memref identifier). For example:
+///
+/// %3 = load %0[%1, %1] : memref<4x4xi32>
+///
+class LoadOp
+ : public OpImpl::Base<LoadOp, OpImpl::VariadicOperands, OpImpl::OneResult> {
+public:
+ SSAValue *getMemRef() { return getOperand(0); }
+ const SSAValue *getMemRef() const { return getOperand(0); }
+
+ llvm::iterator_range<Operation::operand_iterator> getIndices() {
+ return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+ }
+
+ llvm::iterator_range<Operation::const_operand_iterator> getIndices() const {
+ return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+ }
+
+ static StringRef getOperationName() { return "load"; }
+
+ // Hooks to customize behavior of this op.
+ const char *verify() const;
+ static OpAsmParserResult parse(OpAsmParser *parser);
+ void print(OpAsmPrinter *p) const;
+
+private:
+ friend class Operation;
+ explicit LoadOp(const Operation *state) : Base(state) {}
+};
+
+/// The "affine_apply" operation applies an affine map to a list of operands,
+/// yielding a list of results. The operand and result list sizes must be the
+/// same. All operands and results are of type 'AffineInt'. This operation
+/// requires a single affine map attribute named "map".
+/// For example:
+///
+/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
+/// (affineint) -> (affineint)
+///
class AffineApplyOp
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
OpImpl::VariadicResults> {
diff --git a/lib/IR/OperationSet.cpp b/lib/IR/OperationSet.cpp
index c8e1e9b..43a65aa 100644
--- a/lib/IR/OperationSet.cpp
+++ b/lib/IR/OperationSet.cpp
@@ -19,10 +19,19 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using llvm::StringMap;
+OpAsmParser::~OpAsmParser() {}
+
+// The fallback for the printer is to reject the short form.
+OpAsmParserResult OpImpl::BaseState::parse(OpAsmParser *parser) {
+ parser->emitError(parser->getNameLoc(), "has no concise form");
+ return {};
+}
+
// The fallback for the printer is to print it the longhand form.
void OpImpl::BaseState::print(OpAsmPrinter *p) const {
p->printDefaultOp(getOperation());
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 4da0d60..069f3ed 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/SSAValue.h"
@@ -26,6 +27,18 @@
// TODO: Have verify functions return std::string to enable more descriptive
// error messages.
+OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
+ SmallVector<OpAsmParser::OperandType, 2> ops;
+ Type *type;
+ SSAValue *lhs, *rhs;
+ if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
+ parser->resolveOperand(ops[0], type, lhs) ||
+ parser->resolveOperand(ops[1], type, rhs))
+ return {};
+
+ return OpAsmParserResult({lhs, rhs}, type);
+}
+
void AddFOp::print(OpAsmPrinter *p) const {
*p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
<< *getType();
@@ -71,6 +84,22 @@
<< *getOperand()->getType();
}
+OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
+ OpAsmParser::OperandType operandInfo;
+ IntegerAttr *indexAttr;
+ Type *type;
+ SSAValue *operand;
+ if (parser->parseOperand(operandInfo) || parser->parseComma() ||
+ parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
+ parser->resolveOperand(operandInfo, type, operand))
+ return {};
+
+ auto &builder = parser->getBuilder();
+ return OpAsmParserResult(
+ operand, builder.getAffineIntType(),
+ NamedAttribute(builder.getIdentifier("index"), indexAttr));
+}
+
const char *DimOp::verify() const {
// Check that we have an integer index operand.
auto indexAttr = getAttrOfType<IntegerAttr>("index");
@@ -95,6 +124,35 @@
return nullptr;
}
+void LoadOp::print(OpAsmPrinter *p) const {
+ *p << "load " << *getMemRef() << '[';
+ p->printOperands(getIndices());
+ *p << "] : " << *getMemRef()->getType();
+}
+
+OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
+ OpAsmParser::OperandType memrefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ MemRefType *type;
+ SmallVector<SSAValue *, 4> operands;
+
+ auto affineIntTy = parser->getBuilder().getAffineIntType();
+ if (parser->parseOperand(memrefInfo) ||
+ parser->parseOperandList(indexInfo, -1,
+ OpAsmParser::Delimeter::SquareDelimeter) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperands(memrefInfo, type, operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, operands))
+ return {};
+
+ return OpAsmParserResult(operands, type->getElementType());
+}
+
+const char *LoadOp::verify() const {
+ // TODO: Check load
+ return nullptr;
+}
+
void AffineApplyOp::print(OpAsmPrinter *p) const {
// TODO: Print operands etc.
*p << "affine_apply map: " << *getAffineMap();
@@ -122,5 +180,6 @@
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, ConstantOp, DimOp, AffineApplyOp>(/*prefix=*/"");
+ opSet.addOperations<AddFOp, ConstantOp, DimOp, LoadOp, AffineApplyOp>(
+ /*prefix=*/"");
}
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 91596d3..011dfcb 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -77,8 +77,10 @@
case ')': return formToken(Token::r_paren, tokStart);
case '{': return formToken(Token::l_brace, tokStart);
case '}': return formToken(Token::r_brace, tokStart);
- case '[': return formToken(Token::l_bracket, tokStart);
- case ']': return formToken(Token::r_bracket, tokStart);
+ case '[':
+ return formToken(Token::l_square, tokStart);
+ case ']':
+ return formToken(Token::r_square, tokStart);
case '<': return formToken(Token::less, tokStart);
case '>': return formToken(Token::greater, tokStart);
case '=': return formToken(Token::equal, tokStart);
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3c655d1..63ebeff 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
@@ -52,7 +53,8 @@
SMDiagnosticHandlerTy errorReporter)
: context(module->getContext()), module(module),
lex(sourceMgr, errorReporter), curToken(lex.lexToken()),
- errorReporter(errorReporter) {}
+ errorReporter(errorReporter), operationSet(OperationSet::get(context)) {
+ }
// A map from affine map identifier to AffineMap.
llvm::StringMap<AffineMap *> affineMapDefinitions;
@@ -77,6 +79,9 @@
// The diagnostic error reporter.
SMDiagnosticHandlerTy const errorReporter;
+
+ // The active OperationSet we're parsing with.
+ OperationSet &operationSet;
};
} // end anonymous namespace
@@ -99,6 +104,7 @@
ParserState &getState() const { return state; }
MLIRContext *getContext() const { return state.context; }
Module *getModule() { return state.module; }
+ OperationSet &getOperationSet() const { return state.operationSet; }
/// Return the current token the parser is inspecting.
const Token &getToken() const { return state.curToken; }
@@ -611,8 +617,8 @@
return builder.getStringAttr(val);
}
- case Token::l_bracket: {
- consumeToken(Token::l_bracket);
+ case Token::l_square: {
+ consumeToken(Token::l_square);
SmallVector<Attribute *, 4> elements;
auto parseElt = [&]() -> ParseResult {
@@ -620,7 +626,7 @@
return elements.back() ? ParseSuccess : ParseFailure;
};
- if (parseCommaSeparatedListUntil(Token::r_bracket, parseElt))
+ if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
@@ -1101,11 +1107,11 @@
/// Parse the list of symbolic identifiers to an affine map.
ParseResult AffineMapParser::parseSymbolIdList() {
- if (parseToken(Token::l_bracket, "expected '['"))
+ if (parseToken(Token::l_square, "expected '['"))
return ParseFailure;
auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
- return parseCommaSeparatedListUntil(Token::r_bracket, parseElt);
+ return parseCommaSeparatedListUntil(Token::r_square, parseElt);
}
/// Parse the list of dimensional identifiers to an affine map.
@@ -1131,7 +1137,7 @@
return nullptr;
// Symbols are optional.
- if (getToken().is(Token::l_bracket)) {
+ if (getToken().is(Token::l_square)) {
if (parseSymbolIdList())
return nullptr;
}
@@ -1260,6 +1266,8 @@
// Operations
ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
+ Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc);
+ Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
private:
/// This keeps track of all of the SSA values we are tracking, indexed by
@@ -1416,7 +1424,7 @@
///
ParseResult
FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
- if (!getToken().is(Token::percent_identifier))
+ if (getToken().isNot(Token::percent_identifier))
return ParseSuccess;
return parseCommaSeparatedList([&]() -> ParseResult {
SSAUseInfo result;
@@ -1514,60 +1522,15 @@
return ParseFailure;
}
- if (getToken().isNot(Token::string))
+ Operation *op;
+ if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
+ op = parseCustomOperation(createOpFunc);
+ else if (getToken().is(Token::string))
+ op = parseVerboseOperation(createOpFunc);
+ else
return emitError("expected operation name in quotes");
- auto name = getToken().getStringValue();
- if (name.empty())
- return emitError("empty operation name is invalid");
-
- consumeToken(Token::string);
-
- // Parse the operand list.
- SmallVector<SSAUseInfo, 8> operandInfos;
-
- if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
- parseOptionalSSAUseList(operandInfos) ||
- parseToken(Token::r_paren, "expected ')' to end operand list")) {
- return ParseFailure;
- }
-
- SmallVector<NamedAttribute, 4> attributes;
- if (getToken().is(Token::l_brace)) {
- if (parseAttributeDict(attributes))
- return ParseFailure;
- }
-
- if (parseToken(Token::colon, "expected ':' followed by instruction type"))
- return ParseFailure;
-
- auto typeLoc = getToken().getLoc();
- auto type = parseType();
- if (!type)
- return ParseFailure;
- auto fnType = dyn_cast<FunctionType>(type);
- if (!fnType)
- return emitError(typeLoc, "expected function type");
-
- // Check that we have the right number of types for the operands.
- auto operandTypes = fnType->getInputs();
- if (operandTypes.size() != operandInfos.size()) {
- auto plural = "s"[operandInfos.size() == 1];
- return emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
- " operand type" + plural + " but had " +
- llvm::utostr(operandTypes.size()));
- }
-
- // Resolve all of the operands.
- SmallVector<SSAValue *, 8> operands;
- for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
- operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
- if (!operands.back())
- return ParseFailure;
- }
-
- auto nameId = builder.getIdentifier(name);
- auto op = createOpFunc(nameId, operands, fnType->getResults(), attributes);
+ // If parsing of the basic operation failed, then this whole thing fails.
if (!op)
return ParseFailure;
@@ -1595,6 +1558,228 @@
return ParseSuccess;
}
+Operation *FunctionParser::parseVerboseOperation(
+ const CreateOperationFunction &createOpFunc) {
+ auto name = getToken().getStringValue();
+ if (name.empty())
+ return (emitError("empty operation name is invalid"), nullptr);
+
+ consumeToken(Token::string);
+
+ // Parse the operand list.
+ SmallVector<SSAUseInfo, 8> operandInfos;
+
+ if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+ parseOptionalSSAUseList(operandInfos) ||
+ parseToken(Token::r_paren, "expected ')' to end operand list")) {
+ return nullptr;
+ }
+
+ SmallVector<NamedAttribute, 4> attributes;
+ if (getToken().is(Token::l_brace)) {
+ if (parseAttributeDict(attributes))
+ return nullptr;
+ }
+
+ if (parseToken(Token::colon, "expected ':' followed by instruction type"))
+ return nullptr;
+
+ auto typeLoc = getToken().getLoc();
+ auto type = parseType();
+ if (!type)
+ return nullptr;
+ auto fnType = dyn_cast<FunctionType>(type);
+ if (!fnType)
+ return (emitError(typeLoc, "expected function type"), nullptr);
+
+ // Check that we have the right number of types for the operands.
+ auto operandTypes = fnType->getInputs();
+ if (operandTypes.size() != operandInfos.size()) {
+ auto plural = "s"[operandInfos.size() == 1];
+ return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
+ " operand type" + plural + " but had " +
+ llvm::utostr(operandTypes.size())),
+ nullptr);
+ }
+
+ // Resolve all of the operands.
+ SmallVector<SSAValue *, 8> operands;
+ for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
+ operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+ if (!operands.back())
+ return nullptr;
+ }
+
+ auto nameId = builder.getIdentifier(name);
+ return createOpFunc(nameId, operands, fnType->getResults(), attributes);
+}
+
+namespace {
+class CustomOpAsmParser : public OpAsmParser {
+public:
+ CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
+ : nameLoc(nameLoc), opName(opName), parser(parser) {}
+
+ /// This is an internal helper to parser a colon, we don't want to expose
+ /// this to clients.
+ bool internalParseColon(llvm::SMLoc *loc) {
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ return parser.parseToken(Token::colon, "expected ':'");
+ }
+
+ //===--------------------------------------------------------------------===//
+ // High level parsing methods.
+ //===--------------------------------------------------------------------===//
+
+ bool parseComma(llvm::SMLoc *loc = nullptr) override {
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ return parser.parseToken(Token::comma, "expected ','");
+ }
+
+ bool parseColonType(Type *&result, llvm::SMLoc *loc = nullptr) override {
+ return internalParseColon(loc) || !(result = parser.parseType());
+ }
+
+ bool parseColonTypeList(SmallVectorImpl<Type *> &result,
+ llvm::SMLoc *loc = nullptr) override {
+ if (internalParseColon(loc))
+ return true;
+
+ do {
+ if (auto *type = parser.parseType())
+ result.push_back(type);
+ else
+ return true;
+
+ } while (parser.consumeIf(Token::comma));
+ return false;
+ }
+
+ bool parseAttribute(Attribute *&result, llvm::SMLoc *loc = nullptr) override {
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ result = parser.parseAttribute();
+ return result == nullptr;
+ }
+
+ bool parseOperand(OperandType &result) override {
+ FunctionParser::SSAUseInfo useInfo;
+ if (parser.parseSSAUse(useInfo))
+ return true;
+
+ result = {useInfo.loc, useInfo.name, useInfo.number};
+ return false;
+ }
+
+ bool parseOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount = -1,
+ Delimeter delimeter = Delimeter::NoDelimeter) override {
+ auto startLoc = parser.getToken().getLoc();
+
+ // Handle delimeters.
+ switch (delimeter) {
+ case Delimeter::NoDelimeter:
+ break;
+ case Delimeter::ParenDelimeter:
+ if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
+ return true;
+ break;
+ case Delimeter::SquareDelimeter:
+ if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
+ return true;
+ break;
+ }
+
+ // Check for zero operands.
+ if (parser.getToken().is(Token::percent_identifier)) {
+ do {
+ OperandType operand;
+ if (parseOperand(operand))
+ return true;
+ result.push_back(operand);
+ } while (parser.consumeIf(Token::comma));
+ }
+
+ // Handle delimeters.
+ switch (delimeter) {
+ case Delimeter::NoDelimeter:
+ break;
+ case Delimeter::ParenDelimeter:
+ if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
+ return true;
+ break;
+ case Delimeter::SquareDelimeter:
+ if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+ return true;
+ break;
+ }
+
+ if (requiredOperandCount != -1 && result.size() != requiredOperandCount)
+ emitError(startLoc,
+ "expected " + Twine(requiredOperandCount) + " operands");
+ return false;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Methods for interacting with the parser
+ //===--------------------------------------------------------------------===//
+
+ Builder &getBuilder() const override { return parser.builder; }
+
+ llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+ bool resolveOperand(OperandType operand, Type *type,
+ SSAValue *&result) override {
+ FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+ operand.location};
+ result = parser.resolveSSAUse(operandInfo, type);
+ return result == nullptr;
+ }
+
+ /// Emit a diagnostic at the specified location.
+ void emitError(llvm::SMLoc loc, const Twine &message) override {
+ parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
+ emittedError = true;
+ }
+
+ bool didEmitError() const { return emittedError; }
+
+private:
+ SMLoc nameLoc;
+ StringRef opName;
+ FunctionParser &parser;
+ bool emittedError = false;
+};
+} // end anonymous namespace.
+
+Operation *FunctionParser::parseCustomOperation(
+ const CreateOperationFunction &createOpFunc) {
+ auto opLoc = getToken().getLoc();
+ auto opName = getTokenSpelling();
+ CustomOpAsmParser opAsmParser(opLoc, opName, *this);
+
+ auto *opDefinition = getOperationSet().lookup(opName);
+ if (!opDefinition) {
+ opAsmParser.emitError(opLoc, "is unknown");
+ return nullptr;
+ }
+
+ consumeToken();
+
+ // Have the op implementation take a crack and parsing this.
+ auto result = opDefinition->parseAssembly(&opAsmParser);
+
+ // If it emitted an error, we failed.
+ if (opAsmParser.didEmitError())
+ return nullptr;
+
+ // Otherwise, we succeeded. Use the state it parsed as our op information.
+ auto nameId = builder.getIdentifier(opName);
+ return createOpFunc(nameId, result.operands, result.types, result.attributes);
+}
+
//===----------------------------------------------------------------------===//
// CFG Functions
//===----------------------------------------------------------------------===//
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index b9ef9b05..44a40ee 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -71,8 +71,8 @@
TOK_PUNCTUATION(r_paren, ")")
TOK_PUNCTUATION(l_brace, "{")
TOK_PUNCTUATION(r_brace, "}")
-TOK_PUNCTUATION(l_bracket, "[")
-TOK_PUNCTUATION(r_bracket, "]")
+TOK_PUNCTUATION(l_square, "[")
+TOK_PUNCTUATION(r_square, "]")
TOK_PUNCTUATION(less, "<")
TOK_PUNCTUATION(greater, ">")
TOK_PUNCTUATION(equal, "=")
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index ab82e4b..5618239 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -2,6 +2,7 @@
// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), (d1 + 2))
#map5 = (d0, d1) -> (d0 + 1, d1 + 2)
+#id2 = (i,j)->(i,j)
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
cfgfunc @cfgfunc_with_ops(f32) {
@@ -19,22 +20,23 @@
return
}
-// CHECK-LABEL: cfgfunc @standard_instrs() {
-cfgfunc @standard_instrs() {
-bb42: // CHECK: bb0:
- // CHECK: %0 = "getTensor"() : () -> tensor<4x4x?xf32>
- %42 = "getTensor"() : () -> tensor<4x4x?xf32>
+// CHECK-LABEL: cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32) {
+cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32) {
+// CHECK: bb0(%0: tensor<4x4x?xf32>, %1: f32):
+bb42(%t: tensor<4x4x?xf32>, %f: f32):
+ // CHECK: %2 = dim %0, 2 : tensor<4x4x?xf32>
+ %a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
- // CHECK: dim %0, 2 : tensor<4x4x?xf32>
- %a = "dim"(%42){index: 2} : (tensor<4x4x?xf32>) -> affineint
+ // CHECK: %3 = dim %0, 2 : tensor<4x4x?xf32>
+ %a2 = dim %t, 2 : tensor<4x4x?xf32>
- // FIXME: Add support for fp attributes so this can use 'constant'.
- %f = "FIXMEConst"(){value: 1} : () -> f32
+ // CHECK: %4 = addf %1, %1 : f32
+ %f2 = "addf"(%f, %f) : (f32,f32) -> f32
- // CHECK: %3 = addf %2, %2 : f32
- "addf"(%f, %f) : (f32,f32) -> f32
+ // CHECK: %5 = addf %4, %4 : f32
+ %f3 = addf %f2, %f2 : f32
- // CHECK: %4 = "constant"(){value: 42} : () -> i32
+ // CHECK: %6 = "constant"(){value: 42} : () -> i32
%x = "constant"(){value: 42} : () -> i32
return
}
@@ -53,4 +55,18 @@
%y = "affine_apply" (%i, %j) { map: #map5 } :
(affineint, affineint) -> (affineint, affineint)
return
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: cfgfunc @load_store
+cfgfunc @load_store(memref<4x4xi32, #id2, 0>, affineint) {
+bb0(%0: memref<4x4xi32, #id2, 0>, %1: affineint):
+
+ // CHECK: %2 = load %0[%1, %1] : memref<4x4xi32, #map2, 0>
+ %2 = "load"(%0, %1, %1) : (memref<4x4xi32, #id2, 0>, affineint, affineint)->i32
+
+ // CHECK: %3 = load %0[%1, %1] : memref<4x4xi32, #map2, 0>
+ %3 = load %0[%1, %1] : memref<4x4xi32, #id2, 0>
+
+ return
+}
+
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index 4aafb21..2991d63 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -60,3 +60,11 @@
%x = "affine_apply" (%i, %j) {map: (d0, d1) -> ((d0 + 1), (d1 + 2))} : (affineint,affineint) -> (affineint) // expected-error {{'affine_apply' op result count and affine map result count must match}}
return
}
+
+// -----
+
+cfgfunc @unknown_custom_op() {
+bb0:
+ %i = crazyThing() {value: 0} : () -> affineint // expected-error {{custom op 'crazyThing' is unknown}}
+ return
+}
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index c82f0ec..2e87540 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -85,7 +85,7 @@
bb40:
return
bb41:
-bb42: // expected-error {{expected operation name}}
+bb42: // expected-error {{custom op 'bb42' is unknown}}
return
}
@@ -170,7 +170,7 @@
// -----
mlfunc @non_statement() {
- asd // expected-error {{expected operation name in quotes}}
+ asd // expected-error {{custom op 'asd' is unknown}}
}
// -----