Make WholeProgramDevirt understand ConstStruct vtables.

Based on a patch by LemonBoy!

Differential Revision: https://reviews.llvm.org/D26581

llvm-svn: 289162
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 7ef5f24..9c80a2a 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -293,6 +293,7 @@
   void buildTypeIdentifierMap(
       std::vector<VTableBits> &Bits,
       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
+  Constant *getValueAtOffset(Constant *I, uint64_t Offset);
   bool
   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
                             const std::set<TypeMemberInfo> &TypeMemberInfos,
@@ -382,6 +383,38 @@
   }
 }
 
+Constant *DevirtModule::getValueAtOffset(Constant *I, uint64_t Offset) {
+  const DataLayout &DL = M.getDataLayout();
+  unsigned Op;
+
+  if (auto *C = dyn_cast<ConstantStruct>(I)) {
+    const StructLayout *SL = DL.getStructLayout(C->getType());
+
+    if (Offset >= SL->getSizeInBytes())
+      return nullptr;
+
+    Op = SL->getElementContainingOffset(Offset);
+
+    if (Offset != SL->getElementOffset(Op))
+      return nullptr;
+
+  } else if (auto *C = dyn_cast<ConstantArray>(I)) {
+    ArrayType *VTableTy = C->getType();
+    uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
+
+    if (Offset % ElemSize != 0)
+      return nullptr;
+
+    Op = Offset / ElemSize;
+
+    if (Op >= C->getNumOperands())
+      return nullptr;
+  } else
+    return nullptr;
+
+  return cast<Constant>(I->getOperand(Op));
+}
+
 bool DevirtModule::tryFindVirtualCallTargets(
     std::vector<VirtualCallTarget> &TargetsForSlot,
     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
@@ -389,22 +422,13 @@
     if (!TM.Bits->GV->isConstant())
       return false;
 
-    auto Init = dyn_cast<ConstantArray>(TM.Bits->GV->getInitializer());
-    if (!Init)
-      return false;
-    ArrayType *VTableTy = Init->getType();
+    Constant *I = TM.Bits->GV->getInitializer();
+    Value *V = getValueAtOffset(I, TM.Offset + ByteOffset);
 
-    uint64_t ElemSize =
-        M.getDataLayout().getTypeAllocSize(VTableTy->getElementType());
-    uint64_t GlobalSlotOffset = TM.Offset + ByteOffset;
-    if (GlobalSlotOffset % ElemSize != 0)
+    if (!V)
       return false;
 
-    unsigned Op = GlobalSlotOffset / ElemSize;
-    if (Op >= Init->getNumOperands())
-      return false;
-
-    auto Fn = dyn_cast<Function>(Init->getOperand(Op)->stripPointerCasts());
+    auto Fn = dyn_cast<Function>(V->stripPointerCasts());
     if (!Fn)
       return false;