Implemented support for user-defined structures
TRAC #11730
Signed-off-by: Andrew Lewycky
Signed-off-by: Daniel Koch

Author:    Nicolas Capens

git-svn-id: https://angleproject.googlecode.com/svn/trunk@116 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/Intermediate.cpp b/src/compiler/Intermediate.cpp
index e7aa5aa..c85aaaf 100644
--- a/src/compiler/Intermediate.cpp
+++ b/src/compiler/Intermediate.cpp
@@ -774,7 +774,7 @@
     // Base assumption:  just make the type the same as the left
     // operand.  Then only deviations from this need be coded.
     //
-    setType(TType(type, EvqTemporary, left->getNominalSize(), left->isMatrix()));
+    setType(left->getType());
 
     //
     // Array operations.
@@ -1394,4 +1394,3 @@
     pragmaTable = new TPragmaTable();
     *pragmaTable = pTable;
 }
-
diff --git a/src/compiler/OutputHLSL.cpp b/src/compiler/OutputHLSL.cpp
index 2a3d854..9875f0a 100644
--- a/src/compiler/OutputHLSL.cpp
+++ b/src/compiler/OutputHLSL.cpp
@@ -54,7 +54,7 @@
                     varyingInput += "    " + typeString(type) + " " + name + arrayString(type) + semantic + ";\n";
                     varyingGlobals += "static " + typeString(type) + " " + name + arrayString(type) + " = " + initializer(type) + ";\n";
                 }
-                else if (qualifier == EvqGlobal)
+                else if (qualifier == EvqGlobal || qualifier == EvqTemporary)
                 {
                     // Globals are declared and intialized as an aggregate node
                 }
@@ -152,7 +152,7 @@
                     varyingOutput += "    " + typeString(type) + " " + name + arrayString(type) + " : TEXCOORD0;\n";   // Actual semantic index assigned during link
                     varyingGlobals += "static " + typeString(type) + " " + name + arrayString(type) + " = " + initializer(type) + ";\n";
                 }
-                else if (qualifier == EvqGlobal)
+                else if (qualifier == EvqGlobal || qualifier == EvqTemporary)
                 {
                     // Globals are declared and intialized as an aggregate node
                 }
@@ -810,36 +810,57 @@
 
             if (variable && (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal))
             {
-                out << typeString(variable->getType()) + " ";
-
-                for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
+                if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "")   // Variable declaration
                 {
-                    TIntermSymbol *symbol = (*sit)->getAsSymbolNode();
+                    out << typeString(variable->getType()) + " ";
 
-                    if (symbol)
+                    for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
                     {
-                        symbol->traverse(this);
+                        TIntermSymbol *symbol = (*sit)->getAsSymbolNode();
 
-                        out << arrayString(symbol->getType());
-                    }
-                    else
-                    {
-                        (*sit)->traverse(this);
-                    }
-
-                    if (visit && this->inVisit)
-                    {
-                        if (*sit != sequence.back())
+                        if (symbol)
                         {
-                            visit = this->visitAggregate(InVisit, node);
+                            symbol->traverse(this);
+
+                            out << arrayString(symbol->getType());
+                        }
+                        else
+                        {
+                            (*sit)->traverse(this);
+                        }
+
+                        if (visit && this->inVisit)
+                        {
+                            if (*sit != sequence.back())
+                            {
+                                visit = this->visitAggregate(InVisit, node);
+                            }
                         }
                     }
-                }
 
-                if (visit && this->postVisit)
-                {
-                    this->visitAggregate(PostVisit, node);
+                    if (visit && this->postVisit)
+                    {
+                        this->visitAggregate(PostVisit, node);
+                    }
                 }
+                else if (variable->getAsSymbolNode() && variable->getAsSymbolNode()->getSymbol() == "")   // Type (struct) declaration
+                {
+                    const TType &type = variable->getType();
+                    const TTypeList &fields = *type.getStruct();
+
+                    out << "struct " + type.getTypeName() + "\n"
+                           "{\n";
+
+                    for (unsigned int i = 0; i < fields.size(); i++)
+                    {
+                        const TType &field = *fields[i].type;
+
+                        out << "    " + typeString(field) + " " + field.getFieldName() + ";\n";
+                    }
+
+                    out << "};\n";
+                }
+                else UNREACHABLE();
             }
             
             return false;
@@ -954,11 +975,11 @@
             }
         }
         break;
