[ARM] Armv8.2-A FP16 code generation (part 2/3)
Half-precision arguments and return values are passed as if it were an int or
float for ARM. This results in truncates and bitcasts to/from i16 and f16
values, which are legalized very early to stack stores/loads. When FullFP16 is
enabled, we want to avoid codegen for these bitcasts as it is unnecessary and
inefficient.
Differential Revision: https://reviews.llvm.org/D42580
llvm-svn: 323861
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 096bee8..cddd957 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -524,9 +524,9 @@
if (Subtarget->hasFullFP16()) {
addRegisterClass(MVT::f16, &ARM::HPRRegClass);
- // Clean up bitcast of incoming arguments if hard float abi is enabled.
- if (Subtarget->isTargetHardFloat())
- setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+ setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+ setOperationAction(ISD::BITCAST, MVT::i32, Custom);
+ setOperationAction(ISD::BITCAST, MVT::f16, Custom);
}
for (MVT VT : MVT::vector_valuetypes()) {
@@ -1273,6 +1273,8 @@
case ARMISD::VMOVRRD: return "ARMISD::VMOVRRD";
case ARMISD::VMOVDRR: return "ARMISD::VMOVDRR";
+ case ARMISD::VMOVhr: return "ARMISD::VMOVhr";
+ case ARMISD::VMOVrh: return "ARMISD::VMOVrh";
case ARMISD::EH_SJLJ_SETJMP: return "ARMISD::EH_SJLJ_SETJMP";
case ARMISD::EH_SJLJ_LONGJMP: return "ARMISD::EH_SJLJ_LONGJMP";
@@ -5051,7 +5053,8 @@
/// use a VMOVDRR or VMOVRRD node. This should not be done when the non-i64
/// operand type is illegal (e.g., v2f32 for a target that doesn't support
/// vectors), since the legalizer won't know what to do with that.
-static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG) {
+static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG,
+ const ARMSubtarget *Subtarget) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDLoc dl(N);
SDValue Op = N->getOperand(0);
@@ -5060,39 +5063,78 @@
// source or destination of the bit convert.
EVT SrcVT = Op.getValueType();
EVT DstVT = N->getValueType(0);
+ const bool HasFullFP16 = Subtarget->hasFullFP16();
- // Half-precision arguments can be passed in like this:
- //
- // t4: f32,ch = CopyFromReg t0, Register:f32 %1
- // t8: i32 = bitcast t4
- // t9: i16 = truncate t8
- // t10: f16 = bitcast t9 <~~~~ SDNode N
- //
- // but we want to avoid code generation for the bitcast, so transform this
- // into:
- //
- // t18: f16 = CopyFromReg t0, Register:f32 %0
- //
- if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
- if (Op.getOpcode() != ISD::TRUNCATE)
- return SDValue();
+ if (SrcVT == MVT::f32 && DstVT == MVT::i32) {
+ // FullFP16: half values are passed in S-registers, and we don't
+ // need any of the bitcast and moves:
+ //
+ // t2: f32,ch = CopyFromReg t0, Register:f32 %0
+ // t5: i32 = bitcast t2
+ // t18: f16 = ARMISD::VMOVhr t5
+ if (Op.getOpcode() != ISD::CopyFromReg ||
+ Op.getValueType() != MVT::f32)
+ return SDValue();
- SDValue Bitcast = Op.getOperand(0);
- if (Bitcast.getOpcode() != ISD::BITCAST ||
- Bitcast.getValueType() != MVT::i32)
- return SDValue();
+ auto Move = N->use_begin();
+ if (Move->getOpcode() != ARMISD::VMOVhr)
+ return SDValue();
- SDValue Copy = Bitcast.getOperand(0);
- if (Copy.getOpcode() != ISD::CopyFromReg ||
- Copy.getValueType() != MVT::f32)
- return SDValue();
-
- SDValue Ops[] = { Copy->getOperand(0), Copy->getOperand(1) };
- return DAG.getNode(ISD::CopyFromReg, SDLoc(Copy), MVT::f16, Ops);
+ SDValue Ops[] = { Op.getOperand(0), Op.getOperand(1) };
+ SDValue Copy = DAG.getNode(ISD::CopyFromReg, SDLoc(Op), MVT::f16, Ops);
+ DAG.ReplaceAllUsesWith(*Move, &Copy);
+ return Copy;
}
- assert((SrcVT == MVT::i64 || DstVT == MVT::i64) &&
- "ExpandBITCAST called for non-i64 type");
+ if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
+ if (!HasFullFP16)
+ return SDValue();
+ // SoftFP: read half-precision arguments:
+ //
+ // t2: i32,ch = ...
+ // t7: i16 = truncate t2 <~~~~ Op
+ // t8: f16 = bitcast t7 <~~~~ N
+ //
+ if (Op.getOperand(0).getValueType() == MVT::i32)
+ return DAG.getNode(ARMISD::VMOVhr, SDLoc(Op),
+ MVT::f16, Op.getOperand(0));
+
+ return SDValue();
+ }
+
+ // Half-precision return values
+ if (SrcVT == MVT::f16 && DstVT == MVT::i16) {
+ if (!HasFullFP16)
+ return SDValue();
+ //
+ // t11: f16 = fadd t8, t10
+ // t12: i16 = bitcast t11 <~~~ SDNode N
+ // t13: i32 = zero_extend t12
+ // t16: ch,glue = CopyToReg t0, Register:i32 %r0, t13
+ // t17: ch = ARMISD::RET_FLAG t16, Register:i32 %r0, t16:1
+ //
+ // transform this into:
+ //
+ // t20: i32 = ARMISD::VMOVrh t11
+ // t16: ch,glue = CopyToReg t0, Register:i32 %r0, t20
+ //
+ auto ZeroExtend = N->use_begin();
+ if (N->use_size() != 1 || ZeroExtend->getOpcode() != ISD::ZERO_EXTEND ||
+ ZeroExtend->getValueType(0) != MVT::i32)
+ return SDValue();
+
+ auto Copy = ZeroExtend->use_begin();
+ if (Copy->getOpcode() == ISD::CopyToReg &&
+ Copy->use_begin()->getOpcode() == ARMISD::RET_FLAG) {
+ SDValue Cvt = DAG.getNode(ARMISD::VMOVrh, SDLoc(Op), MVT::i32, Op);
+ DAG.ReplaceAllUsesWith(*ZeroExtend, &Cvt);
+ return Cvt;
+ }
+ return SDValue();
+ }
+
+ if (!(SrcVT == MVT::i64 || DstVT == MVT::i64))
+ return SDValue();
// Turn i64->f64 into VMOVDRR.
if (SrcVT == MVT::i64 && TLI.isTypeLegal(DstVT)) {
@@ -7982,7 +8024,7 @@
case ISD::EH_SJLJ_SETUP_DISPATCH: return LowerEH_SJLJ_SETUP_DISPATCH(Op, DAG);
case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG,
Subtarget);
- case ISD::BITCAST: return ExpandBITCAST(Op.getNode(), DAG);
+ case ISD::BITCAST: return ExpandBITCAST(Op.getNode(), DAG, Subtarget);
case ISD::SHL:
case ISD::SRL:
case ISD::SRA: return LowerShift(Op.getNode(), DAG, Subtarget);
@@ -8084,7 +8126,7 @@
ExpandREAD_REGISTER(N, Results, DAG);
break;
case ISD::BITCAST:
- Res = ExpandBITCAST(N, DAG);
+ Res = ExpandBITCAST(N, DAG, Subtarget);
break;
case ISD::SRL:
case ISD::SRA:
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index aa80f9a..b096331 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -171,6 +171,10 @@
// Vector move f32 immediate:
VMOVFPIMM,
+ // Move H <-> R, clearing top 16 bits
+ VMOVrh,
+ VMOVhr,
+
// Vector duplicate:
VDUP,
VDUPLANE,
diff --git a/llvm/lib/Target/ARM/ARMInstrVFP.td b/llvm/lib/Target/ARM/ARMInstrVFP.td
index bfc6cd5..48c1a38 100644
--- a/llvm/lib/Target/ARM/ARMInstrVFP.td
+++ b/llvm/lib/Target/ARM/ARMInstrVFP.td
@@ -23,6 +23,11 @@
def arm_fmdrr : SDNode<"ARMISD::VMOVDRR", SDT_VMOVDRR>;
def arm_fmrrd : SDNode<"ARMISD::VMOVRRD", SDT_VMOVRRD>;
+def SDT_VMOVhr : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, i32>] >;
+def SDT_VMOVrh : SDTypeProfile<1, 1, [SDTCisVT<0, i32>, SDTCisFP<1>] >;
+def arm_vmovhr : SDNode<"ARMISD::VMOVhr", SDT_VMOVhr>;
+def arm_vmovrh : SDNode<"ARMISD::VMOVrh", SDT_VMOVrh>;
+
//===----------------------------------------------------------------------===//
// Operand Definitions.
//
@@ -1171,9 +1176,9 @@
// Move H->R, clearing top 16 bits
def VMOVRH : AVConv2I<0b11100001, 0b1001,
- (outs GPR:$Rt), (ins SPR:$Sn),
+ (outs GPR:$Rt), (ins HPR:$Sn),
IIC_fpMOVSI, "vmov", ".f16\t$Rt, $Sn",
- []>,
+ [(set GPR:$Rt, (arm_vmovrh HPR:$Sn))]>,
Requires<[HasFullFP16]>,
Sched<[WriteFPMOV]> {
// Instruction operands.
@@ -1191,9 +1196,9 @@
// Move R->H, clearing top 16 bits
def VMOVHR : AVConv4I<0b11100000, 0b1001,
- (outs SPR:$Sn), (ins GPR:$Rt),
+ (outs HPR:$Sn), (ins GPR:$Rt),
IIC_fpMOVIS, "vmov", ".f16\t$Sn, $Rt",
- []>,
+ [(set HPR:$Sn, (arm_vmovhr GPR:$Rt))]>,
Requires<[HasFullFP16]>,
Sched<[WriteFPMOV]> {
// Instruction operands.