Match union field type when member expression is u->x


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@44879 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/CodeGen/CGExpr.cpp b/CodeGen/CGExpr.cpp
index 5026c83..ab5d059 100644
--- a/CodeGen/CGExpr.cpp
+++ b/CodeGen/CGExpr.cpp
@@ -387,16 +387,24 @@
 
 LValue CodeGenFunction::EmitMemberExpr(const MemberExpr *E) {
 
+  bool isUnion = false;
   Expr *BaseExpr = E->getBase();
   llvm::Value *BaseValue = NULL;
   
   // If this is s.x, emit s as an lvalue.  If it is s->x, emit s as a scalar.
-  if (E->isArrow())
+  if (E->isArrow()) {
     BaseValue = EmitScalarExpr(BaseExpr);
+    const PointerType *PTy = 
+      cast<PointerType>(BaseExpr->getType().getCanonicalType());
+    if (PTy->getPointeeType()->isUnionType())
+      isUnion = true;
+  }
   else {
     LValue BaseLV = EmitLValue(BaseExpr);
     // FIXME: this isn't right for bitfields.
     BaseValue = BaseLV.getAddress();
+    if (BaseExpr->getType()->isUnionType())
+      isUnion = true;
   }
 
   FieldDecl *Field = E->getMemberDecl();
@@ -409,7 +417,7 @@
 
   llvm::Value *V = Builder.CreateGEP(BaseValue,Idxs, Idxs + 2, "tmp");
   // Match union field type.
-  if (BaseExpr->getType()->isUnionType()) {
+  if (isUnion) {
     const llvm::Type * FieldTy = ConvertType(Field->getType());
     const llvm::PointerType * BaseTy = 
       cast<llvm::PointerType>(BaseValue->getType());
diff --git a/test/CodeGen/union.c b/test/CodeGen/union.c
index 4732938..4d32abd 100644
--- a/test/CodeGen/union.c
+++ b/test/CodeGen/union.c
@@ -1,6 +1,6 @@
 // RUN: clang %s -emit-llvm
 
-union {
+union u_tag {
   int a;
   float b;
 } u;
@@ -9,6 +9,10 @@
   u.b = 11;
 }
 
+float get_b(union u_tag *my_u) {
+  return my_u->b;
+}
+
 int f2( float __x ) { 
   union{ 
     float __f;