Give custom ops the ability to also access general additional attributes in the
parser and printer. Fix the spelling of 'delimeter'
PiperOrigin-RevId: 207189892
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index d84bdc3..221f453 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -579,6 +579,8 @@
}
void printOperand(const SSAValue *value) { printValueID(value); }
+ void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+ ArrayRef<const char *> elidedAttrs = {}) override;
enum { nameSentinel = ~0U };
@@ -711,6 +713,44 @@
};
} // end anonymous namespace
+void FunctionPrinter::printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs) {
+ // If there are no attributes, then there is nothing to be done.
+ if (attrs.empty())
+ return;
+
+ // Filter out any attributes that shouldn't be included.
+ SmallVector<NamedAttribute, 8> filteredAttrs;
+ for (auto attr : attrs) {
+ auto attrName = attr.first.str();
+ // Never print attributes that start with a colon. These are internal
+ // attributes that represent location or other internal metadata.
+ if (attrName.startswith(":"))
+ continue;
+
+ // If the caller has requested that this attribute be ignored, then drop it.
+ bool ignore = false;
+ for (const char *elide : elidedAttrs)
+ ignore |= attrName == StringRef(elide);
+
+ // Otherwise add it to our filteredAttrs list.
+ if (!ignore)
+ filteredAttrs.push_back(attr);
+ }
+
+ // If there are no attributes left to print after filtering, then we're done.
+ if (filteredAttrs.empty())
+ return;
+
+ // Otherwise, print them all out in braces.
+ os << " {";
+ interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+ os << attr.first << ": ";
+ printAttribute(attr.second);
+ });
+ os << '}';
+}
+
void FunctionPrinter::printOperation(const Operation *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*printResultNo=*/false);
@@ -737,14 +777,7 @@
os << ')';
auto attrs = op->getAttrs();
- if (!attrs.empty()) {
- os << '{';
- interleaveComma(attrs, [&](NamedAttribute attr) {
- os << attr.first << ": ";
- printAttribute(attr.second);
- });
- os << '}';
- }
+ printOptionalAttrDict(attrs);
// Print the type signature of the operation.
os << " : (";
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 8f4f1b3..1458ce4 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -46,16 +46,15 @@
parseDimAndSymbolList(OpAsmParser *parser,
SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
- if (parser->parseOperandList(opInfos, -1,
- OpAsmParser::Delimeter::ParenDelimeter))
+ if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
// Parse the optional symbol operands.
auto *affineIntTy = parser->getBuilder().getAffineIntType();
- if (parser->parseOperandList(
- opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
+ if (parser->parseOperandList(opInfos, -1,
+ OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return true;
return false;
@@ -67,17 +66,21 @@
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
SSAValue *lhs, *rhs;
- if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
+ SmallVector<NamedAttribute, 4> attrs;
+ if (parser->parseOperandList(ops, 2) ||
+ parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColonType(type) ||
parser->resolveOperand(ops[0], type, lhs) ||
parser->resolveOperand(ops[1], type, rhs))
return {};
- return OpAsmParserResult({lhs, rhs}, type);
+ return OpAsmParserResult({lhs, rhs}, type, attrs);
}
void AddFOp::print(OpAsmPrinter *p) const {
- *p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
- << *getType();
+ *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getType();
}
// Return an error message on failure.
@@ -91,14 +94,16 @@
OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 2> opInfos;
SmallVector<SSAValue *, 4> operands;
+ SmallVector<NamedAttribute, 4> attrs;
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
unsigned numDims;
- if (parser->parseAttribute(mapAttr) ||
- parseDimAndSymbolList(parser, opInfos, operands, numDims))
+ if (parser->parseAttribute(mapAttr, "map", attrs) ||
+ parseDimAndSymbolList(parser, opInfos, operands, numDims) ||
+ parser->parseOptionalAttributeDict(attrs))
return {};
auto *map = mapAttr->getValue();
@@ -110,15 +115,14 @@
}
SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
- return OpAsmParserResult(
- operands, resultTypes,
- NamedAttribute(builder.getIdentifier("map"), mapAttr));
+ return OpAsmParserResult(operands, resultTypes, attrs);
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
}
const char *AffineApplyOp::verify() const {
@@ -147,7 +151,7 @@
// Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(),
type->getNumDynamicDims(), p);
- // Print memref type.
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
*p << " : " << *type;
}
@@ -155,12 +159,13 @@
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
+ SmallVector<NamedAttribute, 4> attrs;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
unsigned numDimOperands;
if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
- parser->parseColonType(type))
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
return {};
// Check numDynamicDims against number of question marks in memref type.
@@ -182,7 +187,7 @@
return {};
}
- return OpAsmParserResult(operands, type);
+ return OpAsmParserResult(operands, type, attrs);
}
const char *AllocOp::verify() const {
@@ -191,19 +196,20 @@
}
void ConstantOp::print(OpAsmPrinter *p) const {
- *p << "constant " << *getValue() << " : " << *getType();
+ *p << "constant " << *getValue();
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
+ *p << " : " << *getType();
}
OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
Attribute *valueAttr;
Type *type;
- if (parser->parseAttribute(valueAttr) || parser->parseColonType(type))
- return {};
+ SmallVector<NamedAttribute, 4> attrs;
- auto &builder = parser->getBuilder();
- return OpAsmParserResult(
- /*operands=*/{}, type,
- NamedAttribute(builder.getIdentifier("value"), valueAttr));
+ if (parser->parseAttribute(valueAttr, "value", attrs) ||
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
+ return {};
+ return OpAsmParserResult(/*operands=*/{}, type, attrs);
}
/// The constant op requires an attribute, and furthermore requires that it
@@ -236,8 +242,9 @@
}
void DimOp::print(OpAsmPrinter *p) const {
- *p << "dim " << *getOperand() << ", " << getIndex() << " : "
- << *getOperand()->getType();
+ *p << "dim " << *getOperand() << ", " << getIndex();
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
+ *p << " : " << *getOperand()->getType();
}
OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
@@ -245,15 +252,17 @@
IntegerAttr *indexAttr;
Type *type;
SSAValue *operand;
+ SmallVector<NamedAttribute, 4> attrs;
+
if (parser->parseOperand(operandInfo) || parser->parseComma() ||
- parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
+ parser->parseAttribute(indexAttr, "index", attrs) ||
+ parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, operand))
return {};
auto &builder = parser->getBuilder();
- return OpAsmParserResult(
- operand, builder.getAffineIntType(),
- NamedAttribute(builder.getIdentifier("index"), indexAttr));
+ return OpAsmParserResult(operand, builder.getAffineIntType(), attrs);
}
const char *DimOp::verify() const {
@@ -283,7 +292,9 @@
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
- *p << "] : " << *getMemRef()->getType();
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getMemRef()->getType();
}
OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
@@ -291,17 +302,18 @@
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
+ SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1,
- OpAsmParser::Delimeter::SquareDelimeter) ||
+ parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) ||
parser->resolveOperands(memrefInfo, type, operands) ||
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
- return OpAsmParserResult(operands, type->getElementType());
+ return OpAsmParserResult(operands, type->getElementType(), attrs);
}
const char *LoadOp::verify() const {
@@ -327,7 +339,9 @@
*p << "store " << *getValueToStore();
*p << ", " << *getMemRef() << '[';
p->printOperands(getIndices());
- *p << "] : " << *getMemRef()->getType();
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getMemRef()->getType();
}
OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
@@ -336,12 +350,13 @@
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
SmallVector<SSAValue *, 4> operands;
MemRefType *memrefType;
+ SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1,
- OpAsmParser::Delimeter::SquareDelimeter) ||
+ parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(memrefType) ||
parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
operands) ||
@@ -349,7 +364,7 @@
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
- return OpAsmParserResult(operands, {});
+ return OpAsmParserResult(operands, {}, attrs);
}
const char *StoreOp::verify() const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 24bda89..c146a93 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1667,11 +1667,31 @@
return false;
}
- bool parseAttribute(Attribute *&result, llvm::SMLoc *loc = nullptr) override {
+ /// Parse an arbitrary attribute and return it in result. This also adds the
+ /// attribute to the specified attribute list with the specified name. this
+ /// captures the location of the attribute in 'loc' if it is non-null.
+ bool parseAttribute(Attribute *&result, const char *attrName,
+ SmallVectorImpl<NamedAttribute> &attrs,
+ llvm::SMLoc *loc = nullptr) override {
if (loc)
*loc = parser.getToken().getLoc();
result = parser.parseAttribute();
- return result == nullptr;
+ if (!result)
+ return true;
+
+ attrs.push_back(
+ NamedAttribute(parser.builder.getIdentifier(attrName), result));
+ return false;
+ }
+
+ /// If a named attribute list is present, parse is into result.
+ bool parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result,
+ llvm::SMLoc *loc = nullptr) override {
+ if (parser.getToken().isNot(Token::l_brace))
+ return false;
+ if (loc)
+ *loc = parser.getToken().getLoc();
+ return parser.parseAttributeDict(result) == ParseFailure;
}
bool parseOperand(OperandType &result) override {
@@ -1685,26 +1705,26 @@
bool parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
- Delimeter delimeter = Delimeter::NoDelimeter) override {
+ Delimiter delimiter = Delimiter::None) override {
auto startLoc = parser.getToken().getLoc();
- // Handle delimeters.
- switch (delimeter) {
- case Delimeter::NoDelimeter:
+ // Handle delimiters.
+ switch (delimiter) {
+ case Delimiter::None:
break;
- case Delimeter::OptionalParenDelimeter:
+ case Delimiter::OptionalParen:
if (parser.getToken().isNot(Token::l_paren))
return false;
LLVM_FALLTHROUGH;
- case Delimeter::ParenDelimeter:
+ case Delimiter::Paren:
if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
return true;
break;
- case Delimeter::OptionalSquareDelimeter:
+ case Delimiter::OptionalSquare:
if (parser.getToken().isNot(Token::l_square))
return false;
LLVM_FALLTHROUGH;
- case Delimeter::SquareDelimeter:
+ case Delimiter::Square:
if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
return true;
break;
@@ -1720,18 +1740,18 @@
} while (parser.consumeIf(Token::comma));
}
- // Handle delimeters. If we reach here, the optional delimiters were
+ // Handle delimiters. If we reach here, the optional delimiters were
// present, so we need to parse their closing one.
- switch (delimeter) {
- case Delimeter::NoDelimeter:
+ switch (delimiter) {
+ case Delimiter::None:
break;
- case Delimeter::OptionalParenDelimeter:
- case Delimeter::ParenDelimeter:
+ case Delimiter::OptionalParen:
+ case Delimiter::Paren:
if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
return true;
break;
- case Delimeter::OptionalSquareDelimeter:
- case Delimeter::SquareDelimeter:
+ case Delimiter::OptionalSquare:
+ case Delimiter::Square:
if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
return true;
break;