More grooming of custom op parser APIs to allow many of them to use a single
parsing chain and resolve TODOs. NFC.
PiperOrigin-RevId: 207913754
diff --git a/include/mlir/IR/OpImplementation.h b/include/mlir/IR/OpImplementation.h
index 65b25ae..8172e7e 100644
--- a/include/mlir/IR/OpImplementation.h
+++ b/include/mlir/IR/OpImplementation.h
@@ -182,6 +182,14 @@
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result,
llvm::SMLoc *loc = nullptr) = 0;
+ /// Add the specified type to the end of the specified type list and return
+ /// false. This is a helper designed to allow parse methods to be simple and
+ /// chain through || operators.
+ bool addTypeToList(Type *type, SmallVectorImpl<Type *> &result) {
+ result.push_back(type);
+ return false;
+ }
+
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.
@@ -263,18 +271,15 @@
/// Resolve an operand to an SSA value, emitting an error and returning true
/// on failure.
virtual bool resolveOperand(OperandType operand, Type *type,
- SSAValue *&result) = 0;
+ SmallVectorImpl<SSAValue *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning
/// true on failure, or appending the results to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operand, Type *type,
SmallVectorImpl<SSAValue *> &result) {
- for (auto elt : operand) {
- SSAValue *value;
- if (resolveOperand(elt, type, value))
+ for (auto elt : operand)
+ if (resolveOperand(elt, type, result))
return true;
- result.push_back(value);
- }
return false;
}
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 95d7931..9644a00 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -63,15 +63,11 @@
bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
- if (parser->parseOperandList(ops, 2) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type) ||
- parser->resolveOperands(ops, type, result->operands))
- return true;
-
- // TODO(clattner): rework parseColonType to eliminate the need for this.
- result->types.push_back(type);
- return false;
+ return parser->parseOperandList(ops, 2) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperands(ops, type, result->operands) ||
+ parser->addTypeToList(type, result->types);
}
void AddFOp::print(OpAsmPrinter *p) const {
@@ -197,13 +193,10 @@
Attribute *valueAttr;
Type *type;
- if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type))
- return true;
-
- result->types.push_back(type);
- return false;
+ return parser->parseAttribute(valueAttr, "value", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->addTypeToList(type, result->types);
}
/// The constant op requires an attribute, and furthermore requires that it
@@ -267,17 +260,13 @@
IntegerAttr *indexAttr;
Type *type;
- // TODO(clattner): remove resolveOperand or change it to push onto the
- // operands list.
- if (parser->parseOperand(operandInfo) || parser->parseComma() ||
- parser->parseAttribute(indexAttr, "index", result->attributes) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type) ||
- parser->resolveOperands(operandInfo, type, result->operands))
- return true;
-
- result->types.push_back(parser->getBuilder().getAffineIntType());
- return false;
+ return parser->parseOperand(operandInfo) || parser->parseComma() ||
+ parser->parseAttribute(indexAttr, "index", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(operandInfo, type, result->operands) ||
+ parser->addTypeToList(parser->getBuilder().getAffineIntType(),
+ result->types);
}
const char *DimOp::verify() const {
@@ -318,17 +307,14 @@
MemRefType *type;
auto affineIntTy = parser->getBuilder().getAffineIntType();
- if (parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type) ||
- // TODO: use a new resolveOperand()
- parser->resolveOperands(memrefInfo, type, result->operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, result->operands))
- return true;
-
- result->types.push_back(type->getElementType());
- return false;
+ return parser->parseOperand(memrefInfo) ||
+ parser->parseOperandList(indexInfo, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(memrefInfo, type, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+ parser->addTypeToList(type->getElementType(), result->types);
}
const char *LoadOp::verify() const {
@@ -372,10 +358,9 @@
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) ||
- parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
- result->operands) ||
- // TODO: use a new resolveOperand().
- parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
+ parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
+ result->operands) ||
+ parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 5ca4e56..5c34bff 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1798,11 +1798,14 @@
llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(OperandType operand, Type *type,
- SSAValue *&result) override {
+ SmallVectorImpl<SSAValue *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
- result = parser.resolveSSAUse(operandInfo, type);
- return result == nullptr;
+ if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
+ result.push_back(value);
+ return false;
+ }
+ return true;
}
/// Emit a diagnostic at the specified location and return true.