[NVPTX] Handle ldg created from sign-/zero-extended load
Reviewers: jingyue
Subscribers: jholewinski
Differential Revision: http://reviews.llvm.org/D18053
llvm-svn: 265389
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 2d0098b..b1ed2df 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1286,7 +1286,7 @@
MemSDNode *Mem;
bool IsLDG = true;
- // If this is an LDG intrinsic, the address is the third operand. Its its an
+ // If this is an LDG intrinsic, the address is the third operand. If its an
// LDG/LDU SD node (from custom vector handling), then its the second operand
if (N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
Op1 = N->getOperand(2);
@@ -1317,10 +1317,23 @@
SDValue Base, Offset, Addr;
EVT EltVT = Mem->getMemoryVT();
+ unsigned NumElts = 1;
if (EltVT.isVector()) {
+ NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
}
+ // Build the "promoted" result VTList for the load. If we are really loading
+ // i8s, then the return type will be promoted to i16 since we do not expose
+ // 8-bit registers in NVPTX.
+ EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
+ SmallVector<EVT, 5> InstVTs;
+ for (unsigned i = 0; i != NumElts; ++i) {
+ InstVTs.push_back(NodeVT);
+ }
+ InstVTs.push_back(MVT::Other);
+ SDVTList InstVTList = CurDAG->getVTList(InstVTs);
+
if (SelectDirectAddr(Op1, Addr)) {
switch (N->getOpcode()) {
default:
@@ -1461,7 +1474,7 @@
}
SDValue Ops[] = { Addr, Chain };
- LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops);
+ LD = CurDAG->getMachineNode(Opcode, DL, InstVTList, Ops);
} else if (TM.is64Bit() ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
: SelectADDRri(Op1.getNode(), Op1, Base, Offset)) {
if (TM.is64Bit()) {
@@ -1750,7 +1763,7 @@
SDValue Ops[] = { Base, Offset, Chain };
- LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops);
+ LD = CurDAG->getMachineNode(Opcode, DL, InstVTList, Ops);
} else {
if (TM.is64Bit()) {
switch (N->getOpcode()) {
@@ -2037,13 +2050,77 @@
}
SDValue Ops[] = { Op1, Chain };
- LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops);
+ LD = CurDAG->getMachineNode(Opcode, DL, InstVTList, Ops);
}
MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1);
MemRefs0[0] = Mem->getMemOperand();
cast<MachineSDNode>(LD)->setMemRefs(MemRefs0, MemRefs0 + 1);
+ // For automatic generation of LDG (through SelectLoad[Vector], not the
+ // intrinsics), we may have an extending load like:
+ //
+ // i32,ch = load<LD1[%data1(addrspace=1)], zext from i8> t0, t7, undef:i64
+ //
+ // Since we load an i8 value, the matching logic above will have selected an
+ // LDG instruction that reads i8 and stores it in an i16 register (NVPTX does
+ // not expose 8-bit registers):
+ //
+ // i16,ch = INT_PTX_LDG_GLOBAL_i8areg64 t7, t0
+ //
+ // To get the correct type in this case, truncate back to i8 and then extend
+ // to the original load type.
+ EVT OrigType = N->getValueType(0);
+ LoadSDNode *LDSD = dyn_cast<LoadSDNode>(N);
+ if (LDSD && EltVT == MVT::i8 && OrigType.getScalarSizeInBits() >= 32) {
+ unsigned CvtOpc = 0;
+
+ switch (LDSD->getExtensionType()) {
+ default:
+ llvm_unreachable("An extension is required for i8 loads");
+ break;
+ case ISD::SEXTLOAD:
+ switch (OrigType.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unhandled integer load type");
+ break;
+ case MVT::i32:
+ CvtOpc = NVPTX::CVT_s32_s8;
+ break;
+ case MVT::i64:
+ CvtOpc = NVPTX::CVT_s64_s8;
+ break;
+ }
+ break;
+ case ISD::EXTLOAD:
+ case ISD::ZEXTLOAD:
+ switch (OrigType.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unhandled integer load type");
+ break;
+ case MVT::i32:
+ CvtOpc = NVPTX::CVT_u32_u8;
+ break;
+ case MVT::i64:
+ CvtOpc = NVPTX::CVT_u64_u8;
+ break;
+ }
+ break;
+ }
+
+ // For each output value, truncate to i8 (since the upper 8 bits are
+ // undefined) and then extend to the desired type.
+ for (unsigned i = 0; i != NumElts; ++i) {
+ SDValue Res(LD, i);
+ SDValue OrigVal(N, i);
+
+ SDNode *CvtNode =
+ CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res,
+ CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32));
+ ReplaceUses(OrigVal, SDValue(CvtNode, 0));
+ }
+ }
+
return LD;
}