Support for MEMCPY and MEMSET.

llvm-svn: 25226
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 65622af..a6a95c7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -147,6 +147,9 @@
     setOperationAction(ISD::SHL_PARTS      , MVT::i32  , Custom);
     setOperationAction(ISD::SRA_PARTS      , MVT::i32  , Custom);
     setOperationAction(ISD::SRL_PARTS      , MVT::i32  , Custom);
+    // X86 wants to expand memset / memcpy itself.
+    setOperationAction(ISD::MEMSET         , MVT::Other, Custom);
+    setOperationAction(ISD::MEMCPY         , MVT::Other, Custom);
   }
 
   // We don't have line number support yet.
@@ -1614,6 +1617,109 @@
     return DAG.getNode(X86ISD::RET_FLAG, MVT::Other, Op.getOperand(0),
                        DAG.getConstant(getBytesToPopOnReturn(), MVT::i16));
   }
+  case ISD::MEMSET: {
+    SDOperand InFlag;
+    SDOperand Chain = Op.getOperand(0);
+    unsigned Align =
+      (unsigned)cast<ConstantSDNode>(Op.getOperand(4))->getValue();
+    if (Align == 0) Align = 1;
+
+    MVT::ValueType AVT;
+    SDOperand Count;
+    if (ConstantSDNode *ValC = dyn_cast<ConstantSDNode>(Op.getOperand(2))) {
+      unsigned ValReg;
+      unsigned Val = ValC->getValue() & 255;
+
+      // If the value is a constant, then we can potentially use larger sets.
+      switch (Align & 3) {
+      case 2:   // WORD aligned
+        AVT = MVT::i16;
+        if (ConstantSDNode *I = dyn_cast<ConstantSDNode>(Op.getOperand(3)))
+          Count = DAG.getConstant(I->getValue() / 2, MVT::i32);
+        else
+          Count = DAG.getNode(ISD::SRL, MVT::i32, Op.getOperand(3),
+                              DAG.getConstant(1, MVT::i8));
+        Val    = (Val << 8) | Val;
+        ValReg = X86::AX;
+        break;
+      case 0:   // DWORD aligned
+        AVT = MVT::i32;
+        if (ConstantSDNode *I = dyn_cast<ConstantSDNode>(Op.getOperand(3)))
+          Count = DAG.getConstant(I->getValue() / 4, MVT::i32);
+        else
+          Count = DAG.getNode(ISD::SRL, MVT::i32, Op.getOperand(3),
+                              DAG.getConstant(2, MVT::i8));
+        Val = (Val << 8)  | Val;
+        Val = (Val << 16) | Val;
+        ValReg = X86::EAX;
+        break;
+      default:  // Byte aligned
+        AVT = MVT::i8;
+        Count = Op.getOperand(3);
+        ValReg = X86::AL;
+        break;
+      }
+
+      Chain  = DAG.getCopyToReg(Chain, ValReg, DAG.getConstant(Val, AVT),
+                                InFlag);
+      InFlag = Chain.getValue(1);
+    } else {
+      AVT    = MVT::i8;
+      Count  = Op.getOperand(3);
+      Chain  = DAG.getCopyToReg(Chain, X86::AL, Op.getOperand(2), InFlag);
+      InFlag = Chain.getValue(1);
+    }
+
+    Chain  = DAG.getCopyToReg(Chain, X86::ECX, Count, InFlag);
+    InFlag = Chain.getValue(1);
+    Chain  = DAG.getCopyToReg(Chain, X86::EDI, Op.getOperand(1), InFlag);
+    InFlag = Chain.getValue(1);
+
+    return DAG.getNode(X86ISD::REP_STOS, MVT::Other, Chain,
+                       DAG.getValueType(AVT), InFlag);
+  }
+  case ISD::MEMCPY: {
+    SDOperand Chain = Op.getOperand(0);
+    unsigned Align =
+      (unsigned)cast<ConstantSDNode>(Op.getOperand(4))->getValue();
+    if (Align == 0) Align = 1;
+
+    MVT::ValueType AVT;
+    SDOperand Count;
+    switch (Align & 3) {
+    case 2:   // WORD aligned
+      AVT = MVT::i16;
+      if (ConstantSDNode *I = dyn_cast<ConstantSDNode>(Op.getOperand(3)))
+        Count = DAG.getConstant(I->getValue() / 2, MVT::i32);
+      else
+        Count = DAG.getNode(ISD::SRL, MVT::i32, Op.getOperand(3),
+                            DAG.getConstant(1, MVT::i8));
+      break;
+    case 0:   // DWORD aligned
+      AVT = MVT::i32;
+      if (ConstantSDNode *I = dyn_cast<ConstantSDNode>(Op.getOperand(3)))
+        Count = DAG.getConstant(I->getValue() / 4, MVT::i32);
+      else
+        Count = DAG.getNode(ISD::SRL, MVT::i32, Op.getOperand(3),
+                            DAG.getConstant(2, MVT::i8));
+      break;
+    default:  // Byte aligned
+      AVT = MVT::i8;
+      Count = Op.getOperand(3);
+      break;
+    }
+
+    SDOperand InFlag;
+    Chain  = DAG.getCopyToReg(Chain, X86::ECX, Count, InFlag);
+    InFlag = Chain.getValue(1);
+    Chain  = DAG.getCopyToReg(Chain, X86::EDI, Op.getOperand(1), InFlag);
+    InFlag = Chain.getValue(1);
+    Chain  = DAG.getCopyToReg(Chain, X86::ESI, Op.getOperand(2), InFlag);
+    InFlag = Chain.getValue(1);
+
+    return DAG.getNode(X86ISD::REP_MOVS, MVT::Other, Chain,
+                       DAG.getValueType(AVT), InFlag);
+  }
   case ISD::GlobalAddress: {
     GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
     SDOperand GVOp = DAG.getTargetGlobalAddress(GV, getPointerTy());
@@ -1659,6 +1765,8 @@
   case X86ISD::CMOV:               return "X86ISD::CMOV";
   case X86ISD::BRCOND:             return "X86ISD::BRCOND";
   case X86ISD::RET_FLAG:           return "X86ISD::RET_FLAG";
+  case X86ISD::REP_STOS:           return "X86ISD::RET_STOS";
+  case X86ISD::REP_MOVS:           return "X86ISD::RET_MOVS";
   }
 }