[MLIR] AffineExpr final cleanups

This CL:
1. performs the global codemod AffineXExpr->AffineXExprClass and
AffineXExprRef -> AffineXExpr;
2. simplifies function calls by removing the redundant MLIRContext parameter;
3. adds missing binary operator versions of scalar op AffineExpr where it
makes sense.

PiperOrigin-RevId: 216242674
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index 2d314e0..fa2541a 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -33,14 +33,14 @@
 
 /// Constructs an affine expression from a flat ArrayRef. If there are local
 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
-/// products expression, 'localExprs' is expected to have the AffineExpr for it,
-/// and is substituted into. The ArrayRef 'eq' is expected to be in the format
-/// [dims, symbols, locals, constant term].
+/// products expression, 'localExprs' is expected to have the AffineExprClass
+/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
+/// format [dims, symbols, locals, constant term].
 //  TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
-static AffineExprRef toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
-                                  unsigned numSymbols,
-                                  ArrayRef<AffineExprRef> localExprs,
-                                  MLIRContext *context) {
+static AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+                               unsigned numSymbols,
+                               ArrayRef<AffineExpr> localExprs,
+                               MLIRContext *context) {
   // Assert expected numLocals = eq.size() - numDims - numSymbols - 1
   assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
          "unexpected number of local expressions");
