[Sema][AArch64] Support arm_sve_vector_bits attribute

This patch implements the semantics for the 'arm_sve_vector_bits' type
attribute, defined by the Arm C Language Extensions (ACLE) for SVE [1].
The purpose of this attribute is to define vector-length-specific (VLS)
versions of existing vector-length-agnostic (VLA) types.

The semantics were already implemented by D83551, although the
implementation approach has since changed to represent VLSTs as
VectorType in the AST and fixed-length vectors in the IR everywhere
except in function args/returns. This is described in the prototype
patch D85128 demonstrating the new approach.

The semantic changes added in D83551 are changed since the
AttributedType is replaced by VectorType in the AST. Minimal changes
were necessary in the previous patch as the canonical type for both VLA
and VLS was the same (i.e. sizeless), except in constructs such as
globals and structs where sizeless types are unsupported. This patch
reverts the changes that permitted VLS types that were represented as
sizeless types in such circumstances, and adds support for implicit
casting between VLA <-> VLS types as described in section 3.7.3.2 of the
ACLE.

Since the SVE builtin types for bool and uint8 are both represented as
BuiltinType::UChar in VLSTs, two new vector kinds are implemented to
distinguish predicate and data vectors.

[1] https://developer.arm.com/documentation/100987/latest

Reviewed By: aaron.ballman

Differential Revision: https://reviews.llvm.org/D85736
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 172a901..2b411cd 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -1871,50 +1871,6 @@
   return TI;
 }
 
-static unsigned getSveVectorWidth(const Type *T) {
-  // Get the vector size from the 'arm_sve_vector_bits' attribute via the
-  // AttributedTypeLoc associated with the typedef decl.
-  if (const auto *TT = T->getAs<TypedefType>()) {
-    const TypedefNameDecl *Typedef = TT->getDecl();
-    TypeSourceInfo *TInfo = Typedef->getTypeSourceInfo();
-    TypeLoc TL = TInfo->getTypeLoc();
-    if (AttributedTypeLoc ATL = TL.getAs<AttributedTypeLoc>())
-      if (const auto *Attr = ATL.getAttrAs<ArmSveVectorBitsAttr>())
-        return Attr->getNumBits();
-  }
-
-  llvm_unreachable("bad 'arm_sve_vector_bits' attribute!");
-}
-
-static unsigned getSvePredWidth(const ASTContext &Context, const Type *T) {
-  return getSveVectorWidth(T) / Context.getCharWidth();
-}
-
-unsigned ASTContext::getBitwidthForAttributedSveType(const Type *T) const {
-  assert(T->isVLST() &&
-         "getBitwidthForAttributedSveType called for non-attributed type!");
-
-  switch (T->castAs<BuiltinType>()->getKind()) {
-  default:
-    llvm_unreachable("unknown builtin type!");
-  case BuiltinType::SveInt8:
-  case BuiltinType::SveInt16:
-  case BuiltinType::SveInt32:
-  case BuiltinType::SveInt64:
-  case BuiltinType::SveUint8:
-  case BuiltinType::SveUint16:
-  case BuiltinType::SveUint32:
-  case BuiltinType::SveUint64:
-  case BuiltinType::SveFloat16:
-  case BuiltinType::SveFloat32:
-  case BuiltinType::SveFloat64:
-  case BuiltinType::SveBFloat16:
-    return getSveVectorWidth(T);
-  case BuiltinType::SveBool:
-    return getSvePredWidth(*this, T);
-  }
-}
-
 /// getTypeInfoImpl - Return the size of the specified type, in bits.  This
 /// method does not work on incomplete types.
 ///
@@ -1981,6 +1937,13 @@
     uint64_t TargetVectorAlign = Target->getMaxVectorAlign();
     if (TargetVectorAlign && TargetVectorAlign < Align)
       Align = TargetVectorAlign;
+    if (VT->getVectorKind() == VectorType::SveFixedLengthDataVector)
+      // Adjust the alignment for fixed-length SVE vectors. This is important
+      // for non-power-of-2 vector lengths.
+      Align = 128;
+    else if (VT->getVectorKind() == VectorType::SveFixedLengthPredicateVector)
+      // Adjust the alignment for fixed-length SVE predicates.
+      Align = 16;
     break;
   }
 
@@ -2319,10 +2282,7 @@
       Align = Info.Align;
       AlignIsRequired = Info.AlignIsRequired;
     }
-    if (T->isVLST())
-      Width = getBitwidthForAttributedSveType(T);
-    else
-      Width = Info.Width;
+    Width = Info.Width;
     break;
   }
 
@@ -8540,6 +8500,31 @@
   return false;
 }
 
