[InstCombine] factorize left shifts of add/sub

We do similar factorization folds in SimplifyUsingDistributiveLaws,
but that drops no-wrap properties. Propagating those optimally may
help solve:
https://llvm.org/PR47430

The propagation is all-or-nothing for these patterns: when all
3 incoming ops have nsw or nuw, the 2 new ops should have the
same no-wrap property:
https://alive2.llvm.org/ce/z/Dv8wsU

This also solves:
https://llvm.org/PR47584
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 5ce32bc..c1cacdf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1199,6 +1199,42 @@
   return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType());
 }
 
+/// This is a specialization of a more general transform from
+/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally
+/// for multi-use cases or propagating nsw/nuw, then we would not need this.
+static Instruction *factorizeMathWithShlOps(BinaryOperator &I,
+                                            InstCombiner::BuilderTy &Builder) {
+  // TODO: Also handle mul by doubling the shift amount?
+  assert(I.getOpcode() == Instruction::Add ||
+         I.getOpcode() == Instruction::Sub && "Expected add/sub");
+  auto *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
+  auto *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));
+  if (!Op0 || !Op1 || !(Op0->hasOneUse() || Op1->hasOneUse()))
+    return nullptr;
+
+  Value *X, *Y, *ShAmt;
+  if (!match(Op0, m_Shl(m_Value(X), m_Value(ShAmt))) ||
+      !match(Op1, m_Shl(m_Value(Y), m_Specific(ShAmt))))
+    return nullptr;
+
+  // No-wrap propagates only when all ops have no-wrap.
+  bool HasNSW = I.hasNoSignedWrap() && Op0->hasNoSignedWrap() &&
+                Op1->hasNoSignedWrap();
+  bool HasNUW = I.hasNoUnsignedWrap() && Op0->hasNoUnsignedWrap() &&
+                Op1->hasNoUnsignedWrap();
+
+  // add/sub (X << ShAmt), (Y << ShAmt) --> (add/sub X, Y) << ShAmt
+  Value *NewMath = Builder.CreateBinOp(I.getOpcode(), X, Y);
+  if (auto *NewI = dyn_cast<BinaryOperator>(NewMath)) {
+    NewI->setHasNoSignedWrap(HasNSW);
+    NewI->setHasNoUnsignedWrap(HasNUW);
+  }
+  auto *NewShl = BinaryOperator::CreateShl(NewMath, ShAmt);
+  NewShl->setHasNoSignedWrap(HasNSW);
+  NewShl->setHasNoUnsignedWrap(HasNUW);
+  return NewShl;
+}
+
 Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1),
                                  I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
@@ -1215,6 +1251,9 @@
   if (Value *V = SimplifyUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
+  if (Instruction *R = factorizeMathWithShlOps(I, Builder))
+    return R;
+
   if (Instruction *X = foldAddWithConstant(I))
     return X;
 
@@ -1753,6 +1792,9 @@
   if (Value *V = SimplifyUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
+  if (Instruction *R = factorizeMathWithShlOps(I, Builder))
+    return R;
+
   if (I.getType()->isIntOrIntVectorTy(1))
     return BinaryOperator::CreateXor(Op0, Op1);