Implement induction variables. Pretty print induction variable operands as %i<ssa value number>. Add support for future pretty printing of ML function arguments as %arg<ssa value number>.

Induction variables are implemented by inheriting ForStmt from MLValue. ForStmt provides APIs that make this design decision invisible to the ForStmt users.

This CL in combination with cl/206253643 resolves  http://b/111769060.

PiperOrigin-RevId: 206655937
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index cc31c12..f77f1d7 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -187,19 +187,17 @@
 };
 
 /// For statement represents an affine loop nest.
-class ForStmt : public Statement, public StmtBlock {
+class ForStmt : public Statement, public StmtBlock, private MLValue {
 public:
   // TODO: lower and upper bounds should be affine maps with
   // dimension and symbol use lists.
   explicit ForStmt(AffineConstantExpr *lowerBound,
-                   AffineConstantExpr *upperBound, AffineConstantExpr *step)
-      : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
-        lowerBound(lowerBound), upperBound(upperBound), step(step) {}
+                   AffineConstantExpr *upperBound, AffineConstantExpr *step,
+                   MLIRContext *context);
 
   // Loop bounds and step are immortal objects and don't need to be deleted.
   ~ForStmt() {}
 
-  // TODO: represent induction variable
   AffineConstantExpr *getLowerBound() const { return lowerBound; }
   AffineConstantExpr *getUpperBound() const { return upperBound; }
   AffineConstantExpr *getStep() const { return step; }
@@ -213,6 +211,16 @@
     return block->getStmtBlockKind() == StmtBlockKind::For;
   }
 
+  // For statement represents induction variable by inheriting
+  // from MLValue. This design is hidden behind interfaces.
+  static bool classof(const SSAValue *value) {
+    return value->getKind() == SSAValueKind::InductionVar;
+  }
+
+  /// MLValue methods
+  MLValue *getInductionVar() { return this; }
+  const MLValue *getInductionVar() const { return this; }
+
 private:
   AffineConstantExpr *lowerBound;
   AffineConstantExpr *upperBound;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index c928b34..2e09850 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -542,22 +542,40 @@
 protected:
   void numberValueID(const SSAValue *value) {
     assert(!valueIDs.count(value) && "Value numbered multiple times");
-    valueIDs[value] = nextValueID++;
+    unsigned id;
+    switch (value->getKind()) {
+    case SSAValueKind::BBArgument:
+    case SSAValueKind::InstResult:
+    case SSAValueKind::StmtResult:
+      id = nextValueID++;
+      break;
+    case SSAValueKind::FnArgument:
+      id = nextFnArgumentID++;
+      break;
+    case SSAValueKind::InductionVar:
+      id = nextInductionVarID++;
+      break;
+    }
+    valueIDs[value] = id;
   }
 
-  void printValueID(const SSAValue *value,
-                    bool dontPrintResultNo = false) const {
+  void printValueID(const SSAValue *value, bool printResultNo = true) const {
     int resultNo = -1;
     auto lookupValue = value;
 
-    // If this is a reference to the result of a multi-result instruction, print
-    // out the # identifier and make sure to map our lookup to the first result
-    // of the instruction.
+    // If this is a reference to the result of a multi-result instruction or
+    // statement, print out the # identifier and make sure to map our lookup
+    // to the first result of the instruction.
     if (auto *result = dyn_cast<InstResult>(value)) {
       if (result->getOwner()->getNumResults() != 1) {
         resultNo = result->getResultNumber();
         lookupValue = result->getOwner()->getResult(0);
       }
+    } else if (auto *result = dyn_cast<StmtResult>(value)) {
+      if (result->getOwner()->getNumResults() != 1) {
+        resultNo = result->getResultNumber();
+        lookupValue = result->getOwner()->getResult(0);
+      }
     }
 
     auto it = valueIDs.find(lookupValue);
@@ -566,8 +584,14 @@
       return;
     }
 
-    os << '%' << it->getSecond();
-    if (resultNo != -1 && !dontPrintResultNo)
+    os << '%';
+    if (isa<ForStmt>(value))
+
+      os << 'i';
+    else if (isa<FnArgument>(value))
+      os << "arg";
+    os << it->getSecond();
+    if (resultNo != -1 && printResultNo)
       os << '#' << resultNo;
   }
 
@@ -575,12 +599,14 @@
   /// This is the value ID for each SSA value in the current function.
   DenseMap<const SSAValue *, unsigned> valueIDs;
   unsigned nextValueID = 0;
+  unsigned nextInductionVarID = 0;
+  unsigned nextFnArgumentID = 0;
 };
 } // end anonymous namespace
 
 void FunctionPrinter::printOperation(const Operation *op) {
   if (op->getNumResults()) {
-    printValueID(op->getResult(0), /*dontPrintResultNo*/ true);
+    printValueID(op->getResult(0), /*printResultNo=*/false);
     os << " = ";
   }
 
@@ -874,6 +900,9 @@
       if (stmt->getNumResults() != 0)
         printer->numberValueID(stmt->getResult(0));
     }
+    void visitForStmt(ForStmt *stmt) {
+      printer->numberValueID(stmt->getInductionVar());
+    }
     MLFunctionPrinter *printer;
   };
 
