Simplify affine binary op expression class hierarchy

- Drop sub-classing of affine binary op expressions.
- Drop affine expr op kind sub. Represent it as multiply by -1 and add. This
  will also be in line with the math form when we'll need to represent a system of
  linear equalities/inequalities: the negative number goes into the coefficient
  of an affine form. (For eg. x_1 + (-1)*x_2 + 3*x_3 + (-2) >= 0). The folding
  simplification will transparently deal with multiplying the -1 with any other
  constants. This also means we won't need to simplify a multiply expression
  like in x_1 + (-2)*x_2 to a subtract expression (x_1 - 2*x_2) for
  canonicalization/uniquing.
- When we print the IR, we will still pretty print to a subtract when possible.

PiperOrigin-RevId: 205298958
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index 6ebb0ae..92bae8a 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -35,10 +35,13 @@
 public:
   enum class Kind {
     Add,
-    Sub,
+    // RHS of mul is always a constant or a symbolic expression.
     Mul,
+    // RHS of mod is always a constant or a symbolic expression.
     Mod,
+    // RHS of floordiv is always a constant or a symbolic expression.
     FloorDiv,
+    // RHS of ceildiv is always a constant or a symbolic expression.
     CeilDiv,
 
     /// This is a marker for the last affine binary op. The range of binary
@@ -83,9 +86,17 @@
   return os;
 }
 
-/// Binary affine expression.
+/// Affine binary operation expression. An affine binary operation could be an
+/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
+/// represented through a multiply by -1 and add.) These expressions are always
+/// constructed in a simplified form. For eg., the LHS and RHS operands can't
+/// both be constants. There are additional canonicalizing rules depending on
+/// the op type: see checks in the constructor.
 class AffineBinaryOpExpr : public AffineExpr {
 public:
+  static AffineExpr *get(Kind kind, AffineExpr *lhs, AffineExpr *rhs,
+                         MLIRContext *context);
+
   AffineExpr *getLHS() const { return lhs; }
   AffineExpr *getRHS() const { return rhs; }
 
@@ -94,10 +105,9 @@
     return expr->getKind() <= Kind::LAST_AFFINE_BINARY_OP;
   }
 
-protected:
-  static AffineExpr *get(Kind kind, AffineExpr *lhs, AffineExpr *rhs,
-                         MLIRContext *context);
+  void print(raw_ostream &os) const;
 
+protected:
   explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhs, AffineExpr *rhs);
 
   AffineExpr *const lhs;
@@ -107,8 +117,6 @@
   // Simplification prior to construction of binary affine op expressions.
   static AffineExpr *simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
                                  MLIRContext *context);
-  static AffineExpr *simplifySub(AffineExpr *lhs, AffineExpr *rhs,
-                                 MLIRContext *context);
   static AffineExpr *simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
                                  MLIRContext *context);
   static AffineExpr *simplifyFloorDiv(AffineExpr *lhs, AffineExpr *rhs,
@@ -119,102 +127,6 @@
                                  MLIRContext *context);
 };
 
