[mlir] Fix a use-after-free iterator error found by asan

While fixing this the parser-affine-map.mlir test started failing due to ordering of the printed affine maps. Even the existing CHECK-DAGs weren't enough to disambiguate; a partial match on one line precluded a total match on a following line.

The fix for this was easy - print the affine maps in reference order rather than in DenseMap iteration order.

PiperOrigin-RevId: 205843770
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 55cc98c..4a21b31 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -63,14 +63,13 @@
     return it->second;
   }
 
-  const DenseMap<const AffineMap *, int> &getAffineMapIds() const {
-    return affineMapIds;
-  }
+  ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; }
 
 private:
   void recordAffineMapReference(const AffineMap *affineMap) {
     if (affineMapIds.count(affineMap) == 0) {
-      affineMapIds[affineMap] = nextAffineMapId++;
+      affineMapIds[affineMap] = affineMapsById.size();
+      affineMapsById.push_back(affineMap);
     }
   }
 
@@ -84,7 +83,7 @@
   void visitOperation(const Operation *op);
 
   DenseMap<const AffineMap *, int> affineMapIds;
-  int nextAffineMapId = 0;
+  std::vector<const AffineMap *> affineMapsById;
 };
 } // end anonymous namespace
 
@@ -228,10 +227,10 @@
 }
 
 void ModulePrinter::print(const Module *module) {
-  for (const auto &mapAndId : state.getAffineMapIds()) {
-    printAffineMapId(mapAndId.second);
+  for (const auto &map : state.getAffineMapIds()) {
+    printAffineMapId(state.getAffineMapId(map));
     os << " = ";
-    mapAndId.first->print(os);
+    map->print(os);
     os << '\n';
   }
   for (auto *fn : module->functionList)
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index a74be26..31b94e8 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -669,11 +669,11 @@
 
   // Check if we already have this affine expression.
   auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs);
-  auto *&result = impl.affineExprs[keyValue];
+  auto **result = &impl.affineExprs[keyValue];
 
   // If we already have it, return that value.
-  if (result)
-    return result;
+  if (*result)
+    return *result;
 
   // Simplify the expression if possible.
   AffineExpr *simplified;
@@ -697,6 +697,9 @@
     llvm_unreachable("unexpected binary affine expr");
   }
 
+  // The recursive calls above may have invalidated the 'result' pointer.
+  result = &impl.affineExprs[keyValue];
+
   // If simplified to a non-binary affine op expr, don't store it.
   if (simplified && !isa<AffineBinaryOpExpr>(simplified)) {
     // 'affineExprs' only contains uniqued AffineBinaryOpExpr's.
@@ -705,13 +708,13 @@
 
   if (simplified)
     // We know that it's a binary op expression.
-    return result = simplified;
+    return *result = simplified;
 
   // On the first use, we allocate them into the bump pointer.
-  result = impl.allocator.Allocate<AffineBinaryOpExpr>();
+  *result = impl.allocator.Allocate<AffineBinaryOpExpr>();
   // Initialize the memory using placement new.
-  new (result) AffineBinaryOpExpr(kind, lhs, rhs);
-  return result;
+  new (*result) AffineBinaryOpExpr(kind, lhs, rhs);
+  return *result;
 }
 
 AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {
diff --git a/test/IR/parser-affine-map.mlir b/test/IR/parser-affine-map.mlir
index 030b86d..b8a5553 100644
--- a/test/IR/parser-affine-map.mlir
+++ b/test/IR/parser-affine-map.mlir
@@ -1,123 +1,123 @@
 // RUN: %S/../../mlir-opt %s -o - | FileCheck %s
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d1)
 #map0 = (i, j) -> (i, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, d1)
 #map1 = (i, j) [s0] -> (i, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = () -> (0)
