Expose custom asmprinter support to core operations and have them adopt it,
fixing the printing syntax for dim, constant, fadd, etc.

PiperOrigin-RevId: 205908627
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 1d61213..fa4462d 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/Types.h"
@@ -38,6 +39,8 @@
 
 void Identifier::dump() const { print(llvm::errs()); }
 
+OpAsmPrinter::~OpAsmPrinter() {}
+
 //===----------------------------------------------------------------------===//
 // ModuleState
 //===----------------------------------------------------------------------===//
@@ -176,15 +179,15 @@
   }
 
   void print(const Module *module);
-  void print(const Attribute *attr) const;
-  void print(const Type *type) const;
+  void printAttribute(const Attribute *attr);
+  void printType(const Type *type);
   void print(const Function *fn);
   void print(const ExtFunction *fn);
   void print(const CFGFunction *fn);
   void print(const MLFunction *fn);
 
-  void print(const AffineMap *map);
-  void print(const AffineExpr *expr) const;
+  void printAffineMap(const AffineMap *map);
+  void printAffineExpr(const AffineExpr *expr);
 
 protected:
   raw_ostream &os;
@@ -192,9 +195,9 @@
 
   void printFunctionSignature(const Function *fn);
   void printAffineMapId(int affineMapId) const;
-  void printAffineMapReference(const AffineMap *affineMap) const;
+  void printAffineMapReference(const AffineMap *affineMap);
 
-  void print(const AffineBinaryOpExpr *expr) const;
+  void printAffineBinaryOpExpr(const AffineBinaryOpExpr *expr);
 };
 } // end anonymous namespace
 
@@ -215,7 +218,7 @@
   os << "#map" << affineMapId;
 }
 
