[SelectionDAG] Optimization of BITREVERSE legalization for power-of-2 integer scalar/vector types

An extension of D19978, this patch replaces the default BITREVERSE evaluation of individual bit masks+shifts with block mask+shifts when we have integer elements of power-of-2 bits in size.

After calling BSWAP to reverse the order of the constituent bytes (which typically follows a similar approach), every neighbouring 4-bits, 2-bits and finally 1-bit pairs are masked off and swapped over with shifts.

In doing so we can significantly reduce the number of operations required.

Differential Revision: https://reviews.llvm.org/D21578

llvm-svn: 276432
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 8163409..609ff86 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -2520,13 +2520,56 @@
   return DAG.getNode(ISD::TRUNCATE, dl, DestVT, Operation);
 }
 
-/// Open code the operations for BITREVERSE.
+/// Legalize a BITREVERSE scalar/vector operation as a series of mask + shifts.
 SDValue SelectionDAGLegalize::ExpandBITREVERSE(SDValue Op, const SDLoc &dl) {
   EVT VT = Op.getValueType();
   EVT SHVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
   unsigned Sz = VT.getScalarSizeInBits();
 
-  SDValue Tmp, Tmp2;
+  SDValue Tmp, Tmp2, Tmp3;
+
+  // If we can, perform BSWAP first and then the mask+swap the i4, then i2
+  // and finally the i1 pairs.
+  // TODO: We can easily support i4/i2 legal types if any target ever does.
+  if (Sz >= 8 && isPowerOf2_32(Sz)) {
+    // Create the masks - repeating the pattern every byte.
+    APInt MaskHi4(Sz, 0), MaskHi2(Sz, 0), MaskHi1(Sz, 0);
+    APInt MaskLo4(Sz, 0), MaskLo2(Sz, 0), MaskLo1(Sz, 0);
+    for (unsigned J = 0; J != Sz; J += 8) {
+      MaskHi4 = MaskHi4.Or(APInt(Sz, 0xF0ull << J));
+      MaskLo4 = MaskLo4.Or(APInt(Sz, 0x0Full << J));
+      MaskHi2 = MaskHi2.Or(APInt(Sz, 0xCCull << J));
+      MaskLo2 = MaskLo2.Or(APInt(Sz, 0x33ull << J));
+      MaskHi1 = MaskHi1.Or(APInt(Sz, 0xAAull << J));
+      MaskLo1 = MaskLo1.Or(APInt(Sz, 0x55ull << J));
+    }
+
+    // BSWAP if the type is wider than a single byte.
+    Tmp = (Sz > 8 ? DAG.getNode(ISD::BSWAP, dl, VT, Op) : Op);
+
+    // swap i4: ((V & 0xF0) >> 4) | ((V & 0x0F) << 4)
+    Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(MaskHi4, dl, VT));
+    Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(MaskLo4, dl, VT));
+    Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp2, DAG.getConstant(4, dl, VT));
+    Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(4, dl, VT));
+    Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
+
+    // swap i2: ((V & 0xCC) >> 2) | ((V & 0x33) << 2)
+    Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(MaskHi2, dl, VT));
+    Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(MaskLo2, dl, VT));
+    Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp2, DAG.getConstant(2, dl, VT));
+    Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(2, dl, VT));
+    Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
+
+    // swap i1: ((V & 0xAA) >> 1) | ((V & 0x55) << 1)
+    Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(MaskHi1, dl, VT));
+    Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(MaskLo1, dl, VT));
+    Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp2, DAG.getConstant(1, dl, VT));
+    Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(1, dl, VT));
+    Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
+    return Tmp;
+  }
+
   Tmp = DAG.getConstant(0, dl, VT);
   for (unsigned I = 0, J = Sz-1; I < Sz; ++I, --J) {
     if (I < J)
@@ -2550,7 +2593,7 @@
   EVT VT = Op.getValueType();
   EVT SHVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
   SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5, Tmp6, Tmp7, Tmp8;
-  switch (VT.getSimpleVT().SimpleTy) {
+  switch (VT.getSimpleVT().getScalarType().SimpleTy) {
   default: llvm_unreachable("Unhandled Expand type in BSWAP!");
   case MVT::i16:
     Tmp2 = DAG.getNode(ISD::SHL, dl, VT, Op, DAG.getConstant(8, dl, SHVT));