[SveEmitter] Fix encoding/decoding of SVETypeFlags

Summary:
This issue was introduced when reworking D75861. The bug isn't
actually hit with current unit tests because the contiguous loads/stores
infer the EltType and the MemEltType from the pointer and result, rather
than using the flags. But it will be needed for other intrinsics, such as
gather/scatter.

Reviewers: SjoerdMeijer, Andrzej

Reviewed By: SjoerdMeijer

Subscribers: andwar, tschuett, cfe-commits, llvm-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D76617
diff --git a/clang/utils/TableGen/SveEmitter.cpp b/clang/utils/TableGen/SveEmitter.cpp
index 831dfa8..dc009e5 100644
--- a/clang/utils/TableGen/SveEmitter.cpp
+++ b/clang/utils/TableGen/SveEmitter.cpp
@@ -65,9 +65,6 @@
     applyModifier(CharMod);
   }
 
-  /// Return the value in SVETypeFlags for this type.
-  unsigned getTypeFlags() const;
-
   bool isPointer() const { return Pointer; }
   bool isVoidPointer() const { return Pointer && Void; }
   bool isSigned() const { return Signed; }
@@ -138,36 +135,22 @@
   /// The architectural #ifdef guard.
   std::string Guard;
 
+  // The merge suffix such as _m, _x or _z.
+  std::string MergeSuffix;
+
   /// The types of return value [0] and parameters [1..].
   std::vector<SVEType> Types;
 
   /// The "base type", which is VarType('d', BaseTypeSpec).
   SVEType BaseType;
 
-  unsigned Flags;
+  uint64_t Flags;
 
 public:
-  /// The type of predication.
-  enum MergeType {
-    MergeNone,
-    MergeAny,
-    MergeOp1,
-    MergeZero,
-    MergeAnyExp,
-    MergeZeroExp,
-    MergeInvalid
-  } Merge;
-
-  Intrinsic(StringRef Name, StringRef Proto, int64_t MT, StringRef LLVMName,
-            unsigned Flags, TypeSpec BT, ClassKind Class, SVEEmitter &Emitter,
-            StringRef Guard)
-      : Name(Name.str()), LLVMName(LLVMName), Proto(Proto.str()),
-        BaseTypeSpec(BT), Class(Class), Guard(Guard.str()), BaseType(BT, 'd'),
-        Flags(Flags), Merge(MergeType(MT)) {
-    // Types[0] is the return value.
-    for (unsigned I = 0; I < Proto.size(); ++I)
-      Types.emplace_back(BaseTypeSpec, Proto[I]);
-  }
+  Intrinsic(StringRef Name, StringRef Proto, uint64_t MergeTy,
+            StringRef MergeSuffix, uint64_t MemoryElementTy, StringRef LLVMName,
+            uint64_t Flags, TypeSpec BT, ClassKind Class, SVEEmitter &Emitter,
+            StringRef Guard);
 
   ~Intrinsic()=default;
 
@@ -179,14 +162,13 @@
 
   StringRef getGuard() const { return Guard; }
   ClassKind getClassKind() const { return Class; }
-  MergeType getMergeType() const { return Merge; }
 
   SVEType getReturnType() const { return Types[0]; }
   ArrayRef<SVEType> getTypes() const { return Types; }
   SVEType getParamType(unsigned I) const { return Types[I + 1]; }
   unsigned getNumParams() const { return Proto.size() - 1; }
 
-  unsigned getFlags() const { return Flags; }
+  uint64_t getFlags() const { return Flags; }
   bool isFlagSet(uint64_t Flag) const { return Flags & Flag;}
 
   /// Return the type string for a BUILTIN() macro in Builtins.def.
@@ -209,7 +191,7 @@
   void emitIntrinsic(raw_ostream &OS) const;
 
 private:
-  std::string getMergeSuffix() const;
+  std::string getMergeSuffix() const { return MergeSuffix; }
   std::string mangleName(ClassKind LocalCK) const;
   std::string replaceTemplatedArgs(std::string Name, TypeSpec TS,
                                    std::string Proto) const;
@@ -221,8 +203,8 @@
   llvm::StringMap<uint64_t> EltTypes;
   llvm::StringMap<uint64_t> MemEltTypes;
   llvm::StringMap<uint64_t> FlagTypes;
+  llvm::StringMap<uint64_t> MergeTypes;
 
-  unsigned getTypeFlags(const SVEType &T);
 public:
   SVEEmitter(RecordKeeper &R) : Records(R) {
     for (auto *RV : Records.getAllDerivedDefinitions("EltType"))
@@ -231,8 +213,42 @@
       MemEltTypes[RV->getNameInitAsString()] = RV->getValueAsInt("Value");
     for (auto *RV : Records.getAllDerivedDefinitions("FlagType"))
       FlagTypes[RV->getNameInitAsString()] = RV->getValueAsInt("Value");
+    for (auto *RV : Records.getAllDerivedDefinitions("MergeType"))
+      MergeTypes[RV->getNameInitAsString()] = RV->getValueAsInt("Value");
   }
 