-void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) const {
+void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) {
   int mapId = state.getAffineMapId(affineMap);
   if (mapId >= 0) {
     // Map will be printed at top of module so print reference to its id.
@@ -237,7 +240,7 @@
     print(fn);
 }
 
-void ModulePrinter::print(const Attribute *attr) const {
+void ModulePrinter::printAttribute(const Attribute *attr) {
   switch (attr->getKind()) {
   case Attribute::Kind::Bool:
     os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
@@ -256,7 +259,7 @@
   case Attribute::Kind::Array: {
     auto elts = cast<ArrayAttr>(attr)->getValue();
     os << '[';
-    interleaveComma(elts, [&](Attribute *attr) { print(attr); });
+    interleaveComma(elts, [&](Attribute *attr) { printAttribute(attr); });
     os << ']';
     break;
   }
@@ -266,7 +269,7 @@
   }
 }
 
-void ModulePrinter::print(const Type *type) const {
+void ModulePrinter::printType(const Type *type) {
   switch (type->getKind()) {
   case Type::Kind::AffineInt:
     os << "affineint";
@@ -356,7 +359,7 @@
 // Affine expressions and maps
 //===----------------------------------------------------------------------===//
 
-void ModulePrinter::print(const AffineExpr *expr) const {
+void ModulePrinter::printAffineExpr(const AffineExpr *expr) {
   switch (expr->getKind()) {
   case AffineExpr::Kind::SymbolId:
     os << 's' << cast<AffineSymbolExpr>(expr)->getPosition();
@@ -372,14 +375,14 @@
   case AffineExpr::Kind::FloorDiv:
   case AffineExpr::Kind::CeilDiv:
   case AffineExpr::Kind::Mod:
-    return print(cast<AffineBinaryOpExpr>(expr));
+    return printAffineBinaryOpExpr(cast<AffineBinaryOpExpr>(expr));
   }
 }
 
-void ModulePrinter::print(const AffineBinaryOpExpr *expr) const {
+void ModulePrinter::printAffineBinaryOpExpr(const AffineBinaryOpExpr *expr) {
   if (expr->getKind() != AffineExpr::Kind::Add) {
     os << '(';
-    print(expr->getLHS());
+    printAffineExpr(expr->getLHS());
     switch (expr->getKind()) {
     case AffineExpr::Kind::Mul:
       os << " * ";
@@ -397,14 +400,14 @@
       llvm_unreachable("unexpected affine binary op expression");
     }
 
-    print(expr->getRHS());
+    printAffineExpr(expr->getRHS());
     os << ')';
     return;
   }
 
   // Print out special "pretty" forms for add.
   os << '(';
-  print(expr->getLHS());
+  printAffineExpr(expr->getLHS());
 
   // Pretty print addition to a product that has a negative operand as a
   // subtraction.
@@ -413,7 +416,7 @@
       if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
         if (rrhs->getValue() < 0) {
           os << " - (";
-          print(rhs->getLHS());
+          printAffineExpr(rhs->getLHS());
           os << " * " << -rrhs->getValue() << "))";
           return;
         }
@@ -430,11 +433,11 @@
   }
 
   os << " + ";
-  print(expr->getRHS());
+  printAffineExpr(expr->getRHS());
   os << ')';
 }
 
-void ModulePrinter::print(const AffineMap *map) {
+void ModulePrinter::printAffineMap(const AffineMap *map) {
   // Dimension identifiers.
   os << '(';
   for (int i = 0; i < (int)map->getNumDims() - 1; i++)
@@ -457,7 +460,8 @@
   assert(!map->getResults().empty());
   // Result affine expressions.
   os << " -> (";
-  interleaveComma(map->getResults(), [&](AffineExpr *expr) { print(expr); });
+  interleaveComma(map->getResults(),
+                  [&](AffineExpr *expr) { printAffineExpr(expr); });
   os << ")";
 
   if (!map->isBounded()) {
@@ -466,7 +470,8 @@
 
   // Print range sizes for bounded affine maps.
   os << " size (";
-  interleaveComma(map->getRangeSizes(), [&](AffineExpr *expr) { print(expr); });
+  interleaveComma(map->getRangeSizes(),
+                  [&](AffineExpr *expr) { printAffineExpr(expr); });
   os << ")";
 }
 
@@ -478,7 +483,8 @@
   auto type = fn->getType();
 
   os << "@" << fn->getName() << '(';
-  interleaveComma(type->getInputs(), [&](Type *eltType) { print(eltType); });
+  interleaveComma(type->getInputs(),
+                  [&](Type *eltType) { printType(eltType); });
   os << ')';
 
   switch (type->getResults().size()) {
@@ -486,11 +492,12 @@
     break;
   case 1:
     os << " -> ";
-    print(type->getResults()[0]);
+    printType(type->getResults()[0]);
     break;
   default:
     os << " -> (";
-    interleaveComma(type->getResults(), [&](Type *eltType) { print(eltType); });
+    interleaveComma(type->getResults(),
+                    [&](Type *eltType) { printType(eltType); });
     os << ')';
     break;
   }
@@ -504,13 +511,29 @@
 
 namespace {
 
-// FunctionState contains common functionality for printing
+// FunctionPrinter contains common functionality for printing
 // CFG and ML functions.
-class FunctionState : public ModulePrinter {
+class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
 public:
-  FunctionState(const ModulePrinter &other) : ModulePrinter(other) {}
+  FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {}
 
   void printOperation(const Operation *op);
+  void printDefaultOp(const Operation *op);
+
+  // Implement OpAsmPrinter.
+  raw_ostream &getStream() const { return os; }
+  void printType(const Type *type) { ModulePrinter::printType(type); }
+  void printAttribute(const Attribute *attr) {
+    ModulePrinter::printAttribute(attr);
+  }
+  void printAffineMap(const AffineMap *map) {
+    return ModulePrinter::printAffineMap(map);
+  }
+  void printAffineExpr(const AffineExpr *expr) {
+    return ModulePrinter::printAffineExpr(expr);
+  }
+
+  void printOperand(const SSAValue *value) { printValueID(value); }
 
 protected:
   void numberValueID(const SSAValue *value) {
@@ -551,7 +574,7 @@
 };
 } // end anonymous namespace
 
-void FunctionState::printOperation(const Operation *op) {
+void FunctionPrinter::printOperation(const Operation *op) {
   os << "  ";
 
   if (op->getNumResults()) {
@@ -562,12 +585,15 @@
   // Check to see if this is a known operation.  If so, use the registered
   // custom printer hook.
   if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
-    opInfo->printAssembly(op, os);
+    opInfo->printAssembly(op, this);
     return;
   }
 
   // Otherwise use the standard verbose printing approach.
+  printDefaultOp(op);
+}
 
+void FunctionPrinter::printDefaultOp(const Operation *op) {
   // TODO: escape name if necessary.
   os << "\"" << op->getName().str() << "\"(";
 
@@ -580,7 +606,7 @@
     os << '{';
     interleaveComma(attrs, [&](NamedAttribute attr) {
       os << attr.first << ": ";
-      print(attr.second);
+      printAttribute(attr.second);
     });
     os << '}';
   }
@@ -588,15 +614,16 @@
   // Print the type signature of the operation.
   os << " : (";
   interleaveComma(op->getOperands(),
-                  [&](const SSAValue *value) { print(value->getType()); });
+                  [&](const SSAValue *value) { printType(value->getType()); });
   os << ") -> ";
 
   if (op->getNumResults() == 1) {
-    print(op->getResult(0)->getType());
+    printType(op->getResult(0)->getType());
   } else {
     os << '(';
-    interleaveComma(op->getResults(),
-                    [&](const SSAValue *result) { print(result->getType()); });
+    interleaveComma(op->getResults(), [&](const SSAValue *result) {
+      printType(result->getType());
+    });
     os << ')';
   }
 }
@@ -606,7 +633,7 @@
 //===----------------------------------------------------------------------===//
 
 namespace {
-class CFGFunctionPrinter : public FunctionState {
+class CFGFunctionPrinter : public FunctionPrinter {
 public:
   CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other);
 
@@ -637,7 +664,7 @@
 
 CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
                                        const ModulePrinter &other)
-    : FunctionState(other), function(function) {
+    : FunctionPrinter(other), function(function) {
   // Each basic block gets a unique ID per function.
   unsigned blockID = 0;
   for (auto &block : *function) {
@@ -679,7 +706,7 @@
     interleaveComma(block->getArguments(), [&](const BBArgument *arg) {
       printValueID(arg);
       os << ": ";
-      ModulePrinter::print(arg->getType());
+      printType(arg->getType());
     });
     os << ')';
   }
@@ -722,7 +749,7 @@
     });
     os << ") : ";
     interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
-      ModulePrinter::print(operand.get()->getType());
+      printType(operand.get()->getType());
     });
   }
 }
@@ -738,7 +765,7 @@
                     [&](const CFGValue *operand) { printValueID(operand); });
     os << " : ";
     interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) {
-      ModulePrinter::print(operand->getType());
+      printType(operand->getType());
     });
     os << ")";
   }
@@ -750,7 +777,7 @@
                     [&](const CFGValue *operand) { printValueID(operand); });
     os << " : ";
     interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) {
-      ModulePrinter::print(operand->getType());
+      printType(operand->getType());
     });
     os << ")";
   }