-      case EOpParameters:       outputTriplet(visit, "(", ", ", ")\n{\n");  break;
-      case EOpConstructFloat:   outputTriplet(visit, "vec1(", NULL, ")");   break;
-      case EOpConstructVec2:    outputTriplet(visit, "vec2(", ", ", ")");   break;
-      case EOpConstructVec3:    outputTriplet(visit, "vec3(", ", ", ")");   break;
-      case EOpConstructVec4:    outputTriplet(visit, "vec4(", ", ", ")");   break;
+      case EOpParameters:       outputTriplet(visit, "(", ", ", ")\n{\n");             break;
+      case EOpConstructFloat:   outputTriplet(visit, "vec1(", NULL, ")");              break;
+      case EOpConstructVec2:    outputTriplet(visit, "vec2(", ", ", ")");              break;
+      case EOpConstructVec3:    outputTriplet(visit, "vec3(", ", ", ")");              break;
+      case EOpConstructVec4:    outputTriplet(visit, "vec4(", ", ", ")");              break;
       case EOpConstructBool:    UNIMPLEMENTED(); /* FIXME */ out << "Construct bool";  break;
       case EOpConstructBVec2:   UNIMPLEMENTED(); /* FIXME */ out << "Construct bvec2"; break;
       case EOpConstructBVec3:   UNIMPLEMENTED(); /* FIXME */ out << "Construct bvec3"; break;
@@ -967,18 +988,18 @@
       case EOpConstructIVec2:   UNIMPLEMENTED(); /* FIXME */ out << "Construct ivec2"; break;
       case EOpConstructIVec3:   UNIMPLEMENTED(); /* FIXME */ out << "Construct ivec3"; break;
       case EOpConstructIVec4:   UNIMPLEMENTED(); /* FIXME */ out << "Construct ivec4"; break;
