Implement operands for the 'if' statement.

This CL also includes two other minor changes:
- change the implemented syntax from 'if (cond)' to 'if cond', as specified by MLIR spec.
- a minor fix to the implementation of the ForStmt.

PiperOrigin-RevId: 210618122
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index a545fe1..c677f38 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -319,12 +319,14 @@
     return cloneStmt;
   }
 
-  // Creates for statement. When step is not specified, it is set to 1.
+  // Create for statement. When step is not specified, it is set to 1.
   ForStmt *createFor(Location *location, ArrayRef<MLValue *> lbOperands,
                      AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
                      AffineMap *ubMap, int64_t step = 1);
 
-  IfStmt *createIf(Location *location, IntegerSet *condition);
+  /// Create if statement.
+  IfStmt *createIf(Location *location, ArrayRef<MLValue *> operands,
+                   IntegerSet *set);
 
 private:
   StmtBlock *block = nullptr;
diff --git a/include/mlir/IR/IntegerSet.h b/include/mlir/IR/IntegerSet.h
index 7602d90..0ca2062 100644
--- a/include/mlir/IR/IntegerSet.h
+++ b/include/mlir/IR/IntegerSet.h
@@ -52,6 +52,7 @@
 
   unsigned getNumDims() const { return dimCount; }
   unsigned getNumSymbols() const { return symbolCount; }