-/// Binary affine add expression.
-class AffineAddExpr : public AffineBinaryOpExpr {
-public:
-  static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs,
-                         MLIRContext *context);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const AffineExpr *expr) {
-    return expr->getKind() == Kind::Add;
-  }
-  void print(raw_ostream &os) const;
-
-private:
-  // No constructor; use AffineBinaryOpExpr::get
-  AffineAddExpr(AffineExpr *lhs, AffineExpr *rhs) = delete;
-};
-
-/// Binary affine subtract expression.
-class AffineSubExpr : public AffineBinaryOpExpr {
-public:
-  static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs,
-                         MLIRContext *context);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const AffineExpr *expr) {
-    return expr->getKind() == Kind::Sub;
-  }
-  void print(raw_ostream &os) const;
-
-private:
-  AffineSubExpr(AffineExpr *lhs, AffineExpr *rhs) = delete;
-};
-
-/// Binary affine multiplication expression.
-class AffineMulExpr : public AffineBinaryOpExpr {
-public:
-  static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs,
-                         MLIRContext *context);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const AffineExpr *expr) {
-    return expr->getKind() == Kind::Mul;
-  }
-  void print(raw_ostream &os) const;
-
-private:
-  AffineMulExpr(AffineExpr *lhs, AffineExpr *rhs) = delete;
-};
-
-/// Binary affine modulo operation expression.
-class AffineModExpr : public AffineBinaryOpExpr {
-public:
-  static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs,
-                         MLIRContext *context);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const AffineExpr *expr) {
-    return expr->getKind() == Kind::Mod;
-  }
-  void print(raw_ostream &os) const;
-
-private:
-  AffineModExpr(AffineExpr *lhs, AffineExpr *rhs) = delete;
-};
-
-/// Binary affine floordiv expression.
-class AffineFloorDivExpr : public AffineBinaryOpExpr {
-public:
-  static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs,
-                         MLIRContext *context);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const AffineExpr *expr) {
-    return expr->getKind() == Kind::FloorDiv;
-  }
-  void print(raw_ostream &os) const;
-
-private:
-  AffineFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) = delete;
-};
-
-/// Binary affine ceildiv expression.
-class AffineCeilDivExpr : public AffineBinaryOpExpr {
-public:
-  static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs,
-                         MLIRContext *context);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const AffineExpr *expr) {
-    return expr->getKind() == Kind::CeilDiv;
-  }
-  void print(raw_ostream &os) const;
-
-private:
-  AffineCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) = delete;
-};
 
 /// A dimensional identifier appearing in an affine expression.
 ///
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 8d0ee3d..6bfbaf5 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -28,27 +28,19 @@
   switch (kind) {
   case Kind::Add:
     assert(!isa<AffineConstantExpr>(lhs));
-    // TODO (more verification)
-    break;
-  case Kind::Sub:
-    // TODO (verification)
     break;
   case Kind::Mul:
     assert(!isa<AffineConstantExpr>(lhs));
     assert(rhs->isSymbolicOrConstant());
-    // TODO (more verification)
     break;
   case Kind::FloorDiv:
     assert(rhs->isSymbolicOrConstant());
-    // TODO (more verification)
     break;
   case Kind::CeilDiv:
     assert(rhs->isSymbolicOrConstant());
-    // TODO (more verification)
     break;
   case Kind::Mod:
     assert(rhs->isSymbolicOrConstant());
-    // TODO (more verification)
     break;
   default:
     llvm_unreachable("unexpected binary affine expr");
@@ -67,7 +59,6 @@
     return true;
 
   case Kind::Add:
-  case Kind::Sub:
   case Kind::Mul:
   case Kind::FloorDiv:
   case Kind::CeilDiv:
@@ -87,16 +78,15 @@
   case Kind::DimId:
   case Kind::Constant:
     return true;
-  case Kind::Add:
-  case Kind::Sub: {
-    auto op = cast<AffineBinaryOpExpr>(this);
+  case Kind::Add: {
+    auto *op = cast<AffineBinaryOpExpr>(this);
     return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine();
   }
 
   case Kind::Mul: {
     // TODO: Canonicalize the constants in binary operators to the RHS when
     // possible, allowing this to merge into the next case.
-    auto op = cast<AffineBinaryOpExpr>(this);
+    auto *op = cast<AffineBinaryOpExpr>(this);
     return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine() &&
            (isa<AffineConstantExpr>(op->getLHS()) ||
             isa<AffineConstantExpr>(op->getRHS()));
@@ -104,7 +94,7 @@
   case Kind::FloorDiv:
   case Kind::CeilDiv:
   case Kind::Mod: {
-    auto op = cast<AffineBinaryOpExpr>(this);
+    auto *op = cast<AffineBinaryOpExpr>(this);
     return op->getLHS()->isPureAffine() &&
            isa<AffineConstantExpr>(op->getRHS());
   }
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index d8e09b5..8972510 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -38,7 +38,7 @@
 
   if (isa<AffineConstantExpr>(lhs) ||
       (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()))
-    return AffineAddExpr::get(rhs, lhs, context);
+    return AffineBinaryOpExpr::get(Kind::Add, rhs, lhs, context);
 
   return nullptr;
   // TODO(someone): implement more simplification like x + 0 -> x; (x + 2) + 4
@@ -46,16 +46,6 @@
   // simplifications as opposed to incremental hacks.
 }
 
-AffineExpr *AffineBinaryOpExpr::simplifySub(AffineExpr *lhs, AffineExpr *rhs,
-                                            MLIRContext *context) {
-  if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
-    if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
-      return AffineConstantExpr::get(l->getValue() - r->getValue(), context);
-
-  return nullptr;
-  // TODO(someone): implement more simplification like mentioned for add.
-}
-
 /// Simplify a multiply expression. Fold it to a constant when possible, and
 /// make the symbolic/constant operand the RHS.
 AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
@@ -71,7 +61,7 @@
   // constant. (Note that a constant is trivially symbolic).
   if (!rhs->isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
     // At least one of them has to be symbolic.
-    return AffineMulExpr::get(rhs, lhs, context);
+    return AffineBinaryOpExpr::get(Kind::Mul, rhs, lhs, context);
   }
 
   return nullptr;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 570ae49..03d9b1d 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -661,38 +661,62 @@
   llvm::errs() << "\n";
 }
 
-void AffineAddExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " + " << *getRHS() << ")";
-}
-
-void AffineSubExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " - " << *getRHS() << ")";
-}
-
-void AffineMulExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " * " << *getRHS() << ")";
-}
-
-void AffineModExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " mod " << *getRHS() << ")";
-}
-
-void AffineFloorDivExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
-}
-
-void AffineCeilDivExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
-}
-
 void AffineSymbolExpr::print(raw_ostream &os) const {
-  os << "s" << getPosition();
+  os << 's' << getPosition();
 }
 
