[opaque pointer types] Pass function type for CallBase::setCalledFunction.

Differential Revision: https://reviews.llvm.org/D57174

llvm-svn: 352914
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 80eb51b..b34b3fd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -4214,7 +4214,8 @@
       // We cannot remove an invoke, because it would change the CFG, just
       // change the callee to a null pointer.
       cast<InvokeInst>(OldCall)->setCalledFunction(
-                                    Constant::getNullValue(CalleeF->getType()));
+          CalleeF->getFunctionType(),
+          Constant::getNullValue(CalleeF->getType()));
       return nullptr;
     }
   }
@@ -4555,8 +4556,8 @@
 InstCombiner::transformCallThroughTrampoline(CallBase &Call,
                                              IntrinsicInst &Tramp) {
   Value *Callee = Call.getCalledValue();
-  PointerType *PTy = cast<PointerType>(Callee->getType());
-  FunctionType *FTy = cast<FunctionType>(PTy->getElementType());
+  Type *CalleeTy = Callee->getType();
+  FunctionType *FTy = Call.getFunctionType();
   AttributeList Attrs = Call.getAttributes();
 
   // If the call already has the 'nest' attribute somewhere then give up -
@@ -4565,7 +4566,7 @@
     return nullptr;
 
   Function *NestF = cast<Function>(Tramp.getArgOperand(1)->stripPointerCasts());
-  FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType());
+  FunctionType *NestFTy = NestF->getFunctionType();
 
   AttributeList NestAttrs = NestF->getAttributes();
   if (!NestAttrs.isEmpty()) {
@@ -4689,9 +4690,7 @@
   // Replace the trampoline call with a direct call.  Since there is no 'nest'
   // parameter, there is no need to adjust the argument list.  Let the generic
   // code sort out any function type mismatches.
-  Constant *NewCallee =
-    NestF->getType() == PTy ? NestF :
-                              ConstantExpr::getBitCast(NestF, PTy);
-  Call.setCalledFunction(NewCallee);
+  Constant *NewCallee = ConstantExpr::getBitCast(NestF, CalleeTy);
+  Call.setCalledFunction(FTy, NewCallee);
   return &Call;
 }