Support for AffineMapAttr.
PiperOrigin-RevId: 205157390
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6a2d932..0e5d101 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -50,6 +50,7 @@
void initialize(const Module *module);
void print(const Module *module);
+ void print(const Attribute *attr) const;
void print(const Type *type) const;
void print(const Function *fn);
void print(const ExtFunction *fn);
@@ -77,6 +78,11 @@
void visitCFGFunction(const CFGFunction *fn);
void visitMLFunction(const MLFunction *fn);
void visitType(const Type *type);
+ void visitAttribute(const Attribute *attr);
+ void visitOperation(const Operation *op);
+
+ void printAffineMapId(int affineMapId) const;
+ void printAffineMapReference(const AffineMap* affineMap) const;
raw_ostream &os;
DenseMap<const AffineMap *, int> affineMapIds;
@@ -113,6 +119,22 @@
}
}
+void ModuleState::visitAttribute(const Attribute *attr) {
+ if (isa<AffineMapAttr>(attr)) {
+ recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+ } else if (isa<ArrayAttr>(attr)) {
+ for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
+ visitAttribute(elt);
+ }
+ }
+}
+
+void ModuleState::visitOperation(const Operation *op) {
+ for (auto elt : op->getAttrs()) {
+ visitAttribute(elt.second);
+ }
+}
+
void ModuleState::visitExtFunction(const ExtFunction *fn) {
visitType(fn->getType());
}
@@ -120,11 +142,16 @@
void ModuleState::visitCFGFunction(const CFGFunction *fn) {
visitType(fn->getType());
// TODO Visit function body instructions.
+ for (auto &block : *fn) {
+ for (auto &op : block.getOperations()) {
+ visitOperation(&op);
+ }
+ }
}
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
- // TODO Visit function body statements.
+ // TODO Visit function body statements (and attributes if required).
}
void ModuleState::visitFunction(const Function *fn) {
@@ -151,13 +178,24 @@
}
// Prints affine map identifier.
-static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+void ModuleState::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
+void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
+ const int mapId = getAffineMapId(affineMap);
+ if (mapId >= 0) {
+ // Map will be printed at top of module so print reference to its id.
+ printAffineMapId(mapId);
+ } else {
+ // Map not in module state so print inline.
+ affineMap->print(os);
+ }
+}
+
void ModuleState::print(const Module *module) {
for (const auto &mapAndId : affineMapIds) {
- printAffineMapId(mapAndId.second, os);
+ printAffineMapId(mapAndId.second);
os << " = ";
mapAndId.first->print(os);
os << '\n';
@@ -165,6 +203,37 @@
for (auto *fn : module->functionList) print(fn);
}
+void ModuleState::print(const Attribute *attr) const {
+ switch (attr->getKind()) {
+ case Attribute::Kind::Bool:
+ os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
+ break;
+ case Attribute::Kind::Integer:
+ os << cast<IntegerAttr>(attr)->getValue();
+ break;
+ case Attribute::Kind::Float:
+ // FIXME: this isn't precise, we should print with a hex format.
+ os << cast<FloatAttr>(attr)->getValue();
+ break;
+ case Attribute::Kind::String:
+ // FIXME: should escape the string.
+ os << '"' << cast<StringAttr>(attr)->getValue() << '"';
+ break;
+ case Attribute::Kind::Array: {
+ auto elts = cast<ArrayAttr>(attr)->getValue();
+ os << '[';
+ interleave(elts,
+ [&](Attribute *attr) { print(attr); },
+ [&]() { os << ", "; });
+ os << ']';
+ break;
+ }
+ case Attribute::Kind::AffineMap:
+ printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+ break;
+ }
+}
+
void ModuleState::print(const Type *type) const {
switch (type->getKind()) {
case Type::Kind::AffineInt:
@@ -243,14 +312,7 @@
os << *v->getElementType();
for (auto map : v->getAffineMaps()) {
os << ", ";
- const int mapId = getAffineMapId(map);
- if (mapId >= 0) {
- // Map will be printed at top of module so print reference to its id.
- printAffineMapId(mapId, os);
- } else {
- // Map not in module state so print inline.
- map->print(os);
- }
+ printAffineMapReference(map);
}
os << ", " << v->getMemorySpace();
os << '>';
@@ -338,7 +400,9 @@
os << '{';
interleave(
attrs,
- [&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
+ [&](NamedAttribute attr) {
+ os << attr.first << ": ";
+ moduleState->print(attr.second); },
[&]() { os << ", "; });
os << '}';
}
@@ -553,6 +617,15 @@
// print and dump methods
//===----------------------------------------------------------------------===//
+void Attribute::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
+ moduleState.print(this);
+}
+
+void Attribute::dump() const {
+ print(llvm::errs());
+}
+
void Type::print(raw_ostream &os) const {
ModuleState moduleState(os);
moduleState.print(this);