Add support for a union type in LLVM IR.  Patch by Talin!


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@96011 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/AsmParser/LLLexer.cpp b/lib/AsmParser/LLLexer.cpp
index 8ad658d..46f3cbc 100644
--- a/lib/AsmParser/LLLexer.cpp
+++ b/lib/AsmParser/LLLexer.cpp
@@ -570,6 +570,7 @@
 
   KEYWORD(type);
   KEYWORD(opaque);
+  KEYWORD(union);
 
   KEYWORD(eq); KEYWORD(ne); KEYWORD(slt); KEYWORD(sgt); KEYWORD(sle);
   KEYWORD(sge); KEYWORD(ult); KEYWORD(ugt); KEYWORD(ule); KEYWORD(uge);
diff --git a/lib/AsmParser/LLParser.cpp b/lib/AsmParser/LLParser.cpp
index 4dab118..5cff310 100644
--- a/lib/AsmParser/LLParser.cpp
+++ b/lib/AsmParser/LLParser.cpp
@@ -1295,6 +1295,11 @@
     if (ParseStructType(Result, false))
       return true;
     break;
+  case lltok::kw_union:
+    // TypeRec ::= 'union' '{' ... '}'
+    if (ParseUnionType(Result))
+      return true;
+    break;
   case lltok::lsquare:
     // TypeRec ::= '[' ... ']'
     Lex.Lex(); // eat the lsquare.
@@ -1604,6 +1609,38 @@
   return false;
 }
 
+/// ParseUnionType
+///   TypeRec
+///     ::= 'union' '{' TypeRec (',' TypeRec)* '}'
+bool LLParser::ParseUnionType(PATypeHolder &Result) {
+  assert(Lex.getKind() == lltok::kw_union);
+  Lex.Lex(); // Consume the 'union'
+
+  if (ParseToken(lltok::lbrace, "'{' expected after 'union'")) return true;
+
+  SmallVector<PATypeHolder, 8> ParamsList;
+  do {
+    LocTy EltTyLoc = Lex.getLoc();
+    if (ParseTypeRec(Result)) return true;
+    ParamsList.push_back(Result);
+
+    if (Result->isVoidTy())
+      return Error(EltTyLoc, "union element can not have void type");
+    if (!UnionType::isValidElementType(Result))
+      return Error(EltTyLoc, "invalid element type for union");
+
+  } while (EatIfPresent(lltok::comma)) ;
+
+  if (ParseToken(lltok::rbrace, "expected '}' at end of union"))
+    return true;
+
+  SmallVector<const Type*, 8> ParamsListTy;
+  for (unsigned i = 0, e = ParamsList.size(); i != e; ++i)
+    ParamsListTy.push_back(ParamsList[i].get());
+  Result = HandleUpRefs(UnionType::get(&ParamsListTy[0], ParamsListTy.size()));
+  return false;
+}
+
 /// ParseArrayVectorType - Parse an array or vector type, assuming the first
 /// token has already been consumed.
 ///   TypeRec
@@ -2163,8 +2200,8 @@
         ParseToken(lltok::rparen, "expected ')' in extractvalue constantexpr"))
       return true;
 
-    if (!isa<StructType>(Val->getType()) && !isa<ArrayType>(Val->getType()))
-      return Error(ID.Loc, "extractvalue operand must be array or struct");
+    if (!Val->getType()->isAggregateType())
+      return Error(ID.Loc, "extractvalue operand must be aggregate type");
     if (!ExtractValueInst::getIndexedType(Val->getType(), Indices.begin(),
                                           Indices.end()))
       return Error(ID.Loc, "invalid indices for extractvalue");
@@ -2184,8 +2221,8 @@
         ParseIndexList(Indices) ||
         ParseToken(lltok::rparen, "expected ')' in insertvalue constantexpr"))
       return true;
-    if (!isa<StructType>(Val0->getType()) && !isa<ArrayType>(Val0->getType()))
-      return Error(ID.Loc, "extractvalue operand must be array or struct");
+    if (!Val0->getType()->isAggregateType())
+      return Error(ID.Loc, "insertvalue operand must be aggregate type");
     if (!ExtractValueInst::getIndexedType(Val0->getType(), Indices.begin(),
                                           Indices.end()))
       return Error(ID.Loc, "invalid indices for insertvalue");
@@ -2521,8 +2558,17 @@
     V = Constant::getNullValue(Ty);
     return false;
   case ValID::t_Constant:
