Implement proper tail calls in the X86 backend for all fastcc->fastcc
tail calls.
llvm-svn: 22046
diff --git a/llvm/lib/Target/X86/X86ISelPattern.cpp b/llvm/lib/Target/X86/X86ISelPattern.cpp
index 4c77ff6..28534a2 100644
--- a/llvm/lib/Target/X86/X86ISelPattern.cpp
+++ b/llvm/lib/Target/X86/X86ISelPattern.cpp
@@ -84,7 +84,8 @@
   class X86TargetLowering : public TargetLowering {
     int VarArgsFrameIndex;            // FrameIndex for start of varargs area.
     int ReturnAddrIndex;              // FrameIndex for return slot.
-    int BytesToPopOnReturn;           // Number of bytes ret should pop.
+    int BytesToPopOnReturn;           // Number of arg bytes ret should pop.
+    int BytesCallerReserves;          // Number of arg bytes caller makes.
   public:
     X86TargetLowering(TargetMachine &TM) : TargetLowering(TM) {
       // Set up the TargetLowering object.
@@ -154,6 +155,10 @@
     //
     unsigned getBytesToPopOnReturn() const { return BytesToPopOnReturn; }
 
+    // Return the number of bytes that the caller reserves for arguments passed
+    // to this function.
+    unsigned getBytesCallerReserves() const { return BytesCallerReserves; }
+
     /// LowerOperation - Provide custom lowering hooks for some operations.
     ///
     virtual SDOperand LowerOperation(SDOperand Op, SelectionDAG &DAG);
@@ -180,6 +185,9 @@
     virtual std::pair<SDOperand, SDOperand>
     LowerFrameReturnAddress(bool isFrameAddr, SDOperand Chain, unsigned Depth,
                             SelectionDAG &DAG);
+
+    SDOperand getReturnAddressFrameIndex(SelectionDAG &DAG);
+
   private:
     // C Calling Convention implementation.
     std::vector<SDOperand> LowerCCCArguments(Function &F, SelectionDAG &DAG);
@@ -279,6 +287,7 @@
     VarArgsFrameIndex = MFI->CreateFixedObject(1, ArgOffset);
   ReturnAddrIndex = 0;     // No return address slot generated yet.
   BytesToPopOnReturn = 0;  // Callee pops nothing.
+  BytesCallerReserves = ArgOffset;
 
   // Finally, inform the code generator which regs we return values in.
   switch (getValueType(F.getReturnType())) {
@@ -631,6 +640,7 @@
   VarArgsFrameIndex = 0xAAAAAAA;   // fastcc functions can't have varargs.
   ReturnAddrIndex = 0;             // No return address slot generated yet.
   BytesToPopOnReturn = ArgOffset;  // Callee pops all stack arguments.
+  BytesCallerReserves = 0;
 
   // Finally, inform the code generator which regs we return values in.
   switch (getValueType(F.getReturnType())) {
@@ -838,6 +848,15 @@
   return std::make_pair(ResultVal, Chain);
 }
 
+SDOperand X86TargetLowering::getReturnAddressFrameIndex(SelectionDAG &DAG) {
+  if (ReturnAddrIndex == 0) {
+    // Set up a frame object for the return address.
+    MachineFunction &MF = DAG.getMachineFunction();
+    ReturnAddrIndex = MF.getFrameInfo()->CreateFixedObject(4, -4);
+  }
+
+  return DAG.getFrameIndex(ReturnAddrIndex, MVT::i32);
+}
 
 
 
@@ -848,14 +867,7 @@
   if (Depth)        // Depths > 0 not supported yet!
     Result = DAG.getConstant(0, getPointerTy());
   else {
-    if (ReturnAddrIndex == 0) {
-      // Set up a frame object for the return address.
-      MachineFunction &MF = DAG.getMachineFunction();
-      ReturnAddrIndex = MF.getFrameInfo()->CreateFixedObject(4, -4);
-    }
-
-    SDOperand RetAddrFI = DAG.getFrameIndex(ReturnAddrIndex, MVT::i32);
-
+    SDOperand RetAddrFI = getReturnAddressFrameIndex(DAG);
     if (!isFrameAddress)
       // Just load the return address
       Result = DAG.getLoad(MVT::i32, DAG.getEntryNode(), RetAddrFI,
@@ -951,6 +963,8 @@
     /// tree.
     std::map<SDOperand, unsigned> ExprMap;
 
+    /// TheDAG - The DAG being selected during Select* operations.
+    SelectionDAG *TheDAG;
   public:
     ISel(TargetMachine &TM) : SelectionDAGISel(X86Lowering), X86Lowering(TM) {
     }
@@ -974,7 +988,6 @@
                         bool FloatPromoteOk = false);
     void EmitFoldedLoad(SDOperand Op, X86AddressMode &AM);
     bool TryToFoldLoadOpStore(SDNode *Node);
-
     bool EmitOrOpOp(SDOperand Op1, SDOperand Op2, unsigned DestReg);
     void EmitCMP(SDOperand LHS, SDOperand RHS, bool isOnlyUse);
     bool EmitBranchCC(MachineBasicBlock *Dest, SDOperand Chain, SDOperand Cond);
@@ -985,6 +998,8 @@
     X86AddressMode SelectAddrExprs(const X86ISelAddressMode &IAM);
     bool MatchAddress(SDOperand N, X86ISelAddressMode &AM);
     void SelectAddress(SDOperand N, X86AddressMode &AM);
+    bool EmitPotentialTailCall(SDNode *Node);
+    void EmitFastCCToFastCCTailCall(SDNode *TailCallNode);
     void Select(SDOperand N);
   };
 }
@@ -1063,9 +1078,13 @@
   // registers required to compute each node.
   ComputeRegPressure(DAG.getRoot());
 
+  TheDAG = &DAG;
+
   // Codegen the basic block.
   Select(DAG.getRoot());
 
+  TheDAG = 0;
+
   // Finally, look at all of the successors of this block.  If any contain a PHI
   // node of FP type, we need to insert an FP_REG_KILL in this block.
   for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
@@ -3645,6 +3664,266 @@
   return true;
 }
 
+/// If node is a ret(tailcall) node, emit the specified tail call and return
+/// true, otherwise return false.
+///
+/// FIXME: This whole thing should be a post-legalize optimization pass which
+/// recognizes and transforms the dag.  We don't want the selection phase doing
+/// this stuff!!
+///
+bool ISel::EmitPotentialTailCall(SDNode *RetNode) {
+  assert(RetNode->getOpcode() == ISD::RET && "Not a return");
+
+  SDOperand Chain = RetNode->getOperand(0);
+
+  // If this is a token factor node where one operand is a call, dig into it.
+  SDOperand TokFactor;
+  unsigned TokFactorOperand = 0;
+  if (Chain.getOpcode() == ISD::TokenFactor) {
+    for (unsigned i = 0, e = Chain.getNumOperands(); i != e; ++i)
+      if (Chain.getOperand(i).getOpcode() == ISD::CALLSEQ_END ||
+          Chain.getOperand(i).getOpcode() == X86ISD::TAILCALL) {
+        TokFactorOperand = i;
+        TokFactor = Chain;
+        Chain = Chain.getOperand(i);
+        break;
+      }
+    if (TokFactor.Val == 0) return false;  // No call operand.
+  }
+
+  // Skip the CALLSEQ_END node if present.
+  if (Chain.getOpcode() == ISD::CALLSEQ_END)
+    Chain = Chain.getOperand(0);
+
+  // Is a tailcall the last control operation that occurs before the return?
+  if (Chain.getOpcode() != X86ISD::TAILCALL)
+    return false;
+
+  // If we return a value, is it the value produced by the call?
+  if (RetNode->getNumOperands() > 1) {
+    // Not returning the ret val of the call?
+    if (Chain.Val->getNumValues() == 1 ||
+        RetNode->getOperand(1) != Chain.getValue(1))
+      return false;
+
+    if (RetNode->getNumOperands() > 2) {
+      if (Chain.Val->getNumValues() == 2 ||
+          RetNode->getOperand(2) != Chain.getValue(2))
+        return false;
+    }
+    assert(RetNode->getNumOperands() <= 3);
+  }
+
+  // CalleeCallArgAmt - The total number of bytes used for the callee arg area.
+  // For FastCC, this will always be > 0.
+  unsigned CalleeCallArgAmt =
+    cast<ConstantSDNode>(Chain.getOperand(2))->getValue();
+
+  // CalleeCallArgPopAmt - The number of bytes in the call area popped by the
+  // callee.  For FastCC this will always be > 0, for CCC this is always 0.
+  unsigned CalleeCallArgPopAmt =
+    cast<ConstantSDNode>(Chain.getOperand(3))->getValue();
+
+  // There are several cases we can handle here.  First, if the caller and
+  // callee are both CCC functions, we can tailcall if the callee takes <= the
+  // number of argument bytes that the caller does.
+  if (CalleeCallArgPopAmt == 0 &&                  // Callee is C CallingConv?
+      X86Lowering.getBytesToPopOnReturn() == 0) {  // Caller is C CallingConv?
+    // Check to see if caller arg area size >= callee arg area size.
+    if (X86Lowering.getBytesCallerReserves() >= CalleeCallArgAmt) {
+      //std::cerr << "CCC TAILCALL UNIMP!\n";
+      // If TokFactor is non-null, emit all operands.
+
+      //EmitCCCToCCCTailCall(Chain.Val);
+      //return true;
+    }
+    return false;
+  }
+
+  // Second, if both are FastCC functions, we can always perform the tail call.
+  if (CalleeCallArgPopAmt && X86Lowering.getBytesToPopOnReturn()) {
+    // If TokFactor is non-null, emit all operands before the call.
+    if (TokFactor.Val) {
+      for (unsigned i = 0, e = TokFactor.getNumOperands(); i != e; ++i)
+        if (i != TokFactorOperand)
+          Select(TokFactor.getOperand(i));
+    }
+
+    EmitFastCCToFastCCTailCall(Chain.Val);
+    return true;
+  }
+
+  // We don't support mixed calls, due to issues with alignment.  We could in
+  // theory handle some mixed calls from CCC -> FastCC if the stack is properly
+  // aligned (which depends on the number of arguments to the callee).  TODO.
+  return false;
+}
+
+static SDOperand GetAdjustedArgumentStores(SDOperand Chain, int Offset,
+                                           SelectionDAG &DAG) {
+  MVT::ValueType StoreVT;
+  switch (Chain.getOpcode()) {
+  case ISD::CALLSEQ_START:
+    // If we found the start of the call sequence, we're done.
+    return Chain;
+  case ISD::TokenFactor: {
+    std::vector<SDOperand> Ops;
+    Ops.reserve(Chain.getNumOperands());
+    for (unsigned i = 0, e = Chain.getNumOperands(); i != e; ++i)
+      Ops.push_back(GetAdjustedArgumentStores(Chain.getOperand(i), Offset,DAG));
+    return DAG.getNode(ISD::TokenFactor, MVT::Other, Ops);
+  }
+  case ISD::STORE:       // Normal store
+    StoreVT = Chain.getOperand(1).getValueType();
+    break;
+  case ISD::TRUNCSTORE:  // FLOAT store
+    StoreVT = cast<MVTSDNode>(Chain)->getExtraValueType();
+    break;
+  }
+
+  SDOperand OrigDest = Chain.getOperand(2);
+  unsigned OrigOffset;
+
+  if (OrigDest.getOpcode() == ISD::CopyFromReg) {
+    OrigOffset = 0;
+    assert(cast<RegSDNode>(OrigDest)->getReg() == X86::ESP);
+  } else {
+    // We expect only (ESP+C)
+    assert(OrigDest.getOpcode() == ISD::ADD &&
+           isa<ConstantSDNode>(OrigDest.getOperand(1)) &&
+           OrigDest.getOperand(0).getOpcode() == ISD::CopyFromReg &&
+           cast<RegSDNode>(OrigDest.getOperand(0))->getReg() == X86::ESP);
+    OrigOffset = cast<ConstantSDNode>(OrigDest.getOperand(1))->getValue();
+  }
+
+  // Compute the new offset from the incoming ESP value we wish to use.
+  unsigned NewOffset = OrigOffset + Offset;
+
+  unsigned OpSize = (MVT::getSizeInBits(StoreVT)+7)/8;  // Bits -> Bytes
+  MachineFunction &MF = DAG.getMachineFunction();
+  int FI = MF.getFrameInfo()->CreateFixedObject(OpSize, NewOffset);
+  SDOperand FIN = DAG.getFrameIndex(FI, MVT::i32);
+
+  SDOperand InChain = GetAdjustedArgumentStores(Chain.getOperand(0), Offset,
+                                                DAG);
+  if (Chain.getOpcode() == ISD::STORE)
+    return DAG.getNode(ISD::STORE, MVT::Other, InChain, Chain.getOperand(1),
+                       FIN);
+  assert(Chain.getOpcode() == ISD::TRUNCSTORE);
+  return DAG.getNode(ISD::TRUNCSTORE, MVT::Other, InChain, Chain.getOperand(1),
+                     FIN, DAG.getSrcValue(NULL), StoreVT);
+}
+
+
+/// EmitFastCCToFastCCTailCall - Given a tailcall in the tail position to a
+/// fastcc function from a fastcc function, emit the code to emit a 'proper'
+/// tail call.
+void ISel::EmitFastCCToFastCCTailCall(SDNode *TailCallNode) {
+  unsigned CalleeCallArgSize =
+    cast<ConstantSDNode>(TailCallNode->getOperand(2))->getValue();
+  unsigned CallerArgSize = X86Lowering.getBytesToPopOnReturn();
+
+  //std::cerr << "****\n*** EMITTING TAIL CALL!\n****\n";
+
+  // Adjust argument stores.  Instead of storing to [ESP], f.e., store to frame
+  // indexes that are relative to the incoming ESP.  If the incoming and
+  // outgoing arg sizes are the same we will store to [InESP] instead of
+  // [CurESP] and the ESP referenced will be relative to the incoming function
+  // ESP.
+  int ESPOffset = CallerArgSize-CalleeCallArgSize;
+  SDOperand AdjustedArgStores =
+    GetAdjustedArgumentStores(TailCallNode->getOperand(0), ESPOffset, *TheDAG);
+
+  // Copy the return address of the caller into a virtual register so we don't
+  // clobber it.
+  SDOperand RetVal;
+  if (ESPOffset) {
+    SDOperand RetValAddr = X86Lowering.getReturnAddressFrameIndex(*TheDAG);
+    RetVal = TheDAG->getLoad(MVT::i32, TheDAG->getEntryNode(),
+                                       RetValAddr, TheDAG->getSrcValue(NULL));
+    SelectExpr(RetVal);
+  }
+
+  // Codegen all of the argument stores.
+  Select(AdjustedArgStores);
+
+  if (RetVal.Val) {
+    // Emit a store of the saved ret value to the new location.
+    MachineFunction &MF = TheDAG->getMachineFunction();
+    int ReturnAddrFI = MF.getFrameInfo()->CreateFixedObject(4, ESPOffset-4);
+    SDOperand RetValAddr = TheDAG->getFrameIndex(ReturnAddrFI, MVT::i32);
+    Select(TheDAG->getNode(ISD::STORE, MVT::Other, TheDAG->getEntryNode(),
+                           RetVal, RetValAddr));
+  }
+
+  // Get the destination value.
+  SDOperand Callee = TailCallNode->getOperand(1);
+  bool isDirect = isa<GlobalAddressSDNode>(Callee) ||
+                  isa<ExternalSymbolSDNode>(Callee);
+  unsigned CalleeReg;
+  if (!isDirect) CalleeReg = SelectExpr(Callee);
+
+  unsigned RegOp1 = 0;
+  unsigned RegOp2 = 0;
+
+  if (TailCallNode->getNumOperands() > 4) {
+    // The first value is passed in (a part of) EAX, the second in EDX.
+    RegOp1 = SelectExpr(TailCallNode->getOperand(4));
+    if (TailCallNode->getNumOperands() > 5)
+      RegOp2 = SelectExpr(TailCallNode->getOperand(5));
+      
+    switch (TailCallNode->getOperand(4).getValueType()) {
+    default: assert(0 && "Bad thing to pass in regs");
+    case MVT::i1:
+    case MVT::i8:
+      BuildMI(BB, X86::MOV8rr, 1, X86::AL).addReg(RegOp1);
+      RegOp1 = X86::AL;
+      break;
+    case MVT::i16:
+      BuildMI(BB, X86::MOV16rr, 1,X86::AX).addReg(RegOp1);
+      RegOp1 = X86::AX;
+      break;
+    case MVT::i32:
+      BuildMI(BB, X86::MOV32rr, 1,X86::EAX).addReg(RegOp1);
+      RegOp1 = X86::EAX;
+      break;
+    }
+    if (RegOp2)
+      switch (TailCallNode->getOperand(5).getValueType()) {
+      default: assert(0 && "Bad thing to pass in regs");
+      case MVT::i1:
+      case MVT::i8:
+        BuildMI(BB, X86::MOV8rr, 1, X86::DL).addReg(RegOp2);
+        RegOp2 = X86::DL;
+        break;
+      case MVT::i16:
+        BuildMI(BB, X86::MOV16rr, 1, X86::DX).addReg(RegOp2);
+        RegOp2 = X86::DX;
+        break;
+      case MVT::i32:
+        BuildMI(BB, X86::MOV32rr, 1, X86::EDX).addReg(RegOp2);
+        RegOp2 = X86::EDX;
+        break;
+      }
+  }
+
+  // Adjust ESP.
+  if (ESPOffset)
+    BuildMI(BB, X86::ADJSTACKPTRri, 2,
+            X86::ESP).addReg(X86::ESP).addImm(ESPOffset);
+
+  // TODO: handle jmp [mem]
+  if (!isDirect) {
+    BuildMI(BB, X86::TAILJMPr, 1).addReg(CalleeReg);
+  } else if (GlobalAddressSDNode *GASD = dyn_cast<GlobalAddressSDNode>(Callee)){
+    BuildMI(BB, X86::TAILJMPd, 1).addGlobalAddress(GASD->getGlobal(),true);
+  } else {
+    ExternalSymbolSDNode *ESSDN = cast<ExternalSymbolSDNode>(Callee);
+    BuildMI(BB, X86::TAILJMPd, 1).addExternalSymbol(ESSDN->getSymbol(), true);
+  }
+  // ADD IMPLICIT USE RegOp1/RegOp2's
+}
+
 
 void ISel::Select(SDOperand N) {
   unsigned Tmp1, Tmp2, Opc;
@@ -3698,6 +3977,12 @@
     }
     return;
   case ISD::RET:
+    if (N.getOperand(0).getOpcode() == ISD::CALLSEQ_END ||
+        N.getOperand(0).getOpcode() == X86ISD::TAILCALL ||
+        N.getOperand(0).getOpcode() == ISD::TokenFactor)
+      if (EmitPotentialTailCall(Node))
+        return;
+
     switch (N.getNumOperands()) {
     default:
       assert(0 && "Unknown return instruction!");