+  // Returns the SVETypeFlags for a given value and mask.
+  uint64_t encodeFlag(uint64_t V, StringRef MaskName) const {
+    auto It = FlagTypes.find(MaskName);
+    if (It != FlagTypes.end()) {
+      uint64_t Mask = It->getValue();
+      unsigned Shift = llvm::countTrailingZeros(Mask);
+      return (V << Shift) & Mask;
+    }
+    llvm_unreachable("Unsupported flag");
+  }
+
+  // Returns the SVETypeFlags for the given element type.
+  uint64_t encodeEltType(StringRef EltName) {
+    auto It = EltTypes.find(EltName);
+    if (It != EltTypes.end())
+      return encodeFlag(It->getValue(), "EltTypeMask");
+    llvm_unreachable("Unsupported EltType");
+  }
+
+  // Returns the SVETypeFlags for the given memory element type.
+  uint64_t encodeMemoryElementType(uint64_t MT) {
+    return encodeFlag(MT, "MemEltTypeMask");
+  }
+
+  // Returns the SVETypeFlags for the given merge type.
+  uint64_t encodeMergeType(uint64_t MT) {
+    return encodeFlag(MT, "MergeTypeMask");
+  }
+
+  // Returns the SVETypeFlags value for the given SVEType.
+  uint64_t encodeTypeFlags(const SVEType &T);
+
   /// Emit arm_sve.h.
   void createHeader(raw_ostream &o);
 
@@ -256,36 +272,6 @@
 // Type implementation
 //===----------------------------------------------------------------------===//
 
