Affine expression analysis and simplification.

Outside of IR/
- simplify a MutableAffineMap by flattening the affine expressions
- add a simplify affine expression pass that uses this analysis
- update the FlatAffineConstraints API (to be used in the next CL)

In IR:
- add isMultipleOf and getKnownGCD for AffineExpr, and make the in-IR
  simplication of simplifyMod simpler and more powerful.
- rename the AffineExpr visitor methods to distinguish b/w visiting and
  walking, and to simplify API names based on context.

The next CL will use some of these for the loop unrolling/unroll-jam to make
the detection for the need of cleanup loop powerful/non-trivial.

A future CL will finally move this simplification to FlatAffineConstraints to
make it more powerful. For eg., currently, even if a mod expr appearing in a
part of the expression tree can't be simplified, the whole thing won't be
simplified.

PiperOrigin-RevId: 211012256
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index 6d28260..ac6e728 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -41,23 +41,41 @@
 /// A mutable affine map. Its affine expressions are however unique.
 struct MutableAffineMap {
 public:
-  explicit MutableAffineMap(AffineMap *map);
+  MutableAffineMap(AffineMap *map, MLIRContext *context);
 
   AffineExpr *getResult(unsigned idx) const { return results[idx]; }
   unsigned getNumResults() const { return results.size(); }
+  unsigned getNumDims() const { return numDims; }
+  unsigned getNumSymbols() const { return numSymbols; }
+  /// Returns true if the idx'th result expression is a multiple of factor.
+  bool isMultipleOf(unsigned idx, int64_t factor) const;
+
+  /// Simplify the (result) expressions in this map using analysis (used by
+  //-simplify-affine-expr pass).
+  void simplify();
+  /// Get the AffineMap corresponding to this MutableAffineMap. Note that an
+  /// AffineMap * will be uniqued and stored in context, while a mutable one
+  /// isn't.
+  AffineMap *getAffineMap();
 
 private:
+  // Same meaning as AffineMap's fields.
   SmallVector<AffineExpr *, 8> results;
   SmallVector<AffineExpr *, 8> rangeSizes;
+  unsigned numDims;
+  unsigned numSymbols;
+  /// A pointer to the IR's context to store all newly created AffineExpr's.
+  MLIRContext *context;
 };
 
 /// A mutable integer set. Its affine expressions are however unique.
 struct MutableIntegerSet {
 public:
-  explicit MutableIntegerSet(IntegerSet *set);
+  MutableIntegerSet(IntegerSet *set, MLIRContext *context);
 
   /// Create a universal set (no constraints).
-  explicit MutableIntegerSet(unsigned numDims, unsigned numSymbols);
+  MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+                    MLIRContext *context);
 
   unsigned getNumDims() const { return numDims; }
   unsigned getNumSymbols() const { return numSymbols; }
@@ -74,6 +92,8 @@
 
   SmallVector<AffineExpr *, 8> constraints;
   SmallVector<bool, 8> eqFlags;
+  /// A pointer to the IR's context to store all newly created AffineExpr's.
+  MLIRContext *context;
 };
 
 /// An AffineValueMap is an affine map plus its ML value operands and
@@ -89,9 +109,9 @@
 // TODO(bondhugula): Some of these classes could go into separate files.
 class AffineValueMap {
 public:
-  explicit AffineValueMap(const AffineApplyOp &op);
-  explicit AffineValueMap(const AffineBound &bound);
-  explicit AffineValueMap(AffineMap *map);
+  AffineValueMap(const AffineApplyOp &op, MLIRContext *context);
+  AffineValueMap(const AffineBound &bound, MLIRContext *context);
+  AffineValueMap(AffineMap *map, MLIRContext *context);
 
   ~AffineValueMap();
 
@@ -110,7 +130,7 @@
 
   /// Return true if the idx^th result can be proved to be a multiple of
   /// 'factor', false otherwise.
-  bool isMultipleOf(unsigned idx, int64_t factor) const;
+  inline bool isMultipleOf(unsigned idx, int64_t factor) const;
 
   /// Return true if the result at 'idx' is a constant, false
   /// otherwise.
@@ -128,8 +148,6 @@
   SmallVector<MLValue *, 4> operands;
   /// The SSA results binding to the results of 'map'.
   SmallVector<MLValue *, 4> results;
-  /// A pointer to the IR's context to store all newly created AffineExpr's.
-  MLIRContext *context;
 };
 
 /// An IntegerValueSet is an integer set plus its operands.
@@ -155,8 +173,6 @@
   MutableIntegerSet set;
   /// The SSA operands binding to the dim's and symbols of 'set'.
   SmallVector<MLValue *, 4> operands;