@@ -74,7 +74,7 @@
 
 namespace {
 
-// This class is used to flatten a pure affine expression (AffineExprRef,
+// This class is used to flatten a pure affine expression (AffineExpr,
 // which is in a tree form) into a sum of products (w.r.t constants) when
 // possible, and in that process simplifying the expression. The simplification
 // performed includes the accumulation of contributions for each dimensional and
@@ -127,14 +127,14 @@
   // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv
   // expressions that could not be simplified.
   unsigned numLocals;
-  // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
+  // AffineExprClass's corresponding to the floordiv/ceildiv/mod expressions for
   // which new identifiers were introduced; if the latter do not get canceled
-  // out, these expressions are needed to reconstruct the AffineExprRef / tree
+  // out, these expressions are needed to reconstruct the AffineExpr / tree
   // form. Note that these expressions themselves would have been simplified
   // (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
   // simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
   // would be the local expression stored for q.
-  SmallVector<AffineExprRef, 4> localExprs;
+  SmallVector<AffineExpr, 4> localExprs;
   MLIRContext *context;
 
   AffineExprFlattener(unsigned numDims, unsigned numSymbols,
@@ -144,10 +144,10 @@
     operandExprStack.reserve(8);
   }
 
-  void visitMulExpr(AffineBinaryOpExprRef expr) {
+  void visitMulExpr(AffineBinaryOpExpr expr) {
     assert(operandExprStack.size() >= 2);
     // This is a pure affine expr; the RHS will be a constant.
-    assert(expr->getRHS().isa<AffineConstantExprRef>());
+    assert(expr->getRHS().isa<AffineConstantExpr>());
     // Get the RHS constant.
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     operandExprStack.pop_back();
@@ -158,7 +158,7 @@
     }
   }
 
-  void visitAddExpr(AffineBinaryOpExprRef expr) {
+  void visitAddExpr(AffineBinaryOpExpr expr) {
     assert(operandExprStack.size() >= 2);
     const auto &rhs = operandExprStack.back();
     auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -171,10 +171,10 @@
     operandExprStack.pop_back();
   }
 
-  void visitModExpr(AffineBinaryOpExprRef expr) {
+  void visitModExpr(AffineBinaryOpExpr expr) {
     assert(operandExprStack.size() >= 2);
     // This is a pure affine expr; the RHS will be a constant.
-    assert(expr->getRHS().isa<AffineConstantExprRef>());
+    assert(expr->getRHS().isa<AffineConstantExpr>());
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     operandExprStack.pop_back();
     auto &lhs = operandExprStack.back();
@@ -200,32 +200,32 @@
     addLocalId(a.floorDiv(b));
     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
   }
-  void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
+  void visitCeilDivExpr(AffineBinaryOpExpr expr) {
     visitDivExpr(expr, /*isCeil=*/true);
   }
-  void visitFloorDivExpr(AffineBinaryOpExprRef expr) {
+  void visitFloorDivExpr(AffineBinaryOpExpr expr) {
     visitDivExpr(expr, /*isCeil=*/false);
   }
-  void visitDimExpr(AffineDimExprRef expr) {
+  void visitDimExpr(AffineDimExpr expr) {
     operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
     auto &eq = operandExprStack.back();
     eq[getDimStartIndex() + expr->getPosition()] = 1;
   }
-  void visitSymbolExpr(AffineSymbolExprRef expr) {
+  void visitSymbolExpr(AffineSymbolExpr expr) {
     operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
     auto &eq = operandExprStack.back();
     eq[getSymbolStartIndex() + expr->getPosition()] = 1;
   }
-  void visitConstantExpr(AffineConstantExprRef expr) {
+  void visitConstantExpr(AffineConstantExpr expr) {
     operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
     auto &eq = operandExprStack.back();
     eq[getConstantIndex()] = expr->getValue();
   }
 
 private:
-  void visitDivExpr(AffineBinaryOpExprRef expr, bool isCeil) {
+  void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) {
     assert(operandExprStack.size() >= 2);
-    assert(expr->getRHS().isa<AffineConstantExprRef>());
+    assert(expr->getRHS().isa<AffineConstantExpr>());
     // This is a pure affine expr; the RHS is a positive constant.
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     // TODO(bondhugula): handle division by zero at the same time the issue is
@@ -266,9 +266,9 @@
   }
 
   // Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
-  // expr). localExpr is the simplified tree expression (AffineExprRef)
+  // expr). localExpr is the simplified tree expression (AffineExpr)
   // corresponding to the quantifier.
-  void addLocalId(AffineExprRef localExpr) {
+  void addLocalId(AffineExpr localExpr) {
     for (auto &subExpr : operandExprStack) {
       subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
     }
@@ -284,8 +284,8 @@
 
 } // end anonymous namespace
 
-AffineExprRef mlir::simplifyAffineExpr(AffineExprRef expr, unsigned numDims,
-                                       unsigned numSymbols) {
+AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
+                                    unsigned numSymbols) {
   // TODO(bondhugula): only pure affine for now. The simplification here can be
   // extended to semi-affine maps in the future.
   if (!expr->isPureAffine())
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index ac0ba3b..463a64b 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -47,7 +47,7 @@
 struct AffineMapCompositionUpdate {
   using PositionMap = DenseMap<unsigned, unsigned>;
 
-  explicit AffineMapCompositionUpdate(ArrayRef<AffineExprRef> inputResults)
+  explicit AffineMapCompositionUpdate(ArrayRef<AffineExpr> inputResults)
       : inputResults(inputResults), outputNumDims(0), outputNumSymbols(0) {}
 
   // Map from 'curr' affine map dim position to 'output' affine map
@@ -65,7 +65,7 @@
   // symbol position.
   PositionMap inputSymbolMap;
   // Results of 'input' affine map.
-  ArrayRef<AffineExprRef> inputResults;
+  ArrayRef<AffineExpr> inputResults;
   // Number of dimension operands for 'output' affine map.
   unsigned outputNumDims;
   // Number of symbol operands for 'output' affine map.
@@ -80,29 +80,29 @@
   AffineExprComposer(const AffineMapCompositionUpdate &mapUpdate)
       : mapUpdate(mapUpdate), walkingInputMap(false) {}
 
-  AffineExprRef walk(AffineExprRef expr) {
+  AffineExpr walk(AffineExpr expr) {
     switch (expr->getKind()) {
     case AffineExprKind::Add:
       return walkBinExpr(
-          expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs + rhs; });
+          expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs + rhs; });
     case AffineExprKind::Mul:
       return walkBinExpr(
-          expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs * rhs; });
+          expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs * rhs; });
     case AffineExprKind::Mod:
       return walkBinExpr(
-          expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs % rhs; });
+          expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs % rhs; });
     case AffineExprKind::FloorDiv:
-      return walkBinExpr(expr, [](AffineExprRef lhs, AffineExprRef rhs) {
+      return walkBinExpr(expr, [](AffineExpr lhs, AffineExpr rhs) {
         return lhs.floorDiv(rhs);
       });
     case AffineExprKind::CeilDiv:
-      return walkBinExpr(expr, [](AffineExprRef lhs, AffineExprRef rhs) {
+      return walkBinExpr(expr, [](AffineExpr lhs, AffineExpr rhs) {
         return lhs.ceilDiv(rhs);
       });
     case AffineExprKind::Constant:
       return expr;
     case AffineExprKind::DimId: {
-      unsigned dimPosition = expr.cast<AffineDimExprRef>()->getPosition();
+      unsigned dimPosition = expr.cast<AffineDimExpr>()->getPosition();
       if (walkingInputMap) {
         return getAffineDimExpr(mapUpdate.inputDimMap.lookup(dimPosition),
                                 expr->getContext());
@@ -123,7 +123,7 @@
       return composer.walk(mapUpdate.inputResults[inputResultIndex]);
     }
     case AffineExprKind::SymbolId:
-      unsigned symbolPosition = expr.cast<AffineSymbolExprRef>()->getPosition();
+      unsigned symbolPosition = expr.cast<AffineSymbolExpr>()->getPosition();
       if (walkingInputMap) {
         return getAffineSymbolExpr(
             mapUpdate.inputSymbolMap.lookup(symbolPosition),
@@ -139,10 +139,9 @@
                      bool walkingInputMap)
       : mapUpdate(mapUpdate), walkingInputMap(walkingInputMap) {}
 
-  AffineExprRef
-  walkBinExpr(AffineExprRef expr,
-              std::function<AffineExprRef(AffineExprRef, AffineExprRef)> op) {
-    auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
+  AffineExpr walkBinExpr(AffineExpr expr,
+                         std::function<AffineExpr(AffineExpr, AffineExpr)> op) {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
     return op(walk(binOpExpr->getLHS()), walk(binOpExpr->getRHS()));
   }
 
@@ -197,7 +196,7 @@
 }
 
 AffineMap *MutableAffineMap::getAffineMap() {
-  return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
+  return AffineMap::get(numDims, numSymbols, results, rangeSizes);
 }
 
 MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context)
@@ -295,10 +294,10 @@
     DenseSet<unsigned> *positions;
     AffineExprPositionGatherer(unsigned numDims, DenseSet<unsigned> *positions)
         : numDims(numDims), positions(positions) {}
-    void visitDimExpr(AffineDimExprRef expr) {
+    void visitDimExpr(AffineDimExpr expr) {
       positions->insert(expr->getPosition());
     }
-    void visitSymbolExpr(AffineSymbolExprRef expr) {
+    void visitSymbolExpr(AffineSymbolExpr expr) {
       positions->insert(numDims + expr->getPosition());
     }
   };
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 4d72808..7fc5b29 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -38,7 +38,7 @@
     unsigned j = 0;
     AffineBoundExprList::const_iterator it, e;
     for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
-      if (auto cExpr = it->dyn_cast<AffineConstantExprRef>()) {
+      if (auto cExpr = it->dyn_cast<AffineConstantExpr>()) {
         if (val == None) {
           val = cExpr->getValue();
           *idx = j;
@@ -52,8 +52,9 @@
   return val;
 }
 
-// Merge the two lists of AffineExpr's into a single one, avoiding duplicates.
-// lb specifies whether the bound lists are for a lower bound or an upper bound.
+// Merge the two lists of AffineExprClass's into a single one, avoiding
+// duplicates. lb specifies whether the bound lists are for a lower bound or an
+// upper bound.
 // TODO(bondhugula): clean this code up.
 static void mergeBounds(const HyperRectangularSet &set,
                         AffineBoundExprList &lhsList,
@@ -68,7 +69,7 @@
     }
     if (it == lhsList.end()) {
       // There can only be one constant affine expr in this bound list.
-      if (auto cExpr = expr.dyn_cast<AffineConstantExprRef>()) {
+      if (auto cExpr = expr.dyn_cast<AffineConstantExpr>()) {
         unsigned idx;
         if (lb) {
           auto cb = getReducedConstBound(
@@ -105,8 +106,8 @@
 }
 
 HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
-                                         ArrayRef<ArrayRef<AffineExprRef>> lbs,
-                                         ArrayRef<ArrayRef<AffineExprRef>> ubs,
+                                         ArrayRef<ArrayRef<AffineExpr>> lbs,
+                                         ArrayRef<ArrayRef<AffineExpr>> ubs,
                                          MLIRContext *context,
                                          IntegerSet *symbolContext)
     : context(symbolContext ? MutableIntegerSet(symbolContext, context)
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 522720e..b3e3afe 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -32,7 +32,7 @@
 /// Returns the trip count of the loop as an affine expression if the latter is
 /// expressible as an affine expression, and nullptr otherwise. The trip count
 /// expression is simplified before returning.
-AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
+AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
   // upper_bound - lower_bound + 1
   int64_t loopSpan;
 
@@ -56,12 +56,12 @@
       return nullptr;
 
     // ub_expr - lb_expr + 1
-    AffineExprRef lbExpr(lbMap->getResult(0));
-    AffineExprRef ubExpr(ubMap->getResult(0));
+    AffineExpr lbExpr(lbMap->getResult(0));
+    AffineExpr ubExpr(ubMap->getResult(0));
     auto loopSpanExpr = simplifyAffineExpr(
         ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
         std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
-    auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExprRef>();
+    auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
     if (!cExpr)
       return loopSpanExpr.ceilDiv(step);
     loopSpan = cExpr->getValue();
@@ -84,7 +84,7 @@
   if (!tripCountExpr)
     return None;
 
-  if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>())
+  if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>())
     return constExpr->getValue();
 
   return None;
@@ -99,7 +99,7 @@
   if (!tripCountExpr)
     return 1;
 
-  if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>()) {
+  if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>()) {
     uint64_t tripCount = constExpr->getValue();
 
     // 0 iteration loops (greatest divisor is 2^64 - 1).