X86: Try to use a smaller encoding by transforming (X << C1) & C2 into (X & (C2 >> C1)) & C1. (Part of PR5039)

This tends to happen a lot with bitfield code generated by clang. A simple example for x86_64 is
uint64_t foo(uint64_t x) { return (x&1) << 42; }
which used to compile into bloated code:
	shlq	$42, %rdi               ## encoding: [0x48,0xc1,0xe7,0x2a]
	movabsq	$4398046511104, %rax    ## encoding: [0x48,0xb8,0x00,0x00,0x00,0x00,0x00,0x04,0x00,0x00]
	andq	%rdi, %rax              ## encoding: [0x48,0x21,0xf8]
	ret                             ## encoding: [0xc3]

with this patch we can fold the immediate into the and:
	andq	$1, %rdi                ## encoding: [0x48,0x83,0xe7,0x01]
	movq	%rdi, %rax              ## encoding: [0x48,0x89,0xf8]
	shlq	$42, %rax               ## encoding: [0x48,0xc1,0xe0,0x2a]
	ret                             ## encoding: [0xc3]

It's possible to save another byte by using 'andl' instead of 'andq' but I currently see no way of doing
that without making this code even more complicated. See the TODOs in the code.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@129990 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelDAGToDAG.cpp b/lib/Target/X86/X86ISelDAGToDAG.cpp
index 9b0ec6e..e156cef 100644
--- a/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -1580,6 +1580,81 @@
       return RetVal;
     break;
   }
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR: {
+    // For operations of the form (x << C1) op C2, check if we can use a smaller
+    // encoding for C2 by transforming it into (x op (C2>>C1)) << C1.
+    SDValue N0 = Node->getOperand(0);
+    SDValue N1 = Node->getOperand(1);
+
+    if (N0->getOpcode() != ISD::SHL || !N0->hasOneUse())
+      break;
+
+    // i8 is unshrinkable, i16 should be promoted to i32.
+    if (NVT != MVT::i32 && NVT != MVT::i64)
+      break;
+
+    ConstantSDNode *Cst = dyn_cast<ConstantSDNode>(N1);
+    ConstantSDNode *ShlCst = dyn_cast<ConstantSDNode>(N0->getOperand(1));
+    if (!Cst || !ShlCst)
+      break;
+
+    int64_t Val = Cst->getSExtValue();
+    uint64_t ShlVal = ShlCst->getZExtValue();
+
+    // Make sure that we don't change the operation by removing bits.
+    // This only matters for OR and XOR, AND is unaffected.
+    if (Opcode != ISD::AND && ((Val >> ShlVal) << ShlVal) != Val)
+      break;
+
+    unsigned ShlOp, Op;
+    EVT CstVT = NVT;
+
+    // Check the minimum bitwidth for the new constant.
+    // TODO: AND32ri is the same as AND64ri32 with zext imm.
+    // TODO: MOV32ri+OR64r is cheaper than MOV64ri64+OR64rr
+    // TODO: Using 16 and 8 bit operations is also possible for or32 & xor32.
+    if (!isInt<8>(Val) && isInt<8>(Val >> ShlVal))
+      CstVT = MVT::i8;
+    else if (!isInt<32>(Val) && isInt<32>(Val >> ShlVal))
+      CstVT = MVT::i32;
+
+    // Bail if there is no smaller encoding.
+    if (NVT == CstVT)
+      break;
+
+    switch (NVT.getSimpleVT().SimpleTy) {
+    default: llvm_unreachable("Unsupported VT!");
+    case MVT::i32:
+      assert(CstVT == MVT::i8);
+      ShlOp = X86::SHL32ri;
+
+      switch (Opcode) {
+      case ISD::AND: Op = X86::AND32ri8; break;
+      case ISD::OR:  Op =  X86::OR32ri8; break;
+      case ISD::XOR: Op = X86::XOR32ri8; break;
+      }
+      break;
+    case MVT::i64:
+      assert(CstVT == MVT::i8 || CstVT == MVT::i32);
+      ShlOp = X86::SHL64ri;
+
+      switch (Opcode) {
+      case ISD::AND: Op = CstVT==MVT::i8? X86::AND64ri8 : X86::AND64ri32; break;
+      case ISD::OR:  Op = CstVT==MVT::i8?  X86::OR64ri8 :  X86::OR64ri32; break;
+      case ISD::XOR: Op = CstVT==MVT::i8? X86::XOR64ri8 : X86::XOR64ri32; break;
+      }
+      break;
+    }
+
+    // Emit the smaller op and the shift.
+    SDValue NewCst = CurDAG->getTargetConstant(Val >> ShlVal, CstVT);
+    SDNode *New = CurDAG->getMachineNode(Op, dl, NVT, N0->getOperand(0),NewCst);
+    return CurDAG->SelectNodeTo(Node, ShlOp, NVT, SDValue(New, 0),
+                                getI8Imm(ShlVal));
+    break;
+  }
   case X86ISD::UMUL: {
     SDValue N0 = Node->getOperand(0);
     SDValue N1 = Node->getOperand(1);