AMDGPU/GlobalISel: Legalize G_FRINT

llvm-svn: 361026
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 63b517f..64ae29e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -293,6 +293,18 @@
     .legalFor({S32, S64})
     .scalarize(0);
 
+  if (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS) {
+    getActionDefinitionsBuilder(G_FRINT)
+      .legalFor({S32, S64})
+      .clampScalar(0, S32, S64)
+      .scalarize(0);
+  } else {
+    getActionDefinitionsBuilder(G_FRINT)
+      .legalFor({S32})
+      .customFor({S64})
+      .clampScalar(0, S32, S64)
+      .scalarize(0);
+  }
 
   getActionDefinitionsBuilder(G_GEP)
     .legalForCartesianProduct(AddrSpaces64, {S64})
@@ -675,6 +687,8 @@
   switch (MI.getOpcode()) {
   case TargetOpcode::G_ADDRSPACE_CAST:
     return legalizeAddrSpaceCast(MI, MRI, MIRBuilder);
+  case TargetOpcode::G_FRINT:
+    return legalizeFrint(MI, MRI, MIRBuilder);
   default:
     return false;
   }
@@ -831,3 +845,30 @@
   MI.eraseFromParent();
   return true;
 }
+
+bool AMDGPULegalizerInfo::legalizeFrint(
+  MachineInstr &MI, MachineRegisterInfo &MRI,
+  MachineIRBuilder &MIRBuilder) const {
+  MIRBuilder.setInstr(MI);
+
+  unsigned Src = MI.getOperand(1).getReg();
+  LLT Ty = MRI.getType(Src);
+  assert(Ty.isScalar() && Ty.getSizeInBits() == 64);
+
+  APFloat C1Val(APFloat::IEEEdouble(), "0x1.0p+52");
+  APFloat C2Val(APFloat::IEEEdouble(), "0x1.fffffffffffffp+51");
+
+  auto C1 = MIRBuilder.buildFConstant(Ty, C1Val);
+  auto CopySign = MIRBuilder.buildFCopysign(Ty, C1, Src);
+
+  // TODO: Should this propagate fast-math-flags?
+  auto Tmp1 = MIRBuilder.buildFAdd(Ty, Src, CopySign);
+  auto Tmp2 = MIRBuilder.buildFSub(Ty, Tmp1, CopySign);
+
+  auto C2 = MIRBuilder.buildFConstant(Ty, C2Val);
+  auto Fabs = MIRBuilder.buildFAbs(Ty, Src);
+
+  auto Cond = MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, LLT::scalar(1), Fabs, C2);
+  MIRBuilder.buildSelect(MI.getOperand(0).getReg(), Cond, Src, Tmp2);
+  return true;
+}