-      case EOpConstructMat2:    outputTriplet(visit, "float2x2(", ", ", ")"); break;
-      case EOpConstructMat3:    outputTriplet(visit, "float3x3(", ", ", ")"); break;
-      case EOpConstructMat4:    outputTriplet(visit, "float4x4(", ", ", ")"); break;
-      case EOpConstructStruct:  UNIMPLEMENTED(); /* FIXME */ out << "Construct structure";  break;
-      case EOpLessThan:         outputTriplet(visit, "(", " < ", ")");             break;
-      case EOpGreaterThan:      outputTriplet(visit, "(", " > ", ")");             break;
-      case EOpLessThanEqual:    outputTriplet(visit, "(", " <= ", ")");            break;
-      case EOpGreaterThanEqual: outputTriplet(visit, "(", " >= ", ")");            break;
-      case EOpVectorEqual:      outputTriplet(visit, "(", " == ", ")");            break;
-      case EOpVectorNotEqual:   outputTriplet(visit, "(", " != ", ")");            break;
-      case EOpMod:              outputTriplet(visit, "mod(", ", ", ")");           break;   // FIXME: Prevent name clashes
-      case EOpPow:              outputTriplet(visit, "pow(", ", ", ")");           break;
+      case EOpConstructMat2:    outputTriplet(visit, "float2x2(", ", ", ")");          break;
+      case EOpConstructMat3:    outputTriplet(visit, "float3x3(", ", ", ")");          break;
+      case EOpConstructMat4:    outputTriplet(visit, "float4x4(", ", ", ")");          break;
+      case EOpConstructStruct:  outputTriplet(visit, "{", ", ", "}");                  break;
+      case EOpLessThan:         outputTriplet(visit, "(", " < ", ")");                 break;
+      case EOpGreaterThan:      outputTriplet(visit, "(", " > ", ")");                 break;
+      case EOpLessThanEqual:    outputTriplet(visit, "(", " <= ", ")");                break;
+      case EOpGreaterThanEqual: outputTriplet(visit, "(", " >= ", ")");                break;
+      case EOpVectorEqual:      outputTriplet(visit, "(", " == ", ")");                break;
+      case EOpVectorNotEqual:   outputTriplet(visit, "(", " != ", ")");                break;
+      case EOpMod:              outputTriplet(visit, "mod(", ", ", ")");               break;   // FIXME: Prevent name clashes
+      case EOpPow:              outputTriplet(visit, "pow(", ", ", ")");               break;
       case EOpAtan:
         if (node->getSequence().size() == 1)
         {
@@ -1063,75 +1084,83 @@
     else
     {
         int size = type.getObjectSize();
-        bool matrix = type.isMatrix();
-        TBasicType basicType = node->getUnionArrayPointer()[0].getType();
 
-        switch (basicType)
+        if (type.getBasicType() == EbtStruct)
         {
-          case EbtBool:
-            if (!matrix)
+            out << "{";
+        }
+        else
+        {
+            bool matrix = type.isMatrix();
+            TBasicType elementType = node->getUnionArrayPointer()[0].getType();
+
+            switch (elementType)
             {
-                switch (size)
+              case EbtBool:
+                if (!matrix)
                 {
-                  case 1: out << "bool(";  break;
-                  case 2: out << "bool2("; break;
-                  case 3: out << "bool3("; break;
-                  case 4: out << "bool4("; break;
-                  default: UNREACHABLE();
+                    switch (size)
+                    {
+                      case 1: out << "bool(";  break;
+                      case 2: out << "bool2("; break;
+                      case 3: out << "bool3("; break;
+                      case 4: out << "bool4("; break;
+                      default: UNREACHABLE();
+                    }
                 }
-            }
-            else
-            {
-                UNIMPLEMENTED();
-            }
-            break;
-          case EbtFloat:
-            if (!matrix)
-            {
-                switch (size)
+                else
                 {
-                  case 1: out << "float(";  break;
-                  case 2: out << "float2("; break;
-                  case 3: out << "float3("; break;
-                  case 4: out << "float4("; break;
-                  default: UNREACHABLE();
+                    UNIMPLEMENTED();
                 }
-            }
-            else
-            {
-                switch (size)
+                break;
+              case EbtFloat:
+                if (!matrix)
                 {
-                  case 4:  out << "float2x2("; break;
-                  case 9:  out << "float3x3("; break;
-                  case 16: out << "float4x4("; break;
-                  default: UNREACHABLE();
+                    switch (size)
+                    {
+                      case 1: out << "float(";  break;
+                      case 2: out << "float2("; break;
+                      case 3: out << "float3("; break;
+                      case 4: out << "float4("; break;
+                      default: UNREACHABLE();
+                    }
                 }
-            }
-            break;
-          case EbtInt:
-            if (!matrix)
-            {
-                switch (size)
+                else
                 {
-                  case 1: out << "int(";  break;
-                  case 2: out << "int2("; break;
-                  case 3: out << "int3("; break;
-                  case 4: out << "int4("; break;
-                  default: UNREACHABLE();
+                    switch (size)
+                    {
+                      case 4:  out << "float2x2("; break;
+                      case 9:  out << "float3x3("; break;
+                      case 16: out << "float4x4("; break;
+                      default: UNREACHABLE();
+                    }
                 }
+                break;
+              case EbtInt:
+                if (!matrix)
+                {
+                    switch (size)
+                    {
+                      case 1: out << "int(";  break;
+                      case 2: out << "int2("; break;
+                      case 3: out << "int3("; break;
+                      case 4: out << "int4("; break;
+                      default: UNREACHABLE();
+                    }
+                }
+                else
+                {
+                    UNIMPLEMENTED();
+                }
+                break;
+              default:
+                UNIMPLEMENTED();   // FIXME
             }
-            else
-            {
-                UNIMPLEMENTED();
-            }
-            break;
-          default:
-            UNIMPLEMENTED();   // FIXME
         }
 
         for (int i = 0; i < size; i++)
         {
-            switch (basicType)
+            switch (node->getUnionArrayPointer()[i].getType())
             {
               case EbtBool:
                 if (node->getUnionArrayPointer()[i].getBConst())
@@ -1159,7 +1188,14 @@
             }
         }
 
-        out << ")";
+        if (type.getBasicType() == EbtStruct)
+        {
+            out << "}";
+        }
+        else
+        {
+            out << ")";
+        }
     }
 }
 
@@ -1443,7 +1479,11 @@
 
 TString OutputHLSL::typeString(const TType &type)
 {
-    if (type.isMatrix())
+    if (type.getBasicType() == EbtStruct)
+    {
+        return type.getTypeName();
+    }
+    else if (type.isMatrix())
     {
         switch (type.getNominalSize())
         {
diff --git a/src/compiler/glslang.y b/src/compiler/glslang.y
index 75467ac..07f9e0e 100644
--- a/src/compiler/glslang.y
+++ b/src/compiler/glslang.y
@@ -1331,7 +1331,7 @@
 single_declaration
     : fully_specified_type {
         $$.type = $1;
-        $$.intermAggregate = 0;
+        $$.intermAggregate = parseContext->intermediate.makeAggregate(parseContext->intermediate.addSymbol(0, "", TType($1), $1.line), $1.line);;
     }
     | fully_specified_type IDENTIFIER {
 		$$.intermAggregate = parseContext->intermediate.makeAggregate(parseContext->intermediate.addSymbol(0, *$2.string, TType($1), $2.line), $2.line);