Distribute (A + B) * C to (A * C) + (B * C) to make use of NEON multiplier
accumulator forwarding:
vadd d3, d0, d1
vmul d3, d3, d2
=>
vmul d3, d0, d2
vmla d3, d1, d2


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@128665 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp
index 16b110f..5838181 100644
--- a/lib/Target/ARM/ARMISelLowering.cpp
+++ b/lib/Target/ARM/ARMISelLowering.cpp
@@ -5224,6 +5224,42 @@
   return SDValue();
 }
 
+/// PerformVMULCombine
+/// Distribute (A + B) * C to (A * C) + (B * C) to take advantage of the
+/// special multiplier accumulator forwarding.
+///   vmul d3, d0, d2
+///   vmla d3, d1, d2
+/// is faster than
+///   vadd d3, d0, d1
+///   vmul d3, d3, d2
+static SDValue PerformVMULCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI,
+                                  const ARMSubtarget *Subtarget) {
+  if (!Subtarget->hasVMLxForwarding())
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  unsigned Opcode = N0.getOpcode();
+  if (Opcode != ISD::ADD && Opcode != ISD::SUB &&
+      Opcode != ISD::FADD && Opcode != ISD::FSUB) {
+    Opcode = N0.getOpcode();
+    if (Opcode != ISD::ADD && Opcode != ISD::SUB &&
+        Opcode != ISD::FADD && Opcode != ISD::FSUB)
+      return SDValue();
+    std::swap(N0, N1);
+  }
+
+  EVT VT = N->getValueType(0);
+  DebugLoc DL = N->getDebugLoc();
+  SDValue N00 = N0->getOperand(0);
+  SDValue N01 = N0->getOperand(1);
+  return DAG.getNode(Opcode, DL, VT,
+                     DAG.getNode(ISD::MUL, DL, VT, N00, N1),
+                     DAG.getNode(ISD::MUL, DL, VT, N01, N1));
+}
+
 static SDValue PerformMULCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  const ARMSubtarget *Subtarget) {
@@ -5236,6 +5272,8 @@
     return SDValue();
 
   EVT VT = N->getValueType(0);
+  if (VT.is64BitVector() || VT.is128BitVector())
+    return PerformVMULCombine(N, DCI, Subtarget);
   if (VT != MVT::i32)
     return SDValue();