Affine expression analysis and simplification.

Outside of IR/
- simplify a MutableAffineMap by flattening the affine expressions
- add a simplify affine expression pass that uses this analysis
- update the FlatAffineConstraints API (to be used in the next CL)

In IR:
- add isMultipleOf and getKnownGCD for AffineExpr, and make the in-IR
  simplication of simplifyMod simpler and more powerful.
- rename the AffineExpr visitor methods to distinguish b/w visiting and
  walking, and to simplify API names based on context.

The next CL will use some of these for the loop unrolling/unroll-jam to make
the detection for the need of cleanup loop powerful/non-trivial.

A future CL will finally move this simplification to FlatAffineConstraints to
make it more powerful. For eg., currently, even if a mod expr appearing in a
part of the expression tree can't be simplified, the whole thing won't be
simplified.

PiperOrigin-RevId: 211012256
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index 6d28260..ac6e728 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -41,23 +41,41 @@
 /// A mutable affine map. Its affine expressions are however unique.
 struct MutableAffineMap {
 public:
-  explicit MutableAffineMap(AffineMap *map);
+  MutableAffineMap(AffineMap *map, MLIRContext *context);
 
   AffineExpr *getResult(unsigned idx) const { return results[idx]; }
   unsigned getNumResults() const { return results.size(); }
+  unsigned getNumDims() const { return numDims; }
+  unsigned getNumSymbols() const { return numSymbols; }
+  /// Returns true if the idx'th result expression is a multiple of factor.
+  bool isMultipleOf(unsigned idx, int64_t factor) const;
+
+  /// Simplify the (result) expressions in this map using analysis (used by
+  //-simplify-affine-expr pass).
+  void simplify();
+  /// Get the AffineMap corresponding to this MutableAffineMap. Note that an
+  /// AffineMap * will be uniqued and stored in context, while a mutable one
+  /// isn't.
+  AffineMap *getAffineMap();
 
 private:
+  // Same meaning as AffineMap's fields.
   SmallVector<AffineExpr *, 8> results;
   SmallVector<AffineExpr *, 8> rangeSizes;
+  unsigned numDims;
+  unsigned numSymbols;
+  /// A pointer to the IR's context to store all newly created AffineExpr's.
+  MLIRContext *context;
 };
 
 /// A mutable integer set. Its affine expressions are however unique.
 struct MutableIntegerSet {
 public:
-  explicit MutableIntegerSet(IntegerSet *set);
+  MutableIntegerSet(IntegerSet *set, MLIRContext *context);
 
   /// Create a universal set (no constraints).
-  explicit MutableIntegerSet(unsigned numDims, unsigned numSymbols);
+  MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+                    MLIRContext *context);
 
   unsigned getNumDims() const { return numDims; }
   unsigned getNumSymbols() const { return numSymbols; }
@@ -74,6 +92,8 @@
 
   SmallVector<AffineExpr *, 8> constraints;
   SmallVector<bool, 8> eqFlags;
+  /// A pointer to the IR's context to store all newly created AffineExpr's.
+  MLIRContext *context;
 };
 
 /// An AffineValueMap is an affine map plus its ML value operands and
@@ -89,9 +109,9 @@
 // TODO(bondhugula): Some of these classes could go into separate files.
 class AffineValueMap {
 public:
-  explicit AffineValueMap(const AffineApplyOp &op);
-  explicit AffineValueMap(const AffineBound &bound);
-  explicit AffineValueMap(AffineMap *map);
+  AffineValueMap(const AffineApplyOp &op, MLIRContext *context);
+  AffineValueMap(const AffineBound &bound, MLIRContext *context);
+  AffineValueMap(AffineMap *map, MLIRContext *context);
 
   ~AffineValueMap();
 
@@ -110,7 +130,7 @@
 
   /// Return true if the idx^th result can be proved to be a multiple of
   /// 'factor', false otherwise.
-  bool isMultipleOf(unsigned idx, int64_t factor) const;
+  inline bool isMultipleOf(unsigned idx, int64_t factor) const;
 
   /// Return true if the result at 'idx' is a constant, false
   /// otherwise.
@@ -128,8 +148,6 @@
   SmallVector<MLValue *, 4> operands;
   /// The SSA results binding to the results of 'map'.
   SmallVector<MLValue *, 4> results;
-  /// A pointer to the IR's context to store all newly created AffineExpr's.
-  MLIRContext *context;
 };
 
 /// An IntegerValueSet is an integer set plus its operands.
