[AArch64][SVE] Asm: Add restricted register classes for SVE predicate vectors.

Summary:
Add a register class for SVE predicate operands that can only be p0-p7 (as opposed to p0-p15)

Patch [1/3] in a series to add predicated ADD/SUB instructions for SVE.

Reviewers: rengolin, mcrosier, evandro, fhahn, echristo, olista01, SjoerdMeijer, javed.absar

Reviewed By: fhahn

Subscribers: aemerson, javed.absar, tschuett, kristof.beyls, llvm-commits

Differential Revision: https://reviews.llvm.org/D41441

llvm-svn: 321699
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 39e3e33..9023c3d 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -756,27 +756,31 @@
 
 //******************************************************************************
 
-// SVE predicate register class.
-def PPR : RegisterClass<"AArch64",
-                        [nxv16i1, nxv8i1, nxv4i1, nxv2i1],
-                        16, (sequence "P%u", 0, 15)> {
+// SVE predicate register classes.
+class PPRClass<int lastreg> : RegisterClass<
+                                  "AArch64",
+                                  [ nxv16i1, nxv8i1, nxv4i1, nxv2i1 ], 16,
+                                  (sequence "P%u", 0, lastreg)> {
   let Size = 16;
 }
 
-class PPRAsmOperand <string name, int Width>: AsmOperandClass {
+def PPR    : PPRClass<15>;
+def PPR_3b : PPRClass<7>; // Restricted 3 bit SVE predicate register class.
+
+class PPRAsmOperand <string name, string RegClass, int Width>: AsmOperandClass {
   let Name = "SVE" # name # "Reg";
   let PredicateMethod = "isSVEVectorRegOfWidth<"
-                            # Width # ", AArch64::PPRRegClassID>";
+                            # Width # ", " # "AArch64::" # RegClass # "RegClassID>";
   let DiagnosticType = "InvalidSVE" # name # "Reg";
   let RenderMethod = "addRegOperands";
   let ParserMethod = "tryParseSVEPredicateVector";
 }
 
-def PPRAsmOpAny : PPRAsmOperand<"PredicateAny", -1>;
-def PPRAsmOp8   : PPRAsmOperand<"PredicateB",  8>;
-def PPRAsmOp16  : PPRAsmOperand<"PredicateH", 16>;
-def PPRAsmOp32  : PPRAsmOperand<"PredicateS", 32>;
-def PPRAsmOp64  : PPRAsmOperand<"PredicateD", 64>;
+def PPRAsmOpAny : PPRAsmOperand<"PredicateAny", "PPR", -1>;
+def PPRAsmOp8   : PPRAsmOperand<"PredicateB",   "PPR",  8>;
+def PPRAsmOp16  : PPRAsmOperand<"PredicateH",   "PPR", 16>;
+def PPRAsmOp32  : PPRAsmOperand<"PredicateS",   "PPR", 32>;
+def PPRAsmOp64  : PPRAsmOperand<"PredicateD",   "PPR", 64>;
 
 def PPRAny : PPRRegOp<"",  PPRAsmOpAny, PPR>;
 def PPR8   : PPRRegOp<"b", PPRAsmOp8,   PPR>;
@@ -784,6 +788,18 @@
 def PPR32  : PPRRegOp<"s", PPRAsmOp32,  PPR>;
 def PPR64  : PPRRegOp<"d", PPRAsmOp64,  PPR>;
 
+def PPRAsmOp3bAny : PPRAsmOperand<"Predicate3bAny", "PPR_3b", -1>;
+def PPRAsmOp3b8   : PPRAsmOperand<"Predicate3bB",   "PPR_3b",  8>;
+def PPRAsmOp3b16  : PPRAsmOperand<"Predicate3bH",   "PPR_3b", 16>;
+def PPRAsmOp3b32  : PPRAsmOperand<"Predicate3bS",   "PPR_3b", 32>;
+def PPRAsmOp3b64  : PPRAsmOperand<"Predicate3bD",   "PPR_3b", 64>;
+
+def PPR3bAny : PPRRegOp<"",  PPRAsmOp3bAny, PPR_3b>;
+def PPR3b8   : PPRRegOp<"b", PPRAsmOp3b8,   PPR_3b>;
+def PPR3b16  : PPRRegOp<"h", PPRAsmOp3b16,  PPR_3b>;
+def PPR3b32  : PPRRegOp<"s", PPRAsmOp3b32,  PPR_3b>;
+def PPR3b64  : PPRRegOp<"d", PPRAsmOp3b64,  PPR_3b>;
+
 //******************************************************************************
 
 // SVE vector register class
diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
index a480fa3..ac9ff51 100644
--- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
+++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
@@ -843,6 +843,7 @@
       RK = RegKind::SVEDataVector;
       break;
     case AArch64::PPRRegClassID:
+    case AArch64::PPR_3bRegClassID:
       RK = RegKind::SVEPredicateVector;
       break;
     default:
@@ -3652,6 +3653,12 @@
   case Match_InvalidSVEPredicateSReg:
   case Match_InvalidSVEPredicateDReg:
     return Error(Loc, "invalid predicate register.");
+  case Match_InvalidSVEPredicate3bAnyReg:
+  case Match_InvalidSVEPredicate3bBReg:
+  case Match_InvalidSVEPredicate3bHReg:
+  case Match_InvalidSVEPredicate3bSReg:
+  case Match_InvalidSVEPredicate3bDReg:
+    return Error(Loc, "restricted predicate has range [0, 7].");
   default:
     llvm_unreachable("unexpected error code!");
   }
@@ -4081,6 +4088,11 @@
   case Match_InvalidSVEPredicateHReg:
   case Match_InvalidSVEPredicateSReg:
   case Match_InvalidSVEPredicateDReg:
+  case Match_InvalidSVEPredicate3bAnyReg:
+  case Match_InvalidSVEPredicate3bBReg:
+  case Match_InvalidSVEPredicate3bHReg:
+  case Match_InvalidSVEPredicate3bSReg:
+  case Match_InvalidSVEPredicate3bDReg:
   case Match_MSR:
   case Match_MRS: {
     if (ErrorInfo >= Operands.size())
diff --git a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
index ae278ca..30438a1 100644
--- a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
+++ b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
@@ -91,6 +91,9 @@
 static DecodeStatus DecodePPRRegisterClass(MCInst &Inst, unsigned RegNo,
                                            uint64_t Address,
                                            const void *Decode);
+LLVM_ATTRIBUTE_UNUSED static DecodeStatus
+DecodePPR_3bRegisterClass(llvm::MCInst &Inst, unsigned RegNo, uint64_t Address,
+                          const void *Decode);
 
 static DecodeStatus DecodeFixedPointScaleImm32(MCInst &Inst, unsigned Imm,
                                                uint64_t Address,
@@ -481,6 +484,16 @@
   return Success;
 }
 
+static DecodeStatus DecodePPR_3bRegisterClass(MCInst &Inst, unsigned RegNo,
+                                              uint64_t Addr,
+                                              const void* Decoder) {
+  if (RegNo > 7)
+    return Fail;
+
+  // Just reuse the PPR decode table
+  return DecodePPRRegisterClass(Inst, RegNo, Addr, Decoder);
+}
+
 static const unsigned VectorDecoderTable[] = {
     AArch64::Q0,  AArch64::Q1,  AArch64::Q2,  AArch64::Q3,  AArch64::Q4,
     AArch64::Q5,  AArch64::Q6,  AArch64::Q7,  AArch64::Q8,  AArch64::Q9,