[ARM/AArch64] Support FP16 +fp16fml instructions

Add +fp16fml feature for new FP16 instructions, which are a
mandatory part of FP16 from v8.4-A and an optional part of FP16
from v8.2-A. It doesn't seem to be possible to model this in
LLVM, but the relationship between the options is handled by
the related clang patch.

In keeping with what I think is the usual practice, the fp16fml
extension is accepted regardless of base architecture version.

Builds on/replaces Sjoerd Meijer's patch to add these instructions at
https://reviews.llvm.org/D49839.

Differential Revision: https://reviews.llvm.org/D50228

llvm-svn: 340013
diff --git a/llvm/lib/Target/AArch64/AArch64.td b/llvm/lib/Target/AArch64/AArch64.td
index a69d381..2660846 100644
--- a/llvm/lib/Target/AArch64/AArch64.td
+++ b/llvm/lib/Target/AArch64/AArch64.td
@@ -71,6 +71,9 @@
 def FeatureFullFP16 : SubtargetFeature<"fullfp16", "HasFullFP16", "true",
   "Full FP16", [FeatureFPARMv8]>;
 
+def FeatureFP16FML : SubtargetFeature<"fp16fml", "HasFP16FML", "true",
+  "Enable FP16 FML instructions", [FeatureFullFP16]>;
+
 def FeatureSPE : SubtargetFeature<"spe", "HasSPE", "true",
   "Enable Statistical Profiling extension">;
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 7caf32d..123cddd 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -4790,6 +4790,14 @@
   let Inst{4-0}   = Rd;
 }
 
+let Predicates = [HasNEON, HasFP16FML] in
+class BaseSIMDThreeSameMult<bit Q, bit U, bit b13, bits<3> size, string asm, string kind1,
+                                 string kind2> :
+        BaseSIMDThreeSameVector<Q, U, size, 0b11101, V128, asm, kind1, [] > {
+  let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # "}");
+  let Inst{13} = b13;
+}
+
 class BaseSIMDThreeSameVectorDot<bit Q, bit U, string asm, string kind1,
                                  string kind2, RegisterOperand RegType,
                                  ValueType AccumType, ValueType InputType,
@@ -7255,6 +7263,20 @@
   let Inst{11}    = idx{1};  // H
 }
 
+let Predicates = [HasNEON, HasFP16FML] in
+class BaseSIMDThreeSameMultIndex<bit Q, bit U, bits<4> opc, string asm,
+                                 string dst_kind, string lhs_kind,
+                                 string rhs_kind> :
+        BaseSIMDIndexedTied<Q, U, 0, 0b10, opc, V128, V128, V128,
+                            VectorIndexH, asm, "", dst_kind, lhs_kind,
+                            rhs_kind, []> {
+  //idx = H:L:M
+  bits<3> idx;
+  let Inst{11} = idx{2}; // H
+  let Inst{21} = idx{1}; // L
+  let Inst{20} = idx{0}; // M
+}
+
 multiclass SIMDThreeSameVectorDotIndex<bit U, string asm,
                                        SDPatternOperator OpNode> {
   def v8i8  : BaseSIMDThreeSameVectorDotIndex<0, U, asm, ".2s", ".8b", ".4b", V64,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index d6b8bb5..d89ff41 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -49,6 +49,8 @@
 def HasPerfMon       : Predicate<"Subtarget->hasPerfMon()">;
 def HasFullFP16      : Predicate<"Subtarget->hasFullFP16()">,
                                  AssemblerPredicate<"FeatureFullFP16", "fullfp16">;
+def HasFP16FML       : Predicate<"Subtarget->hasFP16FML()">,
+                                 AssemblerPredicate<"FeatureFP16FML", "fp16fml">;
 def HasSPE           : Predicate<"Subtarget->hasSPE()">,
                                  AssemblerPredicate<"FeatureSPE", "spe">;
 def HasFuseAES       : Predicate<"Subtarget->hasFuseAES()">,
@@ -3299,6 +3301,24 @@
 defm SQRDMLSH : SIMDThreeSameVectorSQRDMLxHTiedHS<1,0b10001,"sqrdmlsh",
                                                     int_aarch64_neon_sqsub>;
 
+// FP16FML
+def FMLAL_2S   : BaseSIMDThreeSameMult<0, 0, 1, 0b001, "fmlal", ".2s", ".2h">;
+def FMLSL_2S   : BaseSIMDThreeSameMult<0, 0, 1, 0b101, "fmlsl", ".2s", ".2h">;
+def FMLAL_4S   : BaseSIMDThreeSameMult<1, 0, 1, 0b001, "fmlal", ".4s", ".4h">;
+def FMLSL_4S   : BaseSIMDThreeSameMult<1, 0, 1, 0b101, "fmlsl", ".4s", ".4h">;
+def FMLAL2_2S  : BaseSIMDThreeSameMult<0, 1, 0, 0b001, "fmlal2", ".2s", ".2h">;
+def FMLSL2_2S  : BaseSIMDThreeSameMult<0, 1, 0, 0b101, "fmlsl2", ".2s", ".2h">;
+def FMLAL2_4S  : BaseSIMDThreeSameMult<1, 1, 0, 0b001, "fmlal2", ".4s", ".4h">;
+def FMLSL2_4S  : BaseSIMDThreeSameMult<1, 1, 0, 0b101, "fmlsl2", ".4s", ".4h">;
+def FMLALI_2s  : BaseSIMDThreeSameMultIndex<0, 0, 0b0000, "fmlal", ".2s", ".2h", ".h">;
+def FMLSLI_2s  : BaseSIMDThreeSameMultIndex<0, 0, 0b0100, "fmlsl", ".2s", ".2h", ".h">;
+def FMLALI_4s  : BaseSIMDThreeSameMultIndex<1, 0, 0b0000, "fmlal", ".4s", ".4h", ".h">;
+def FMLSLI_4s  : BaseSIMDThreeSameMultIndex<1, 0, 0b0100, "fmlsl", ".4s", ".4h", ".h">;
+def FMLALI2_2s : BaseSIMDThreeSameMultIndex<0, 1, 0b1000, "fmlal2", ".2s", ".2h", ".h">;
+def FMLSLI2_2s : BaseSIMDThreeSameMultIndex<0, 1, 0b1100, "fmlsl2", ".2s", ".2h", ".h">;
+def FMLALI2_4s : BaseSIMDThreeSameMultIndex<1, 1, 0b1000, "fmlal2", ".4s", ".4h", ".h">;
+def FMLSLI2_4s : BaseSIMDThreeSameMultIndex<1, 1, 0b1100, "fmlsl2", ".4s", ".4h", ".h">;
+
 defm AND : SIMDLogicalThreeVector<0, 0b00, "and", and>;
 defm BIC : SIMDLogicalThreeVector<0, 0b01, "bic",
                                   BinOpFrag<(and node:$LHS, (vnot node:$RHS))> >;
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 5af4c0d..6f08ca0 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -78,6 +78,7 @@
   bool HasRDM = false;
   bool HasPerfMon = false;
   bool HasFullFP16 = false;
+  bool HasFP16FML = false;
   bool HasSPE = false;
 
   // ARMv8.4 Crypto extensions
@@ -291,6 +292,7 @@
 
   bool hasPerfMon() const { return HasPerfMon; }
   bool hasFullFP16() const { return HasFullFP16; }
+  bool hasFP16FML() const { return HasFP16FML; }
   bool hasSPE() const { return HasSPE; }
   bool hasLSLFast() const { return HasLSLFast; }
   bool hasSVE() const { return HasSVE; }
diff --git a/llvm/lib/Target/ARM/ARM.td b/llvm/lib/Target/ARM/ARM.td
index 43e28f6..f08afa8 100644
--- a/llvm/lib/Target/ARM/ARM.td
+++ b/llvm/lib/Target/ARM/ARM.td
@@ -61,6 +61,11 @@
                                              "floating point",
                                              [FeatureFPARMv8]>;
 
+def FeatureFP16FML        : SubtargetFeature<"fp16fml", "HasFP16FML", "true",
+                                             "Enable full half-precision "
+                                             "floating point fml instructions",
+                                             [FeatureFullFP16]>;
+
 def FeatureVFPOnlySP      : SubtargetFeature<"fp-only-sp", "FPOnlySP", "true",
                                              "Floating point unit supports "
                                              "single precision only">;
diff --git a/llvm/lib/Target/ARM/ARMInstrFormats.td b/llvm/lib/Target/ARM/ARMInstrFormats.td
index 70aded2..87a2802 100644
--- a/llvm/lib/Target/ARM/ARMInstrFormats.td
+++ b/llvm/lib/Target/ARM/ARMInstrFormats.td
@@ -2579,6 +2579,37 @@
   let Inst{3-0}   = Vm{3-0};
 }
 
+// In Armv8.2-A, some NEON instructions are added that encode Vn and Vm
+// differently:
+//    if Q == ‘1’ then UInt(N:Vn) else UInt(Vn:N);
+//    if Q == ‘1’ then UInt(M:Vm) else UInt(Vm:M);
+// Class N3VCP8 above describes the Q=1 case, and this class the Q=0 case.
+class N3VCP8Q0<bits<2> op24_23, bits<2> op21_20, bit op6, bit op4,
+             dag oops, dag iops, InstrItinClass itin,
+             string opc, string dt, string asm, string cstr, list<dag> pattern>
+  : NeonInp<oops, iops, AddrModeNone, IndexModeNone, N3RegCplxFrm, itin, opc, dt, asm, cstr, pattern> {
+  bits<5> Vd;
+  bits<5> Vn;
+  bits<5> Vm;
+
+  let DecoderNamespace = "VFPV8";
+  // These have the same encodings in ARM and Thumb2
+  let PostEncoderMethod = "";
+
+  let Inst{31-25} = 0b1111110;
+  let Inst{24-23} = op24_23;
+  let Inst{22}    = Vd{4};
+  let Inst{21-20} = op21_20;
+  let Inst{19-16} = Vn{4-1};
+  let Inst{15-12} = Vd{3-0};
+  let Inst{11-8}  = 0b1000;
+  let Inst{7}     = Vn{0};
+  let Inst{6}     = op6;
+  let Inst{5}     = Vm{0};
+  let Inst{4}     = op4;
+  let Inst{3-0}   = Vm{4-1};
+}
+
 // Operand types for complex instructions
 class ComplexRotationOperand<int Angle, int Remainder, string Type, string Diag>
   : AsmOperandClass {
diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td
index d4c342c..e6d85be 100644
--- a/llvm/lib/Target/ARM/ARMInstrInfo.td
+++ b/llvm/lib/Target/ARM/ARMInstrInfo.td
@@ -285,6 +285,8 @@
                                  AssemblerPredicate<"FeatureFP16","half-float conversions">;
 def HasFullFP16      : Predicate<"Subtarget->hasFullFP16()">,
                                  AssemblerPredicate<"FeatureFullFP16","full half-float">;
+def HasFP16FML       : Predicate<"Subtarget->hasFP16FML()">,
+                                 AssemblerPredicate<"FeatureFP16FML","full half-float fml">;
 def HasDivideInThumb : Predicate<"Subtarget->hasDivideInThumbMode()">,
                                  AssemblerPredicate<"FeatureHWDivThumb", "divide in THUMB">;
 def HasDivideInARM   : Predicate<"Subtarget->hasDivideInARMMode()">,
diff --git a/llvm/lib/Target/ARM/ARMInstrNEON.td b/llvm/lib/Target/ARM/ARMInstrNEON.td
index 5a300c7..a7bb32d 100644
--- a/llvm/lib/Target/ARM/ARMInstrNEON.td
+++ b/llvm/lib/Target/ARM/ARMInstrNEON.td
@@ -5109,6 +5109,54 @@
                    (VACGEhq QPR:$Vd, QPR:$Vm, QPR:$Vn, pred:$p)>;
 }
 
+// +fp16fml Floating Point Multiplication Variants
+let Predicates = [HasNEON, HasFP16FML], DecoderNamespace= "VFPV8" in {
+
+class N3VCP8F16Q1<string asm, RegisterClass Td, RegisterClass Tn,
+                RegisterClass Tm, bits<2> op1, bits<2> op2, bit op3>
+  : N3VCP8<op1, op2, 1, op3, (outs Td:$Vd), (ins Tn:$Vn, Tm:$Vm), NoItinerary,
+           asm, "f16", "$Vd, $Vn, $Vm", "", []>;
+
+class N3VCP8F16Q0<string asm, RegisterClass Td, RegisterClass Tn,
+                RegisterClass Tm, bits<2> op1, bits<2> op2, bit op3>
+  : N3VCP8Q0<op1, op2, 0, op3, (outs Td:$Vd), (ins Tn:$Vn, Tm:$Vm), NoItinerary,
+           asm, "f16", "$Vd, $Vn, $Vm", "", []>;
+
+class VFMQ0<string opc, bits<2> S>
+  : N3VLaneCP8<0, S, 0, 1, (outs DPR:$Vd),
+               (ins SPR:$Vn, SPR:$Vm, VectorIndex32:$idx),
+               IIC_VMACD, opc, "f16", "$Vd, $Vn, $Vm$idx", "", []> {
+  bit idx;
+  let Inst{3} = idx;
+  let Inst{19-16} = Vn{4-1};
+  let Inst{7}     = Vn{0};
+  let Inst{5}     = Vm{0};
+  let Inst{2-0}   = Vm{3-1};
+}
+
+class VFMQ1<string opc, bits<2> S>
+  : N3VLaneCP8<0, S, 1, 1, (outs QPR:$Vd),
+               (ins DPR:$Vn, DPR:$Vm, VectorIndex16:$idx),
+               IIC_VMACD, opc, "f16", "$Vd, $Vn, $Vm$idx", "", []> {
+  bits<2> idx;
+  let Inst{5} = idx{1};
+  let Inst{3} = idx{0};
+}
+
+let hasNoSchedulingInfo = 1 in {
+//                                                op1   op2   op3
+def VFMALD  : N3VCP8F16Q0<"vfmal", DPR, SPR, SPR, 0b00, 0b10, 1>;
+def VFMSLD  : N3VCP8F16Q0<"vfmsl", DPR, SPR, SPR, 0b01, 0b10, 1>;
+def VFMALQ  : N3VCP8F16Q1<"vfmal", QPR, DPR, DPR, 0b00, 0b10, 1>;
+def VFMSLQ  : N3VCP8F16Q1<"vfmsl", QPR, DPR, DPR, 0b01, 0b10, 1>;
+def VFMALDI : VFMQ0<"vfmal", 0b00>;
+def VFMSLDI : VFMQ0<"vfmsl", 0b01>;
+def VFMALQI : VFMQ1<"vfmal", 0b00>;
+def VFMSLQI : VFMQ1<"vfmsl", 0b01>;
+}
+} // HasNEON, HasFP16FML
+
+
 def: NEONInstAlias<"vaclt${p}.f32 $Vd, $Vm",
                    (VACGTfd DPR:$Vd, DPR:$Vm, DPR:$Vd, pred:$p)>;
 def: NEONInstAlias<"vaclt${p}.f32 $Vd, $Vm",
diff --git a/llvm/lib/Target/ARM/ARMSubtarget.h b/llvm/lib/Target/ARM/ARMSubtarget.h
index 69bc3ea..d7bfa89 100644
--- a/llvm/lib/Target/ARM/ARMSubtarget.h
+++ b/llvm/lib/Target/ARM/ARMSubtarget.h
@@ -227,6 +227,9 @@
   /// HasFullFP16 - True if subtarget supports half-precision FP operations
   bool HasFullFP16 = false;
 
+  /// HasFP16FML - True if subtarget supports half-precision FP fml operations
+  bool HasFP16FML = false;
+
   /// HasD16 - True if subtarget is limited to 16 double precision
   /// FP registers for VFPv3.
   bool HasD16 = false;
@@ -622,6 +625,7 @@
   bool hasFP16() const { return HasFP16; }
   bool hasD16() const { return HasD16; }
   bool hasFullFP16() const { return HasFullFP16; }
+  bool hasFP16FML() const { return HasFP16FML; }
 
   bool hasFuseAES() const { return HasFuseAES; }
   bool hasFuseLiterals() const { return HasFuseLiterals; }
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index e0cd2d8..7d14bd7 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -57,7 +57,7 @@
   const FeatureBitset InlineFeatureWhitelist = {
       ARM::FeatureVFP2, ARM::FeatureVFP3, ARM::FeatureNEON, ARM::FeatureThumb2,
       ARM::FeatureFP16, ARM::FeatureVFP4, ARM::FeatureFPARMv8,
-      ARM::FeatureFullFP16, ARM::FeatureHWDivThumb,
+      ARM::FeatureFullFP16, ARM::FeatureFP16FML, ARM::FeatureHWDivThumb,
       ARM::FeatureHWDivARM, ARM::FeatureDB, ARM::FeatureV7Clrex,
       ARM::FeatureAcquireRelease, ARM::FeatureSlowFPBrcc,
       ARM::FeaturePerfMon, ARM::FeatureTrustZone, ARM::Feature8MSecExt,
diff --git a/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp b/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp
index a5fbbbf..da6296a 100644
--- a/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp
+++ b/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp
@@ -5626,7 +5626,8 @@
       Mnemonic.startswith("vsel") || Mnemonic == "vins" || Mnemonic == "vmovx" ||
       Mnemonic == "bxns"  || Mnemonic == "blxns" ||
       Mnemonic == "vudot" || Mnemonic == "vsdot" ||
-      Mnemonic == "vcmla" || Mnemonic == "vcadd")
+      Mnemonic == "vcmla" || Mnemonic == "vcadd" ||
+      Mnemonic == "vfmal" || Mnemonic == "vfmsl")
     return Mnemonic;
 
   // First, split out any predication code. Ignore mnemonics we know aren't
@@ -5716,7 +5717,8 @@
       (FullInst.startswith("vmull") && FullInst.endswith(".p64")) ||
       Mnemonic == "vmovx" || Mnemonic == "vins" ||
       Mnemonic == "vudot" || Mnemonic == "vsdot" ||
-      Mnemonic == "vcmla" || Mnemonic == "vcadd") {
+      Mnemonic == "vcmla" || Mnemonic == "vcadd" ||
+      Mnemonic == "vfmal" || Mnemonic == "vfmsl") {
     // These mnemonics are never predicable
     CanAcceptPredicationCode = false;
   } else if (!isThumb()) {