+// CHECK: #map{{[0-9]+}} = () -> (0)
 #map2 = () -> (0)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
 #map3 = (i, j) -> (i+1, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), d1)
 #map4 = (i, j) [s0] -> (i + s0, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
 #map5 = (i, j) -> (1+i, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), (d1 + 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), (d1 + 5))
 #map6 = (i, j) [s0] -> (i + s0, j + 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + s0), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + s0), d1)
 #map7 = (i, j) [s0] -> (i + j + s0, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((((d0 + 5) + d1) + s0), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((((d0 + 5) + d1) + s0), d1)
 #map8 = (i, j) [s0] -> (5 + i + j + s0, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), d1)
 #map9 = (i, j) [s0] -> ((i + j) + 5, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + (d1 + 5)), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + (d1 + 5)), d1)
 #map10 = (i, j) [s0] -> (i + (j + 5), j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 * 2), (d1 * 3))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 * 2), (d1 * 3))
 #map11 = (i, j) [s0] -> (2*i, 3*j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + 12) + ((d1 + (s0 * 3)) * 5)), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + 12) + ((d1 + (s0 * 3)) * 5)), d1)
 #map12 = (i, j) [s0] -> (i + 2*6 + 5*(j+s0*3), j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 * 5) + d1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 * 5) + d1), d1)
 #map13 = (i, j) [s0] -> (5*i + j, j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + d1), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + d1), d1)
 #map14 = (i, j) [s0] -> ((i + j), (j))
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), (d1 + 3))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), (d1 + 3))
 #map15 = (i, j) [s0] -> ((i + j)+5, (j)+3)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, 0)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, 0)
 #map16 = (i, j) [s1] -> (i, 0)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (d1 * s0))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (d1 * s0))
 #map17 = (i, j) [s0] -> (i, s0*j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, ((d0 * 3) + d1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, ((d0 * 3) + d1))
 #map19 = (i, j) -> (i, 3*i + j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, (d0 + (d1 * 3)))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, (d0 + (d1 * 3)))
 #map20 = (i, j)  -> (i, i + 3*j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (((d0 * ((s0 * s0) * 9)) + 2) + 1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (((d0 * ((s0 * s0) * 9)) + 2) + 1))
 #map18 = (i, j) [N] -> (i, 2 + N*N*9*i + 1)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (1, ((d0 + (d1 * 3)) + 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (1, ((d0 + (d1 * 3)) + 5))
 #map21 = (i, j)  -> (1, i + 3*j + 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((s0 * 5), ((d0 + (d1 * 3)) + (d0 * 5)))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0] -> ((s0 * 5), ((d0 + (d1 * 3)) + (d0 * 5)))
 #map22 = (i, j) [s0] -> (5*s0, i + 3*j + 5*i)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * (s0 * s1)), d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * (s0 * s1)), d1)
 #map23 = (i, j) [s0, s1] -> (i*(s0*s1), j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 mod 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 mod 5))
 #map24 = (i, j) [s0, s1] -> (i, j mod 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv 5))
 #map25 = (i, j) [s0, s1] -> (i, j floordiv 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
 #map26 = (i, j) [s0, s1] -> (i, j ceildiv 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * 1)) - 5))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * 1)) - 5))
 #map29 = (i, j) [s0, s1] -> (i, i - j - 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 -  ((d1 * s1) * 1)) + 2))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 -  ((d1 * s1) * 1)) + 2))
 #map30 = (i, j) [M, N] -> (i, i - N*j + 2)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1))
 #map32 = (i, j) [s0, s1] -> (-5*i, -3*j, -2, -1*(i+j), -1*s0)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (-4, (d0 * -1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (-4, (d0 * -1))
 #map33 = (i, j) -> (-2+-5-(-3), -1*i)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv s0), (d1 mod s0))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv s0), (d1 mod s0))
 #map34 = (i, j) [s0, s1] -> (i, j floordiv s0, j mod s0)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2) [s0, s1, s2] -> (((((d0 * s1) * s2) + (d1 * s1)) + d2))
+// CHECK: #map{{[0-9]+}} = (d0, d1, d2) [s0, s1, s2] -> (((((d0 * s1) * s2) + (d1 * s1)) + d2))
 #map35 = (i, j, k) [s0, s1, s2] -> (i*s1*s2 + j*s1 + k)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
 #map36 = (i, j) -> (5+3, 2*2, 8-7, 100 floordiv 32, 5 mod 3, 10 ceildiv 3)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
 #map37 = (i, j) -> (5 mod 3 + 2, 5*3 - 4, 128 * (500 ceildiv 128), 40 floordiv 7 * 3)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (((d0 * 2) + 1), (d1 + 2))
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (((d0 * 2) + 1), (d1 + 2))
 #map38 = (i, j) -> (1 + i*2, 2 + j)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
 #map39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
 #map40 = (i, j) -> (i, j) size (10, 20)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
 #map41 = (i, j) [N, M] -> (i, j) size (N, M+10)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
+// CHECK: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
 #map42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
 
 // CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)