AMDGPU: Fix creating invalid copy when adjusting dmask

Move the entire optimization to one place. Before it was possible
to adjust dmask without changing the register class of the output
instruction, since they were done in separate places. Fix all
lane sizes and move all of the optimization into the DAG folding.

llvm-svn: 319705
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index bab7739..18dc0fb 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -6586,9 +6586,9 @@
 }
 
 /// \brief Adjust the writemask of MIMG instructions
-void SITargetLowering::adjustWritemask(MachineSDNode *&Node,
-                                       SelectionDAG &DAG) const {
-  SDNode *Users[4] = { };
+SDNode *SITargetLowering::adjustWritemask(MachineSDNode *&Node,
+                                          SelectionDAG &DAG) const {
+  SDNode *Users[4] = { nullptr };
   unsigned Lane = 0;
   unsigned DmaskIdx = (Node->getNumOperands() - Node->getNumValues() == 9) ? 2 : 3;
   unsigned OldDmask = Node->getConstantOperandVal(DmaskIdx);
@@ -6605,7 +6605,7 @@
     // Abort if we can't understand the usage
     if (!I->isMachineOpcode() ||
         I->getMachineOpcode() != TargetOpcode::EXTRACT_SUBREG)
-      return;
+      return Node;
 
     // Lane means which subreg of %vgpra_vgprb_vgprc_vgprd is used.
     // Note that subregs are packed, i.e. Lane==0 is the first bit set
@@ -6623,7 +6623,7 @@
 
     // Abort if we have more than one user per component
     if (Users[Lane])
-      return;
+      return Node;
 
     Users[Lane] = *I;
     NewDmask |= 1 << Comp;
@@ -6631,25 +6631,41 @@
 
   // Abort if there's no change
   if (NewDmask == OldDmask)
-    return;
+    return Node;
+
+  unsigned BitsSet = countPopulation(NewDmask);
+
+  const SIInstrInfo *TII = getSubtarget()->getInstrInfo();
+  int NewOpcode = TII->getMaskedMIMGOp(Node->getMachineOpcode(), BitsSet);
+  assert(NewOpcode != -1 &&
+         NewOpcode != static_cast<int>(Node->getMachineOpcode()) &&
+         "failed to find equivalent MIMG op");
 
   // Adjust the writemask in the node
-  std::vector<SDValue> Ops;
+  SmallVector<SDValue, 12> Ops;
   Ops.insert(Ops.end(), Node->op_begin(), Node->op_begin() + DmaskIdx);
   Ops.push_back(DAG.getTargetConstant(NewDmask, SDLoc(Node), MVT::i32));
   Ops.insert(Ops.end(), Node->op_begin() + DmaskIdx + 1, Node->op_end());
-  Node = (MachineSDNode*)DAG.UpdateNodeOperands(Node, Ops);
 
-  // If we only got one lane, replace it with a copy
-  // (if NewDmask has only one bit set...)
-  if (NewDmask && (NewDmask & (NewDmask-1)) == 0) {
-    SDValue RC = DAG.getTargetConstant(AMDGPU::VGPR_32RegClassID, SDLoc(),
-                                       MVT::i32);
-    SDNode *Copy = DAG.getMachineNode(TargetOpcode::COPY_TO_REGCLASS,
-                                      SDLoc(), Users[Lane]->getValueType(0),
-                                      SDValue(Node, 0), RC);
+  MVT SVT = Node->getValueType(0).getVectorElementType().getSimpleVT();
+
+  auto NewVTList =
+    DAG.getVTList(BitsSet == 1 ?
+                  SVT : MVT::getVectorVT(SVT, BitsSet == 3 ? 4 : BitsSet),
+                  MVT::Other);
+
+  MachineSDNode *NewNode = DAG.getMachineNode(NewOpcode, SDLoc(Node),
+                                              NewVTList, Ops);
+  // Update chain.
+  DAG.ReplaceAllUsesOfValueWith(SDValue(Node, 1), SDValue(NewNode, 1));
+
+  if (BitsSet == 1) {
+    assert(Node->hasNUsesOfValue(1, 0));
+    SDNode *Copy = DAG.getMachineNode(TargetOpcode::COPY,
+                                      SDLoc(Node), Users[Lane]->getValueType(0),
+                                      SDValue(NewNode, 0));
     DAG.ReplaceAllUsesWith(Users[Lane], Copy);
-    return;
+    return nullptr;
   }
 
   // Update the users of the node with the new indices
@@ -6659,7 +6675,7 @@
       continue;
 
     SDValue Op = DAG.getTargetConstant(Idx, SDLoc(User), MVT::i32);
-    DAG.UpdateNodeOperands(User, User->getOperand(0), Op);
+    DAG.UpdateNodeOperands(User, SDValue(NewNode, 0), Op);
 
     switch (Idx) {
     default: break;
@@ -6668,6 +6684,9 @@
     case AMDGPU::sub2: Idx = AMDGPU::sub3; break;
     }
   }
+
+  DAG.RemoveDeadNode(Node);
+  return nullptr;
 }
 
 static bool isFrameIndexOp(SDValue Op) {
@@ -6725,14 +6744,16 @@
 }
 
 /// \brief Fold the instructions after selecting them.
+/// Returns null if users were already updated.
 SDNode *SITargetLowering::PostISelFolding(MachineSDNode *Node,
                                           SelectionDAG &DAG) const {
   const SIInstrInfo *TII = getSubtarget()->getInstrInfo();
   unsigned Opcode = Node->getMachineOpcode();
 
   if (TII->isMIMG(Opcode) && !TII->get(Opcode).mayStore() &&
-      !TII->isGather4(Opcode))
-    adjustWritemask(Node, DAG);
+      !TII->isGather4(Opcode)) {
+    return adjustWritemask(Node, DAG);
+  }
 
   if (Opcode == AMDGPU::INSERT_SUBREG ||
       Opcode == AMDGPU::REG_SEQUENCE) {
@@ -6810,31 +6831,6 @@
     return;
   }
 
-  if (TII->isMIMG(MI)) {
-    unsigned VReg = MI.getOperand(0).getReg();
-    const TargetRegisterClass *RC = MRI.getRegClass(VReg);
-    // TODO: Need mapping tables to handle other cases (register classes).
-    if (RC != &AMDGPU::VReg_128RegClass)
-      return;
-
-    unsigned DmaskIdx = MI.getNumOperands() == 12 ? 3 : 4;
-    unsigned Writemask = MI.getOperand(DmaskIdx).getImm();
-    unsigned BitsSet = 0;
-    for (unsigned i = 0; i < 4; ++i)
-      BitsSet += Writemask & (1 << i) ? 1 : 0;
-    switch (BitsSet) {
-    default: return;
-    case 1:  RC = &AMDGPU::VGPR_32RegClass; break;
-    case 2:  RC = &AMDGPU::VReg_64RegClass; break;
-    case 3:  RC = &AMDGPU::VReg_96RegClass; break;
-    }
-
-    unsigned NewOpcode = TII->getMaskedMIMGOp(MI.getOpcode(), BitsSet);
-    MI.setDesc(TII->get(NewOpcode));
-    MRI.setRegClass(VReg, RC);
-    return;
-  }
-
   // Replace unused atomics with the no return version.
   int NoRetAtomicOp = AMDGPU::getAtomicNoRetOp(MI.getOpcode());
   if (NoRetAtomicOp != -1) {