[SVE][IR] Scalable Vector size queries and IR instruction support

* Adds a TypeSize struct to represent the known minimum size of a type
  along with a flag to indicate that the runtime size is a integer multiple
  of that size
* Converts existing size query functions from Type.h and DataLayout.h to
  return a TypeSize result
* Adds convenience methods (including a transparent conversion operator
  to uint64_t) so that most existing code 'just works' as if the return
  values were still scalars.
* Uses the new size queries along with ElementCount to ensure that all
  supported instructions used with scalable vectors can be constructed
  in IR.

Reviewers: hfinkel, lattner, rkruppe, greened, rovka, rengolin, sdesmalen

Reviewed By: rovka, sdesmalen

Differential Revision: https://reviews.llvm.org/D53137

llvm-svn: 374042
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 57dee45..89811ec 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -436,7 +436,8 @@
     if (auto *AllocSize = dyn_cast_or_null<ConstantInt>(Size)) {
       Type *Ty = I.getAllocatedType();
       AllocatedSize = SaturatingMultiplyAdd(
-          AllocSize->getLimitedValue(), DL.getTypeAllocSize(Ty), AllocatedSize);
+          AllocSize->getLimitedValue(), DL.getTypeAllocSize(Ty).getFixedSize(),
+          AllocatedSize);
       return Base::visitAlloca(I);
     }
   }
@@ -444,7 +445,8 @@
   // Accumulate the allocated size.
   if (I.isStaticAlloca()) {
     Type *Ty = I.getAllocatedType();
-    AllocatedSize = SaturatingAdd(DL.getTypeAllocSize(Ty), AllocatedSize);
+    AllocatedSize = SaturatingAdd(DL.getTypeAllocSize(Ty).getFixedSize(),
+                                  AllocatedSize);
   }
 
   // We will happily inline static alloca instructions.
diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index 6c05966..4f24f07 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -309,7 +309,8 @@
         NoopInput = Op;
     } else if (isa<TruncInst>(I) &&
                TLI.allowTruncateForTailCall(Op->getType(), I->getType())) {
-      DataBits = std::min(DataBits, I->getType()->getPrimitiveSizeInBits());
+      DataBits = std::min((uint64_t)DataBits,
+                         I->getType()->getPrimitiveSizeInBits().getFixedSize());
       NoopInput = Op;
     } else if (auto CS = ImmutableCallSite(I)) {
       const Value *ReturnedOp = CS.getReturnedArgOperand();
diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp
index b125d15..5fe7a2e 100644
--- a/llvm/lib/IR/DataLayout.cpp
+++ b/llvm/lib/IR/DataLayout.cpp
@@ -29,6 +29,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/TypeSize.h"
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
@@ -745,7 +746,10 @@
     llvm_unreachable("Bad type for getAlignment!!!");
   }
 
-  return getAlignmentInfo(AlignType, getTypeSizeInBits(Ty), abi_or_pref, Ty);
+  // If we're dealing with a scalable vector, we just need the known minimum
+  // size for determining alignment. If not, we'll get the exact size.
+  return getAlignmentInfo(AlignType, getTypeSizeInBits(Ty).getKnownMinSize(),
+                          abi_or_pref, Ty);
 }
 
 unsigned DataLayout::getABITypeAlignment(Type *Ty) const {
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index de1317e..2033180 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -38,6 +38,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/TypeSize.h"
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
@@ -1792,7 +1793,7 @@
                                      const Twine &Name,
                                      Instruction *InsertBefore)
 : Instruction(VectorType::get(cast<VectorType>(V1->getType())->getElementType(),
-                cast<VectorType>(Mask->getType())->getNumElements()),
+                cast<VectorType>(Mask->getType())->getElementCount()),
               ShuffleVector,
               OperandTraits<ShuffleVectorInst>::op_begin(this),
               OperandTraits<ShuffleVectorInst>::operands(this),
@@ -1809,7 +1810,7 @@
                                      const Twine &Name,
                                      BasicBlock *InsertAtEnd)
 : Instruction(VectorType::get(cast<VectorType>(V1->getType())->getElementType(),
-                cast<VectorType>(Mask->getType())->getNumElements()),
+                cast<VectorType>(Mask->getType())->getElementCount()),
               ShuffleVector,
               OperandTraits<ShuffleVectorInst>::op_begin(this),
               OperandTraits<ShuffleVectorInst>::operands(this),
@@ -2982,8 +2983,8 @@
       }
 
   // Get the bit sizes, we'll need these
