[globalisel] Add G_SEXT_INREG

Summary:
Targets often have instructions that can sign-extend certain cases faster
than the equivalent shift-left/arithmetic-shift-right. Such cases can be
identified by matching a shift-left/shift-right pair but there are some
issues with this in the context of combines. For example, suppose you can
sign-extend 8-bit up to 32-bit with a target extend instruction.
  %1:_(s32) = G_SHL %0:_(s32), i32 24 # (I've inlined the G_CONSTANT for brevity)
  %2:_(s32) = G_ASHR %1:_(s32), i32 24
  %3:_(s32) = G_ASHR %2:_(s32), i32 1
would reasonably combine to:
  %1:_(s32) = G_SHL %0:_(s32), i32 24
  %2:_(s32) = G_ASHR %1:_(s32), i32 25
which no longer matches the special case. If your shifts and extend are
equal cost, this would break even as a pair of shifts but if your shift is
more expensive than the extend then it's cheaper as:
  %2:_(s32) = G_SEXT_INREG %0:_(s32), i32 8
  %3:_(s32) = G_ASHR %2:_(s32), i32 1
It's possible to match the shift-pair in ISel and emit an extend and ashr.
However, this is far from the only way to break this shift pair and make
it hard to match the extends. Another example is that with the right
known-zeros, this:
  %1:_(s32) = G_SHL %0:_(s32), i32 24
  %2:_(s32) = G_ASHR %1:_(s32), i32 24
  %3:_(s32) = G_MUL %2:_(s32), i32 2
can become:
  %1:_(s32) = G_SHL %0:_(s32), i32 24
  %2:_(s32) = G_ASHR %1:_(s32), i32 23

All upstream targets have been configured to lower it to the current
G_SHL,G_ASHR pair but will likely want to make it legal in some cases to
handle their faster cases.

To follow-up: Provide a way to legalize based on the constant. At the
moment, I'm thinking that the best way to achieve this is to provide the
MI in LegalityQuery but that opens the door to breaking core principles
of the legalizer (legality is not context sensitive). That said, it's
worth noting that looking at other instructions and acting on that
information doesn't violate this principle in itself. It's only a
violation if, at the end of legalization, a pass that checks legality
without being able to see the context would say an instruction might not be
legal. That's a fairly subtle distinction so to give a concrete example,
saying %2 in:
  %1 = G_CONSTANT 16
  %2 = G_SEXT_INREG %0, %1
is legal is in violation of that principle if the legality of %2 depends
on %1 being constant and/or being 16. However, legalizing to either:
  %2 = G_SEXT_INREG %0, 16
or:
  %1 = G_CONSTANT 16
  %2:_(s32) = G_SHL %0, %1
  %3:_(s32) = G_ASHR %2, %1
depending on whether %1 is constant and 16 does not violate that principle
since both outputs are genuinely legal.

Reviewers: bogner, aditya_nandakumar, volkan, aemerson, paquette, arsenm

Subscribers: sdardis, jvesely, wdng, nhaehnle, rovka, kristof.beyls, javed.absar, hiraditya, jrtc27, atanasyan, Petar.Avramovic, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D61289

llvm-svn: 368487
diff --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
index 461bc60..51a7479 100644
--- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
@@ -162,6 +162,17 @@
       return buildConstant(DstOps[0], Cst->getSExtValue());
     break;
   }
+  case TargetOpcode::G_SEXT_INREG: {
+    assert(DstOps.size() == 1 && "Invalid dst ops");
+    assert(SrcOps.size() == 2 && "Invalid src ops");
+    const DstOp &Dst = DstOps[0];
+    const SrcOp &Src0 = SrcOps[0];
+    const SrcOp &Src1 = SrcOps[1];
+    if (auto MaybeCst =
+            ConstantFoldExtOp(Opc, Src0.getReg(), Src1.getImm(), *getMRI()))
+      return buildConstant(Dst, MaybeCst->getSExtValue());
+    break;
+  }
   }
   bool CanCopy = checkCopyToDefsPossible(DstOps);
   if (!canPerformCSEForOpc(Opc))
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 500bae4..e2b5082 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -861,6 +861,98 @@
     MI.eraseFromParent();
     return Legalized;
   }
