More simplification for affine binary op expr's.

- simplify operations with identity elements (multiply by 1, add with 0).
- simplify successive add/mul: fold constants, propagate constants to the
  right.
- simplify floordiv and ceildiv when divisors are constants, and the LHS is a
  multiply expression with RHS constant.
- fix an affine expression printing bug on paren emission.

- while on this, fix affine-map test cases file (memref's using layout maps
  that were duplicates of existing ones should be emitted pointing to the
  unique'd one).

PiperOrigin-RevId: 207046738
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 8972510..4bd224d 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -27,32 +27,65 @@
     : numDims(numDims), numSymbols(numSymbols), numResults(numResults),
       results(results), rangeSizes(rangeSizes) {}
 
-/// Fold to a constant when possible. Canonicalize so that only the RHS is a
-/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
-/// expressions, make it the RHS. Return nullptr if it can't be simplified.
+/// Simplify add expression. Return nullptr if it can't be simplified.
 AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
                                             MLIRContext *context) {
-  if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
-    if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
-      return AffineConstantExpr::get(l->getValue() + r->getValue(), context);
+  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
 
+  // Fold if both LHS, RHS are a constant.
+  if (lhsConst && rhsConst)
+    return AffineConstantExpr::get(lhsConst->getValue() + rhsConst->getValue(),
+                                   context);
+
+  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
+  // If only one of them is a symbolic expressions, make it the RHS.
   if (isa<AffineConstantExpr>(lhs) ||
-      (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()))
+      (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
     return AffineBinaryOpExpr::get(Kind::Add, rhs, lhs, context);
+  }
+
+  // At this point, if there was a constant, it would be on the right.
+
+  // Addition with a zero is a noop, return the other input.
+  if (rhsConst) {
+    if (rhsConst->getValue() == 0)
+      return lhs;
+  }
+  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
+  auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+  if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
+    if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
+      return AffineBinaryOpExpr::get(
+          Kind::Add, lBin->getLHS(),
+          AffineConstantExpr::get(lrhs->getValue() + rhsConst->getValue(),
+                                  context),
+          context);
+  }
+
+  // When doing successive additions, bring constant to the right: turn (d0 + 2)
+  // + d1 into (d0 + d1) + 2.
+  if (lBin && lBin->getKind() == Kind::Add) {
+    if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+      return AffineBinaryOpExpr::get(
+          Kind::Add,
+          AffineBinaryOpExpr::get(Kind::Add, lBin->getLHS(), rhs, context),
+          lrhs, context);
+    }
+  }
 
   return nullptr;
-  // TODO(someone): implement more simplification like x + 0 -> x; (x + 2) + 4
-  // -> (x + 6). Do this in a systematic way in conjunction with other
-  // simplifications as opposed to incremental hacks.
 }
 
-/// Simplify a multiply expression. Fold it to a constant when possible, and
-/// make the symbolic/constant operand the RHS.
+/// Simplify a multiply expression. Return nullptr if it can't be simplified.
 AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
                                             MLIRContext *context) {
-  if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
-    if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
-      return AffineConstantExpr::get(l->getValue() * r->getValue(), context);
+  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+  if (lhsConst && rhsConst)
+    return AffineConstantExpr::get(lhsConst->getValue() * rhsConst->getValue(),
+                                   context);
 
   assert(lhs->isSymbolicOrConstant() || rhs->isSymbolicOrConstant());
 
@@ -64,33 +97,100 @@
     return AffineBinaryOpExpr::get(Kind::Mul, rhs, lhs, context);
   }
 
+  // At this point, if there was a constant, it would be on the right.
+
+  // Multiplication with a one is a noop, return the other input.
+  if (rhsConst) {
+    if (rhsConst->getValue() == 1)
+      return lhs;
+    // Multiplication with zero.
+    if (rhsConst->getValue() == 0)
+      return rhsConst;
+  }
+
+  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
+  auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+  if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
+    if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
+      return AffineBinaryOpExpr::get(
+          Kind::Mul, lBin->getLHS(),
+          AffineConstantExpr::get(lrhs->getValue() * rhsConst->getValue(),
+                                  context),
+          context);
+  }
+
+  // When doing successive multiplication, bring constant to the right: turn (d0
+  // * 2) * d1 into (d0 * d1) * 2.
+  if (lBin && lBin->getKind() == Kind::Mul) {
+    if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+      return AffineBinaryOpExpr::get(
+          Kind::Mul,
+          AffineBinaryOpExpr::get(Kind::Mul, lBin->getLHS(), rhs, context),
+          lrhs, context);
+    }
+  }
+
   return nullptr;
-  // TODO(someone): implement some more simplification/canonicalization such as
-  // 1*x is same as x, and in general, move it in the form d_i*expr where d_i is
-  // a dimensional identifier. So, 2*(d0 + 4) + s0*d0 becomes (2 + s0)*d0 + 8.
 }
 
 AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
                                                  AffineExpr *rhs,
                                                  MLIRContext *context) {
-  if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
-    if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
-      return AffineConstantExpr::get(l->getValue() / r->getValue(), context);
+  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+  if (lhsConst && rhsConst)
+    return AffineConstantExpr::get(lhsConst->getValue() / rhsConst->getValue(),
+                                   context);
+
+  // Fold floordiv of a multiply with a constant that is a multiple of the
+  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
+  if (rhsConst) {
+    auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+    if (lBin && lBin->getKind() == Kind::Mul) {
+      if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+        // rhsConst is known to be positive if a constant.
+        if (lrhs->getValue() % rhsConst->getValue() == 0)
+          return AffineBinaryOpExpr::get(
+              Kind::Mul, lBin->getLHS(),
+              AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
+                                      context),
+              context);
+      }
+    }
+  }
 
   return nullptr;
-  // TODO(someone): implement more simplification along the lines described in
-  // simplifyMod TODO. For eg: 128*N floordiv 128 is N.
 }
 
 AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
                                                 AffineExpr *rhs,
                                                 MLIRContext *context) {
-  if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
-    if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
-      return AffineConstantExpr::get(
-          (int64_t)llvm::divideCeil((uint64_t)l->getValue(),
-                                    (uint64_t)r->getValue()),
-          context);
+  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+  if (lhsConst && rhsConst)
+    return AffineConstantExpr::get(
+        (int64_t)llvm::divideCeil((uint64_t)lhsConst->getValue(),
+                                  (uint64_t)rhsConst->getValue()),
+        context);
+
+  // Fold ceildiv of a multiply with a constant that is a multiple of the
+  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
+  if (rhsConst) {
+    auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+    if (lBin && lBin->getKind() == Kind::Mul) {
+      if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+        // rhsConst is known to be positive if a constant.
+        if (lrhs->getValue() % rhsConst->getValue() == 0)
+          return AffineBinaryOpExpr::get(
+              Kind::Mul, lBin->getLHS(),
+              AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
+                                      context),
+              context);
+      }
+    }
+  }
 
   return nullptr;
   // TODO(someone): implement more simplification along the lines described in
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index f881872..3db030b 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -454,9 +454,9 @@
 
         if (rrhs->getValue() < -1) {
           printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
-          os << " - (";
+          os << " - ";
           printAffineExprInternal(rhs->getLHS(), BindingStrength::Strong);
-          os << " * " << -rrhs->getValue() << ')';
+          os << " * " << -rrhs->getValue();
           if (enclosingTightness == BindingStrength::Strong)
             os << ')';
           return;