-unsigned SVEEmitter::getTypeFlags(const SVEType &T) {
-  unsigned FirstEltType = EltTypes["FirstEltType"];
-  if (T.isFloat()) {
-    switch (T.getElementSizeInBits()) {
-    case 16: return FirstEltType + EltTypes["EltTyFloat16"];
-    case 32: return FirstEltType + EltTypes["EltTyFloat32"];
-    case 64: return FirstEltType + EltTypes["EltTyFloat64"];
-    default: llvm_unreachable("Unhandled float element bitwidth!");
-    }
-  }
-
-  if (T.isPredicateVector()) {
-    switch (T.getElementSizeInBits()) {
-    case 8:  return FirstEltType + EltTypes["EltTyBool8"];
-    case 16: return FirstEltType + EltTypes["EltTyBool16"];
-    case 32: return FirstEltType + EltTypes["EltTyBool32"];
-    case 64: return FirstEltType + EltTypes["EltTyBool64"];
-    default: llvm_unreachable("Unhandled predicate element bitwidth!");
-    }
-  }
-
-  switch (T.getElementSizeInBits()) {
-  case 8:  return FirstEltType + EltTypes["EltTyInt8"];
-  case 16: return FirstEltType + EltTypes["EltTyInt16"];
-  case 32: return FirstEltType + EltTypes["EltTyInt32"];
-  case 64: return FirstEltType + EltTypes["EltTyInt64"];
-  default: llvm_unreachable("Unhandled integer element bitwidth!");
-  }
-}
-
 std::string SVEType::builtin_str() const {
   std::string S;
   if (isVoid())
@@ -543,6 +529,26 @@
 // Intrinsic implementation
 //===----------------------------------------------------------------------===//
 
+Intrinsic::Intrinsic(StringRef Name, StringRef Proto, uint64_t MergeTy,
+                     StringRef MergeSuffix, uint64_t MemoryElementTy,
+                     StringRef LLVMName, uint64_t Flags, TypeSpec BT,
+                     ClassKind Class, SVEEmitter &Emitter, StringRef Guard)
+    : Name(Name.str()), LLVMName(LLVMName), Proto(Proto.str()),
+      BaseTypeSpec(BT), Class(Class), Guard(Guard.str()),
+      MergeSuffix(MergeSuffix.str()), BaseType(BT, 'd'), Flags(Flags) {
+
+  // Types[0] is the return value.
+  for (unsigned I = 0; I < Proto.size(); ++I) {
+    SVEType T(BaseTypeSpec, Proto[I]);
+    Types.push_back(T);
+  }
+
+  // Set flags based on properties
+  this->Flags |= Emitter.encodeTypeFlags(BaseType);
+  this->Flags |= Emitter.encodeMemoryElementType(MemoryElementTy);
+  this->Flags |= Emitter.encodeMergeType(MergeTy);
+}
+
 std::string Intrinsic::getBuiltinTypeStr() {
   std::string S;
 
@@ -601,20 +607,6 @@
   return Ret;
 }
 
-// ACLE function names have a merge style postfix.
-std::string Intrinsic::getMergeSuffix() const {
-  switch (getMergeType()) {
-    default:
-      llvm_unreachable("Unknown predication specifier");
-    case MergeNone:    return "";
-    case MergeAny:
-    case MergeAnyExp:  return "_x";
-    case MergeOp1:     return "_m";
-    case MergeZero:
-    case MergeZeroExp: return "_z";
-  }
-}
-
 std::string Intrinsic::mangleName(ClassKind LocalCK) const {
   std::string S = getName();
 
@@ -668,6 +660,49 @@
 //===----------------------------------------------------------------------===//
 // SVEEmitter implementation
 //===----------------------------------------------------------------------===//
+uint64_t SVEEmitter::encodeTypeFlags(const SVEType &T) {
+  if (T.isFloat()) {
+    switch (T.getElementSizeInBits()) {
+    case 16:
+      return encodeEltType("EltTyFloat16");
+    case 32:
+      return encodeEltType("EltTyFloat32");
+    case 64:
+      return encodeEltType("EltTyFloat64");
+    default:
+      llvm_unreachable("Unhandled float element bitwidth!");
+    }
+  }
+
+  if (T.isPredicateVector()) {
+    switch (T.getElementSizeInBits()) {
+    case 8:
+      return encodeEltType("EltTyBool8");
+    case 16:
+      return encodeEltType("EltTyBool16");
+    case 32:
+      return encodeEltType("EltTyBool32");
+    case 64:
+      return encodeEltType("EltTyBool64");
+    default:
+      llvm_unreachable("Unhandled predicate element bitwidth!");
+    }
+  }
+
+  switch (T.getElementSizeInBits()) {
+  case 8:
+    return encodeEltType("EltTyInt8");
+  case 16:
+    return encodeEltType("EltTyInt16");
+  case 32:
+    return encodeEltType("EltTyInt32");
+  case 64:
+    return encodeEltType("EltTyInt64");
+  default:
+    llvm_unreachable("Unhandled integer element bitwidth!");
+  }
+}
+
 void SVEEmitter::createIntrinsic(
     Record *R, SmallVectorImpl<std::unique_ptr<Intrinsic>> &Out) {
   StringRef Name = R->getValueAsString("Name");
@@ -675,13 +710,14 @@
   StringRef Types = R->getValueAsString("Types");
   StringRef Guard = R->getValueAsString("ArchGuard");
   StringRef LLVMName = R->getValueAsString("LLVMIntrinsic");
-  int64_t Merge = R->getValueAsInt("Merge");
+  uint64_t Merge = R->getValueAsInt("Merge");
+  StringRef MergeSuffix = R->getValueAsString("MergeSuffix");
+  uint64_t MemEltType = R->getValueAsInt("MemEltType");
   std::vector<Record*> FlagsList = R->getValueAsListOfDefs("Flags");
 
   int64_t Flags = 0;
   for (auto FlagRec : FlagsList)
     Flags |= FlagRec->getValueAsInt("Value");
-  Flags |= R->getValueAsInt("MemEltType") + MemEltTypes["FirstMemEltType"];
 
   // Extract type specs from string
   SmallVector<TypeSpec, 8> TypeSpecs;
@@ -701,14 +737,15 @@
 
   // Create an Intrinsic for each type spec.
   for (auto TS : TypeSpecs) {
-    Out.push_back(std::make_unique<Intrinsic>(Name, Proto, Merge,
-                                              LLVMName, Flags, TS, ClassS,
-                                              *this, Guard));
+    Out.push_back(std::make_unique<Intrinsic>(Name, Proto, Merge, MergeSuffix,
+                                              MemEltType, LLVMName, Flags, TS,
+                                              ClassS, *this, Guard));
 
     // Also generate the short-form (e.g. svadd_m) for the given type-spec.
     if (Intrinsic::isOverloadedIntrinsic(Name))
-      Out.push_back(std::make_unique<Intrinsic>(
-          Name, Proto, Merge, LLVMName, Flags, TS, ClassG, *this, Guard));
+      Out.push_back(std::make_unique<Intrinsic>(Name, Proto, Merge, MergeSuffix,
+                                                MemEltType, LLVMName, Flags, TS,
+                                                ClassG, *this, Guard));
   }
 }
 
@@ -846,7 +883,7 @@
     if (Def->getClassKind() == ClassG)
       continue;
 
-    uint64_t Flags = Def->getFlags() | getTypeFlags(Def->getBaseType());
+    uint64_t Flags = Def->getFlags();
     auto FlagString = std::to_string(Flags);
 
     std::string LLVMName = Def->getLLVMName();
@@ -876,6 +913,11 @@
   for (auto &KV : MemEltTypes)
     OS << "  " << KV.getKey() << " = " << KV.getValue() << ",\n";
   OS << "#endif\n\n";
+
+  OS << "#ifdef LLVM_GET_SVE_MERGETYPES\n";
+  for (auto &KV : MergeTypes)
+    OS << "  " << KV.getKey() << " = " << KV.getValue() << ",\n";
+  OS << "#endif\n\n";
 }
 
 namespace clang {