+  case TargetOpcode::G_SEXT_INREG: {
+    if (TypeIdx != 0)
+      return UnableToLegalize;
+
+    if (!MI.getOperand(2).isImm())
+      return UnableToLegalize;
+    int64_t SizeInBits = MI.getOperand(2).getImm();
+
+    // So long as the new type has more bits than the bits we're extending we
+    // don't need to break it apart.
+    if (NarrowTy.getScalarSizeInBits() >= SizeInBits) {
+      Observer.changingInstr(MI);
+      // We don't lose any non-extension bits by truncating the src and
+      // sign-extending the dst.
+      MachineOperand &MO1 = MI.getOperand(1);
+      auto TruncMIB = MIRBuilder.buildTrunc(NarrowTy, MO1.getReg());
+      MO1.setReg(TruncMIB->getOperand(0).getReg());
+
+      MachineOperand &MO2 = MI.getOperand(0);
+      Register DstExt = MRI.createGenericVirtualRegister(NarrowTy);
+      MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
+      MIRBuilder.buildInstr(TargetOpcode::G_SEXT, {MO2.getReg()}, {DstExt});
+      MO2.setReg(DstExt);
+      Observer.changedInstr(MI);
+      return Legalized;
+    }
+
+    // Break it apart. Components below the extension point are unmodified. The
+    // component containing the extension point becomes a narrower SEXT_INREG.
+    // Components above it are ashr'd from the component containing the
+    // extension point.
+    if (SizeOp0 % NarrowSize != 0)
+      return UnableToLegalize;
+    int NumParts = SizeOp0 / NarrowSize;
+
+    // List the registers where the destination will be scattered.
+    SmallVector<Register, 2> DstRegs;
+    // List the registers where the source will be split.
+    SmallVector<Register, 2> SrcRegs;
+
+    // Create all the temporary registers.
+    for (int i = 0; i < NumParts; ++i) {
+      Register SrcReg = MRI.createGenericVirtualRegister(NarrowTy);
+
+      SrcRegs.push_back(SrcReg);
+    }
+
+    // Explode the big arguments into smaller chunks.
+    MIRBuilder.buildUnmerge(SrcRegs, MI.getOperand(1).getReg());
+
+    Register AshrCstReg =
+        MIRBuilder.buildConstant(NarrowTy, NarrowTy.getScalarSizeInBits() - 1)
+            ->getOperand(0)
+            .getReg();
+    Register FullExtensionReg = 0;
+    Register PartialExtensionReg = 0;
+
+    // Do the operation on each small part.
+    for (int i = 0; i < NumParts; ++i) {
+      if ((i + 1) * NarrowTy.getScalarSizeInBits() < SizeInBits)
+        DstRegs.push_back(SrcRegs[i]);
+      else if (i * NarrowTy.getScalarSizeInBits() > SizeInBits) {
+        assert(PartialExtensionReg &&
+               "Expected to visit partial extension before full");
+        if (FullExtensionReg) {
+          DstRegs.push_back(FullExtensionReg);
+          continue;
+        }
+        DstRegs.push_back(MIRBuilder
+                              .buildInstr(TargetOpcode::G_ASHR, {NarrowTy},
+                                          {PartialExtensionReg, AshrCstReg})
+                              ->getOperand(0)
+                              .getReg());
+        FullExtensionReg = DstRegs.back();
+      } else {
+        DstRegs.push_back(
+            MIRBuilder
+                .buildInstr(
+                    TargetOpcode::G_SEXT_INREG, {NarrowTy},
+                    {SrcRegs[i], SizeInBits % NarrowTy.getScalarSizeInBits()})
+                ->getOperand(0)
+                .getReg());
+        PartialExtensionReg = DstRegs.back();
+      }
+    }
+
+    // Gather the destination registers into the final destination.
+    Register DstReg = MI.getOperand(0).getReg();
+    MIRBuilder.buildMerge(DstReg, DstRegs);
+    MI.eraseFromParent();
+    return Legalized;
+  }
   }
 }
 
@@ -1633,6 +1725,15 @@
     Observer.changedInstr(MI);
     return Legalized;
   }
+  case TargetOpcode::G_SEXT_INREG:
+    if (TypeIdx != 0)
+      return UnableToLegalize;
+
+    Observer.changingInstr(MI);
+    widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
+    widenScalarDst(MI, WideTy, 0, TargetOpcode::G_TRUNC);
+    Observer.changedInstr(MI);
+    return Legalized;
   }
 }
 
@@ -1980,6 +2081,21 @@
     return lowerFMinNumMaxNum(MI);
   case G_UNMERGE_VALUES:
     return lowerUnmergeValues(MI);