+  unsigned getNumOperands() const { return dimCount + symbolCount; }
   unsigned getNumConstraints() const { return numConstraints; }
 
   ArrayRef<AffineExpr *> getConstraints() const {
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index c5af5d5..9570568 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -22,7 +22,6 @@
 #ifndef MLIR_IR_STATEMENTS_H
 #define MLIR_IR_STATEMENTS_H
 
-#include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLValue.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/StmtBlock.h"
@@ -32,6 +31,8 @@
 namespace mlir {
 class AffineMap;
 class AffineBound;
+class IntegerSet;
+class AffineCondition;
 
 /// Operation statements represent operations inside ML functions.
 class OperationStmt final
@@ -342,11 +343,11 @@
 
 /// AffineBound represents a lower or upper bound in the for statement.
 /// This class does not own the underlying operands. Instead, it refers
-/// to the operands stored in the ForStmt. It's life span should not exceed
+/// to the operands stored in the ForStmt. Its life span should not exceed
 /// that of the for statement it refers to.
 class AffineBound {
 public:
-  const ForStmt *getOwner() const { return &stmt; }
+  const ForStmt *getForStmt() const { return &stmt; }
   AffineMap *getMap() const { return map; }
 
   unsigned getNumOperands() const { return opEnd - opStart; }
@@ -374,8 +375,12 @@
   }
 
 private:
+  // 'for' statement that contains this bound.
   const ForStmt &stmt;
+  // Start and end positions of this affine bound operands in the list of
+  // the containing 'for' statement operands.
   unsigned opStart, opEnd;
+  // Affine map for this bound.
   AffineMap *map;
 
   AffineBound(const ForStmt &stmt, const unsigned opStart, const unsigned opEnd,
@@ -412,18 +417,74 @@
 /// If statement restricts execution to a subset of the loop iteration space.
 class IfStmt : public Statement {
 public:
-  explicit IfStmt(Location *location, IntegerSet *condition);
+  static IfStmt *create(Location *location, ArrayRef<MLValue *> operands,
+                        IntegerSet *set);
   ~IfStmt();
 
+  //===--------------------------------------------------------------------===//
+  // Then, else, condition.
+  //===--------------------------------------------------------------------===//
+
   IfClause *getThen() const { return thenClause; }
   IfClause *getElse() const { return elseClause; }
-  IntegerSet *getCondition() const { return condition; }
   bool hasElse() const { return elseClause != nullptr; }
+
   IfClause *createElse() {
     assert(elseClause == nullptr && "already has an else clause!");
     return (elseClause = new IfClause(this));
   }
 
+  const AffineCondition getCondition() const;
+
+  IntegerSet *getIntegerSet() const { return set; }
+
+  //===--------------------------------------------------------------------===//
+  // Operands
+  //===--------------------------------------------------------------------===//
+
+  /// Operand iterators.
+  using operand_iterator = OperandIterator<IfStmt, MLValue>;
+  using const_operand_iterator = OperandIterator<const IfStmt, const MLValue>;
+
+  /// Operand iterator range.
+  using operand_range = llvm::iterator_range<operand_iterator>;
+  using const_operand_range = llvm::iterator_range<const_operand_iterator>;
+
+  unsigned getNumOperands() const { return operands.size(); }
+
+  MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
+  const MLValue *getOperand(unsigned idx) const {
+    return getStmtOperand(idx).get();
+  }
+  void setOperand(unsigned idx, MLValue *value) {
+    getStmtOperand(idx).set(value);
+  }
+
+  operand_iterator operand_begin() { return operand_iterator(this, 0); }
+  operand_iterator operand_end() {
+    return operand_iterator(this, getNumOperands());
+  }
+
+  const_operand_iterator operand_begin() const {
+    return const_operand_iterator(this, 0);
+  }
+  const_operand_iterator operand_end() const {
+    return const_operand_iterator(this, getNumOperands());
+  }
+
+  ArrayRef<StmtOperand> getStmtOperands() const { return operands; }
+  MutableArrayRef<StmtOperand> getStmtOperands() { return operands; }
+  StmtOperand &getStmtOperand(unsigned idx) { return getStmtOperands()[idx]; }
+  const StmtOperand &getStmtOperand(unsigned idx) const {
+    return getStmtOperands()[idx];
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
+
+  MLIRContext *getContext() const;
+
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Statement *stmt) {
     return stmt->getKind() == Kind::If;
@@ -434,11 +495,38 @@
   // store the IfClause object for it inline to save an extra allocation.
   IfClause *thenClause;
   IfClause *elseClause;
-  // TODO(shpeisman): please name the ifStmt's conditional encapsulating
-  // IntegerSet + its operands as AffineCondition.
+
   // The integer set capturing the conditional guard.
-  IntegerSet *condition;
-  // TODO: arguments to integer set
+  IntegerSet *set;
+
+  // Condition operands.
+  std::vector<StmtOperand> operands;
+
+  explicit IfStmt(Location *location, unsigned numOperands, IntegerSet *set);
+};
+
+/// AffineCondition represents a condition of the 'if' statement.
+/// Its life span should not exceed that of the objects it refers to.
+/// AffineCondition does not provide its own methods for iterating over
+/// the operands since the iterators of the if statement accomplish
+/// the same purpose.
+///
+/// AffineCondition is trivially copyable, so it should be passed by value.
+class AffineCondition {
+public:
+  const IfStmt *getIfStmt() const { return &stmt; }
+  IntegerSet *getSet() const { return set; }
+
+private:
+  // 'if' statement that contains this affine condition.
+  const IfStmt &stmt;
+  // Integer set for this affine condition.
+  IntegerSet *set;
+
+  AffineCondition(const IfStmt &stmt, const IntegerSet *set)
+      : stmt(stmt), set(const_cast<IntegerSet *>(set)) {}
+
+  friend class IfStmt;
 };
 } // end namespace mlir
 
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 2cc20ac..21dd4c6 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -186,7 +186,7 @@
 }
 
 void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
-  recordIntegerSetReference(ifStmt->getCondition());
+  recordIntegerSetReference(ifStmt->getIntegerSet());
   for (auto &childStmt : *ifStmt->getThen())
     visitStatement(&childStmt);
   if (ifStmt->hasElse())
@@ -1406,9 +1406,11 @@
 }
 
 void MLFunctionPrinter::print(const IfStmt *stmt) {
-  os.indent(numSpaces) << "if (";
-  printIntegerSetReference(stmt->getCondition());
-  os << ") {\n";
+  os.indent(numSpaces) << "if ";
+  IntegerSet *set = stmt->getIntegerSet();
+  printIntegerSetReference(set);
+  printDimAndSymbolList(stmt->getStmtOperands(), set->getNumDims());
+  os << " {\n";
   print(stmt->getThen());
   os.indent(numSpaces) << "}";
   if (stmt->hasElse()) {
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 8d0f223..a03de2f 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -254,8 +254,9 @@
   return stmt;
 }
 
-IfStmt *MLFuncBuilder::createIf(Location *location, IntegerSet *condition) {
-  auto *stmt = new IfStmt(location, condition);
+IfStmt *MLFuncBuilder::createIf(Location *location,
+                                ArrayRef<MLValue *> operands, IntegerSet *set) {
+  auto *stmt = IfStmt::create(location, operands, set);
   block->getStatements().insert(insertPoint, stmt);
   return stmt;
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 3f95a94..673aa43 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -15,7 +15,9 @@
 // limitations under the License.
 // =============================================================================
 
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardOps.h"
@@ -115,8 +117,7 @@
   case Kind::For:
     return cast<ForStmt>(this)->getNumOperands();
   case Kind::If:
-    // TODO: query IfStmt once it has operands.
-    return 0;
+    return cast<IfStmt>(this)->getNumOperands();
   }
 }
 
@@ -127,8 +128,7 @@
   case Kind::For:
     return cast<ForStmt>(this)->getStmtOperands();
   case Kind::If:
-    // TODO: query IfStmt once it has operands.
-    return {};
+    return cast<IfStmt>(this)->getStmtOperands();
   }
 }
 