@@ -155,8 +173,6 @@
   MutableIntegerSet set;
   /// The SSA operands binding to the dim's and symbols of 'set'.
   SmallVector<MLValue *, 4> operands;
-  /// A pointer to the IR's context to store all newly created AffineExpr's.
-  MLIRContext *context;
 };
 
 /// A flat list of affine equalities and inequalities in the form.
@@ -190,8 +206,7 @@
   /// constraints and identifiers..
   FlatAffineConstraints(unsigned numReservedInequalities,
                         unsigned numReservedEqualities, unsigned numReservedIds)
-      : numEqualities(0), numInequalities(0),
-        numReservedEqualities(numReservedEqualities),
+      : numReservedEqualities(numReservedEqualities),
         numReservedInequalities(numReservedInequalities),
         numReservedIds(numReservedIds) {
     equalities.reserve(numReservedIds * numReservedEqualities);
@@ -208,23 +223,50 @@
 
   /// Create an affine constraint system from an IntegerValueSet.
   // TODO(bondhugula)
-  FlatAffineConstraints(const IntegerValueSet &set);
+  explicit FlatAffineConstraints(const IntegerValueSet &set);
 
   FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef,
                         const IntegerSet &set);
 
+  FlatAffineConstraints(const MutableAffineMap &map);
+
   ~FlatAffineConstraints() {}
 
