[InstCombine] Catch more bswap cases missed due to zext and truncs.

Fixes PR27824.
Differential Revision: http://reviews.llvm.org/D20591.

llvm-svn: 270853
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index a536ddc..edde166 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -1773,7 +1773,23 @@
         // If the AndMask is zero for this bit, clear the bit.
         if ((AndMask & Bit) == 0)
           Result->Provenance[i] = BitPart::Unset;
+      return Result;
+    }
 
+    // If this is a zext instruction zero extend the result.
+    if (I->getOpcode() == Instruction::ZExt) {
+      auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
+                                  MatchBitReversals, BPS);
+      if (!Res)
+        return Result;
+
+      Result = BitPart(Res->Provider, BitWidth);
+      auto NarrowBitWidth =
+          cast<IntegerType>(cast<ZExtInst>(I)->getSrcTy())->getBitWidth();
+      for (unsigned i = 0; i < NarrowBitWidth; ++i)
+        Result->Provenance[i] = Res->Provenance[i];
+      for (unsigned i = NarrowBitWidth; i < BitWidth; ++i)
+        Result->Provenance[i] = BitPart::Unset;
       return Result;
     }
   }
@@ -1816,6 +1832,15 @@
     return false;   // Can't do vectors or integers > 128 bits.
   unsigned BW = ITy->getBitWidth();
 
+  unsigned DemandedBW = BW;
+  IntegerType *DemandedTy = ITy;
+  if (I->hasOneUse()) {
+    if (TruncInst *Trunc = dyn_cast<TruncInst>(I->user_back())) {
+      DemandedTy = cast<IntegerType>(Trunc->getType());
+      DemandedBW = DemandedTy->getBitWidth();
+    }
+  }
+
   // Try to find all the pieces corresponding to the bswap.
   std::map<Value *, Optional<BitPart>> BPS;
   auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS);
@@ -1825,11 +1850,12 @@
 
   // Now, is the bit permutation correct for a bswap or a bitreverse? We can
   // only byteswap values with an even number of bytes.
-  bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;
-  for (unsigned i = 0; i < BW; ++i) {
-    OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW);
+  bool OKForBSwap = DemandedBW % 16 == 0, OKForBitReverse = true;
+  for (unsigned i = 0; i < DemandedBW; ++i) {
+    OKForBSwap &=
+        bitTransformIsCorrectForBSwap(BitProvenance[i], i, DemandedBW);
     OKForBitReverse &=
-        bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW);
+        bitTransformIsCorrectForBitReverse(BitProvenance[i], i, DemandedBW);
   }
 
   Intrinsic::ID Intrin;
@@ -1840,6 +1866,24 @@
   else
     return false;
 
+  if (ITy != DemandedTy) {
+    Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, DemandedTy);
+    Value *Provider = Res->Provider;
+    IntegerType *ProviderTy = cast<IntegerType>(Provider->getType());
+    // We may need to truncate the provider.
+    if (DemandedTy != ProviderTy) {
+      auto *Trunc = CastInst::Create(Instruction::Trunc, Provider, DemandedTy,
+                                     "trunc", I);
+      InsertedInsts.push_back(Trunc);
+      Provider = Trunc;
+    }
+    auto *CI = CallInst::Create(F, Provider, "rev", I);
+    InsertedInsts.push_back(CI);
+    auto *ExtInst = CastInst::Create(Instruction::ZExt, CI, ITy, "zext", I);
+    InsertedInsts.push_back(ExtInst);
+    return true;
+  }
+
   Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, ITy);
   InsertedInsts.push_back(CallInst::Create(F, Res->Provider, "rev", I));
   return true;