[AArch64][GlobalISel] Use TST for comparisons when possible

Porting over the part of `emitComparison` in AArch64ISelLowering where we use
TST to represent a compare.

- Rename `tryOptCMN` to `tryFoldIntegerCompare`, since it now also emits TSTs
  when possible.

- Add a utility function for emitting a TST with register operands.

- Rename opt-fold-cmn.mir to opt-fold-compare.mir, since it now also tests the
  TST fold as well.

Differential Revision: https://reviews.llvm.org/D64371

llvm-svn: 365404
diff --git a/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp
index 46d6ccb..bef690c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp
@@ -130,6 +130,8 @@
                                    MachineIRBuilder &MIRBuilder) const;
   MachineInstr *emitCMN(MachineOperand &LHS, MachineOperand &RHS,
                         MachineIRBuilder &MIRBuilder) const;
+  MachineInstr *emitTST(const Register &LHS, const Register &RHS,
+                        MachineIRBuilder &MIRBuilder) const;
   MachineInstr *emitExtractVectorElt(Optional<Register> DstReg,
                                      const RegisterBank &DstRB, LLT ScalarTy,
                                      Register VecReg, unsigned LaneIdx,
@@ -202,9 +204,9 @@
   bool tryOptVectorShuffle(MachineInstr &I) const;
   bool tryOptVectorDup(MachineInstr &MI) const;
   bool tryOptSelect(MachineInstr &MI) const;
-  MachineInstr *tryOptCMN(MachineOperand &LHS, MachineOperand &RHS,
-                          MachineOperand &Predicate,
-                          MachineIRBuilder &MIRBuilder) const;
+  MachineInstr *tryFoldIntegerCompare(MachineOperand &LHS, MachineOperand &RHS,
+                                      MachineOperand &Predicate,
+                                      MachineIRBuilder &MIRBuilder) const;
 
   const AArch64TargetMachine &TM;
   const AArch64Subtarget &STI;
@@ -801,6 +803,19 @@
   return CmpOpcTbl[ShouldUseImm][OpSize == 64];
 }
 
