Simplify affine binary op expression class hierarchy

- Drop sub-classing of affine binary op expressions.
- Drop affine expr op kind sub. Represent it as multiply by -1 and add. This
  will also be in line with the math form when we'll need to represent a system of
  linear equalities/inequalities: the negative number goes into the coefficient
  of an affine form. (For eg. x_1 + (-1)*x_2 + 3*x_3 + (-2) >= 0). The folding
  simplification will transparently deal with multiplying the -1 with any other
  constants. This also means we won't need to simplify a multiply expression
  like in x_1 + (-2)*x_2 to a subtract expression (x_1 - 2*x_2) for
  canonicalization/uniquing.
- When we print the IR, we will still pretty print to a subtract when possible.

PiperOrigin-RevId: 205298958
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 570ae49..03d9b1d 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -661,38 +661,62 @@
   llvm::errs() << "\n";
 }
 
-void AffineAddExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " + " << *getRHS() << ")";
-}
-
-void AffineSubExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " - " << *getRHS() << ")";
-}
-
-void AffineMulExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " * " << *getRHS() << ")";
-}
-
-void AffineModExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " mod " << *getRHS() << ")";
-}
-
-void AffineFloorDivExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
-}
-
-void AffineCeilDivExpr::print(raw_ostream &os) const {
-  os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
-}
-
 void AffineSymbolExpr::print(raw_ostream &os) const {
-  os << "s" << getPosition();
+  os << 's' << getPosition();
 }
 
-void AffineDimExpr::print(raw_ostream &os) const { os << "d" << getPosition(); }
+void AffineDimExpr::print(raw_ostream &os) const { os << 'd' << getPosition(); }
 
 void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); }
 
+static void printAdd(const AffineBinaryOpExpr *addExpr, raw_ostream &os) {
+  os << '(' << *addExpr->getLHS();
+
+  // Pretty print addition to a product that has a negative operand as a
+  // subtraction.
+  if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(addExpr->getRHS())) {
+    if (rhs->getKind() == AffineExpr::Kind::Mul) {
+      if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
+        if (rrhs->getValue() < 0) {
+          os << " - (" << *rhs->getLHS() << " * " << -rrhs->getValue() << "))";
+          return;
+        }
+      }
+    }
+  }
+
+  // Pretty print addition to a negative number as a subtraction.
+  if (auto *rhs = dyn_cast<AffineConstantExpr>(addExpr->getRHS())) {
+    if (rhs->getValue() < 0) {
+      os << " - " << -rhs->getValue() << ")";
+      return;
+    }
+  }
+
+  os << " + " << *addExpr->getRHS() << ")";
+}
+
+void AffineBinaryOpExpr::print(raw_ostream &os) const {
+  switch (getKind()) {
+  case Kind::Add:
+    return printAdd(this, os);
+  case Kind::Mul:
+    os << "(" << *getLHS() << " * " << *getRHS() << ")";
+    return;
+  case Kind::FloorDiv:
+    os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
+    return;
+  case Kind::CeilDiv:
+    os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
+    return;
+  case Kind::Mod:
+    os << "(" << *getLHS() << " mod " << *getRHS() << ")";
+    return;
+  default:
+    llvm_unreachable("unexpected affine binary op expression");
+  }
+}
+
 void AffineExpr::print(raw_ostream &os) const {
   switch (getKind()) {
   case Kind::SymbolId:
@@ -702,17 +726,11 @@
   case Kind::Constant:
     return cast<AffineConstantExpr>(this)->print(os);
   case Kind::Add:
-    return cast<AffineAddExpr>(this)->print(os);
-  case Kind::Sub:
-    return cast<AffineSubExpr>(this)->print(os);
   case Kind::Mul:
-    return cast<AffineMulExpr>(this)->print(os);
   case Kind::FloorDiv:
-    return cast<AffineFloorDivExpr>(this)->print(os);
   case Kind::CeilDiv:
-    return cast<AffineCeilDivExpr>(this)->print(os);
   case Kind::Mod:
-    return cast<AffineModExpr>(this)->print(os);
+    return cast<AffineBinaryOpExpr>(this)->print(os);
   }
 }