@@ -918,7 +947,9 @@
 }
 
 void MLFunctionPrinter::print(const ForStmt *stmt) {
-  os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
+  os.indent(numSpaces) << "for ";
+  printOperand(stmt->getInductionVar());
+  os << " = " << *stmt->getLowerBound();
   os << " to " << *stmt->getUpperBound();
   if (stmt->getStep()->getValue() != 1)
     os << " step " << *stmt->getStep();
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index e7cb9cb..1a094d9 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -161,7 +161,7 @@
                                   AffineConstantExpr *step) {
   if (!step)
     step = getConstantExpr(1);
-  auto *stmt = new ForStmt(lowerBound, upperBound, step);
+  auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
   block->getStatements().push_back(stmt);
   return stmt;
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 0731e54..6ace8ff 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/Types.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -191,6 +192,16 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ForStmt
+//===----------------------------------------------------------------------===//
+
+ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
+                 AffineConstantExpr *step, MLIRContext *context)
+    : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
+      MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)),
+      lowerBound(lowerBound), upperBound(upperBound), step(step) {}
+
+//===----------------------------------------------------------------------===//
 // IfStmt
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 654db98..3bcec37 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2086,10 +2086,8 @@
   if (getToken().isNot(Token::percent_identifier))
     return emitError("expected SSA identifier for the loop variable");
 
-  // TODO: create SSA value definition from name
-  StringRef name = getTokenSpelling().drop_front();
-  (void)name;
-
+  auto loc = getToken().getLoc();
+  StringRef inductionVariableName = getTokenSpelling().drop_front();
   consumeToken(Token::percent_identifier);
 
   if (parseToken(Token::equal, "expected ="))
@@ -2116,14 +2114,20 @@
   }
 
   // Create for statement.
-  ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
+  ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step);
+
+  // Create SSA value definition for the induction variable.
+  addDefinition({inductionVariableName, 0, loc}, forStmt->getInductionVar());
 
   // If parsing of the for statement body fails,
   // MLIR contains for statement with those nested statements that have been
   // successfully parsed.
-  if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
+  if (parseStmtBlock(forStmt))
     return ParseFailure;
 
+  // Reset insertion point to the current block.
+  builder.setInsertionPoint(forStmt->getBlock());
+
   return ParseSuccess;
 }
 
@@ -2174,6 +2178,9 @@
       return ParseFailure;
   }
 
+  // Reset insertion point to the current block.
+  builder.setInsertionPoint(ifStmt->getBlock());
+
   return ParseSuccess;
 }
 
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 38a79f8..6d1f15d 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -281,3 +281,12 @@
   return
 }
 
+// -----
+
+mlfunc @duplicate_induction_var() {
+  for %i = 1 to 10 {   // expected-error {{previously defined here}}
+    for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}}
+    }
+  }
+  return
+}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 61604fb..fe82238 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -116,7 +116,7 @@
 mlfunc @mlfunc_with_ops() {
   // CHECK: %0 = "foo"() : () -> i64
   %a = "foo"() : ()->i64
-  // CHECK: for x = 1 to 10 {
+  // CHECK: for %i0 = 1 to 10 {
   for %i = 1 to 10 {
     // CHECK: %1 = "doo"() : () -> f32
     %b = "doo"() : ()->f32
@@ -132,18 +132,34 @@
 
 // CHECK-LABEL: mlfunc @loops() {
 mlfunc @loops() {
-  // CHECK: for x = 1 to 100 step 2 {
+  // CHECK: for %i0 = 1 to 100 step 2 {
   for %i = 1 to 100 step 2 {
-    // CHECK: for x = 1 to 200 {
+    // CHECK: for %i1 = 1 to 200 {
     for %j = 1 to 200 {
     }        // CHECK:     }
   }          // CHECK:   }
   return     // CHECK:   return
 }            // CHECK: }
 
+// CHECK-LABEL: mlfunc @complex_loops() {
+mlfunc @complex_loops() {
+  for %i1 = 1 to 100 {      // CHECK:   for %i0 = 1 to 100 {
+    for %j1 = 1 to 100 {    // CHECK:     for %i1 = 1 to 100 {
+       "foo"() : () -> ()   // CHECK:       "foo"() : () -> ()
+    }                       // CHECK:     }
+    "boo"() : () -> ()      // CHECK:     "boo"() : () -> ()
+    for %j2 = 1 to 10 {     // CHECK:     for %i2 = 1 to 10 {
+      for %k2 = 1 to 10 {   // CHECK:       for %i3 = 1 to 10 {
+        "goo"() : () -> ()  // CHECK:         "goo"() : () -> ()
+      }                     // CHECK:       }
+    }                       // CHECK:     }
+  }                         // CHECK:   }
+  return                    // CHECK:   return
+}                           // CHECK: }
+
 // CHECK-LABEL: mlfunc @ifstmt() {
 mlfunc @ifstmt() {
-  for %i = 1 to 10 {    // CHECK   for x = 1 to 10 {
+  for %i = 1 to 10 {    // CHECK   for %i0 = 1 to 10 {
     if () {             // CHECK     if () {
     } else if () {      // CHECK     } else if () {
     } else {            // CHECK     } else {
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index 3c42142..2d42e4a 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: mlfunc @loops() {
 mlfunc @loops() {
-  // CHECK: for x = 1 to 100 step 2 {
+  // CHECK: for %i0 = 1 to 100 step 2 {
   for %i = 1 to 100 step 2 {
     // CHECK: "custom"(){value: 1} : () -> ()
     // CHECK-NEXT: "custom"(){value: 1} : () -> ()