Support for AffineMapAttr.

PiperOrigin-RevId: 205157390
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
index 2ff129b..47d0b9f 100644
--- a/include/mlir/IR/Attributes.h
+++ b/include/mlir/IR/Attributes.h
@@ -22,7 +22,8 @@
 #include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
-  class MLIRContext;
+class MLIRContext;
+class AffineMap;
 
 /// Instances of the Attribute class are immutable, uniqued, immortal, and owned
 /// by MLIRContext.  As such, they are passed around by raw non-const pointer.
@@ -34,6 +35,7 @@
     Float,
     String,
     Array,
+    AffineMap,
     // TODO: Function references.
   };
 
@@ -147,7 +149,23 @@
   ArrayRef<Attribute*> value;
 };
 
+class AffineMapAttr : public Attribute {
+public:
+  static AffineMapAttr *get(AffineMap *value, MLIRContext *context);
+
+  AffineMap *getValue() const {
+    return value;
+  }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Attribute *attr) {
+    return attr->getKind() == Kind::AffineMap;
+  }
+private:
+  AffineMapAttr(AffineMap *value) : Attribute(Kind::AffineMap), value(value) {}
+  AffineMap *value;
+};
+
 } // end namespace mlir.
 
 #endif
-
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 9297fd3..c41a886 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -37,6 +37,7 @@
 class FloatAttr;
 class StringAttr;
 class ArrayAttr;
+class AffineMapAttr;
 class AffineMap;
 class AffineExpr;
 class AffineConstantExpr;
@@ -74,6 +75,7 @@
   FloatAttr *getFloatAttr(double value);
   StringAttr *getStringAttr(StringRef bytes);
   ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
+  AffineMapAttr *getAffineMapAttr(AffineMap *value);
 
   // Affine Expressions and Affine Map.
   AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6a2d932..0e5d101 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -50,6 +50,7 @@
   void initialize(const Module *module);
 
   void print(const Module *module);
+  void print(const Attribute *attr) const;
   void print(const Type *type) const;
   void print(const Function *fn);
   void print(const ExtFunction *fn);
@@ -77,6 +78,11 @@
   void visitCFGFunction(const CFGFunction *fn);
   void visitMLFunction(const MLFunction *fn);
   void visitType(const Type *type);
+  void visitAttribute(const Attribute *attr);
+  void visitOperation(const Operation *op);
+
+  void printAffineMapId(int affineMapId) const;
+  void printAffineMapReference(const AffineMap* affineMap) const;
 
   raw_ostream &os;
   DenseMap<const AffineMap *, int> affineMapIds;
@@ -113,6 +119,22 @@
   }
 }
 
+void ModuleState::visitAttribute(const Attribute *attr) {
+  if (isa<AffineMapAttr>(attr)) {
+    recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+  } else if (isa<ArrayAttr>(attr)) {
+    for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
+      visitAttribute(elt);
+    }
+  }
+}
+
+void ModuleState::visitOperation(const Operation *op) {
+  for (auto elt : op->getAttrs()) {
+    visitAttribute(elt.second);
+  }
+}
+
 void ModuleState::visitExtFunction(const ExtFunction *fn) {
   visitType(fn->getType());
 }
@@ -120,11 +142,16 @@
 void ModuleState::visitCFGFunction(const CFGFunction *fn) {
   visitType(fn->getType());
   // TODO Visit function body instructions.
+  for (auto &block : *fn) {
+    for (auto &op : block.getOperations()) {
+      visitOperation(&op);
+    }
+  }
 }
 
 void ModuleState::visitMLFunction(const MLFunction *fn) {
   visitType(fn->getType());
-  // TODO Visit function body statements.
+  // TODO Visit function body statements (and attributes if required).
 }
 
 void ModuleState::visitFunction(const Function *fn) {
@@ -151,13 +178,24 @@
 }
 
 // Prints affine map identifier.
-static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+void ModuleState::printAffineMapId(int affineMapId) const {
   os << "#map" << affineMapId;
 }
 
