Give custom ops the ability to also access general additional attributes in the
parser and printer. Fix the spelling of 'delimeter'
PiperOrigin-RevId: 207189892
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 0edafcb..1b06a62 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -99,6 +99,9 @@
const Operation *getOperation() const { return state; }
Operation *getOperation() { return state; }
+ /// Return all of the attributes on this operation.
+ ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); }
+
/// Return an attribute with the specified name.
Attribute *getAttr(StringRef name) const { return state->getAttr(name); }
diff --git a/include/mlir/IR/OpImplementation.h b/include/mlir/IR/OpImplementation.h
index d377bed..62b8e4f 100644
--- a/include/mlir/IR/OpImplementation.h
+++ b/include/mlir/IR/OpImplementation.h
@@ -69,6 +69,14 @@
virtual void printAffineMap(const AffineMap *map) = 0;
virtual void printAffineExpr(const AffineExpr *expr) = 0;
+ /// If the specified operation has attributes, print out an attribute
+ /// dictionary with their values. elidedAttrs allows the client to ignore
+ /// specific well known attributes, commonly used if the attribute value is
+ /// printed some other way (like as a fixed operand).
+ virtual void
+ printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+ ArrayRef<const char *> elidedAttrs = {}) = 0;
+
/// Print the entire operation with the default verbose formatting.
virtual void printDefaultOp(const Operation *op) = 0;
@@ -127,7 +135,7 @@
///
/// The "%x = load" tokens are already parsed and therefore invisible to the
/// custom op parser. This can be supported by calling `parseOperandList` to
-/// parse the %p, then calling `parseOperandList` with a `SquareDelimeter` to
+/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
/// parse the indices, then calling `parseColonTypeList` to parse the result
/// type.
///
@@ -174,17 +182,23 @@
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result,
llvm::SMLoc *loc = nullptr) = 0;
- /// Parse an attribute.
- virtual bool parseAttribute(Attribute *&result,
+ /// Parse an arbitrary attribute and return it in result. This also adds the
+ /// attribute to the specified attribute list with the specified name. this
+ /// captures the location of the attribute in 'loc' if it is non-null.
+ virtual bool parseAttribute(Attribute *&result, const char *attrName,
+ SmallVectorImpl<NamedAttribute> &attrs,
llvm::SMLoc *loc = nullptr) = 0;
- /// Parse an attribute of a specific kind.
+ /// Parse an attribute of a specific kind, capturing the location into `loc`
+ /// if specified.
template <typename AttrType>
- bool parseAttribute(AttrType *&result, llvm::SMLoc *loc = nullptr) {
+ bool parseAttribute(AttrType *&result, const char *attrName,
+ SmallVectorImpl<NamedAttribute> &attrs,
+ llvm::SMLoc *loc = nullptr) {
// Parse any kind of attribute.
Attribute *attr;
llvm::SMLoc tmpLoc;
- if (parseAttribute(attr, &tmpLoc))
+ if (parseAttribute(attr, attrName, attrs, &tmpLoc))
return true;
if (loc)
*loc = tmpLoc;
@@ -199,6 +213,11 @@
return false;
}
+ /// If a named attribute dictionary is present, parse it into result.
+ virtual bool
+ parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result,
+ llvm::SMLoc *loc = nullptr) = 0;
+
/// This is the representation of an operand reference.
struct OperandType {
llvm::SMLoc location; // Location of the token.
@@ -209,27 +228,26 @@
/// Parse a single operand.
virtual bool parseOperand(OperandType &result) = 0;
- /// These are the supported delimeters around operand lists, used by
+ /// These are the supported delimiters around operand lists, used by
/// parseOperandList.
- enum Delimeter {
- /// Zero or more operands with no delimeters.
- NoDelimeter,
+ enum Delimiter {
+ /// Zero or more operands with no delimiters.
+ None,
/// Parens surrounding zero or more operands.
- ParenDelimeter,
+ Paren,
/// Square brackets surrounding zero or more operands.
- SquareDelimeter,
+ Square,
/// Parens supporting zero or more operands, or nothing.
- OptionalParenDelimeter,
+ OptionalParen,
/// Square brackets supporting zero or more ops, or nothing.
- OptionalSquareDelimeter,
+ OptionalSquare,
};
/// Parse zero or more SSA comma-separated operand references with a specified
- /// surrounding delimeter, and an optional required operand count.
- virtual bool
- parseOperandList(SmallVectorImpl<OperandType> &result,
- int requiredOperandCount = -1,
- Delimeter delimeter = Delimeter::NoDelimeter) = 0;
+ /// surrounding delimiter, and an optional required operand count.
+ virtual bool parseOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount = -1,
+ Delimiter delimiter = Delimiter::None) = 0;
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index 1e2d614..b798f8e 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -101,6 +101,7 @@
// (maybe a dozen or so, but not hundreds or thousands) so we use linear
// searches for everything.
+ /// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const;
/// Return the specified attribute if present, null otherwise.
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index d84bdc3..221f453 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -579,6 +579,8 @@
}
void printOperand(const SSAValue *value) { printValueID(value); }
+ void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+ ArrayRef<const char *> elidedAttrs = {}) override;
enum { nameSentinel = ~0U };
@@ -711,6 +713,44 @@
};
} // end anonymous namespace
+void FunctionPrinter::printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs) {
+ // If there are no attributes, then there is nothing to be done.
+ if (attrs.empty())
+ return;
+
+ // Filter out any attributes that shouldn't be included.
+ SmallVector<NamedAttribute, 8> filteredAttrs;
+ for (auto attr : attrs) {
+ auto attrName = attr.first.str();
+ // Never print attributes that start with a colon. These are internal
+ // attributes that represent location or other internal metadata.
+ if (attrName.startswith(":"))
+ continue;
+
+ // If the caller has requested that this attribute be ignored, then drop it.
+ bool ignore = false;
+ for (const char *elide : elidedAttrs)
+ ignore |= attrName == StringRef(elide);
+
+ // Otherwise add it to our filteredAttrs list.
+ if (!ignore)
+ filteredAttrs.push_back(attr);
+ }
+
+ // If there are no attributes left to print after filtering, then we're done.
+ if (filteredAttrs.empty())
+ return;
+
+ // Otherwise, print them all out in braces.
+ os << " {";
+ interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+ os << attr.first << ": ";
+ printAttribute(attr.second);
+ });
+ os << '}';
+}
+
void FunctionPrinter::printOperation(const Operation *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*printResultNo=*/false);
@@ -737,14 +777,7 @@
os << ')';
auto attrs = op->getAttrs();
- if (!attrs.empty()) {
- os << '{';
- interleaveComma(attrs, [&](NamedAttribute attr) {
- os << attr.first << ": ";
- printAttribute(attr.second);
- });
- os << '}';
- }
+ printOptionalAttrDict(attrs);
// Print the type signature of the operation.
os << " : (";
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 8f4f1b3..1458ce4 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -46,16 +46,15 @@
parseDimAndSymbolList(OpAsmParser *parser,
SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
- if (parser->parseOperandList(opInfos, -1,
- OpAsmParser::Delimeter::ParenDelimeter))
+ if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
// Parse the optional symbol operands.
auto *affineIntTy = parser->getBuilder().getAffineIntType();
- if (parser->parseOperandList(
- opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
+ if (parser->parseOperandList(opInfos, -1,
+ OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return true;
return false;
@@ -67,17 +66,21 @@
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
SSAValue *lhs, *rhs;
- if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
+ SmallVector<NamedAttribute, 4> attrs;
+ if (parser->parseOperandList(ops, 2) ||
+ parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColonType(type) ||
parser->resolveOperand(ops[0], type, lhs) ||
parser->resolveOperand(ops[1], type, rhs))
return {};
- return OpAsmParserResult({lhs, rhs}, type);
+ return OpAsmParserResult({lhs, rhs}, type, attrs);
}
void AddFOp::print(OpAsmPrinter *p) const {
- *p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
- << *getType();
+ *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getType();
}
// Return an error message on failure.
@@ -91,14 +94,16 @@
OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 2> opInfos;
SmallVector<SSAValue *, 4> operands;
+ SmallVector<NamedAttribute, 4> attrs;
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
unsigned numDims;
- if (parser->parseAttribute(mapAttr) ||
- parseDimAndSymbolList(parser, opInfos, operands, numDims))
+ if (parser->parseAttribute(mapAttr, "map", attrs) ||
+ parseDimAndSymbolList(parser, opInfos, operands, numDims) ||
+ parser->parseOptionalAttributeDict(attrs))
return {};
auto *map = mapAttr->getValue();
@@ -110,15 +115,14 @@
}
SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
- return OpAsmParserResult(
- operands, resultTypes,
- NamedAttribute(builder.getIdentifier("map"), mapAttr));
+ return OpAsmParserResult(operands, resultTypes, attrs);
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
}
const char *AffineApplyOp::verify() const {
@@ -147,7 +151,7 @@
// Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(),
type->getNumDynamicDims(), p);
- // Print memref type.
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
*p << " : " << *type;
}
@@ -155,12 +159,13 @@
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
+ SmallVector<NamedAttribute, 4> attrs;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
unsigned numDimOperands;
if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
- parser->parseColonType(type))
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
return {};
// Check numDynamicDims against number of question marks in memref type.
@@ -182,7 +187,7 @@
return {};
}
- return OpAsmParserResult(operands, type);
+ return OpAsmParserResult(operands, type, attrs);
}
const char *AllocOp::verify() const {
@@ -191,19 +196,20 @@
}
void ConstantOp::print(OpAsmPrinter *p) const {
- *p << "constant " << *getValue() << " : " << *getType();
+ *p << "constant " << *getValue();
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
+ *p << " : " << *getType();
}
OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
Attribute *valueAttr;
Type *type;
- if (parser->parseAttribute(valueAttr) || parser->parseColonType(type))
- return {};
+ SmallVector<NamedAttribute, 4> attrs;
- auto &builder = parser->getBuilder();
- return OpAsmParserResult(
- /*operands=*/{}, type,
- NamedAttribute(builder.getIdentifier("value"), valueAttr));
+ if (parser->parseAttribute(valueAttr, "value", attrs) ||
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
+ return {};
+ return OpAsmParserResult(/*operands=*/{}, type, attrs);
}
/// The constant op requires an attribute, and furthermore requires that it
@@ -236,8 +242,9 @@
}
void DimOp::print(OpAsmPrinter *p) const {
- *p << "dim " << *getOperand() << ", " << getIndex() << " : "
- << *getOperand()->getType();
+ *p << "dim " << *getOperand() << ", " << getIndex();
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
+ *p << " : " << *getOperand()->getType();
}
OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
@@ -245,15 +252,17 @@
IntegerAttr *indexAttr;
Type *type;
SSAValue *operand;
+ SmallVector<NamedAttribute, 4> attrs;
+
if (parser->parseOperand(operandInfo) || parser->parseComma() ||
- parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
+ parser->parseAttribute(indexAttr, "index", attrs) ||
+ parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, operand))
return {};
auto &builder = parser->getBuilder();
- return OpAsmParserResult(
- operand, builder.getAffineIntType(),
- NamedAttribute(builder.getIdentifier("index"), indexAttr));
+ return OpAsmParserResult(operand, builder.getAffineIntType(), attrs);
}
const char *DimOp::verify() const {
@@ -283,7 +292,9 @@
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
- *p << "] : " << *getMemRef()->getType();
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getMemRef()->getType();
}
OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
@@ -291,17 +302,18 @@
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
+ SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1,
- OpAsmParser::Delimeter::SquareDelimeter) ||
+ parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) ||
parser->resolveOperands(memrefInfo, type, operands) ||
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
- return OpAsmParserResult(operands, type->getElementType());
+ return OpAsmParserResult(operands, type->getElementType(), attrs);
}
const char *LoadOp::verify() const {
@@ -327,7 +339,9 @@
*p << "store " << *getValueToStore();
*p << ", " << *getMemRef() << '[';
p->printOperands(getIndices());
- *p << "] : " << *getMemRef()->getType();
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getMemRef()->getType();
}
OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
@@ -336,12 +350,13 @@
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
SmallVector<SSAValue *, 4> operands;
MemRefType *memrefType;
+ SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1,
- OpAsmParser::Delimeter::SquareDelimeter) ||
+ parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(memrefType) ||
parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
operands) ||
@@ -349,7 +364,7 @@
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
- return OpAsmParserResult(operands, {});
+ return OpAsmParserResult(operands, {}, attrs);
}
const char *StoreOp::verify() const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 24bda89..c146a93 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1667,11 +1667,31 @@
return false;
}
- bool parseAttribute(Attribute *&result, llvm::SMLoc *loc = nullptr) override {
+ /// Parse an arbitrary attribute and return it in result. This also adds the
+ /// attribute to the specified attribute list with the specified name. this
+ /// captures the location of the attribute in 'loc' if it is non-null.
+ bool parseAttribute(Attribute *&result, const char *attrName,
+ SmallVectorImpl<NamedAttribute> &attrs,
+ llvm::SMLoc *loc = nullptr) override {
if (loc)
*loc = parser.getToken().getLoc();
result = parser.parseAttribute();
- return result == nullptr;
+ if (!result)
+ return true;
+
+ attrs.push_back(
+ NamedAttribute(parser.builder.getIdentifier(attrName), result));
+ return false;
+ }
+
+ /// If a named attribute list is present, parse is into result.
+ bool parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result,
+ llvm::SMLoc *loc = nullptr) override {
+ if (parser.getToken().isNot(Token::l_brace))
+ return false;
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ return parser.parseAttributeDict(result) == ParseFailure;
}
bool parseOperand(OperandType &result) override {
@@ -1685,26 +1705,26 @@
bool parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
- Delimeter delimeter = Delimeter::NoDelimeter) override {
+ Delimiter delimiter = Delimiter::None) override {
auto startLoc = parser.getToken().getLoc();
- // Handle delimeters.
- switch (delimeter) {
- case Delimeter::NoDelimeter:
+ // Handle delimiters.
+ switch (delimiter) {
+ case Delimiter::None:
break;
- case Delimeter::OptionalParenDelimeter:
+ case Delimiter::OptionalParen:
if (parser.getToken().isNot(Token::l_paren))
return false;
LLVM_FALLTHROUGH;
- case Delimeter::ParenDelimeter:
+ case Delimiter::Paren:
if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
return true;
break;
- case Delimeter::OptionalSquareDelimeter:
+ case Delimiter::OptionalSquare:
if (parser.getToken().isNot(Token::l_square))
return false;
LLVM_FALLTHROUGH;
- case Delimeter::SquareDelimeter:
+ case Delimiter::Square:
if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
return true;
break;
@@ -1720,18 +1740,18 @@
} while (parser.consumeIf(Token::comma));
}
- // Handle delimeters. If we reach here, the optional delimiters were
+ // Handle delimiters. If we reach here, the optional delimiters were
// present, so we need to parse their closing one.
- switch (delimeter) {
- case Delimeter::NoDelimeter:
+ switch (delimiter) {
+ case Delimiter::None:
break;
- case Delimeter::OptionalParenDelimeter:
- case Delimeter::ParenDelimeter:
+ case Delimiter::OptionalParen:
+ case Delimiter::Paren:
if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
return true;
break;
- case Delimeter::OptionalSquareDelimeter:
- case Delimeter::SquareDelimeter:
+ case Delimiter::OptionalSquare:
+ case Delimiter::Square:
if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
return true;
break;
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index ae48e45..bb46bd7 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -45,6 +45,9 @@
// CHECK: %c42_i32_0 = constant 42 : i32
%7 = constant 42 : i32
+
+ // CHECK: %c43 = constant 43 {crazy: "foo"} : affineint
+ %8 = constant 43 {crazy: "foo"} : affineint
return
}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index c481f9b..15e5b8f 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -177,20 +177,20 @@
// CHECK: "foo"()
"foo"(){} : ()->()
- // CHECK: "foo"(){a: 1, b: -423, c: [true, false], d: 1.600000e+01} : () -> ()
- "foo"(){a: 1, b: -423, c: [true, false], d: 16.0 } : () -> ()
+ // CHECK: "foo"() {a: 1, b: -423, c: [true, false], d: 1.600000e+01} : () -> ()
+ "foo"() {a: 1, b: -423, c: [true, false], d: 16.0 } : () -> ()
- // CHECK: "foo"(){map1: #map{{[0-9]+}}}
- "foo"(){map1: #map1} : () -> ()
+ // CHECK: "foo"() {map1: #map{{[0-9]+}}}
+ "foo"() {map1: #map1} : () -> ()
- // CHECK: "foo"(){map2: #map{{[0-9]+}}}
- "foo"(){map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> ()
+ // CHECK: "foo"() {map2: #map{{[0-9]+}}}
+ "foo"() {map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> ()
- // CHECK: "foo"(){map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
- "foo"(){map12: [#map1, #map2]} : () -> ()
+ // CHECK: "foo"() {map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
+ "foo"() {map12: [#map1, #map2]} : () -> ()
- // CHECK: "foo"(){cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
- "foo"(){if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
+ // CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
+ "foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
return
}