+  case TargetOpcode::G_SEXT_INREG: {
+    assert(MI.getOperand(2).isImm() && "Expected immediate");
+    int64_t SizeInBits = MI.getOperand(2).getImm();
+
+    Register DstReg = MI.getOperand(0).getReg();
+    Register SrcReg = MI.getOperand(1).getReg();
+    LLT DstTy = MRI.getType(DstReg);
+    Register TmpRes = MRI.createGenericVirtualRegister(DstTy);
+
+    auto MIBSz = MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - SizeInBits);
+    MIRBuilder.buildInstr(TargetOpcode::G_SHL, {TmpRes}, {SrcReg, MIBSz->getOperand(0).getReg()});
+    MIRBuilder.buildInstr(TargetOpcode::G_ASHR, {DstReg}, {TmpRes, MIBSz->getOperand(0).getReg()});
+    MI.eraseFromParent();
+    return Legalized;
+  }
   }
 }
 
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp
index 6e1de95..ebe3b7c 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp
@@ -215,7 +215,30 @@
     return true;
   }
   const bool AllCovered = (FirstUncovered >= NumTypeIdxs);
-  LLVM_DEBUG(dbgs() << ".. the first uncovered type index: " << FirstUncovered
+  if (NumTypeIdxs > 0)
+    LLVM_DEBUG(dbgs() << ".. the first uncovered type index: " << FirstUncovered
+                      << ", " << (AllCovered ? "OK" : "FAIL") << "\n");
+  return AllCovered;
+#else
+  return true;
+#endif
+}
+
+bool LegalizeRuleSet::verifyImmIdxsCoverage(unsigned NumImmIdxs) const {
+#ifndef NDEBUG
+  if (Rules.empty()) {
+    LLVM_DEBUG(
+        dbgs() << ".. imm index coverage check SKIPPED: no rules defined\n");
+    return true;
+  }
+  const int64_t FirstUncovered = ImmIdxsCovered.find_first_unset();
+  if (FirstUncovered < 0) {
+    LLVM_DEBUG(dbgs() << ".. imm index coverage check SKIPPED:"
+                         " user-defined predicate detected\n");
+    return true;
+  }
+  const bool AllCovered = (FirstUncovered >= NumImmIdxs);
+  LLVM_DEBUG(dbgs() << ".. the first uncovered imm index: " << FirstUncovered
                     << ", " << (AllCovered ? "OK" : "FAIL") << "\n");
   return AllCovered;
 #else
@@ -387,8 +410,6 @@
     LLVM_DEBUG(dbgs() << ".. opcode " << Opcode << " is aliased to " << Alias
                       << "\n");
     OpcodeIdx = getOpcodeIdxForOpcode(Alias);
-    LLVM_DEBUG(dbgs() << ".. opcode " << Alias << " is aliased to "
-                      << RulesForOpcode[OpcodeIdx].getAlias() << "\n");
     assert(RulesForOpcode[OpcodeIdx].getAlias() == 0 && "Cannot chain aliases");
   }
 
@@ -677,12 +698,23 @@
                      ? std::max(OpInfo.getGenericTypeIndex() + 1U, Acc)
                      : Acc;
         });
+    const unsigned NumImmIdxs = std::accumulate(
+        MCID.opInfo_begin(), MCID.opInfo_end(), 0U,
+        [](unsigned Acc, const MCOperandInfo &OpInfo) {
+          return OpInfo.isGenericImm()
+                     ? std::max(OpInfo.getGenericImmIndex() + 1U, Acc)
+                     : Acc;
+        });
     LLVM_DEBUG(dbgs() << MII.getName(Opcode) << " (opcode " << Opcode
                       << "): " << NumTypeIdxs << " type ind"
-                      << (NumTypeIdxs == 1 ? "ex" : "ices") << "\n");
+                      << (NumTypeIdxs == 1 ? "ex" : "ices") << ", "
+                      << NumImmIdxs << " imm ind"
+                      << (NumImmIdxs == 1 ? "ex" : "ices") << "\n");
     const LegalizeRuleSet &RuleSet = getActionDefinitions(Opcode);
     if (!RuleSet.verifyTypeIdxsCoverage(NumTypeIdxs))
       FailedOpcodes.push_back(Opcode);
+    else if (!RuleSet.verifyImmIdxsCoverage(NumImmIdxs))
+      FailedOpcodes.push_back(Opcode);
   }
   if (!FailedOpcodes.empty()) {
     errs() << "The following opcodes have ill-defined legalization rules:";
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 712f0db..907cb67 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -392,6 +392,23 @@
   return false;
 }
 
