Implement memset -> rep stos*


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@19467 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelPattern.cpp b/lib/Target/X86/X86ISelPattern.cpp
index fe56208..1f4e866 100644
--- a/lib/Target/X86/X86ISelPattern.cpp
+++ b/lib/Target/X86/X86ISelPattern.cpp
@@ -50,7 +50,6 @@
       
       computeRegisterProperties();
 
-      setOperationUnsupported(ISD::MEMSET, MVT::Other);
       setOperationUnsupported(ISD::MEMCPY, MVT::Other);
       setOperationUnsupported(ISD::MEMMOVE, MVT::Other);
 
@@ -1858,9 +1857,11 @@
   if (/*!N->hasOneUse() &&*/ !LoweredTokens.insert(N).second)
     return;  // Already selected.
 
-  switch (N.getOpcode()) {
+  SDNode *Node = N.Val;
+
+  switch (Node->getOpcode()) {
   default:
-    N.Val->dump(); std::cerr << "\n";
+    Node->dump(); std::cerr << "\n";
     assert(0 && "Node not handled yet!");
   case ISD::EntryToken: return;  // Noop
   case ISD::CopyToReg:
@@ -2027,6 +2028,66 @@
                                                    X86::ADJCALLSTACKUP;
     BuildMI(BB, Opc, 1).addImm(Tmp1);
     return;
+  case ISD::MEMSET: {
+    Select(N.getOperand(0));  // Select the chain.
+    unsigned Align =
+      (unsigned)cast<ConstantSDNode>(Node->getOperand(4))->getValue();
+    if (Align == 0) Align = 1;
+
+    // Turn the byte code into # iterations
+    unsigned CountReg;
+    unsigned Opcode;
+    if (ConstantSDNode *ValC = dyn_cast<ConstantSDNode>(Node->getOperand(2))) {
+      unsigned Val = ValC->getValue() & 255;
+
+      // If the value is a constant, then we can potentially use larger sets.
+      switch (Align & 3) {
+      case 2:   // WORD aligned
+        CountReg = MakeReg(MVT::i32);
+        if (ConstantSDNode *I = dyn_cast<ConstantSDNode>(Node->getOperand(3))) {
+          BuildMI(BB, X86::MOV32ri, 1, CountReg).addImm(I->getValue()/2);
+        } else {
+          unsigned ByteReg = SelectExpr(Node->getOperand(3));
+          BuildMI(BB, X86::SHR32ri, 2, CountReg).addReg(ByteReg).addImm(1);
+        }
+        BuildMI(BB, X86::MOV16ri, 1, X86::AX).addImm((Val << 8) | Val);
+        Opcode = X86::REP_STOSW;
+        break;
+      case 0:   // DWORD aligned
+        CountReg = MakeReg(MVT::i32);
+        if (ConstantSDNode *I = dyn_cast<ConstantSDNode>(Node->getOperand(3))) {
+          BuildMI(BB, X86::MOV32ri, 1, CountReg).addImm(I->getValue()/4);
+        } else {
+          unsigned ByteReg = SelectExpr(Node->getOperand(3));
+          BuildMI(BB, X86::SHR32ri, 2, CountReg).addReg(ByteReg).addImm(2);
+        }
+        Val = (Val << 8) | Val;
+        BuildMI(BB, X86::MOV32ri, 1, X86::EAX).addImm((Val << 16) | Val);
+        Opcode = X86::REP_STOSD;
+        break;
+      default:  // BYTE aligned
+        CountReg = SelectExpr(Node->getOperand(3));
+        BuildMI(BB, X86::MOV8ri, 1, X86::AL).addImm(Val);
+        Opcode = X86::REP_STOSB;
+        break;
+      }
+    } else {
+      // If it's not a constant value we are storing, just fall back.  We could
+      // try to be clever to form 16 bit and 32 bit values, but we don't yet.
+      unsigned ValReg = SelectExpr(Node->getOperand(2));
+      BuildMI(BB, X86::MOV8rr, 1, X86::AL).addReg(ValReg);
+      CountReg = SelectExpr(Node->getOperand(3));
+      Opcode = X86::REP_STOSB;
+    }
+
+    // No matter what the alignment is, we put the source in ESI, the
+    // destination in EDI, and the count in ECX.
+    unsigned TmpReg1 = SelectExpr(Node->getOperand(1));
+    BuildMI(BB, X86::MOV32rr, 1, X86::ECX).addReg(CountReg);
+    BuildMI(BB, X86::MOV32rr, 1, X86::EDI).addReg(TmpReg1);
+    BuildMI(BB, Opcode, 0);
+    return;
+  }
   }
   assert(0 && "Should not be reached!");
 }