Simplify affine binary op expression class hierarchy
- Drop sub-classing of affine binary op expressions.
- Drop affine expr op kind sub. Represent it as multiply by -1 and add. This
will also be in line with the math form when we'll need to represent a system of
linear equalities/inequalities: the negative number goes into the coefficient
of an affine form. (For eg. x_1 + (-1)*x_2 + 3*x_3 + (-2) >= 0). The folding
simplification will transparently deal with multiplying the -1 with any other
constants. This also means we won't need to simplify a multiply expression
like in x_1 + (-2)*x_2 to a subtract expression (x_1 - 2*x_2) for
canonicalization/uniquing.
- When we print the IR, we will still pretty print to a subtract when possible.
PiperOrigin-RevId: 205298958
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 8d0ee3d..6bfbaf5 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -28,27 +28,19 @@
switch (kind) {
case Kind::Add:
assert(!isa<AffineConstantExpr>(lhs));
- // TODO (more verification)
- break;
- case Kind::Sub:
- // TODO (verification)
break;
case Kind::Mul:
assert(!isa<AffineConstantExpr>(lhs));
assert(rhs->isSymbolicOrConstant());
- // TODO (more verification)
break;
case Kind::FloorDiv:
assert(rhs->isSymbolicOrConstant());
- // TODO (more verification)
break;
case Kind::CeilDiv:
assert(rhs->isSymbolicOrConstant());
- // TODO (more verification)
break;
case Kind::Mod:
assert(rhs->isSymbolicOrConstant());
- // TODO (more verification)
break;
default:
llvm_unreachable("unexpected binary affine expr");
@@ -67,7 +59,6 @@
return true;
case Kind::Add:
- case Kind::Sub:
case Kind::Mul:
case Kind::FloorDiv:
case Kind::CeilDiv:
@@ -87,16 +78,15 @@
case Kind::DimId:
case Kind::Constant:
return true;
- case Kind::Add:
- case Kind::Sub: {
- auto op = cast<AffineBinaryOpExpr>(this);
+ case Kind::Add: {
+ auto *op = cast<AffineBinaryOpExpr>(this);
return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine();
}
case Kind::Mul: {
// TODO: Canonicalize the constants in binary operators to the RHS when
// possible, allowing this to merge into the next case.
- auto op = cast<AffineBinaryOpExpr>(this);
+ auto *op = cast<AffineBinaryOpExpr>(this);
return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine() &&
(isa<AffineConstantExpr>(op->getLHS()) ||
isa<AffineConstantExpr>(op->getRHS()));
@@ -104,7 +94,7 @@
case Kind::FloorDiv:
case Kind::CeilDiv:
case Kind::Mod: {
- auto op = cast<AffineBinaryOpExpr>(this);
+ auto *op = cast<AffineBinaryOpExpr>(this);
return op->getLHS()->isPureAffine() &&
isa<AffineConstantExpr>(op->getRHS());
}
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index d8e09b5..8972510 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -38,7 +38,7 @@
if (isa<AffineConstantExpr>(lhs) ||
(lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()))
- return AffineAddExpr::get(rhs, lhs, context);
+ return AffineBinaryOpExpr::get(Kind::Add, rhs, lhs, context);
return nullptr;
// TODO(someone): implement more simplification like x + 0 -> x; (x + 2) + 4
@@ -46,16 +46,6 @@
// simplifications as opposed to incremental hacks.
}
-AffineExpr *AffineBinaryOpExpr::simplifySub(AffineExpr *lhs, AffineExpr *rhs,
- MLIRContext *context) {
- if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
- if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
- return AffineConstantExpr::get(l->getValue() - r->getValue(), context);
-
- return nullptr;
- // TODO(someone): implement more simplification like mentioned for add.
-}
-
/// Simplify a multiply expression. Fold it to a constant when possible, and
/// make the symbolic/constant operand the RHS.
AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
@@ -71,7 +61,7 @@
// constant. (Note that a constant is trivially symbolic).
if (!rhs->isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
// At least one of them has to be symbolic.
- return AffineMulExpr::get(rhs, lhs, context);
+ return AffineBinaryOpExpr::get(Kind::Mul, rhs, lhs, context);
}
return nullptr;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 570ae49..03d9b1d 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -661,38 +661,62 @@
llvm::errs() << "\n";
}
-void AffineAddExpr::print(raw_ostream &os) const {
- os << "(" << *getLHS() << " + " << *getRHS() << ")";
-}
-
-void AffineSubExpr::print(raw_ostream &os) const {
- os << "(" << *getLHS() << " - " << *getRHS() << ")";
-}
-
-void AffineMulExpr::print(raw_ostream &os) const {
- os << "(" << *getLHS() << " * " << *getRHS() << ")";
-}
-
-void AffineModExpr::print(raw_ostream &os) const {
- os << "(" << *getLHS() << " mod " << *getRHS() << ")";
-}
-
-void AffineFloorDivExpr::print(raw_ostream &os) const {
- os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
-}
-
-void AffineCeilDivExpr::print(raw_ostream &os) const {
- os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
-}
-
void AffineSymbolExpr::print(raw_ostream &os) const {
- os << "s" << getPosition();
+ os << 's' << getPosition();
}
-void AffineDimExpr::print(raw_ostream &os) const { os << "d" << getPosition(); }
+void AffineDimExpr::print(raw_ostream &os) const { os << 'd' << getPosition(); }
void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); }
+static void printAdd(const AffineBinaryOpExpr *addExpr, raw_ostream &os) {
+ os << '(' << *addExpr->getLHS();
+
+ // Pretty print addition to a product that has a negative operand as a
+ // subtraction.
+ if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(addExpr->getRHS())) {
+ if (rhs->getKind() == AffineExpr::Kind::Mul) {
+ if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
+ if (rrhs->getValue() < 0) {
+ os << " - (" << *rhs->getLHS() << " * " << -rrhs->getValue() << "))";
+ return;
+ }
+ }
+ }
+ }
+
+ // Pretty print addition to a negative number as a subtraction.
+ if (auto *rhs = dyn_cast<AffineConstantExpr>(addExpr->getRHS())) {
+ if (rhs->getValue() < 0) {
+ os << " - " << -rhs->getValue() << ")";
+ return;
+ }
+ }
+
+ os << " + " << *addExpr->getRHS() << ")";
+}
+
+void AffineBinaryOpExpr::print(raw_ostream &os) const {
+ switch (getKind()) {
+ case Kind::Add:
+ return printAdd(this, os);
+ case Kind::Mul:
+ os << "(" << *getLHS() << " * " << *getRHS() << ")";
+ return;
+ case Kind::FloorDiv:
+ os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
+ return;
+ case Kind::CeilDiv:
+ os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
+ return;
+ case Kind::Mod:
+ os << "(" << *getLHS() << " mod " << *getRHS() << ")";
+ return;
+ default:
+ llvm_unreachable("unexpected affine binary op expression");
+ }
+}
+
void AffineExpr::print(raw_ostream &os) const {
switch (getKind()) {
case Kind::SymbolId:
@@ -702,17 +726,11 @@
case Kind::Constant:
return cast<AffineConstantExpr>(this)->print(os);
case Kind::Add:
- return cast<AffineAddExpr>(this)->print(os);
- case Kind::Sub:
- return cast<AffineSubExpr>(this)->print(os);
case Kind::Mul:
- return cast<AffineMulExpr>(this)->print(os);
case Kind::FloorDiv:
- return cast<AffineFloorDivExpr>(this)->print(os);
case Kind::CeilDiv:
- return cast<AffineCeilDivExpr>(this)->print(os);
case Kind::Mod:
- return cast<AffineModExpr>(this)->print(os);
+ return cast<AffineBinaryOpExpr>(this)->print(os);
}
}
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 3d7e023..dc5b8e2 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -121,27 +121,23 @@
}
AffineExpr *Builder::getAddExpr(AffineExpr *lhs, AffineExpr *rhs) {
- return AffineAddExpr::get(lhs, rhs, context);
-}
-
-AffineExpr *Builder::getSubExpr(AffineExpr *lhs, AffineExpr *rhs) {
- return AffineSubExpr::get(lhs, rhs, context);
+ return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context);
}
AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) {
- return AffineMulExpr::get(lhs, rhs, context);
+ return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context);
}
AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) {
- return AffineModExpr::get(lhs, rhs, context);
+ return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context);
}
AffineExpr *Builder::getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
- return AffineFloorDivExpr::get(lhs, rhs, context);
+ return AffineBinaryOpExpr::get(AffineExpr::Kind::FloorDiv, lhs, rhs, context);
}
AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
- return AffineCeilDivExpr::get(lhs, rhs, context);
+ return AffineBinaryOpExpr::get(AffineExpr::Kind::CeilDiv, lhs, rhs, context);
}
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index df3d01a..8d2de10 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -684,9 +684,6 @@
case Kind::Add:
simplified = AffineBinaryOpExpr::simplifyAdd(lhs, rhs, context);
break;
- case Kind::Sub:
- simplified = AffineBinaryOpExpr::simplifySub(lhs, rhs, context);
- break;
case Kind::Mul:
simplified = AffineBinaryOpExpr::simplifyMul(lhs, rhs, context);
break;
@@ -720,36 +717,6 @@
return result;
}
-AffineExpr *AffineAddExpr::get(AffineExpr *lhs, AffineExpr *rhs,
- MLIRContext *context) {
- return AffineBinaryOpExpr::get(Kind::Add, lhs, rhs, context);
-}
-
-AffineExpr *AffineSubExpr::get(AffineExpr *lhs, AffineExpr *rhs,
- MLIRContext *context) {
- return AffineBinaryOpExpr::get(Kind::Sub, lhs, rhs, context);
-}
-
-AffineExpr *AffineMulExpr::get(AffineExpr *lhs, AffineExpr *rhs,
- MLIRContext *context) {
- return AffineBinaryOpExpr::get(Kind::Mul, lhs, rhs, context);
-}
-
-AffineExpr *AffineFloorDivExpr::get(AffineExpr *lhs, AffineExpr *rhs,
- MLIRContext *context) {
- return AffineBinaryOpExpr::get(Kind::FloorDiv, lhs, rhs, context);
-}
-
-AffineExpr *AffineCeilDivExpr::get(AffineExpr *lhs, AffineExpr *rhs,
- MLIRContext *context) {
- return AffineBinaryOpExpr::get(Kind::CeilDiv, lhs, rhs, context);
-}
-
-AffineExpr *AffineModExpr::get(AffineExpr *lhs, AffineExpr *rhs,
- MLIRContext *context) {
- return AffineBinaryOpExpr::get(Kind::Mod, lhs, rhs, context);
-}
-
AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {
// TODO(bondhugula): complete this
// FIXME: this should be POD