AMDGPU: Split wide vectors of i16/f16 into 32-bit regs on calls
This improves code for the same reasons as scalarizing 32-bit
element vectors.
llvm-svn: 338418
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 097b935..d6647c7 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -703,6 +703,11 @@
unsigned Size = ScalarVT.getSizeInBits();
if (Size == 32 || Size == 64)
return ScalarVT.getSimpleVT();
+
+ if (Size == 16 &&
+ Subtarget->has16BitInsts() &&
+ isPowerOf2_32(VT.getVectorNumElements()))
+ return VT.isInteger() ? MVT::v2i16 : MVT::v2f16;
}
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
@@ -712,10 +717,16 @@
CallingConv::ID CC,
EVT VT) const {
if (CC != CallingConv::AMDGPU_KERNEL && VT.isVector()) {
+ unsigned NumElts = VT.getVectorNumElements();
EVT ScalarVT = VT.getScalarType();
unsigned Size = ScalarVT.getSizeInBits();
+
if (Size == 32 || Size == 64)
- return VT.getVectorNumElements();
+ return NumElts;
+
+ // FIXME: Fails to break down as we want with v3.
+ if (Size == 16 && Subtarget->has16BitInsts() && isPowerOf2_32(NumElts))
+ return VT.getVectorNumElements() / 2;
}
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
@@ -726,12 +737,23 @@
EVT VT, EVT &IntermediateVT,
unsigned &NumIntermediates, MVT &RegisterVT) const {
if (CC != CallingConv::AMDGPU_KERNEL && VT.isVector()) {
+ unsigned NumElts = VT.getVectorNumElements();
EVT ScalarVT = VT.getScalarType();
unsigned Size = ScalarVT.getSizeInBits();
if (Size == 32 || Size == 64) {
RegisterVT = ScalarVT.getSimpleVT();
IntermediateVT = RegisterVT;
- NumIntermediates = VT.getVectorNumElements();
+ NumIntermediates = NumElts;
+ return NumIntermediates;
+ }
+
+ // FIXME: We should fix the ABI to be the same on targets without 16-bit
+ // support, but unless we can properly handle 3-vectors, it will be still be
+ // inconsistent.
+ if (Size == 16 && Subtarget->has16BitInsts() && isPowerOf2_32(NumElts)) {
+ RegisterVT = VT.isInteger() ? MVT::v2i16 : MVT::v2f16;
+ IntermediateVT = RegisterVT;
+ NumIntermediates = NumElts / 2;
return NumIntermediates;
}
}