-  /// A pointer to the IR's context to store all newly created AffineExpr's.
-  MLIRContext *context;
 };
 
 /// A flat list of affine equalities and inequalities in the form.
@@ -190,8 +206,7 @@
   /// constraints and identifiers..
   FlatAffineConstraints(unsigned numReservedInequalities,
                         unsigned numReservedEqualities, unsigned numReservedIds)
-      : numEqualities(0), numInequalities(0),
-        numReservedEqualities(numReservedEqualities),
+      : numReservedEqualities(numReservedEqualities),
         numReservedInequalities(numReservedInequalities),
         numReservedIds(numReservedIds) {
     equalities.reserve(numReservedIds * numReservedEqualities);
@@ -208,23 +223,50 @@
 
   /// Create an affine constraint system from an IntegerValueSet.
   // TODO(bondhugula)
-  FlatAffineConstraints(const IntegerValueSet &set);
+  explicit FlatAffineConstraints(const IntegerValueSet &set);
 
   FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef,
                         const IntegerSet &set);
 
+  FlatAffineConstraints(const MutableAffineMap &map);
+
   ~FlatAffineConstraints() {}
 
-  inline int64_t atEq(unsigned i, unsigned j) {
-    return equalities[i * numIds + j];
+  inline int64_t atEq(unsigned i, unsigned j) const {
+    return equalities[i * (numIds + 1) + j];
   }
 
-  inline int64_t atIneq(unsigned i, unsigned j) {
-    return inequalities[i * numIds + j];
+  inline int64_t &atEq(unsigned i, unsigned j) {
+    return equalities[i * (numIds + 1) + j];
   }
 
-  unsigned getNumEqualities() const { return equalities.size(); }
-  unsigned getNumInequalities() const { return inequalities.size(); }
+  inline int64_t atIneq(unsigned i, unsigned j) const {
+    return inequalities[i * (numIds + 1) + j];
+  }
+
+  inline int64_t &atIneq(unsigned i, unsigned j) {
+    return inequalities[i * (numIds + 1) + j];
+  }
+
+  inline unsigned getNumCols() const { return numIds + 1; }
+
+  inline unsigned getNumEqualities() const {
+    return equalities.size() / getNumCols();
+  }
+
+  inline unsigned getNumInequalities() const {
+    return inequalities.size() / getNumCols();
+  }
+
+  ArrayRef<int64_t> getEquality(unsigned idx) {
+    return ArrayRef<int64_t>(&equalities[idx * getNumCols()], getNumCols());
+  }
+
+  ArrayRef<int64_t> getInequality(unsigned idx) {
+    return ArrayRef<int64_t>(&inequalities[idx * getNumCols()], getNumCols());
+  }
+
+  AffineExpr *toAffineExpr(unsigned idx, MLIRContext *context);
 
   void addInequality(ArrayRef<int64_t> inEq);
   void addEquality(ArrayRef<int64_t> eq);
@@ -239,11 +281,19 @@
   void removeEquality(unsigned pos);
   void removeInequality(unsigned pos);
 
-  unsigned getNumConstraints() const { return numEqualities + numInequalities; }
-  unsigned getNumIds() const { return numIds; }
-  unsigned getNumDimIds() const { return numDims; }
-  unsigned getNumSymbolIds() const { return numSymbols; }
-  unsigned getNumLocalIds() const { return numDims - numSymbols - numDims; }
+  unsigned getNumConstraints() const {
+    return equalities.size() + inequalities.size();
+  }
+  inline unsigned getNumIds() const { return numIds; }
+  inline unsigned getNumResultDimIds() const { return numResultDims; }
+  inline unsigned getNumDimIds() const { return numDims; }
+  inline unsigned getNumSymbolIds() const { return numSymbols; }
+  inline unsigned getNumLocalIds() const {
+    return numIds - numResultDims - numDims - numSymbols;
+  }
+
+  void print(raw_ostream &os) const;
+  void dump() const;
 
 private:
   /// Coefficients of affine equalities (in == 0 form).
@@ -252,12 +302,6 @@
   /// Coefficients of affine inequalities (in >= 0 form).
   SmallVector<int64_t, 64> inequalities;
 
-  /// Number of equalities in this system.
-  unsigned numEqualities;
-
-  /// Number of inequalities in this system.
-  unsigned numInequalities;
-
   // Pre-allocated space.
   unsigned numReservedEqualities;
   unsigned numReservedInequalities;
@@ -267,6 +311,9 @@
   unsigned numIds;
 
   /// Number of identifiers corresponding to real dimensions.
+  unsigned numResultDims;
+
+  /// Number of identifiers corresponding to real dimensions.
   unsigned numDims;
 
   /// Number of identifiers corresponding to symbols (unknown but constant for
diff --git a/include/mlir/Analysis/HyperRectangularSet.h b/include/mlir/Analysis/HyperRectangularSet.h
index ba28b3e..8956aeb 100644
--- a/include/mlir/Analysis/HyperRectangularSet.h
+++ b/include/mlir/Analysis/HyperRectangularSet.h
@@ -95,6 +95,7 @@
   HyperRectangularSet(unsigned numDims, unsigned numSymbols,
                       ArrayRef<ArrayRef<AffineExpr *>> lbs,
                       ArrayRef<ArrayRef<AffineExpr *>> ubs,
+                      MLIRContext *context,
                       IntegerSet *symbolContext = nullptr);
 
   unsigned getNumDims() const { return numDims; }
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index 825783d..6ef62dc 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -70,6 +70,12 @@
   /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
   bool isPureAffine() const;
 
+  /// Returns the greatest known common divisor of this affine expression.
+  uint64_t getKnownGcd() const;
+
+  /// Return true if the affine expression is a multiple of 'factor'.
+  bool isMultipleOf(int64_t factor) const;
+
 protected:
   explicit AffineExpr(Kind kind) : kind(kind) {}
   ~AffineExpr() {}
diff --git a/include/mlir/IR/AffineExprVisitor.h b/include/mlir/IR/AffineExprVisitor.h
index 7f0bc15..5a2753d 100644
--- a/include/mlir/IR/AffineExprVisitor.h
+++ b/include/mlir/IR/AffineExprVisitor.h
@@ -26,7 +26,7 @@
 
 namespace mlir {
 
-/// Base class for AffineExpr visitors.
+/// Base class for AffineExpr visitors/walkers.
 ///
 /// AffineExpr visitors are used when you want to perform different actions
 /// for different kinds of AffineExprs without having to use lots of casts
@@ -72,54 +72,86 @@
 /// just as efficient as having your own switch statement over the statement
 /// opcode.
 
-template <typename SubClass> class AffineExprVisitor {
+template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
   //===--------------------------------------------------------------------===//
   // Interface code - This is the public interface of the AffineExprVisitor
   // that you use to visit affine expressions...
 public:
-  // Function to visit an AffineExpr.
-  void visit(AffineExpr *expr) {
+  // Function to walk an AffineExpr (in post order).
+  RetTy walkPostOrder(AffineExpr *expr) {
     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
                   "Must instantiate with a derived type of AffineExprVisitor");
     switch (expr->getKind()) {
     case AffineExpr::Kind::Add: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryAddOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mul: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryMulOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mod: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryModOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
     }
     case AffineExpr::Kind::FloorDiv: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryFloorDivOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::CeilDiv: {
       auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      visitAffineBinaryOpOperands(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAffineBinaryCeilDivOpExpr(
-          binOpExpr);
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::Constant:
-      return static_cast<SubClass *>(this)->visitAffineConstantExpr(
+      return static_cast<SubClass *>(this)->visitConstantExpr(
           cast<AffineConstantExpr>(expr));
     case AffineExpr::Kind::DimId:
-      return static_cast<SubClass *>(this)->visitAffineDimExpr(
+      return static_cast<SubClass *>(this)->visitDimExpr(
           cast<AffineDimExpr>(expr));
     case AffineExpr::Kind::SymbolId:
-      return static_cast<SubClass *>(this)->visitAffineSymbolExpr(
+      return static_cast<SubClass *>(this)->visitSymbolExpr(
+          cast<AffineSymbolExpr>(expr));
+    }
+  }
+
+  // Function to visit an AffineExpr.
+  RetTy visit(AffineExpr *expr) {
+    static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
+                  "Must instantiate with a derived type of AffineExprVisitor");
+    switch (expr->getKind()) {
+    case AffineExpr::Kind::Add: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::Mul: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::Mod: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::FloorDiv: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::CeilDiv: {
+      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+    }
+    case AffineExpr::Kind::Constant:
+      return static_cast<SubClass *>(this)->visitConstantExpr(
+          cast<AffineConstantExpr>(expr));
+    case AffineExpr::Kind::DimId:
+      return static_cast<SubClass *>(this)->visitDimExpr(
+          cast<AffineDimExpr>(expr));
+    case AffineExpr::Kind::SymbolId:
+      return static_cast<SubClass *>(this)->visitSymbolExpr(
           cast<AffineSymbolExpr>(expr));
     }
   }
@@ -135,29 +167,30 @@
   // Default visit methods. Note that the default op-specific binary op visit
   // methods call the general visitAffineBinaryOpExpr visit method.
   void visitAffineBinaryOpExpr(AffineBinaryOpExpr *expr) {}
-  void visitAffineBinaryAddOpExpr(AffineBinaryOpExpr *expr) {
+  void visitAddExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryMulOpExpr(AffineBinaryOpExpr *expr) {
+  void visitMulExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryModOpExpr(AffineBinaryOpExpr *expr) {
+  void visitModExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryFloorDivOpExpr(AffineBinaryOpExpr *expr) {
+  void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineBinaryCeilDivOpExpr(AffineBinaryOpExpr *expr) {
+  void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitAffineConstantExpr(AffineConstantExpr *expr) {}
+  void visitConstantExpr(AffineConstantExpr *expr) {}
   void visitAffineDimExpr(AffineDimExpr *expr) {}
   void visitAffineSymbolExpr(AffineSymbolExpr *expr) {}
 
 private:
-  void visitAffineBinaryOpOperands(AffineBinaryOpExpr *expr) {
-    visit(expr->getLHS());
-    visit(expr->getRHS());
+  // Walk the operands - each operand is itself walked in post order.
+  void walkOperandsPostOrder(AffineBinaryOpExpr *expr) {
+    walkPostOrder(expr->getLHS());
+    walkPostOrder(expr->getRHS());
   }
 };
 
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index b9320c8..30ece71 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -25,6 +25,7 @@
 
 namespace mlir {
 
+class FunctionPass;
 class MLFunctionPass;
 class ModulePass;
 
@@ -38,6 +39,9 @@
 /// line if provided.
 MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
 
+/// Creates an affine expression simplification pass.
+FunctionPass *createSimplifyAffineExprPass();
+
 /// Replaces all ML functions in the module with equivalent CFG functions.
 /// Function references are appropriately patched to refer to the newly
 /// generated CFG functions.
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index 782257e..965004b 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -20,47 +20,66 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/AffineStructures.h"
+
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardOps.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 
-MutableAffineMap::MutableAffineMap(AffineMap *map) {
+MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
+    : numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
+      context(context) {
   for (auto *result : map->getResults())
     results.push_back(result);
   for (auto *rangeSize : map->getRangeSizes())
     results.push_back(rangeSize);
 }
 
-MutableIntegerSet::MutableIntegerSet(IntegerSet *set)
-    : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()) {
-  // TODO(bondhugula)
-}
+bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
+  if (results[idx]->isMultipleOf(factor))
+    return true;
 
-// Universal set.
-MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols)
-    : numDims(numDims), numSymbols(numSymbols) {}
-
-AffineValueMap::AffineValueMap(const AffineApplyOp &op)
-    : map(op.getAffineMap()) {
-  // TODO: pull operands and results in.
-}
-
-bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
-  // Check if the (first result expr) % factor becomes 0.
-  if (auto *expr = dyn_cast<AffineConstantExpr>(AffineBinaryOpExpr::get(
-          AffineExpr::Kind::Mod, map.getResult(idx),
-          AffineConstantExpr::get(factor, context), context)))
-    return expr->getValue() == 0;
-
-  // TODO(bondhugula): use FlatAffineConstraints to complete this.
+  // TODO(bondhugula): use FlatAffineConstraints to complete this (for a more
+  // powerful analysis).
   assert(0 && "isMultipleOf implementation incomplete");
   return false;
 }
 
+MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context)
+    : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()),
+      context(context) {
+  // TODO(bondhugula)
+}
+
+// Universal set.
+MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+                                     MLIRContext *context)
+    : numDims(numDims), numSymbols(numSymbols), context(context) {}
+
+AffineValueMap::AffineValueMap(const AffineApplyOp &op, MLIRContext *context)
+    : map(op.getAffineMap(), context) {
+  // TODO: pull operands and results in.
+}
+
+inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
+  return map.isMultipleOf(idx, factor);
+}
+
 AffineValueMap::~AffineValueMap() {}
 
+void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
+  assert(eq.size() == getNumCols());
+  unsigned offset = equalities.size();
+  equalities.resize(equalities.size() + eq.size());
+  for (unsigned i = 0, e = eq.size(); i < e; i++) {
+    equalities[offset + i] = eq[i];
+  }
+}
+
 } // end namespace mlir
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 1a06515..14e180b 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -110,9 +110,10 @@
 HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
                                          ArrayRef<ArrayRef<AffineExpr *>> lbs,
                                          ArrayRef<ArrayRef<AffineExpr *>> ubs,
+                                         MLIRContext *context,
                                          IntegerSet *symbolContext)
-    : context(symbolContext ? MutableIntegerSet(symbolContext)
-                            : MutableIntegerSet(numDims, numSymbols)) {
+    : context(symbolContext ? MutableIntegerSet(symbolContext, context)
+                            : MutableIntegerSet(numDims, numSymbols, context)) {
   unsigned d = 0;
   for (auto boundList : lbs) {
     AffineBoundExprList lb;
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 6bfbaf5..6393b3e 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -100,3 +100,56 @@
   }
   }
 }
+
+uint64_t AffineExpr::getKnownGcd() const {
+  AffineBinaryOpExpr *binExpr = nullptr;
+  switch (kind) {
+  case Kind::SymbolId:
+    LLVM_FALLTHROUGH;
+  case Kind::DimId:
+    return 1;
+  case Kind::Constant:
+    return std::abs(cast<AffineConstantExpr>(this)->getValue());
+  case Kind::Mul:
+    binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+    return binExpr->getLHS()->getKnownGcd() * binExpr->getRHS()->getKnownGcd();
+  case Kind::Add:
+    LLVM_FALLTHROUGH;
+  case Kind::FloorDiv:
+  case Kind::CeilDiv:
+  case Kind::Mod:
+    binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+    return llvm::GreatestCommonDivisor64(binExpr->getLHS()->getKnownGcd(),
+                                         binExpr->getRHS()->getKnownGcd());
+  }
+}
+
+bool AffineExpr::isMultipleOf(int64_t factor) const {
+  AffineBinaryOpExpr *binExpr = nullptr;
+  uint64_t l, u;
+  switch (kind) {
+  case Kind::SymbolId:
+    LLVM_FALLTHROUGH;
+  case Kind::DimId:
+    return factor * factor == 1;
+  case Kind::Constant:
+    return cast<AffineConstantExpr>(this)->getValue() % factor == 0;
+  case Kind::Mul:
+    binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+    // It's probably not worth optimizing this further (to not traverse the
+    // whole sub-tree under - it that would require a version of isMultipleOf
+    // that on a 'false' return also returns the known GCD).
+    return (l = binExpr->getLHS()->getKnownGcd()) % factor == 0 ||
+           (u = binExpr->getRHS()->getKnownGcd()) % factor == 0 ||
+           (l * u) % factor == 0;
+  case Kind::Add:
+  case Kind::FloorDiv:
+  case Kind::CeilDiv:
+  case Kind::Mod:
+    binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+    return llvm::GreatestCommonDivisor64(binExpr->getLHS()->getKnownGcd(),
+                                         binExpr->getRHS()->getKnownGcd()) %
+               factor ==
+           0;
+  }
+}
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 2b6d680..37a3b45 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -225,17 +225,13 @@
     return AffineConstantExpr::get(lhsConst->getValue() % rhsConst->getValue(),
                                    context);
 
-  // Fold modulo of a multiply with a constant that is a multiple of the
-  // modulo factor to zero. Eg: (i * 128) mod 64 = 0.
+  // Fold modulo of an expression that is known to be a multiple of a constant
+  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
+  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
   if (rhsConst) {
-    auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
-    if (lBin && lBin->getKind() == Kind::Mul) {
-      if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
-        // rhsConst is known to be positive if a constant.
-        if (lrhs->getValue() % rhsConst->getValue() == 0)
-          return AffineConstantExpr::get(0, context);
-      }
-    }
+    // rhsConst is known to be positive if a constant.
+    if (lhs->getKnownGcd() % rhsConst->getValue() == 0)
+      return AffineConstantExpr::get(0, context);
   }
 
   return nullptr;
diff --git a/lib/Transforms/SimplifyAffineExpr.cpp b/lib/Transforms/SimplifyAffineExpr.cpp
new file mode 100644
index 0000000..3c72887
--- /dev/null
+++ b/lib/Transforms/SimplifyAffineExpr.cpp
@@ -0,0 +1,252 @@
+//===- SimplifyAffineExpr.cpp - MLIR Affine Structures Class-----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to simplify affine expressions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StmtVisitor.h"
+
+#include "mlir/Transforms/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using llvm::report_fatal_error;
+
+namespace {
+
+/// Simplify all affine expressions appearing in the operation statements of the
+/// MLFunction.
+//  TODO(someone): Gradually, extend this to all affine map references found in
+//  ML functions and CFG functions.
+struct SimplifyAffineExpr : public FunctionPass {
+  explicit SimplifyAffineExpr() {}
+
+  void runOnMLFunction(MLFunction *f);
+  // Does nothing on CFG functions for now. No reusable walkers/visitors exist
+  // for this yet? TODO(someone).
+  void runOnCFGFunction(CFGFunction *f) {}
+};
+
+// This class is used to flatten a pure affine expression into a sum of products
+// (w.r.t constants) when possible, and in that process accumulating
+// contributions for each dimensional and symbolic identifier together. Note
+// that an affine expression may not always be expressible that way due to the
+// preesnce of modulo, floordiv, and ceildiv expressions. A simplification that
+// this flattening naturally performs is to fold a modulo expression to a zero,
+// if possible. Two examples are below:
+//
+// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to  d0 + d1
+// (d0 - d0 mod 4 + 4) mod 4  simplified to 0.
+//
+// For modulo and floordiv expressions, an additional variable is introduced to
+// rewrite it as a sum of products (w.r.t constants). For example, for the
+// second example above, d0 % 4 is replaced by d0 - 4*q with q being introduced:
+// the expression simplifies to:
+// (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to zero.
+//
+// This is a linear time post order walk for an affine expression that attempts
+// the above simplifications through visit methods, with partial results being
+// stored in 'operandExprStack'. When a parent expr is visited, the flattened
+// expressions corresponding to its two operands would already be on the stack -
+// the parent expr looks at the two flattened expressions and combines the two.
+// It pops off the operand expressions and pushes the combined result (although
+// this is done in-place on its LHS operand expr. When the walk is completed,
+// the flattened form of the top-level expression would be left on the stack.
+//
+class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
+public:
+  std::vector<SmallVector<int64_t, 32>> operandExprStack;
+
+  // The layout of the flattened expressions is dimensions, symbols, locals,
+  // and constant term.
+  unsigned getNumCols() const { return numDims + numSymbols + numLocals + 1; }
+
+  AffineExprFlattener(unsigned numDims, unsigned numSymbols)
+      : numDims(numDims), numSymbols(numSymbols), numLocals(0) {}
+
+  void visitMulExpr(AffineBinaryOpExpr *expr) {
+    assert(expr->isPureAffine());
+    // Get the RHS constant.
+    auto rhsConst = operandExprStack.back()[getNumCols() - 1];
+    operandExprStack.pop_back();
+    // Update the LHS in place instead of pop and push.
+    auto &lhs = operandExprStack.back();
+    for (unsigned i = 0, e = lhs.size(); i < e; i++) {
+      lhs[i] *= rhsConst;
+    }
+  }
+  void visitAddExpr(AffineBinaryOpExpr *expr) {
+    const auto &rhs = operandExprStack.back();
+    auto &lhs = operandExprStack[operandExprStack.size() - 2];
+    assert(lhs.size() == rhs.size());
+    // Update the LHS in place.
+    for (unsigned i = 0; i < rhs.size(); i++) {
+      lhs[i] += rhs[i];
+    }
+    // Pop off the RHS.
+    operandExprStack.pop_back();
+  }
+  void visitModExpr(AffineBinaryOpExpr *expr) {
+    assert(expr->isPureAffine());
+    // This is a pure affine expr; the RHS is a constant.
+    auto rhsConst = operandExprStack.back()[getNumCols() - 1];
+    operandExprStack.pop_back();
+    auto &lhs = operandExprStack.back();
+    assert(rhsConst != 0 && "RHS constant can't be zero");
+    unsigned i;
+    for (i = 0; i < lhs.size(); i++)
+      if (lhs[i] % rhsConst != 0)
+        break;
+    if (i == lhs.size()) {
+      // The modulo expression here simplifies to zero.
+      lhs.assign(lhs.size(), 0);
+      return;
+    }
+    // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
+    // q * expr2) where q is the existential quantifier introduced.
+    addExistentialQuantifier();
+    lhs = operandExprStack.back();
+    lhs[numDims + numSymbols + numLocals - 1] = -rhsConst;
+  }
+  void visitConstantExpr(AffineConstantExpr *expr) {
+    operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+    auto &eq = operandExprStack.back();
+    eq[getNumCols() - 1] = expr->getValue();
+  }
+  void visitDimExpr(AffineDimExpr *expr) {
+    SmallVector<int64_t, 32> eq(getNumCols(), 0);
+    eq[expr->getPosition()] = 1;
+    operandExprStack.push_back(eq);
+  }
+  void visitSymbolExpr(AffineSymbolExpr *expr) {
+    SmallVector<int64_t, 32> eq(getNumCols(), 0);
+    eq[numDims + expr->getPosition()] = 1;
+    operandExprStack.push_back(eq);
+  }
+  void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
+    // TODO(bondhugula): handle ceildiv as well; won't simplify further through
+    // this analysis but will be handled (rest of the expr will simplify).
+    report_fatal_error("ceildiv expr simplification not supported here");
+  }
+  void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
+    // TODO(bondhugula): handle ceildiv as well; won't simplify further through
+    // this analysis but will be handled (rest of the expr will simplify).
+    report_fatal_error("floordiv expr simplification unimplemented");
+  }
+  // Add an existential quantifier (used to flatten a mod or a floordiv expr).
+  void addExistentialQuantifier() {
+    for (auto &subExpr : operandExprStack) {
+      subExpr.insert(subExpr.begin() + numDims + numSymbols + numLocals, 0);
+    }
+    numLocals++;
+  }
+
+  unsigned numDims;
+  unsigned numSymbols;
+  unsigned numLocals;
+};
+
+} // end anonymous namespace
+
+FunctionPass *mlir::createSimplifyAffineExprPass() {
+  return new SimplifyAffineExpr();
+}
+
+AffineMap *MutableAffineMap::getAffineMap() {
+  return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
+}
+
+void SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
+  struct MapSimplifier : public StmtWalker<MapSimplifier> {
+    MLIRContext *context;
+    MapSimplifier(MLIRContext *context) : context(context) {}
+
+    void visitOperationStmt(OperationStmt *opStmt) {
+      for (auto attr : opStmt->getAttrs()) {
+        if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
+          MutableAffineMap mMap(mapAttr->getValue(), context);
+          mMap.simplify();
+          auto *map = mMap.getAffineMap();
+          opStmt->setAttr(attr.first, AffineMapAttr::get(map, context));
+        }
+      }
+    }
+  };
+
+  MapSimplifier v(f->getContext());
+  v.walkPostOrder(f);
+}
+
+/// Get an affine expression from a flat ArrayRef. If there are local variables
+/// (existential quantifiers introduced during the flattening) that appear in
+/// the sum of products expression, we can't readily express it as an affine
+/// expression of dimension and symbol id's; return nullptr in such cases.
+static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+                                unsigned numSymbols, MLIRContext *context) {
+  // Check if any local variable has a non-zero coefficient.
+  for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
+    if (eq[j] != 0)
+      return nullptr;
+  }
+
+  AffineExpr *expr = AffineConstantExpr::get(0, context);
+  for (unsigned j = 0; j < numDims + numSymbols; j++) {
+    if (eq[j] != 0) {
+      AffineExpr *id =
+          j < numDims
+              ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
+              : AffineSymbolExpr::get(j - numDims, context);
+      expr = AffineBinaryOpExpr::get(
+          AffineExpr::Kind::Add, expr,
+          AffineBinaryOpExpr::get(AffineExpr::Kind::Mul,
+                                  AffineConstantExpr::get(eq[j], context), id,
+                                  context),
+          context);
+    }
+  }
+  unsigned constTerm = eq[eq.size() - 1];
+  if (constTerm != 0)
+    expr = AffineBinaryOpExpr::get(AffineExpr::Kind::Add, expr,
+                                   AffineConstantExpr::get(constTerm, context),
+                                   context);
+  return expr;
+}
+
+// Simplify the result affine expressions of this map. The expressions have to
+// be pure for the simplification implemented.
+void MutableAffineMap::simplify() {
+  // Simplify each of the results if possible.
+  for (unsigned i = 0, e = getNumResults(); i < e; i++) {
+    AffineExpr *result = getResult(i);
+    if (!result->isPureAffine())
+      continue;
+
+    AffineExprFlattener flattener(numDims, numSymbols);
+    flattener.walkPostOrder(result);
+    const auto &flattenedExpr = flattener.operandExprStack.back();
+    auto *expr = toAffineExpr(flattenedExpr, numDims, numSymbols, context);
+    if (expr)
+      results[i] = expr;
+    flattener.operandExprStack.pop_back();
+    assert(flattener.operandExprStack.empty());
+  }
+}
diff --git a/test/IR/affine-map.mlir b/test/IR/affine-map.mlir
index 4aacbd3..5d16f13 100644
--- a/test/IR/affine-map.mlir
+++ b/test/IR/affine-map.mlir
@@ -29,8 +29,9 @@
 #map3j = (i, j) -> (i + 1, j*1*4 + 2)
 #map3k = (i, j) -> (i + 1, j*4*1 + 2)
 