+Optional<APInt> llvm::ConstantFoldExtOp(unsigned Opcode, const unsigned Op1,
+                                        uint64_t Imm,
+                                        const MachineRegisterInfo &MRI) {
+  auto MaybeOp1Cst = getConstantVRegVal(Op1, MRI);
+  if (MaybeOp1Cst) {
+    LLT Ty = MRI.getType(Op1);
+    APInt C1(Ty.getSizeInBits(), *MaybeOp1Cst, true);
+    switch (Opcode) {
+    default:
+      break;
+    case TargetOpcode::G_SEXT_INREG:
+      return C1.trunc(Imm).sext(C1.getBitWidth());
+    }
+  }
+  return None;
+}
+
 void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) {
   AU.addPreserved<StackProtector>();
 }
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index 9346638..49f0c02 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -1368,7 +1368,23 @@
         break;
       }
     }
+    break;
+  }
+  case TargetOpcode::G_SEXT_INREG: {
+    if (!MI->getOperand(2).isImm()) {
+      report("G_SEXT_INREG expects an immediate operand #2", MI);
+      break;
+    }
 
+    LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
+    LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());
+    verifyVectorElementMatch(DstTy, SrcTy, MI);
+
+    int64_t Imm = MI->getOperand(2).getImm();
+    if (Imm <= 0)
+      report("G_SEXT_INREG size must be >= 1", MI);
+    if (Imm >= SrcTy.getScalarSizeInBits())
+      report("G_SEXT_INREG size must be less than source bit width", MI);
     break;
   }
   default:
diff --git a/llvm/lib/Target/AArch64/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/AArch64LegalizerInfo.cpp
index 79a2167..3992e0e 100644
--- a/llvm/lib/Target/AArch64/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64LegalizerInfo.cpp
@@ -370,6 +370,8 @@
 
   getActionDefinitionsBuilder(G_TRUNC).alwaysLegal();
 
+  getActionDefinitionsBuilder(G_SEXT_INREG).lower();
+
   // FP conversions
   getActionDefinitionsBuilder(G_FPTRUNC).legalFor(
       {{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}});
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 663cdd7..37222d9 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -774,6 +774,8 @@
       .scalarize(1);
   }
 
+  getActionDefinitionsBuilder(G_SEXT_INREG).lower();
+
   computeTables();
   verify(*ST.getInstrInfo());
 }
diff --git a/llvm/lib/Target/ARM/ARMLegalizerInfo.cpp b/llvm/lib/Target/ARM/ARMLegalizerInfo.cpp
index 73a57b2..81414e6 100644
--- a/llvm/lib/Target/ARM/ARMLegalizerInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMLegalizerInfo.cpp
@@ -84,6 +84,8 @@
   getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT})
       .legalForCartesianProduct({s8, s16, s32}, {s1, s8, s16});
 
+  getActionDefinitionsBuilder(G_SEXT_INREG).lower();
+
   getActionDefinitionsBuilder({G_MUL, G_AND, G_OR, G_XOR})
       .legalFor({s32})
       .minScalar(0, s32);
diff --git a/llvm/lib/Target/Mips/MipsLegalizerInfo.cpp b/llvm/lib/Target/Mips/MipsLegalizerInfo.cpp
index ea7cc09..558af42 100644
--- a/llvm/lib/Target/Mips/MipsLegalizerInfo.cpp
+++ b/llvm/lib/Target/Mips/MipsLegalizerInfo.cpp
@@ -144,6 +144,8 @@
       .libcallForCartesianProduct({s64, s32}, {s64})
       .minScalar(1, s32);
 
+  getActionDefinitionsBuilder(G_SEXT_INREG).lower();
+
   computeTables();
   verify(*ST.getInstrInfo());
 }
diff --git a/llvm/lib/Target/X86/X86LegalizerInfo.cpp b/llvm/lib/Target/X86/X86LegalizerInfo.cpp
index 9690056..04121f8 100644
--- a/llvm/lib/Target/X86/X86LegalizerInfo.cpp
+++ b/llvm/lib/Target/X86/X86LegalizerInfo.cpp
@@ -177,6 +177,7 @@
     setAction({G_ANYEXT, Ty}, Legal);
   }
   setAction({G_ANYEXT, s128}, Legal);
+  getActionDefinitionsBuilder(G_SEXT_INREG).lower();
 
   // Comparison
   setAction({G_ICMP, s1}, Legal);