Implement operands for the lower and upper bounds of the for statement.

This revamps implementation of the loop bounds in the ForStmt, using general representation that supports operands. The frequent case of constant bounds is supported
via special access methods.

This also includes:
- Operand iterators for the Statement class.
- OpPointer::is() method to query the class of the Operation.
- Support for the bound shorthand notation parsing and printing.
- Validity checks for the bound operands used as dim ids and symbols

I didn't mean this CL to be so large. It just happened this way, as one thing led to another.

PiperOrigin-RevId: 210204858
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 95cd8ff..2b6d680 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -39,6 +39,15 @@
   return true;
 }
 
+bool AffineMap::isSingleConstant() const {
+  return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
+}
+
+int64_t AffineMap::getSingleConstantValue() const {
+  assert(isSingleConstant() && "map must have a single constant result");
+  return dyn_cast<AffineConstantExpr>(getResult(0))->getValue();
+}
+
 /// Simplify add expression. Return nullptr if it can't be simplified.
 AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
                                             MLIRContext *context) {
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index cd90e3e..2cc20ac 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -101,6 +101,18 @@
     }
   }
 
+  // Return true if this map could be printed using the shorthand form.
+  static bool hasShorthandForm(const AffineMap *boundMap) {
+    if (boundMap->isSingleConstant())
+      return true;
+
+    // Check if the affine map is single dim id or single symbol identity -
+    // (i)->(i) or ()[s]->(i)
+    return boundMap->getNumOperands() == 1 && boundMap->getNumResults() == 1 &&
+           (isa<AffineDimExpr>(boundMap->getResult(0)) ||
+            isa<AffineSymbolExpr>(boundMap->getResult(0)));
+  }
+
   // Visit functions.
   void visitFunction(const Function *fn);
   void visitExtFunction(const ExtFunction *fn);
@@ -183,6 +195,14 @@
 }
 
 void ModuleState::visitForStmt(const ForStmt *forStmt) {
+  AffineMap *lbMap = forStmt->getLowerBoundMap();
+  if (!hasShorthandForm(lbMap))
+    recordAffineMapReference(lbMap);
+
+  AffineMap *ubMap = forStmt->getUpperBoundMap();
+  if (!hasShorthandForm(ubMap))
+    recordAffineMapReference(ubMap);
+
   for (auto &childStmt : *forStmt)
     visitStatement(&childStmt);
 }
@@ -1216,20 +1236,24 @@
 
   const MLFunction *getFunction() const { return function; }
 
-  // Prints ML function
+  // Prints ML function.
   void print();
 
-  // Prints ML function signature
+  // Prints ML function signature.
   void printFunctionSignature();
 
-  // Methods to print ML function statements
+  // Methods to print ML function statements.
   void print(const Statement *stmt);
   void print(const OperationStmt *stmt);
   void print(const ForStmt *stmt);
   void print(const IfStmt *stmt);
   void print(const StmtBlock *block);
 
-  // Number of spaces used for indenting nested statements
+  // Print loop bounds.
+  void printDimAndSymbolList(ArrayRef<StmtOperand> ops, unsigned numDims);
+  void printBound(AffineBound bound, const char *prefix);
+
+  // Number of spaces used for indenting nested statements.
   const static unsigned indentWidth = 2;
 
 private:
@@ -1249,7 +1273,7 @@
 
 /// Number all of the SSA values in this ML function.
 void MLFunctionPrinter::numberValues() {
-  // Numbers ML function arguments
+  // Numbers ML function arguments.
   for (auto *arg : function->getArguments())
     numberValueID(arg);
 
@@ -1323,8 +1347,11 @@
 void MLFunctionPrinter::print(const ForStmt *stmt) {
   os.indent(numSpaces) << "for ";
   printOperand(stmt);
-  os << " = " << *stmt->getLowerBound();
-  os << " to " << *stmt->getUpperBound();
+  os << " = ";
+  printBound(stmt->getLowerBound(), "max");
+  os << " to ";
+  printBound(stmt->getUpperBound(), "min");
+
   if (stmt->getStep() != 1)
     os << " step " << stmt->getStep();
 
@@ -1333,6 +1360,51 @@
   os.indent(numSpaces) << "}";
 }
 
+void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<StmtOperand> ops,
+                                              unsigned numDims) {
+  auto printComma = [&]() { os << ", "; };
+  os << '(';
+  interleave(ops.begin(), ops.begin() + numDims,
+             [&](const StmtOperand &v) { printOperand(v.get()); }, printComma);
+  os << ')';
+
+  if (numDims < ops.size()) {
+    os << '[';
+    interleave(ops.begin() + numDims, ops.end(),
+               [&](const StmtOperand &v) { printOperand(v.get()); },
+               printComma);
+    os << ']';
+  }
+}
+
+void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
+  AffineMap *map = bound.getMap();
+
+  // Check if this bound should be printed using short-hand notation.
+  if (map->getNumResults() == 1) {
+    AffineExpr *expr = map->getResult(0);
+
+    // Print constant bound.
+    if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      os << constExpr->getValue();
+      return;
+    }
+
+    // Print bound that consists of a single SSA id.
+    if (isa<AffineDimExpr>(expr) || isa<AffineSymbolExpr>(expr)) {
+      printOperand(bound.getOperand(0));
+      return;
+    }
+  } else {
+    // Map has multiple results. Print 'min' or 'max' prefix.
+    os << prefix << ' ';
+  }
+
+  // Print the map and the operands.
+  printAffineMapReference(map);
+  printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims());
+}
+
 void MLFunctionPrinter::print(const IfStmt *stmt) {
   os.indent(numSpaces) << "if (";
   printIntegerSetReference(stmt->getCondition());
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index ef468de..376d54e 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -163,6 +163,18 @@
   return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
 }
 