+bool ASTContext::areCompatibleSveTypes(QualType FirstType,
+                                       QualType SecondType) {
+  assert(((FirstType->isSizelessBuiltinType() && SecondType->isVectorType()) ||
+          (FirstType->isVectorType() && SecondType->isSizelessBuiltinType())) &&
+         "Expected SVE builtin type and vector type!");
+
+  auto IsValidCast = [this](QualType FirstType, QualType SecondType) {
+    if (const auto *BT = FirstType->getAs<BuiltinType>()) {
+      if (const auto *VT = SecondType->getAs<VectorType>()) {
+        // Predicates have the same representation as uint8 so we also have to
+        // check the kind to make these types incompatible.
+        if (VT->getVectorKind() == VectorType::SveFixedLengthPredicateVector)
+          return BT->getKind() == BuiltinType::SveBool;
+        else if (VT->getVectorKind() == VectorType::SveFixedLengthDataVector)
+          return VT->getElementType().getCanonicalType() ==
+                 FirstType->getSveEltType(*this);
+      }
+    }
+    return false;
+  };
+
+  return IsValidCast(FirstType, SecondType) ||
+         IsValidCast(SecondType, FirstType);
+}
+
 bool ASTContext::hasDirectOwnershipQualifier(QualType Ty) const {
   while (true) {
     // __strong id
diff --git a/clang/lib/AST/JSONNodeDumper.cpp b/clang/lib/AST/JSONNodeDumper.cpp
index 4bd00ec..a9136a9 100644
--- a/clang/lib/AST/JSONNodeDumper.cpp
+++ b/clang/lib/AST/JSONNodeDumper.cpp
@@ -616,6 +616,12 @@
   case VectorType::NeonPolyVector:
     JOS.attribute("vectorKind", "neon poly");
     break;
+  case VectorType::SveFixedLengthDataVector:
+    JOS.attribute("vectorKind", "fixed-length sve data vector");
+    break;
+  case VectorType::SveFixedLengthPredicateVector:
+    JOS.attribute("vectorKind", "fixed-length sve predicate vector");
+    break;
   }
 }
 
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 47a7e43..16c4c37 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -1408,6 +1408,12 @@
   case VectorType::NeonPolyVector:
     OS << " neon poly";
     break;
+  case VectorType::SveFixedLengthDataVector:
+    OS << " fixed-length sve data vector";
+    break;
+  case VectorType::SveFixedLengthPredicateVector:
+    OS << " fixed-length sve predicate vector";
+    break;
   }
   OS << " " << T->getNumElements();
 }
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 71e6db0..801f89a 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2313,11 +2313,42 @@
   return false;
 }
 
-bool Type::isVLST() const {
-  if (!isVLSTBuiltinType())
-    return false;
+QualType Type::getSveEltType(const ASTContext &Ctx) const {
+  assert(isVLSTBuiltinType() && "unsupported type!");
 
-  return hasAttr(attr::ArmSveVectorBits);
+  const BuiltinType *BTy = getAs<BuiltinType>();
+  switch (BTy->getKind()) {
+  default:
+    llvm_unreachable("Unknown builtin SVE type!");
+  case BuiltinType::SveInt8:
+    return Ctx.SignedCharTy;
+  case BuiltinType::SveUint8:
+  case BuiltinType::SveBool:
+    // Represent predicates as i8 rather than i1 to avoid any layout issues.
+    // The type is bitcasted to a scalable predicate type when casting between
+    // scalable and fixed-length vectors.
+    return Ctx.UnsignedCharTy;
+  case BuiltinType::SveInt16:
+    return Ctx.ShortTy;
+  case BuiltinType::SveUint16:
+    return Ctx.UnsignedShortTy;
+  case BuiltinType::SveInt32:
+    return Ctx.IntTy;
+  case BuiltinType::SveUint32:
+    return Ctx.UnsignedIntTy;
+  case BuiltinType::SveInt64:
+    return Ctx.LongTy;
+  case BuiltinType::SveUint64:
+    return Ctx.UnsignedLongTy;
+  case BuiltinType::SveFloat16:
+    return Ctx.Float16Ty;
+  case BuiltinType::SveBFloat16:
+    return Ctx.BFloat16Ty;
+  case BuiltinType::SveFloat32:
+    return Ctx.FloatTy;
+  case BuiltinType::SveFloat64:
+    return Ctx.DoubleTy;
+  }
 }
 
 bool QualType::isPODType(const ASTContext &Context) const {
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 7286a88..5e9b226 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -655,6 +655,24 @@
     printBefore(T->getElementType(), OS);
     break;
   }
+  case VectorType::SveFixedLengthDataVector:
+  case VectorType::SveFixedLengthPredicateVector:
+    // FIXME: We prefer to print the size directly here, but have no way
+    // to get the size of the type.
+    OS << "__attribute__((__arm_sve_vector_bits__(";
+
+    if (T->getVectorKind() == VectorType::SveFixedLengthPredicateVector)
+      // Predicates take a bit per byte of the vector size, multiply by 8 to
+      // get the number of bits passed to the attribute.
+      OS << T->getNumElements() * 8;
+    else
+      OS << T->getNumElements();
+
+    OS << " * sizeof(";
+    print(T->getElementType(), OS, StringRef());
+    // Multiply by 8 for the number of bits.
+    OS << ") * 8))) ";
+    printBefore(T->getElementType(), OS);
   }
 }
 
