[MLIR] Templated AffineExprBaseRef

This CL implements AffineExprBaseRef as a templated type to allow LLVM-style
casts to work properly. This also allows making AffineExprBaseRef::expr
private.

To achieve this, it is necessary to use llvm::simplify_type and make
AffineConstExpr derive from both AffineExpr and llvm::simplify<AffineExprRef>.
Note that llvm::simplify_type is just an interface to enable the proper
template resolution of isa/cast/dyn_cast but it otherwise holds no value.

Lastly note that certain dyn_cast operations wanted the const AffineExpr* form
of AffineExprBaseRef so I made the implicit constructor take that by default
and documented the immutable behavior. I think this is consistent with the
decision to make unique'd type immutable by convention and never use const on
them.

PiperOrigin-RevId: 215642247
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index 2a09d3c..ebcc50a 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -32,10 +32,10 @@
 /// and is substituted into. The ArrayRef 'eq' is expected to be in the format
 /// [dims, symbols, locals, constant term].
 //  TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
-static AffineExprWrap toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
-                                   unsigned numSymbols,
-                                   ArrayRef<AffineExprWrap> localExprs,
-                                   MLIRContext *context) {
+static AffineExprRef toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+                                  unsigned numSymbols,
+                                  ArrayRef<AffineExprRef> localExprs,
+                                  MLIRContext *context) {
   // Assert expected numLocals = eq.size() - numDims - numSymbols - 1
   assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
          "unexpected number of local expressions");
