Support for AffineMapAttr.
PiperOrigin-RevId: 205157390
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6a2d932..0e5d101 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -50,6 +50,7 @@
void initialize(const Module *module);
void print(const Module *module);
+ void print(const Attribute *attr) const;
void print(const Type *type) const;
void print(const Function *fn);
void print(const ExtFunction *fn);
@@ -77,6 +78,11 @@
void visitCFGFunction(const CFGFunction *fn);
void visitMLFunction(const MLFunction *fn);
void visitType(const Type *type);
+ void visitAttribute(const Attribute *attr);
+ void visitOperation(const Operation *op);
+
+ void printAffineMapId(int affineMapId) const;
+ void printAffineMapReference(const AffineMap* affineMap) const;
raw_ostream &os;
DenseMap<const AffineMap *, int> affineMapIds;
@@ -113,6 +119,22 @@
}
}
+void ModuleState::visitAttribute(const Attribute *attr) {
+ if (isa<AffineMapAttr>(attr)) {
+ recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+ } else if (isa<ArrayAttr>(attr)) {
+ for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
+ visitAttribute(elt);
+ }
+ }
+}
+
+void ModuleState::visitOperation(const Operation *op) {
+ for (auto elt : op->getAttrs()) {
+ visitAttribute(elt.second);
+ }
+}
+
void ModuleState::visitExtFunction(const ExtFunction *fn) {
visitType(fn->getType());
}
@@ -120,11 +142,16 @@
void ModuleState::visitCFGFunction(const CFGFunction *fn) {
visitType(fn->getType());
// TODO Visit function body instructions.
+ for (auto &block : *fn) {
+ for (auto &op : block.getOperations()) {
+ visitOperation(&op);
+ }
+ }
}
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
- // TODO Visit function body statements.
+ // TODO Visit function body statements (and attributes if required).
}
void ModuleState::visitFunction(const Function *fn) {
@@ -151,13 +178,24 @@
}
// Prints affine map identifier.
-static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+void ModuleState::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
+void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
+ const int mapId = getAffineMapId(affineMap);
+ if (mapId >= 0) {
+ // Map will be printed at top of module so print reference to its id.
+ printAffineMapId(mapId);
+ } else {
+ // Map not in module state so print inline.
+ affineMap->print(os);
+ }
+}
+
void ModuleState::print(const Module *module) {
for (const auto &mapAndId : affineMapIds) {
- printAffineMapId(mapAndId.second, os);
+ printAffineMapId(mapAndId.second);
os << " = ";
mapAndId.first->print(os);
os << '\n';
@@ -165,6 +203,37 @@
for (auto *fn : module->functionList) print(fn);
}
+void ModuleState::print(const Attribute *attr) const {
+ switch (attr->getKind()) {
+ case Attribute::Kind::Bool:
+ os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
+ break;
+ case Attribute::Kind::Integer:
+ os << cast<IntegerAttr>(attr)->getValue();
+ break;
+ case Attribute::Kind::Float:
+ // FIXME: this isn't precise, we should print with a hex format.
+ os << cast<FloatAttr>(attr)->getValue();
+ break;
+ case Attribute::Kind::String:
+ // FIXME: should escape the string.
+ os << '"' << cast<StringAttr>(attr)->getValue() << '"';
+ break;
+ case Attribute::Kind::Array: {
+ auto elts = cast<ArrayAttr>(attr)->getValue();
+ os << '[';
+ interleave(elts,
+ [&](Attribute *attr) { print(attr); },
+ [&]() { os << ", "; });
+ os << ']';
+ break;
+ }
+ case Attribute::Kind::AffineMap:
+ printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+ break;
+ }
+}
+
void ModuleState::print(const Type *type) const {
switch (type->getKind()) {
case Type::Kind::AffineInt:
@@ -243,14 +312,7 @@
os << *v->getElementType();
for (auto map : v->getAffineMaps()) {
os << ", ";
- const int mapId = getAffineMapId(map);
- if (mapId >= 0) {
- // Map will be printed at top of module so print reference to its id.
- printAffineMapId(mapId, os);
- } else {
- // Map not in module state so print inline.
- map->print(os);
- }
+ printAffineMapReference(map);
}
os << ", " << v->getMemorySpace();
os << '>';
@@ -338,7 +400,9 @@
os << '{';
interleave(
attrs,
- [&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
+ [&](NamedAttribute attr) {
+ os << attr.first << ": ";
+ moduleState->print(attr.second); },
[&]() { os << ", "; });
os << '}';
}
@@ -553,6 +617,15 @@
// print and dump methods
//===----------------------------------------------------------------------===//
+void Attribute::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
+ moduleState.print(this);
+}
+
+void Attribute::dump() const {
+ print(llvm::errs());
+}
+
void Type::print(raw_ostream &os) const {
ModuleState moduleState(os);
moduleState.print(this);
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
deleted file mode 100644
index df23424..0000000
--- a/lib/IR/Attributes.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/Attributes.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/Support/STLExtras.h"
-using namespace mlir;
-
-void Attribute::print(raw_ostream &os) const {
- switch (getKind()) {
- case Kind::Bool:
- os << (cast<BoolAttr>(this)->getValue() ? "true" : "false");
- break;
- case Kind::Integer:
- os << cast<IntegerAttr>(this)->getValue();
- break;
- case Kind::Float:
- // FIXME: this isn't precise, we should print with a hex format.
- os << cast<FloatAttr>(this)->getValue();
- break;
- case Kind::String:
- // FIXME: should escape the string.
- os << '"' << cast<StringAttr>(this)->getValue() << '"';
- break;
- case Kind::Array: {
- auto elts = cast<ArrayAttr>(this)->getValue();
- os << '[';
- interleave(elts,
- [&](Attribute *attr) { attr->print(os); },
- [&]() { os << ", "; });
- os << ']';
- break;
- }
- }
-}
-
-void Attribute::dump() const {
- print(llvm::errs());
-}
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index e9bea2a..8d27991 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -94,6 +94,10 @@
return ArrayAttr::get(value, context);
}
+AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
+ return AffineMapAttr::get(value, context);
+}
+
//===----------------------------------------------------------------------===//
// Affine Expressions and Affine Map.
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index eb06e94..df3d01a 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -230,6 +230,7 @@
StringMap<StringAttr*> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr*, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs;
+ DenseMap<AffineMap*, AffineMapAttr*> affineMapAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
@@ -541,6 +542,16 @@
return *existing.first = result;
}
+AffineMapAttr *AffineMapAttr::get(AffineMap* value, MLIRContext *context) {
+ auto *&result = context->getImpl().affineMapAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<AffineMapAttr>();
+ new (result) AffineMapAttr(value);
+ return result;
+}
+
/// Perform a three-way comparison between the names of the specified
/// NamedAttributes.
static int compareNamedAttributes(const NamedAttribute *lhs,