Build lookup tables for switches (PR884)

This adds a transformation to SimplifyCFG that attemps to turn switch
instructions into loads from lookup tables. It works on switches that
are only used to initialize one or more phi nodes in a common successor
basic block, for example:

  int f(int x) {
    switch (x) {
    case 0: return 5;
    case 1: return 4;
    case 2: return -2;
    case 5: return 7;
    case 6: return 9;
    default: return 42;
  }

This speeds up the code by removing the hard-to-predict jump, and
reduces code size by removing the code for the jump targets.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@163302 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 6cd3bbc..62b98cb 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -22,6 +22,7 @@
 #include "llvm/LLVMContext.h"
 #include "llvm/MDBuilder.h"
 #include "llvm/Metadata.h"
+#include "llvm/Module.h"
 #include "llvm/Operator.h"
 #include "llvm/Type.h"
 #include "llvm/ADT/DenseMap.h"
@@ -54,6 +55,7 @@
        cl::desc("Duplicate return instructions into unconditional branches"));
 
 STATISTIC(NumSpeculations, "Number of speculative executed instructions");
+STATISTIC(NumLookupTables, "Number of switch instructions turned into lookup tables");
 
 namespace {
   /// ValueEqualityComparisonCase - Represents a case of a switch.
@@ -2977,6 +2979,287 @@
   return Changed;
 }
 