-void AffineDimExpr::print(raw_ostream &os) const { os << "d" << getPosition(); }
+void AffineDimExpr::print(raw_ostream &os) const { os << 'd' << getPosition(); }
 
 void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); }
 
+static void printAdd(const AffineBinaryOpExpr *addExpr, raw_ostream &os) {
+  os << '(' << *addExpr->getLHS();
+
+  // Pretty print addition to a product that has a negative operand as a
+  // subtraction.
+  if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(addExpr->getRHS())) {
+    if (rhs->getKind() == AffineExpr::Kind::Mul) {
+      if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
+        if (rrhs->getValue() < 0) {
+          os << " - (" << *rhs->getLHS() << " * " << -rrhs->getValue() << "))";
+          return;
+        }
+      }
+    }
+  }
+
+  // Pretty print addition to a negative number as a subtraction.
+  if (auto *rhs = dyn_cast<AffineConstantExpr>(addExpr->getRHS())) {
+    if (rhs->getValue() < 0) {
+      os << " - " << -rhs->getValue() << ")";
+      return;
+    }
+  }
+
+  os << " + " << *addExpr->getRHS() << ")";
+}
+
+void AffineBinaryOpExpr::print(raw_ostream &os) const {
+  switch (getKind()) {
+  case Kind::Add:
+    return printAdd(this, os);
+  case Kind::Mul:
+    os << "(" << *getLHS() << " * " << *getRHS() << ")";
+    return;
+  case Kind::FloorDiv:
+    os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
+    return;
+  case Kind::CeilDiv:
+    os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
+    return;
+  case Kind::Mod:
+    os << "(" << *getLHS() << " mod " << *getRHS() << ")";
+    return;
+  default:
+    llvm_unreachable("unexpected affine binary op expression");
+  }
+}
+
 void AffineExpr::print(raw_ostream &os) const {
   switch (getKind()) {
   case Kind::SymbolId:
@@ -702,17 +726,11 @@
   case Kind::Constant:
     return cast<AffineConstantExpr>(this)->print(os);
   case Kind::Add:
-    return cast<AffineAddExpr>(this)->print(os);
-  case Kind::Sub:
-    return cast<AffineSubExpr>(this)->print(os);
   case Kind::Mul:
-    return cast<AffineMulExpr>(this)->print(os);
   case Kind::FloorDiv:
-    return cast<AffineFloorDivExpr>(this)->print(os);
   case Kind::CeilDiv:
-    return cast<AffineCeilDivExpr>(this)->print(os);
   case Kind::Mod:
-    return cast<AffineModExpr>(this)->print(os);
+    return cast<AffineBinaryOpExpr>(this)->print(os);
   }
 }
 
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 3d7e023..dc5b8e2 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -121,27 +121,23 @@
 }
 
 AffineExpr *Builder::getAddExpr(AffineExpr *lhs, AffineExpr *rhs) {
-  return AffineAddExpr::get(lhs, rhs, context);
-}
-
-AffineExpr *Builder::getSubExpr(AffineExpr *lhs, AffineExpr *rhs) {
-  return AffineSubExpr::get(lhs, rhs, context);
+  return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context);
 }
 
 AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) {
-  return AffineMulExpr::get(lhs, rhs, context);
+  return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context);
 }
 
 AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) {
-  return AffineModExpr::get(lhs, rhs, context);
+  return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context);
 }
 
 AffineExpr *Builder::getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
