Enhance the type system to support arbitrary precision integers, which are
important for low-bitwidth inference cases and hardware synthesis targets.

Rename 'int' to 'affineint' to avoid confusion between "the integers" and "the int
type".

PiperOrigin-RevId: 202751508
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index a9ee039..5f2ca9d 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -24,18 +24,12 @@
 namespace mlir {
   class MLIRContext;
   class PrimitiveType;
+  class IntegerType;
 
 /// Integer identifier for all the concrete type kinds.
 enum class TypeKind {
-  // Integer.
-  I1,
-  I8,
-  I16,
-  I32,
-  I64,
-
   // Target pointer sized integer.
-  Int,
+  AffineInt,
 
   // Floating point.
   BF16,
@@ -48,6 +42,7 @@
   LAST_PRIMITIVE_TYPE = F64,
 
   // Derived types.
+  Integer,
   Function,
   Vector,
   RankedTensor,
@@ -80,12 +75,8 @@
   void dump() const;
 
   // Convenience factories.
-  static PrimitiveType *getI1(MLIRContext *ctx);
-  static PrimitiveType *getI8(MLIRContext *ctx);
-  static PrimitiveType *getI16(MLIRContext *ctx);
-  static PrimitiveType *getI32(MLIRContext *ctx);
-  static PrimitiveType *getI64(MLIRContext *ctx);
-  static PrimitiveType *getInt(MLIRContext *ctx);
+  static IntegerType *getInt(unsigned width, MLIRContext *ctx);
+  static PrimitiveType *getAffineInt(MLIRContext *ctx);
   static PrimitiveType *getBF16(MLIRContext *ctx);
   static PrimitiveType *getF16(MLIRContext *ctx);
   static PrimitiveType *getF32(MLIRContext *ctx);
@@ -140,23 +131,9 @@
   PrimitiveType(TypeKind kind, MLIRContext *context);
 };
 
-inline PrimitiveType *Type::getI1(MLIRContext *ctx) {
-  return PrimitiveType::get(TypeKind::I1, ctx);
-}
-inline PrimitiveType *Type::getI8(MLIRContext *ctx) {
-  return PrimitiveType::get(TypeKind::I8, ctx);
-}
-inline PrimitiveType *Type::getI16(MLIRContext *ctx) {
-  return PrimitiveType::get(TypeKind::I16, ctx);
-}
-inline PrimitiveType *Type::getI32(MLIRContext *ctx) {
-  return PrimitiveType::get(TypeKind::I32, ctx);
-}
-inline PrimitiveType *Type::getI64(MLIRContext *ctx) {
-  return PrimitiveType::get(TypeKind::I64, ctx);
-}
-inline PrimitiveType *Type::getInt(MLIRContext *ctx) {
-  return PrimitiveType::get(TypeKind::Int, ctx);
+
+inline PrimitiveType *Type::getAffineInt(MLIRContext *ctx) {
+  return PrimitiveType::get(TypeKind::AffineInt, ctx);
 }
 inline PrimitiveType *Type::getBF16(MLIRContext *ctx) {
   return PrimitiveType::get(TypeKind::BF16, ctx);
@@ -171,6 +148,30 @@
   return PrimitiveType::get(TypeKind::F64, ctx);
 }
 
+/// Integer types can have arbitrary bitwidth up to a large fixed limit of 4096.
+class IntegerType : public Type {
+public:
+  static IntegerType *get(unsigned width, MLIRContext *context);
+
+  /// Return the bitwidth of this integer type.
+  unsigned getWidth() const {
+    return width;
+  }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Type *type) {
+    return type->getKind() == TypeKind::Integer;
+  }
+private:
+  unsigned width;
+  IntegerType(unsigned width, MLIRContext *context);
+};
+
+inline IntegerType *Type::getInt(unsigned width, MLIRContext *ctx) {
+  return IntegerType::get(width, ctx);
+}
+
+
 
 /// Function types map from a list of inputs to a list of results.
 class FunctionType : public Type {
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 7c1112b..85ff432 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -120,6 +120,9 @@
   using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
   AffineMapSet affineMaps;
 
+  /// Integer type uniquing.
+  DenseMap<unsigned, IntegerType*> integers;
+
   /// Function type uniquing.
   using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
   FunctionTypeSet functions;
@@ -173,15 +176,10 @@
   return Identifier(it->getKeyData());
 }
 
-
 //===----------------------------------------------------------------------===//
 // Types
 //===----------------------------------------------------------------------===//
 
-PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
-  : Type(kind, context) {
-}
-
 PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
   assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
   auto &impl = context->getImpl();
@@ -200,10 +198,16 @@
   return impl.primitives[(int)kind] = ptr;
 }
 
-FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
-                           unsigned numResults, MLIRContext *context)
-  : Type(TypeKind::Function, context, numInputs),
-    numResults(numResults), inputsAndResults(inputsAndResults) {
+IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
+  auto &impl = context->getImpl();
+
+  auto *&result = impl.integers[width];
+  if (!result) {
+    result = impl.allocator.Allocate<IntegerType>();
+    new (result) IntegerType(width, context);
+  }
+
+  return result;
 }
 
 FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
@@ -236,18 +240,9 @@
   return *existing.first = result;
 }
 
-
-
-VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
-                       MLIRContext *context)
-  : Type(TypeKind::Vector, context, shape.size()),
-    shapeElements(shape.data()), elementType(elementType) {
-}
-
-
 VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
   assert(!shape.empty() && "vector types must have at least one dimension");
-  assert(isa<PrimitiveType>(elementType) &&
+  assert((isa<PrimitiveType>(elementType) || isa<IntegerType>(elementType)) &&
          "vectors elements must be primitives");
 
   auto *context = elementType->getContext();
@@ -277,22 +272,12 @@
 
 TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
   : Type(kind, context), elementType(elementType) {
-  assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
+  assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType) ||
+          isa<IntegerType>(elementType)) &&
          "tensor elements must be primitives or vectors");
   assert(isa<TensorType>(this));
 }
 
-RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
-                                   MLIRContext *context)
-  : TensorType(TypeKind::RankedTensor, elementType, context),
-    shapeElements(shape.data()) {
-  setSubclassData(shape.size());
-}
-
-UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
-  : TensorType(TypeKind::UnrankedTensor, elementType, context) {
-}
-
 RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
                                         Type *elementType) {
   auto *context = elementType->getContext();
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index a5578b8..e16c6eb 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -21,18 +21,50 @@
 #include "mlir/Support/STLExtras.h"
 using namespace mlir;
 
+PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
+  : Type(kind, context) {
+}
+
+IntegerType::IntegerType(unsigned width, MLIRContext *context)
+  : Type(TypeKind::Integer, context), width(width) {
+}
+
+FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
+                           unsigned numResults, MLIRContext *context)
+  : Type(TypeKind::Function, context, numInputs),
+    numResults(numResults), inputsAndResults(inputsAndResults) {
+}
+
+VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
+                       MLIRContext *context)
+  : Type(TypeKind::Vector, context, shape.size()),
+    shapeElements(shape.data()), elementType(elementType) {
+}
+
+RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
+                                   MLIRContext *context)
+  : TensorType(TypeKind::RankedTensor, elementType, context),
+    shapeElements(shape.data()) {
+  setSubclassData(shape.size());
+}
+
+UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
+  : TensorType(TypeKind::UnrankedTensor, elementType, context) {
+}
+
 void Type::print(raw_ostream &os) const {
   switch (getKind()) {
-  case TypeKind::I1:   os << "i1"; return;
-  case TypeKind::I8:   os << "i8"; return;
-  case TypeKind::I16:  os << "i16"; return;
-  case TypeKind::I32:  os << "i32"; return;
-  case TypeKind::I64:  os << "i64"; return;
-  case TypeKind::Int:  os << "int"; return;
+  case TypeKind::AffineInt: os << "affineint"; return;
   case TypeKind::BF16: os << "bf16"; return;
   case TypeKind::F16:  os << "f16"; return;
   case TypeKind::F32:  os << "f32"; return;
   case TypeKind::F64:  os << "f64"; return;
+
+  case TypeKind::Integer: {
+    auto *integer = cast<IntegerType>(this);
+    os << 'i' << integer->getWidth();
+    return;
+  }
   case TypeKind::Function: {
     auto *func = cast<FunctionType>(this);
     os << '(';
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 17755e0..8943200 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -140,6 +140,7 @@
 /// Lex a bare identifier or keyword that starts with a letter.
 ///
 ///   bare-id ::= letter (letter|digit|[_])*
+///   integer-type ::= `i[1-9][0-9]*`
 ///
 Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
   // Match the rest of the identifier regex: [0-9a-zA-Z_]*
@@ -149,6 +150,15 @@
   // Check to see if this identifier is a keyword.
   StringRef spelling(tokStart, curPtr-tokStart);
 
+  // Check for i123.
+  if (tokStart[0] == 'i') {
+    bool allDigits = true;
+    for (auto c : spelling.drop_front())
+      allDigits &= isdigit(c) != 0;
+    if (allDigits && spelling.size() != 1)
+      return Token(Token::inttype, spelling);
+  }
+
   Token::Kind kind = llvm::StringSwitch<Token::Kind>(spelling)
 #define TOK_KEYWORD(SPELLING) \
     .Case(#SPELLING, Token::kw_##SPELLING)
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 6927050..1bfa331 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -121,7 +121,7 @@
   // as the results of their action.
 
   // Type parsing.
-  PrimitiveType *parsePrimitiveType();
+  Type *parsePrimitiveType();
   Type *parseElementType();
   VectorType *parseVectorType();
   ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
@@ -218,12 +218,11 @@
 
 /// Parse the low-level fixed dtypes in the system.
 ///
-///   primitive-type
-///      ::= `f16` | `bf16` | `f32` | `f64`      // Floating point
-///        | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers
-///        | `int`
+///   primitive-type ::= `f16` | `bf16` | `f32` | `f64`
+///   primitive-type ::= integer-type
+///   primitive-type ::= `affineint`
 ///
-PrimitiveType *Parser::parsePrimitiveType() {
+Type *Parser::parsePrimitiveType() {
   switch (curToken.getKind()) {
   default:
     return (emitError("expected type"), nullptr);
@@ -239,24 +238,16 @@
   case Token::kw_f64:
     consumeToken(Token::kw_f64);
     return Type::getF64(context);
-  case Token::kw_i1:
-    consumeToken(Token::kw_i1);
-    return Type::getI1(context);
-  case Token::kw_i8:
-    consumeToken(Token::kw_i8);
-    return Type::getI8(context);
-  case Token::kw_i16:
-    consumeToken(Token::kw_i16);
-    return Type::getI16(context);
-  case Token::kw_i32:
-    consumeToken(Token::kw_i32);
-    return Type::getI32(context);
-  case Token::kw_i64:
-    consumeToken(Token::kw_i64);
-    return Type::getI64(context);
-  case Token::kw_int:
-    consumeToken(Token::kw_int);
-    return Type::getInt(context);
+  case Token::kw_affineint:
+    consumeToken(Token::kw_affineint);
+    return Type::getAffineInt(context);
+  case Token::inttype: {
+    auto width = curToken.getIntTypeBitwidth();
+    if (!width.hasValue())
+      return (emitError("invalid integer width"), nullptr);
+    consumeToken(Token::inttype);
+    return Type::getInt(width.getValue(), context);
+  }
   }
 }
 
@@ -419,11 +410,9 @@
     return (emitError("expected '>' in memref type"), nullptr);
 
   // FIXME: Add an IR representation for memref types.
-  return Type::getI1(context);
+  return Type::getInt(1, context);
 }
 
-
-
 /// Parse a function type.
 ///
 ///   function-type ::= type-list-parens `->` type-list
@@ -445,7 +434,6 @@
   return FunctionType::get(arguments, results, context);
 }
 
-
 /// Parse an arbitrary type.
 ///
 ///   type ::= primitive-type
diff --git a/lib/Parser/Token.cpp b/lib/Parser/Token.cpp
index 5563255..e1e4bed 100644
--- a/lib/Parser/Token.cpp
+++ b/lib/Parser/Token.cpp
@@ -48,6 +48,18 @@
   return result;
 }
 
+/// For an inttype token, return its bitwidth.
+Optional<unsigned> Token::getIntTypeBitwidth() const {
+ unsigned result = 0;
+  if (spelling[1] == '0' ||
+      spelling.drop_front().getAsInteger(10, result) ||
+      // Arbitrary but large limit on bitwidth.
+      result > 4096 || result == 0)
+    return None;
+  return result;
+}
+
+
 /// Given a 'string' token, return its value, including removing the quote
 /// characters and unescaping the contents of the string.
 std::string Token::getStringValue() const {
diff --git a/lib/Parser/Token.h b/lib/Parser/Token.h
index e5e4fc4..bc9e8e4 100644
--- a/lib/Parser/Token.h
+++ b/lib/Parser/Token.h
@@ -73,6 +73,9 @@
   /// return None.
   Optional<unsigned> getUnsignedIntegerValue() const;
 
+  /// For an inttype token, return its bitwidth.
+  Optional<unsigned> getIntTypeBitwidth() const;
+
   /// Given a 'string' token, return its value, including removing the quote
   /// characters and unescaping the contents of the string.
   std::string getStringValue() const;
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index 72d769a..73a30df 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -58,6 +58,7 @@
 // Literals
 TOK_LITERAL(integer)                    // 42
 TOK_LITERAL(string)                     // "foo"
+TOK_LITERAL(inttype)                    // i421
 
 // Punctuation.
 TOK_PUNCTUATION(arrow,            "->")
@@ -84,6 +85,7 @@
 // TODO: More operator tokens
 
 // Keywords.  These turn "foo" into Token::kw_foo enums.
+TOK_KEYWORD(affineint)
 TOK_KEYWORD(bf16)
 TOK_KEYWORD(br)
 TOK_KEYWORD(cfgfunc)
@@ -91,12 +93,6 @@
 TOK_KEYWORD(f16)
 TOK_KEYWORD(f32)
 TOK_KEYWORD(f64)
-TOK_KEYWORD(i1)
-TOK_KEYWORD(i16)
-TOK_KEYWORD(i32)
-TOK_KEYWORD(i64)
-TOK_KEYWORD(i8)
-TOK_KEYWORD(int)
 TOK_KEYWORD(memref)
 TOK_KEYWORD(mlfunc)
 TOK_KEYWORD(return)
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 408fe13..fb4c2cc 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -6,7 +6,7 @@
 ; Check different error cases.
 ; -----
 
-extfunc @illegaltype(i42) ; expected-error {{expected type}}
+extfunc @illegaltype(i) ; expected-error {{expected type}}
 
 ; -----
 
@@ -19,7 +19,7 @@
 
 ; -----
 
-extfunc missingsigil() -> (i1, int, f32) ; expected-error {{expected a function identifier like}}
+extfunc missingsigil() -> (i1, affineint, f32) ; expected-error {{expected a function identifier like}}
 
 
 ; -----
@@ -75,3 +75,9 @@
   ""()   ; expected-error {{empty operation name is invalid}}
   return
 }
+
+; -----
+
+extfunc @illegaltype(i0) ; expected-error {{invalid integer width}}
+
+
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index d69192d..1b9d9cb 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -10,25 +10,28 @@
 ; CHECK: extfunc @bar()
 extfunc @bar() -> ()
 
-; CHECK: extfunc @baz() -> (i1, int, f32)
-extfunc @baz() -> (i1, int, f32)
+; CHECK: extfunc @baz() -> (i1, affineint, f32)
+extfunc @baz() -> (i1, affineint, f32)
 
 ; CHECK: extfunc @missingReturn()
 extfunc @missingReturn()
 
+; CHECK: extfunc @int_types(i1, i2, i4, i7, i87) -> (i1, affineint, i19)
+extfunc @int_types(i1, i2, i4, i7, i87) -> (i1, affineint, i19)
+
 
 ; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
 extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
 
-; CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xint>, tensor<i8>)
+; CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xaffineint>, tensor<i8>)
 extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
-                 tensor<1x?x4x?x?xint>, tensor<i8>)
+                 tensor<1x?x4x?x?xaffineint>, tensor<i8>)
 
 ; CHECK: extfunc @memrefs(i1, i1)
-extfunc @memrefs(memref<1x?x4x?x?xint>, memref<i8>)
+extfunc @memrefs(memref<1x?x4x?x?xaffineint>, memref<i8>)
 
 ; CHECK: extfunc @functions((i1, i1) -> (), () -> ())
-extfunc @functions((memref<1x?x4x?x?xint>, memref<i8>) -> (), ()->())
+extfunc @functions((memref<1x?x4x?x?xaffineint>, memref<i8>) -> (), ()->())
 
 
 ; CHECK-LABEL: cfgfunc @simpleCFG() {