Refactor the asmparser hook to work with a new OperationState type that fully
encapsulates an operation that is yet to be created. This is a patch towards
custom ops providing create methods that don't need to be templated, allowing
them to move out of line in the future.
PiperOrigin-RevId: 207725557
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 59aca69..80b30c6 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -155,7 +155,8 @@
// Add new basic block and set the insertion point to the end of it.
BasicBlock *createBlock();
- // Create an operation at the current insertion point.
+ // TODO(clattner): remove this.
+ /// Create an operation at the current insertion point.
OperationInst *createOperation(Identifier name, ArrayRef<CFGValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes) {
@@ -165,18 +166,23 @@
return op;
}
+ /// Create an operation given the fields represented as an OperationState.
+ OperationInst *createOperation(const OperationState &state);
+
+ // TODO(clattner): rework build to return an OperationState so the
+ // implementations can moved out of line.
+ /// Create operation of specific op type at the current insertion point.
+ template <typename OpTy, typename... Args>
+ OpPointer<OpTy> create(Args... args) {
+ return OpTy::build(this, args...);
+ }
+
OperationInst *cloneOperation(const OperationInst &srcOpInst) {
auto *op = srcOpInst.clone();
block->getOperations().insert(insertPoint, op);
return op;
}
- // Create operation of specific op type at the current insertion point.
- template <typename OpTy, typename... Args>
- OpPointer<OpTy> create(Args... args) {
- return OpTy::build(this, args...);
- }
-
// Terminators.
ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
@@ -256,6 +262,7 @@
/// Get the current insertion point of the builder.
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
+ // TODO(clattner): remove this.
OperationStmt *createOperation(Identifier name, ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes) {
@@ -265,6 +272,11 @@
return op;
}
+ /// Create an operation given the fields represented as an OperationState.
+ OperationStmt *createOperation(const OperationState &state);
+
+ // TODO(clattner): rework build to return an OperationState so the
+ // implementations can moved out of line.
// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index b2a95fe..5fc9de1 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -71,23 +71,6 @@
const OpType value;
};
-/// This is the result type of parsing a custom operation. If an error is
-/// emitted, it is fine to return this in a partially mutated state.
-struct OpAsmParserResult {
- SmallVector<SSAValue *, 4> operands;
- SmallVector<Type *, 4> types;
- SmallVector<NamedAttribute, 4> attributes;
-
- /*implicit*/ OpAsmParserResult() {}
-
- OpAsmParserResult(ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
- ArrayRef<NamedAttribute> attributes = {})
- : operands(operands.begin(), operands.end()),
- types(types.begin(), types.end()),
- attributes(attributes.begin(), attributes.end()) {}
-};
-
-
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
@@ -125,8 +108,11 @@
/// back to this one which accepts everything.
const char *verify() const { return nullptr; }
- // Unless overridden, the short form of an op is always rejected.
- static OpAsmParserResult parse(OpAsmParser *parser);
+ /// Unless overridden, the short form of an op is always rejected. Op
+ /// implementations should implement this to return boolean true on failure.
+ /// On success, they should return false and fill in result with the fields to
+ /// use.
+ static bool parse(OpAsmParser *parser, OperationState *result);
// The fallback for the printer is to print it the longhand form.
void print(OpAsmPrinter *p) const;
@@ -159,9 +145,11 @@
}
/// This is the hook used by the AsmParser to parse the custom form of this
- /// op from an .mlir file. Op implementations should provide a parse method.
- static OpAsmParserResult parseAssembly(OpAsmParser *parser) {
- return ConcreteType::parse(parser);
+ /// op from an .mlir file. Op implementations should provide a parse method,
+ /// which returns boolean true on failure. On success, they should return
+ /// false and fill in result with the fields to use.
+ static bool parseAssembly(OpAsmParser *parser, OperationState *result) {
+ return ConcreteType::parse(parser, result);
}
/// This is the hook used by the AsmPrinter to emit this to the .mlir file.
diff --git a/include/mlir/IR/OpImplementation.h b/include/mlir/IR/OpImplementation.h
index 62b8e4f..65b25ae 100644
--- a/include/mlir/IR/OpImplementation.h
+++ b/include/mlir/IR/OpImplementation.h
@@ -278,8 +278,8 @@
return false;
}
- /// Emit a diagnostic at the specified location.
- virtual void emitError(llvm::SMLoc loc, const Twine &message) = 0;
+ /// Emit a diagnostic at the specified location and return true.
+ virtual bool emitError(llvm::SMLoc loc, const Twine &message) = 0;
};
} // end namespace mlir
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index 6f78232..51f7ee1 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -32,12 +32,34 @@
template <typename ObjectType, typename ElementType> class OperandIterator;
template <typename ObjectType, typename ElementType> class ResultIterator;
class SSAValue;
+class Type;
/// NamedAttribute is a used for operation attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute
/// pointer should always be non-null.
typedef std::pair<Identifier, Attribute*> NamedAttribute;
+/// This represents an operation in an abstracted form, suitable for use with
+/// the builder APIs. This object is a large and heavy weight object meant to
+/// be used as a temporary object on the stack. It is generally unwise to put
+/// this in a collection.
+struct OperationState {
+ Identifier name;
+ SmallVector<SSAValue *, 4> operands;
+ SmallVector<Type *, 4> types;
+ SmallVector<NamedAttribute, 4> attributes;
+
+public:
+ OperationState(Identifier name) : name(name) {}
+
+ OperationState(Identifier name, ArrayRef<SSAValue *> operands,
+ ArrayRef<Type *> types,
+ ArrayRef<NamedAttribute> attributes = {})
+ : name(name), operands(operands.begin(), operands.end()),
+ types(types.begin(), types.end()),
+ attributes(attributes.begin(), attributes.end()) {}
+};
+
/// Operations represent all of the arithmetic and other basic computation in
/// MLIR. This class is the common implementation details behind OperationInst
/// and OperationStmt.
diff --git a/include/mlir/IR/OperationSet.h b/include/mlir/IR/OperationSet.h
index 59dbfb9..1fea466 100644
--- a/include/mlir/IR/OperationSet.h
+++ b/include/mlir/IR/OperationSet.h
@@ -27,6 +27,7 @@
namespace mlir {
class Operation;
+class OperationState;
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
@@ -53,7 +54,7 @@
bool (&isClassFor)(const Operation *op);
/// Use the specified object to parse this ops custom assembly format.
- OpAsmParserResult (&parseAssembly)(OpAsmParser *parser);
+ bool (&parseAssembly)(OpAsmParser *parser, OperationState *result);
/// This hook implements the AsmPrinter for this operation.
void (&printAssembly)(const Operation *op, OpAsmPrinter *p);
@@ -66,7 +67,8 @@
private:
AbstractOperation(StringRef name, bool (&isClassFor)(const Operation *op),
- OpAsmParserResult (&parseAssembly)(OpAsmParser *parser),
+ bool (&parseAssembly)(OpAsmParser *parser,
+ OperationState *result),
void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
const char *(&verifyInvariants)(const Operation *op))
: name(name), isClassFor(isClassFor), parseAssembly(parseAssembly),
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index f8701bf..ac5dbad 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -51,7 +51,7 @@
}
const char *verify() const;
- static OpAsmParserResult parse(OpAsmParser *parser);
+ static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
private:
@@ -85,7 +85,7 @@
static StringRef getOperationName() { return "affine_apply"; }
// Hooks to customize behavior of this op.
- static OpAsmParserResult parse(OpAsmParser *parser);
+ static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
const char *verify() const;
@@ -124,7 +124,7 @@
// Hooks to customize behavior of this op.
const char *verify() const;
- static OpAsmParserResult parse(OpAsmParser *parser);
+ static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
private:
@@ -148,7 +148,7 @@
static StringRef getOperationName() { return "constant"; }
// Hooks to customize behavior of this op.
- static OpAsmParserResult parse(OpAsmParser *parser);
+ static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
const char *verify() const;
@@ -235,7 +235,7 @@
// Hooks to customize behavior of this op.
const char *verify() const;
- static OpAsmParserResult parse(OpAsmParser *parser);
+ static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
private:
@@ -269,7 +269,7 @@
// Hooks to customize behavior of this op.
const char *verify() const;
- static OpAsmParserResult parse(OpAsmParser *parser);
+ static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
private:
@@ -308,7 +308,7 @@
// Hooks to customize behavior of this op.
const char *verify() const;
- static OpAsmParserResult parse(OpAsmParser *parser);
+ static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
private:
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 715e460..4424bd3 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -158,10 +158,36 @@
return b;
}
+/// Create an operation given the fields represented as an OperationState.
+OperationInst *CFGFuncBuilder::createOperation(const OperationState &state) {
+ SmallVector<CFGValue *, 8> operands;
+ operands.reserve(state.operands.size());
+ for (auto elt : state.operands)
+ operands.push_back(cast<CFGValue>(elt));
+
+ auto *op = OperationInst::create(state.name, operands, state.types,
+ state.attributes, context);
+ block->getOperations().insert(insertPoint, op);
+ return op;
+}
+
//===----------------------------------------------------------------------===//
// Statements.
//===----------------------------------------------------------------------===//
+/// Create an operation given the fields represented as an OperationState.
+OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
+ SmallVector<MLValue *, 8> operands;
+ operands.reserve(state.operands.size());
+ for (auto elt : state.operands)
+ operands.push_back(cast<MLValue>(elt));
+
+ auto *op = OperationStmt::create(state.name, operands, state.types,
+ state.attributes, context);
+ block->getStatements().insert(insertPoint, op);
+ return op;
+}
+
ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step) {
diff --git a/lib/IR/OperationSet.cpp b/lib/IR/OperationSet.cpp
index 9bbc84e..98bf5fe 100644
--- a/lib/IR/OperationSet.cpp
+++ b/lib/IR/OperationSet.cpp
@@ -27,9 +27,8 @@
OpAsmParser::~OpAsmParser() {}
// The fallback for the printer is to reject the short form.
-OpAsmParserResult OpBaseState::parse(OpAsmParser *parser) {
- parser->emitError(parser->getNameLoc(), "has no concise form");
- return {};
+bool OpBaseState::parse(OpAsmParser *parser, OperationState *result) {
+ return parser->emitError(parser->getNameLoc(), "has no concise form");
}
// The fallback for the printer is to print it the longhand form.
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 1458ce4..b67c12d 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -44,8 +44,8 @@
// Returns 'false' on success and 'true' on error.
static bool
parseDimAndSymbolList(OpAsmParser *parser,
- SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
+ SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
// Store number of dimensions for validation by caller.
@@ -60,21 +60,18 @@
return false;
}
-// TODO: Have verify functions return std::string to enable more descriptive
-// error messages.
-OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
+bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
- SSAValue *lhs, *rhs;
- SmallVector<NamedAttribute, 4> attrs;
if (parser->parseOperandList(ops, 2) ||
- parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
- parser->resolveOperand(ops[0], type, lhs) ||
- parser->resolveOperand(ops[1], type, rhs))
- return {};
+ parser->resolveOperands(ops, type, result->operands))
+ return true;
- return OpAsmParserResult({lhs, rhs}, type, attrs);
+ // TODO(clattner): rework parseColonType to eliminate the need for this.
+ result->types.push_back(type);
+ return false;
}
void AddFOp::print(OpAsmPrinter *p) const {
@@ -83,6 +80,8 @@
*p << " : " << *getType();
}
+// TODO: Have verify functions return std::string to enable more descriptive
+// error messages.
// Return an error message on failure.
const char *AddFOp::verify() const {
// TODO: Check that the types of the LHS and RHS match.
@@ -91,31 +90,26 @@
return nullptr;
}
-OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
- SmallVector<OpAsmParser::OperandType, 2> opInfos;
- SmallVector<SSAValue *, 4> operands;
- SmallVector<NamedAttribute, 4> attrs;
-
+bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
unsigned numDims;
- if (parser->parseAttribute(mapAttr, "map", attrs) ||
- parseDimAndSymbolList(parser, opInfos, operands, numDims) ||
- parser->parseOptionalAttributeDict(attrs))
- return {};
+ if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
+ parseDimAndSymbolList(parser, result->operands, numDims) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return true;
auto *map = mapAttr->getValue();
if (map->getNumDims() != numDims ||
- numDims + map->getNumSymbols() != opInfos.size()) {
- parser->emitError(parser->getNameLoc(),
- "dimension or symbol index mismatch");
- return {};
+ numDims + map->getNumSymbols() != result->operands.size()) {
+ return parser->emitError(parser->getNameLoc(),
+ "dimension or symbol index mismatch");
}
- SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
- return OpAsmParserResult(operands, resultTypes, attrs);
+ result->types.append(map->getNumResults(), affineIntTy);
+ return false;
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
@@ -155,39 +149,37 @@
*p << " : " << *type;
}
-OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
+bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MemRefType *type;
- SmallVector<SSAValue *, 4> operands;
- SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
- SmallVector<NamedAttribute, 4> attrs;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
unsigned numDimOperands;
- if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
- parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
- return {};
+ if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return true;
// Check numDynamicDims against number of question marks in memref type.
if (numDimOperands != type->getNumDynamicDims()) {
- parser->emitError(parser->getNameLoc(),
- "Dynamic dimensions count mismatch: dimension operand "
- "count does not equal memref dynamic dimension count.");
- return {};
+ return parser->emitError(parser->getNameLoc(),
+ "dimension operand count does not equal memref "
+ "dynamic dimension count");
}
// Check that the number of symbol operands matches the number of symbols in
// the first affinemap of the memref's affine map composition.
// Note that a memref must specify at least one affine map in the composition.
- if ((operandsInfo.size() - numDimOperands) !=
+ if (result->operands.size() - numDimOperands !=
type->getAffineMaps()[0]->getNumSymbols()) {
- parser->emitError(parser->getNameLoc(),
- "AffineMap symbol count mismatch: symbol operand "
- "count does not equal memref affine map symbol count.");
- return {};
+ return parser->emitError(
+ parser->getNameLoc(),
+ "affine map symbol operand count does not equal memref affine map "
+ "symbol count");
}
- return OpAsmParserResult(operands, type, attrs);
+ result->types.push_back(type);
+ return false;
}
const char *AllocOp::verify() const {
@@ -201,15 +193,17 @@
*p << " : " << *getType();
}
-OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
+bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute *valueAttr;
Type *type;
- SmallVector<NamedAttribute, 4> attrs;
- if (parser->parseAttribute(valueAttr, "value", attrs) ||
- parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
- return {};
- return OpAsmParserResult(/*operands=*/{}, type, attrs);
+ if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return true;
+
+ result->types.push_back(type);
+ return false;
}
/// The constant op requires an attribute, and furthermore requires that it
@@ -247,22 +241,22 @@
*p << " : " << *getOperand()->getType();
}
-OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
+bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr *indexAttr;
Type *type;
- SSAValue *operand;
- SmallVector<NamedAttribute, 4> attrs;
+ // TODO(clattner): remove resolveOperand or change it to push onto the
+ // operands list.
if (parser->parseOperand(operandInfo) || parser->parseComma() ||
- parser->parseAttribute(indexAttr, "index", attrs) ||
- parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseAttribute(indexAttr, "index", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
- parser->resolveOperand(operandInfo, type, operand))
- return {};
+ parser->resolveOperands(operandInfo, type, result->operands))
+ return true;
- auto &builder = parser->getBuilder();
- return OpAsmParserResult(operand, builder.getAffineIntType(), attrs);
+ result->types.push_back(parser->getBuilder().getAffineIntType());
+ return false;
}
const char *DimOp::verify() const {
@@ -297,23 +291,23 @@
*p << " : " << *getMemRef()->getType();
}
-OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
+bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type;
- SmallVector<SSAValue *, 4> operands;
- SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
- parser->resolveOperands(memrefInfo, type, operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, operands))
- return {};
+ // TODO: use a new resolveOperand()
+ parser->resolveOperands(memrefInfo, type, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands))
+ return true;
- return OpAsmParserResult(operands, type->getElementType(), attrs);
+ result->types.push_back(type->getElementType());
+ return false;
}
const char *LoadOp::verify() const {
@@ -344,27 +338,24 @@
*p << " : " << *getMemRef()->getType();
}
-OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
+bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- SmallVector<SSAValue *, 4> operands;
MemRefType *memrefType;
- SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
- if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
- parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(attrs) ||
- parser->parseColonType(memrefType) ||
- parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
- operands) ||
- parser->resolveOperands(memrefInfo, memrefType, operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, operands))
- return {};
-
- return OpAsmParserResult(operands, {}, attrs);
+ return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+ parser->parseOperand(memrefInfo) ||
+ parser->parseOperandList(indexInfo, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(memrefType) ||
+ parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
+ result->operands) ||
+ // TODO: use a new resolveOperand().
+ parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands);
}
const char *StoreOp::verify() const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index b50e950..d093064 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1799,10 +1799,11 @@
return result == nullptr;
}
- /// Emit a diagnostic at the specified location.
- void emitError(llvm::SMLoc loc, const Twine &message) override {
+ /// Emit a diagnostic at the specified location and return true.
+ bool emitError(llvm::SMLoc loc, const Twine &message) override {
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
emittedError = true;
+ return true;
}
bool didEmitError() const { return emittedError; }
@@ -1830,15 +1831,17 @@
consumeToken();
// Have the op implementation take a crack and parsing this.
- auto result = opDefinition->parseAssembly(&opAsmParser);
+ OperationState opState(builder.getIdentifier(opName));
+ if (opDefinition->parseAssembly(&opAsmParser, &opState))
+ return nullptr;
// If it emitted an error, we failed.
if (opAsmParser.didEmitError())
return nullptr;
// Otherwise, we succeeded. Use the state it parsed as our op information.
- auto nameId = builder.getIdentifier(opName);
- return createOpFunc(nameId, result.operands, result.types, result.attributes);
+ return createOpFunc(opState.name, opState.operands, opState.types,
+ opState.attributes);
}
//===----------------------------------------------------------------------===//
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index 8275d44..27d2729 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -75,7 +75,7 @@
bb0:
%0 = "constant"() {value: 7} : () -> affineint
// Test alloc with wrong number of dynamic dimensions.
- %1 = alloc(%0)[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{Dynamic dimensions count mismatch: dimension operand count does not equal memref dynamic dimension count}}
+ %1 = alloc(%0)[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{custom op 'alloc' dimension operand count does not equal memref dynamic dimension count}}
return
}
@@ -85,7 +85,7 @@
bb0:
%0 = "constant"() {value: 7} : () -> affineint
// Test alloc with wrong number of symbols
- %1 = alloc(%0) : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{AffineMap symbol count mismatch: symbol operand count does not equal memref affine map symbol count}}
+ %1 = alloc(%0) : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{custom op 'alloc' affine map symbol operand count does not equal memref affine map symbol count}}
return
}