-  return AffineFloorDivExpr::get(lhs, rhs, context);
+  return AffineBinaryOpExpr::get(AffineExpr::Kind::FloorDiv, lhs, rhs, context);
 }
 
 AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
-  return AffineCeilDivExpr::get(lhs, rhs, context);
+  return AffineBinaryOpExpr::get(AffineExpr::Kind::CeilDiv, lhs, rhs, context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index df3d01a..8d2de10 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -684,9 +684,6 @@
   case Kind::Add:
     simplified = AffineBinaryOpExpr::simplifyAdd(lhs, rhs, context);
     break;
-  case Kind::Sub:
-    simplified = AffineBinaryOpExpr::simplifySub(lhs, rhs, context);
-    break;
   case Kind::Mul:
     simplified = AffineBinaryOpExpr::simplifyMul(lhs, rhs, context);
     break;
@@ -720,36 +717,6 @@
   return result;
 }
 
-AffineExpr *AffineAddExpr::get(AffineExpr *lhs, AffineExpr *rhs,
-                               MLIRContext *context) {
-  return AffineBinaryOpExpr::get(Kind::Add, lhs, rhs, context);
-}
-
-AffineExpr *AffineSubExpr::get(AffineExpr *lhs, AffineExpr *rhs,
-                               MLIRContext *context) {
-  return AffineBinaryOpExpr::get(Kind::Sub, lhs, rhs, context);
-}
-
-AffineExpr *AffineMulExpr::get(AffineExpr *lhs, AffineExpr *rhs,
-                               MLIRContext *context) {
-  return AffineBinaryOpExpr::get(Kind::Mul, lhs, rhs, context);
-}
-
-AffineExpr *AffineFloorDivExpr::get(AffineExpr *lhs, AffineExpr *rhs,
-                                    MLIRContext *context) {
-  return AffineBinaryOpExpr::get(Kind::FloorDiv, lhs, rhs, context);
-}
-
-AffineExpr *AffineCeilDivExpr::get(AffineExpr *lhs, AffineExpr *rhs,
-                                   MLIRContext *context) {
-  return AffineBinaryOpExpr::get(Kind::CeilDiv, lhs, rhs, context);
-}
-
-AffineExpr *AffineModExpr::get(AffineExpr *lhs, AffineExpr *rhs,
-                               MLIRContext *context) {
-  return AffineBinaryOpExpr::get(Kind::Mod, lhs, rhs, context);
-}
-
 AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {
   // TODO(bondhugula): complete this
   // FIXME: this should be POD
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 98fd716..7b7ee89 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -758,7 +758,8 @@
   case AffineLowPrecOp::Add:
     return builder.getAddExpr(lhs, rhs);
   case AffineLowPrecOp::Sub:
-    return builder.getSubExpr(lhs, rhs);
+    return builder.getAddExpr(
+        lhs, builder.getMulExpr(rhs, builder.getConstantExpr(-1)));
   case AffineLowPrecOp::LNoOp:
     llvm_unreachable("can't create affine expression for null low prec op");
     return nullptr;
diff --git a/test/IR/parser-affine-map.mlir b/test/IR/parser-affine-map.mlir
index 50f2bd7..030b86d 100644
--- a/test/IR/parser-affine-map.mlir
+++ b/test/IR/parser-affine-map.mlir
@@ -81,10 +81,10 @@
 // CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
 #map26 = (i, j) [s0, s1] -> (i, j ceildiv 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - d1) - 5))
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * 1)) - 5))
 #map29 = (i, j) [s0, s1] -> (i, i - j - 5)
 
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * s1)) + 2))
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 -  ((d1 * s1) * 1)) + 2))
 #map30 = (i, j) [M, N] -> (i, i - N*j + 2)
 
 // CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1))
@@ -238,4 +238,4 @@
 extfunc @f41(memref<2x4xi8, #map41, 1>)
 
 // CHECK: extfunc @f42(memref<2x4xi8, #map{{[0-9]+}}, 1>)
-extfunc @f42(memref<2x4xi8, #map42, 1>)
\ No newline at end of file
+extfunc @f42(memref<2x4xi8, #map42, 1>)