Adds ModuleState to support printing outlined AffineMaps.

PiperOrigin-RevId: 204999887
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9a86b85..be6573f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -44,27 +44,253 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Module printing
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModuleState  {
+ public:
+  ModuleState(raw_ostream &os);
+
+  void initialize(const Module *module);
+
+  void print(const Module *module);
+  void print(const Type *type) const;
+  void print(const Function* fn);
+
+  void recordAffineMapReference(const AffineMap* affineMap) {
+    if (affineMapIds.count(affineMap) == 0) {
+      affineMapIds[affineMap] = nextAffineMapId++;
+    }
+  }
+
+  int getAffineMapId(const AffineMap* affineMap) const {
+    auto it = affineMapIds.find(affineMap);
+    if (it == affineMapIds.end()) {
+      return -1;
+    }
+    return it->second;
+  }
+
+ private:
+  // Visit functions.
+  void visitFunction(const Function *fn);
+  void visitExtFunction(const ExtFunction *fn);
+  void visitCFGFunction(const CFGFunction *fn);
+  void visitMLFunction(const MLFunction *fn);
+  void visitType(const Type *type);
+
+  raw_ostream &os;
+  DenseMap<const AffineMap*, int> affineMapIds;
+  int nextAffineMapId = 0;
+};
+} // end anonymous namespace
+
+ModuleState::ModuleState(raw_ostream &os) : os(os) {
+}
+
+// Initializes module state, populating affine map state.
+void ModuleState::initialize(const Module *module) {
+  for (auto fn : module->functionList) {
+    visitFunction(fn);
+  }
+}
+
+// TODO Support visiting other types/instructions when implemented.
+void ModuleState::visitType(const Type *type) {
+  if (type->getKind() == Type::Kind::Function) {
+    // Visit input and result types for functions.
+    auto *funcType = cast<FunctionType>(type);
+    for (auto* input : funcType->getInputs()) {
+      visitType(input);
+    }
+    for (auto* result : funcType->getResults()) {
+      visitType(result);
+    }
+  } else if (type->getKind() == Type::Kind::MemRef) {
+    // Visit affine maps in memref type.
+    auto *memref = cast<MemRefType>(type);
+    for (AffineMap* map : memref->getAffineMaps()) {
+      recordAffineMapReference(map);
+    }
+  }
+}
+
+void ModuleState::visitExtFunction(const ExtFunction *fn) {
+  visitType(fn->getType());
+}
+
+void ModuleState::visitCFGFunction(const CFGFunction *fn) {
+  visitType(fn->getType());
+  // TODO Visit function body instructions.
+}
+
+void ModuleState::visitMLFunction(const MLFunction *fn) {
+  visitType(fn->getType());
+  // TODO Visit function body statements.
+}
+
+void ModuleState::visitFunction(const Function *fn) {
+  switch (fn->getKind()) {
+    case Function::Kind::ExtFunc:
+      return visitExtFunction(cast<ExtFunction>(fn));
+    case Function::Kind::CFGFunc:
+      return visitCFGFunction(cast<CFGFunction>(fn));
+    case Function::Kind::MLFunc:
+      return visitMLFunction(cast<MLFunction>(fn));
+  }
+}
+
+static void printExtFunction(const ExtFunction* fn,
+                             const ModuleState* moduleState, raw_ostream &os);
+
+
+static void printCFGFunction(const CFGFunction* fn,
+                             const ModuleState* moduleState, raw_ostream &os);
+
+static void printMLFunction(const MLFunction* fn,
+                            const ModuleState* moduleState, raw_ostream &os);
+
+// Prints function with initialized module state.
+void ModuleState::print(const Function* fn) {
+  switch (fn->getKind()) {
+    case Function::Kind::ExtFunc:
+      return printExtFunction(cast<ExtFunction>(fn), this, os);
+    case Function::Kind::CFGFunc:
+      return printCFGFunction(cast<CFGFunction>(fn), this, os);
+    case Function::Kind::MLFunc:
+      return printMLFunction(cast<MLFunction>(fn), this, os);
+  }
+}
+
+// Prints affine map identifier.
+static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+  os << "#map" << affineMapId;
+}
+
+void ModuleState::print(const Module *module) {
+  for (const auto& mapAndId : affineMapIds) {
+    printAffineMapId(mapAndId.second, os);
+    os << " = ";
+    mapAndId.first->print(os);
+    os << '\n';
+  }
+  for (auto *fn : module->functionList)
+    print(fn);
+}
+
+void ModuleState::print(const Type *type) const {
+  switch (type->getKind()) {
+  case Type::Kind::AffineInt: os << "affineint"; return;
+  case Type::Kind::BF16: os << "bf16"; return;
+  case Type::Kind::F16:  os << "f16"; return;
+  case Type::Kind::F32:  os << "f32"; return;
+  case Type::Kind::F64:  os << "f64"; return;
+
+  case Type::Kind::Integer: {
+    auto *integer = cast<IntegerType>(type);
+    os << 'i' << integer->getWidth();
+    return;
+  }
+  case Type::Kind::Function: {
+    auto *func = cast<FunctionType>(type);
+    os << '(';
+    interleave(func->getInputs(),
+               [&](Type *type) { os << *type; },
+               [&]() { os << ", "; });
+    os << ") -> ";
+    auto results = func->getResults();
+    if (results.size() == 1)
+      os << *results[0];
+    else {
+      os << '(';
+      interleave(results,
+                 [&](Type *type) { os << *type; },
+                 [&]() { os << ", "; });
+      os << ')';
+    }
+    return;
+  }
+  case Type::Kind::Vector: {
+    auto *v = cast<VectorType>(type);
+    os << "vector<";
+    for (auto dim : v->getShape())
+      os << dim << 'x';
+    os << *v->getElementType() << '>';
+    return;
+  }
+  case Type::Kind::RankedTensor: {
+    auto *v = cast<RankedTensorType>(type);
+    os << "tensor<";
+    for (auto dim : v->getShape()) {
+      if (dim < 0)
+        os << '?';
+      else
+        os << dim;
+      os << 'x';
+    }
+    os << *v->getElementType() << '>';
+    return;
+  }
+  case Type::Kind::UnrankedTensor: {
+    auto *v = cast<UnrankedTensorType>(type);
+    os << "tensor<??" << *v->getElementType() << '>';
+    return;
+  }
+  case Type::Kind::MemRef: {
+    auto *v = cast<MemRefType>(type);
+    os << "memref<";
+    for (auto dim : v->getShape()) {
+      if (dim < 0)
+        os << '?';
+      else
+        os << dim;
+      os << 'x';
+    }
+    os << *v->getElementType();
+    for (auto map : v->getAffineMaps()) {
+      os << ", ";
+      const int mapId = getAffineMapId(map);
+      if (mapId >= 0) {
+        // Map will be printed at top of module so print reference to its id.
+        printAffineMapId(mapId, os);
+      } else {
+        // Map not in module state so print inline.
+        map->print(os);
+      }
+    }
+    os << ", " << v->getMemorySpace();
+    os << '>';
+    return;
+  }
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // Function printing
 //===----------------------------------------------------------------------===//
 
-static void printFunctionSignature(const Function *fn, raw_ostream &os) {
+static void printFunctionSignature(const Function *fn,
+                                   const ModuleState *moduleState,
+                                   raw_ostream &os) {
   auto type = fn->getType();
 
   os << "@" << fn->getName() << '(';
   interleave(type->getInputs(),
-             [&](Type *eltType) { os << *eltType; },
+             [&](Type *eltType) { moduleState->print(eltType); },
              [&]() { os << ", "; });
   os << ')';
 
   switch (type->getResults().size()) {
   case 0: break;
   case 1:
-    os << " -> " << *type->getResults()[0];
+    os << " -> ";
+    moduleState->print(type->getResults()[0]);
     break;
   default:
     os << " -> (";
     interleave(type->getResults(),
-               [&](Type *eltType) { os << *eltType; },
+               [&](Type *eltType) { moduleState->print(eltType); },
                [&]() { os << ", "; });
     os << ')';
     break;
@@ -72,8 +298,9 @@
 }
 
 void ExtFunction::print(raw_ostream &os) const {
+  ModuleState moduleState(os);
   os << "extfunc ";
-  printFunctionSignature(this, os);
+  printFunctionSignature(this, &moduleState, os);
   os << "\n";
 }
 
@@ -83,18 +310,23 @@
 // CFG and ML functions.
 class FunctionState {
 public:
-  FunctionState(MLIRContext *context, raw_ostream &os);
+  FunctionState(MLIRContext *context, const ModuleState *moduleState,
+                raw_ostream &os);
 
   void printOperation(const Operation *op);
 
 protected:
   raw_ostream &os;
+  const ModuleState *moduleState;
   const OperationSet &operationSet;
 };
 } // end anonymous namespace
 
-FunctionState::FunctionState(MLIRContext *context, raw_ostream &os)
-    : os(os), operationSet(OperationSet::get(context)) {}
+FunctionState::FunctionState(MLIRContext *context,
+                             const ModuleState *moduleState,
+                             raw_ostream &os)
+    : os(os), moduleState(moduleState),
+      operationSet(OperationSet::get(context)) {}
 
 void FunctionState::printOperation(const Operation *op) {
   // Check to see if this is a known operation.  If so, use the registered
@@ -126,7 +358,8 @@
 namespace {
 class CFGFunctionState : public FunctionState {
 public:
-  CFGFunctionState(const CFGFunction *function, raw_ostream &os);
+  CFGFunctionState(const CFGFunction *function, const ModuleState *moduleState,
+                   raw_ostream &os);
 
   const CFGFunction *getFunction() const { return function; }
 
@@ -150,8 +383,11 @@
 };
 } // end anonymous namespace
 
-CFGFunctionState::CFGFunctionState(const CFGFunction *function, raw_ostream &os)
-    : FunctionState(function->getContext(), os), function(function) {
+CFGFunctionState::CFGFunctionState(const CFGFunction *function,
+                                   const ModuleState *moduleState,
+                                   raw_ostream &os)
+    : FunctionState(function->getContext(), moduleState, os),
+      function(function) {
   // Each basic block gets a unique ID per function.
   unsigned blockID = 0;
   for (auto &block : *function)
@@ -160,7 +396,7 @@
 
 void CFGFunctionState::print() {
   os << "cfgfunc ";
-  printFunctionSignature(this->getFunction(), os);
+  printFunctionSignature(this->getFunction(), moduleState, os);
   os << " {\n";
 
   for (auto &block : *function)
@@ -210,7 +446,8 @@
 namespace {
 class MLFunctionState : public FunctionState {
 public:
-  MLFunctionState(const MLFunction *function, raw_ostream &os);
+  MLFunctionState(const MLFunction *function, const ModuleState *moduleState,
+                  raw_ostream &os);
 
   const MLFunction *getFunction() const { return function; }
 
@@ -233,14 +470,16 @@
 };
 } // end anonymous namespace
 
-MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
-    : FunctionState(function->getContext(), os), function(function),
-      numSpaces(0) {}
+MLFunctionState::MLFunctionState(const MLFunction *function,
+                                 const ModuleState *moduleState,
+                                 raw_ostream &os)
+    : FunctionState(function->getContext(), moduleState, os),
+      function(function), numSpaces(0) {}
 
 void MLFunctionState::print() {
   os << "mlfunc ";
   // FIXME: should print argument names rather than just signature
-  printFunctionSignature(function, os);
+  printFunctionSignature(function, moduleState, os);
   os << " {\n";
   print(function);
   os << "  return\n";
@@ -288,12 +527,41 @@
   }
 }
 
+void printExtFunction(const ExtFunction* fn, const ModuleState* moduleState,
+                      raw_ostream &os) {
+  os << "extfunc ";
+  printFunctionSignature(fn, moduleState, os);
+  os << '\n';
+}
+
+void printCFGFunction(const CFGFunction* fn, const ModuleState* moduleState,
+                      raw_ostream &os) {
+  CFGFunctionState state(fn, moduleState, os);
+  state.print();
+}
+
+void printMLFunction(const MLFunction* fn, const ModuleState* moduleState,
+                     raw_ostream &os) {
+  MLFunctionState state(fn, moduleState, os);
+  state.print();
+}
+
 //===----------------------------------------------------------------------===//
 // print and dump methods
 //===----------------------------------------------------------------------===//
 
+void Type::print(raw_ostream &os) const {
+  ModuleState moduleState(os);
+  moduleState.print(this);
+}
+
+void Type::dump() const {
+  print(llvm::errs());
+}
+
 void Instruction::print(raw_ostream &os) const {
-  CFGFunctionState state(getFunction(), os);
+  ModuleState moduleState(os);
+  CFGFunctionState state(getFunction(), &moduleState, os);
   state.print(this);
 }
 
@@ -406,7 +674,8 @@
 }
 
 void BasicBlock::print(raw_ostream &os) const {
-  CFGFunctionState state(getFunction(), os);
+  ModuleState moduleState(os);
+  CFGFunctionState state(getFunction(), &moduleState, os);
   state.print();
 }
 
@@ -415,7 +684,8 @@
 }
 
 void Statement::print(raw_ostream &os) const {
-  MLFunctionState state(getFunction(), os);
+  ModuleState moduleState(os);
+  MLFunctionState state(getFunction(), &moduleState, os);
   state.print(this);
 }
 
@@ -436,24 +706,21 @@
 }
 
 void CFGFunction::print(raw_ostream &os) const {
-  CFGFunctionState state(this, os);
+  ModuleState moduleState(os);
+  CFGFunctionState state(this, &moduleState, os);
   state.print();
 }
 
 void MLFunction::print(raw_ostream &os) const {
-  MLFunctionState state(this, os);
+  ModuleState moduleState(os);
+  MLFunctionState state(this, &moduleState, os);
   state.print();
 }
 
 void Module::print(raw_ostream &os) const {
-  unsigned id = 0;
-  for (auto *map : affineMapList) {
-    os << "#" << id++ << " = ";
-    map->print(os);
-    os << '\n';
-  }
-  for (auto *fn : functionList)
-    fn->print(os);
+  ModuleState moduleState(os);
+  moduleState.initialize(this);
+  moduleState.print(this);
 }
 
 void Module::dump() const {