@@ -301,7 +301,9 @@
                  AffineMap *ubMap, int64_t step, MLIRContext *context)
     : Statement(Kind::For, location),
       MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
-      StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {}
+      StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
+  operands.reserve(numOperands);
+}
 
 const AffineBound ForStmt::getLowerBound() const {
   return AffineBound(*this, 0, lbMap->getNumOperands(), lbMap);
@@ -357,18 +359,47 @@
 // IfStmt
 //===----------------------------------------------------------------------===//
 
-IfStmt::IfStmt(Location *location, IntegerSet *condition)
+IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet *set)
     : Statement(Kind::If, location), thenClause(new IfClause(this)),
-      elseClause(nullptr), condition(condition) {}
+      elseClause(nullptr), set(set) {
+  operands.reserve(numOperands);
+}
 
 IfStmt::~IfStmt() {
   delete thenClause;
   if (elseClause)
     delete elseClause;
-  // An IfStmt's IntegerSet 'condition' should not be deleted since it is
+  // An IfStmt's IntegerSet 'set' should not be deleted since it is
   // allocated through MLIRContext's bump pointer allocator.
 }
 
+IfStmt *IfStmt::create(Location *location, ArrayRef<MLValue *> operands,
+                       IntegerSet *set) {
+  unsigned numOperands = operands.size();
+  assert(numOperands == set->getNumOperands() &&
+         "operand cound does not match the integer set operand count");
+
+  IfStmt *stmt = new IfStmt(location, numOperands, set);
+
+  for (auto *op : operands)
+    stmt->operands.emplace_back(StmtOperand(stmt, op));
+
+  return stmt;
+}
+
+const AffineCondition IfStmt::getCondition() const {
+  return AffineCondition(*this, set);
+}
+
+MLIRContext *IfStmt::getContext() const {
+  // Check for degenerate case of if statement with no operands.
+  // This is unlikely, but legal.
+  if (operands.empty())
+    return findFunction()->getContext();
+
+  return getOperand(0)->getType()->getContext();
+}
+
 //===----------------------------------------------------------------------===//
 // Statement Cloning
 //===----------------------------------------------------------------------===//
@@ -428,9 +459,7 @@
 
   // Otherwise, we must have an If statement.
   auto *ifStmt = cast<IfStmt>(this);
-  auto *newIf = new IfStmt(getLoc(), ifStmt->getCondition());
-
-  // TODO: remap operands with remapOperand when if statements have them.
+  auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet());
 
   auto *resultThen = newIf->getThen();
   for (auto &childStmt : *ifStmt->getThen())
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index deda932..48f3643 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -372,7 +372,7 @@
 
     // TODO: check that operation is not a return statement unless it's
     // the last one in the function.