@@ -69,7 +69,7 @@
 
 namespace {
 
-// This class is used to flatten a pure affine expression (AffineExprWrap,
+// This class is used to flatten a pure affine expression (AffineExprRef,
 // which is in a tree form) into a sum of products (w.r.t constants) when
 // possible, and in that process simplifying the expression. The simplification
 // performed includes the accumulation of contributions for each dimensional and
@@ -124,12 +124,12 @@
   unsigned numLocals;
   // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
   // which new identifiers were introduced; if the latter do not get canceled
-  // out, these expressions are needed to reconstruct the AffineExprWrap  / tree
+  // out, these expressions are needed to reconstruct the AffineExprRef / tree
   // form. Note that these expressions themselves would have been simplified
   // (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
   // simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
   // would be the local expression stored for q.
-  SmallVector<AffineExprWrap, 4> localExprs;
+  SmallVector<AffineExprRef, 4> localExprs;
   MLIRContext *context;
 
   AffineExprFlattener(unsigned numDims, unsigned numSymbols,
@@ -142,7 +142,7 @@
   void visitMulExpr(AffineBinaryOpExpr *expr) {
     assert(operandExprStack.size() >= 2);
     // This is a pure affine expr; the RHS will be a constant.
-    assert(isa<AffineConstantExpr>(expr->getRHS().expr));
+    assert(isa<AffineConstantExpr>(expr->getRHS()));
     // Get the RHS constant.
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     operandExprStack.pop_back();
@@ -169,7 +169,7 @@
   void visitModExpr(AffineBinaryOpExpr *expr) {
     assert(operandExprStack.size() >= 2);
     // This is a pure affine expr; the RHS will be a constant.
-    assert(isa<AffineConstantExpr>(expr->getRHS().expr));
+    assert(isa<AffineConstantExpr>(expr->getRHS()));
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     operandExprStack.pop_back();
     auto &lhs = operandExprStack.back();
@@ -220,7 +220,7 @@
 private:
   void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
     assert(operandExprStack.size() >= 2);
-    assert(isa<AffineConstantExpr>(expr->getRHS().expr));
+    assert(isa<AffineConstantExpr>(expr->getRHS()));
     // This is a pure affine expr; the RHS is a positive constant.
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     // TODO(bondhugula): handle division by zero at the same time the issue is
@@ -259,9 +259,9 @@
   }
 
   // Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
-  // expr). localExpr is the simplified tree expression (AffineExprWrap )
+  // expr). localExpr is the simplified tree expression (AffineExprRef)
   // corresponding to the quantifier.
-  void addLocalId(AffineExprWrap localExpr) {
+  void addLocalId(AffineExprRef localExpr) {
     for (auto &subExpr : operandExprStack) {
       subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
     }
@@ -277,8 +277,8 @@
 
 } // end anonymous namespace
 
-AffineExprWrap mlir::simplifyAffineExpr(AffineExprWrap expr, unsigned numDims,
-                                        unsigned numSymbols) {
+AffineExprRef mlir::simplifyAffineExpr(AffineExprRef expr, unsigned numDims,
+                                       unsigned numSymbols) {
   // TODO(bondhugula): only pure affine for now. The simplification here can be
   // extended to semi-affine maps in the future.
   if (!expr->isPureAffine())
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index c745207..3a6183b 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -38,7 +38,7 @@
     unsigned j = 0;
     AffineBoundExprList::const_iterator it, e;
     for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
-      if (auto *cExpr = dyn_cast<AffineConstantExpr>(it->expr)) {
+      if (auto *cExpr = dyn_cast<AffineConstantExpr>(*it)) {
         if (val == None) {
           val = cExpr->getValue();
           *idx = j;
@@ -68,7 +68,7 @@
     }
     if (it == lhsList.end()) {
       // There can only be one constant affine expr in this bound list.
-      if (auto *cExpr = dyn_cast<AffineConstantExpr>(expr.expr)) {
+      if (auto cExpr = dyn_cast<AffineConstantExpr>(expr)) {
         unsigned idx;
         if (lb) {
           auto cb = getReducedConstBound(
@@ -105,8 +105,8 @@
 }
 
 HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
-                                         ArrayRef<ArrayRef<AffineExprWrap>> lbs,
-                                         ArrayRef<ArrayRef<AffineExprWrap>> ubs,
+                                         ArrayRef<ArrayRef<AffineExprRef>> lbs,
+                                         ArrayRef<ArrayRef<AffineExprRef>> ubs,
                                          MLIRContext *context,
                                          IntegerSet *symbolContext)
     : context(symbolContext ? MutableIntegerSet(symbolContext, context)
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 5c64b2f..babe95f 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -32,7 +32,7 @@
 /// Returns the trip count of the loop as an affine expression if the latter is
 /// expressible as an affine expression, and nullptr otherwise. The trip count
 /// expression is simplified before returning.
-AffineExprWrap mlir::getTripCountExpr(const ForStmt &forStmt) {
+AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
   // upper_bound - lower_bound + 1
   int64_t loopSpan;
 
@@ -56,12 +56,12 @@
       return nullptr;
 
     // ub_expr - lb_expr + 1
-    AffineExprWrap lbExpr(lbMap->getResult(0));
-    AffineExprWrap ubExpr(ubMap->getResult(0));
+    AffineExprRef lbExpr(lbMap->getResult(0));
+    AffineExprRef ubExpr(ubMap->getResult(0));
     auto loopSpanExpr = simplifyAffineExpr(
         ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
         std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
-    auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr.expr);
+    auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
     if (!cExpr)
       return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
     loopSpan = cExpr->getValue();
@@ -81,8 +81,7 @@
 llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
   auto tripCountExpr = getTripCountExpr(forStmt);
 
-  if (auto *constExpr =
-          dyn_cast_or_null<AffineConstantExpr>(tripCountExpr.expr))
+  if (auto constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
     return constExpr->getValue();
 
   return None;
@@ -97,7 +96,7 @@
   if (!tripCountExpr)
     return 1;
 
-  if (auto *constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr.expr)) {
+  if (auto constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
     uint64_t tripCount = constExpr->getValue();
 
     // 0 iteration loops (greatest divisor is 2^64 - 1).
@@ -109,5 +108,5 @@
   }
 
   // Trip count is not a known constant; return its largest known divisor.
-  return tripCountExpr.expr->getLargestKnownDivisor();
+  return tripCountExpr->getLargestKnownDivisor();
 }