[mlir] Fix a use-after-free iterator error found by asan
While fixing this the parser-affine-map.mlir test started failing due to ordering of the printed affine maps. Even the existing CHECK-DAGs weren't enough to disambiguate; a partial match on one line precluded a total match on a following line.
The fix for this was easy - print the affine maps in reference order rather than in DenseMap iteration order.
PiperOrigin-RevId: 205843770
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 55cc98c..4a21b31 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -63,14 +63,13 @@
return it->second;
}
- const DenseMap<const AffineMap *, int> &getAffineMapIds() const {
- return affineMapIds;
- }
+ ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; }
private:
void recordAffineMapReference(const AffineMap *affineMap) {
if (affineMapIds.count(affineMap) == 0) {
- affineMapIds[affineMap] = nextAffineMapId++;
+ affineMapIds[affineMap] = affineMapsById.size();
+ affineMapsById.push_back(affineMap);
}
}
@@ -84,7 +83,7 @@
void visitOperation(const Operation *op);
DenseMap<const AffineMap *, int> affineMapIds;
- int nextAffineMapId = 0;
+ std::vector<const AffineMap *> affineMapsById;
};
} // end anonymous namespace
@@ -228,10 +227,10 @@
}
void ModulePrinter::print(const Module *module) {
- for (const auto &mapAndId : state.getAffineMapIds()) {
- printAffineMapId(mapAndId.second);
+ for (const auto &map : state.getAffineMapIds()) {
+ printAffineMapId(state.getAffineMapId(map));
os << " = ";
- mapAndId.first->print(os);
+ map->print(os);
os << '\n';
}
for (auto *fn : module->functionList)
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index a74be26..31b94e8 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -669,11 +669,11 @@
// Check if we already have this affine expression.
auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs);
- auto *&result = impl.affineExprs[keyValue];
+ auto **result = &impl.affineExprs[keyValue];
// If we already have it, return that value.
- if (result)
- return result;
+ if (*result)
+ return *result;
// Simplify the expression if possible.
AffineExpr *simplified;
@@ -697,6 +697,9 @@
llvm_unreachable("unexpected binary affine expr");
}
+ // The recursive calls above may have invalidated the 'result' pointer.
+ result = &impl.affineExprs[keyValue];
+
// If simplified to a non-binary affine op expr, don't store it.
if (simplified && !isa<AffineBinaryOpExpr>(simplified)) {
// 'affineExprs' only contains uniqued AffineBinaryOpExpr's.
@@ -705,13 +708,13 @@
if (simplified)
// We know that it's a binary op expression.
- return result = simplified;
+ return *result = simplified;
// On the first use, we allocate them into the bump pointer.
- result = impl.allocator.Allocate<AffineBinaryOpExpr>();
+ *result = impl.allocator.Allocate<AffineBinaryOpExpr>();
// Initialize the memory using placement new.
- new (result) AffineBinaryOpExpr(kind, lhs, rhs);
- return result;
+ new (*result) AffineBinaryOpExpr(kind, lhs, rhs);
+ return *result;
}
AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {
diff --git a/test/IR/parser-affine-map.mlir b/test/IR/parser-affine-map.mlir
index 030b86d..b8a5553 100644
--- a/test/IR/parser-affine-map.mlir
+++ b/test/IR/parser-affine-map.mlir
@@ -1,123 +1,123 @@
// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d1)
#map0 = (i, j) -> (i, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, d1)
#map1 = (i, j) [s0] -> (i, j)
-// CHECK-DAG: #map{{[0-9]+}} = () -> (0)
+// CHECK: #map{{[0-9]+}} = () -> (0)
#map2 = () -> (0)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
#map3 = (i, j) -> (i+1, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), d1)
#map4 = (i, j) [s0] -> (i + s0, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
#map5 = (i, j) -> (1+i, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), (d1 + 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), (d1 + 5))
#map6 = (i, j) [s0] -> (i + s0, j + 5)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + s0), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + s0), d1)
#map7 = (i, j) [s0] -> (i + j + s0, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((((d0 + 5) + d1) + s0), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((((d0 + 5) + d1) + s0), d1)
#map8 = (i, j) [s0] -> (5 + i + j + s0, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), d1)
#map9 = (i, j) [s0] -> ((i + j) + 5, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + (d1 + 5)), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + (d1 + 5)), d1)
#map10 = (i, j) [s0] -> (i + (j + 5), j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 * 2), (d1 * 3))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 * 2), (d1 * 3))
#map11 = (i, j) [s0] -> (2*i, 3*j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + 12) + ((d1 + (s0 * 3)) * 5)), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + 12) + ((d1 + (s0 * 3)) * 5)), d1)
#map12 = (i, j) [s0] -> (i + 2*6 + 5*(j+s0*3), j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 * 5) + d1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 * 5) + d1), d1)
#map13 = (i, j) [s0] -> (5*i + j, j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + d1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + d1), d1)
#map14 = (i, j) [s0] -> ((i + j), (j))
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), (d1 + 3))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), (d1 + 3))
#map15 = (i, j) [s0] -> ((i + j)+5, (j)+3)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, 0)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, 0)
#map16 = (i, j) [s1] -> (i, 0)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (d1 * s0))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (d1 * s0))
#map17 = (i, j) [s0] -> (i, s0*j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, ((d0 * 3) + d1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, ((d0 * 3) + d1))
#map19 = (i, j) -> (i, 3*i + j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, (d0 + (d1 * 3)))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, (d0 + (d1 * 3)))
#map20 = (i, j) -> (i, i + 3*j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (((d0 * ((s0 * s0) * 9)) + 2) + 1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (((d0 * ((s0 * s0) * 9)) + 2) + 1))
#map18 = (i, j) [N] -> (i, 2 + N*N*9*i + 1)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (1, ((d0 + (d1 * 3)) + 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (1, ((d0 + (d1 * 3)) + 5))
#map21 = (i, j) -> (1, i + 3*j + 5)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((s0 * 5), ((d0 + (d1 * 3)) + (d0 * 5)))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((s0 * 5), ((d0 + (d1 * 3)) + (d0 * 5)))
#map22 = (i, j) [s0] -> (5*s0, i + 3*j + 5*i)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * (s0 * s1)), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * (s0 * s1)), d1)
#map23 = (i, j) [s0, s1] -> (i*(s0*s1), j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 mod 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 mod 5))
#map24 = (i, j) [s0, s1] -> (i, j mod 5)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv 5))
#map25 = (i, j) [s0, s1] -> (i, j floordiv 5)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
#map26 = (i, j) [s0, s1] -> (i, j ceildiv 5)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * 1)) - 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * 1)) - 5))
#map29 = (i, j) [s0, s1] -> (i, i - j - 5)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - ((d1 * s1) * 1)) + 2))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - ((d1 * s1) * 1)) + 2))
#map30 = (i, j) [M, N] -> (i, i - N*j + 2)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1))
#map32 = (i, j) [s0, s1] -> (-5*i, -3*j, -2, -1*(i+j), -1*s0)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (-4, (d0 * -1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (-4, (d0 * -1))
#map33 = (i, j) -> (-2+-5-(-3), -1*i)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv s0), (d1 mod s0))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv s0), (d1 mod s0))
#map34 = (i, j) [s0, s1] -> (i, j floordiv s0, j mod s0)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2) [s0, s1, s2] -> (((((d0 * s1) * s2) + (d1 * s1)) + d2))
+// CHECK: #map{{[0-9]+}} = (d0, d1, d2) [s0, s1, s2] -> (((((d0 * s1) * s2) + (d1 * s1)) + d2))
#map35 = (i, j, k) [s0, s1, s2] -> (i*s1*s2 + j*s1 + k)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
#map36 = (i, j) -> (5+3, 2*2, 8-7, 100 floordiv 32, 5 mod 3, 10 ceildiv 3)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
#map37 = (i, j) -> (5 mod 3 + 2, 5*3 - 4, 128 * (500 ceildiv 128), 40 floordiv 7 * 3)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (((d0 * 2) + 1), (d1 + 2))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (((d0 * 2) + 1), (d1 + 2))
#map38 = (i, j) -> (1 + i*2, 2 + j)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
#map39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
#map40 = (i, j) -> (i, j) size (10, 20)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
#map41 = (i, j) [N, M] -> (i, j) size (N, M+10)
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
#map42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
// CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)