Match union field type when member expression is u->x


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@44879 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/CodeGen/CGExpr.cpp b/CodeGen/CGExpr.cpp
index 5026c83..ab5d059 100644
--- a/CodeGen/CGExpr.cpp
+++ b/CodeGen/CGExpr.cpp
@@ -387,16 +387,24 @@
 
 LValue CodeGenFunction::EmitMemberExpr(const MemberExpr *E) {
 
+  bool isUnion = false;
   Expr *BaseExpr = E->getBase();
   llvm::Value *BaseValue = NULL;
   
   // If this is s.x, emit s as an lvalue.  If it is s->x, emit s as a scalar.
-  if (E->isArrow())
+  if (E->isArrow()) {
     BaseValue = EmitScalarExpr(BaseExpr);
+    const PointerType *PTy = 
+      cast<PointerType>(BaseExpr->getType().getCanonicalType());
+    if (PTy->getPointeeType()->isUnionType())
+      isUnion = true;
+  }
   else {
     LValue BaseLV = EmitLValue(BaseExpr);
     // FIXME: this isn't right for bitfields.
     BaseValue = BaseLV.getAddress();
+    if (BaseExpr->getType()->isUnionType())
+      isUnion = true;
   }
 
   FieldDecl *Field = E->getMemberDecl();
@@ -409,7 +417,7 @@
 
   llvm::Value *V = Builder.CreateGEP(BaseValue,Idxs, Idxs + 2, "tmp");
   // Match union field type.
-  if (BaseExpr->getType()->isUnionType()) {
+  if (isUnion) {
     const llvm::Type * FieldTy = ConvertType(Field->getType());
     const llvm::PointerType * BaseTy = 
       cast<llvm::PointerType>(BaseValue->getType());