-// The following reduction should be unique'd out too but the expression
-// simplifier is not powerful enough.
+// The following reduction should be unique'd out too but such expression
+// simplification is not performed for IR parsing, but only through analyses
+// and transforms.
 // CHECK: #map{{[0-9]+}} = (d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, d1 + d1 + d1 + d1 + 2)
 #map3l = (i, j) -> ((j - i) + 2*(i - j + 1) + j - 1 + 0, j + j + 1 + j + j + 1)
 
@@ -155,13 +156,17 @@
 // CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (0, d1, d0 * 2, 0)
 #map46 = (i, j, k) -> (i*0, 1*j, i * 128 floordiv 64, j * 0 floordiv 64)
 
-// CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (d0, d0 * 4, 0, 0)
-#map47 = (i, j, k) -> (i * 64 ceildiv 64, i * 512 ceildiv 128, 4 * j mod 4, 4*j*4 mod 8)
+// CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (d0, d0 * 4, 0, 0, 0)
+#map47 = (i, j, k) -> (i * 64 ceildiv 64, i * 512 ceildiv 128, 4 * j mod 4, 4*j*4 mod 8, k mod 1)
 
-// floordiv should resolve similarly to ceildiv and be unique'd out
+// floordiv should resolve similarly to ceildiv and be unique'd out.
 // CHECK-NOT: #map48{{[a-z]}}
 #map48 = (i, j, k) -> (i * 64 floordiv 64, i * 512 floordiv 128, 4 * j mod 4, 4*j*4 mod 8)
 
