AMDGPU/GlobalISel: Select G_TRUNC

llvm-svn: 364215
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index 590fd25..45bf1f6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -555,6 +555,76 @@
   return Ret;
 }
 
+static int sizeToSubRegIndex(unsigned Size) {
+  switch (Size) {
+  case 32:
+    return AMDGPU::sub0;
+  case 64:
+    return AMDGPU::sub0_sub1;
+  case 96:
+    return AMDGPU::sub0_sub1_sub2;
+  case 128:
+    return AMDGPU::sub0_sub1_sub2_sub3;
+  case 256:
+    return AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_sub6_sub7;
+  default:
+    if (Size < 32)
+      return AMDGPU::sub0;
+    if (Size > 256)
+      return -1;
+    return sizeToSubRegIndex(PowerOf2Ceil(Size));
+  }
+}
+
+bool AMDGPUInstructionSelector::selectG_TRUNC(MachineInstr &I) const {
+  MachineBasicBlock *BB = I.getParent();
+  MachineFunction *MF = BB->getParent();
+  MachineRegisterInfo &MRI = MF->getRegInfo();
+
+  unsigned DstReg = I.getOperand(0).getReg();
+  unsigned SrcReg = I.getOperand(1).getReg();
+  const LLT DstTy = MRI.getType(DstReg);
+  const LLT SrcTy = MRI.getType(SrcReg);
+  if (!DstTy.isScalar())
+    return false;
+
+  const RegisterBank *DstRB = RBI.getRegBank(DstReg, MRI, TRI);
+  const RegisterBank *SrcRB = RBI.getRegBank(SrcReg, MRI, TRI);
+  if (SrcRB != DstRB)
+    return false;
+
+  unsigned DstSize = DstTy.getSizeInBits();
+  unsigned SrcSize = SrcTy.getSizeInBits();
+
+  const TargetRegisterClass *SrcRC
+    = TRI.getRegClassForSizeOnBank(SrcSize, *SrcRB, MRI);
+  const TargetRegisterClass *DstRC
+    = TRI.getRegClassForSizeOnBank(DstSize, *DstRB, MRI);
+
+  if (SrcSize > 32) {
+    int SubRegIdx = sizeToSubRegIndex(DstSize);
+    if (SubRegIdx == -1)
+      return false;
+
+    // Deal with weird cases where the class only partially supports the subreg
+    // index.
+    SrcRC = TRI.getSubClassWithSubReg(SrcRC, SubRegIdx);
+    if (!SrcRC)
+      return false;
+
+    I.getOperand(1).setSubReg(SubRegIdx);
+  }
+
+  if (!RBI.constrainGenericRegister(SrcReg, *SrcRC, MRI) ||
+      !RBI.constrainGenericRegister(DstReg, *DstRC, MRI)) {
+    LLVM_DEBUG(dbgs() << "Failed to constrain G_TRUNC\n");
+    return false;
+  }
+
+  I.setDesc(TII.get(TargetOpcode::COPY));
+  return true;
+}
+
 bool AMDGPUInstructionSelector::selectG_CONSTANT(MachineInstr &I) const {
   MachineBasicBlock *BB = I.getParent();
   MachineFunction *MF = BB->getParent();
@@ -770,6 +840,8 @@
     return selectG_SELECT(I);
   case TargetOpcode::G_STORE:
     return selectG_STORE(I);
+  case TargetOpcode::G_TRUNC:
+    return selectG_TRUNC(I);
   }
   return false;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
index 35f528f..8cee015 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
@@ -63,6 +63,7 @@
 
   MachineOperand getSubOperand64(MachineOperand &MO, unsigned SubIdx) const;
   bool selectCOPY(MachineInstr &I) const;
+  bool selectG_TRUNC(MachineInstr &I) const;
   bool selectG_CONSTANT(MachineInstr &I) const;
   bool selectG_ADD(MachineInstr &I) const;
   bool selectG_EXTRACT(MachineInstr &I) const;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
index fe675b5..6e7d47b 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
@@ -1679,17 +1679,12 @@
 }
 
 const TargetRegisterClass *
-SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
+SIRegisterInfo::getRegClassForSizeOnBank(unsigned Size,
+                                         const RegisterBank &RB,
                                          const MachineRegisterInfo &MRI) const {
-  unsigned Size = getRegSizeInBits(MO.getReg(), MRI);
-  const RegisterBank *RB = MRI.getRegBankOrNull(MO.getReg());
-  if (!RB)
-    return nullptr;
-
-  Size = PowerOf2Ceil(Size);
   switch (Size) {
   case 1: {
-    switch (RB->getID()) {
+    switch (RB.getID()) {
     case AMDGPU::VGPRRegBankID:
       return &AMDGPU::VGPR_32RegClass;
     case AMDGPU::VCCRegBankID:
@@ -1706,30 +1701,41 @@
     }
   }
   case 32:
-    return RB->getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VGPR_32RegClass :
-                                                  &AMDGPU::SReg_32_XM0RegClass;
+    return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VGPR_32RegClass :
+                                                 &AMDGPU::SReg_32_XM0RegClass;
   case 64:
-    return RB->getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_64RegClass :
-                                                   &AMDGPU::SReg_64_XEXECRegClass;
+    return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_64RegClass :
+                                                  &AMDGPU::SReg_64_XEXECRegClass;
   case 96:
-    return RB->getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_96RegClass :
-                                                  &AMDGPU::SReg_96RegClass;
+    return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_96RegClass :
+                                                 &AMDGPU::SReg_96RegClass;
   case 128:
-    return RB->getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_128RegClass :
-                                                  &AMDGPU::SReg_128RegClass;
+    return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_128RegClass :
+                                                 &AMDGPU::SReg_128RegClass;
   case 160:
-    return RB->getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_160RegClass :
-                                                  &AMDGPU::SReg_160RegClass;
+    return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_160RegClass :
+                                                 &AMDGPU::SReg_160RegClass;
   case 256:
-    return RB->getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_256RegClass :
-                                                  &AMDGPU::SReg_256RegClass;
+    return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_256RegClass :
+                                                 &AMDGPU::SReg_256RegClass;
   case 512:
-    return RB->getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_512RegClass :
-                                                  &AMDGPU::SReg_512RegClass;
+    return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VReg_512RegClass :
+                                                 &AMDGPU::SReg_512RegClass;
   default:
-    break;
+    if (Size < 32)
+      return RB.getID() == AMDGPU::VGPRRegBankID ? &AMDGPU::VGPR_32RegClass :
+                                                   &AMDGPU::SReg_32_XM0RegClass;
+    assert(Size < 512 && "unimplemented");
+    return getRegClassForSizeOnBank(PowerOf2Ceil(Size), RB, MRI);
   }
-  llvm_unreachable("not implemented");
+}
+
+const TargetRegisterClass *
+SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
+                                         const MachineRegisterInfo &MRI) const {
+  if (const RegisterBank *RB = MRI.getRegBankOrNull(MO.getReg()))
+    return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB, MRI);
+  return nullptr;
 }
 
 unsigned SIRegisterInfo::getVCC() const {
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.h b/llvm/lib/Target/AMDGPU/SIRegisterInfo.h
index 8a42b09..4f11403 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.h
@@ -229,6 +229,18 @@
   unsigned getReturnAddressReg(const MachineFunction &MF) const;
 
   const TargetRegisterClass *
+  getRegClassForSizeOnBank(unsigned Size,
+                           const RegisterBank &Bank,
+                           const MachineRegisterInfo &MRI) const;
+
+  const TargetRegisterClass *
+  getRegClassForTypeOnBank(LLT Ty,
+                           const RegisterBank &Bank,
+                           const MachineRegisterInfo &MRI) const {
+    return getRegClassForSizeOnBank(Ty.getSizeInBits(), Bank, MRI);
+  }
+
+  const TargetRegisterClass *
   getConstrainedRegClassForOperand(const MachineOperand &MO,
                                  const MachineRegisterInfo &MRI) const override;