Patch to implement UMLAL/SMLAL instructions for the ARM architecture

This patch corrects the definition of umlal/smlal instructions and adds support
for matching them to the ARM dag combiner.

Bug 12213

Patch by Yin Ma!


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@161581 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp
index 7d82229..bb2116f 100644
--- a/lib/Target/ARM/ARMISelLowering.cpp
+++ b/lib/Target/ARM/ARMISelLowering.cpp
@@ -571,6 +571,11 @@
     }
   }
 
+  // ARM and Thumb2 support UMLAL/SMLAL.
+  if (!Subtarget->isThumb1Only())
+    setTargetDAGCombine(ISD::ADDC);
+
+
   computeRegisterProperties();
 
   // ARM does not have f32 extending load.
@@ -989,6 +994,8 @@
   case ARMISD::VTBL2:         return "ARMISD::VTBL2";
   case ARMISD::VMULLs:        return "ARMISD::VMULLs";
   case ARMISD::VMULLu:        return "ARMISD::VMULLu";
+  case ARMISD::UMLAL:         return "ARMISD::UMLAL";
+  case ARMISD::SMLAL:         return "ARMISD::SMLAL";
   case ARMISD::BUILD_VECTOR:  return "ARMISD::BUILD_VECTOR";
   case ARMISD::FMAX:          return "ARMISD::FMAX";
   case ARMISD::FMIN:          return "ARMISD::FMIN";
@@ -7127,6 +7134,154 @@
   return DAG.getNode(ISD::TRUNCATE, N->getDebugLoc(), VT, tmp);
 }
 
