Support for affine integer sets

- introduce affine integer sets into the IR
- parse and print affine integer sets (both inline or outlined) similar to
  affine maps
- use integer set for IfStmt's conditional, and implement parsing of IfStmt's
  conditional

- fixed an affine expr paren omission bug while one this.

TODO: parse/represent/print MLValue operands to affine integer set references.
PiperOrigin-RevId: 207779408
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d89c62e..5ca4e56 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -58,6 +58,8 @@
 
   // A map from affine map identifier to AffineMap.
   llvm::StringMap<AffineMap *> affineMapDefinitions;
+  // A map from integer set identifier to IntegerSet.
+  llvm::StringMap<IntegerSet *> integerSetDefinitions;
 
 private:
   ParserState(const ParserState &) = delete;
@@ -177,6 +179,8 @@
   // Polyhedral structures.
   AffineMap *parseAffineMapInline();
   AffineMap *parseAffineMapReference();
+  IntegerSet *parseIntegerSetInline();
+  IntegerSet *parseIntegerSetReference();
 
 private:
   // The Parser is subclassed and reinstantiated.  Do not add additional
@@ -717,13 +721,15 @@
 };
 
 namespace {
-/// This is a specialized parser for AffineMap's, maintaining the state
-/// transient to their bodies.
-class AffineMapParser : public Parser {
+/// This is a specialized parser for affine structures (affine maps, affine
+/// expressions, and integer sets), maintaining the state transient to their
+/// bodies.
+class AffineParser : public Parser {
 public:
-  explicit AffineMapParser(ParserState &state) : Parser(state) {}
+  explicit AffineParser(ParserState &state) : Parser(state) {}
 
   AffineMap *parseAffineMapInline();
+  IntegerSet *parseIntegerSetInline();
 
 private:
   // Binary affine op parsing.
@@ -751,6 +757,7 @@
   AffineExpr *parseAffineHighPrecOpExpr(AffineExpr *llhs,
                                         AffineHighPrecOp llhsOp,
                                         SMLoc llhsOpLoc);
+  AffineExpr *parseAffineConstraint(bool *isEq);
 
 private:
   SmallVector<std::pair<StringRef, AffineExpr *>, 4> dimsAndSymbols;
@@ -760,10 +767,9 @@
 /// Create an affine binary high precedence op expression (mul's, div's, mod).
 /// opLoc is the location of the op token to be used to report errors
 /// for non-conforming expressions.
-AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
-                                                   AffineExpr *lhs,
-                                                   AffineExpr *rhs,
-                                                   SMLoc opLoc) {
+AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
+                                                AffineExpr *lhs,
+                                                AffineExpr *rhs, SMLoc opLoc) {
   // TODO: make the error location info accurate.
   switch (op) {
   case Mul:
@@ -801,9 +807,9 @@
 }
 
 /// Create an affine binary low precedence op expression (add, sub).
-AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
-                                                   AffineExpr *lhs,
-                                                   AffineExpr *rhs) {
+AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
+                                                AffineExpr *lhs,
+                                                AffineExpr *rhs) {
   switch (op) {
   case AffineLowPrecOp::Add:
     return builder.getAddExpr(lhs, rhs);
@@ -818,7 +824,7 @@
 
 /// Consume this token if it is a lower precedence affine op (there are only two
 /// precedence levels).
-AffineLowPrecOp AffineMapParser::consumeIfLowPrecOp() {
+AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
   switch (getToken().getKind()) {
   case Token::plus:
     consumeToken(Token::plus);
@@ -833,7 +839,7 @@
 
 /// Consume this token if it is a higher precedence affine op (there are only
 /// two precedence levels)
-AffineHighPrecOp AffineMapParser::consumeIfHighPrecOp() {
+AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
   switch (getToken().getKind()) {
   case Token::star:
     consumeToken(Token::star);
@@ -861,9 +867,9 @@
 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
 /// null. llhsOpLoc is the location of the llhsOp token that will be used to
 /// report an error for non-conforming expressions.
-AffineExpr *AffineMapParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
-                                                       AffineHighPrecOp llhsOp,
-                                                       SMLoc llhsOpLoc) {
+AffineExpr *AffineParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
+                                                    AffineHighPrecOp llhsOp,
+                                                    SMLoc llhsOpLoc) {
   AffineExpr *lhs = parseAffineOperandExpr(llhs);
   if (!lhs)
     return nullptr;
@@ -892,7 +898,7 @@
 /// Parse an affine expression inside parentheses.
 ///
 ///   affine-expr ::= `(` affine-expr `)`
-AffineExpr *AffineMapParser::parseParentheticalExpr() {
+AffineExpr *AffineParser::parseParentheticalExpr() {
   if (parseToken(Token::l_paren, "expected '('"))
     return nullptr;
   if (getToken().is(Token::r_paren))
@@ -910,7 +916,7 @@
 /// Parse the negation expression.
 ///
 ///   affine-expr ::= `-` affine-expr
-AffineExpr *AffineMapParser::parseNegateExpression(AffineExpr *lhs) {
+AffineExpr *AffineParser::parseNegateExpression(AffineExpr *lhs) {
   if (parseToken(Token::minus, "expected '-'"))
     return nullptr;
 
@@ -929,7 +935,7 @@
 /// Parse a bare id that may appear in an affine expression.
 ///
 ///   affine-expr ::= bare-id
-AffineExpr *AffineMapParser::parseBareIdExpr() {
+AffineExpr *AffineParser::parseBareIdExpr() {
   if (getToken().isNot(Token::bare_identifier))
     return (emitError("expected bare identifier"), nullptr);
 
@@ -947,7 +953,7 @@
 /// Parse a positive integral constant appearing in an affine expression.
 ///
 ///   affine-expr ::= integer-literal
-AffineExpr *AffineMapParser::parseIntegerExpr() {
+AffineExpr *AffineParser::parseIntegerExpr() {
   // No need to handle negative numbers separately here. They are naturally
   // handled via the unary negation operator, although (FIXME) MININT_64 still
   // not correctly handled.
@@ -971,7 +977,7 @@
 //  operand expression, it's an op expression and will be parsed via
 //  parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -l
 //  are valid operands that will be parsed by this function.
-AffineExpr *AffineMapParser::parseAffineOperandExpr(AffineExpr *lhs) {
+AffineExpr *AffineParser::parseAffineOperandExpr(AffineExpr *lhs) {
   switch (getToken().getKind()) {
   case Token::bare_identifier:
     return parseBareIdExpr();
@@ -1021,8 +1027,8 @@
 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where (e2*e3)
 /// will be parsed using parseAffineHighPrecOpExpr().
-AffineExpr *AffineMapParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
-                                                      AffineLowPrecOp llhsOp) {
+AffineExpr *AffineParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
+                                                   AffineLowPrecOp llhsOp) {
   AffineExpr *lhs;
   if (!(lhs = parseAffineOperandExpr(llhs)))
     return nullptr;
@@ -1077,14 +1083,14 @@
 /// Additional conditions are checked depending on the production. For eg., one
 /// of the operands for `*` has to be either constant/symbolic; the second
 /// operand for floordiv, ceildiv, and mod has to be a positive integer.
-AffineExpr *AffineMapParser::parseAffineExpr() {
+AffineExpr *AffineParser::parseAffineExpr() {
   return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
 }
 
 /// Parse a dim or symbol from the lists appearing before the actual expressions
 /// of the affine map. Update our state to store the dimensional/symbolic
 /// identifier.
-ParseResult AffineMapParser::parseIdentifierDefinition(AffineExpr *idExpr) {
+ParseResult AffineParser::parseIdentifierDefinition(AffineExpr *idExpr) {
   if (getToken().isNot(Token::bare_identifier))
     return emitError("expected bare identifier");
 
@@ -1100,7 +1106,7 @@
 }
 
 /// Parse the list of symbolic identifiers to an affine map.
-ParseResult AffineMapParser::parseSymbolIdList(unsigned &numSymbols) {
+ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
   consumeToken(Token::l_square);
   auto parseElt = [&]() -> ParseResult {
     auto *symbol = AffineSymbolExpr::get(numSymbols++, getContext());
@@ -1110,7 +1116,7 @@
 }
 
 /// Parse the list of dimensional identifiers to an affine map.
-ParseResult AffineMapParser::parseDimIdList(unsigned &numDims) {
+ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
   if (parseToken(Token::l_paren,
                  "expected '(' at start of dimensional identifiers list"))
     return ParseFailure;
@@ -1129,7 +1135,7 @@
 ///  dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)`
 ///
 ///  multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
-AffineMap *AffineMapParser::parseAffineMapInline() {
+AffineMap *AffineParser::parseAffineMapInline() {
   unsigned numDims = 0, numSymbols = 0;
 
   // List of dimensional identifiers.
@@ -1201,7 +1207,7 @@
 }
 
 AffineMap *Parser::parseAffineMapInline() {
-  return AffineMapParser(state).parseAffineMapInline();
+  return AffineParser(state).parseAffineMapInline();
 }
 
 AffineMap *Parser::parseAffineMapReference() {
@@ -2100,6 +2106,7 @@
   AffineConstantExpr *parseIntConstant();
   ParseResult parseIfStmt();
   ParseResult parseElseClause(IfClause *elseClause);
+  IntegerSet *parseCondition();
   ParseResult parseStatements(StmtBlock *block);
   ParseResult parseStmtBlock(StmtBlock *block);
 };
@@ -2203,6 +2210,120 @@
   return builder.getConstantExpr((int64_t)val.getValue());
 }
 
+/// Parse condition.
+IntegerSet *MLFunctionParser::parseCondition() {
+  return parseIntegerSetReference();
+
+  // TODO: Parse operands to the integer set.
+}
+
+/// Parse an affine constraint.
+///  affine-constraint ::= affine-expr `>=` `0`
+///                      | affine-expr `==` `0`
+///
+/// isEq is set to true if the parsed constraint is an equality, false if it is
+/// an inequality (greater than or equal).
+///
+AffineExpr *AffineParser::parseAffineConstraint(bool *isEq) {
+  AffineExpr *expr = parseAffineExpr();
+  if (!expr)
+    return nullptr;
+
+  if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
+      getToken().is(Token::integer)) {
+    auto dim = getToken().getUnsignedIntegerValue();
+    if (dim.hasValue() && dim.getValue() == 0) {
+      consumeToken(Token::integer);
+      *isEq = false;
+      return expr;
+    }
+    return (emitError("expected '0' after '>='"), nullptr);
+  }
+
+  if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
+      getToken().is(Token::integer)) {
+    auto dim = getToken().getUnsignedIntegerValue();
+    if (dim.hasValue() && dim.getValue() == 0) {
+      consumeToken(Token::integer);
+      *isEq = true;
+      return expr;
+    }
+    return (emitError("expected '0' after '=='"), nullptr);
+  }
+
+  return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
+          nullptr);
+}
+
+/// Parse an integer set definition.
+///  integer-set-inline
+///                ::= dim-and-symbol-id-lists `:` affine-constraint-conjunction
+///  affine-constraint-conjunction ::= /*empty*/
+///                                 | affine-constraint (`,` affine-constraint)*
+///
+IntegerSet *AffineParser::parseIntegerSetInline() {
+  unsigned numDims = 0, numSymbols = 0;
+
+  // List of dimensional identifiers.
+  if (parseDimIdList(numDims))
+    return nullptr;
+
+  // Symbols are optional.
+  if (getToken().is(Token::l_square)) {
+    if (parseSymbolIdList(numSymbols))
+      return nullptr;
+  }
+
+  if (parseToken(Token::colon, "expected ':' or '['") ||
+      parseToken(Token::l_paren,
+                 "expected '(' at start of integer set constraint list"))
+    return nullptr;
+
+  SmallVector<AffineExpr *, 4> constraints;
+  SmallVector<bool, 4> isEqs;
+  auto parseElt = [&]() -> ParseResult {
+    bool isEq;
+    auto *elt = parseAffineConstraint(&isEq);
+    ParseResult res = elt ? ParseSuccess : ParseFailure;
+    if (elt) {
+      constraints.push_back(elt);
+      isEqs.push_back(isEq);
+    }
+    return res;
+  };
+
+  // Parse a list of affine constraints (comma-separated) .
+  // Grammar: affine-constraint-conjunct ::= `(` affine-constraint (`,`
+  // affine-constraint)* `)
+  if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+    return nullptr;
+
+  // Parsed a valid integer set.
+  return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
+}
+
+IntegerSet *Parser::parseIntegerSetInline() {
+  return AffineParser(state).parseIntegerSetInline();
+}
+
+/// Parse a reference to an integer set.
+///  integer-set ::= integer-set-id | integer-set-inline
+///  integer-set-id ::= `@@` suffix-id
+///
+IntegerSet *Parser::parseIntegerSetReference() {
+  if (getToken().is(Token::double_at_identifier)) {
+    // Parse integer set identifier and verify that it exists.
+    StringRef integerSetId = getTokenSpelling().drop_front(2);
+    if (getState().integerSetDefinitions.count(integerSetId) == 0)
+      return (emitError("undefined integer set id '" + integerSetId + "'"),
+              nullptr);
+    consumeToken(Token::double_at_identifier);
+    return getState().integerSetDefinitions[integerSetId];
+  }
+  // Try to parse an inline integer set definition.
+  return parseIntegerSetInline();
+}
+
 /// If statement.
 ///
 ///   ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
@@ -2212,15 +2333,18 @@
 ///
 ParseResult MLFunctionParser::parseIfStmt() {
   consumeToken(Token::kw_if);
-  if (parseToken(Token::l_paren, "expected ("))
+
+  if (parseToken(Token::l_paren, "expected '('"))
     return ParseFailure;
 
-  // TODO: parse condition
-
-  if (parseToken(Token::r_paren, "expected )"))
+  IntegerSet *condition = parseCondition();
+  if (!condition)
     return ParseFailure;
 
-  IfStmt *ifStmt = builder.createIf();
+  if (parseToken(Token::r_paren, "expected ')'"))
+    return ParseFailure;
+
+  IfStmt *ifStmt = builder.createIf(condition);
   IfClause *thenClause = ifStmt->getThenClause();
 
   // When parsing of an if statement body fails, the IR contains
@@ -2308,6 +2432,7 @@
 
 private:
   ParseResult parseAffineMapDef();
+  ParseResult parseIntegerSetDef();
 
   // Functions.
   ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
@@ -2330,8 +2455,8 @@
   StringRef affineMapId = getTokenSpelling().drop_front();
 
   // Check for redefinitions.
-  auto *&entry = getState().affineMapDefinitions[affineMapId];
-  if (entry)
+  auto **entry = &getState().affineMapDefinitions[affineMapId];
+  if (*entry)
     return emitError("redefinition of affine map id '" + affineMapId + "'");
 
   consumeToken(Token::hash_identifier);
@@ -2341,8 +2466,36 @@
                  "expected '=' in affine map outlined definition"))
     return ParseFailure;
 
-  entry = parseAffineMapInline();
-  if (!entry)
+  *entry = parseAffineMapInline();
+  if (!*entry)
+    return ParseFailure;
+
+  return ParseSuccess;
+}
+
+/// Integer set declaration.
+///
+///  integer-set-decl ::= integer-set-id `=` integer-set-inline
+///
+ParseResult ModuleParser::parseIntegerSetDef() {
+  assert(getToken().is(Token::double_at_identifier));
+
+  StringRef integerSetId = getTokenSpelling().drop_front(2);
+
+  // Check for redefinitions (a default entry is created if one doesn't exist)
+  auto **entry = &getState().integerSetDefinitions[integerSetId];
+  if (*entry)
+    return emitError("redefinition of integer set id '" + integerSetId + "'");
+
+  consumeToken(Token::double_at_identifier);
+
+  // Parse the '='
+  if (parseToken(Token::equal,
+                 "expected '=' in outlined integer set definition"))
+    return ParseFailure;
+
+  *entry = parseIntegerSetInline();
+  if (!*entry)
     return ParseFailure;
 
   return ParseSuccess;
@@ -2509,6 +2662,11 @@
         return ParseFailure;
       break;
 
+    case Token::double_at_identifier:
+      if (parseIntegerSetDef())
+        return ParseFailure;
+      break;
+
     case Token::kw_extfunc:
       if (parseExtFunc())
         return ParseFailure;