+// Simplifications for mod using known GCD's of the LHS expr.
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (0, 0, 0, (d0 * 4 + 3) mod 2)
+#map49 = (i, j)[s0] -> ( (i * 4 + 8) mod 4, 32 * j * s0 * 8 mod 256, (4*i + (j * (s0 * 2))) mod 2, (4*i + 3) mod 2)
+
 // CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)
 extfunc @f0(memref<2x4xi8, #map0, 1>)
 
@@ -323,3 +328,6 @@
 
 // CHECK: extfunc @f48(memref<100x100x100xi8, #map{{[0-9]+}}>)
 extfunc @f48(memref<100x100x100xi8, #map48>)
+
+// CHECK: extfunc @f49(memref<100x100xi8, #map{{[0-9]+}}>)
+extfunc @f49(memref<100x100xi8, #map49>)
diff --git a/test/Transforms/simplify.mlir b/test/Transforms/simplify.mlir
new file mode 100644
index 0000000..19d9cb1
--- /dev/null
+++ b/test/Transforms/simplify.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -o - -simplify-affine-expr | FileCheck %s
+
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (0, 0)
+#map0 = (d0, d1) -> ((d0 - d0 mod 4) mod 4, (d0 - d0 mod 128 - 64) mod 64)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 1, d1 * 5 + 3)
+#map1 = (d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, 1 + 2*d1 + d1 + d1 + d1 + 2)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (0, 0, 0)
+#map2 = (d0, d1) -> (((d0 - d0 mod 2) * 2) mod 4, (5*d1 + 8 - (5*d1 + 4) mod 4) mod 4, 0)
+
+mlfunc @test() {
+  for %n0 = 0 to 127 {
+    for %n1 = 0 to 7 {
+      %x  = affine_apply #map0(%n0, %n1)
+      %y  = affine_apply #map1(%n0, %n1)
+      %z  = affine_apply #map2(%n0, %n1)
+    }
+  }
+  return
+}
+
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index e1cba40..5bbf6b7 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -37,6 +37,7 @@
 #include "llvm/Support/Regex.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/ToolOutputFile.h"
+
 using namespace mlir;
 using namespace llvm;
 
@@ -55,6 +56,7 @@
   ConvertToCFG,
   LoopUnroll,
   LoopUnrollAndJam,
+  SimplifyAffineExpr,
   TFRaiseControlFlow,
 };
 
@@ -65,6 +67,8 @@
                clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
                clEnumValN(LoopUnrollAndJam, "loop-unroll-jam",
                           "Unroll and jam loops"),
+               clEnumValN(SimplifyAffineExpr, "simplify-affine-expr",
+                          "Simplify affine expressions"),
                clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
                           "Dynamic TensorFlow Switch/Match nodes to a CFG")));
 
@@ -117,6 +121,9 @@
     case LoopUnrollAndJam:
       pass = createLoopUnrollAndJamPass();
       break;
+    case SimplifyAffineExpr:
+      pass = createSimplifyAffineExprPass();
+      break;
     case TFRaiseControlFlow:
       pass = createRaiseTFControlFlowPass();
       break;