-    // TODO: check that loop bounds are properly formed.
+    // TODO: check that loop bounds and if conditions are properly formed.
     if (verifyReturn())
       return true;
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index ebeb196..0935090 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
@@ -2196,12 +2197,12 @@
   ParseResult parseForStmt();
   ParseResult parseIntConstant(int64_t &val);
   ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
-                                    const AffineMap *map);
+                                    unsigned numDims, unsigned numOperands,
+                                    const char *affineStructName);
   ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap *&map,
                          bool isLower);
   ParseResult parseIfStmt();
   ParseResult parseElseClause(IfClause *elseClause);
-  IntegerSet *parseCondition();
   ParseResult parseStatements(StmtBlock *block);
   ParseResult parseStmtBlock(StmtBlock *block);
 };
@@ -2306,7 +2307,8 @@
 ///
 ParseResult
 MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
-                                        const AffineMap *map) {
+                                        unsigned numDims, unsigned numOperands,
+                                        const char *affineStructName) {
   if (parseToken(Token::l_paren, "expected '('"))
     return ParseFailure;
 
@@ -2316,8 +2318,9 @@
   if (parseToken(Token::r_paren, "expected ')'"))
     return ParseFailure;
 
-  if (map->getNumDims() != opInfo.size())
-    return emitError("dim operand count and affine map dim count must match");
+  if (numDims != opInfo.size())
+    return emitError("dim operand count and " + Twine(affineStructName) +
+                     " dim count must match");
 
   if (consumeIf(Token::l_square)) {
     parseOptionalSSAUseList(opInfo);
@@ -2325,13 +2328,12 @@
       return ParseFailure;
   }
 
-  if (map->getNumOperands() != opInfo.size())
-    return emitError(
-        "symbol operand count and affine map symbol count must match");
+  if (numOperands != opInfo.size())
+    return emitError("symbol operand count and " + Twine(affineStructName) +
+                     " symbol count must match");
 
   // Resolve SSA uses.
   Type *affineIntType = builder.getAffineIntType();
-  unsigned numDims = map->getNumDims();
   for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
     SSAValue *sval = resolveSSAUse(opInfo[i], affineIntType);
     if (!sval)
@@ -2340,10 +2342,10 @@
     auto *v = cast<MLValue>(sval);
     if (i < numDims && !v->isValidDim())
       return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
-                                          "' cannot be used as dimension id");
+                                          "' cannot be used as a dimension id");
     if (i >= numDims && !v->isValidSymbol())
       return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
-                                          "' cannot be used as symbol");
+                                          "' cannot be used as a symbol");
     operands.push_back(v);
   }
 
@@ -2370,7 +2372,8 @@
     if (!map)
       return ParseFailure;
 
-    if (parseDimAndSymbolList(operands, map))
+    if (parseDimAndSymbolList(operands, map->getNumDims(),
+                              map->getNumOperands(), "affine map"))
       return ParseFailure;
     return ParseSuccess;
   }
@@ -2408,13 +2411,6 @@
   return ParseSuccess;
 }
 
-/// Parse condition.
-IntegerSet *MLFunctionParser::parseCondition() {
-  return parseIntegerSetReference();
-
-  // TODO: Parse operands to the integer set.
-}
-
 /// Parse an affine constraint.
 ///  affine-constraint ::= affine-expr `>=` `0`
 ///                      | affine-expr `==` `0`
