Support for AffineMapAttr.
PiperOrigin-RevId: 205157390
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
index 2ff129b..47d0b9f 100644
--- a/include/mlir/IR/Attributes.h
+++ b/include/mlir/IR/Attributes.h
@@ -22,7 +22,8 @@
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
- class MLIRContext;
+class MLIRContext;
+class AffineMap;
/// Instances of the Attribute class are immutable, uniqued, immortal, and owned
/// by MLIRContext. As such, they are passed around by raw non-const pointer.
@@ -34,6 +35,7 @@
Float,
String,
Array,
+ AffineMap,
// TODO: Function references.
};
@@ -147,7 +149,23 @@
ArrayRef<Attribute*> value;
};
+class AffineMapAttr : public Attribute {
+public:
+ static AffineMapAttr *get(AffineMap *value, MLIRContext *context);
+
+ AffineMap *getValue() const {
+ return value;
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Attribute *attr) {
+ return attr->getKind() == Kind::AffineMap;
+ }
+private:
+ AffineMapAttr(AffineMap *value) : Attribute(Kind::AffineMap), value(value) {}
+ AffineMap *value;
+};
+
} // end namespace mlir.
#endif
-
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 9297fd3..c41a886 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -37,6 +37,7 @@
class FloatAttr;
class StringAttr;
class ArrayAttr;
+class AffineMapAttr;
class AffineMap;
class AffineExpr;
class AffineConstantExpr;
@@ -74,6 +75,7 @@
FloatAttr *getFloatAttr(double value);
StringAttr *getStringAttr(StringRef bytes);
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
+ AffineMapAttr *getAffineMapAttr(AffineMap *value);
// Affine Expressions and Affine Map.
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6a2d932..0e5d101 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -50,6 +50,7 @@
void initialize(const Module *module);
void print(const Module *module);
+ void print(const Attribute *attr) const;
void print(const Type *type) const;
void print(const Function *fn);
void print(const ExtFunction *fn);
@@ -77,6 +78,11 @@
void visitCFGFunction(const CFGFunction *fn);
void visitMLFunction(const MLFunction *fn);
void visitType(const Type *type);
+ void visitAttribute(const Attribute *attr);
+ void visitOperation(const Operation *op);
+
+ void printAffineMapId(int affineMapId) const;
+ void printAffineMapReference(const AffineMap* affineMap) const;
raw_ostream &os;
DenseMap<const AffineMap *, int> affineMapIds;
@@ -113,6 +119,22 @@
}
}
+void ModuleState::visitAttribute(const Attribute *attr) {
+ if (isa<AffineMapAttr>(attr)) {
+ recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+ } else if (isa<ArrayAttr>(attr)) {
+ for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
+ visitAttribute(elt);
+ }
+ }
+}
+
+void ModuleState::visitOperation(const Operation *op) {
+ for (auto elt : op->getAttrs()) {
+ visitAttribute(elt.second);
+ }
+}
+
void ModuleState::visitExtFunction(const ExtFunction *fn) {
visitType(fn->getType());
}
@@ -120,11 +142,16 @@
void ModuleState::visitCFGFunction(const CFGFunction *fn) {
visitType(fn->getType());
// TODO Visit function body instructions.
+ for (auto &block : *fn) {
+ for (auto &op : block.getOperations()) {
+ visitOperation(&op);
+ }
+ }
}
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
- // TODO Visit function body statements.
+ // TODO Visit function body statements (and attributes if required).
}
void ModuleState::visitFunction(const Function *fn) {
@@ -151,13 +178,24 @@
}
// Prints affine map identifier.
-static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+void ModuleState::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
+void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
+ const int mapId = getAffineMapId(affineMap);
+ if (mapId >= 0) {
+ // Map will be printed at top of module so print reference to its id.
+ printAffineMapId(mapId);
+ } else {
+ // Map not in module state so print inline.
+ affineMap->print(os);
+ }
+}
+
void ModuleState::print(const Module *module) {
for (const auto &mapAndId : affineMapIds) {
- printAffineMapId(mapAndId.second, os);
+ printAffineMapId(mapAndId.second);
os << " = ";
mapAndId.first->print(os);
os << '\n';
@@ -165,6 +203,37 @@
for (auto *fn : module->functionList) print(fn);
}
+void ModuleState::print(const Attribute *attr) const {
+ switch (attr->getKind()) {
+ case Attribute::Kind::Bool:
+ os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
+ break;
+ case Attribute::Kind::Integer:
+ os << cast<IntegerAttr>(attr)->getValue();
+ break;
+ case Attribute::Kind::Float:
+ // FIXME: this isn't precise, we should print with a hex format.
+ os << cast<FloatAttr>(attr)->getValue();
+ break;
+ case Attribute::Kind::String:
+ // FIXME: should escape the string.
+ os << '"' << cast<StringAttr>(attr)->getValue() << '"';
+ break;
+ case Attribute::Kind::Array: {
+ auto elts = cast<ArrayAttr>(attr)->getValue();
+ os << '[';
+ interleave(elts,
+ [&](Attribute *attr) { print(attr); },
+ [&]() { os << ", "; });
+ os << ']';
+ break;
+ }
+ case Attribute::Kind::AffineMap:
+ printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+ break;
+ }
+}
+
void ModuleState::print(const Type *type) const {
switch (type->getKind()) {
case Type::Kind::AffineInt:
@@ -243,14 +312,7 @@
os << *v->getElementType();
for (auto map : v->getAffineMaps()) {
os << ", ";
- const int mapId = getAffineMapId(map);
- if (mapId >= 0) {
- // Map will be printed at top of module so print reference to its id.
- printAffineMapId(mapId, os);
- } else {
- // Map not in module state so print inline.
- map->print(os);
- }
+ printAffineMapReference(map);
}
os << ", " << v->getMemorySpace();
os << '>';
@@ -338,7 +400,9 @@
os << '{';
interleave(
attrs,
- [&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
+ [&](NamedAttribute attr) {
+ os << attr.first << ": ";
+ moduleState->print(attr.second); },
[&]() { os << ", "; });
os << '}';
}
@@ -553,6 +617,15 @@
// print and dump methods
//===----------------------------------------------------------------------===//
+void Attribute::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
+ moduleState.print(this);
+}
+
+void Attribute::dump() const {
+ print(llvm::errs());
+}
+
void Type::print(raw_ostream &os) const {
ModuleState moduleState(os);
moduleState.print(this);
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
deleted file mode 100644
index df23424..0000000
--- a/lib/IR/Attributes.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/Attributes.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/Support/STLExtras.h"
-using namespace mlir;
-
-void Attribute::print(raw_ostream &os) const {
- switch (getKind()) {
- case Kind::Bool:
- os << (cast<BoolAttr>(this)->getValue() ? "true" : "false");
- break;
- case Kind::Integer:
- os << cast<IntegerAttr>(this)->getValue();
- break;
- case Kind::Float:
- // FIXME: this isn't precise, we should print with a hex format.
- os << cast<FloatAttr>(this)->getValue();
- break;
- case Kind::String:
- // FIXME: should escape the string.
- os << '"' << cast<StringAttr>(this)->getValue() << '"';
- break;
- case Kind::Array: {
- auto elts = cast<ArrayAttr>(this)->getValue();
- os << '[';
- interleave(elts,
- [&](Attribute *attr) { attr->print(os); },
- [&]() { os << ", "; });
- os << ']';
- break;
- }
- }
-}
-
-void Attribute::dump() const {
- print(llvm::errs());
-}
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index e9bea2a..8d27991 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -94,6 +94,10 @@
return ArrayAttr::get(value, context);
}
+AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
+ return AffineMapAttr::get(value, context);
+}
+
//===----------------------------------------------------------------------===//
// Affine Expressions and Affine Map.
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index eb06e94..df3d01a 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -230,6 +230,7 @@
StringMap<StringAttr*> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr*, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs;
+ DenseMap<AffineMap*, AffineMapAttr*> affineMapAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
@@ -541,6 +542,16 @@
return *existing.first = result;
}
+AffineMapAttr *AffineMapAttr::get(AffineMap* value, MLIRContext *context) {
+ auto *&result = context->getImpl().affineMapAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<AffineMapAttr>();
+ new (result) AffineMapAttr(value);
+ return result;
+}
+
/// Perform a three-way comparison between the names of the specified
/// NamedAttributes.
static int compareNamedAttributes(const NamedAttribute *lhs,
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3d5c908..d220787 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -598,6 +598,11 @@
return builder.getArrayAttr(elements);
}
default:
+ // Try to parse affine map reference.
+ auto* affineMap = parseAffineMapReference();
+ if (affineMap != nullptr)
+ return builder.getAffineMapAttr(affineMap);
+
// TODO: Handle floating point.
return (emitError("expected constant attribute value"), nullptr);
}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 6e0be24..7a1cfce 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -136,6 +136,15 @@
// CHECK: "foo"(){a: 1, b: -423, c: [true, false]} : () -> ()
"foo"(){a: 1, b: -423, c: [true, false] } : () -> ()
+ // CHECK: "foo"(){map1: #map{{[0-9]+}}}
+ "foo"(){map1: #map1} : () -> ()
+
+ // CHECK: "foo"(){map2: #map{{[0-9]+}}}
+ "foo"(){map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> ()
+
+ // CHECK: "foo"(){map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
+ "foo"(){map12: [#map1, #map2]} : () -> ()
+
// CHECK: "foo"(){cfgfunc: [], i123: 7, if: "foo"} : () -> ()
"foo"(){if: "foo", cfgfunc: [], i123: 7} : () -> ()