Simplify multiplications by vectors whose elements are powers of 2.

Patch by Andrea Di Biagio.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@183005 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 2628f4b..2761bc2 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -95,6 +95,25 @@
   return MulExt.slt(Min) || MulExt.sgt(Max);
 }
 
+/// \brief A helper routine of InstCombiner::visitMul().
+///
+/// If C is a vector of known powers of 2, then this function returns
+/// a new vector obtained from C replacing each element with its logBase2.
+/// Return a null pointer otherwise.
+static Constant *getLogBase2Vector(ConstantDataVector *CV) {
+  const APInt *IVal;
+  SmallVector<Constant *, 4> Elts;
+
+  for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) {
+    Constant *Elt = CV->getElementAsConstant(I);
+    if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2())
+      return 0;
+    Elts.push_back(ConstantInt::get(Elt->getType(), IVal->logBase2()));
+  }
+
+  return ConstantVector::get(Elts);
+}
+
 Instruction *InstCombiner::visitMul(BinaryOperator &I) {
   bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -108,24 +127,37 @@
   if (match(Op1, m_AllOnes()))  // X * -1 == 0 - X
     return BinaryOperator::CreateNeg(Op0, I.getName());
 
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+  // Also allow combining multiply instructions on vectors.
+  {
+    Value *NewOp;
+    Constant *C1, *C2;
+    const APInt *IVal;
+    if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)),
+                        m_Constant(C1))) &&
+        match(C1, m_APInt(IVal)))
+      // ((X << C1)*C2) == (X * (C2 << C1))
+      return BinaryOperator::CreateMul(NewOp, ConstantExpr::getShl(C1, C2));
 
-    // ((X << C1)*C2) == (X * (C2 << C1))
-    if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op0))
-      if (SI->getOpcode() == Instruction::Shl)
-        if (Constant *ShOp = dyn_cast<Constant>(SI->getOperand(1)))
-          return BinaryOperator::CreateMul(SI->getOperand(0),
-                                           ConstantExpr::getShl(CI, ShOp));
+    if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) {
+      Constant *NewCst = 0;
+      if (match(C1, m_APInt(IVal)) && IVal->isPowerOf2())
+        // Replace X*(2^C) with X << C, where C is either a scalar or a splat.
+        NewCst = ConstantInt::get(NewOp->getType(), IVal->logBase2());
+      else if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(C1))
+        // Replace X*(2^C) with X << C, where C is a vector of known
+        // constant powers of 2.
+        NewCst = getLogBase2Vector(CV);
 
-    const APInt &Val = CI->getValue();
-    if (Val.isPowerOf2()) {          // Replace X*(2^C) with X << C
-      Constant *NewCst = ConstantInt::get(Op0->getType(), Val.logBase2());
-      BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, NewCst);
-      if (I.hasNoSignedWrap()) Shl->setHasNoSignedWrap();
-      if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap();
-      return Shl;
+      if (NewCst) {
+        BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst);
+        if (I.hasNoSignedWrap()) Shl->setHasNoSignedWrap();
+        if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap();
+        return Shl;
+      }
     }
+  }
 
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
     // Canonicalize (X+C1)*CI -> X*CI+C1*CI.
     { Value *X; ConstantInt *C1;
       if (Op0->hasOneUse() &&