Refactor the asmparser hook to work with a new OperationState type that fully
encapsulates an operation that is yet to be created. This is a patch towards
custom ops providing create methods that don't need to be templated, allowing
them to move out of line in the future.
PiperOrigin-RevId: 207725557
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 715e460..4424bd3 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -158,10 +158,36 @@
return b;
}
+/// Create an operation given the fields represented as an OperationState.
+OperationInst *CFGFuncBuilder::createOperation(const OperationState &state) {
+ SmallVector<CFGValue *, 8> operands;
+ operands.reserve(state.operands.size());
+ for (auto elt : state.operands)
+ operands.push_back(cast<CFGValue>(elt));
+
+ auto *op = OperationInst::create(state.name, operands, state.types,
+ state.attributes, context);
+ block->getOperations().insert(insertPoint, op);
+ return op;
+}
+
//===----------------------------------------------------------------------===//
// Statements.
//===----------------------------------------------------------------------===//
+/// Create an operation given the fields represented as an OperationState.
+OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
+ SmallVector<MLValue *, 8> operands;
+ operands.reserve(state.operands.size());
+ for (auto elt : state.operands)
+ operands.push_back(cast<MLValue>(elt));
+
+ auto *op = OperationStmt::create(state.name, operands, state.types,
+ state.attributes, context);
+ block->getStatements().insert(insertPoint, op);
+ return op;
+}
+
ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step) {
diff --git a/lib/IR/OperationSet.cpp b/lib/IR/OperationSet.cpp
index 9bbc84e..98bf5fe 100644
--- a/lib/IR/OperationSet.cpp
+++ b/lib/IR/OperationSet.cpp
@@ -27,9 +27,8 @@
OpAsmParser::~OpAsmParser() {}
// The fallback for the printer is to reject the short form.
-OpAsmParserResult OpBaseState::parse(OpAsmParser *parser) {
- parser->emitError(parser->getNameLoc(), "has no concise form");
- return {};
+bool OpBaseState::parse(OpAsmParser *parser, OperationState *result) {
+ return parser->emitError(parser->getNameLoc(), "has no concise form");
}
// The fallback for the printer is to print it the longhand form.
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 1458ce4..b67c12d 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -44,8 +44,8 @@
// Returns 'false' on success and 'true' on error.
static bool
parseDimAndSymbolList(OpAsmParser *parser,
- SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
+ SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
// Store number of dimensions for validation by caller.
@@ -60,21 +60,18 @@
return false;
}
-// TODO: Have verify functions return std::string to enable more descriptive
-// error messages.
-OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
+bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
- SSAValue *lhs, *rhs;
- SmallVector<NamedAttribute, 4> attrs;
if (parser->parseOperandList(ops, 2) ||
- parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
- parser->resolveOperand(ops[0], type, lhs) ||
- parser->resolveOperand(ops[1], type, rhs))
- return {};
+ parser->resolveOperands(ops, type, result->operands))
+ return true;
- return OpAsmParserResult({lhs, rhs}, type, attrs);
+ // TODO(clattner): rework parseColonType to eliminate the need for this.
+ result->types.push_back(type);
+ return false;
}
void AddFOp::print(OpAsmPrinter *p) const {
@@ -83,6 +80,8 @@
*p << " : " << *getType();
}
+// TODO: Have verify functions return std::string to enable more descriptive
+// error messages.
// Return an error message on failure.
const char *AddFOp::verify() const {
// TODO: Check that the types of the LHS and RHS match.
@@ -91,31 +90,26 @@
return nullptr;
}
-OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
- SmallVector<OpAsmParser::OperandType, 2> opInfos;
- SmallVector<SSAValue *, 4> operands;
- SmallVector<NamedAttribute, 4> attrs;
-
+bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
unsigned numDims;
- if (parser->parseAttribute(mapAttr, "map", attrs) ||
- parseDimAndSymbolList(parser, opInfos, operands, numDims) ||
- parser->parseOptionalAttributeDict(attrs))
- return {};
+ if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
+ parseDimAndSymbolList(parser, result->operands, numDims) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return true;
auto *map = mapAttr->getValue();
if (map->getNumDims() != numDims ||
- numDims + map->getNumSymbols() != opInfos.size()) {
- parser->emitError(parser->getNameLoc(),
- "dimension or symbol index mismatch");
- return {};
+ numDims + map->getNumSymbols() != result->operands.size()) {
+ return parser->emitError(parser->getNameLoc(),
+ "dimension or symbol index mismatch");
}
- SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
- return OpAsmParserResult(operands, resultTypes, attrs);
+ result->types.append(map->getNumResults(), affineIntTy);
+ return false;
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
@@ -155,39 +149,37 @@
*p << " : " << *type;
}
-OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
+bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MemRefType *type;
- SmallVector<SSAValue *, 4> operands;
- SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
- SmallVector<NamedAttribute, 4> attrs;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
unsigned numDimOperands;
- if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
- parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
- return {};
+ if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return true;
// Check numDynamicDims against number of question marks in memref type.
if (numDimOperands != type->getNumDynamicDims()) {
- parser->emitError(parser->getNameLoc(),
- "Dynamic dimensions count mismatch: dimension operand "
- "count does not equal memref dynamic dimension count.");
- return {};
+ return parser->emitError(parser->getNameLoc(),
+ "dimension operand count does not equal memref "
+ "dynamic dimension count");
}
// Check that the number of symbol operands matches the number of symbols in
// the first affinemap of the memref's affine map composition.
// Note that a memref must specify at least one affine map in the composition.
- if ((operandsInfo.size() - numDimOperands) !=
+ if (result->operands.size() - numDimOperands !=
type->getAffineMaps()[0]->getNumSymbols()) {
- parser->emitError(parser->getNameLoc(),
- "AffineMap symbol count mismatch: symbol operand "
- "count does not equal memref affine map symbol count.");
- return {};
+ return parser->emitError(
+ parser->getNameLoc(),
+ "affine map symbol operand count does not equal memref affine map "
+ "symbol count");
}
- return OpAsmParserResult(operands, type, attrs);
+ result->types.push_back(type);
+ return false;
}
const char *AllocOp::verify() const {
@@ -201,15 +193,17 @@
*p << " : " << *getType();
}
-OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
+bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute *valueAttr;
Type *type;
- SmallVector<NamedAttribute, 4> attrs;
- if (parser->parseAttribute(valueAttr, "value", attrs) ||
- parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
- return {};
- return OpAsmParserResult(/*operands=*/{}, type, attrs);
+ if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return true;
+
+ result->types.push_back(type);
+ return false;
}
/// The constant op requires an attribute, and furthermore requires that it
@@ -247,22 +241,22 @@
*p << " : " << *getOperand()->getType();
}
-OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
+bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr *indexAttr;
Type *type;
- SSAValue *operand;
- SmallVector<NamedAttribute, 4> attrs;
+ // TODO(clattner): remove resolveOperand or change it to push onto the
+ // operands list.
if (parser->parseOperand(operandInfo) || parser->parseComma() ||
- parser->parseAttribute(indexAttr, "index", attrs) ||
- parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseAttribute(indexAttr, "index", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
- parser->resolveOperand(operandInfo, type, operand))
- return {};
+ parser->resolveOperands(operandInfo, type, result->operands))
+ return true;
- auto &builder = parser->getBuilder();
- return OpAsmParserResult(operand, builder.getAffineIntType(), attrs);
+ result->types.push_back(parser->getBuilder().getAffineIntType());
+ return false;
}
const char *DimOp::verify() const {
@@ -297,23 +291,23 @@
*p << " : " << *getMemRef()->getType();
}
-OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
+bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type;
- SmallVector<SSAValue *, 4> operands;
- SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
- parser->resolveOperands(memrefInfo, type, operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, operands))
- return {};
+ // TODO: use a new resolveOperand()
+ parser->resolveOperands(memrefInfo, type, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands))
+ return true;
- return OpAsmParserResult(operands, type->getElementType(), attrs);
+ result->types.push_back(type->getElementType());
+ return false;
}
const char *LoadOp::verify() const {
@@ -344,27 +338,24 @@
*p << " : " << *getMemRef()->getType();
}
-OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
+bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- SmallVector<SSAValue *, 4> operands;
MemRefType *memrefType;
- SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
- if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
- parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(attrs) ||
- parser->parseColonType(memrefType) ||
- parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
- operands) ||
- parser->resolveOperands(memrefInfo, memrefType, operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, operands))
- return {};
-
- return OpAsmParserResult(operands, {}, attrs);
+ return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+ parser->parseOperand(memrefInfo) ||
+ parser->parseOperandList(indexInfo, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(memrefType) ||
+ parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
+ result->operands) ||
+ // TODO: use a new resolveOperand().
+ parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands);
}
const char *StoreOp::verify() const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index b50e950..d093064 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1799,10 +1799,11 @@
return result == nullptr;
}
- /// Emit a diagnostic at the specified location.
- void emitError(llvm::SMLoc loc, const Twine &message) override {
+ /// Emit a diagnostic at the specified location and return true.
+ bool emitError(llvm::SMLoc loc, const Twine &message) override {
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
emittedError = true;
+ return true;
}
bool didEmitError() const { return emittedError; }
@@ -1830,15 +1831,17 @@
consumeToken();
// Have the op implementation take a crack and parsing this.
- auto result = opDefinition->parseAssembly(&opAsmParser);
+ OperationState opState(builder.getIdentifier(opName));
+ if (opDefinition->parseAssembly(&opAsmParser, &opState))
+ return nullptr;
// If it emitted an error, we failed.
if (opAsmParser.didEmitError())
return nullptr;
// Otherwise, we succeeded. Use the state it parsed as our op information.
- auto nameId = builder.getIdentifier(opName);
- return createOpFunc(nameId, result.operands, result.types, result.attributes);
+ return createOpFunc(opState.name, opState.operands, opState.types,
+ opState.attributes);
}
//===----------------------------------------------------------------------===//