@@ -2509,6 +2505,7 @@
 ///  integer-set-id ::= `@@` suffix-id
 ///
 IntegerSet *Parser::parseIntegerSetReference() {
+  // TODO: change '@@' integer set prefix to '#'.
   if (getToken().is(Token::double_at_identifier)) {
     // Parse integer set identifier and verify that it exists.
     StringRef integerSetId = getTokenSpelling().drop_front(2);
@@ -2533,17 +2530,18 @@
   auto loc = getToken().getLoc();
   consumeToken(Token::kw_if);
 
-  if (parseToken(Token::l_paren, "expected '('"))
+  IntegerSet *set = parseIntegerSetReference();
+  if (!set)
     return ParseFailure;
 
-  IntegerSet *condition = parseCondition();
-  if (!condition)
+  SmallVector<MLValue *, 4> operands;
+  if (parseDimAndSymbolList(operands, set->getNumDims(), set->getNumOperands(),
+                            "integer set"))
     return ParseFailure;
 
-  if (parseToken(Token::r_paren, "expected ')'"))
-    return ParseFailure;
+  IfStmt *ifStmt =
+      builder.createIf(getEncodedSourceLocation(loc), operands, set);
 
-  IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), condition);
   IfClause *thenClause = ifStmt->getThen();
 
   // When parsing of an if statement body fails, the IR contains
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 6b3a1b6..ab113e9 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -181,7 +181,7 @@
 
 mlfunc @invalid_if_conditional1() {
   for %i = 1 to 10 {
-    if () // expected-error {{expected '(' at start of dimensional identifiers list}}
+    if () { // expected-error {{expected ':' or '['}}
   }
 }
 
@@ -189,7 +189,7 @@
 
 mlfunc @invalid_if_conditional2() {
   for %i = 1 to 10 {
-    if ((i)[N] : (i >= ))  // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
+    if (i)[N] : (i >= )  // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
   }
 }
 
@@ -197,7 +197,7 @@
 
 mlfunc @invalid_if_conditional3() {
   for %i = 1 to 10 {
-    if ((i)[N] : (i == 1)) // expected-error {{expected '0' after '=='}}
+    if (i)[N] : (i == 1) // expected-error {{expected '0' after '=='}}
   }
 }
 
@@ -205,7 +205,7 @@
 
 mlfunc @invalid_if_conditional4() {
   for %i = 1 to 10 {
-    if ((i)[N] : (i >= 2)) // expected-error {{expected '0' after '>='}}
+    if (i)[N] : (i >= 2) // expected-error {{expected '0' after '>='}}
   }
 }
 
@@ -213,7 +213,7 @@
 
 mlfunc @invalid_if_conditional5() {
   for %i = 1 to 10 {
-    if ((i)[N] : (i <= 0 )) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
+    if (i)[N] : (i <= 0 ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
   }
 }
 
@@ -221,7 +221,7 @@
 
 mlfunc @invalid_if_conditional6() {
   for %i = 1 to 10 {
-    if ((i) : (i)) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
+    if (i) : (i) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
   }
 }
 
@@ -229,7 +229,7 @@
 // TODO (support if (1)?
 mlfunc @invalid_if_conditional7() {
   for %i = 1 to 10 {
-    if ((i) : (1)) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
+    if (i) : (1) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
   }
 }
 
@@ -446,7 +446,7 @@
   for %i = 1 to 100 {
     %a = "foo"(%N) : (affineint)->(affineint)
     for %j = 1 to #map1(%a)[%i] {
-    // expected-error@-1 {{value '%a' cannot be used as dimension id}}
+    // expected-error@-1 {{value '%a' cannot be used as a dimension id}}
     }
   }
   return
@@ -461,7 +461,7 @@
     %a = "foo"(%N) : (affineint)->(affineint)
     %w = affine_apply (i)->(i+1) (%a)
     for %j = 1 to #map1(%w)[%i] {
-    // expected-error@-1 {{value '%w' cannot be used as dimension id}}
+    // expected-error@-1 {{value '%w' cannot be used as a dimension id}}
     }
   }
   return
@@ -475,7 +475,7 @@
   for %i = 1 to 100 {
     %a = "foo"(%N) : (affineint)->(affineint)
     for %j = 1 to #map1(%N)[%i] {
-    // expected-error@-1 {{value '%i' cannot be used as symbol}}
+    // expected-error@-1 {{value '%i' cannot be used as a symbol}}
     }
   }
   return
@@ -489,7 +489,7 @@
   for %i = 1 to 100 {
     %a = "foo"(%N) : (affineint)->(affineint)
     for %j = 1 to #map1(%N)[%a] {
-    // expected-error@-1 {{value '%a' cannot be used as symbol}}
+    // expected-error@-1 {{value '%a' cannot be used as a symbol}}
     }
   }
   return
@@ -503,7 +503,7 @@
   for %i = 1 to 100 {
     %w = affine_apply (i)->(i+1) (%i)
     for %j = 1 to #map1(%i)[%w] {
-    // expected-error@-1 {{value '%w' cannot be used as symbol}}
+    // expected-error@-1 {{value '%w' cannot be used as a symbol}}
     }
   }
   return
