Refactor the AsmParser to follow the pattern established in the parser:
there is now an explicit state class - which only has one instance per top
level FooThing::print call.  The FunctionPrinter's now subclass ModulePrinter
so they can just call print on their types and other global stuff.  This also
makes the contract strict that the global FooThing::print calls are the public
entrypoints and that the printer implementation is otherwise self contained.

No Functionality Change.

PiperOrigin-RevId: 205409317
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 31fd05c..b56c775 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -38,37 +38,23 @@
 
 void Identifier::dump() const { print(llvm::errs()); }
 
-template <typename Container, typename UnaryFunctor>
-inline void interleaveComma(raw_ostream &os, const Container &c,
-                            UnaryFunctor each_fn) {
-  interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
-}
-
 //===----------------------------------------------------------------------===//
-// Module printing
+// ModuleState
 //===----------------------------------------------------------------------===//
 
 namespace {
 class ModuleState {
 public:
-  ModuleState(raw_ostream &os);
+  /// This is the operation set for the current context if it is knowable (a
+  /// context could be determined), otherwise this is null.
+  OperationSet *const operationSet;
 
+  explicit ModuleState(MLIRContext *context)
+      : operationSet(context ? &OperationSet::get(context) : nullptr) {}
+
+  // Initializes module state, populating affine map state.
   void initialize(const Module *module);
 
-  void print(const Module *module);
-  void print(const Attribute *attr) const;
-  void print(const Type *type) const;
-  void print(const Function *fn);
-  void print(const ExtFunction *fn);
-  void print(const CFGFunction *fn);
-  void print(const MLFunction *fn);
-
-  void recordAffineMapReference(const AffineMap *affineMap) {
-    if (affineMapIds.count(affineMap) == 0) {
-      affineMapIds[affineMap] = nextAffineMapId++;
-    }
-  }
-
   int getAffineMapId(const AffineMap *affineMap) const {
     auto it = affineMapIds.find(affineMap);
     if (it == affineMapIds.end()) {
@@ -77,7 +63,17 @@
     return it->second;
   }
 
+  const DenseMap<const AffineMap *, int> &getAffineMapIds() const {
+    return affineMapIds;
+  }
+
 private:
+  void recordAffineMapReference(const AffineMap *affineMap) {
+    if (affineMapIds.count(affineMap) == 0) {
+      affineMapIds[affineMap] = nextAffineMapId++;
+    }
+  }
+
   // Visit functions.
   void visitFunction(const Function *fn);
   void visitExtFunction(const ExtFunction *fn);
@@ -87,23 +83,11 @@
   void visitAttribute(const Attribute *attr);
   void visitOperation(const Operation *op);
 
-  void printAffineMapId(int affineMapId) const;
-  void printAffineMapReference(const AffineMap* affineMap) const;
-
-  raw_ostream &os;
   DenseMap<const AffineMap *, int> affineMapIds;
   int nextAffineMapId = 0;
 };
 }  // end anonymous namespace
 
-ModuleState::ModuleState(raw_ostream &os) : os(os) {}
-
-// Initializes module state, populating affine map state.
-void ModuleState::initialize(const Module *module) {
-  for (auto fn : module->functionList) {
-    visitFunction(fn);
-  }
-}
 
 // TODO Support visiting other types/instructions when implemented.
 void ModuleState::visitType(const Type *type) {
@@ -171,8 +155,54 @@
   }
 }
 
+// Initializes module state, populating affine map state.
+void ModuleState::initialize(const Module *module) {
+  for (auto fn : module->functionList) {
+    visitFunction(fn);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// ModulePrinter
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModulePrinter {
+public:
+  ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
+  explicit ModulePrinter(const ModulePrinter &printer)
+      : os(printer.os), state(printer.state) {}
+
+  template <typename Container, typename UnaryFunctor>
+  inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
+    interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
+  }
+
+  void print(const Module *module);
+  void print(const Attribute *attr) const;
+  void print(const Type *type) const;
+  void print(const Function *fn);
+  void print(const ExtFunction *fn);
+  void print(const CFGFunction *fn);
+  void print(const MLFunction *fn);
+
+  void print(const AffineMap *map);
+  void print(const AffineExpr *expr) const;
+
+protected:
+  raw_ostream &os;
+  ModuleState &state;
+
+  void printFunctionSignature(const Function *fn);
+  void printAffineMapId(int affineMapId) const;
+  void printAffineMapReference(const AffineMap *affineMap) const;
+
+  void print(const AffineBinaryOpExpr *expr) const;
+};
+} // end anonymous namespace
+
 // Prints function with initialized module state.
-void ModuleState::print(const Function *fn) {
+void ModulePrinter::print(const Function *fn) {
   switch (fn->getKind()) {
   case Function::Kind::ExtFunc:
     return print(cast<ExtFunction>(fn));
@@ -184,12 +214,12 @@
 }
 
 // Prints affine map identifier.
-void ModuleState::printAffineMapId(int affineMapId) const {
+void ModulePrinter::printAffineMapId(int affineMapId) const {
   os << "#map" << affineMapId;
 }
 
-void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
-  const int mapId = getAffineMapId(affineMap);
+void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) const {
+  int mapId = state.getAffineMapId(affineMap);
   if (mapId >= 0) {
     // Map will be printed at top of module so print reference to its id.
     printAffineMapId(mapId);
@@ -199,8 +229,8 @@
   }
 }
 
-void ModuleState::print(const Module *module) {
-  for (const auto &mapAndId : affineMapIds) {
+void ModulePrinter::print(const Module *module) {
+  for (const auto &mapAndId : state.getAffineMapIds()) {
     printAffineMapId(mapAndId.second);
     os << " = ";
     mapAndId.first->print(os);
@@ -209,7 +239,7 @@
   for (auto *fn : module->functionList) print(fn);
 }
 
-void ModuleState::print(const Attribute *attr) const {
+void ModulePrinter::print(const Attribute *attr) const {
   switch (attr->getKind()) {
   case Attribute::Kind::Bool:
     os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
@@ -228,7 +258,7 @@
   case Attribute::Kind::Array: {
     auto elts = cast<ArrayAttr>(attr)->getValue();
     os << '[';
-    interleaveComma(os, elts, [&](Attribute *attr) { print(attr); });
+    interleaveComma(elts, [&](Attribute *attr) { print(attr); });
     os << ']';
     break;
   }
@@ -238,7 +268,7 @@
   }
 }
 
-void ModuleState::print(const Type *type) const {
+void ModulePrinter::print(const Type *type) const {
   switch (type->getKind()) {
   case Type::Kind::AffineInt:
     os << "affineint";
@@ -264,14 +294,14 @@
   case Type::Kind::Function: {
     auto *func = cast<FunctionType>(type);
     os << '(';
-    interleaveComma(os, func->getInputs(), [&](Type *type) { os << *type; });
+    interleaveComma(func->getInputs(), [&](Type *type) { os << *type; });
     os << ") -> ";
     auto results = func->getResults();
     if (results.size() == 1)
       os << *results[0];
     else {
       os << '(';
-      interleaveComma(os, results, [&](Type *type) { os << *type; });
+      interleaveComma(results, [&](Type *type) { os << *type; });
       os << ')';
     }
     return;
@@ -324,17 +354,132 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Affine expressions and maps
+//===----------------------------------------------------------------------===//
+
+void ModulePrinter::print(const AffineExpr *expr) const {
+  switch (expr->getKind()) {
+  case AffineExpr::Kind::SymbolId:
+    os << 's' << cast<AffineSymbolExpr>(expr)->getPosition();
+    return;
+  case AffineExpr::Kind::DimId:
+    os << 'd' << cast<AffineDimExpr>(expr)->getPosition();
+    return;
+  case AffineExpr::Kind::Constant:
+    os << cast<AffineConstantExpr>(expr)->getValue();
+    return;
+  case AffineExpr::Kind::Add:
+  case AffineExpr::Kind::Mul:
+  case AffineExpr::Kind::FloorDiv:
+  case AffineExpr::Kind::CeilDiv:
+  case AffineExpr::Kind::Mod:
+    return print(cast<AffineBinaryOpExpr>(expr));
+  }
+}
+
+void ModulePrinter::print(const AffineBinaryOpExpr *expr) const {
+  if (expr->getKind() != AffineExpr::Kind::Add) {
+    os << '(';
+    print(expr->getLHS());
+    switch (expr->getKind()) {
+    case AffineExpr::Kind::Mul:
+      os << " * ";
+      break;
+    case AffineExpr::Kind::FloorDiv:
+      os << " floordiv ";
+      break;
+    case AffineExpr::Kind::CeilDiv:
+      os << " ceildiv ";
+      break;
+    case AffineExpr::Kind::Mod:
+      os << " mod ";
+      break;
+    default:
+      llvm_unreachable("unexpected affine binary op expression");
+    }
+
+    print(expr->getRHS());
+    os << ')';
+    return;
+  }
+
+  // Print out special "pretty" forms for add.
+  os << '(';
+  print(expr->getLHS());
+
+  // Pretty print addition to a product that has a negative operand as a
+  // subtraction.
+  if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(expr->getRHS())) {
+    if (rhs->getKind() == AffineExpr::Kind::Mul) {
+      if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
+        if (rrhs->getValue() < 0) {
+          os << " - (";
+          print(rhs->getLHS());
+          os << " * " << -rrhs->getValue() << "))";
+          return;
+        }
+      }
+    }
+  }
+
+  // Pretty print addition to a negative number as a subtraction.
+  if (auto *rhs = dyn_cast<AffineConstantExpr>(expr->getRHS())) {
+    if (rhs->getValue() < 0) {
+      os << " - " << -rhs->getValue() << ")";
+      return;
+    }
+  }
+
+  os << " + ";
+  print(expr->getRHS());
+  os << ')';
+}
+
+void ModulePrinter::print(const AffineMap *map) {
+  // Dimension identifiers.
+  os << '(';
+  for (int i = 0; i < (int)map->getNumDims() - 1; i++)
+    os << "d" << i << ", ";
+  if (map->getNumDims() >= 1)
+    os << "d" << map->getNumDims() - 1;
+  os << ")";
+
+  // Symbolic identifiers.
+  if (map->getNumSymbols() >= 1) {
+    os << " [";
+    for (int i = 0; i < (int)map->getNumSymbols() - 1; i++)
+      os << "s" << i << ", ";
+    if (map->getNumSymbols() >= 1)
+      os << "s" << map->getNumSymbols() - 1;
+    os << "]";
+  }
+
+  // AffineMap should have at least one result.
+  assert(!map->getResults().empty());
+  // Result affine expressions.
+  os << " -> (";
+  interleaveComma(map->getResults(), [&](AffineExpr *expr) { print(expr); });
+  os << ")";
+
+  if (!map->isBounded()) {
+    return;
+  }
+
+  // Print range sizes for bounded affine maps.
+  os << " size (";
+  interleaveComma(map->getRangeSizes(), [&](AffineExpr *expr) { print(expr); });
+  os << ")";
+}
+
+//===----------------------------------------------------------------------===//
 // Function printing
 //===----------------------------------------------------------------------===//
 
-static void printFunctionSignature(const Function *fn,
-                                   const ModuleState *moduleState,
-                                   raw_ostream &os) {
+void ModulePrinter::printFunctionSignature(const Function *fn) {
   auto type = fn->getType();
 
   os << "@" << fn->getName() << '(';
-  interleaveComma(os, type->getInputs(),
-                  [&](Type *eltType) { moduleState->print(eltType); });
+  interleaveComma(type->getInputs(), [&](Type *eltType) { print(eltType); });
   os << ')';
 
   switch (type->getResults().size()) {
@@ -342,20 +487,19 @@
     break;
   case 1:
     os << " -> ";
-    moduleState->print(type->getResults()[0]);
+    print(type->getResults()[0]);
     break;
   default:
     os << " -> (";
-    interleaveComma(os, type->getResults(),
-                    [&](Type *eltType) { moduleState->print(eltType); });
+    interleaveComma(type->getResults(), [&](Type *eltType) { print(eltType); });
     os << ')';
     break;
   }
 }
 
-void ModuleState::print(const ExtFunction *fn) {
+void ModulePrinter::print(const ExtFunction *fn) {
   os << "extfunc ";
-  printFunctionSignature(fn, this, os);
+  printFunctionSignature(fn);
   os << '\n';
 }
 
@@ -363,18 +507,13 @@
 
 // FunctionState contains common functionality for printing
 // CFG and ML functions.
-class FunctionState {
+class FunctionState : public ModulePrinter {
 public:
-  FunctionState(MLIRContext *context, const ModuleState *moduleState,
-                raw_ostream &os);
+  FunctionState(const ModulePrinter &other) : ModulePrinter(other) {}
 
   void printOperation(const Operation *op);
 
 protected:
-  raw_ostream &os;
-  const ModuleState *moduleState;
-  const OperationSet &operationSet;
-
   void numberValueID(const SSAValue *value) {
     assert(!valueIDs.count(value) && "Value numbered multiple times");
     valueIDs[value] = nextValueID++;
@@ -397,12 +536,6 @@
 };
 }  // end anonymous namespace
 
-FunctionState::FunctionState(MLIRContext *context,
-                             const ModuleState *moduleState, raw_ostream &os)
-    : os(os),
-      moduleState(moduleState),
-      operationSet(OperationSet::get(context)) {}
-
 void FunctionState::printOperation(const Operation *op) {
   os << "  ";
 
@@ -417,7 +550,7 @@
 
   // Check to see if this is a known operation.  If so, use the registered
   // custom printer hook.
-  if (auto opInfo = operationSet.lookup(op->getName().str())) {
+  if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
     opInfo->printAssembly(op, os);
     return;
   }
@@ -431,18 +564,18 @@
   // Operation this check can go away.
   if (auto *inst = dyn_cast<OperationInst>(op)) {
     // TODO: Use getOperands() when we have it.
-    interleaveComma(
-        os, inst->getInstOperands(),
-        [&](const InstOperand &operand) { printValueID(operand.get()); });
+    interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+      printValueID(operand.get());
+    });
   }
 
   os << ')';
   auto attrs = op->getAttrs();
   if (!attrs.empty()) {
     os << '{';
-    interleaveComma(os, attrs, [&](NamedAttribute attr) {
+    interleaveComma(attrs, [&](NamedAttribute attr) {
       os << attr.first << ": ";
-      moduleState->print(attr.second);
+      print(attr.second);
     });
     os << '}';
   }
@@ -453,20 +586,18 @@
     // Print the type signature of the operation.
     os << " : (";
     // TODO: Switch to getOperands() when we have it.
-    interleaveComma(os, inst->getInstOperands(), [&](const InstOperand &op) {
-      moduleState->print(op.get()->getType());
-    });
+    interleaveComma(inst->getInstOperands(),
+                    [&](const InstOperand &op) { print(op.get()->getType()); });
     os << ") -> ";
 
     // TODO: Switch to getResults() when we have it.
     if (inst->getNumResults() == 1) {
-      moduleState->print(inst->getInstResult(0).getType());
+      print(inst->getInstResult(0).getType());
     } else {
       os << '(';
-      interleaveComma(os, inst->getInstResults(),
-                      [&](const InstResult &result) {
-                        moduleState->print(result.getType());
-                      });
+      interleaveComma(inst->getInstResults(), [&](const InstResult &result) {
+        print(result.getType());
+      });
       os << ')';
     }
   }
@@ -477,10 +608,9 @@
 //===----------------------------------------------------------------------===//
 
 namespace {
-class CFGFunctionState : public FunctionState {
+class CFGFunctionPrinter : public FunctionState {
 public:
-  CFGFunctionState(const CFGFunction *function, const ModuleState *moduleState,
-                   raw_ostream &os);
+  CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other);
 
   const CFGFunction *getFunction() const { return function; }
 
@@ -502,25 +632,23 @@
   const CFGFunction *function;
   DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
 
-  void numberBlock(const BasicBlock *block);
+  void numberValuesInBlock(const BasicBlock *block);
 };
 }  // end anonymous namespace
 
-CFGFunctionState::CFGFunctionState(const CFGFunction *function,
-                                   const ModuleState *moduleState,
-                                   raw_ostream &os)
-    : FunctionState(function->getContext(), moduleState, os),
-      function(function) {
+CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
+                                       const ModulePrinter &other)
+    : FunctionState(other), function(function) {
   // Each basic block gets a unique ID per function.
   unsigned blockID = 0;
   for (auto &block : *function) {
     basicBlockIDs[&block] = blockID++;
-    numberBlock(&block);
+    numberValuesInBlock(&block);
   }
 }
 
 /// Number all of the SSA values in the specified basic block.
-void CFGFunctionState::numberBlock(const BasicBlock *block) {
+void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
   // TODO: basic block arguments.
   for (auto &op : *block) {
     // We number instruction that have results, and we only number the first
@@ -532,16 +660,16 @@
   // Terminators do not define values.
 }
 
-void CFGFunctionState::print() {
+void CFGFunctionPrinter::print() {
   os << "cfgfunc ";
-  printFunctionSignature(this->getFunction(), moduleState, os);
+  printFunctionSignature(getFunction());
   os << " {\n";
 
   for (auto &block : *function) print(&block);
   os << "}\n\n";
 }
 
-void CFGFunctionState::print(const BasicBlock *block) {
+void CFGFunctionPrinter::print(const BasicBlock *block) {
   os << "bb" << getBBID(block) << ":\n";
 
   // TODO Print arguments.
@@ -554,7 +682,7 @@
   os << "\n";
 }
 
-void CFGFunctionState::print(const Instruction *inst) {
+void CFGFunctionPrinter::print(const Instruction *inst) {
   switch (inst->getKind()) {
   case Instruction::Kind::Operation:
     return print(cast<OperationInst>(inst));
@@ -565,17 +693,16 @@
   }
 }
 
-void CFGFunctionState::print(const OperationInst *inst) {
+void CFGFunctionPrinter::print(const OperationInst *inst) {
   printOperation(inst);
 }
-void CFGFunctionState::print(const BranchInst *inst) {
+void CFGFunctionPrinter::print(const BranchInst *inst) {
   os << "  br bb" << getBBID(inst->getDest());
 }
-void CFGFunctionState::print(const ReturnInst *inst) { os << "  return"; }
+void CFGFunctionPrinter::print(const ReturnInst *inst) { os << "  return"; }
 
-void ModuleState::print(const CFGFunction *fn) {
-  CFGFunctionState state(fn, this, os);
-  state.print();
+void ModulePrinter::print(const CFGFunction *fn) {
+  CFGFunctionPrinter(fn, *this).print();
 }
 
 //===----------------------------------------------------------------------===//
@@ -583,10 +710,9 @@
 //===----------------------------------------------------------------------===//
 
 namespace {
-class MLFunctionState : public FunctionState {
+class MLFunctionPrinter : public FunctionState {
 public:
-  MLFunctionState(const MLFunction *function, const ModuleState *moduleState,
-                  raw_ostream &os);
+  MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other);
 
   const MLFunction *getFunction() const { return function; }
 
@@ -609,24 +735,21 @@
 };
 }  // end anonymous namespace
 
-MLFunctionState::MLFunctionState(const MLFunction *function,
-                                 const ModuleState *moduleState,
-                                 raw_ostream &os)
-    : FunctionState(function->getContext(), moduleState, os),
-      function(function),
-      numSpaces(0) {}
+MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
+                                     const ModulePrinter &other)
+    : FunctionState(other), function(function), numSpaces(0) {}
 
-void MLFunctionState::print() {
+void MLFunctionPrinter::print() {
   os << "mlfunc ";
   // FIXME: should print argument names rather than just signature
-  printFunctionSignature(function, moduleState, os);
+  printFunctionSignature(function);
   os << " {\n";
   print(function);
   os << "  return\n";
   os << "}\n\n";
 }
 
-void MLFunctionState::print(const StmtBlock *block) {
+void MLFunctionPrinter::print(const StmtBlock *block) {
   numSpaces += indentWidth;
   for (auto &stmt : block->getStatements()) {
     print(&stmt);
@@ -635,7 +758,7 @@
   numSpaces -= indentWidth;
 }
 
-void MLFunctionState::print(const Statement *stmt) {
+void MLFunctionPrinter::print(const Statement *stmt) {
   switch (stmt->getKind()) {
   case Statement::Kind::Operation:
     return print(cast<OperationStmt>(stmt));
@@ -646,9 +769,11 @@
   }
 }
 
-void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
+void MLFunctionPrinter::print(const OperationStmt *stmt) {
+  printOperation(stmt);
+}
 
-void MLFunctionState::print(const ForStmt *stmt) {
+void MLFunctionPrinter::print(const ForStmt *stmt) {
   os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
   os << " to " << *stmt->getUpperBound();
   if (stmt->getStep()->getValue() != 1)
@@ -659,7 +784,7 @@
   os.indent(numSpaces) << "}";
 }
 
-void MLFunctionState::print(const IfStmt *stmt) {
+void MLFunctionPrinter::print(const IfStmt *stmt) {
   os.indent(numSpaces) << "if () {\n";
   print(stmt->getThenClause());
   os.indent(numSpaces) << "}";
@@ -670,9 +795,8 @@
   }
 }
 
-void ModuleState::print(const MLFunction *fn) {
-  MLFunctionState state(fn, this, os);
-  state.print();
+void ModulePrinter::print(const MLFunction *fn) {
+  MLFunctionPrinter(fn, *this).print();
 }
 
 //===----------------------------------------------------------------------===//
@@ -680,8 +804,8 @@
 //===----------------------------------------------------------------------===//
 
 void Attribute::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  moduleState.print(this);
+  ModuleState state(/*no context is known*/ nullptr);
+  ModulePrinter(os, state).print(this);
 }
 
 void Attribute::dump() const {
@@ -689,23 +813,12 @@
 }
 
 void Type::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  moduleState.print(this);
+  ModuleState state(getContext());
+  ModulePrinter(os, state).print(this);
 }
 
 void Type::dump() const { print(llvm::errs()); }
 
-void Instruction::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  CFGFunctionState state(getFunction(), &moduleState, os);
-  state.print(this);
-}
-
-void Instruction::dump() const {
-  print(llvm::errs());
-  llvm::errs() << "\n";
-}
-
 void AffineMap::dump() const {
   print(llvm::errs());
   llvm::errs() << "\n";
@@ -716,163 +829,54 @@
   llvm::errs() << "\n";
 }
 
-void AffineSymbolExpr::print(raw_ostream &os) const {
-  os << 's' << getPosition();
-}
-
-void AffineDimExpr::print(raw_ostream &os) const { os << 'd' << getPosition(); }
-
-void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); }
-
-static void printAdd(const AffineBinaryOpExpr *addExpr, raw_ostream &os) {
-  os << '(' << *addExpr->getLHS();
-
-  // Pretty print addition to a product that has a negative operand as a
-  // subtraction.
-  if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(addExpr->getRHS())) {
-    if (rhs->getKind() == AffineExpr::Kind::Mul) {
-      if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
-        if (rrhs->getValue() < 0) {
-          os << " - (" << *rhs->getLHS() << " * " << -rrhs->getValue() << "))";
-          return;
-        }
-      }
-    }
-  }
-
-  // Pretty print addition to a negative number as a subtraction.
-  if (auto *rhs = dyn_cast<AffineConstantExpr>(addExpr->getRHS())) {
-    if (rhs->getValue() < 0) {
-      os << " - " << -rhs->getValue() << ")";
-      return;
-    }
-  }
-
-  os << " + " << *addExpr->getRHS() << ")";
-}
-
-void AffineBinaryOpExpr::print(raw_ostream &os) const {
-  switch (getKind()) {
-  case Kind::Add:
-    return printAdd(this, os);
-  case Kind::Mul:
-    os << "(" << *getLHS() << " * " << *getRHS() << ")";
-    return;
-  case Kind::FloorDiv:
-    os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
-    return;
-  case Kind::CeilDiv:
-    os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
-    return;
-  case Kind::Mod:
-    os << "(" << *getLHS() << " mod " << *getRHS() << ")";
-    return;
-  default:
-    llvm_unreachable("unexpected affine binary op expression");
-  }
-}
-
 void AffineExpr::print(raw_ostream &os) const {
-  switch (getKind()) {
-  case Kind::SymbolId:
-    return cast<AffineSymbolExpr>(this)->print(os);
-  case Kind::DimId:
-    return cast<AffineDimExpr>(this)->print(os);
-  case Kind::Constant:
-    return cast<AffineConstantExpr>(this)->print(os);
-  case Kind::Add:
-  case Kind::Mul:
-  case Kind::FloorDiv:
-  case Kind::CeilDiv:
-  case Kind::Mod:
-    return cast<AffineBinaryOpExpr>(this)->print(os);
-  }
+  ModuleState state(/*no context is known*/ nullptr);
+  ModulePrinter(os, state).print(this);
 }
 
 void AffineMap::print(raw_ostream &os) const {
-  // Dimension identifiers.
-  os << "(";
-  for (int i = 0; i < (int)getNumDims() - 1; i++) os << "d" << i << ", ";
-  if (getNumDims() >= 1) os << "d" << getNumDims() - 1;
-  os << ")";
+  ModuleState state(/*no context is known*/ nullptr);
+  ModulePrinter(os, state).print(this);
+}
 
-  // Symbolic identifiers.
-  if (getNumSymbols() >= 1) {
-    os << " [";
-    for (int i = 0; i < (int)getNumSymbols() - 1; i++) os << "s" << i << ", ";
-    if (getNumSymbols() >= 1) os << "s" << getNumSymbols() - 1;
-    os << "]";
-  }
+void Instruction::print(raw_ostream &os) const {
+  ModuleState state(getFunction()->getContext());
+  ModulePrinter modulePrinter(os, state);
+  CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
+}
 
-  // AffineMap should have at least one result.
-  assert(!getResults().empty());
-  // Result affine expressions.
-  os << " -> (";
-  interleaveComma(os, getResults(), [&](AffineExpr *expr) { os << *expr; });
-  os << ")";
-
-  if (!isBounded()) {
-    return;
-  }
-
-  // Print range sizes for bounded affine maps.
-  os << " size (";
-  interleaveComma(os, getRangeSizes(), [&](AffineExpr *expr) { os << *expr; });
-  os << ")";
+void Instruction::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
 }
 
 void BasicBlock::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  CFGFunctionState state(getFunction(), &moduleState, os);
-  state.print();
+  ModuleState state(getFunction()->getContext());
+  ModulePrinter modulePrinter(os, state);
+  CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
 }
 
 void BasicBlock::dump() const { print(llvm::errs()); }
 
 void Statement::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  MLFunctionState state(getFunction(), &moduleState, os);
-  state.print(this);
+  ModuleState state(getFunction()->getContext());
+  ModulePrinter modulePrinter(os, state);
+  MLFunctionPrinter(getFunction(), modulePrinter).print(this);
 }
 
 void Statement::dump() const { print(llvm::errs()); }
 
 void Function::print(raw_ostream &os) const {
-  switch (getKind()) {
-  case Kind::ExtFunc:
-    return cast<ExtFunction>(this)->print(os);
-  case Kind::CFGFunc:
-    return cast<CFGFunction>(this)->print(os);
-  case Kind::MLFunc:
-    return cast<MLFunction>(this)->print(os);
-  }
+  ModuleState state(getContext());
+  ModulePrinter(os, state).print(this);
 }
 
 void Function::dump() const { print(llvm::errs()); }
 
-void ExtFunction::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  os << "extfunc ";
-  printFunctionSignature(this, &moduleState, os);
-  os << "\n";
-}
-
-void CFGFunction::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  CFGFunctionState state(this, &moduleState, os);
-  state.print();
-}
-
-void MLFunction::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  MLFunctionState state(this, &moduleState, os);
-  state.print();
-}
-
 void Module::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
-  moduleState.initialize(this);
-  moduleState.print(this);
+  ModuleState state(getContext());
+  state.initialize(this);
+  ModulePrinter(os, state).print(this);
 }
 
 void Module::dump() const { print(llvm::errs()); }