[SystemZ] Improve getMemoryOpCost() to find foldable loads that are converted.

The SystemZ backend can do arithmetic of memory by loading and then extending
one of the operands. Similarly, a load + truncate can be folded into an
operand.

This patch improves the SystemZ TTI cost function to recognize this.

Review: Ulrich Weigand
https://reviews.llvm.org/D52692

llvm-svn: 345327
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index f52c9ca6..670a8d3 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -849,54 +849,102 @@
   return BaseT::getVectorInstrCost(Opcode, Val, Index);
 }
 
+// Check if a load may be folded as a memory operand in its user.
+bool SystemZTTIImpl::
+isFoldableLoad(const LoadInst *Ld, const Instruction *&FoldedValue) {
+  if (!Ld->hasOneUse())
+    return false;
+  FoldedValue = Ld;
+  const Instruction *UserI = cast<Instruction>(*Ld->user_begin());
+  unsigned LoadedBits = getScalarSizeInBits(Ld->getType());
+  unsigned TruncBits = 0;
+  unsigned SExtBits = 0;
+  unsigned ZExtBits = 0;
+  if (UserI->hasOneUse()) {
+    unsigned UserBits = UserI->getType()->getScalarSizeInBits();
+    if (isa<TruncInst>(UserI))
+      TruncBits = UserBits;
+    else if (isa<SExtInst>(UserI))
+      SExtBits = UserBits;
+    else if (isa<ZExtInst>(UserI))
+      ZExtBits = UserBits;
+  }
+  if (TruncBits || SExtBits || ZExtBits) {
+    FoldedValue = UserI;
+    UserI = cast<Instruction>(*UserI->user_begin());
+    // Load (single use) -> trunc/extend (single use) -> UserI
+  }
+  switch (UserI->getOpcode()) {
+  case Instruction::Add: // SE: 16->32, 16/32->64, z14:16->64. ZE: 32->64
+  case Instruction::Sub:
+    if (LoadedBits == 32 && ZExtBits == 64)
+      return true;
+    LLVM_FALLTHROUGH;
+  case Instruction::Mul: // SE: 16->32, 32->64, z14:16->64
+    if (LoadedBits == 16 &&
+        (SExtBits == 32 ||
+         (SExtBits == 64 && ST->hasMiscellaneousExtensions2())))
+      return true;
+    LLVM_FALLTHROUGH;
+  case Instruction::SDiv:// SE: 32->64
+    if (LoadedBits == 32 && SExtBits == 64)
+      return true;
+    LLVM_FALLTHROUGH;
+  case Instruction::UDiv:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+  case Instruction::ICmp:
+    // This also makes sense for float operations, but disabled for now due
+    // to regressions.
+    // case Instruction::FCmp:
+    // case Instruction::FAdd:
+    // case Instruction::FSub:
+    // case Instruction::FMul:
+    // case Instruction::FDiv:
+
+    // All possible extensions of memory checked above.
+    if (SExtBits || ZExtBits)
+      return false;
+
+    unsigned LoadOrTruncBits = (TruncBits ? TruncBits : LoadedBits);
+    return (LoadOrTruncBits == 32 || LoadOrTruncBits == 64);
+    break;
+  }
+  return false;
+}
+
 int SystemZTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
                                     unsigned Alignment, unsigned AddressSpace,
                                     const Instruction *I) {
   assert(!Src->isVoidTy() && "Invalid type");
 
-  if (!Src->isVectorTy() && Opcode == Instruction::Load &&
-      I != nullptr && I->hasOneUse()) {
-      const Instruction *UserI = cast<Instruction>(*I->user_begin());
-      unsigned Bits = getScalarSizeInBits(Src);
-      bool FoldsLoad = false;
-      switch (UserI->getOpcode()) {
-      case Instruction::ICmp:
-      case Instruction::Add:
-      case Instruction::Sub:
-      case Instruction::Mul:
-      case Instruction::SDiv:
-      case Instruction::UDiv:
-      case Instruction::And:
-      case Instruction::Or:
-      case Instruction::Xor:
-      // This also makes sense for float operations, but disabled for now due
-      // to regressions.
-      // case Instruction::FCmp:
-      // case Instruction::FAdd:
-      // case Instruction::FSub:
-      // case Instruction::FMul:
-      // case Instruction::FDiv:
-        FoldsLoad = (Bits == 32 || Bits == 64);
-        break;
-      }
+  if (!Src->isVectorTy() && Opcode == Instruction::Load && I != nullptr) {
+    // Store the load or its truncated or extended value in FoldedValue.
+    const Instruction *FoldedValue = nullptr;
+    if (isFoldableLoad(cast<LoadInst>(I), FoldedValue)) {
+      const Instruction *UserI = cast<Instruction>(*FoldedValue->user_begin());
+      assert (UserI->getNumOperands() == 2 && "Expected a binop.");
 
-      if (FoldsLoad) {
-        assert (UserI->getNumOperands() == 2 &&
-                "Expected to only handle binops.");
+      // UserI can't fold two loads, so in that case return 0 cost only
+      // half of the time.
+      for (unsigned i = 0; i < 2; ++i) {
+        if (UserI->getOperand(i) == FoldedValue)
+          continue;
 
-        // UserI can't fold two loads, so in that case return 0 cost only
-        // half of the time.
-        for (unsigned i = 0; i < 2; ++i) {
-          if (UserI->getOperand(i) == I)
-            continue;
-          if (LoadInst *LI = dyn_cast<LoadInst>(UserI->getOperand(i))) {
-            if (LI->hasOneUse())
-              return i == 0;
-          }
+        if (Instruction *OtherOp = dyn_cast<Instruction>(UserI->getOperand(i))){
+          LoadInst *OtherLoad = dyn_cast<LoadInst>(OtherOp);
+          if (!OtherLoad &&
+              (isa<TruncInst>(OtherOp) || isa<SExtInst>(OtherOp) ||
+               isa<ZExtInst>(OtherOp)))
+            OtherLoad = dyn_cast<LoadInst>(OtherOp->getOperand(0));
+          if (OtherLoad && isFoldableLoad(OtherLoad, FoldedValue/*dummy*/))
+            return i == 0; // Both operands foldable.
         }
-
-        return 0;
       }
+
+      return 0; // Only I is foldable in user.
+    }
   }
 
   unsigned NumOps =
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index 92b2b9b..347a8a6 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -85,6 +85,7 @@
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
                          const Instruction *I = nullptr);
   int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index);
+  bool isFoldableLoad(const LoadInst *Ld, const Instruction *&FoldedValue);
   int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment,
                       unsigned AddressSpace, const Instruction *I = nullptr);