@@ -702,6 +720,24 @@
     printBefore(T->getElementType(), OS);
     break;
   }
+  case VectorType::SveFixedLengthDataVector:
+  case VectorType::SveFixedLengthPredicateVector:
+    // FIXME: We prefer to print the size directly here, but have no way
+    // to get the size of the type.
+    OS << "__attribute__((__arm_sve_vector_bits__(";
+    if (T->getSizeExpr()) {
+      T->getSizeExpr()->printPretty(OS, nullptr, Policy);
+      if (T->getVectorKind() == VectorType::SveFixedLengthPredicateVector)
+        // Predicates take a bit per byte of the vector size, multiply by 8 to
+        // get the number of bits passed to the attribute.
+        OS << " * 8";
+      OS << " * sizeof(";
+      print(T->getElementType(), OS, StringRef());
+      // Multiply by 8 for the number of bits.
+      OS << ") * 8";
+    }
+    OS << "))) ";
+    printBefore(T->getElementType(), OS);
   }
 }
 
@@ -1634,9 +1670,6 @@
   case attr::ArmMveStrictPolymorphism:
     OS << "__clang_arm_mve_strict_polymorphism";
     break;
-  case attr::ArmSveVectorBits:
-    OS << "arm_sve_vector_bits";
-    break;
   }
   OS << "))";
 }
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 566a2f9..fba590f 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -8032,7 +8032,7 @@
     return;
   }
 
-  if (!NewVD->hasLocalStorage() && T->isSizelessType() && !T->isVLST()) {
+  if (!NewVD->hasLocalStorage() && T->isSizelessType()) {
     Diag(NewVD->getLocation(), diag::err_sizeless_nonlocal) << T;
     NewVD->setInvalidDecl();
     return;
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 0df54c0..4501857 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -9009,6 +9009,14 @@
       }
     }
 
+    // Allow assignments between fixed-length and sizeless SVE vectors.
+    if (((LHSType->isSizelessBuiltinType() && RHSType->isVectorType()) ||
+         (LHSType->isVectorType() && RHSType->isSizelessBuiltinType())) &&
+        Context.areCompatibleSveTypes(LHSType, RHSType)) {
+      Kind = CK_BitCast;
+      return Compatible;
+    }
+
     return Incompatible;
   }
 
@@ -9899,6 +9907,22 @@
 
   // Okay, the expression is invalid.
 
+  // Returns true if the operands are SVE VLA and VLS types.
+  auto IsSveConversion = [](QualType FirstType, QualType SecondType) {
+    const VectorType *VecType = SecondType->getAs<VectorType>();
+    return FirstType->isSizelessBuiltinType() && VecType &&
+           (VecType->getVectorKind() == VectorType::SveFixedLengthDataVector ||
+            VecType->getVectorKind() ==
+                VectorType::SveFixedLengthPredicateVector);
+  };
+
+  // If there's a sizeless and fixed-length operand, diagnose that.
+  if (IsSveConversion(LHSType, RHSType) || IsSveConversion(RHSType, LHSType)) {
+    Diag(Loc, diag::err_typecheck_vector_not_convertable_sizeless)
+        << LHSType << RHSType;
+    return QualType();
+  }
+
   // If there's a non-vector, non-real operand, diagnose that.
   if ((!RHSVecType && !RHSType->isRealType()) ||
       (!LHSVecType && !LHSType->isRealType())) {
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 132f5b0..d1fcdf3 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -4323,6 +4323,12 @@
                              VK_RValue, /*BasePath=*/nullptr, CCK).get();
     break;
 
+  case ICK_SVE_Vector_Conversion:
+    From = ImpCastExprToType(From, ToType, CK_BitCast, VK_RValue,
+                             /*BasePath=*/nullptr, CCK)
+               .get();
+    break;
+
   case ICK_Vector_Splat: {
     // Vector splat from any arithmetic type to a vector.
     Expr *Elem = prepareVectorSplat(ToType, From).get();
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index dc00989..ec7c41e 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -137,6 +137,7 @@
     ICR_Conversion,
     ICR_Conversion,
     ICR_Conversion,
+    ICR_Conversion,
     ICR_OCL_Scalar_Widening,
     ICR_Complex_Real_Conversion,
     ICR_Conversion,
@@ -174,6 +175,7 @@
     "Compatible-types conversion",
     "Derived-to-base conversion",
     "Vector conversion",
+    "SVE Vector conversion",
     "Vector splat",
     "Complex-real conversion",
     "Block Pointer conversion",
@@ -1650,6 +1652,12 @@
     }
   }
 
