move a trunc-specific xform out of commonIntCastTransforms into visitTrunc


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@92768 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 103630e..e2bb3fb 100644
--- a/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -28,19 +28,25 @@
     Offset = CI->getZExtValue();
     Scale  = 0;
     return ConstantInt::get(Type::getInt32Ty(Val->getContext()), 0);
-  } else if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) {
+  }
+  
+  if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) {
     if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
       if (I->getOpcode() == Instruction::Shl) {
         // This is a value scaled by '1 << the shift amt'.
         Scale = 1U << RHS->getZExtValue();
         Offset = 0;
         return I->getOperand(0);
-      } else if (I->getOpcode() == Instruction::Mul) {
+      }
+      
+      if (I->getOpcode() == Instruction::Mul) {
         // This value is scaled by 'RHS'.
         Scale = RHS->getZExtValue();
         Offset = 0;
         return I->getOperand(0);
-      } else if (I->getOpcode() == Instruction::Add) {
+      }
+      
+      if (I->getOpcode() == Instruction::Add) {
         // We have X+C.  Check to see if we really have (X*C2)+C1, 
         // where C1 is divisible by C2.
         unsigned SubScale;
@@ -650,18 +656,6 @@
                                       ConstantInt::get(CI.getType(), 1));
     }
     break;
-
-  case Instruction::Shl: {
-    // Canonicalize trunc inside shl, if we can.
-    ConstantInt *CI = dyn_cast<ConstantInt>(Op1);
-    if (CI && DestBitSize < SrcBitSize &&
-        CI->getLimitedValue(DestBitSize) < DestBitSize) {
-      Value *Op0c = Builder->CreateTrunc(Op0, DestTy, Op0->getName());
-      Value *Op1c = Builder->CreateTrunc(Op1, DestTy, Op1->getName());
-      return BinaryOperator::CreateShl(Op0c, Op1c);
-    }
-    break;
-  }
   }
   return 0;
 }
@@ -684,7 +678,7 @@
     return new ICmpInst(ICmpInst::ICMP_NE, Src, Zero);
   }
 
-  // Optimize trunc(lshr(), c) to pull the shift through the truncate.
+  // Optimize trunc(lshr(x, c)) to pull the shift through the truncate.
   ConstantInt *ShAmtV = 0;
   Value *ShiftOp = 0;
   if (Src->hasOneUse() &&
@@ -704,6 +698,21 @@
       return BinaryOperator::CreateLShr(V1, V2);
     }
   }
+  
+  // Transform trunc(shl(X, C)) -> shl(trunc(X), C)
+  if (Src->hasOneUse() &&
+      match(Src, m_Shl(m_Value(ShiftOp), m_ConstantInt(ShAmtV)))) {
+    uint32_t ShAmt = ShAmtV->getLimitedValue(SrcBitWidth);
+    if (ShAmt >= DestBitWidth)        // All zeros.
+      return ReplaceInstUsesWith(CI, Constant::getNullValue(Ty));
+      
+    // Okay, we can shrink this.  Truncate the input, then return a new
+    // shift.
+    Value *V1 = Builder->CreateTrunc(ShiftOp, Ty, ShiftOp->getName());
+    Value *V2 = ConstantExpr::getTrunc(ShAmtV, Ty);
+    return BinaryOperator::CreateShl(V1, V2);
+  }
+
  
   return 0;
 }