More simplification for affine binary op expr's.
- simplify operations with identity elements (multiply by 1, add with 0).
- simplify successive add/mul: fold constants, propagate constants to the
right.
- simplify floordiv and ceildiv when divisors are constants, and the LHS is a
multiply expression with RHS constant.
- fix an affine expression printing bug on paren emission.
- while on this, fix affine-map test cases file (memref's using layout maps
that were duplicates of existing ones should be emitted pointing to the
unique'd one).
PiperOrigin-RevId: 207046738
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 8972510..4bd224d 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -27,32 +27,65 @@
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
results(results), rangeSizes(rangeSizes) {}
-/// Fold to a constant when possible. Canonicalize so that only the RHS is a
-/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
-/// expressions, make it the RHS. Return nullptr if it can't be simplified.
+/// Simplify add expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(l->getValue() + r->getValue(), context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+ // Fold if both LHS, RHS are a constant.
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(lhsConst->getValue() + rhsConst->getValue(),
+ context);
+
+ // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
+ // If only one of them is a symbolic expressions, make it the RHS.
if (isa<AffineConstantExpr>(lhs) ||
- (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()))
+ (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
return AffineBinaryOpExpr::get(Kind::Add, rhs, lhs, context);
+ }
+
+ // At this point, if there was a constant, it would be on the right.
+
+ // Addition with a zero is a noop, return the other input.
+ if (rhsConst) {
+ if (rhsConst->getValue() == 0)
+ return lhs;
+ }
+ // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
+ return AffineBinaryOpExpr::get(
+ Kind::Add, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() + rhsConst->getValue(),
+ context),
+ context);
+ }
+
+ // When doing successive additions, bring constant to the right: turn (d0 + 2)
+ // + d1 into (d0 + d1) + 2.
+ if (lBin && lBin->getKind() == Kind::Add) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ return AffineBinaryOpExpr::get(
+ Kind::Add,
+ AffineBinaryOpExpr::get(Kind::Add, lBin->getLHS(), rhs, context),
+ lrhs, context);
+ }
+ }
return nullptr;
- // TODO(someone): implement more simplification like x + 0 -> x; (x + 2) + 4
- // -> (x + 6). Do this in a systematic way in conjunction with other
- // simplifications as opposed to incremental hacks.
}
-/// Simplify a multiply expression. Fold it to a constant when possible, and
-/// make the symbolic/constant operand the RHS.
+/// Simplify a multiply expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(l->getValue() * r->getValue(), context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(lhsConst->getValue() * rhsConst->getValue(),
+ context);
assert(lhs->isSymbolicOrConstant() || rhs->isSymbolicOrConstant());
@@ -64,33 +97,100 @@
return AffineBinaryOpExpr::get(Kind::Mul, rhs, lhs, context);
}
+ // At this point, if there was a constant, it would be on the right.
+
+ // Multiplication with a one is a noop, return the other input.
+ if (rhsConst) {
+ if (rhsConst->getValue() == 1)
+ return lhs;
+ // Multiplication with zero.
+ if (rhsConst->getValue() == 0)
+ return rhsConst;
+ }
+
+ // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
+ return AffineBinaryOpExpr::get(
+ Kind::Mul, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() * rhsConst->getValue(),
+ context),
+ context);
+ }
+
+ // When doing successive multiplication, bring constant to the right: turn (d0
+ // * 2) * d1 into (d0 * d1) * 2.
+ if (lBin && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ return AffineBinaryOpExpr::get(
+ Kind::Mul,
+ AffineBinaryOpExpr::get(Kind::Mul, lBin->getLHS(), rhs, context),
+ lrhs, context);
+ }
+ }
+
return nullptr;
- // TODO(someone): implement some more simplification/canonicalization such as
- // 1*x is same as x, and in general, move it in the form d_i*expr where d_i is
- // a dimensional identifier. So, 2*(d0 + 4) + s0*d0 becomes (2 + s0)*d0 + 8.
}
AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(l->getValue() / r->getValue(), context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(lhsConst->getValue() / rhsConst->getValue(),
+ context);
+
+ // Fold floordiv of a multiply with a constant that is a multiple of the
+ // divisor. Eg: (i * 128) floordiv 64 = i * 2.
+ if (rhsConst) {
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ // rhsConst is known to be positive if a constant.
+ if (lrhs->getValue() % rhsConst->getValue() == 0)
+ return AffineBinaryOpExpr::get(
+ Kind::Mul, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
+ context),
+ context);
+ }
+ }
+ }
return nullptr;
- // TODO(someone): implement more simplification along the lines described in
- // simplifyMod TODO. For eg: 128*N floordiv 128 is N.
}
AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(
- (int64_t)llvm::divideCeil((uint64_t)l->getValue(),
- (uint64_t)r->getValue()),
- context);
+ auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+
+ if (lhsConst && rhsConst)
+ return AffineConstantExpr::get(
+ (int64_t)llvm::divideCeil((uint64_t)lhsConst->getValue(),
+ (uint64_t)rhsConst->getValue()),
+ context);
+
+ // Fold ceildiv of a multiply with a constant that is a multiple of the
+ // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
+ if (rhsConst) {
+ auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
+ if (lBin && lBin->getKind() == Kind::Mul) {
+ if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
+ // rhsConst is known to be positive if a constant.
+ if (lrhs->getValue() % rhsConst->getValue() == 0)
+ return AffineBinaryOpExpr::get(
+ Kind::Mul, lBin->getLHS(),
+ AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
+ context),
+ context);
+ }
+ }
+ }
return nullptr;
// TODO(someone): implement more simplification along the lines described in