[NVPTX] Handle ldg created from sign-/zero-extended load

Reviewers: jingyue

Subscribers: jholewinski

Differential Revision: http://reviews.llvm.org/D18053

llvm-svn: 265389
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 2d0098b..b1ed2df 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1286,7 +1286,7 @@
   MemSDNode *Mem;
   bool IsLDG = true;
 
-  // If this is an LDG intrinsic, the address is the third operand. Its its an
+  // If this is an LDG intrinsic, the address is the third operand. If its an
   // LDG/LDU SD node (from custom vector handling), then its the second operand
   if (N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
     Op1 = N->getOperand(2);
@@ -1317,10 +1317,23 @@
   SDValue Base, Offset, Addr;
 
   EVT EltVT = Mem->getMemoryVT();
+  unsigned NumElts = 1;
   if (EltVT.isVector()) {
+    NumElts = EltVT.getVectorNumElements();
     EltVT = EltVT.getVectorElementType();
   }
 
+  // Build the "promoted" result VTList for the load. If we are really loading
+  // i8s, then the return type will be promoted to i16 since we do not expose
+  // 8-bit registers in NVPTX.
+  EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
+  SmallVector<EVT, 5> InstVTs;
+  for (unsigned i = 0; i != NumElts; ++i) {
+    InstVTs.push_back(NodeVT);
+  }
+  InstVTs.push_back(MVT::Other);
+  SDVTList InstVTList = CurDAG->getVTList(InstVTs);
+
   if (SelectDirectAddr(Op1, Addr)) {
     switch (N->getOpcode()) {
     default:
@@ -1461,7 +1474,7 @@
     }
 
     SDValue Ops[] = { Addr, Chain };
-    LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops);
+    LD = CurDAG->getMachineNode(Opcode, DL, InstVTList, Ops);
   } else if (TM.is64Bit() ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
                           : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) {
     if (TM.is64Bit()) {
@@ -1750,7 +1763,7 @@
 
     SDValue Ops[] = { Base, Offset, Chain };
 
-    LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops);
+    LD = CurDAG->getMachineNode(Opcode, DL, InstVTList, Ops);
   } else {
     if (TM.is64Bit()) {
       switch (N->getOpcode()) {
@@ -2037,13 +2050,77 @@
     }
 
     SDValue Ops[] = { Op1, Chain };
-    LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops);
+    LD = CurDAG->getMachineNode(Opcode, DL, InstVTList, Ops);
   }
 
   MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1);
   MemRefs0[0] = Mem->getMemOperand();
   cast<MachineSDNode>(LD)->setMemRefs(MemRefs0, MemRefs0 + 1);
 
+  // For automatic generation of LDG (through SelectLoad[Vector], not the
+  // intrinsics), we may have an extending load like:
+  //
+  //   i32,ch = load<LD1[%data1(addrspace=1)], zext from i8> t0, t7, undef:i64
+  //
+  // Since we load an i8 value, the matching logic above will have selected an
+  // LDG instruction that reads i8 and stores it in an i16 register (NVPTX does
+  // not expose 8-bit registers):
+  //
+  //   i16,ch = INT_PTX_LDG_GLOBAL_i8areg64 t7, t0
+  //
+  // To get the correct type in this case, truncate back to i8 and then extend
+  // to the original load type.
+  EVT OrigType = N->getValueType(0);
+  LoadSDNode *LDSD = dyn_cast<LoadSDNode>(N);
+  if (LDSD && EltVT == MVT::i8 && OrigType.getScalarSizeInBits() >= 32) {
+    unsigned CvtOpc = 0;
+
+    switch (LDSD->getExtensionType()) {
+    default:
+      llvm_unreachable("An extension is required for i8 loads");
+      break;
+    case ISD::SEXTLOAD:
+      switch (OrigType.getSimpleVT().SimpleTy) {
+      default:
+        llvm_unreachable("Unhandled integer load type");
+        break;
+      case MVT::i32:
+        CvtOpc = NVPTX::CVT_s32_s8;
+        break;
+      case MVT::i64:
+        CvtOpc = NVPTX::CVT_s64_s8;
+        break;
+      }
+      break;
+    case ISD::EXTLOAD:
+    case ISD::ZEXTLOAD:
+      switch (OrigType.getSimpleVT().SimpleTy) {
+      default:
+        llvm_unreachable("Unhandled integer load type");
+        break;
+      case MVT::i32:
+        CvtOpc = NVPTX::CVT_u32_u8;
+        break;
+      case MVT::i64:
+        CvtOpc = NVPTX::CVT_u64_u8;
+        break;
+      }
+      break;
+    }
+
+    // For each output value, truncate to i8 (since the upper 8 bits are
+    // undefined) and then extend to the desired type.
+    for (unsigned i = 0; i != NumElts; ++i) {
+      SDValue Res(LD, i);
+      SDValue OrigVal(N, i);
+
+      SDNode *CvtNode =
+        CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res,
+                               CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32));
+      ReplaceUses(OrigVal, SDValue(CvtNode, 0));
+    }
+  }
+
   return LD;
 }