Enhance logic in X86DAGToDAGISel::PreprocessForRMW which move load inside callseq_start to allow it to be folded into a call. It was not considering the cases where a token factor is between the load and the callseq_start.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@63022 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelDAGToDAG.cpp b/lib/Target/X86/X86ISelDAGToDAG.cpp
index c1d886c..04ba430 100644
--- a/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -432,14 +432,27 @@
 /// MoveBelowCallSeqStart - Replace CALLSEQ_START operand with load's chain
 /// operand and move load below the call's chain operand.
 static void MoveBelowCallSeqStart(SelectionDAG *CurDAG, SDValue Load,
-                           SDValue Call, SDValue Chain) {
+                                  SDValue Call, SDValue CallSeqStart) {
   SmallVector<SDValue, 8> Ops;
-  for (unsigned i = 0, e = Chain.getNode()->getNumOperands(); i != e; ++i)
-    if (Load.getNode() == Chain.getOperand(i).getNode())
-      Ops.push_back(Load.getOperand(0));
-    else
-      Ops.push_back(Chain.getOperand(i));
-  CurDAG->UpdateNodeOperands(Chain, &Ops[0], Ops.size());
+  SDValue Chain = CallSeqStart.getOperand(0);
+  if (Chain.getNode() == Load.getNode())
+    Ops.push_back(Load.getOperand(0));
+  else {
+    assert(Chain.getOpcode() == ISD::TokenFactor &&
+           "Unexpected CallSeqStart chain operand");
+    for (unsigned i = 0, e = Chain.getNumOperands(); i != e; ++i)
+      if (Chain.getOperand(i).getNode() == Load.getNode())
+        Ops.push_back(Load.getOperand(0));
+      else
+        Ops.push_back(Chain.getOperand(i));
+    SDValue NewChain =
+      CurDAG->getNode(ISD::TokenFactor, MVT::Other, &Ops[0], Ops.size());
+    Ops.clear();
+    Ops.push_back(NewChain);
+  }
+  for (unsigned i = 1, e = CallSeqStart.getNumOperands(); i != e; ++i)
+    Ops.push_back(CallSeqStart.getOperand(i));
+  CurDAG->UpdateNodeOperands(CallSeqStart, &Ops[0], Ops.size());
   CurDAG->UpdateNodeOperands(Load, Call.getOperand(0),
                              Load.getOperand(1), Load.getOperand(2));
   Ops.clear();
@@ -468,7 +481,13 @@
       return false;
     Chain = Chain.getOperand(0);
   }
-  return Chain.getOperand(0).getNode() == Callee.getNode();
+  
+  if (Chain.getOperand(0).getNode() == Callee.getNode())
+    return true;
+  if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor &&
+      Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()))
+    return true;
+  return false;
 }