[mlir] Fix a use-after-free iterator error found by asan
While fixing this the parser-affine-map.mlir test started failing due to ordering of the printed affine maps. Even the existing CHECK-DAGs weren't enough to disambiguate; a partial match on one line precluded a total match on a following line.
The fix for this was easy - print the affine maps in reference order rather than in DenseMap iteration order.
PiperOrigin-RevId: 205843770
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 55cc98c..4a21b31 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -63,14 +63,13 @@
return it->second;
}
- const DenseMap<const AffineMap *, int> &getAffineMapIds() const {
- return affineMapIds;
- }
+ ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; }
private:
void recordAffineMapReference(const AffineMap *affineMap) {
if (affineMapIds.count(affineMap) == 0) {
- affineMapIds[affineMap] = nextAffineMapId++;
+ affineMapIds[affineMap] = affineMapsById.size();
+ affineMapsById.push_back(affineMap);
}
}
@@ -84,7 +83,7 @@
void visitOperation(const Operation *op);
DenseMap<const AffineMap *, int> affineMapIds;
- int nextAffineMapId = 0;
+ std::vector<const AffineMap *> affineMapsById;
};
} // end anonymous namespace
@@ -228,10 +227,10 @@
}
void ModulePrinter::print(const Module *module) {
- for (const auto &mapAndId : state.getAffineMapIds()) {
- printAffineMapId(mapAndId.second);
+ for (const auto &map : state.getAffineMapIds()) {
+ printAffineMapId(state.getAffineMapId(map));
os << " = ";
- mapAndId.first->print(os);
+ map->print(os);
os << '\n';
}
for (auto *fn : module->functionList)
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index a74be26..31b94e8 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -669,11 +669,11 @@
// Check if we already have this affine expression.
auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs);
- auto *&result = impl.affineExprs[keyValue];
+ auto **result = &impl.affineExprs[keyValue];
// If we already have it, return that value.
- if (result)
- return result;
+ if (*result)
+ return *result;
// Simplify the expression if possible.
AffineExpr *simplified;
@@ -697,6 +697,9 @@
llvm_unreachable("unexpected binary affine expr");
}
+ // The recursive calls above may have invalidated the 'result' pointer.
+ result = &impl.affineExprs[keyValue];
+
// If simplified to a non-binary affine op expr, don't store it.
if (simplified && !isa<AffineBinaryOpExpr>(simplified)) {
// 'affineExprs' only contains uniqued AffineBinaryOpExpr's.
@@ -705,13 +708,13 @@
if (simplified)
// We know that it's a binary op expression.
- return result = simplified;
+ return *result = simplified;
// On the first use, we allocate them into the bump pointer.
- result = impl.allocator.Allocate<AffineBinaryOpExpr>();
+ *result = impl.allocator.Allocate<AffineBinaryOpExpr>();
// Initialize the memory using placement new.
- new (result) AffineBinaryOpExpr(kind, lhs, rhs);
- return result;
+ new (*result) AffineBinaryOpExpr(kind, lhs, rhs);
+ return *result;
}
AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {