Parse affine map range sizes.

PiperOrigin-RevId: 204240947
diff --git a/include/mlir/IR/AffineMap.h b/include/mlir/IR/AffineMap.h
index 9119d4d..0b6e614 100644
--- a/include/mlir/IR/AffineMap.h
+++ b/include/mlir/IR/AffineMap.h
@@ -41,7 +41,14 @@
 class AffineMap {
 public:
   static AffineMap *get(unsigned dimCount, unsigned symbolCount,
-                        ArrayRef<AffineExpr *> results, MLIRContext *context);
+                        ArrayRef<AffineExpr *> results,
+                        ArrayRef<AffineExpr *> rangeSizes,
+                        MLIRContext *context);
+
+  /// Returns true if the co-domain (or more loosely speaking, range) of this
+  /// map is bounded. Bounded affine maps have a size (extent) for each of
+  /// their range dimensions (more accurately co-domain dimensions).
+  bool isBounded() const { return rangeSizes != nullptr; }
 
   // Prints affine map to 'os'.
   void print(raw_ostream &os) const;
@@ -55,12 +62,17 @@
     return ArrayRef<AffineExpr *>(results, numResults);
   }
 
- private:
-  AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
-            AffineExpr *const *results);
+  ArrayRef<AffineExpr *> getRangeSizes() const {
+    return rangeSizes ? ArrayRef<AffineExpr *>(rangeSizes, numResults)
+                      : ArrayRef<AffineExpr *>();
+  }
 
-  AffineMap(const AffineMap&) = delete;
-  void operator=(const AffineMap&) = delete;
+private:
+  AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
+            AffineExpr *const *results, AffineExpr *const *rangeSizes);
+
+  AffineMap(const AffineMap &) = delete;
+  void operator=(const AffineMap &) = delete;
 
   const unsigned numDims;
   const unsigned numSymbols;
@@ -69,6 +81,10 @@
   /// The affine expressions for this (multi-dimensional) map.
   /// TODO: use trailing objects for this.
   AffineExpr *const *const results;
+
+  /// The extents along each of the range dimensions if the map is bounded,
+  /// nullptr otherwise.
+  AffineExpr *const *const rangeSizes;
 };
 
 }  // end namespace mlir
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 730fe7d..1453a4a 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -75,7 +75,8 @@
 
   // Affine Expressions and Affine Map.
   AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
-                          ArrayRef<AffineExpr *> results);
+                          ArrayRef<AffineExpr *> results,
+                          ArrayRef<AffineExpr *> rangeSizes);
   AffineDimExpr *getDimExpr(unsigned position);
   AffineSymbolExpr *getSymbolExpr(unsigned position);
   AffineConstantExpr *getConstantExpr(int64_t constant);
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 822b230..733c235 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -23,9 +23,9 @@
 using namespace mlir;
 
 AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
-                     AffineExpr *const *results)
+                     AffineExpr *const *results, AffineExpr *const *rangeSizes)
     : numDims(numDims), numSymbols(numSymbols), numResults(numResults),
-      results(results) {}
+      results(results), rangeSizes(rangeSizes) {}
 
 /// Fold to a constant when possible. Canonicalize so that only the RHS is a
 /// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index d284d0f..27ed3e4 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -392,6 +392,17 @@
   os << " -> (";
   interleave(getResults(), [&](AffineExpr *expr) { os << *expr; },
              [&]() { os << ", "; });
+  os << ")";
+
+  if (!isBounded()) {
+    os << "\n";
+    return;
+  }
+
+  // Print range sizes for bounded affine maps.
+  os << " size (";
+  interleave(getRangeSizes(), [&](AffineExpr *expr) { os << *expr; },
+             [&]() { os << ", "; });
   os << ")\n";
 }
 
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 71aa0a3..e9bea2a 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -99,8 +99,9 @@
 //===----------------------------------------------------------------------===//
 
 AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
-                                 ArrayRef<AffineExpr *> results) {
-  return AffineMap::get(dimCount, symbolCount, results, context);
+                                 ArrayRef<AffineExpr *> results,
+                                 ArrayRef<AffineExpr *> rangeSizes) {
+  return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
 }
 
 AffineDimExpr *Builder::getDimExpr(unsigned position) {
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index f495bda..a2e967e 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -55,21 +55,23 @@
 struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
   // Affine maps are uniqued based on their dim/symbol counts and affine
   // expressions.
-  using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>>;
+  using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>,
+                           ArrayRef<AffineExpr *>>;
   using DenseMapInfo<AffineMap *>::getHashValue;
   using DenseMapInfo<AffineMap *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
         std::get<0>(key), std::get<1>(key),
-        hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
+        hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
+        hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
   }
 
   static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
     return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
-                                  rhs->getResults());
+                                  rhs->getResults(), rhs->getRangeSizes());
   }
 };
 
@@ -555,14 +557,17 @@
 
 AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
                           ArrayRef<AffineExpr *> results,
+                          ArrayRef<AffineExpr *> rangeSizes,
                           MLIRContext *context) {
   // The number of results can't be zero.
   assert(!results.empty());
 
+  assert(rangeSizes.empty() || results.size() == rangeSizes.size());
+
   auto &impl = context->getImpl();
 
   // Check if we already have this affine map.
-  auto key = std::make_tuple(dimCount, symbolCount, results);
+  auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes);
   auto existing = impl.affineMaps.insert_as(nullptr, key);
 
   // If we already have it, return that value.
@@ -575,8 +580,12 @@
   // Copy the results into the bump pointer.
   results = impl.copyInto(ArrayRef<AffineExpr *>(results));
 
+  // Copy the results into the bump pointer.
+  rangeSizes = impl.copyInto(ArrayRef<AffineExpr *>(rangeSizes));
+
   // Initialize the memory using placement new.
-  new (res) AffineMap(dimCount, symbolCount, results.size(), results.data());
+  new (res) AffineMap(dimCount, symbolCount, results.size(), results.data(),
+                      rangeSizes.empty() ? nullptr : rangeSizes.data());
 
   // Cache and return it.
   return *existing.first = res;
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3eff98c..1d73d5b 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -628,6 +628,11 @@
   unsigned getNumDims() const { return dims.size(); }
   unsigned getNumSymbols() const { return symbols.size(); }
 
+  /// Returns true if the only identifiers the parser accepts in affine
+  /// expressions are symbolic identifiers.
+  bool isPureSymbolic() const { return pureSymbolic; }
+  void setSymbolicParsing(bool val) { pureSymbolic = val; }
+
   // Binary affine op parsing.
   AffineLowPrecOp consumeIfLowPrecOp();
   AffineHighPrecOp consumeIfHighPrecOp();
@@ -657,6 +662,9 @@
   // TODO(bondhugula): could just use an vector/ArrayRef and scan the numbers.
   llvm::StringMap<unsigned> dims;
   llvm::StringMap<unsigned> symbols;
+  /// True if the parser should allow only symbolic identifiers in affine
+  /// expressions.
+  bool pureSymbolic = false;
 };
 } // end anonymous namespace
 
@@ -664,6 +672,7 @@
 AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
                                                    AffineExpr *lhs,
                                                    AffineExpr *rhs) {
+  // TODO: make the error location info accurate.
   switch (op) {
   case Mul:
     if (!lhs->isSymbolic() && !rhs->isSymbolic()) {
@@ -828,15 +837,21 @@
     return (emitError("expected bare identifier"), nullptr);
 
   StringRef sRef = getTokenSpelling();
+  // dims, symbols are all pairwise distinct.
   if (dims.count(sRef)) {
+    if (isPureSymbolic())
+      return (emitError("identifier used is not a symbolic identifier"),
+              nullptr);
     consumeToken(Token::bare_identifier);
     return builder.getDimExpr(dims.lookup(sRef));
   }
+
   if (symbols.count(sRef)) {
     consumeToken(Token::bare_identifier);
     return builder.getSymbolExpr(symbols.lookup(sRef));
   }
-  return (emitError("identifier is neither dimensional nor symbolic"), nullptr);
+
+  return (emitError("use of undeclared identifier"), nullptr);
 }
 
 /// Parse a positive integral constant appearing in an affine expression.
@@ -1053,8 +1068,36 @@
   if (parseCommaSeparatedList(Token::r_paren, parseElt, false))
     return nullptr;
 
+  // Parse optional range sizes.
+  //  (`size` `(` dim-size (`,` dim-size)* `)`)?
+  // TODO: check if sizes are non-negative whenever they are constant.
+  SmallVector<AffineExpr *, 4> rangeSizes;
+  if (consumeIf(Token::kw_size)) {
+    // Location of the l_paren token (if it exists) for error reporting later.
+    auto loc = getToken().getLoc();
+    if (!consumeIf(Token::l_paren))
+      return (emitError("expected '(' at start of affine map range"), nullptr);
+
+    auto parseRangeSize = [&]() -> ParseResult {
+      auto *elt = parseAffineExpr();
+      ParseResult res = elt ? ParseSuccess : ParseFailure;
+      rangeSizes.push_back(elt);
+      return res;
+    };
+
+    setSymbolicParsing(true);
+    if (parseCommaSeparatedList(Token::r_paren, parseRangeSize, false))
+      return nullptr;
+    if (exprs.size() > rangeSizes.size())
+      return (emitError(loc, "fewer range sizes than range expressions"),
+              nullptr);
+    if (exprs.size() < rangeSizes.size())
+      return (emitError(loc, "more range sizes than range expressions"),
+              nullptr);
+  }
+
   // Parsed a valid affine map.
-  return builder.getAffineMap(dims.size(), symbols.size(), exprs);
+  return builder.getAffineMap(dims.size(), symbols.size(), exprs, rangeSizes);
 }
 
 AffineMap *Parser::parseAffineMapInline() {
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index 1637a95..dda5cae 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -103,6 +103,7 @@
 TOK_KEYWORD(mlfunc)
 TOK_KEYWORD(mod)
 TOK_KEYWORD(return)
+TOK_KEYWORD(size)
 TOK_KEYWORD(tensor)
 TOK_KEYWORD(true)
 TOK_KEYWORD(vector)
diff --git a/test/IR/parser-affine-map-negative.mlir b/test/IR/parser-affine-map-negative.mlir
index 5ce4850..a795b3b 100644
--- a/test/IR/parser-affine-map-negative.mlir
+++ b/test/IR/parser-affine-map-negative.mlir
@@ -33,7 +33,7 @@
 #hello_world = (i, j) [s0] -> i + s0, j) ; expected-error {{expected '(' at start of affine map range}}
 
 ; -----
-#hello_world = (i, j) [s0] -> (x) ; expected-error {{identifier is neither dimensional nor symbolic}}
+#hello_world = (i, j) [s0] -> (x) ; expected-error {{use of undeclared identifier}}
 
 ; -----
 #hello_world = (i, j, i) [s0] -> (i) ; expected-error {{dimensional identifier name reused}}
@@ -98,7 +98,22 @@
 #hello_world = (i, j) [s0, s1] -> (-1*i j, j) ; expected-error {{expected ',' or ')'}}
 
 ; -----
-#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{identifier is neither dimensional nor symbolic}}
+#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{use of undeclared identifier}}
+
+; -----
+#hello_world = (i, j) -> (i, j) size (10, x) ; expected-error {{use of undeclared identifier}}
+
+; -----
+#hello_world = (i, j) [M] -> (i, j) size (10, j) ; expected-error {{identifier used is not a symbolic identifier}}
+
+; -----
+#hello_world = (i, j) [M] -> (i, j) size (10, M+i) ; expected-error {{identifier used is not a symbolic identifier}}
+
+; -----
+#hello_world = (i, j) -> (i, j) size (10) ; expected-error {{fewer range sizes than range expressions}}
+
+; -----
+#hello_world = (i, j) -> (i, j) size (10, 20, 30) ; expected-error {{more range sizes than range expressions}}
 
 ; TODO(bondhugula): Add more tests; coverage of error messages emitted not complete
 
diff --git a/test/IR/parser-affine-map.mlir b/test/IR/parser-affine-map.mlir
index 7f55f69..f999afe 100644
--- a/test/IR/parser-affine-map.mlir
+++ b/test/IR/parser-affine-map.mlir
@@ -110,3 +110,12 @@
 
 ; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
 #hello_world39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
+#hello_world40 = (i, j) -> (i, j) size (10, 20)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
+#hello_world41 = (i, j) [N, M] -> (i, j) size (N, M+10)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
+#hello_world42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)