-    if (ID.ConstantVal->getType() != Ty)
+    if (ID.ConstantVal->getType() != Ty) {
+      // Allow a constant struct with a single member to be converted
+      // to a union, if the union has a member which is the same type
+      // as the struct member.
+      if (const UnionType* utype = dyn_cast<UnionType>(Ty)) {
+        return ParseUnionValue(utype, ID, V);
+      }
+
       return Error(ID.Loc, "constant expression type mismatch");
+    }
+
     V = ID.ConstantVal;
     return false;
   }
@@ -2552,6 +2598,22 @@
   return false;
 }
 
+bool LLParser::ParseUnionValue(const UnionType* utype, ValID &ID, Value *&V) {
+  if (const StructType* stype = dyn_cast<StructType>(ID.ConstantVal->getType())) {
+    if (stype->getNumContainedTypes() != 1)
+      return Error(ID.Loc, "constant expression type mismatch");
+    int index = utype->getElementTypeIndex(stype->getContainedType(0));
+    if (index < 0)
+      return Error(ID.Loc, "initializer type is not a member of the union");
+
+    V = ConstantUnion::get(
+        utype, cast<Constant>(ID.ConstantVal->getOperand(0)));
+    return false;
+  }
+
+  return Error(ID.Loc, "constant expression type mismatch");
+}
+
 
 /// FunctionHeader
 ///   ::= OptionalLinkage OptionalVisibility OptionalCallingConv OptRetAttrs
@@ -3811,8 +3873,8 @@
       ParseIndexList(Indices, AteExtraComma))
     return true;
 
-  if (!isa<StructType>(Val->getType()) && !isa<ArrayType>(Val->getType()))
-    return Error(Loc, "extractvalue operand must be array or struct");
+  if (!Val->getType()->isAggregateType())
+    return Error(Loc, "extractvalue operand must be aggregate type");
 
   if (!ExtractValueInst::getIndexedType(Val->getType(), Indices.begin(),
                                         Indices.end()))
@@ -3833,8 +3895,8 @@
       ParseIndexList(Indices, AteExtraComma))
     return true;
   
-  if (!isa<StructType>(Val0->getType()) && !isa<ArrayType>(Val0->getType()))
-    return Error(Loc0, "extractvalue operand must be array or struct");
+  if (!Val0->getType()->isAggregateType())
+    return Error(Loc0, "insertvalue operand must be aggregate type");
 
   if (!ExtractValueInst::getIndexedType(Val0->getType(), Indices.begin(),
                                         Indices.end()))
diff --git a/lib/AsmParser/LLParser.h b/lib/AsmParser/LLParser.h
index ed7a1d7..9abe404 100644
--- a/lib/AsmParser/LLParser.h
+++ b/lib/AsmParser/LLParser.h
@@ -31,6 +31,7 @@
   class GlobalValue;
   class MDString;
   class MDNode;
+  class UnionType;
 
   /// ValID - Represents a reference of a definition of some sort with no type.
   /// There are several cases where we have to parse the value but where the
@@ -212,6 +213,7 @@
     }
     bool ParseTypeRec(PATypeHolder &H);
     bool ParseStructType(PATypeHolder &H, bool Packed);
+    bool ParseUnionType(PATypeHolder &H);
     bool ParseArrayVectorType(PATypeHolder &H, bool isVector);
     bool ParseFunctionType(PATypeHolder &Result);
     PATypeHolder HandleUpRefs(const Type *Ty);
@@ -280,6 +282,8 @@
       return ParseTypeAndBasicBlock(BB, Loc, PFS);
     }
 
+    bool ParseUnionValue(const UnionType* utype, ValID &ID, Value *&V);
+
     struct ParamInfo {
       LocTy Loc;
       Value *V;
diff --git a/lib/AsmParser/LLToken.h b/lib/AsmParser/LLToken.h
index 7f1807c..3ac9169 100644
--- a/lib/AsmParser/LLToken.h
+++ b/lib/AsmParser/LLToken.h
@@ -97,6 +97,7 @@
 
     kw_type,
     kw_opaque,
+    kw_union,
 
     kw_eq, kw_ne, kw_slt, kw_sgt, kw_sle, kw_sge, kw_ult, kw_ugt, kw_ule,
     kw_uge, kw_oeq, kw_one, kw_olt, kw_ogt, kw_ole, kw_oge, kw_ord, kw_uno,