Implement initial support for function attributes, including parser, printer,
resolver support.
Still TODO are verifier support (to make sure you don't use an attribute for a
function in another module) and the TODO in ModuleParser::finalizeModule that I
will handle in the next patch.
PiperOrigin-RevId: 209361648
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 896b76e..8303f57 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -415,6 +415,16 @@
case Attribute::Kind::Type:
printType(cast<TypeAttr>(attr)->getValue());
break;
+ case Attribute::Kind::Function: {
+ auto *function = cast<FunctionAttr>(attr)->getValue();
+ if (!function) {
+ os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
+ } else {
+ os << '@' << function->getName() << " : ";
+ printType(function->getType());
+ }
+ break;
+ }
}
}
@@ -784,6 +794,11 @@
}
} else if (auto intOp = op->getAs<ConstantAffineIntOp>()) {
specialName << 'c' << intOp->getValue();
+ } else if (auto constant = op->getAs<ConstantOp>()) {
+ if (isa<FunctionAttr>(constant->getValue()))
+ specialName << 'f';
+ else
+ specialName << "cst";
}
}
@@ -909,7 +924,7 @@
// Filter out any attributes that shouldn't be included.
SmallVector<NamedAttribute, 8> filteredAttrs;
for (auto attr : attrs) {
- auto attrName = attr.first.ref();
+ auto attrName = attr.first.strref();
// Never print attributes that start with a colon. These are internal
// attributes that represent location or other internal metadata.
if (attrName.startswith(":"))
@@ -946,7 +961,7 @@
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
- if (auto *opInfo = state.operationSet->lookup(op->getName().ref())) {
+ if (auto *opInfo = state.operationSet->lookup(op->getName())) {
opInfo->printAssembly(op, this);
return;
}
@@ -957,7 +972,7 @@
void FunctionPrinter::printDefaultOp(const Operation *op) {
os << '"';
- printEscapedString(op->getName().ref(), os);
+ printEscapedString(op->getName(), os);
os << "\"(";
interleaveComma(op->getOperands(),
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index be334ef..5fce94c 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -111,6 +111,10 @@
return TypeAttr::get(type, context);
}
+FunctionAttr *Builder::getFunctionAttr(Function *value) {
+ return FunctionAttr::get(value, context);
+}
+
//===----------------------------------------------------------------------===//
// Affine Expressions, Affine Maps, and Integet Sets.
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index df3d7c8..bfcfd6f 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -15,6 +15,7 @@
// limitations under the License.
// =============================================================================
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
@@ -27,6 +28,11 @@
Function::Function(StringRef name, FunctionType *type, Kind kind)
: kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
+Function::~Function() {
+ // Clean up function attributes referring to this function.
+ FunctionAttr::dropFunctionReference(this);
+}
+
MLIRContext *Function::getContext() const { return getType()->getContext(); }
/// Delete this object.
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 644a74a..a960223 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OperationSet.h"
@@ -254,6 +255,7 @@
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
+ DenseMap<Function *, FunctionAttr *> functionAttrs;
public:
MLIRContextImpl() : identifiers(allocator) {
@@ -633,6 +635,34 @@
return result;
}
+FunctionAttr *FunctionAttr::get(Function *value, MLIRContext *context) {
+ auto *&result = context->getImpl().functionAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<FunctionAttr>();
+ new (result) FunctionAttr(value);
+ return result;
+}
+
+/// This function is used by the internals of the Function class to null out
+/// attributes refering to functions that are about to be deleted.
+void FunctionAttr::dropFunctionReference(Function *value) {
+ // Check to see if there was an attribute referring to this function.
+ auto &functionAttrs = value->getContext()->getImpl().functionAttrs;
+
+ // If not, then we're done.
+ auto it = functionAttrs.find(value);
+ if (it == functionAttrs.end())
+ return;
+
+ // If so, null out the function reference in the attribute (to avoid dangling
+ // pointers) and remove the entry from the map so the map doesn't contain
+ // dangling keys.
+ it->second->value = nullptr;
+ functionAttrs.erase(it);
+}
+
/// Perform a three-way comparison between the names of the specified
/// NamedAttributes.
static int compareNamedAttributes(const NamedAttribute *lhs,
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index 5a33d26..36b892a 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -21,13 +21,13 @@
Module::Module(MLIRContext *context) : context(context) {}
/// Look up a function with the specified name, returning null if no such
-/// name exists.
+/// name exists. Function names never include the @ on them.
Function *Module::getNamedFunction(StringRef name) {
return getNamedFunction(Identifier::get(name, context));
}
/// Look up a function with the specified name, returning null if no such
-/// name exists.
+/// name exists. Function names never include the @ on them.
Function *Module::getNamedFunction(Identifier name) {
auto it = symbolTable.find(name);
return it != symbolTable.end() ? it->second : nullptr;
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 723dcc2..6d8c366 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -204,16 +204,25 @@
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant " << *getValue();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
- *p << " : " << *getType();
+
+ if (!isa<FunctionAttr>(getValue()))
+ *p << " : " << *getType();
}
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute *valueAttr;
Type *type;
- return parser->parseAttribute(valueAttr, "value", result->attributes) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type) ||
+ if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return true;
+
+ // 'constant' taking a function reference doesn't get a redundant type
+ // specifier. The attribute itself carries it.
+ if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
+ return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
+
+ return parser->parseColonType(type) ||
parser->addTypeToList(type, result->types);
}
@@ -244,7 +253,9 @@
}
if (isa<FunctionType>(type)) {
- // TODO: Verify a function attr.
+ if (!isa<FunctionAttr>(value))
+ return "requires 'value' to be a function reference";
+ return nullptr;
}
return "requires a result type that aligns with the 'value' attribute";