+AffineMap *Builder::getConstantMap(int64_t val) {
+  return AffineMap::get(0, 0, getConstantExpr(val), {}, context);
+}
+
+AffineMap *Builder::getDimIdentityMap() {
+  return AffineMap::get(1, 0, getDimExpr(0), {}, context);
+}
+
+AffineMap *Builder::getSymbolIdentityMap() {
+  return AffineMap::get(0, 1, getSymbolExpr(0), {}, context);
+}
+
 //===----------------------------------------------------------------------===//
 // CFG function elements.
 //===----------------------------------------------------------------------===//
@@ -216,10 +228,12 @@
 }
 
 ForStmt *MLFuncBuilder::createFor(Attribute *location,
-                                  AffineConstantExpr *lowerBound,
-                                  AffineConstantExpr *upperBound,
-                                  int64_t step) {
-  auto *stmt = new ForStmt(location, lowerBound, upperBound, step, context);
+                                  ArrayRef<MLValue *> lbOperands,
+                                  AffineMap *lbMap,
+                                  ArrayRef<MLValue *> ubOperands,
+                                  AffineMap *ubMap, int64_t step) {
+  auto *stmt = ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap,
+                               step, context);
   block->getStatements().insert(insertPoint, stmt);
   return stmt;
 }
diff --git a/lib/IR/SSAValue.cpp b/lib/IR/SSAValue.cpp
index 50f7fb0..04de6db 100644
--- a/lib/IR/SSAValue.cpp
+++ b/lib/IR/SSAValue.cpp
@@ -1,4 +1,4 @@
-//===- Instructions.cpp - MLIR CFGFunction Instruction Classes ------------===//
+//===- SSAValue.cpp - MLIR SSAValue Classes ------------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -17,6 +17,7 @@
 
 #include "mlir/IR/SSAValue.h"
 #include "mlir/IR/Instructions.h"
+#include "mlir/IR/StandardOps.h"
 #include "mlir/IR/Statements.h"
 using namespace mlir;
 
@@ -43,3 +44,43 @@
     return stmt;
   return nullptr;
 }
+
+//===----------------------------------------------------------------------===//
+// MLValue implementation.
+//===----------------------------------------------------------------------===//
+
+// MLValue can be used a a dimension id if it is valid as a symbol, or
+// it is an induction variable, or it is a result of affine apply operation
+// with dimension id arguments.
+bool MLValue::isValidDim() const {
+  if (auto *stmt = getDefiningStmt()) {
+    // Top level statement or constant operation is ok.
+    if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>())
+      return true;
+    // Affine apply operation is ok if all of its operands are ok.
+    if (auto op = stmt->getAs<AffineApplyOp>())
+      return op->isValidDim();
+    return false;
+  }
+  // This value is either a function argument or an induction variable. Both are
+  // ok.
+  return true;
+}
+
+// MLValue can be used as a symbol if it is a constant, or it is defined at
+// the top level, or it is a result of affine apply operation with symbol
+// arguments.
+bool MLValue::isValidSymbol() const {
+  if (auto *stmt = getDefiningStmt()) {
+    // Top level statement or constant operation is ok.
+    if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>())
+      return true;
+    // Affine apply operation is ok if all of its operands are ok.
+    if (auto op = stmt->getAs<AffineApplyOp>())
+      return op->isValidSymbol();
+    return false;
+  }
+  // This value is either a function argument or an induction variable.
+  // Function argument is ok, induction variable is not.
+  return isa<MLFuncArgument>(this);
+}
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 18507b0..427de7a 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -152,6 +152,29 @@
   return nullptr;
 }
 