+/// ValidLookupTableConstant - Return true if the backend will be able to handle
+/// initializing an array of constants like C.
+bool ValidLookupTableConstant(Constant *C) {
+  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
+    return CE->isGEPWithNoNotionalOverIndexing();
+
+  return isa<ConstantFP>(C) ||
+      isa<ConstantInt>(C) ||
+      isa<ConstantPointerNull>(C) ||
+      isa<GlobalValue>(C) ||
+      isa<UndefValue>(C);
+}
+
+/// GetCaseResulsts - Try to determine the resulting constant values in phi
+/// nodes at the common destination basic block for one of the case
+/// destinations of a switch instruction.
+static bool GetCaseResults(SwitchInst *SI,
+                           BasicBlock *CaseDest,
+                           BasicBlock **CommonDest,
+                           SmallVector<std::pair<PHINode*,Constant*>, 4> &Res) {
+  // The block from which we enter the common destination.
+  BasicBlock *Pred = SI->getParent();
+
+  // If CaseDest is empty, continue to its successor.
+  if (CaseDest->getFirstNonPHIOrDbg() == CaseDest->getTerminator() &&
+      !isa<PHINode>(CaseDest->begin())) {
+
+    TerminatorInst *Terminator = CaseDest->getTerminator();
+    if (Terminator->getNumSuccessors() != 1)
+      return false;
+
+    Pred = CaseDest;
+    CaseDest = Terminator->getSuccessor(0);
+  }
+
+  // If we did not have a CommonDest before, use the current one.
+  if (!*CommonDest)
+    *CommonDest = CaseDest;
+  // If the destination isn't the common one, abort.
+  if (CaseDest != *CommonDest)
+    return false;
+
+  // Get the values for this case from phi nodes in the destination block.
+  BasicBlock::iterator I = (*CommonDest)->begin();
+  while (PHINode *PHI = dyn_cast<PHINode>(I++)) {
+    int Idx = PHI->getBasicBlockIndex(Pred);
+    if (Idx == -1)
+      continue;
+
+    Constant *ConstVal = dyn_cast<Constant>(PHI->getIncomingValue(Idx));
+    if (!ConstVal)
+      return false;
+
+    // Be conservative about which kinds of constants we support.
+    if (!ValidLookupTableConstant(ConstVal))
+      return false;
+
+    Res.push_back(std::make_pair(PHI, ConstVal));
+  }
+
+  return true;
+}
+
+/// BuildLookupTable - Build a lookup table with the contents of Results, using
+/// DefaultResult to fill the holes in the table. If the table ends up
+/// containing the same result in each element, set *SingleResult to that value
+/// and return NULL.
+static GlobalVariable *BuildLookupTable(
+    Module &M,
+    uint64_t TableSize,
+    ConstantInt *Offset,
+    const std::vector<std::pair<ConstantInt*,Constant*> >& Results,
+    Constant *DefaultResult,
+    Constant **SingleResult) {
+  assert(Results.size() && "Need values to build lookup table");
+  assert(TableSize >= Results.size() && "Table needs to hold all values");
+
+  // If all values in the table are equal, this is that value.
+  Constant *SameResult = Results.begin()->second;
+
+  // Build up the table contents.
+  std::vector<Constant*> TableContents(TableSize);
+  for (size_t I = 0, E = Results.size(); I != E; ++I) {
+    ConstantInt *CaseVal = Results[I].first;
+    Constant *CaseRes = Results[I].second;
+
+    uint64_t Idx = (CaseVal->getValue() - Offset->getValue()).getLimitedValue();
+    TableContents[Idx] = CaseRes;
+
+    if (CaseRes != SameResult)
+      SameResult = NULL;
+  }
+
+  // Fill in any holes in the table with the default result.
+  if (Results.size() < TableSize) {
+    for (unsigned i = 0; i < TableSize; ++i) {
+      if (!TableContents[i])
+        TableContents[i] = DefaultResult;
+    }
+
+    if (DefaultResult != SameResult)
+      SameResult = NULL;
+  }
+
+  // Same result was used in the entire table; just return that.
+  if (SameResult) {
+    *SingleResult = SameResult;
+    return NULL;
+  }
+
+  ArrayType *ArrayTy = ArrayType::get(DefaultResult->getType(), TableSize);
+  Constant *Initializer = ConstantArray::get(ArrayTy, TableContents);
+
+  GlobalVariable *GV = new GlobalVariable(M, ArrayTy, /*constant=*/ true,
+                                          GlobalVariable::PrivateLinkage,
+                                          Initializer,
+                                          "switch.table");
+  GV->setUnnamedAddr(true);
+  return GV;
+}
+
+/// SwitchToLookupTable - If the switch is only used to initialize one or more
+/// phi nodes in a common successor block with different constant values,
+/// replace the switch with lookup tables.
+static bool SwitchToLookupTable(SwitchInst *SI,
+                                IRBuilder<> &Builder) {
+  assert(SI->getNumCases() > 1 && "Degenerate switch?");
+  // FIXME: Handle unreachable cases.
+
+  // FIXME: If the switch is too sparse for a lookup table, perhaps we could
+  // split off a dense part and build a lookup table for that.
+
+  // FIXME: If the results are all integers and the lookup table would fit in a
+  // target-legal register, we should store them as a bitmap and use shift/mask
+  // to look up the result.
+
+  // FIXME: This creates arrays of GEPs to constant strings, which means each
+  // GEP needs a runtime relocation in PIC code. We should just build one big
+  // string and lookup indices into that.
+
+  // Ignore the switch if the number of cases are too small.
+  // This is similar to the check when building jump tables in
+  // SelectionDAGBuilder::handleJTSwitchCase.
+  // FIXME: Determine the best cut-off.
+  if (SI->getNumCases() < 4)
+    return false;
+
+  // Figure out the corresponding result for each case value and phi node in the
+  // common destination, as well as the the min and max case values.
+  assert(SI->case_begin() != SI->case_end());
+  SwitchInst::CaseIt CI = SI->case_begin();
+  ConstantInt *MinCaseVal = CI.getCaseValue();
+  ConstantInt *MaxCaseVal = CI.getCaseValue();
+
+  BasicBlock *CommonDest = NULL;
+  typedef std::vector<std::pair<ConstantInt*, Constant*> > ResultListTy;
+  SmallDenseMap<PHINode*, ResultListTy> ResultLists;
+  SmallDenseMap<PHINode*, Constant*> DefaultResults;
+  SmallDenseMap<PHINode*, Type*> ResultTypes;
+  SmallVector<PHINode*, 4> PHIs;
+
+  for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) {
+    ConstantInt *CaseVal = CI.getCaseValue();
+    if (CaseVal->getValue().slt(MinCaseVal->getValue()))
+      MinCaseVal = CaseVal;
+    if (CaseVal->getValue().sgt(MaxCaseVal->getValue()))
+      MaxCaseVal = CaseVal;
+
+    // Resulting value at phi nodes for this case value.
+    typedef SmallVector<std::pair<PHINode*, Constant*>, 4> ResultsTy;
+    ResultsTy Results;
+    if (!GetCaseResults(SI, CI.getCaseSuccessor(), &CommonDest, Results))
+      return false;
+
+    // Append the result from this case to the list for each phi.
+    for (ResultsTy::iterator I = Results.begin(), E = Results.end(); I!=E; ++I) {
+      if (!ResultLists.count(I->first))
+        PHIs.push_back(I->first);
+      ResultLists[I->first].push_back(std::make_pair(CaseVal, I->second));
+    }
+  }
+
+  // Get the resulting values for the default case.
+  {
+    SmallVector<std::pair<PHINode*, Constant*>, 4> DefaultResultsList;
+    if (!GetCaseResults(SI, SI->getDefaultDest(), &CommonDest, DefaultResultsList))
+      return false;
+    for (size_t I = 0, E = DefaultResultsList.size(); I != E; ++I) {
+      PHINode *PHI = DefaultResultsList[I].first;
+      Constant *Result = DefaultResultsList[I].second;
+      DefaultResults[PHI] = Result;
+      ResultTypes[PHI] = Result->getType();
+    }
+  }
+
+  APInt RangeSpread = MaxCaseVal->getValue() - MinCaseVal->getValue();
+  // The table density should be at lest 40%. This is the same criterion as for
+  // jump tables, see SelectionDAGBuilder::handleJTSwitchCase.
+  // FIXME: Find the best cut-off.
+  // Be careful to avoid overlow in the density computation.
+  if (RangeSpread.zextOrSelf(64).ugt(UINT64_MAX / 4 - 1))
+    return false;
+  uint64_t TableSize = RangeSpread.getLimitedValue() + 1;
+  if (SI->getNumCases() * 10 < TableSize * 4)
+    return false;
+
+  // Build the lookup tables.
+  SmallDenseMap<PHINode*, GlobalVariable*> LookupTables;
+  SmallDenseMap<PHINode*, Constant*> SingleResults;
+
+  Module &Mod = *CommonDest->getParent()->getParent();
+  for (SmallDenseMap<PHINode*, ResultListTy>::iterator I = ResultLists.begin(),
+       E = ResultLists.end(); I != E; ++I) {
+    PHINode *PHI = I->first;
+
+    Constant *SingleResult = NULL;
+    LookupTables[PHI] = BuildLookupTable(Mod, TableSize, MinCaseVal, I->second,
+                                         DefaultResults[PHI], &SingleResult);
+    SingleResults[PHI] = SingleResult;
+  }
+
+  // Create the BB that does the lookups.
+  BasicBlock *LookupBB = BasicBlock::Create(Mod.getContext(),
+                                            "switch.lookup",
+                                            CommonDest->getParent(),
+                                            CommonDest);
+
+  // Check whether the condition value is within the case range, and branch to
+  // the new BB.
+  Builder.SetInsertPoint(SI);
+  Value *TableIndex = Builder.CreateSub(SI->getCondition(), MinCaseVal,
+                                        "switch.tableidx");
+  Value *Cmp = Builder.CreateICmpULT(TableIndex, ConstantInt::get(
+      MinCaseVal->getType(), TableSize));
+  Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest());
+
+  // Populate the BB that does the lookups.
+  Builder.SetInsertPoint(LookupBB);
+  bool ReturnedEarly = false;
+  for (SmallVector<PHINode*, 4>::iterator I = PHIs.begin(), E = PHIs.end();
+       I != E; ++I) {
+    PHINode *PHI = *I;
+    // There was a single result for this phi; just use that.
+    if (Constant *SingleResult = SingleResults[PHI]) {
+      PHI->addIncoming(SingleResult, LookupBB);
+      continue;
+    }
+
+    Value *GEPIndices[] = { Builder.getInt32(0), TableIndex };
+    Value *GEP = Builder.CreateInBoundsGEP(LookupTables[PHI], GEPIndices,
+                                           "switch.gep");
+    Value *Result = Builder.CreateLoad(GEP, "switch.load");
+
+    // If the result is only going to be used to return from the function,
+    // we want to do that right here.
+    if (PHI->hasOneUse() && isa<ReturnInst>(*PHI->use_begin())) {
+      if (CommonDest->getFirstNonPHIOrDbg() == CommonDest->getTerminator()) {
+        Builder.CreateRet(Result);
+        ReturnedEarly = true;
+      }
+    }
+
+    if (!ReturnedEarly)
+      PHI->addIncoming(Result, LookupBB);
+  }
+
+  if (!ReturnedEarly)
+    Builder.CreateBr(CommonDest);
+
+  // Remove the switch.
+  for (unsigned i = 0; i < SI->getNumSuccessors(); ++i) {
+    BasicBlock *Succ = SI->getSuccessor(i);
+    if (Succ == SI->getDefaultDest()) continue;
+    Succ->removePredecessor(SI->getParent());
+  }
+  SI->eraseFromParent();
+
+  ++NumLookupTables;
+  return true;
+}
+
 bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   // If this switch is too complex to want to look at, ignore it.
   if (!isValueEqualityComparison(SI))
@@ -3016,6 +3299,9 @@
   if (ForwardSwitchConditionToPHI(SI))
     return SimplifyCFG(BB) | true;
 
+  if (SwitchToLookupTable(SI, Builder))
+    return SimplifyCFG(BB) | true;
+
   return false;
 }