Implement initial support for function attributes, including parser, printer,
resolver support.
Still TODO are verifier support (to make sure you don't use an attribute for a
function in another module) and the TODO in ModuleParser::finalizeModule that I
will handle in the next patch.
PiperOrigin-RevId: 209361648
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 644a74a..a960223 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OperationSet.h"
@@ -254,6 +255,7 @@
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
+ DenseMap<Function *, FunctionAttr *> functionAttrs;
public:
MLIRContextImpl() : identifiers(allocator) {
@@ -633,6 +635,34 @@
return result;
}
+FunctionAttr *FunctionAttr::get(Function *value, MLIRContext *context) {
+ auto *&result = context->getImpl().functionAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<FunctionAttr>();
+ new (result) FunctionAttr(value);
+ return result;
+}
+
+/// This function is used by the internals of the Function class to null out
+/// attributes refering to functions that are about to be deleted.
+void FunctionAttr::dropFunctionReference(Function *value) {
+ // Check to see if there was an attribute referring to this function.
+ auto &functionAttrs = value->getContext()->getImpl().functionAttrs;
+
+ // If not, then we're done.
+ auto it = functionAttrs.find(value);
+ if (it == functionAttrs.end())
+ return;
+
+ // If so, null out the function reference in the attribute (to avoid dangling
+ // pointers) and remove the entry from the map so the map doesn't contain
+ // dangling keys.
+ it->second->value = nullptr;
+ functionAttrs.erase(it);
+}
+
/// Perform a three-way comparison between the names of the specified
/// NamedAttributes.
static int compareNamedAttributes(const NamedAttribute *lhs,