Support for AffineMapAttr.

PiperOrigin-RevId: 205157390
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6a2d932..0e5d101 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -50,6 +50,7 @@
   void initialize(const Module *module);
 
   void print(const Module *module);
+  void print(const Attribute *attr) const;
   void print(const Type *type) const;
   void print(const Function *fn);
   void print(const ExtFunction *fn);
@@ -77,6 +78,11 @@
   void visitCFGFunction(const CFGFunction *fn);
   void visitMLFunction(const MLFunction *fn);
   void visitType(const Type *type);
+  void visitAttribute(const Attribute *attr);
+  void visitOperation(const Operation *op);
+
+  void printAffineMapId(int affineMapId) const;
+  void printAffineMapReference(const AffineMap* affineMap) const;
 
   raw_ostream &os;
   DenseMap<const AffineMap *, int> affineMapIds;
@@ -113,6 +119,22 @@
   }
 }
 
+void ModuleState::visitAttribute(const Attribute *attr) {
+  if (isa<AffineMapAttr>(attr)) {
+    recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+  } else if (isa<ArrayAttr>(attr)) {
+    for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
+      visitAttribute(elt);
+    }
+  }
+}
+
+void ModuleState::visitOperation(const Operation *op) {
+  for (auto elt : op->getAttrs()) {
+    visitAttribute(elt.second);
+  }
+}
+
 void ModuleState::visitExtFunction(const ExtFunction *fn) {
   visitType(fn->getType());
 }
@@ -120,11 +142,16 @@
 void ModuleState::visitCFGFunction(const CFGFunction *fn) {
   visitType(fn->getType());
   // TODO Visit function body instructions.
+  for (auto &block : *fn) {
+    for (auto &op : block.getOperations()) {
+      visitOperation(&op);
+    }
+  }
 }
 
 void ModuleState::visitMLFunction(const MLFunction *fn) {
   visitType(fn->getType());
-  // TODO Visit function body statements.
+  // TODO Visit function body statements (and attributes if required).
 }
 
 void ModuleState::visitFunction(const Function *fn) {
@@ -151,13 +178,24 @@
 }
 
 // Prints affine map identifier.
-static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+void ModuleState::printAffineMapId(int affineMapId) const {
   os << "#map" << affineMapId;
 }
 
+void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
+  const int mapId = getAffineMapId(affineMap);
+  if (mapId >= 0) {
+    // Map will be printed at top of module so print reference to its id.
+    printAffineMapId(mapId);
+  } else {
+    // Map not in module state so print inline.
+    affineMap->print(os);
+  }
+}
+
 void ModuleState::print(const Module *module) {
   for (const auto &mapAndId : affineMapIds) {
-    printAffineMapId(mapAndId.second, os);
+    printAffineMapId(mapAndId.second);
     os << " = ";
     mapAndId.first->print(os);
     os << '\n';
@@ -165,6 +203,37 @@
   for (auto *fn : module->functionList) print(fn);
 }
 
+void ModuleState::print(const Attribute *attr) const {
+  switch (attr->getKind()) {
+  case Attribute::Kind::Bool:
+    os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
+    break;
+  case Attribute::Kind::Integer:
+    os << cast<IntegerAttr>(attr)->getValue();
+    break;
+  case Attribute::Kind::Float:
+    // FIXME: this isn't precise, we should print with a hex format.
+    os << cast<FloatAttr>(attr)->getValue();
+    break;
+  case Attribute::Kind::String:
+    // FIXME: should escape the string.
+    os << '"' << cast<StringAttr>(attr)->getValue() << '"';
+    break;
+  case Attribute::Kind::Array: {
+    auto elts = cast<ArrayAttr>(attr)->getValue();
+    os << '[';
+    interleave(elts,
+                [&](Attribute *attr) { print(attr); },
+                [&]() { os << ", "; });
+    os << ']';
+    break;
+  }
+  case Attribute::Kind::AffineMap:
+    printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+    break;
+  }
+}
+
 void ModuleState::print(const Type *type) const {
   switch (type->getKind()) {
   case Type::Kind::AffineInt:
@@ -243,14 +312,7 @@
     os << *v->getElementType();
     for (auto map : v->getAffineMaps()) {
       os << ", ";
-      const int mapId = getAffineMapId(map);
-      if (mapId >= 0) {
-        // Map will be printed at top of module so print reference to its id.
-        printAffineMapId(mapId, os);
-      } else {
-        // Map not in module state so print inline.
-        map->print(os);
-      }
+      printAffineMapReference(map);
     }
     os << ", " << v->getMemorySpace();
     os << '>';
@@ -338,7 +400,9 @@
     os << '{';
     interleave(
         attrs,
-        [&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
+        [&](NamedAttribute attr) {
+          os << attr.first << ": ";
+          moduleState->print(attr.second); },
         [&]() { os << ", "; });
     os << '}';
   }
@@ -553,6 +617,15 @@
 // print and dump methods
 //===----------------------------------------------------------------------===//
 
+void Attribute::print(raw_ostream &os) const {
+  ModuleState moduleState(os);
+  moduleState.print(this);
+}
+
+void Attribute::dump() const {
+  print(llvm::errs());
+}
+
 void Type::print(raw_ostream &os) const {
   ModuleState moduleState(os);
   moduleState.print(this);
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
deleted file mode 100644
index df23424..0000000
--- a/lib/IR/Attributes.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/Attributes.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/Support/STLExtras.h"
-using namespace mlir;
-
-void Attribute::print(raw_ostream &os) const {
-  switch (getKind()) {
-  case Kind::Bool:
-    os << (cast<BoolAttr>(this)->getValue() ? "true" : "false");
-    break;
-  case Kind::Integer:
-    os << cast<IntegerAttr>(this)->getValue();
-    break;
-  case Kind::Float:
-    // FIXME: this isn't precise, we should print with a hex format.
-    os << cast<FloatAttr>(this)->getValue();
-    break;
-  case Kind::String:
-    // FIXME: should escape the string.
-    os << '"' << cast<StringAttr>(this)->getValue() << '"';
-    break;
-  case Kind::Array: {
-    auto elts = cast<ArrayAttr>(this)->getValue();
-    os << '[';
-    interleave(elts,
-                [&](Attribute *attr) { attr->print(os); },
-                [&]() { os << ", "; });
-    os << ']';
-    break;
-  }
-  }
-}
-
-void Attribute::dump() const {
-  print(llvm::errs());
-}
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index e9bea2a..8d27991 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -94,6 +94,10 @@
   return ArrayAttr::get(value, context);
 }
 
+AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
+  return AffineMapAttr::get(value, context);
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Expressions and Affine Map.
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index eb06e94..df3d01a 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -230,6 +230,7 @@
   StringMap<StringAttr*> stringAttrs;
   using ArrayAttrSet = DenseSet<ArrayAttr*, ArrayAttrKeyInfo>;
   ArrayAttrSet arrayAttrs;
+  DenseMap<AffineMap*, AffineMapAttr*> affineMapAttrs;
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
@@ -541,6 +542,16 @@
   return *existing.first = result;
 }
 
+AffineMapAttr *AffineMapAttr::get(AffineMap* value, MLIRContext *context) {
+  auto *&result = context->getImpl().affineMapAttrs[value];
+  if (result)
+    return result;
+
+  result = context->getImpl().allocator.Allocate<AffineMapAttr>();
+  new (result) AffineMapAttr(value);
+  return result;
+}
+
 /// Perform a three-way comparison between the names of the specified
 /// NamedAttributes.
 static int compareNamedAttributes(const NamedAttribute *lhs,