-  unsigned SrcBits = SrcTy->getPrimitiveSizeInBits();   // 0 for ptr
-  unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr
+  auto SrcBits = SrcTy->getPrimitiveSizeInBits();   // 0 for ptr
+  auto DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr
 
   // Run through the possibilities ...
   if (DestTy->isIntegerTy()) {               // Casting to integral
@@ -3030,7 +3031,7 @@
 
   if (VectorType *SrcVecTy = dyn_cast<VectorType>(SrcTy)) {
     if (VectorType *DestVecTy = dyn_cast<VectorType>(DestTy)) {
-      if (SrcVecTy->getNumElements() == DestVecTy->getNumElements()) {
+      if (SrcVecTy->getElementCount() == DestVecTy->getElementCount()) {
         // An element by element cast. Valid if casting the elements is valid.
         SrcTy = SrcVecTy->getElementType();
         DestTy = DestVecTy->getElementType();
@@ -3044,12 +3045,12 @@
     }
   }
 
-  unsigned SrcBits = SrcTy->getPrimitiveSizeInBits();   // 0 for ptr
-  unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr
+  auto SrcBits = SrcTy->getPrimitiveSizeInBits();   // 0 for ptr
+  auto DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr
 
   // Could still have vectors of pointers if the number of elements doesn't
   // match
-  if (SrcBits == 0 || DestBits == 0)
+  if (SrcBits.getKnownMinSize() == 0 || DestBits.getKnownMinSize() == 0)
     return false;
 
   if (SrcBits != DestBits)
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 8ece7f2..3eab504 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -26,6 +26,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/TypeSize.h"
 #include <cassert>
 #include <utility>
 
@@ -111,18 +112,22 @@
   return false;
 }
 
-unsigned Type::getPrimitiveSizeInBits() const {
+TypeSize Type::getPrimitiveSizeInBits() const {
   switch (getTypeID()) {
-  case Type::HalfTyID: return 16;
-  case Type::FloatTyID: return 32;
-  case Type::DoubleTyID: return 64;
-  case Type::X86_FP80TyID: return 80;
-  case Type::FP128TyID: return 128;
-  case Type::PPC_FP128TyID: return 128;
-  case Type::X86_MMXTyID: return 64;
-  case Type::IntegerTyID: return cast<IntegerType>(this)->getBitWidth();
-  case Type::VectorTyID:  return cast<VectorType>(this)->getBitWidth();
-  default: return 0;
+  case Type::HalfTyID: return TypeSize::Fixed(16);
+  case Type::FloatTyID: return TypeSize::Fixed(32);
+  case Type::DoubleTyID: return TypeSize::Fixed(64);
+  case Type::X86_FP80TyID: return TypeSize::Fixed(80);
+  case Type::FP128TyID: return TypeSize::Fixed(128);
+  case Type::PPC_FP128TyID: return TypeSize::Fixed(128);
+  case Type::X86_MMXTyID: return TypeSize::Fixed(64);
+  case Type::IntegerTyID:
+    return TypeSize::Fixed(cast<IntegerType>(this)->getBitWidth());
+  case Type::VectorTyID: {
+    const VectorType *VTy = cast<VectorType>(this);
+    return TypeSize(VTy->getBitWidth(), VTy->isScalable());
+  }
+  default: return TypeSize::Fixed(0);
   }
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c7302b4..ee62b6d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8526,7 +8526,7 @@
       // Get the shift amount based on the scaling factor:
       // log2(sizeof(IdxTy)) - log2(8).
       uint64_t ShiftAmt =
-          countTrailingZeros(DL.getTypeStoreSizeInBits(IdxTy)) - 3;
+        countTrailingZeros(DL.getTypeStoreSizeInBits(IdxTy).getFixedSize()) - 3;
       // Is the constant foldable in the shift of the addressing mode?
       // I.e., shift amount is between 1 and 4 inclusive.
       if (ShiftAmt == 0 || ShiftAmt > 4)
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index c1e935f..4b81683 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -959,14 +959,16 @@
       std::tie(UsedI, I) = Uses.pop_back_val();
 
       if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
-        Size = std::max(Size, DL.getTypeStoreSize(LI->getType()));
+        Size = std::max(Size,
+                        DL.getTypeStoreSize(LI->getType()).getFixedSize());
         continue;
       }
       if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
         Value *Op = SI->getOperand(0);
         if (Op == UsedI)
           return SI;
-        Size = std::max(Size, DL.getTypeStoreSize(Op->getType()));
+        Size = std::max(Size,
+                        DL.getTypeStoreSize(Op->getType()).getFixedSize());
         continue;
       }