+// The result of the affine apply operation can be used as a dimension id if it
+// is a CFG value or if it is an MLValue, and all the operands are valid
+// dimension ids.
+bool AffineApplyOp::isValidDim() const {
+  for (auto *op : getOperands()) {
+    if (auto *v = dyn_cast<MLValue>(op))
+      if (!v->isValidDim())
+        return false;
+  }
+  return true;
+}
+
+// The result of the affine apply operation can be used as a symbol if it is
+// a CFG value or if it is an MLValue, and all the operands are symbols.
+bool AffineApplyOp::isValidSymbol() const {
+  for (auto *op : getOperands()) {
+    if (auto *v = dyn_cast<MLValue>(op))
+      if (!v->isValidSymbol())
+        return false;
+  }
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 8a6248e..3746f71 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -15,12 +15,14 @@
 // limitations under the License.
 // =============================================================================
 
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardOps.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
-#include "mlir/IR/Types.h"
 #include "llvm/ADT/DenseMap.h"
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -66,7 +68,7 @@
   case Kind::Operation:
     return cast<OperationStmt>(this)->getContext();
   case Kind::For:
-    return cast<ForStmt>(this)->getType()->getContext();
+    return cast<ForStmt>(this)->getContext();
   case Kind::If:
     // TODO(shpeisman): When if statement has value operands, we can get a
     // context from their type.
@@ -94,6 +96,42 @@
   return nlc.numNestedLoops == 1;
 }
 
+MLValue *Statement::getOperand(unsigned idx) {
+  return getStmtOperand(idx).get();
+}
+
+const MLValue *Statement::getOperand(unsigned idx) const {
+  return getStmtOperand(idx).get();
+}
+
+void Statement::setOperand(unsigned idx, MLValue *value) {
+  getStmtOperand(idx).set(value);
+}
+
+unsigned Statement::getNumOperands() const {
+  switch (getKind()) {
+  case Kind::Operation:
+    return cast<OperationStmt>(this)->getNumOperands();
+  case Kind::For:
+    return cast<ForStmt>(this)->getNumOperands();
+  case Kind::If:
+    // TODO: query IfStmt once it has operands.
+    return 0;
+  }
+}
+
+MutableArrayRef<StmtOperand> Statement::getStmtOperands() {
+  switch (getKind()) {
+  case Kind::Operation:
+    return cast<OperationStmt>(this)->getStmtOperands();
+  case Kind::For:
+    return cast<ForStmt>(this)->getStmtOperands();
+  case Kind::If:
+    // TODO: query IfStmt once it has operands.
+    return {};
+  }
+}
+
 /// Emit a note about this statement, reporting up to any diagnostic
 /// handlers that may be listening.
 void Statement::emitNote(const Twine &message) const {
@@ -231,17 +269,89 @@
   return findFunction()->getContext();
 }
 
+bool OperationStmt::isReturn() const { return is<ReturnOp>(); }
+
 //===----------------------------------------------------------------------===//
 // ForStmt
 //===----------------------------------------------------------------------===//
 
-ForStmt::ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
-                 AffineConstantExpr *upperBound, int64_t step,
-                 MLIRContext *context)
+ForStmt *ForStmt::create(Attribute *location, ArrayRef<MLValue *> lbOperands,
+                         AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
+                         AffineMap *ubMap, int64_t step, MLIRContext *context) {
+  assert(lbOperands.size() == lbMap->getNumOperands() &&
+         "lower bound operand count does not match the affine map");
+  assert(ubOperands.size() == ubMap->getNumOperands() &&
+         "upper bound operand count does not match the affine map");
+
+  unsigned numOperands = lbOperands.size() + ubOperands.size();
+  ForStmt *stmt =
+      new ForStmt(location, numOperands, lbMap, ubMap, step, context);
+
+  unsigned i = 0;
+  for (unsigned e = lbOperands.size(); i != e; ++i)
+    stmt->operands.emplace_back(StmtOperand(stmt, lbOperands[i]));
+
+  for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
+    stmt->operands.emplace_back(StmtOperand(stmt, ubOperands[j]));
+
+  return stmt;
+}
+
+ForStmt::ForStmt(Attribute *location, unsigned numOperands, AffineMap *lbMap,
+                 AffineMap *ubMap, int64_t step, MLIRContext *context)
     : Statement(Kind::For, location),
       MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