+static SDValue findMUL_LOHI(SDValue V) {
+  if (V->getOpcode() == ISD::UMUL_LOHI ||
+      V->getOpcode() == ISD::SMUL_LOHI)
+    return V;
+  return SDValue();
+}
+
+static SDValue AddCombineTo64bitMLAL(SDNode *AddcNode,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     const ARMSubtarget *Subtarget) {
+
+  if (Subtarget->isThumb1Only()) return SDValue();
+
+  // Only perform the checks after legalize when the pattern is available.
+  if (DCI.isBeforeLegalize()) return SDValue();
+
+  // Look for multiply add opportunities.
+  // The pattern is a ISD::UMUL_LOHI followed by two add nodes, where
+  // each add nodes consumes a value from ISD::UMUL_LOHI and there is
+  // a glue link from the first add to the second add.
+  // If we find this pattern, we can replace the U/SMUL_LOHI, ADDC, and ADDE by
+  // a S/UMLAL instruction.
+  //          loAdd   UMUL_LOHI
+  //            \    / :lo    \ :hi
+  //             \  /          \          [no multiline comment]
+  //              ADDC         |  hiAdd
+  //                 \ :glue  /  /
+  //                  \      /  /
+  //                    ADDE
+  //
+  assert(AddcNode->getOpcode() == ISD::ADDC && "Expect an ADDC");
+  SDValue AddcOp0 = AddcNode->getOperand(0);
+  SDValue AddcOp1 = AddcNode->getOperand(1);
+
+  // Check if the two operands are from the same mul_lohi node.
+  if (AddcOp0.getNode() == AddcOp1.getNode())
+    return SDValue();
+
+  assert(AddcNode->getNumValues() == 2 &&
+         AddcNode->getValueType(0) == MVT::i32 &&
+         AddcNode->getValueType(1) == MVT::Glue &&
+         "Expect ADDC with two result values: i32, glue");
+
+  // Check that the ADDC adds the low result of the S/UMUL_LOHI.
+  if (AddcOp0->getOpcode() != ISD::UMUL_LOHI &&
+      AddcOp0->getOpcode() != ISD::SMUL_LOHI &&
+      AddcOp1->getOpcode() != ISD::UMUL_LOHI &&
+      AddcOp1->getOpcode() != ISD::SMUL_LOHI)
+    return SDValue();
+
+  // Look for the glued ADDE.
+  SDNode* AddeNode = AddcNode->getGluedUser();
+  if (AddeNode == NULL)
+    return SDValue();
+
+  // Make sure it is really an ADDE.
+  if (AddeNode->getOpcode() != ISD::ADDE)
+    return SDValue();
+
+  assert(AddeNode->getNumOperands() == 3 &&
+         AddeNode->getOperand(2).getValueType() == MVT::Glue &&
+         "ADDE node has the wrong inputs");
+
+  // Check for the triangle shape.
+  SDValue AddeOp0 = AddeNode->getOperand(0);
+  SDValue AddeOp1 = AddeNode->getOperand(1);
+
+  // Make sure that the ADDE operands are not coming from the same node.
+  if (AddeOp0.getNode() == AddeOp1.getNode())
+    return SDValue();
+
+  // Find the MUL_LOHI node walking up ADDE's operands.
+  bool IsLeftOperandMUL = false;
+  SDValue MULOp = findMUL_LOHI(AddeOp0);
+  if (MULOp == SDValue())
+   MULOp = findMUL_LOHI(AddeOp1);
+  else
+    IsLeftOperandMUL = true;
+  if (MULOp == SDValue())
+     return SDValue();
+
+  // Figure out the right opcode.
+  unsigned Opc = MULOp->getOpcode();
+  unsigned FinalOpc = (Opc == ISD::SMUL_LOHI) ? ARMISD::SMLAL : ARMISD::UMLAL;
+
+  // Figure out the high and low input values to the MLAL node.
+  SDValue* HiMul = &MULOp;
+  SDValue* HiAdd = NULL;
+  SDValue* LoMul = NULL;
+  SDValue* LowAdd = NULL;
+
+  if (IsLeftOperandMUL)
+    HiAdd = &AddeOp1;
+  else
+    HiAdd = &AddeOp0;
+
+
+  if (AddcOp0->getOpcode() == Opc) {
+    LoMul = &AddcOp0;
+    LowAdd = &AddcOp1;
+  }
+  if (AddcOp1->getOpcode() == Opc) {
+    LoMul = &AddcOp1;
+    LowAdd = &AddcOp0;
+  }
+
+  if (LoMul == NULL)
+    return SDValue();
+
+  if (LoMul->getNode() != HiMul->getNode())
+    return SDValue();
+
+  // Create the merged node.
+  SelectionDAG &DAG = DCI.DAG;
+
+  // Build operand list.
+  SmallVector<SDValue, 8> Ops;
+  Ops.push_back(LoMul->getOperand(0));
+  Ops.push_back(LoMul->getOperand(1));
+  Ops.push_back(*LowAdd);
+  Ops.push_back(*HiAdd);
+
+  SDValue MLALNode =  DAG.getNode(FinalOpc, AddcNode->getDebugLoc(),
+                                 DAG.getVTList(MVT::i32, MVT::i32),
+                                 &Ops[0], Ops.size());
+
+  // Replace the ADDs' nodes uses by the MLA node's values.
+  SDValue HiMLALResult(MLALNode.getNode(), 1);
+  DAG.ReplaceAllUsesOfValueWith(SDValue(AddeNode, 0), HiMLALResult);
+
+  SDValue LoMLALResult(MLALNode.getNode(), 0);
+  DAG.ReplaceAllUsesOfValueWith(SDValue(AddcNode, 0), LoMLALResult);
+
+  // Return original node to notify the driver to stop replacing.
+  SDValue resNode(AddcNode, 0);
+  return resNode;
+}
+
+/// PerformADDCCombine - Target-specific dag combine transform from
+/// ISD::ADDC, ISD::ADDE, and ISD::MUL_LOHI to MLAL.
+static SDValue PerformADDCCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const ARMSubtarget *Subtarget) {
+
+  return AddCombineTo64bitMLAL(N, DCI, Subtarget);
+
+}
+
 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
 /// operands N0 and N1.  This is a helper for PerformADDCombine that is
 /// called with the default operands, and if that fails, with commuted
@@ -8738,6 +8893,7 @@
                                              DAGCombinerInfo &DCI) const {
   switch (N->getOpcode()) {
   default: break;
+  case ISD::ADDC:       return PerformADDCCombine(N, DCI, Subtarget);
   case ISD::ADD:        return PerformADDCombine(N, DCI, Subtarget);
   case ISD::SUB:        return PerformSUBCombine(N, DCI);
   case ISD::MUL:        return PerformMULCombine(N, DCI, Subtarget);