Support for affine integer sets

- introduce affine integer sets into the IR
- parse and print affine integer sets (both inline or outlined) similar to
  affine maps
- use integer set for IfStmt's conditional, and implement parsing of IfStmt's
  conditional

- fixed an affine expr paren omission bug while one this.

TODO: parse/represent/print MLValue operands to affine integer set references.
PiperOrigin-RevId: 207779408
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 2f67c27..c2cd257 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
@@ -72,6 +73,18 @@
 
   ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; }
 
+  int getIntegerSetId(const IntegerSet *integerSet) const {
+    auto it = integerSetIds.find(integerSet);
+    if (it == integerSetIds.end()) {
+      return -1;
+    }
+    return it->second;
+  }
+
+  ArrayRef<const IntegerSet *> getIntegerSetIds() const {
+    return integerSetsById;
+  }
+
 private:
   void recordAffineMapReference(const AffineMap *affineMap) {
     if (affineMapIds.count(affineMap) == 0) {
@@ -80,17 +93,31 @@
     }
   }
 
+  void recordIntegerSetReference(const IntegerSet *integerSet) {
+    if (integerSetIds.count(integerSet) == 0) {
+      integerSetIds[integerSet] = integerSetsById.size();
+      integerSetsById.push_back(integerSet);
+    }
+  }
+
   // Visit functions.
   void visitFunction(const Function *fn);
   void visitExtFunction(const ExtFunction *fn);
   void visitCFGFunction(const CFGFunction *fn);
   void visitMLFunction(const MLFunction *fn);
+  void visitStatement(const Statement *stmt);
+  void visitForStmt(const ForStmt *forStmt);
+  void visitIfStmt(const IfStmt *ifStmt);
+  void visitOperationStmt(const OperationStmt *opStmt);
   void visitType(const Type *type);
   void visitAttribute(const Attribute *attr);
   void visitOperation(const Operation *op);
 
   DenseMap<const AffineMap *, int> affineMapIds;
   std::vector<const AffineMap *> affineMapsById;
+
+  DenseMap<const IntegerSet *, int> integerSetIds;
+  std::vector<const IntegerSet *> integerSetsById;
 };
 } // end anonymous namespace
 
@@ -113,8 +140,8 @@
 void ModuleState::visitAttribute(const Attribute *attr) {
   if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
     recordAffineMapReference(mapAttr->getValue());
-  } else if (auto *array = dyn_cast<ArrayAttr>(attr)) {
-    for (auto elt : array->getValue()) {
+  } else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    for (auto elt : arrayAttr->getValue()) {
       visitAttribute(elt);
     }
   }
@@ -145,9 +172,42 @@
   }
 }
 
+void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
+  recordIntegerSetReference(ifStmt->getCondition());
+  for (auto &childStmt : *ifStmt->getThenClause())
+    visitStatement(&childStmt);
+  if (ifStmt->hasElseClause())
+    for (auto &childStmt : *ifStmt->getElseClause())
+      visitStatement(&childStmt);
+}
+
+void ModuleState::visitForStmt(const ForStmt *forStmt) {
+  for (auto &childStmt : *forStmt)
+    visitStatement(&childStmt);
+}
+
+void ModuleState::visitOperationStmt(const OperationStmt *opStmt) {
+  // TODO: visit any attributes if necessary.
+}
+
+void ModuleState::visitStatement(const Statement *stmt) {
+  switch (stmt->getKind()) {
+  case Statement::Kind::If:
+    return visitIfStmt(cast<IfStmt>(stmt));
+  case Statement::Kind::For:
+    return visitForStmt(cast<ForStmt>(stmt));
+  case Statement::Kind::Operation:
+    return visitOperationStmt(cast<OperationStmt>(stmt));
+  default:
+    return;
+  }
+}
+
 void ModuleState::visitMLFunction(const MLFunction *fn) {
   visitType(fn->getType());
-  // TODO Visit function body statements (and attributes if required).
+  for (auto &stmt : *fn) {
+    ModuleState::visitStatement(&stmt);
+  }
 }
 
 void ModuleState::visitFunction(const Function *fn) {
@@ -161,7 +221,7 @@
   }
 }
 
-// Initializes module state, populating affine map state.
+// Initializes module state, populating affine map and integer set state.
 void ModuleState::initialize(const Module *module) {
   for (auto &fn : *module) {
     visitFunction(&fn);
@@ -194,6 +254,8 @@
 
   void printAffineMap(const AffineMap *map);
   void printAffineExpr(const AffineExpr *expr);
+  void printAffineConstraint(const AffineExpr *expr, bool isEq);
+  void printIntegerSet(const IntegerSet *set);
 
 protected:
   raw_ostream &os;
@@ -203,6 +265,8 @@
   void printFunctionResultType(const FunctionType *type);
   void printAffineMapId(int affineMapId) const;
   void printAffineMapReference(const AffineMap *affineMap);
+  void printIntegerSetId(int integerSetId) const;
+  void printIntegerSetReference(const IntegerSet *integerSet);
 
   /// This enum is used to represent the binding stength of the enclosing
   /// context that an AffineExpr is being printed in, so we can intelligently
@@ -244,6 +308,22 @@
   }
 }
 
