Address AsmPrinter changes from last CL.

PiperOrigin-RevId: 205096519
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index be6573f..baddac9 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -34,37 +34,35 @@
 #include "llvm/Support/raw_ostream.h"
 using namespace mlir;
 
+void Identifier::print(raw_ostream &os) const { os << str(); }
 
-void Identifier::print(raw_ostream &os) const {
-  os << str();
-}
-
-void Identifier::dump() const {
-  print(llvm::errs());
-}
+void Identifier::dump() const { print(llvm::errs()); }
 
 //===----------------------------------------------------------------------===//
 // Module printing
 //===----------------------------------------------------------------------===//
 
 namespace {
-class ModuleState  {
- public:
+class ModuleState {
+public:
   ModuleState(raw_ostream &os);
 
   void initialize(const Module *module);
 
   void print(const Module *module);
   void print(const Type *type) const;
-  void print(const Function* fn);
+  void print(const Function *fn);
+  void print(const ExtFunction *fn);
+  void print(const CFGFunction *fn);
+  void print(const MLFunction *fn);
 
-  void recordAffineMapReference(const AffineMap* affineMap) {
+  void recordAffineMapReference(const AffineMap *affineMap) {
     if (affineMapIds.count(affineMap) == 0) {
       affineMapIds[affineMap] = nextAffineMapId++;
     }
   }
 
-  int getAffineMapId(const AffineMap* affineMap) const {
+  int getAffineMapId(const AffineMap *affineMap) const {
     auto it = affineMapIds.find(affineMap);
     if (it == affineMapIds.end()) {
       return -1;
@@ -72,7 +70,7 @@
     return it->second;
   }
 
- private:
+private:
   // Visit functions.
   void visitFunction(const Function *fn);
   void visitExtFunction(const ExtFunction *fn);
@@ -81,13 +79,12 @@
   void visitType(const Type *type);
 
   raw_ostream &os;
-  DenseMap<const AffineMap*, int> affineMapIds;
+  DenseMap<const AffineMap *, int> affineMapIds;
   int nextAffineMapId = 0;
 };
-} // end anonymous namespace
+}  // end anonymous namespace
 
-ModuleState::ModuleState(raw_ostream &os) : os(os) {
-}
+ModuleState::ModuleState(raw_ostream &os) : os(os) {}
 
 // Initializes module state, populating affine map state.
 void ModuleState::initialize(const Module *module) {
@@ -101,16 +98,16 @@
   if (type->getKind() == Type::Kind::Function) {
     // Visit input and result types for functions.
     auto *funcType = cast<FunctionType>(type);
-    for (auto* input : funcType->getInputs()) {
+    for (auto *input : funcType->getInputs()) {
       visitType(input);
     }
-    for (auto* result : funcType->getResults()) {
+    for (auto *result : funcType->getResults()) {
       visitType(result);
     }
   } else if (type->getKind() == Type::Kind::MemRef) {
     // Visit affine maps in memref type.
     auto *memref = cast<MemRefType>(type);
-    for (AffineMap* map : memref->getAffineMaps()) {
+    for (AffineMap *map : memref->getAffineMaps()) {
       recordAffineMapReference(map);
     }
   }
@@ -132,34 +129,24 @@
 
 void ModuleState::visitFunction(const Function *fn) {
   switch (fn->getKind()) {
-    case Function::Kind::ExtFunc:
-      return visitExtFunction(cast<ExtFunction>(fn));
-    case Function::Kind::CFGFunc:
-      return visitCFGFunction(cast<CFGFunction>(fn));
-    case Function::Kind::MLFunc:
-      return visitMLFunction(cast<MLFunction>(fn));
+  case Function::Kind::ExtFunc:
+    return visitExtFunction(cast<ExtFunction>(fn));
+  case Function::Kind::CFGFunc:
+    return visitCFGFunction(cast<CFGFunction>(fn));
+  case Function::Kind::MLFunc:
+    return visitMLFunction(cast<MLFunction>(fn));
   }
 }
 
-static void printExtFunction(const ExtFunction* fn,
-                             const ModuleState* moduleState, raw_ostream &os);
-
-
-static void printCFGFunction(const CFGFunction* fn,
-                             const ModuleState* moduleState, raw_ostream &os);
-
-static void printMLFunction(const MLFunction* fn,
-                            const ModuleState* moduleState, raw_ostream &os);
-
 // Prints function with initialized module state.
-void ModuleState::print(const Function* fn) {
+void ModuleState::print(const Function *fn) {
   switch (fn->getKind()) {
-    case Function::Kind::ExtFunc:
-      return printExtFunction(cast<ExtFunction>(fn), this, os);
-    case Function::Kind::CFGFunc:
-      return printCFGFunction(cast<CFGFunction>(fn), this, os);
-    case Function::Kind::MLFunc:
-      return printMLFunction(cast<MLFunction>(fn), this, os);
+  case Function::Kind::ExtFunc:
+    return print(cast<ExtFunction>(fn));
+  case Function::Kind::CFGFunc:
+    return print(cast<CFGFunction>(fn));
+  case Function::Kind::MLFunc:
+    return print(cast<MLFunction>(fn));
   }
 }
 
@@ -169,23 +156,32 @@
 }
 
 void ModuleState::print(const Module *module) {
-  for (const auto& mapAndId : affineMapIds) {
+  for (const auto &mapAndId : affineMapIds) {
     printAffineMapId(mapAndId.second, os);
     os << " = ";
     mapAndId.first->print(os);
     os << '\n';
   }
-  for (auto *fn : module->functionList)
-    print(fn);
+  for (auto *fn : module->functionList) print(fn);
 }
 
 void ModuleState::print(const Type *type) const {
   switch (type->getKind()) {
-  case Type::Kind::AffineInt: os << "affineint"; return;
-  case Type::Kind::BF16: os << "bf16"; return;
-  case Type::Kind::F16:  os << "f16"; return;
-  case Type::Kind::F32:  os << "f32"; return;
-  case Type::Kind::F64:  os << "f64"; return;
+  case Type::Kind::AffineInt:
+    os << "affineint";
+    return;
+  case Type::Kind::BF16:
+    os << "bf16";
+    return;
+  case Type::Kind::F16:
+    os << "f16";
+    return;
+  case Type::Kind::F32:
+    os << "f32";
+    return;
+  case Type::Kind::F64:
+    os << "f64";
+    return;
 
   case Type::Kind::Integer: {
     auto *integer = cast<IntegerType>(type);
@@ -195,8 +191,7 @@
   case Type::Kind::Function: {
     auto *func = cast<FunctionType>(type);
     os << '(';
-    interleave(func->getInputs(),
-               [&](Type *type) { os << *type; },
+    interleave(func->getInputs(), [&](Type *type) { os << *type; },
                [&]() { os << ", "; });
     os << ") -> ";
     auto results = func->getResults();
@@ -204,8 +199,7 @@
       os << *results[0];
     else {
       os << '(';
-      interleave(results,
-                 [&](Type *type) { os << *type; },
+      interleave(results, [&](Type *type) { os << *type; },
                  [&]() { os << ", "; });
       os << ')';
     }
@@ -214,8 +208,7 @@
   case Type::Kind::Vector: {
     auto *v = cast<VectorType>(type);
     os << "vector<";
-    for (auto dim : v->getShape())
-      os << dim << 'x';
+    for (auto dim : v->getShape()) os << dim << 'x';
     os << *v->getElementType() << '>';
     return;
   }
@@ -282,7 +275,8 @@
   os << ')';
 
   switch (type->getResults().size()) {
-  case 0: break;
+  case 0:
+    break;
   case 1:
     os << " -> ";
     moduleState->print(type->getResults()[0]);
@@ -297,11 +291,10 @@
   }
 }
 
-void ExtFunction::print(raw_ostream &os) const {
-  ModuleState moduleState(os);
+void ModuleState::print(const ExtFunction *fn) {
   os << "extfunc ";
-  printFunctionSignature(this, &moduleState, os);
-  os << "\n";
+  printFunctionSignature(fn, this, os);
+  os << '\n';
 }
 
 namespace {
@@ -320,12 +313,12 @@
   const ModuleState *moduleState;
   const OperationSet &operationSet;
 };
-} // end anonymous namespace
+}  // end anonymous namespace
 
 FunctionState::FunctionState(MLIRContext *context,
-                             const ModuleState *moduleState,
-                             raw_ostream &os)
-    : os(os), moduleState(moduleState),
+                             const ModuleState *moduleState, raw_ostream &os)
+    : os(os),
+      moduleState(moduleState),
       operationSet(OperationSet::get(context)) {}
 
 void FunctionState::printOperation(const Operation *op) {
@@ -379,9 +372,9 @@
 
 private:
   const CFGFunction *function;
-  DenseMap<const BasicBlock*, unsigned> basicBlockIDs;
+  DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
 };
-} // end anonymous namespace
+}  // end anonymous namespace
 
 CFGFunctionState::CFGFunctionState(const CFGFunction *function,
                                    const ModuleState *moduleState,
@@ -390,8 +383,7 @@
       function(function) {
   // Each basic block gets a unique ID per function.
   unsigned blockID = 0;
-  for (auto &block : *function)
-    basicBlockIDs[&block] = blockID++;
+  for (auto &block : *function) basicBlockIDs[&block] = blockID++;
 }
 
 void CFGFunctionState::print() {
@@ -399,8 +391,7 @@
   printFunctionSignature(this->getFunction(), moduleState, os);
   os << " {\n";
 
-  for (auto &block : *function)
-    print(&block);
+  for (auto &block : *function) print(&block);
   os << "}\n\n";
 }
 
@@ -435,8 +426,11 @@
 void CFGFunctionState::print(const BranchInst *inst) {
   os << "  br bb" << getBBID(inst->getDest());
 }
-void CFGFunctionState::print(const ReturnInst *inst) {
-  os << "  return";
+void CFGFunctionState::print(const ReturnInst *inst) { os << "  return"; }
+
+void ModuleState::print(const CFGFunction *fn) {
+  CFGFunctionState state(fn, this, os);
+  state.print();
 }
 
 //===----------------------------------------------------------------------===//
@@ -468,13 +462,14 @@
   const MLFunction *function;
   int numSpaces;
 };
-} // end anonymous namespace
+}  // end anonymous namespace
 
 MLFunctionState::MLFunctionState(const MLFunction *function,
                                  const ModuleState *moduleState,
                                  raw_ostream &os)
     : FunctionState(function->getContext(), moduleState, os),
-      function(function), numSpaces(0) {}
+      function(function),
+      numSpaces(0) {}
 
 void MLFunctionState::print() {
   os << "mlfunc ";
@@ -506,9 +501,7 @@
   }
 }
 
-void MLFunctionState::print(const OperationStmt *stmt) {
-  printOperation(stmt);
-}
+void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
 
 void MLFunctionState::print(const ForStmt *stmt) {
   os.indent(numSpaces) << "for {\n";
@@ -527,22 +520,8 @@
   }
 }
 
-void printExtFunction(const ExtFunction* fn, const ModuleState* moduleState,
-                      raw_ostream &os) {
-  os << "extfunc ";
-  printFunctionSignature(fn, moduleState, os);
-  os << '\n';
-}
-
-void printCFGFunction(const CFGFunction* fn, const ModuleState* moduleState,
-                      raw_ostream &os) {
-  CFGFunctionState state(fn, moduleState, os);
-  state.print();
-}
-
-void printMLFunction(const MLFunction* fn, const ModuleState* moduleState,
-                     raw_ostream &os) {
-  MLFunctionState state(fn, moduleState, os);
+void ModuleState::print(const MLFunction *fn) {
+  MLFunctionState state(fn, this, os);
   state.print();
 }
 
@@ -555,9 +534,7 @@
   moduleState.print(this);
 }
 
-void Type::dump() const {
-  print(llvm::errs());
-}
+void Type::dump() const { print(llvm::errs()); }
 
 void Instruction::print(raw_ostream &os) const {
   ModuleState moduleState(os);
@@ -638,19 +615,15 @@
 void AffineMap::print(raw_ostream &os) const {
   // Dimension identifiers.
   os << "(";
-  for (int i = 0; i < (int)getNumDims() - 1; i++)
-    os << "d" << i << ", ";
-  if (getNumDims() >= 1)
-    os << "d" << getNumDims() - 1;
+  for (int i = 0; i < (int)getNumDims() - 1; i++) os << "d" << i << ", ";
+  if (getNumDims() >= 1) os << "d" << getNumDims() - 1;
   os << ")";
 
   // Symbolic identifiers.
   if (getNumSymbols() >= 1) {
     os << " [";
-    for (int i = 0; i < (int)getNumSymbols() - 1; i++)
-      os << "s" << i << ", ";
-    if (getNumSymbols() >= 1)
-      os << "s" << getNumSymbols() - 1;
+    for (int i = 0; i < (int)getNumSymbols() - 1; i++) os << "s" << i << ", ";
+    if (getNumSymbols() >= 1) os << "s" << getNumSymbols() - 1;
     os << "]";
   }
 
@@ -679,9 +652,7 @@
   state.print();
 }
 
-void BasicBlock::dump() const {
-  print(llvm::errs());
-}
+void BasicBlock::dump() const { print(llvm::errs()); }
 
 void Statement::print(raw_ostream &os) const {
   ModuleState moduleState(os);
@@ -689,20 +660,26 @@
   state.print(this);
 }
 
-void Statement::dump() const {
-  print(llvm::errs());
-}
+void Statement::dump() const { print(llvm::errs()); }
 
 void Function::print(raw_ostream &os) const {
   switch (getKind()) {
-  case Kind::ExtFunc: return cast<ExtFunction>(this)->print(os);
-  case Kind::CFGFunc: return cast<CFGFunction>(this)->print(os);
-  case Kind::MLFunc:  return cast<MLFunction>(this)->print(os);
+  case Kind::ExtFunc:
+    return cast<ExtFunction>(this)->print(os);
+  case Kind::CFGFunc:
+    return cast<CFGFunction>(this)->print(os);
+  case Kind::MLFunc:
+    return cast<MLFunction>(this)->print(os);
   }
 }
 
-void Function::dump() const {
-  print(llvm::errs());
+void Function::dump() const { print(llvm::errs()); }
+
+void ExtFunction::print(raw_ostream &os) const {
+  ModuleState moduleState(os);
+  os << "extfunc ";
+  printFunctionSignature(this, &moduleState, os);
+  os << "\n";
 }
 
 void CFGFunction::print(raw_ostream &os) const {
@@ -723,6 +700,4 @@
   moduleState.print(this);
 }
 
-void Module::dump() const {
-  print(llvm::errs());
-}
+void Module::dump() const { print(llvm::errs()); }