-      StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
-      upperBound(upperBound), step(step) {}
+      StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {}
+
+const AffineBound ForStmt::getLowerBound() const {
+  return AffineBound(*this, 0, lbMap->getNumOperands(), lbMap);
+}
+
+const AffineBound ForStmt::getUpperBound() const {
+  return AffineBound(*this, lbMap->getNumOperands(), getNumOperands(), ubMap);
+}
+
+void ForStmt::setLowerBound(ArrayRef<MLValue *> operands, AffineMap *map) {
+  // TODO: handle the case when number of existing or new operands is non-zero.
+  assert(getNumOperands() == 0 && operands.empty());
+
+  this->lbMap = map;
+}
+
+void ForStmt::setUpperBound(ArrayRef<MLValue *> operands, AffineMap *map) {
+  // TODO: handle the case when number of existing or new operands is non-zero.
+  assert(getNumOperands() == 0 && operands.empty());
+
+  this->ubMap = map;
+}
+
+bool ForStmt::hasConstantLowerBound() const {
+  return lbMap->isSingleConstant();
+}
+
+bool ForStmt::hasConstantUpperBound() const {
+  return ubMap->isSingleConstant();
+}
+
+int64_t ForStmt::getConstantLowerBound() const {
+  return lbMap->getSingleConstantValue();
+}
+
+int64_t ForStmt::getConstantUpperBound() const {
+  return ubMap->getSingleConstantValue();
+}
+
+void ForStmt::setConstantLowerBound(int64_t value) {
+  MLIRContext *context = getContext();
+  auto *expr = AffineConstantExpr::get(value, context);
+  setLowerBound({}, AffineMap::get(0, 0, expr, {}, context));
+}
+
+void ForStmt::setConstantUpperBound(int64_t value) {
+  MLIRContext *context = getContext();
+  auto *expr = AffineConstantExpr::get(value, context);
+  setUpperBound({}, AffineMap::get(0, 0, expr, {}, context));
+}
 
 //===----------------------------------------------------------------------===//
 // IfStmt
@@ -277,12 +387,12 @@
     return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
   };
 
-  if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
-    SmallVector<MLValue *, 8> operands;
-    operands.reserve(opStmt->getNumOperands());
-    for (auto *opValue : opStmt->getOperands())
-      operands.push_back(remapOperand(opValue));
+  SmallVector<MLValue *, 8> operands;
+  operands.reserve(getNumOperands());
+  for (auto *opValue : getOperands())
+    operands.push_back(remapOperand(opValue));
 
+  if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
     SmallVector<Type *, 8> resultTypes;
     resultTypes.reserve(opStmt->getNumResults());
     for (auto *result : opStmt->getResults())
@@ -297,13 +407,18 @@
   }
 
   if (auto *forStmt = dyn_cast<ForStmt>(this)) {
-    auto *newFor =
-        new ForStmt(getLoc(), forStmt->getLowerBound(),
-                    forStmt->getUpperBound(), forStmt->getStep(), context);
+    auto *lbMap = forStmt->getLowerBoundMap();
+    auto *ubMap = forStmt->getUpperBoundMap();
+
+    auto *newFor = ForStmt::create(
+        getLoc(),
+        ArrayRef<MLValue *>(operands).take_front(lbMap->getNumOperands()),
+        lbMap, ArrayRef<MLValue *>(operands).take_back(ubMap->getNumOperands()),
+        ubMap, forStmt->getStep(), context);
+
     // Remember the induction variable mapping.
     operandMap[forStmt] = newFor;
 
-    // TODO: remap operands in loop bounds when they are added.
     // Recursively clone the body of the for loop.
     for (auto &subStmt : *forStmt)
       newFor->push_back(subStmt.clone(operandMap, context));
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index b74e113..deda932 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -372,6 +372,7 @@
 
     // TODO: check that operation is not a return statement unless it's
     // the last one in the function.
+    // TODO: check that loop bounds are properly formed.
     if (verifyReturn())
       return true;
 
@@ -409,23 +410,20 @@
       liveValues.insert(forStmt, true);
 
     for (auto &stmt : block) {
-      // TODO: For and If will eventually have operands, we need to check them.
-      // When this happens, Statement should have a general getOperands() method
-      // we can use here first.
-      if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
-        // Verify that each of the operands are live.
-        unsigned operandNo = 0;
-        for (auto *opValue : opStmt->getOperands()) {
-          if (!liveValues.count(opValue)) {
-            opStmt->emitError("operand #" + Twine(operandNo) +
-                              " does not dominate this use");
-            if (auto *useStmt = opValue->getDefiningStmt())
-              useStmt->emitNote("operand defined here");
-            return true;
-          }
-          ++operandNo;
+      // Verify that each of the operands are live.
+      unsigned operandNo = 0;
+      for (auto *opValue : stmt.getOperands()) {
+        if (!liveValues.count(opValue)) {
+          stmt.emitError("operand #" + Twine(operandNo) +
+                         " does not dominate this use");
+          if (auto *useStmt = opValue->getDefiningStmt())
+            useStmt->emitNote("operand defined here");
+          return true;
         }
+        ++operandNo;
+      }
 
+      if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
         // Operations define values, add them to the hash table.
         for (auto *result : opStmt->getResults())
           liveValues.insert(result, true);