[AArch64] Armv8.2-A: add the crypto extensions

This adds MC support for the crypto instructions that were made optional
extensions in Armv8.2-A (AArch64 only).

Differential Revision: https://reviews.llvm.org/D49370

llvm-svn: 338010
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 1ba2f38..1060c64 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -5720,7 +5720,7 @@
   def v16i8  : BaseSIMDDifferentThreeVector<U, 0b001, opc,
                                             V128, V128, V128,
                                             asm#"2", ".8h", ".16b", ".16b", []>;
-  let Predicates = [HasCrypto] in {
+  let Predicates = [HasAES] in {
     def v1i64  : BaseSIMDDifferentThreeVector<U, 0b110, opc,
                                               V128, V64, V64,
                                               asm, ".1q", ".1d", ".1d", []>;
@@ -9920,7 +9920,6 @@
 // Crypto extensions
 //----------------------------------------------------------------------------
 
-let Predicates = [HasCrypto] in {
 let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in
 class AESBase<bits<4> opc, string asm, dag outs, dag ins, string cstr,
               list<dag> pat>
@@ -10010,7 +10009,103 @@
 class SHAInstSS<bits<4> opc, string asm, Intrinsic OpNode>
   : SHA2OpInst<opc, asm, "", "", (outs FPR32:$Rd), (ins FPR32:$Rn),
                [(set (i32 FPR32:$Rd), (OpNode (i32 FPR32:$Rn)))]>;
-} // end of 'let Predicates = [HasCrypto]'
+
+// Armv8.2-A Crypto extensions
+class BaseCryptoV82<dag oops, dag iops, string asm, string asmops, string cst,
+                    list<dag> pattern>
+  : I <oops, iops, asm, asmops, cst, pattern>, Sched<[WriteV]> {
+  bits<5> Vd;
+  bits<5> Vn;
+  let Inst{31-25} = 0b1100111;
+  let Inst{9-5}   = Vn;
+  let Inst{4-0}   = Vd;
+}
+
+class CryptoRRTied<bits<1>op0, bits<2>op1, string asm, string asmops>
+  : BaseCryptoV82<(outs V128:$Vd), (ins V128:$Vn, V128:$Vm), asm, asmops,
+                  "$Vm = $Vd", []> {
+  let Inst{31-25} = 0b1100111;
+  let Inst{24-21} = 0b0110;
+  let Inst{20-15} = 0b000001;
+  let Inst{14}    = op0;
+  let Inst{13-12} = 0b00;
+  let Inst{11-10} = op1;
+}
+class CryptoRRTied_2D<bits<1>op0, bits<2>op1, string asm>
+  : CryptoRRTied<op0, op1, asm, "{\t$Vd.2d, $Vn.2d}">;
+class CryptoRRTied_4S<bits<1>op0, bits<2>op1, string asm>
+  : CryptoRRTied<op0, op1, asm, "{\t$Vd.4s, $Vn.4s}">;
+
+class CryptoRRR<bits<1> op0, bits<2>op1, dag oops, dag iops, string asm,
+                string asmops, string cst>
+  : BaseCryptoV82<oops, iops, asm , asmops, cst, []> {
+  bits<5> Vm;
+  let Inst{24-21} = 0b0011;
+  let Inst{20-16} = Vm;
+  let Inst{15}    = 0b1;
+  let Inst{14}    = op0;
+  let Inst{13-12} = 0b00;
+  let Inst{11-10} = op1;
+}
+class CryptoRRR_2D<bits<1> op0, bits<2>op1, string asm>
+  : CryptoRRR<op0, op1, (outs V128:$Vd), (ins V128:$Vn, V128:$Vm), asm,
+              "{\t$Vd.2d, $Vn.2d, $Vm.2d}", "">;
+class CryptoRRRTied_2D<bits<1> op0, bits<2>op1, string asm>
+  : CryptoRRR<op0, op1, (outs V128:$Vdst), (ins V128:$Vd, V128:$Vn, V128:$Vm), asm,
+              "{\t$Vd.2d, $Vn.2d, $Vm.2d}", "$Vd = $Vdst">;
+class CryptoRRR_4S<bits<1> op0, bits<2>op1, string asm>
+  : CryptoRRR<op0, op1, (outs V128:$Vd), (ins V128:$Vn, V128:$Vm), asm,
+              "{\t$Vd.4s, $Vn.4s, $Vm.4s}", "">;
+class CryptoRRRTied_4S<bits<1> op0, bits<2>op1, string asm>
+  : CryptoRRR<op0, op1, (outs V128:$Vdst), (ins V128:$Vd, V128:$Vn, V128:$Vm), asm,
+              "{\t$Vd.4s, $Vn.4s, $Vm.4s}", "$Vd = $Vdst">;
+class CryptoRRRTied<bits<1> op0, bits<2>op1, string asm>
+  : CryptoRRR<op0, op1, (outs FPR128:$Vdst), (ins FPR128:$Vd, FPR128:$Vn, V128:$Vm),
+              asm, "{\t$Vd, $Vn, $Vm.2d}", "$Vd = $Vdst">;
+
+class CryptoRRRR<bits<2>op0, string asm, string asmops>
+  : BaseCryptoV82<(outs V128:$Vd), (ins V128:$Vn, V128:$Vm, V128:$Va), asm,
+                  asmops, "", []> {
+  bits<5> Vm;
+  bits<5> Va;
+  let Inst{24-23} = 0b00;
+  let Inst{22-21} = op0;
+  let Inst{20-16} = Vm;
+  let Inst{15}    = 0b0;
+  let Inst{14-10} = Va;
+}
+class CryptoRRRR_16B<bits<2>op0, string asm>
+ : CryptoRRRR<op0, asm, "{\t$Vd.16b, $Vn.16b, $Vm.16b, $Va.16b}"> {
+}
+class CryptoRRRR_4S<bits<2>op0, string asm>
+ : CryptoRRRR<op0, asm, "{\t$Vd.4s, $Vn.4s, $Vm.4s, $Va.4s}"> {
+}
+
+class CryptoRRRi6<string asm>
+  : BaseCryptoV82<(outs V128:$Vd), (ins V128:$Vn, V128:$Vm, uimm6:$imm), asm,
+                  "{\t$Vd.2d, $Vn.2d, $Vm.2d, $imm}", "", []> {
+  bits<6> imm;
+  bits<5> Vm;
+  let Inst{24-21} = 0b0100;
+  let Inst{20-16} = Vm;
+  let Inst{15-10} = imm;
+  let Inst{9-5}   = Vn;
+  let Inst{4-0}   = Vd;
+}
+
+class CryptoRRRi2Tied<bits<1>op0, bits<2>op1, string asm>
+  : BaseCryptoV82<(outs V128:$Vdst),
+                  (ins V128:$Vd, V128:$Vn, V128:$Vm, VectorIndexS:$imm),
+                  asm, "{\t$Vd.4s, $Vn.4s, $Vm.s$imm}", "$Vd = $Vdst", []> {
+  bits<2> imm;
+  bits<5> Vm;
+  let Inst{24-21} = 0b0010;
+  let Inst{20-16} = Vm;
+  let Inst{15}    = 0b1;
+  let Inst{14}    = op0;
+  let Inst{13-12} = imm;
+  let Inst{11-10} = op1;
+}
 
 //----------------------------------------------------------------------------
 // v8.1 atomic instructions extension:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 6ea7b01..62c9599 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -28,6 +28,14 @@
                                  AssemblerPredicate<"FeatureNEON", "neon">;
 def HasCrypto        : Predicate<"Subtarget->hasCrypto()">,
                                  AssemblerPredicate<"FeatureCrypto", "crypto">;
+def HasSM4           : Predicate<"Subtarget->hasSM4()">,
+                                 AssemblerPredicate<"FeatureSM4", "sm4">;
+def HasSHA3          : Predicate<"Subtarget->hasSHA3()">,
+                                 AssemblerPredicate<"FeatureSHA3", "sha3">;
+def HasSHA2          : Predicate<"Subtarget->hasSHA2()">,
+                                 AssemblerPredicate<"FeatureSHA2", "sha2">;
+def HasAES           : Predicate<"Subtarget->hasAES()">,
+                                 AssemblerPredicate<"FeatureAES", "aes">;
 def HasDotProd       : Predicate<"Subtarget->hasDotProd()">,
                                  AssemblerPredicate<"FeatureDotProd", "dotprod">;
 def HasCRC           : Predicate<"Subtarget->hasCRC()">,
@@ -467,6 +475,30 @@
 defm UDOTlane : SIMDThreeSameVectorDotIndex<1, "udot", int_aarch64_neon_udot>;
 }
 
+// Armv8.2-A Crypto extensions
+let Predicates = [HasSHA3] in {
+def SHA512H   : CryptoRRRTied<0b0, 0b00, "sha512h">;
+def SHA512H2  : CryptoRRRTied<0b0, 0b01, "sha512h2">;
+def SHA512SU0 : CryptoRRTied_2D<0b0, 0b00, "sha512su0">;
+def SHA512SU1 : CryptoRRRTied_2D<0b0, 0b10, "sha512su1">;
+def RAX1      : CryptoRRR_2D<0b0,0b11, "rax1">;
+def EOR3      : CryptoRRRR_16B<0b00, "eor3">;
+def BCAX      : CryptoRRRR_16B<0b01, "bcax">;
+def XAR       : CryptoRRRi6<"xar">;
+} // HasSHA3
+
+let Predicates = [HasSM4] in {
+def SM3TT1A   : CryptoRRRi2Tied<0b0, 0b00, "sm3tt1a">;
+def SM3TT1B   : CryptoRRRi2Tied<0b0, 0b01, "sm3tt1b">;
+def SM3TT2A   : CryptoRRRi2Tied<0b0, 0b10, "sm3tt2a">;
+def SM3TT2B   : CryptoRRRi2Tied<0b0, 0b11, "sm3tt2b">;
+def SM3SS1    : CryptoRRRR_4S<0b10, "sm3ss1">;
+def SM3PARTW1 : CryptoRRRTied_4S<0b1, 0b00, "sm3partw1">;
+def SM3PARTW2 : CryptoRRRTied_4S<0b1, 0b01, "sm3partw2">;
+def SM4ENCKEY : CryptoRRR_4S<0b1, 0b10, "sm4ekey">;
+def SM4E      : CryptoRRTied_4S<0b0, 0b01, "sm4e">;
+} // HasSM4
+
 let Predicates = [HasRCPC] in {
   // v8.3 Release Consistent Processor Consistent support, optional in v8.2.
   def LDAPRB  : RCPCLoad<0b00, "ldaprb", GPR32>;
@@ -555,7 +587,7 @@
     let Inst{31} = 0;
   }
 
-} // HasV8_3A
+} // HasV8_3a
 
 // v8.4 Flag manipulation instructions
 let Predicates = [HasV8_4a] in {
@@ -5606,10 +5638,12 @@
 // Crypto extensions
 //----------------------------------------------------------------------------
 
+let Predicates = [HasAES] in {
 def AESErr   : AESTiedInst<0b0100, "aese",   int_aarch64_crypto_aese>;
 def AESDrr   : AESTiedInst<0b0101, "aesd",   int_aarch64_crypto_aesd>;
 def AESMCrr  : AESInst<    0b0110, "aesmc",  int_aarch64_crypto_aesmc>;
 def AESIMCrr : AESInst<    0b0111, "aesimc", int_aarch64_crypto_aesimc>;
+}
 
 // Pseudo instructions for AESMCrr/AESIMCrr with a register constraint required
 // for AES fusion on some CPUs.
@@ -5636,6 +5670,7 @@
                                               (v16i8 V128:$src2)))))>,
           Requires<[HasFuseAES]>;
 
+let Predicates = [HasSHA2] in {
 def SHA1Crrr     : SHATiedInstQSV<0b000, "sha1c",   int_aarch64_crypto_sha1c>;
 def SHA1Prrr     : SHATiedInstQSV<0b001, "sha1p",   int_aarch64_crypto_sha1p>;
 def SHA1Mrrr     : SHATiedInstQSV<0b010, "sha1m",   int_aarch64_crypto_sha1m>;
@@ -5647,6 +5682,7 @@
 def SHA1Hrr     : SHAInstSS<    0b0000, "sha1h",    int_aarch64_crypto_sha1h>;
 def SHA1SU1rr   : SHATiedInstVV<0b0001, "sha1su1",  int_aarch64_crypto_sha1su1>;
 def SHA256SU0rr : SHATiedInstVV<0b0010, "sha256su0",int_aarch64_crypto_sha256su0>;
+}
 
 //----------------------------------------------------------------------------
 // Compiler-pseudos
diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
index 9c8f9d1..a51c41d 100644
--- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
+++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
@@ -4749,7 +4749,11 @@
   const char *Name;
   const FeatureBitset Features;
 } ExtensionMap[] = {
-  { "crc", {AArch64::FeatureCRC} },
+  { "crc",  {AArch64::FeatureCRC} },
+  { "sm4",  {AArch64::FeatureSM4} },
+  { "sha3", {AArch64::FeatureSHA3} },
+  { "sha2", {AArch64::FeatureSHA2} },
+  { "aes",  {AArch64::FeatureAES} },
   { "crypto", {AArch64::FeatureCrypto} },
   { "fp", {AArch64::FeatureFPARMv8} },
   { "simd", {AArch64::FeatureNEON} },
@@ -4763,6 +4767,54 @@
   { "profile", {} },
 };
 
+static void ExpandCryptoAEK(AArch64::ArchKind ArchKind,
+                            SmallVector<StringRef, 4> &RequestedExtensions) {
+  const bool NoCrypto =
+      (std::find(RequestedExtensions.begin(), RequestedExtensions.end(),
+                 "nocrypto") != std::end(RequestedExtensions));
+  const bool Crypto =
+      (std::find(RequestedExtensions.begin(), RequestedExtensions.end(),
+                 "crypto") != std::end(RequestedExtensions));
+
+  if (!NoCrypto && Crypto) {
+    switch (ArchKind) {
+    default:
+      // Map 'generic' (and others) to sha2 and aes, because
+      // that was the traditional meaning of crypto.
+    case AArch64::ArchKind::ARMV8_1A:
+    case AArch64::ArchKind::ARMV8_2A:
+    case AArch64::ArchKind::ARMV8_3A:
+      RequestedExtensions.push_back("sha2");
+      RequestedExtensions.push_back("aes");
+      break;
+    case AArch64::ArchKind::ARMV8_4A:
+      RequestedExtensions.push_back("sm4");
+      RequestedExtensions.push_back("sha3");
+      RequestedExtensions.push_back("sha2");
+      RequestedExtensions.push_back("aes");
+      break;
+    }
+  } else if (NoCrypto) {
+    switch (ArchKind) {
+    default:
+      // Map 'generic' (and others) to sha2 and aes, because
+      // that was the traditional meaning of crypto.
+    case AArch64::ArchKind::ARMV8_1A:
+    case AArch64::ArchKind::ARMV8_2A:
+    case AArch64::ArchKind::ARMV8_3A:
+      RequestedExtensions.push_back("nosha2");
+      RequestedExtensions.push_back("noaes");
+      break;
+    case AArch64::ArchKind::ARMV8_4A:
+      RequestedExtensions.push_back("nosm4");
+      RequestedExtensions.push_back("nosha3");
+      RequestedExtensions.push_back("nosha2");
+      RequestedExtensions.push_back("noaes");
+      break;
+    }
+  }
+}
+
 /// parseDirectiveArch
 ///   ::= .arch token
 bool AArch64AsmParser::parseDirectiveArch(SMLoc L) {
@@ -4793,6 +4845,8 @@
   if (!ExtensionString.empty())
     ExtensionString.split(RequestedExtensions, '+');
 
+  ExpandCryptoAEK(ID, RequestedExtensions);
+
   FeatureBitset Features = STI.getFeatureBits();
   for (auto Name : RequestedExtensions) {
     bool EnableFeature = true;
@@ -4852,6 +4906,8 @@
   STI.setDefaultFeatures(CPU, "");
   CurLoc = incrementLoc(CurLoc, CPU.size());
 
+  ExpandCryptoAEK(llvm::AArch64::getCPUArchKind(CPU), RequestedExtensions);
+
   FeatureBitset Features = STI.getFeatureBits();
   for (auto Name : RequestedExtensions) {
     // Advance source location past '+'.