+  if ((ToType->isSizelessBuiltinType() || FromType->isSizelessBuiltinType()) &&
+      S.Context.areCompatibleSveTypes(FromType, ToType)) {
+    ICK = ICK_SVE_Vector_Conversion;
+    return true;
+  }
+
   // We can perform the conversion between vector types in the following cases:
   // 1)vector types are equivalent AltiVec and GCC vector types
   // 2)lax vector conversions are permitted and the vector types are of the
@@ -4104,6 +4112,20 @@
                  : ImplicitConversionSequence::Worse;
   }
 
+  if (SCS1.Second == ICK_SVE_Vector_Conversion &&
+      SCS2.Second == ICK_SVE_Vector_Conversion) {
+    bool SCS1IsCompatibleSVEVectorConversion =
+        S.Context.areCompatibleSveTypes(SCS1.getFromType(), SCS1.getToType(2));
+    bool SCS2IsCompatibleSVEVectorConversion =
+        S.Context.areCompatibleSveTypes(SCS2.getFromType(), SCS2.getToType(2));
+
+    if (SCS1IsCompatibleSVEVectorConversion !=
+        SCS2IsCompatibleSVEVectorConversion)
+      return SCS1IsCompatibleSVEVectorConversion
+                 ? ImplicitConversionSequence::Better
+                 : ImplicitConversionSequence::Worse;
+  }
+
   return ImplicitConversionSequence::Indistinguishable;
 }
 
@@ -5524,6 +5546,7 @@
   case ICK_Compatible_Conversion:
   case ICK_Derived_To_Base:
   case ICK_Vector_Conversion:
+  case ICK_SVE_Vector_Conversion:
   case ICK_Vector_Splat:
   case ICK_Complex_Real:
   case ICK_Block_Pointer_Conversion:
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index c08d442..03442fb 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2342,7 +2342,7 @@
       return QualType();
   }
 
-  if (T->isSizelessType() && !T->isVLST()) {
+  if (T->isSizelessType()) {
     Diag(Loc, diag::err_array_incomplete_or_sizeless_type) << 1 << T;
     return QualType();
   }
@@ -7810,14 +7810,10 @@
 /// HandleArmSveVectorBitsTypeAttr - The "arm_sve_vector_bits" attribute is
 /// used to create fixed-length versions of sizeless SVE types defined by
 /// the ACLE, such as svint32_t and svbool_t.
-static void HandleArmSveVectorBitsTypeAttr(TypeProcessingState &State,
-                                           QualType &CurType,
-                                           ParsedAttr &Attr) {
-  Sema &S = State.getSema();
-  ASTContext &Ctx = S.Context;
-
+static void HandleArmSveVectorBitsTypeAttr(QualType &CurType, ParsedAttr &Attr,
+                                           Sema &S) {
   // Target must have SVE.
-  if (!Ctx.getTargetInfo().hasFeature("sve")) {
+  if (!S.Context.getTargetInfo().hasFeature("sve")) {
     S.Diag(Attr.getLoc(), diag::err_attribute_unsupported) << Attr;
     Attr.setInvalid();
     return;
@@ -7862,8 +7858,18 @@
     return;
   }
 
-  auto *A = ::new (Ctx) ArmSveVectorBitsAttr(Ctx, Attr, VecSize);
-  CurType = State.getAttributedType(A, CurType, CurType);
+  const auto *BT = CurType->castAs<BuiltinType>();
+
+  QualType EltType = CurType->getSveEltType(S.Context);
+  unsigned TypeSize = S.Context.getTypeSize(EltType);
+  VectorType::VectorKind VecKind = VectorType::SveFixedLengthDataVector;
+  if (BT->getKind() == BuiltinType::SveBool) {
+    // Predicates are represented as i8.
+    VecSize /= S.Context.getCharWidth() * S.Context.getCharWidth();
+    VecKind = VectorType::SveFixedLengthPredicateVector;
+  } else
+    VecSize /= TypeSize;
+  CurType = S.Context.getVectorType(EltType, VecSize, VecKind);
 }
 
 static void HandleArmMveStrictPolymorphismAttr(TypeProcessingState &State,
@@ -8134,7 +8140,7 @@
       attr.setUsedAsTypeAttr();
       break;
     case ParsedAttr::AT_ArmSveVectorBits:
-      HandleArmSveVectorBitsTypeAttr(state, type, attr);
+      HandleArmSveVectorBitsTypeAttr(type, attr, state.getSema());
       attr.setUsedAsTypeAttr();
       break;
     case ParsedAttr::AT_ArmMveStrictPolymorphism: {