-  inline int64_t atEq(unsigned i, unsigned j) {
-    return equalities[i * numIds + j];
+  inline int64_t atEq(unsigned i, unsigned j) const {
+    return equalities[i * (numIds + 1) + j];
   }
 
-  inline int64_t atIneq(unsigned i, unsigned j) {
-    return inequalities[i * numIds + j];
+  inline int64_t &atEq(unsigned i, unsigned j) {
+    return equalities[i * (numIds + 1) + j];
   }
 
-  unsigned getNumEqualities() const { return equalities.size(); }
-  unsigned getNumInequalities() const { return inequalities.size(); }
+  inline int64_t atIneq(unsigned i, unsigned j) const {
+    return inequalities[i * (numIds + 1) + j];
+  }
+
+  inline int64_t &atIneq(unsigned i, unsigned j) {
+    return inequalities[i * (numIds + 1) + j];
+  }
+
+  inline unsigned getNumCols() const { return numIds + 1; }
+
+  inline unsigned getNumEqualities() const {
+    return equalities.size() / getNumCols();
+  }
+
+  inline unsigned getNumInequalities() const {
+    return inequalities.size() / getNumCols();
+  }
+
+  ArrayRef<int64_t> getEquality(unsigned idx) {
+    return ArrayRef<int64_t>(&equalities[idx * getNumCols()], getNumCols());
+  }
+
+  ArrayRef<int64_t> getInequality(unsigned idx) {
+    return ArrayRef<int64_t>(&inequalities[idx * getNumCols()], getNumCols());
+  }
+
+  AffineExpr *toAffineExpr(unsigned idx, MLIRContext *context);
 
   void addInequality(ArrayRef<int64_t> inEq);
   void addEquality(ArrayRef<int64_t> eq);
@@ -239,11 +281,19 @@
   void removeEquality(unsigned pos);
   void removeInequality(unsigned pos);
 
-  unsigned getNumConstraints() const { return numEqualities + numInequalities; }
-  unsigned getNumIds() const { return numIds; }
-  unsigned getNumDimIds() const { return numDims; }
-  unsigned getNumSymbolIds() const { return numSymbols; }
-  unsigned getNumLocalIds() const { return numDims - numSymbols - numDims; }
+  unsigned getNumConstraints() const {
+    return equalities.size() + inequalities.size();
+  }
+  inline unsigned getNumIds() const { return numIds; }
+  inline unsigned getNumResultDimIds() const { return numResultDims; }
+  inline unsigned getNumDimIds() const { return numDims; }
+  inline unsigned getNumSymbolIds() const { return numSymbols; }
+  inline unsigned getNumLocalIds() const {
+    return numIds - numResultDims - numDims - numSymbols;
+  }
+
+  void print(raw_ostream &os) const;
+  void dump() const;
 
 private:
   /// Coefficients of affine equalities (in == 0 form).
@@ -252,12 +302,6 @@
   /// Coefficients of affine inequalities (in >= 0 form).
   SmallVector<int64_t, 64> inequalities;
 
-  /// Number of equalities in this system.
-  unsigned numEqualities;
-
-  /// Number of inequalities in this system.
-  unsigned numInequalities;
-
   // Pre-allocated space.
   unsigned numReservedEqualities;
   unsigned numReservedInequalities;
@@ -267,6 +311,9 @@
   unsigned numIds;
 
   /// Number of identifiers corresponding to real dimensions.
+  unsigned numResultDims;
+
+  /// Number of identifiers corresponding to real dimensions.
   unsigned numDims;
 
   /// Number of identifiers corresponding to symbols (unknown but constant for
diff --git a/include/mlir/Analysis/HyperRectangularSet.h b/include/mlir/Analysis/HyperRectangularSet.h
index ba28b3e..8956aeb 100644
--- a/include/mlir/Analysis/HyperRectangularSet.h
+++ b/include/mlir/Analysis/HyperRectangularSet.h
@@ -95,6 +95,7 @@
   HyperRectangularSet(unsigned numDims, unsigned numSymbols,
                       ArrayRef<ArrayRef<AffineExpr *>> lbs,
                       ArrayRef<ArrayRef<AffineExpr *>> ubs,
+                      MLIRContext *context,
                       IntegerSet *symbolContext = nullptr);
 
   unsigned getNumDims() const { return numDims; }
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index 825783d..6ef62dc 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -70,6 +70,12 @@
   /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
   bool isPureAffine() const;
 
+  /// Returns the greatest known common divisor of this affine expression.
+  uint64_t getKnownGcd() const;
+
+  /// Return true if the affine expression is a multiple of 'factor'.
+  bool isMultipleOf(int64_t factor) const;
+
 protected:
   explicit AffineExpr(Kind kind) : kind(kind) {}
   ~AffineExpr() {}
diff --git a/include/mlir/IR/AffineExprVisitor.h b/include/mlir/IR/AffineExprVisitor.h
index 7f0bc15..5a2753d 100644
--- a/include/mlir/IR/AffineExprVisitor.h
+++ b/include/mlir/IR/AffineExprVisitor.h
@@ -26,7 +26,7 @@
 
 namespace mlir {
 
-/// Base class for AffineExpr visitors.
+/// Base class for AffineExpr visitors/walkers.
 ///
 /// AffineExpr visitors are used when you want to perform different actions
 /// for different kinds of AffineExprs without having to use lots of casts
@@ -72,54 +72,86 @@
 /// just as efficient as having your own switch statement over the statement
 /// opcode.
 
-template <typename SubClass> class AffineExprVisitor {
+template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
   //===--------------------------------------------------------------------===//
   // Interface code - This is the public interface of the AffineExprVisitor
   // that you use to visit affine expressions...
 public:
-  // Function to visit an AffineExpr.
-  void visit(AffineExpr *expr) {
+  // Function to walk an AffineExpr (in post order).
+  RetTy walkPostOrder(AffineExpr *expr) {
     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
                   "Must instantiate with a derived type of AffineExprVisitor");
     switch (expr->getKind()) {
     case AffineExpr::Kind::Add: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryAddOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mul: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryMulOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mod: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryModOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
     }
     case AffineExpr::Kind::FloorDiv: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryFloorDivOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::CeilDiv: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryCeilDivOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::Constant:
-      return static_cast<SubClass *>(this)->visitAffineConstantExpr(
+      return static_cast<SubClass *>(this)->visitConstantExpr(
           cast<AffineConstantExpr>(expr));
     case AffineExpr::Kind::DimId:
-      return static_cast<SubClass *>(this)->visitAffineDimExpr(
+      return static_cast<SubClass *>(this)->visitDimExpr(
           cast<AffineDimExpr>(expr));
     case AffineExpr::Kind::SymbolId:
-      return static_cast<SubClass *>(this)->visitAffineSymbolExpr(
+      return static_cast<SubClass *>(this)->visitSymbolExpr(
+          cast<AffineSymbolExpr>(expr));
+    }
+  }
+
+  // Function to visit an AffineExpr.
+  RetTy visit(AffineExpr *expr) {
+    static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
+                  "Must instantiate with a derived type of AffineExprVisitor");
+    switch (expr->getKind()) {
+    case AffineExpr::Kind::Add: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::Mul: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::Mod: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::FloorDiv: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::CeilDiv: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::Constant:
+      return static_cast<SubClass *>(this)->visitConstantExpr(
+          cast<AffineConstantExpr>(expr));
+    case AffineExpr::Kind::DimId:
+      return static_cast<SubClass *>(this)->visitDimExpr(
+          cast<AffineDimExpr>(expr));
+    case AffineExpr::Kind::SymbolId:
+      return static_cast<SubClass *>(this)->visitSymbolExpr(
           cast<AffineSymbolExpr>(expr));
     }
   }
@@ -135,29 +167,30 @@
   // Default visit methods. Note that the default op-specific binary op visit
   // methods call the general visitAffineBinaryOpExpr visit method.
   void visitAffineBinaryOpExpr(AffineBinaryOpExpr *expr) {}
-  void visitAffineBinaryAddOpExpr(AffineBinaryOpExpr *expr) {
+  void visitAddExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryMulOpExpr(AffineBinaryOpExpr *expr) {
+  void visitMulExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryModOpExpr(AffineBinaryOpExpr *expr) {
+  void visitModExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryFloorDivOpExpr(AffineBinaryOpExpr *expr) {
+  void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryCeilDivOpExpr(AffineBinaryOpExpr *expr) {
+  void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineConstantExpr(AffineConstantExpr *expr) {}
+  void visitConstantExpr(AffineConstantExpr *expr) {}
   void visitAffineDimExpr(AffineDimExpr *expr) {}
   void visitAffineSymbolExpr(AffineSymbolExpr *expr) {}
 
 private:
-  void visitAffineBinaryOpOperands(AffineBinaryOpExpr *expr) {
-    visit(expr->getLHS());
-    visit(expr->getRHS());
+  // Walk the operands - each operand is itself walked in post order.
+  void walkOperandsPostOrder(AffineBinaryOpExpr *expr) {
+    walkPostOrder(expr->getLHS());
+    walkPostOrder(expr->getRHS());
   }
 };
 
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index b9320c8..30ece71 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -25,6 +25,7 @@
 
 namespace mlir {
 
+class FunctionPass;
 class MLFunctionPass;
 class ModulePass;
 
@@ -38,6 +39,9 @@
 /// line if provided.
 MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
 
+/// Creates an affine expression simplification pass.
+FunctionPass *createSimplifyAffineExprPass();
+
 /// Replaces all ML functions in the module with equivalent CFG functions.
 /// Function references are appropriately patched to refer to the newly
 /// generated CFG functions.