@@ -766,7 +793,7 @@
                   [&](const CFGValue *operand) { printValueID(operand); });
   os << " : ";
   interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
-    ModulePrinter::print(operand->getType());
+    printType(operand->getType());
   });
 }
 
@@ -779,7 +806,7 @@
 //===----------------------------------------------------------------------===//
 
 namespace {
-class MLFunctionPrinter : public FunctionState {
+class MLFunctionPrinter : public FunctionPrinter {
 public:
   MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other);
 
@@ -806,7 +833,7 @@
 
 MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
                                      const ModulePrinter &other)
-    : FunctionState(other), function(function), numSpaces(0) {}
+    : FunctionPrinter(other), function(function), numSpaces(0) {}
 
 void MLFunctionPrinter::print() {
   os << "mlfunc ";
@@ -874,14 +901,14 @@
 
 void Attribute::print(raw_ostream &os) const {
   ModuleState state(/*no context is known*/ nullptr);
-  ModulePrinter(os, state).print(this);
+  ModulePrinter(os, state).printAttribute(this);
 }
 
 void Attribute::dump() const { print(llvm::errs()); }
 
 void Type::print(raw_ostream &os) const {
   ModuleState state(getContext());
-  ModulePrinter(os, state).print(this);
+  ModulePrinter(os, state).printType(this);
 }
 
 void Type::dump() const { print(llvm::errs()); }
@@ -898,12 +925,12 @@
 
 void AffineExpr::print(raw_ostream &os) const {
   ModuleState state(/*no context is known*/ nullptr);
-  ModulePrinter(os, state).print(this);
+  ModulePrinter(os, state).printAffineExpr(this);
 }
 
 void AffineMap::print(raw_ostream &os) const {
   ModuleState state(/*no context is known*/ nullptr);
-  ModulePrinter(os, state).print(this);
+  ModulePrinter(os, state).printAffineMap(this);
 }
 
 void Instruction::print(raw_ostream &os) const {