[AArch64] Improve load/store optimizer to handle LDUR + LDR.

This patch allows the mixing of scaled and unscaled load/stores to form
load/store pairs.

This is a reapplication of r259812, which had an incorrect assert.  The
test_stur_str_no_assert() test is a reduced version of the issue hit in
the AArch64 self-host.

PR24465

llvm-svn: 260523
diff --git a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp
index b46f829..fb65048 100644
--- a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp
+++ b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp
@@ -259,6 +259,10 @@
   return isNarrowLoad(MI->getOpcode());
 }
 
+static bool isNarrowLoadOrStore(unsigned Opc) {
+  return isNarrowLoad(Opc) || isNarrowStore(Opc);
+}
+
 // Scaling factor for unscaled load or store.
 static int getMemScale(MachineInstr *MI) {
   switch (MI->getOpcode()) {
@@ -825,10 +829,28 @@
   const MachineOperand &BaseRegOp =
       MergeForward ? getLdStBaseOp(Paired) : getLdStBaseOp(I);
 
+  int Offset = getLdStOffsetOp(I).getImm();
+  int PairedOffset = getLdStOffsetOp(Paired).getImm();
+  bool PairedIsUnscaled = isUnscaledLdSt(Paired->getOpcode());
+  if (IsUnscaled != PairedIsUnscaled) {
+    // We're trying to pair instructions that differ in how they are scaled.  If
+    // I is scaled then scale the offset of Paired accordingly.  Otherwise, do
+    // the opposite (i.e., make Paired's offset unscaled).
+    int MemSize = getMemScale(Paired);
+    if (PairedIsUnscaled) {
+      // If the unscaled offset isn't a multiple of the MemSize, we can't
+      // pair the operations together.
+      assert(!(PairedOffset % getMemScale(Paired)) &&
+             "Offset should be a multiple of the stride!");
+      PairedOffset /= MemSize;
+    } else {
+      PairedOffset *= MemSize;
+    }
+  }
+
   // Which register is Rt and which is Rt2 depends on the offset order.
   MachineInstr *RtMI, *Rt2MI;
-  if (getLdStOffsetOp(I).getImm() ==
-      getLdStOffsetOp(Paired).getImm() + OffsetStride) {
+  if (Offset == PairedOffset + OffsetStride) {
     RtMI = Paired;
     Rt2MI = I;
     // Here we swapped the assumption made for SExtIdx.
@@ -841,10 +863,11 @@
     Rt2MI = Paired;
   }
   int OffsetImm = getLdStOffsetOp(RtMI).getImm();
-  // Handle Unscaled.
-  if (IsUnscaled) {
-    assert (!(OffsetImm % OffsetStride) && "Unscaled offset cannot be scaled.");
-    OffsetImm /= OffsetStride;
+  // Scale the immediate offset, if necessary.
+  if (isUnscaledLdSt(RtMI->getOpcode())) {
+    assert(!(OffsetImm % getMemScale(RtMI)) &&
+           "Unscaled offset cannot be scaled.");
+    OffsetImm /= getMemScale(RtMI);
   }
 
   // Construct the new instruction.
@@ -1039,9 +1062,13 @@
 static bool inBoundsForPair(bool IsUnscaled, int Offset, int OffsetStride) {
   // Convert the byte-offset used by unscaled into an "element" offset used
   // by the scaled pair load/store instructions.
-  if (IsUnscaled)
+  if (IsUnscaled) {
+    // If the byte-offset isn't a multiple of the stride, there's no point
+    // trying to match it.
+    if (Offset % OffsetStride)
+      return false;
     Offset /= OffsetStride;
-
+  }
   return Offset <= 63 && Offset >= -64;
 }
 
@@ -1148,7 +1175,21 @@
     Flags.setSExtIdx(NonSExtOpc == (unsigned)OpcA ? 1 : 0);
     return true;
   }
-  return false;
+
+  // If the second instruction isn't even a load/store, bail out.
+  if (!PairIsValidLdStrOpc)
+    return false;
+
+  // FIXME: We don't support merging narrow loads/stores with mixed
+  // scaled/unscaled offsets.
+  if (isNarrowLoadOrStore(OpcA) || isNarrowLoadOrStore(OpcB))
+    return false;
+
+  // Try to match an unscaled load/store with a scaled load/store.
+  return isUnscaledLdSt(OpcA) != isUnscaledLdSt(OpcB) &&
+         getMatchingPairOpcode(OpcA) == getMatchingPairOpcode(OpcB);
+
+  // FIXME: Can we also match a mixed sext/zext unscaled/scaled pair?
 }
 
 /// Scan the instructions looking for a load/store that can be combined with the
@@ -1204,6 +1245,23 @@
       // final offset must be in range.
       unsigned MIBaseReg = getLdStBaseOp(MI).getReg();
       int MIOffset = getLdStOffsetOp(MI).getImm();
+      bool MIIsUnscaled = isUnscaledLdSt(MI);
+      if (IsUnscaled != MIIsUnscaled) {
+        // We're trying to pair instructions that differ in how they are scaled.
+        // If FirstMI is scaled then scale the offset of MI accordingly.
+        // Otherwise, do the opposite (i.e., make MI's offset unscaled).
+        int MemSize = getMemScale(MI);
+        if (MIIsUnscaled) {
+          // If the unscaled offset isn't a multiple of the MemSize, we can't
+          // pair the operations together: bail and keep looking.
+          if (MIOffset % MemSize)
+            continue;
+          MIOffset /= MemSize;
+        } else {
+          MIOffset *= MemSize;
+        }
+      }
+
       if (BaseReg == MIBaseReg && ((Offset == MIOffset + OffsetStride) ||
                                    (Offset + OffsetStride == MIOffset))) {
         int MinOffset = Offset < MIOffset ? Offset : MIOffset;
@@ -1214,10 +1272,9 @@
           return E;
         // If the resultant immediate offset of merging these instructions
         // is out of range for a pairwise instruction, bail and keep looking.
-        bool MIIsUnscaled = isUnscaledLdSt(MI);
         bool IsNarrowLoad = isNarrowLoad(MI->getOpcode());
         if (!IsNarrowLoad &&
-            !inBoundsForPair(MIIsUnscaled, MinOffset, OffsetStride)) {
+            !inBoundsForPair(IsUnscaled, MinOffset, OffsetStride)) {
           trackRegDefsUses(MI, ModifiedRegs, UsedRegs, TRI);
           MemInsns.push_back(MI);
           continue;