More simplification for affine binary op expr's.
- simplify operations with identity elements (multiply by 1, add with 0).
- simplify successive add/mul: fold constants, propagate constants to the
right.
- simplify floordiv and ceildiv when divisors are constants, and the LHS is a
multiply expression with RHS constant.
- fix an affine expression printing bug on paren emission.
- while on this, fix affine-map test cases file (memref's using layout maps
that were duplicates of existing ones should be emitted pointing to the
unique'd one).
PiperOrigin-RevId: 207046738
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 8972510..4bd224d 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -27,32 +27,65 @@
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
results(results), rangeSizes(rangeSizes) {}
-/// Fold to a constant when possible. Canonicalize so that only the RHS is a
-/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
-/// expressions, make it the RHS. Return nullptr if it can't be simplified.
+/// Simplify add expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(l->getValue() + r->getValue(), context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+ // Fold if both LHS, RHS are a constant.
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(lhsConst->getValue() + rhsConst->getValue(),
+ context);
+
+ // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
+ // If only one of them is a symbolic expressions, make it the RHS.
if (isa<AffineConstantExpr>(lhs) ||
- (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()))
+ (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
return AffineBinaryOpExpr::get(Kind::Add, rhs, lhs, context);
+ }
+
+ // At this point, if there was a constant, it would be on the right.
+
+ // Addition with a zero is a noop, return the other input.
+ if (rhsConst) {
+ if (rhsConst->getValue() == 0)
+ return lhs;
+ }
+ // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
+ return AffineBinaryOpExpr::get(
+ Kind::Add, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() + rhsConst->getValue(),
+ context),
+ context);
+ }
+
+ // When doing successive additions, bring constant to the right: turn (d0 + 2)
+ // + d1 into (d0 + d1) + 2.
+ if (lBin && lBin->getKind() == Kind::Add) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ return AffineBinaryOpExpr::get(
+ Kind::Add,
+ AffineBinaryOpExpr::get(Kind::Add, lBin->getLHS(), rhs, context),
+ lrhs, context);
+ }
+ }
return nullptr;
- // TODO(someone): implement more simplification like x + 0 -> x; (x + 2) + 4
- // -> (x + 6). Do this in a systematic way in conjunction with other
- // simplifications as opposed to incremental hacks.
}
-/// Simplify a multiply expression. Fold it to a constant when possible, and
-/// make the symbolic/constant operand the RHS.
+/// Simplify a multiply expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(l->getValue() * r->getValue(), context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(lhsConst->getValue() * rhsConst->getValue(),
+ context);
assert(lhs->isSymbolicOrConstant() || rhs->isSymbolicOrConstant());
@@ -64,33 +97,100 @@
return AffineBinaryOpExpr::get(Kind::Mul, rhs, lhs, context);
}
+ // At this point, if there was a constant, it would be on the right.
+
+ // Multiplication with a one is a noop, return the other input.
+ if (rhsConst) {
+ if (rhsConst->getValue() == 1)
+ return lhs;
+ // Multiplication with zero.
+ if (rhsConst->getValue() == 0)
+ return rhsConst;
+ }
+
+ // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
+ return AffineBinaryOpExpr::get(
+ Kind::Mul, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() * rhsConst->getValue(),
+ context),
+ context);
+ }
+
+ // When doing successive multiplication, bring constant to the right: turn (d0
+ // * 2) * d1 into (d0 * d1) * 2.
+ if (lBin && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ return AffineBinaryOpExpr::get(
+ Kind::Mul,
+ AffineBinaryOpExpr::get(Kind::Mul, lBin->getLHS(), rhs, context),
+ lrhs, context);
+ }
+ }
+
return nullptr;
- // TODO(someone): implement some more simplification/canonicalization such as
- // 1*x is same as x, and in general, move it in the form d_i*expr where d_i is
- // a dimensional identifier. So, 2*(d0 + 4) + s0*d0 becomes (2 + s0)*d0 + 8.
}
AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(l->getValue() / r->getValue(), context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(lhsConst->getValue() / rhsConst->getValue(),
+ context);
+
+ // Fold floordiv of a multiply with a constant that is a multiple of the
+ // divisor. Eg: (i * 128) floordiv 64 = i * 2.
+ if (rhsConst) {
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ // rhsConst is known to be positive if a constant.
+ if (lrhs->getValue() % rhsConst->getValue() == 0)
+ return AffineBinaryOpExpr::get(
+ Kind::Mul, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
+ context),
+ context);
+ }
+ }
+ }
return nullptr;
- // TODO(someone): implement more simplification along the lines described in
- // simplifyMod TODO. For eg: 128*N floordiv 128 is N.
}
AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(
- (int64_t)llvm::divideCeil((uint64_t)l->getValue(),
- (uint64_t)r->getValue()),
- context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(
+ (int64_t)llvm::divideCeil((uint64_t)lhsConst->getValue(),
+ (uint64_t)rhsConst->getValue()),
+ context);
+
+ // Fold ceildiv of a multiply with a constant that is a multiple of the
+ // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
+ if (rhsConst) {
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ // rhsConst is known to be positive if a constant.
+ if (lrhs->getValue() % rhsConst->getValue() == 0)
+ return AffineBinaryOpExpr::get(
+ Kind::Mul, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
+ context),
+ context);
+ }
+ }
+ }
return nullptr;
// TODO(someone): implement more simplification along the lines described in
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index f881872..3db030b 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -454,9 +454,9 @@
if (rrhs->getValue() < -1) {
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
- os << " - (";
+ os << " - ";
printAffineExprInternal(rhs->getLHS(), BindingStrength::Strong);
- os << " * " << -rrhs->getValue() << ')';
+ os << " * " << -rrhs->getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
diff --git a/test/IR/affine-map.mlir b/test/IR/affine-map.mlir
index 161883b..4a8ca6b 100644
--- a/test/IR/affine-map.mlir
+++ b/test/IR/affine-map.mlir
@@ -7,16 +7,26 @@
#map1 = (i, j)[s0] -> (i, j)
// CHECK: #map{{[0-9]+}} = () -> (0)
+// A map may have 0 inputs. However, an affine_apply always takes at least one input.
#map2 = () -> (0)
// All three maps are unique'd as one map and so there
// should be only one output.
-// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 1, d1)
-#map3 = (i, j) -> (i+1, j)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 1, d1 * 4 + 2)
+#map3 = (i, j) -> (i+1, 4*j + 2)
// CHECK-EMPTY
-#map3a = (i, j) -> (1+i, j)
+#map3a = (i, j) -> (1+i, 4*j + 2)
// CHECK-EMPTY
-#map3b = (i, j) -> (2+3-2*2+i, j)
+#map3b = (i, j) -> (2 + 3 - 2*2 + i, 4*j + 2)
+#map3c = (i, j) -> (i +1 + 0, 4*j + 2)
+#map3d = (i, j) -> (i + 3 + 2 - 4, 4*j + 2)
+#map3e = (i, j) -> (1*i+3*2-2*2-1, 4*j + 2)
+#map3f = (i, j) -> (i + 1, 4*j*1 + 2)
+#map3g = (i, j) -> (i + 1, 2*2*j + 2)
+#map3h = (i, j) -> (i + 1, 2*j*2 + 2)
+#map3i = (i, j) -> (i + 1, j*2*2 + 2)
+#map3j = (i, j) -> (i + 1, j*1*4 + 2)
+#map3k = (i, j) -> (i + 1, j*4*1 + 2)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 2, d1)
#map4 = (i, j) -> (3+3-2*2+i, j)
@@ -30,7 +40,7 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + s0, d1)
#map7 = (i, j)[s0] -> (i + j + s0, j)
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + 5 + d1 + s0, d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + s0 + 5, d1)
#map8 = (i, j)[s0] -> (5 + i + j + s0, j)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + 5, d1)
@@ -42,7 +52,7 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 * 2, d1 * 3)
#map11 = (i, j)[s0] -> (2*i, 3*j)
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + 12 + (d1 + s0 * 3) * 5, d1)
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + (d1 + s0 * 3) * 5 + 12, d1)
#map12 = (i, j)[s0] -> (i + 2*6 + 5*(j+s0*3), j)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 * 5 + d1, d1)
@@ -51,8 +61,8 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1, d1)
#map14 = (i, j)[s0] -> ((i + j), (j))
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + 5, d1 + 3)
-#map15 = (i, j)[s0] -> ((i + j)+5, (j)+3)
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + 7, d1 + 3)
+#map15 = (i, j)[s0] -> ((i + j + 2) + 5, (j)+3)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0, 0)
#map16 = (i, j)[s1] -> (i, 0)
@@ -66,7 +76,7 @@
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d0 + d1 * 3)
#map20 = (i, j) -> (i, i + 3*j)
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0, d0 * ((s0 * s0) * 9) + 2 + 1)
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0, d0 * ((s0 * s0) * 9) + 3)
#map18 = (i, j)[N] -> (i, 2 + N*N*9*i + 1)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (1, d0 + d1 * 3 + 5)
@@ -105,9 +115,9 @@
// CHECK: #map{{[0-9]+}} = (d0, d1, d2)[s0, s1, s2] -> ((d0 * s1) * s2 + d1 * s1 + d2)
#map35 = (i, j, k)[s0, s1, s2] -> (i*s1*s2 + j*s1 + k)
+// Constant folding.
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
#map36 = (i, j) -> (5+3, 2*2, 8-7, 100 floordiv 32, 5 mod 3, 10 ceildiv 3)
-
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
#map37 = (i, j) -> (5 mod 3 + 2, 5*3 - 4, 128 * (500 ceildiv 128), 40 floordiv 7 * 3)
@@ -123,14 +133,21 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (s0, s1 + 10)
#map41 = (i, j)[N, M] -> (i, j) size (N, M+10)
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (128, s0 * 2 + 5 + s1)
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (128, s0 * 2 + s1 + 5)
#map42 = (i, j)[N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> ((d0 * 5) floordiv 4, (d1 ceildiv 7) mod s0)
#map43 = (i, j) [s0] -> ( i * 5 floordiv 4, j ceildiv 7 mod s0)
-// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 - d1 * 2)
-#map44 = (i, j) -> (i - 2*j)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 - d1 * 2, (d1 * 6) floordiv 4)
+#map44 = (i, j) -> (i - 2*j, j * 6 floordiv 4)
+
+// Simplifications
+// CHECK: #map{{[0-9]+}} = (d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d2 + d1, (d0 * s0) * 8)
+#map45 = (i, j, k) [N] -> (1 + i + 3 + j - 3 + k, k + 5 + j - 5, 2*i*4*N)
+
+// CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (0, d0 * 2, 0, d0, d0 * 4)
+#map46 = (i, j, k) -> (i*0, i * 128 floordiv 64, j * 0 floordiv 64, i * 64 ceildiv 64, i * 512 ceildiv 128)
// CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f0(memref<2x4xi8, #map0, 1>)
@@ -143,6 +160,28 @@
// CHECK: extfunc @f3(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3(memref<2x4xi8, #map3, 1>)
+// CHECK: extfunc @f3a(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3a(memref<2x4xi8, #map3a, 1>)
+// CHECK: extfunc @f3b(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3b(memref<2x4xi8, #map3b, 1>)
+// CHECK: extfunc @f3c(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3c(memref<2x4xi8, #map3c, 1>)
+// CHECK: extfunc @f3d(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3d(memref<2x4xi8, #map3d, 1>)
+// CHECK: extfunc @f3e(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3e(memref<2x4xi8, #map3e, 1>)
+// CHECK: extfunc @f3f(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3f(memref<2x4xi8, #map3f, 1>)
+// CHECK: extfunc @f3g(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3g(memref<2x4xi8, #map3g, 1>)
+// CHECK: extfunc @f3h(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3h(memref<2x4xi8, #map3h, 1>)
+// CHECK: extfunc @f3i(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3i(memref<2x4xi8, #map3i, 1>)
+// CHECK: extfunc @f3j(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3j(memref<2x4xi8, #map3j, 1>)
+// CHECK: extfunc @f3k(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3k(memref<2x4xi8, #map3k, 1>)
// CHECK: extfunc @f4(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f4(memref<2x4xi8, #map4, 1>)
@@ -253,11 +292,13 @@
extfunc @f42(memref<2x4xi8, #map42, 1>)
// CHECK: extfunc @f43(memref<2x4xi8, #map{{[0-9]+}}>)
-extfunc @f43(memref<2x4xi8, #map42>)
+extfunc @f43(memref<2x4xi8, #map43>)
// CHECK: extfunc @f44(memref<2x4xi8, #map{{[0-9]+}}>)
-extfunc @f44(memref<2x4xi8, #map43>)
+extfunc @f44(memref<2x4xi8, #map44>)
-// CHECK: extfunc @f45(memref<2xi8, #map{{[0-9]+}}>)
-extfunc @f45(memref<2xi8, #map44>)
+// CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
+extfunc @f45(memref<100x100x100xi8, #map45>)
+// CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
+extfunc @f45(memref<100x100x100xi8, #map46>)