boundschecking:
add support for select
add experimental support for alloc_size metadata

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@157481 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/BoundsChecking.cpp b/lib/Transforms/Scalar/BoundsChecking.cpp
index 85c5e11..4c3dea2 100644
--- a/lib/Transforms/Scalar/BoundsChecking.cpp
+++ b/lib/Transforms/Scalar/BoundsChecking.cpp
@@ -26,6 +26,7 @@
 #include "llvm/GlobalVariable.h"
 #include "llvm/Instructions.h"
 #include "llvm/Intrinsics.h"
+#include "llvm/Metadata.h"
 #include "llvm/Operator.h"
 #include "llvm/Pass.h"
 using namespace llvm;
@@ -118,6 +119,9 @@
 /// incurr at run-time.
 ConstTriState BoundsChecking::computeAllocSize(Value *Alloc, uint64_t &Size,
                                      Value* &SizeValue) {
+  IntegerType *RetTy = TD->getIntPtrType(Fn->getContext());
+
+  // global variable with definitive size
   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Alloc)) {
     if (GV->hasDefinitiveInitializer()) {
       Constant *C = GV->getInitializer();
@@ -126,6 +130,7 @@
     }
     return Dunno;
 
+  // stack allocation
   } else if (AllocaInst *AI = dyn_cast<AllocaInst>(Alloc)) {
     if (!AI->getAllocatedType()->isSized())
       return Dunno;
@@ -147,68 +152,117 @@
     SizeValue = Builder->CreateMul(SizeValue, ArraySize);
     return NotConst;
 
+  // ptr = select(ptr1, ptr2)
+  } else if (SelectInst *SI = dyn_cast<SelectInst>(Alloc)) {
+    uint64_t SizeFalse;
+    Value *SizeValueFalse;
+    ConstTriState TrueConst = computeAllocSize(SI->getTrueValue(), Size,
+                                               SizeValue);
+    ConstTriState FalseConst = computeAllocSize(SI->getFalseValue(), SizeFalse,
+                                                SizeValueFalse);
+
+    if (TrueConst == Const && FalseConst == Const && Size == SizeFalse)
+      return Const;
+
+    if (Penalty < 2 || (TrueConst == Dunno && FalseConst == Dunno))
+      return Dunno;
+
+    // if one of the branches is Dunno, assume it is ok and check just the other
+    APInt MaxSize = APInt::getMaxValue(TD->getTypeSizeInBits(RetTy));
+
+    if (TrueConst == Const)
+      SizeValue = ConstantInt::get(RetTy, Size);
+    else if (TrueConst == Dunno)
+      SizeValue = ConstantInt::get(RetTy, MaxSize);
+
+    if (FalseConst == Const)
+      SizeValueFalse = ConstantInt::get(RetTy, SizeFalse);
+    else if (FalseConst == Dunno)
+      SizeValueFalse = ConstantInt::get(RetTy, MaxSize);
+
+    SizeValue = Builder->CreateSelect(SI->getCondition(), SizeValue,
+                                      SizeValueFalse);
+    return NotConst;
+
+  // call allocation function
   } else if (CallInst *CI = dyn_cast<CallInst>(Alloc)) {
-    Function *Callee = CI->getCalledFunction();
-    if (!Callee || !Callee->isDeclaration())
-      return Dunno;
+    SmallVector<unsigned, 4> Args;
 
-    FunctionType *FTy = Callee->getFunctionType();
-    if (FTy->getNumParams() == 1) {
+    if (MDNode *MD = CI->getMetadata("alloc_size")) {
+      for (unsigned i = 0, e = MD->getNumOperands(); i != e; ++i)
+        Args.push_back(cast<ConstantInt>(MD->getOperand(i))->getZExtValue());
+
+    } else if (Function *Callee = CI->getCalledFunction()) {
+      FunctionType *FTy = Callee->getFunctionType();
+
       // alloc(size)
-      if ((FTy->getParamType(0)->isIntegerTy(32) ||
-           FTy->getParamType(0)->isIntegerTy(64)) &&
-          (Callee->getName() == "malloc" ||
-           Callee->getName() == "valloc" ||
-           Callee->getName() == "_Znwj"  || // operator new(unsigned int)
-           Callee->getName() == "_Znwm"  || // operator new(unsigned long)
-           Callee->getName() == "_Znaj"  || // operator new[](unsigned int)
-           Callee->getName() == "_Znam")) { // operator new[](unsigned long)
-        SizeValue = CI->getArgOperand(0);
-        if (ConstantInt *Arg = dyn_cast<ConstantInt>(SizeValue)) {
-          Size = Arg->getZExtValue();
-          return Const;
+      if (FTy->getNumParams() == 1 && FTy->getParamType(0)->isIntegerTy()) {
+        if ((Callee->getName() == "malloc" ||
+             Callee->getName() == "valloc" ||
+             Callee->getName() == "_Znwj"  || // operator new(unsigned int)
+             Callee->getName() == "_Znwm"  || // operator new(unsigned long)
+             Callee->getName() == "_Znaj"  || // operator new[](unsigned int)
+             Callee->getName() == "_Znam")) {
+          Args.push_back(0);
         }
-        return Penalty >= 2 ? NotConst : Dunno;
+      } else if (FTy->getNumParams() == 2) {
+        // alloc(_, x)
+        if (FTy->getParamType(1)->isIntegerTy() &&
+            ((Callee->getName() == "realloc" ||
+              Callee->getName() == "reallocf"))) {
+          Args.push_back(1);
+
+        // alloc(x, y)
+        } else if (FTy->getParamType(0)->isIntegerTy() &&
+                   FTy->getParamType(1)->isIntegerTy() &&
+                   Callee->getName() == "calloc") {
+          Args.push_back(0);
+          Args.push_back(1);
+        }
       }
+    }
+
+    if (Args.empty())
       return Dunno;
-    }
 
-    if (FTy->getNumParams() == 2) {
-      // alloc(x, y) and return buffer of size x * y
-      if (((FTy->getParamType(0)->isIntegerTy(32) &&
-            FTy->getParamType(1)->isIntegerTy(32)) ||
-           (FTy->getParamType(0)->isIntegerTy(64) &&
-            FTy->getParamType(1)->isIntegerTy(64))) &&
-          Callee->getName() == "calloc") {
-        Value *Arg1 = CI->getArgOperand(0);
-        Value *Arg2 = CI->getArgOperand(1);
-        if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Arg1)) {
-          if (ConstantInt *CI2 = dyn_cast<ConstantInt>(Arg2)) {
-            Size = (CI1->getValue() * CI2->getValue()).getZExtValue();
-            return Const;
-          }
-        }
-
-        if (Penalty < 2)
-          return Dunno;
-
-        SizeValue = Builder->CreateMul(Arg1, Arg2);
-        return NotConst;
-      }
-
-      // realloc(ptr, size)
-      if ((FTy->getParamType(1)->isIntegerTy(32) ||
-           FTy->getParamType(1)->isIntegerTy(64)) &&
-          (Callee->getName() == "realloc" ||
-           Callee->getName() == "reallocf")) {
-        SizeValue = CI->getArgOperand(1);
-        if (ConstantInt *Arg = dyn_cast<ConstantInt>(SizeValue)) {
-          Size = Arg->getZExtValue();
-          return Const;
-        }
-        return Penalty >= 2 ? NotConst : Dunno;
+    // check if all arguments are constant. if so, the object size is also const
+    bool AllConst = true;
+    for (SmallVectorImpl<unsigned>::iterator I = Args.begin(), E = Args.end();
+         I != E; ++I) {
+      if (!isa<ConstantInt>(CI->getArgOperand(*I))) {
+        AllConst = false;
+        break;
       }
     }
+
+    if (AllConst) {
+      Size = 1;
+      for (SmallVectorImpl<unsigned>::iterator I = Args.begin(), E = Args.end();
+           I != E; ++I) {
+        ConstantInt *Arg = cast<ConstantInt>(CI->getArgOperand(*I));
+        Size *= (size_t)Arg->getZExtValue();
+      }
+      return Const;
+    }
+
+    if (Penalty < 2)
+      return Dunno;
+
+    // not all arguments are constant, so create a sequence of multiplications
+    bool First = true;
+    for (SmallVectorImpl<unsigned>::iterator I = Args.begin(), E = Args.end();
+         I != E; ++I) {
+      Value *Arg = CI->getArgOperand(*I);
+      if (First) {
+        SizeValue = Arg;
+        First = false;
+        continue;
+      }
+      SizeValue = Builder->CreateMul(SizeValue, Arg);
+    }
+
+    return NotConst;
+
     // TODO: handle more standard functions:
     // - strdup / strndup
     // - strcpy / strncpy
@@ -216,7 +270,7 @@
     // - strcat / strncat
   }
 
-  DEBUG(dbgs() << "computeAllocSize failed:\n" << *Alloc);
+  DEBUG(dbgs() << "computeAllocSize failed:\n" << *Alloc << "\n");
   return Dunno;
 }