+void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
+  const int mapId = getAffineMapId(affineMap);
+  if (mapId >= 0) {
+    // Map will be printed at top of module so print reference to its id.
+    printAffineMapId(mapId);
+  } else {
+    // Map not in module state so print inline.
+    affineMap->print(os);
+  }
+}
+
 void ModuleState::print(const Module *module) {
   for (const auto &mapAndId : affineMapIds) {
-    printAffineMapId(mapAndId.second, os);
+    printAffineMapId(mapAndId.second);
     os << " = ";
     mapAndId.first->print(os);
     os << '\n';
@@ -165,6 +203,37 @@
   for (auto *fn : module->functionList) print(fn);
 }
 
+void ModuleState::print(const Attribute *attr) const {
+  switch (attr->getKind()) {
+  case Attribute::Kind::Bool:
+    os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
+    break;
+  case Attribute::Kind::Integer:
+    os << cast<IntegerAttr>(attr)->getValue();
+    break;
+  case Attribute::Kind::Float:
+    // FIXME: this isn't precise, we should print with a hex format.
+    os << cast<FloatAttr>(attr)->getValue();
+    break;
+  case Attribute::Kind::String:
+    // FIXME: should escape the string.
+    os << '"' << cast<StringAttr>(attr)->getValue() << '"';
+    break;
+  case Attribute::Kind::Array: {
+    auto elts = cast<ArrayAttr>(attr)->getValue();
+    os << '[';
+    interleave(elts,
+                [&](Attribute *attr) { print(attr); },
+                [&]() { os << ", "; });
+    os << ']';
+    break;
+  }
+  case Attribute::Kind::AffineMap:
+    printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
+    break;
+  }
+}
+
 void ModuleState::print(const Type *type) const {
   switch (type->getKind()) {
   case Type::Kind::AffineInt:
@@ -243,14 +312,7 @@
     os << *v->getElementType();
     for (auto map : v->getAffineMaps()) {
       os << ", ";
-      const int mapId = getAffineMapId(map);
-      if (mapId >= 0) {
-        // Map will be printed at top of module so print reference to its id.
-        printAffineMapId(mapId, os);
-      } else {
-        // Map not in module state so print inline.
-        map->print(os);
-      }
+      printAffineMapReference(map);
     }
     os << ", " << v->getMemorySpace();
     os << '>';
@@ -338,7 +400,9 @@
     os << '{';
     interleave(
         attrs,
-        [&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
+        [&](NamedAttribute attr) {
+          os << attr.first << ": ";
+          moduleState->print(attr.second); },
         [&]() { os << ", "; });
     os << '}';
   }
@@ -553,6 +617,15 @@
 // print and dump methods
 //===----------------------------------------------------------------------===//
 
+void Attribute::print(raw_ostream &os) const {
+  ModuleState moduleState(os);
+  moduleState.print(this);
+}
+
+void Attribute::dump() const {
+  print(llvm::errs());
+}
+
 void Type::print(raw_ostream &os) const {
   ModuleState moduleState(os);
   moduleState.print(this);
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
deleted file mode 100644
index df23424..0000000
--- a/lib/IR/Attributes.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/Attributes.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/Support/STLExtras.h"
-using namespace mlir;
-
-void Attribute::print(raw_ostream &os) const {
-  switch (getKind()) {
-  case Kind::Bool:
-    os << (cast<BoolAttr>(this)->getValue() ? "true" : "false");
-    break;
-  case Kind::Integer:
-    os << cast<IntegerAttr>(this)->getValue();
-    break;
-  case Kind::Float:
-    // FIXME: this isn't precise, we should print with a hex format.
-    os << cast<FloatAttr>(this)->getValue();
-    break;
-  case Kind::String:
-    // FIXME: should escape the string.
-    os << '"' << cast<StringAttr>(this)->getValue() << '"';
-    break;
-  case Kind::Array: {
-    auto elts = cast<ArrayAttr>(this)->getValue();
-    os << '[';
-    interleave(elts,
-                [&](Attribute *attr) { attr->print(os); },
-                [&]() { os << ", "; });
-    os << ']';
-    break;
-  }
-  }
-}
-
-void Attribute::dump() const {
-  print(llvm::errs());
-}
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index e9bea2a..8d27991 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -94,6 +94,10 @@
   return ArrayAttr::get(value, context);
 }
 
+AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
+  return AffineMapAttr::get(value, context);
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Expressions and Affine Map.
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index eb06e94..df3d01a 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -230,6 +230,7 @@
   StringMap<StringAttr*> stringAttrs;
   using ArrayAttrSet = DenseSet<ArrayAttr*, ArrayAttrKeyInfo>;
   ArrayAttrSet arrayAttrs;
+  DenseMap<AffineMap*, AffineMapAttr*> affineMapAttrs;
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
@@ -541,6 +542,16 @@
   return *existing.first = result;
 }
 
+AffineMapAttr *AffineMapAttr::get(AffineMap* value, MLIRContext *context) {
+  auto *&result = context->getImpl().affineMapAttrs[value];
+  if (result)
+    return result;
+
+  result = context->getImpl().allocator.Allocate<AffineMapAttr>();
+  new (result) AffineMapAttr(value);
+  return result;
+}
+
 /// Perform a three-way comparison between the names of the specified
 /// NamedAttributes.
 static int compareNamedAttributes(const NamedAttribute *lhs,
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3d5c908..d220787 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -598,6 +598,11 @@
     return builder.getArrayAttr(elements);
   }
   default:
+    // Try to parse affine map reference.
+    auto* affineMap = parseAffineMapReference();
+    if (affineMap != nullptr)
+      return builder.getAffineMapAttr(affineMap);
+
     // TODO: Handle floating point.
     return (emitError("expected constant attribute value"), nullptr);
   }
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 6e0be24..7a1cfce 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -136,6 +136,15 @@
   // CHECK: "foo"(){a: 1, b: -423, c: [true, false]}  : () -> ()
   "foo"(){a: 1, b: -423, c: [true, false] } : () -> ()
 
+  // CHECK: "foo"(){map1: #map{{[0-9]+}}}
+  "foo"(){map1: #map1} : () -> ()
+
+  // CHECK: "foo"(){map2: #map{{[0-9]+}}}
+  "foo"(){map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> ()
+
+  // CHECK: "foo"(){map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
+  "foo"(){map12: [#map1, #map2]} : () -> ()
+
   // CHECK: "foo"(){cfgfunc: [], i123: 7, if: "foo"} : () -> ()
   "foo"(){if: "foo", cfgfunc: [], i123: 7} : () -> ()