Fix some issues where we weren't printing affine map references symbolically.
Two problems: 1) we didn't visit the types in ops correctly, and 2) the
general "T" version of the OpAsmPrinter inserter would match things like
MemRefType& and print it directly.
PiperOrigin-RevId: 206863642
diff --git a/include/mlir/IR/OpImplementation.h b/include/mlir/IR/OpImplementation.h
index fc15422..21c940f 100644
--- a/include/mlir/IR/OpImplementation.h
+++ b/include/mlir/IR/OpImplementation.h
@@ -25,6 +25,7 @@
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
class AffineMap;
@@ -97,8 +98,32 @@
return p;
}
-template <typename T>
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, StringRef other) {
+ p.getStream() << other;
+ return p;
+}
+
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const char *other) {
+ p.getStream() << other;
+ return p;
+}
+
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, char other) {
+ p.getStream() << other;
+ return p;
+}
+
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, unsigned other) {
+ p.getStream() << other;
+ return p;
+}
+
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, int other) {
+ p.getStream() << other;
+ return p;
+}
+
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, float other) {
p.getStream() << other;
return p;
}
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 0bc573b..99fb1df 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -118,9 +118,15 @@
}
void ModuleState::visitOperation(const Operation *op) {
- for (auto elt : op->getAttrs()) {
+ // Visit all the types used in the operation.
+ for (auto *operand : op->getOperands())
+ visitType(operand->getType());
+ for (auto *result : op->getResults())
+ visitType(result->getType());
+
+ // Visit each of the attributes.
+ for (auto elt : op->getAttrs())
visitAttribute(elt.second);
- }
}
void ModuleState::visitExtFunction(const ExtFunction *fn) {
diff --git a/test/IR/memory-ops.mlir b/test/IR/memory-ops.mlir
index 5702632..eb07fa1 100644
--- a/test/IR/memory-ops.mlir
+++ b/test/IR/memory-ops.mlir
@@ -1,25 +1,28 @@
// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
+// CHECK: #map0 = (d0, d1) -> (d0, d1)
+// CHECK: #map1 = (d0, d1)[s0] -> (d0 + s0, d1)
+
// CHECK-LABEL: cfgfunc @alloc() {
cfgfunc @alloc() {
bb0:
// Test simple alloc.
- // CHECK: %0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+ // CHECK: %0 = alloc() : memref<1024x64xf32, #map0, 1>
%0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
%1 = "constant"() {value: 0} : () -> affineint
%2 = "constant"() {value: 1} : () -> affineint
// Test alloc with dynamic dimensions.
- // CHECK: %3 = alloc(%1, %2) : memref<?x?xf32, (d0, d1) -> (d0, d1), 1>
+ // CHECK: %3 = alloc(%1, %2) : memref<?x?xf32, #map0, 1>
%3 = alloc(%1, %2) : memref<?x?xf32, (d0, d1) -> (d0, d1), 1>
// Test alloc with no dynamic dimensions and one symbol.
- // CHECK: %4 = alloc()[%1] : memref<2x4xf32, (d0, d1)[s0] -> (d0 + s0, d1), 1>
+ // CHECK: %4 = alloc()[%1] : memref<2x4xf32, #map1, 1>
%4 = alloc()[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
// Test alloc with dynamic dimensions and one symbol.
- // CHECK: %5 = alloc(%2)[%1] : memref<2x?xf32, (d0, d1)[s0] -> (d0 + s0, d1), 1>
+ // CHECK: %5 = alloc(%2)[%1] : memref<2x?xf32, #map1, 1>
%5 = alloc(%2)[%1] : memref<2x?xf32, (d0, d1)[s0] -> (d0 + s0, d1), 1>
// CHECK: return
@@ -29,16 +32,16 @@
// CHECK-LABEL: cfgfunc @load_store
cfgfunc @load_store() {
bb0:
- // CHECK: %0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+ // CHECK: %0 = alloc() : memref<1024x64xf32, #map0, 1>
%0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
%1 = "constant"() {value: 0} : () -> affineint
%2 = "constant"() {value: 1} : () -> affineint
- // CHECK: %3 = load %0[%1, %2] : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+ // CHECK: %3 = load %0[%1, %2] : memref<1024x64xf32, #map0, 1>
%3 = load %0[%1, %2] : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
- // CHECK: store %3, %0[%1, %2] : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+ // CHECK: store %3, %0[%1, %2] : memref<1024x64xf32, #map0, 1>
store %3, %0[%1, %2] : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
return