Expose custom asmprinter support to core operations and have them adopt it,
fixing the printing syntax for dim, constant, fadd, etc.
PiperOrigin-RevId: 205908627
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 1d61213..fa4462d 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
@@ -38,6 +39,8 @@
void Identifier::dump() const { print(llvm::errs()); }
+OpAsmPrinter::~OpAsmPrinter() {}
+
//===----------------------------------------------------------------------===//
// ModuleState
//===----------------------------------------------------------------------===//
@@ -176,15 +179,15 @@
}
void print(const Module *module);
- void print(const Attribute *attr) const;
- void print(const Type *type) const;
+ void printAttribute(const Attribute *attr);
+ void printType(const Type *type);
void print(const Function *fn);
void print(const ExtFunction *fn);
void print(const CFGFunction *fn);
void print(const MLFunction *fn);
- void print(const AffineMap *map);
- void print(const AffineExpr *expr) const;
+ void printAffineMap(const AffineMap *map);
+ void printAffineExpr(const AffineExpr *expr);
protected:
raw_ostream &os;
@@ -192,9 +195,9 @@
void printFunctionSignature(const Function *fn);
void printAffineMapId(int affineMapId) const;
- void printAffineMapReference(const AffineMap *affineMap) const;
+ void printAffineMapReference(const AffineMap *affineMap);
- void print(const AffineBinaryOpExpr *expr) const;
+ void printAffineBinaryOpExpr(const AffineBinaryOpExpr *expr);
};
} // end anonymous namespace
@@ -215,7 +218,7 @@
os << "#map" << affineMapId;
}
-void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) const {
+void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) {
int mapId = state.getAffineMapId(affineMap);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
@@ -237,7 +240,7 @@
print(fn);
}
-void ModulePrinter::print(const Attribute *attr) const {
+void ModulePrinter::printAttribute(const Attribute *attr) {
switch (attr->getKind()) {
case Attribute::Kind::Bool:
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
@@ -256,7 +259,7 @@
case Attribute::Kind::Array: {
auto elts = cast<ArrayAttr>(attr)->getValue();
os << '[';
- interleaveComma(elts, [&](Attribute *attr) { print(attr); });
+ interleaveComma(elts, [&](Attribute *attr) { printAttribute(attr); });
os << ']';
break;
}
@@ -266,7 +269,7 @@
}
}
-void ModulePrinter::print(const Type *type) const {
+void ModulePrinter::printType(const Type *type) {
switch (type->getKind()) {
case Type::Kind::AffineInt:
os << "affineint";
@@ -356,7 +359,7 @@
// Affine expressions and maps
//===----------------------------------------------------------------------===//
-void ModulePrinter::print(const AffineExpr *expr) const {
+void ModulePrinter::printAffineExpr(const AffineExpr *expr) {
switch (expr->getKind()) {
case AffineExpr::Kind::SymbolId:
os << 's' << cast<AffineSymbolExpr>(expr)->getPosition();
@@ -372,14 +375,14 @@
case AffineExpr::Kind::FloorDiv:
case AffineExpr::Kind::CeilDiv:
case AffineExpr::Kind::Mod:
- return print(cast<AffineBinaryOpExpr>(expr));
+ return printAffineBinaryOpExpr(cast<AffineBinaryOpExpr>(expr));
}
}
-void ModulePrinter::print(const AffineBinaryOpExpr *expr) const {
+void ModulePrinter::printAffineBinaryOpExpr(const AffineBinaryOpExpr *expr) {
if (expr->getKind() != AffineExpr::Kind::Add) {
os << '(';
- print(expr->getLHS());
+ printAffineExpr(expr->getLHS());
switch (expr->getKind()) {
case AffineExpr::Kind::Mul:
os << " * ";
@@ -397,14 +400,14 @@
llvm_unreachable("unexpected affine binary op expression");
}
- print(expr->getRHS());
+ printAffineExpr(expr->getRHS());
os << ')';
return;
}
// Print out special "pretty" forms for add.
os << '(';
- print(expr->getLHS());
+ printAffineExpr(expr->getLHS());
// Pretty print addition to a product that has a negative operand as a
// subtraction.
@@ -413,7 +416,7 @@
if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
if (rrhs->getValue() < 0) {
os << " - (";
- print(rhs->getLHS());
+ printAffineExpr(rhs->getLHS());
os << " * " << -rrhs->getValue() << "))";
return;
}
@@ -430,11 +433,11 @@
}
os << " + ";
- print(expr->getRHS());
+ printAffineExpr(expr->getRHS());
os << ')';
}
-void ModulePrinter::print(const AffineMap *map) {
+void ModulePrinter::printAffineMap(const AffineMap *map) {
// Dimension identifiers.
os << '(';
for (int i = 0; i < (int)map->getNumDims() - 1; i++)
@@ -457,7 +460,8 @@
assert(!map->getResults().empty());
// Result affine expressions.
os << " -> (";
- interleaveComma(map->getResults(), [&](AffineExpr *expr) { print(expr); });
+ interleaveComma(map->getResults(),
+ [&](AffineExpr *expr) { printAffineExpr(expr); });
os << ")";
if (!map->isBounded()) {
@@ -466,7 +470,8 @@
// Print range sizes for bounded affine maps.
os << " size (";
- interleaveComma(map->getRangeSizes(), [&](AffineExpr *expr) { print(expr); });
+ interleaveComma(map->getRangeSizes(),
+ [&](AffineExpr *expr) { printAffineExpr(expr); });
os << ")";
}
@@ -478,7 +483,8 @@
auto type = fn->getType();
os << "@" << fn->getName() << '(';
- interleaveComma(type->getInputs(), [&](Type *eltType) { print(eltType); });
+ interleaveComma(type->getInputs(),
+ [&](Type *eltType) { printType(eltType); });
os << ')';
switch (type->getResults().size()) {
@@ -486,11 +492,12 @@
break;
case 1:
os << " -> ";
- print(type->getResults()[0]);
+ printType(type->getResults()[0]);
break;
default:
os << " -> (";
- interleaveComma(type->getResults(), [&](Type *eltType) { print(eltType); });
+ interleaveComma(type->getResults(),
+ [&](Type *eltType) { printType(eltType); });
os << ')';
break;
}
@@ -504,13 +511,29 @@
namespace {
-// FunctionState contains common functionality for printing
+// FunctionPrinter contains common functionality for printing
// CFG and ML functions.
-class FunctionState : public ModulePrinter {
+class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
public:
- FunctionState(const ModulePrinter &other) : ModulePrinter(other) {}
+ FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {}
void printOperation(const Operation *op);
+ void printDefaultOp(const Operation *op);
+
+ // Implement OpAsmPrinter.
+ raw_ostream &getStream() const { return os; }
+ void printType(const Type *type) { ModulePrinter::printType(type); }
+ void printAttribute(const Attribute *attr) {
+ ModulePrinter::printAttribute(attr);
+ }
+ void printAffineMap(const AffineMap *map) {
+ return ModulePrinter::printAffineMap(map);
+ }
+ void printAffineExpr(const AffineExpr *expr) {
+ return ModulePrinter::printAffineExpr(expr);
+ }
+
+ void printOperand(const SSAValue *value) { printValueID(value); }
protected:
void numberValueID(const SSAValue *value) {
@@ -551,7 +574,7 @@
};
} // end anonymous namespace
-void FunctionState::printOperation(const Operation *op) {
+void FunctionPrinter::printOperation(const Operation *op) {
os << " ";
if (op->getNumResults()) {
@@ -562,12 +585,15 @@
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
- opInfo->printAssembly(op, os);
+ opInfo->printAssembly(op, this);
return;
}
// Otherwise use the standard verbose printing approach.
+ printDefaultOp(op);
+}
+void FunctionPrinter::printDefaultOp(const Operation *op) {
// TODO: escape name if necessary.
os << "\"" << op->getName().str() << "\"(";
@@ -580,7 +606,7 @@
os << '{';
interleaveComma(attrs, [&](NamedAttribute attr) {
os << attr.first << ": ";
- print(attr.second);
+ printAttribute(attr.second);
});
os << '}';
}
@@ -588,15 +614,16 @@
// Print the type signature of the operation.
os << " : (";
interleaveComma(op->getOperands(),
- [&](const SSAValue *value) { print(value->getType()); });
+ [&](const SSAValue *value) { printType(value->getType()); });
os << ") -> ";
if (op->getNumResults() == 1) {
- print(op->getResult(0)->getType());
+ printType(op->getResult(0)->getType());
} else {
os << '(';
- interleaveComma(op->getResults(),
- [&](const SSAValue *result) { print(result->getType()); });
+ interleaveComma(op->getResults(), [&](const SSAValue *result) {
+ printType(result->getType());
+ });
os << ')';
}
}
@@ -606,7 +633,7 @@
//===----------------------------------------------------------------------===//
namespace {
-class CFGFunctionPrinter : public FunctionState {
+class CFGFunctionPrinter : public FunctionPrinter {
public:
CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other);
@@ -637,7 +664,7 @@
CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
const ModulePrinter &other)
- : FunctionState(other), function(function) {
+ : FunctionPrinter(other), function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function) {
@@ -679,7 +706,7 @@
interleaveComma(block->getArguments(), [&](const BBArgument *arg) {
printValueID(arg);
os << ": ";
- ModulePrinter::print(arg->getType());
+ printType(arg->getType());
});
os << ')';
}
@@ -722,7 +749,7 @@
});
os << ") : ";
interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
- ModulePrinter::print(operand.get()->getType());
+ printType(operand.get()->getType());
});
}
}
@@ -738,7 +765,7 @@
[&](const CFGValue *operand) { printValueID(operand); });
os << " : ";
interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) {
- ModulePrinter::print(operand->getType());
+ printType(operand->getType());
});
os << ")";
}
@@ -750,7 +777,7 @@
[&](const CFGValue *operand) { printValueID(operand); });
os << " : ";
interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) {
- ModulePrinter::print(operand->getType());
+ printType(operand->getType());
});
os << ")";
}
@@ -766,7 +793,7 @@
[&](const CFGValue *operand) { printValueID(operand); });
os << " : ";
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
- ModulePrinter::print(operand->getType());
+ printType(operand->getType());
});
}
@@ -779,7 +806,7 @@
//===----------------------------------------------------------------------===//
namespace {
-class MLFunctionPrinter : public FunctionState {
+class MLFunctionPrinter : public FunctionPrinter {
public:
MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other);
@@ -806,7 +833,7 @@
MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
const ModulePrinter &other)
- : FunctionState(other), function(function), numSpaces(0) {}
+ : FunctionPrinter(other), function(function), numSpaces(0) {}
void MLFunctionPrinter::print() {
os << "mlfunc ";
@@ -874,14 +901,14 @@
void Attribute::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
- ModulePrinter(os, state).print(this);
+ ModulePrinter(os, state).printAttribute(this);
}
void Attribute::dump() const { print(llvm::errs()); }
void Type::print(raw_ostream &os) const {
ModuleState state(getContext());
- ModulePrinter(os, state).print(this);
+ ModulePrinter(os, state).printType(this);
}
void Type::dump() const { print(llvm::errs()); }
@@ -898,12 +925,12 @@
void AffineExpr::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
- ModulePrinter(os, state).print(this);
+ ModulePrinter(os, state).printAffineExpr(this);
}
void AffineMap::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
- ModulePrinter(os, state).print(this);
+ ModulePrinter(os, state).printAffineMap(this);
}
void Instruction::print(raw_ostream &os) const {