@@ -541,3 +541,28 @@
   }
   return
 }
+
+// -----
+@@set0 = (i)[N] : (i >= 0, N - i >= 0)
+
+mlfunc @invalid_if_operands1(%N : affineint) {
+  for %i = 1 to 10 {
+    if @@set0(%i) {
+    // expected-error@-1 {{symbol operand count and integer set symbol count must match}}
+
+// -----
+@@set0 = (i)[N] : (i >= 0, N - i >= 0)
+
+mlfunc @invalid_if_operands2(%N : affineint) {
+  for %i = 1 to 10 {
+    if @@set0()[%N] {
+    // expected-error@-1 {{dim operand count and integer set dim count must match}}
+
+// -----
+@@set0 = (i)[N] : (i >= 0, N - i >= 0)
+
+mlfunc @invalid_if_operands3(%N : affineint) {
+  for %i = 1 to 10 {
+    if @@set0(%i)[%i] {
+    // expected-error@-1 {{value '%i' cannot be used as a symbol}}
+
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index e12e2a6..9624b75 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -33,8 +33,8 @@
 // CHECK: #map{{[0-9]+}} = (d0)[s0] -> (d0 + s0, d0 - s0)
 #bound_map2 = (i)[s] -> (i + s, i - s)
 
-// CHECK-DAG: @@set0 = (d0)[s0] : (d0 >= 0, d0 * -1 + s0 >= 0, s0 - 5 == 0)
-@@set0 = (i)[N] : (i >= 0, -i + N >= 0, N - 5 == 0)
+// CHECK-DAG: @@set0 = (d0)[s0, s1] : (d0 >= 0, d0 * -1 + s0 >= 0, s0 - 5 == 0, d0 * -1 + s1 + 1 >= 0)
+@@set0 = (i)[N, M] : (i >= 0, -i + N >= 0, N - 5 == 0, -i + M + 1 >= 0)
 
 // CHECK-DAG: @@set1 = (d0)[s0] : (d0 - 2 >= 0, d0 * -1 + 4 >= 0)
 
@@ -237,15 +237,16 @@
   return    // CHECK:   return
 }           // CHECK: }
 
-// CHECK-LABEL: mlfunc @ifstmt(%arg0 : i32) {
-mlfunc @ifstmt(%N: i32) {
-  for %i = 1 to 10 {    // CHECK   for %i0 = 1 to 10 {
-    if (@@set0) {        // CHECK     if (@@set0) {
+// CHECK-LABEL: mlfunc @ifstmt(%arg0 : affineint) {
+mlfunc @ifstmt(%N: affineint) {
+  %c = constant 200 : affineint // CHECK   %c200 = constant 200
+  for %i = 1 to 10 {   	        // CHECK   for %i0 = 1 to 10 {
+    if @@set0(%i)[%N, %c] {     // CHECK     if @@set0(%i0)[%arg0, %c200] {
       %x = constant 1 : i32
        // CHECK: %c1_i32 = constant 1 : i32
       %y = "add"(%x, %i) : (i32, affineint) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, affineint) -> i32
       %z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32
-    } else if ((i)[N] : (i - 2 >= 0, 4 - i >= 0))  {      // CHECK     } else if (@@set1) {
+    } else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N]  {      // CHECK     } else if (@@set1(%i0)[%arg0]) {
       // CHECK: %c1 = constant 1 : affineint
       %u = constant 1 : affineint
       // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1]
@@ -280,8 +281,8 @@
   // CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
   "foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
 
-  // CHECK: "foo"() {fn: @attributes : () -> (), if: @ifstmt : (i32) -> ()} : () -> ()
-  "foo"() {fn: @attributes : () -> (), if: @ifstmt : (i32) -> ()} : () -> ()
+  // CHECK: "foo"() {fn: @attributes : () -> (), if: @ifstmt : (affineint) -> ()} : () -> ()
+  "foo"() {fn: @attributes : () -> (), if: @ifstmt : (affineint) -> ()} : () -> ()
   return
 }