+/// Returns true if \p P is an unsigned integer comparison predicate.
+static bool isUnsignedICMPPred(const CmpInst::Predicate P) {
+  switch (P) {
+  default:
+    return false;
+  case CmpInst::ICMP_UGT:
+  case CmpInst::ICMP_UGE:
+  case CmpInst::ICMP_ULT:
+  case CmpInst::ICMP_ULE:
+    return true;
+  }
+}
+
 static AArch64CC::CondCode changeICMPPredToAArch64CC(CmpInst::Predicate P) {
   switch (P) {
   default:
@@ -2919,16 +2934,45 @@
   return &*CmpMI;
 }
 
+MachineInstr *
+AArch64InstructionSelector::emitTST(const Register &LHS, const Register &RHS,
+                                    MachineIRBuilder &MIRBuilder) const {
+  MachineRegisterInfo &MRI = MIRBuilder.getMF().getRegInfo();
+  unsigned RegSize = MRI.getType(LHS).getSizeInBits();
+  bool Is32Bit = (RegSize == 32);
+  static const unsigned OpcTable[2][2]{{AArch64::ANDSXrr, AArch64::ANDSXri},
+                                       {AArch64::ANDSWrr, AArch64::ANDSWri}};
+  Register ZReg = Is32Bit ? AArch64::WZR : AArch64::XZR;
+
+  // We might be able to fold in an immediate into the TST. We need to make sure
+  // it's a logical immediate though, since ANDS requires that.
+  auto ValAndVReg = getConstantVRegValWithLookThrough(RHS, MRI);
+  bool IsImmForm = ValAndVReg.hasValue() &&
+                   AArch64_AM::isLogicalImmediate(ValAndVReg->Value, RegSize);
+  unsigned Opc = OpcTable[Is32Bit][IsImmForm];
+  auto TstMI = MIRBuilder.buildInstr(Opc, {ZReg}, {LHS});
+
+  if (IsImmForm)
+    TstMI.addImm(
+        AArch64_AM::encodeLogicalImmediate(ValAndVReg->Value, RegSize));
+  else
+    TstMI.addUse(RHS);
+
+  constrainSelectedInstRegOperands(*TstMI, TII, TRI, RBI);
+  return &*TstMI;
+}
+
 MachineInstr *AArch64InstructionSelector::emitIntegerCompare(
     MachineOperand &LHS, MachineOperand &RHS, MachineOperand &Predicate,
     MachineIRBuilder &MIRBuilder) const {
   assert(LHS.isReg() && RHS.isReg() && "Expected LHS and RHS to be registers!");
   MachineRegisterInfo &MRI = MIRBuilder.getMF().getRegInfo();
 
-  // Fold the compare into a CMN if possible.
-  MachineInstr *Cmn = tryOptCMN(LHS, RHS, Predicate, MIRBuilder);
-  if (Cmn)
-    return Cmn;
+  // Fold the compare if possible.
+  MachineInstr *FoldCmp =
+      tryFoldIntegerCompare(LHS, RHS, Predicate, MIRBuilder);
+  if (FoldCmp)
+    return FoldCmp;
 
   // Can't fold into a CMN. Just emit a normal compare.
   unsigned CmpOpc = 0;
@@ -3170,10 +3214,9 @@
   return true;
 }
 
-MachineInstr *
-AArch64InstructionSelector::tryOptCMN(MachineOperand &LHS, MachineOperand &RHS,
-                                      MachineOperand &Predicate,
-                                      MachineIRBuilder &MIRBuilder) const {
+MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare(
+    MachineOperand &LHS, MachineOperand &RHS, MachineOperand &Predicate,
+    MachineIRBuilder &MIRBuilder) const {
   assert(LHS.isReg() && RHS.isReg() && Predicate.isPredicate() &&
          "Unexpected MachineOperand");
   MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
@@ -3228,42 +3271,52 @@
   // Check if the RHS or LHS of the G_ICMP is defined by a SUB
   MachineInstr *LHSDef = FindDef(LHS.getReg());
   MachineInstr *RHSDef = FindDef(RHS.getReg());
-  const AArch64CC::CondCode CC =
-      changeICMPPredToAArch64CC((CmpInst::Predicate)Predicate.getPredicate());
-  bool DidFold = false;
+  CmpInst::Predicate P = (CmpInst::Predicate)Predicate.getPredicate();
+  const AArch64CC::CondCode CC = changeICMPPredToAArch64CC(P);
 
-  MachineOperand CMNLHS = LHS;
-  MachineOperand CMNRHS = RHS;
-  if (IsCMN(LHSDef, CC)) {
-    // We're doing this:
-    //
-    // Given:
-    //
-    // x = G_SUB 0, y
-    // G_ICMP x, z
-    //
-    // Update the G_ICMP:
-    //
-    // G_ICMP y, z
-    CMNLHS = LHSDef->getOperand(2);
-    DidFold = true;
-  } else if (IsCMN(RHSDef, CC)) {
-    // Same idea here, but with the RHS of the compare instead:
-    //
-    // Given:
-    //
-    // x = G_SUB 0, y
-    // G_ICMP z, x
-    //
-    // Update the G_ICMP:
-    //
-    // G_ICMP z, y
-    CMNRHS = RHSDef->getOperand(2);
-    DidFold = true;
+  // Given this:
+  //
+  // x = G_SUB 0, y
+  // G_ICMP x, z
+  //
+  // Produce this:
+  //
+  // cmn y, z
+  if (IsCMN(LHSDef, CC))
+    return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder);
+
+  // Same idea here, but with the RHS of the compare instead:
+  //
+  // Given this:
+  //
+  // x = G_SUB 0, y
+  // G_ICMP z, x
+  //
+  // Produce this:
+  //
+  // cmn z, y
+  if (IsCMN(RHSDef, CC))
+    return emitCMN(LHS, RHSDef->getOperand(2), MIRBuilder);
+
+  // Given this:
+  //
+  // z = G_AND x, y
+  // G_ICMP z, 0
+  //
+  // Produce this if the compare is signed:
+  //
+  // tst x, y
+  if (!isUnsignedICMPPred(P) && LHSDef &&
+      LHSDef->getOpcode() == TargetOpcode::G_AND) {
+    // Make sure that the RHS is 0.
+    auto ValAndVReg = getConstantVRegValWithLookThrough(RHS.getReg(), MRI);
+    if (!ValAndVReg || ValAndVReg->Value != 0)
+      return nullptr;
+
+    return emitTST(LHSDef->getOperand(1).getReg(),
+                   LHSDef->getOperand(2).getReg(), MIRBuilder);
   }
 
-  if (DidFold)
-    return emitCMN(CMNLHS, CMNRHS, MIRBuilder);
   return nullptr;
 }