+// Prints integer set identifier.
+void ModulePrinter::printIntegerSetId(int integerSetId) const {
+  os << "@@set" << integerSetId;
+}
+
+void ModulePrinter::printIntegerSetReference(const IntegerSet *integerSet) {
+  int setId;
+  if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
+    // The set will be printed at top of module; so print reference to its id.
+    printIntegerSetId(setId);
+  } else {
+    // Set not in module state so print inline.
+    integerSet->print(os);
+  }
+}
+
 void ModulePrinter::print(const Module *module) {
   for (const auto &map : state.getAffineMapIds()) {
     printAffineMapId(state.getAffineMapId(map));
@@ -251,6 +331,12 @@
     map->print(os);
     os << '\n';
   }
+  for (const auto &set : state.getIntegerSetIds()) {
+    printIntegerSetId(state.getIntegerSetId(set));
+    os << " = ";
+    set->print(os);
+    os << '\n';
+  }
   for (auto const &fn : *module)
     print(&fn);
 }
@@ -473,7 +559,9 @@
   if (auto *rhs = dyn_cast<AffineConstantExpr>(binOp->getRHS())) {
     if (rhs->getValue() < 0) {
       printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
-      os << " - " << -rhs->getValue() << ')';
+      os << " - " << -rhs->getValue();
+      if (enclosingTightness == BindingStrength::Strong)
+        os << ')';
       return;
     }
   }
@@ -486,6 +574,11 @@
     os << ')';
 }
 
+void ModulePrinter::printAffineConstraint(const AffineExpr *expr, bool isEq) {
+  printAffineExprInternal(expr, BindingStrength::Weak);
+  isEq ? os << " == 0" : os << " >= 0";
+}
+
 void ModulePrinter::printAffineMap(const AffineMap *map) {
   // Dimension identifiers.
   os << '(';
@@ -524,6 +617,38 @@
   os << ')';
 }
 
+void ModulePrinter::printIntegerSet(const IntegerSet *set) {
+  // Dimension identifiers.
+  os << '(';
+  for (unsigned i = 1; i < set->getNumDims(); ++i)
+    os << 'd' << i - 1 << ", ";
+  if (set->getNumDims() >= 1)
+    os << 'd' << set->getNumDims() - 1;
+  os << ')';
+
+  // Symbolic identifiers.
+  if (set->getNumSymbols() != 0) {
+    os << '[';
+    for (unsigned i = 0; i < set->getNumSymbols() - 1; ++i)
+      os << 's' << i << ", ";
+    if (set->getNumSymbols() >= 1)
+      os << 's' << set->getNumSymbols() - 1;
+    os << ']';
+  }
+
+  // Print constraints.
+  os << " : (";
+  auto numConstraints = set->getNumConstraints();
+  for (int i = 1; i < numConstraints; ++i) {
+    printAffineConstraint(set->getConstraint(i - 1), set->isEq(i - 1));
+    os << ", ";
+  }
+  if (numConstraints >= 1)
+    printAffineConstraint(set->getConstraint(numConstraints - 1),
+                          set->isEq(numConstraints - 1));
+  os << ')';
+}
+
 //===----------------------------------------------------------------------===//
 // Function printing
 //===----------------------------------------------------------------------===//
@@ -582,6 +707,9 @@
   void printAffineMap(const AffineMap *map) {
     return ModulePrinter::printAffineMapReference(map);
   }
+  void printIntegerSet(const IntegerSet *set) {
+    return ModulePrinter::printIntegerSetReference(set);
+  }
   void printAffineExpr(const AffineExpr *expr) {
     return ModulePrinter::printAffineExpr(expr);
   }
@@ -1139,7 +1267,9 @@
 }
 
 void MLFunctionPrinter::print(const IfStmt *stmt) {
-  os.indent(numSpaces) << "if () {\n";
+  os.indent(numSpaces) << "if (";
+  printIntegerSetReference(stmt->getCondition());
+  os << ") {\n";
   print(stmt->getThenClause());
   os.indent(numSpaces) << "}";
   if (stmt->hasElseClause()) {
@@ -1181,6 +1311,11 @@
   llvm::errs() << "\n";
 }
 
+void IntegerSet::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}
+
 void AffineExpr::print(raw_ostream &os) const {
   ModuleState state(/*no context is known*/ nullptr);
   ModulePrinter(os, state).printAffineExpr(this);
@@ -1191,6 +1326,11 @@
   ModulePrinter(os, state).printAffineMap(this);
 }
 
+void IntegerSet::print(raw_ostream &os) const {
+  ModuleState state(/*no context is known*/ nullptr);
+  ModulePrinter(os, state).printIntegerSet(this);
+}
+
 void SSAValue::print(raw_ostream &os) const {
   switch (getKind()) {
   case SSAValueKind::BBArgument: