[ARM] Armv8.2-A FP16 code generation (part 2/3)

Half-precision arguments and return values are passed as if it were an int or
float for ARM. This results in truncates and bitcasts to/from i16 and f16
values, which are legalized very early to stack stores/loads. When FullFP16 is
enabled, we want to avoid codegen for these bitcasts as it is unnecessary and
inefficient.

Differential Revision: https://reviews.llvm.org/D42580

llvm-svn: 323861
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 096bee8..cddd957 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -524,9 +524,9 @@
 
   if (Subtarget->hasFullFP16()) {
     addRegisterClass(MVT::f16, &ARM::HPRRegClass);
-    // Clean up bitcast of incoming arguments if hard float abi is enabled.
-    if (Subtarget->isTargetHardFloat())
-      setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+    setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+    setOperationAction(ISD::BITCAST, MVT::i32, Custom);
+    setOperationAction(ISD::BITCAST, MVT::f16, Custom);
   }
 
   for (MVT VT : MVT::vector_valuetypes()) {
@@ -1273,6 +1273,8 @@
 
   case ARMISD::VMOVRRD:       return "ARMISD::VMOVRRD";
   case ARMISD::VMOVDRR:       return "ARMISD::VMOVDRR";
+  case ARMISD::VMOVhr:        return "ARMISD::VMOVhr";
+  case ARMISD::VMOVrh:        return "ARMISD::VMOVrh";
 
   case ARMISD::EH_SJLJ_SETJMP: return "ARMISD::EH_SJLJ_SETJMP";
   case ARMISD::EH_SJLJ_LONGJMP: return "ARMISD::EH_SJLJ_LONGJMP";
@@ -5051,7 +5053,8 @@
 /// use a VMOVDRR or VMOVRRD node.  This should not be done when the non-i64
 /// operand type is illegal (e.g., v2f32 for a target that doesn't support
 /// vectors), since the legalizer won't know what to do with that.
-static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG) {
+static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG,
+                             const ARMSubtarget *Subtarget) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   SDLoc dl(N);
   SDValue Op = N->getOperand(0);
@@ -5060,39 +5063,78 @@
   // source or destination of the bit convert.
   EVT SrcVT = Op.getValueType();
   EVT DstVT = N->getValueType(0);
+  const bool HasFullFP16 = Subtarget->hasFullFP16();
 
-  // Half-precision arguments can be passed in like this:
-  //
-  //    t4: f32,ch = CopyFromReg t0, Register:f32 %1
-  //            t8: i32 = bitcast t4
-  //          t9: i16 = truncate t8
-  //        t10: f16 = bitcast t9   <~~~~ SDNode N
-  //
-  // but we want to avoid code generation for the bitcast, so transform this
-  // into:
-  //
-  // t18: f16 = CopyFromReg t0, Register:f32 %0
-  //
-  if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
-     if (Op.getOpcode() != ISD::TRUNCATE)
-        return SDValue();
+  if (SrcVT == MVT::f32 && DstVT == MVT::i32) {
+     // FullFP16: half values are passed in S-registers, and we don't
+     // need any of the bitcast and moves:
+     //
+     // t2: f32,ch = CopyFromReg t0, Register:f32 %0
+     //   t5: i32 = bitcast t2
+     // t18: f16 = ARMISD::VMOVhr t5
+     if (Op.getOpcode() != ISD::CopyFromReg ||
+         Op.getValueType() != MVT::f32)
+       return SDValue();
 
-    SDValue Bitcast = Op.getOperand(0);
-    if (Bitcast.getOpcode() != ISD::BITCAST ||
-        Bitcast.getValueType() != MVT::i32)
-      return SDValue();
+     auto Move = N->use_begin();
+     if (Move->getOpcode() != ARMISD::VMOVhr)
+       return SDValue();
 
-    SDValue Copy = Bitcast.getOperand(0);
-    if (Copy.getOpcode() != ISD::CopyFromReg ||
-        Copy.getValueType() != MVT::f32)
-      return SDValue();
-
-    SDValue Ops[] = { Copy->getOperand(0), Copy->getOperand(1) };
-    return DAG.getNode(ISD::CopyFromReg, SDLoc(Copy), MVT::f16, Ops);
+     SDValue Ops[] = { Op.getOperand(0), Op.getOperand(1) };
+     SDValue Copy = DAG.getNode(ISD::CopyFromReg, SDLoc(Op), MVT::f16, Ops);
+     DAG.ReplaceAllUsesWith(*Move, &Copy);
+     return Copy;
   }
 
-  assert((SrcVT == MVT::i64 || DstVT == MVT::i64) &&
-         "ExpandBITCAST called for non-i64 type");
+  if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
+    if (!HasFullFP16)
+      return SDValue();
+    // SoftFP: read half-precision arguments:
+    //
+    // t2: i32,ch = ... 
+    //        t7: i16 = truncate t2 <~~~~ Op
+    //      t8: f16 = bitcast t7    <~~~~ N
+    //
+    if (Op.getOperand(0).getValueType() == MVT::i32)
+      return DAG.getNode(ARMISD::VMOVhr, SDLoc(Op),
+                         MVT::f16, Op.getOperand(0));
+
+    return SDValue();
+  }
+
+  // Half-precision return values 
+  if (SrcVT == MVT::f16 && DstVT == MVT::i16) {
+    if (!HasFullFP16)
+      return SDValue();
+    //
+    //          t11: f16 = fadd t8, t10
+    //        t12: i16 = bitcast t11       <~~~ SDNode N
+    //      t13: i32 = zero_extend t12
+    //    t16: ch,glue = CopyToReg t0, Register:i32 %r0, t13
+    //  t17: ch = ARMISD::RET_FLAG t16, Register:i32 %r0, t16:1
+    //
+    // transform this into:
+    //
+    //    t20: i32 = ARMISD::VMOVrh t11
+    //  t16: ch,glue = CopyToReg t0, Register:i32 %r0, t20
+    //
+    auto ZeroExtend = N->use_begin();
+    if (N->use_size() != 1 || ZeroExtend->getOpcode() != ISD::ZERO_EXTEND ||
+        ZeroExtend->getValueType(0) != MVT::i32)
+      return SDValue();
+
+    auto Copy = ZeroExtend->use_begin();
+    if (Copy->getOpcode() == ISD::CopyToReg &&
+        Copy->use_begin()->getOpcode() == ARMISD::RET_FLAG) {
+      SDValue Cvt = DAG.getNode(ARMISD::VMOVrh, SDLoc(Op), MVT::i32, Op);
+      DAG.ReplaceAllUsesWith(*ZeroExtend, &Cvt);
+      return Cvt;
+    }
+    return SDValue();
+  }
+
+  if (!(SrcVT == MVT::i64 || DstVT == MVT::i64))
+    return SDValue();
 
   // Turn i64->f64 into VMOVDRR.
   if (SrcVT == MVT::i64 && TLI.isTypeLegal(DstVT)) {
@@ -7982,7 +8024,7 @@
   case ISD::EH_SJLJ_SETUP_DISPATCH: return LowerEH_SJLJ_SETUP_DISPATCH(Op, DAG);
   case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG,
                                                                Subtarget);
-  case ISD::BITCAST:       return ExpandBITCAST(Op.getNode(), DAG);
+  case ISD::BITCAST:       return ExpandBITCAST(Op.getNode(), DAG, Subtarget);
   case ISD::SHL:
   case ISD::SRL:
   case ISD::SRA:           return LowerShift(Op.getNode(), DAG, Subtarget);
@@ -8084,7 +8126,7 @@
     ExpandREAD_REGISTER(N, Results, DAG);
     break;
   case ISD::BITCAST:
-    Res = ExpandBITCAST(N, DAG);
+    Res = ExpandBITCAST(N, DAG, Subtarget);
     break;
   case ISD::SRL:
   case ISD::SRA:
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index aa80f9a..b096331 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -171,6 +171,10 @@
       // Vector move f32 immediate:
       VMOVFPIMM,
 
+      // Move H <-> R, clearing top 16 bits
+      VMOVrh,
+      VMOVhr,
+
       // Vector duplicate:
       VDUP,
       VDUPLANE,
diff --git a/llvm/lib/Target/ARM/ARMInstrVFP.td b/llvm/lib/Target/ARM/ARMInstrVFP.td
index bfc6cd5..48c1a38 100644
--- a/llvm/lib/Target/ARM/ARMInstrVFP.td
+++ b/llvm/lib/Target/ARM/ARMInstrVFP.td
@@ -23,6 +23,11 @@
 def arm_fmdrr  : SDNode<"ARMISD::VMOVDRR", SDT_VMOVDRR>;
 def arm_fmrrd  : SDNode<"ARMISD::VMOVRRD", SDT_VMOVRRD>;
 
+def SDT_VMOVhr : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, i32>] >;
+def SDT_VMOVrh : SDTypeProfile<1, 1, [SDTCisVT<0, i32>, SDTCisFP<1>] >;
+def arm_vmovhr : SDNode<"ARMISD::VMOVhr", SDT_VMOVhr>;
+def arm_vmovrh : SDNode<"ARMISD::VMOVrh", SDT_VMOVrh>;
+
 //===----------------------------------------------------------------------===//
 // Operand Definitions.
 //
@@ -1171,9 +1176,9 @@
 
 // Move H->R, clearing top 16 bits
 def VMOVRH : AVConv2I<0b11100001, 0b1001,
-                      (outs GPR:$Rt), (ins SPR:$Sn),
+                      (outs GPR:$Rt), (ins HPR:$Sn),
                       IIC_fpMOVSI, "vmov", ".f16\t$Rt, $Sn",
-                      []>,
+                      [(set GPR:$Rt, (arm_vmovrh HPR:$Sn))]>,
              Requires<[HasFullFP16]>,
              Sched<[WriteFPMOV]> {
   // Instruction operands.
@@ -1191,9 +1196,9 @@
 
 // Move R->H, clearing top 16 bits
 def VMOVHR : AVConv4I<0b11100000, 0b1001,
-                      (outs SPR:$Sn), (ins GPR:$Rt),
+                      (outs HPR:$Sn), (ins GPR:$Rt),
                       IIC_fpMOVIS, "vmov", ".f16\t$Sn, $Rt",
-                      []>,
+                      [(set HPR:$Sn, (arm_vmovhr GPR:$Rt))]>,
              Requires<[HasFullFP16]>,
              Sched<[WriteFPMOV]> {
   // Instruction operands.