Implemented complex vector/matrix construction
TRAC #11868
Signed-off-by: Shannon Woods
Signed-off-by: Daniel Koch

Author:    Nicolas Capens

git-svn-id: https://angleproject.googlecode.com/svn/trunk@199 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/OutputHLSL.cpp b/src/compiler/OutputHLSL.cpp
index 9d1fa54..201dbd7 100644
--- a/src/compiler/OutputHLSL.cpp
+++ b/src/compiler/OutputHLSL.cpp
@@ -321,126 +321,6 @@
            "\n"
            "uniform gl_DepthRangeParameters gl_DepthRange;\n"
            "\n"
-           "float vec1(float x)\n"
-           "{\n"
-           "    return x;\n"
-           "}\n"
-           "\n"
-           "float vec1(float2 xy)\n"
-           "{\n"
-           "    return xy[0];\n"
-           "}\n"
-           "\n"
-           "float vec1(float3 xyz)\n"
-           "{\n"
-           "    return xyz[0];\n"
-           "}\n"
-           "\n"
-           "float vec1(float4 xyzw)\n"
-           "{\n"
-           "    return xyzw[0];\n"
-           "}\n"
-           "\n"
-           "float2 vec2(float x)\n"
-           "{\n"
-           "    return float2(x, x);\n"
-           "}\n"
-           "\n"
-           "float2 vec2(float x, float y)\n"
-           "{\n"
-           "    return float2(x, y);\n"
-           "}\n"
-           "\n"
-           "float2 vec2(float2 xy)\n"
-           "{\n"
-           "    return xy;\n"
-           "}\n"
-           "\n"
-           "float2 vec2(float3 xyz)\n"
-           "{\n"
-           "    return float2(xyz[0], xyz[1]);\n"
-           "}\n"
-           "\n"
-           "float2 vec2(float4 xyzw)\n"
-           "{\n"
-           "    return float2(xyzw[0], xyzw[1]);\n"
-           "}\n"
-           "\n"
-           "float3 vec3(float x)\n"
-           "{\n"
-           "    return float3(x, x, x);\n"
-           "}\n"
-           "\n"
-           "float3 vec3(float x, float y, float z)\n"
-           "{\n"
-           "    return float3(x, y, z);\n"
-           "}\n"
-           "\n"
-           "float3 vec3(float2 xy, float z)\n"
-           "{\n"
-           "    return float3(xy[0], xy[1], z);\n"
-           "}\n"
-           "\n"
-           "float3 vec3(float x, float2 yz)\n"
-           "{\n"
-           "    return float3(x, yz[0], yz[1]);\n"
-           "}\n"
-           "\n"
-           "float3 vec3(float3 xyz)\n"
-           "{\n"
-           "    return xyz;\n"
-           "}\n"
-           "\n"
-           "float3 vec3(float4 xyzw)\n"
-           "{\n"
-           "    return float3(xyzw[0], xyzw[1], xyzw[2]);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float x)\n"
-           "{\n"
-           "    return float4(x, x, x, x);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float x, float y, float z, float w)\n"
-           "{\n"
-           "    return float4(x, y, z, w);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float2 xy, float z, float w)\n"
-           "{\n"
-           "    return float4(xy[0], xy[1], z, w);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float x, float2 yz, float w)\n"
-           "{\n"
-           "    return float4(x, yz[0], yz[1], w);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float x, float y, float2 zw)\n"
-           "{\n"
-           "    return float4(x, y, zw[0], zw[1]);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float2 xy, float2 zw)\n"
-           "{\n"
-           "    return float4(xy[0], xy[1], zw[0], zw[1]);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float3 xyz, float w)\n"
-           "{\n"
-           "    return float4(xyz[0], xyz[1], xyz[2], w);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float x, float3 yzw)\n"
-           "{\n"
-           "    return float4(x, yzw[0], yzw[1], yzw[2]);\n"
-           "}\n"
-           "\n"
-           "float4 vec4(float4 xyzw)\n"
-           "{\n"
-           "    return xyzw;\n"
-           "}\n"
-           "\n"
            "bool xor(bool p, bool q)\n"
            "{\n"
            "    return (p || q) && !(p && q);\n"
@@ -632,6 +512,132 @@
                "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
                "}\n";
     }
+
+    for (ConstructorSet::iterator constructor = mConstructors.begin(); constructor != mConstructors.end(); constructor++)
+    {
+        out << typeString(constructor->type) + " " + constructor->name + "(";
+
+        for (unsigned int parameter = 0; parameter < constructor->parameters.size(); parameter++)
+        {
+            const TType &type = constructor->parameters[parameter];
+
+            out << typeString(type) + " x" + str(parameter);
+
+            if (parameter < constructor->parameters.size() - 1)
+            {
+                out << ", ";
+            }
+        }
+
+        out << ")\n"
+               "{\n";
+
+        out << "    return " + typeString(constructor->type) + "(";
+
+        if (constructor->type.isMatrix() && constructor->parameters.size() == 1)
+        {
+            int dim = constructor->type.getNominalSize();
+            const TType &parameter = constructor->parameters[0];
+
+            if (parameter.isScalar())
+            {
+                for (int row = 0; row < dim; row++)
+                {
+                    for (int col = 0; col < dim; col++)
+                    {
+                        out << TString((row == col) ? "x0" : "0.0");
+                        
+                        if (row < dim - 1 || col < dim - 1)
+                        {
+                            out << ", ";
+                        }
+                    }
+                }
+            }
+            else if (parameter.isMatrix())
+            {
+                for (int row = 0; row < dim; row++)
+                {
+                    for (int col = 0; col < dim; col++)
+                    {
+                        if (row < parameter.getNominalSize() && col < parameter.getNominalSize())
+                        {
+                            out << TString("x0") + "[" + str(row) + "]" + "[" + str(col) + "]";
+                        }
+                        else
+                        {
+                            out << TString((row == col) ? "1.0" : "0.0");
+                        }
+
+                        if (row < dim - 1 || col < dim - 1)
+                        {
+                            out << ", ";
+                        }
+                    }
+                }
+            }
+            else UNREACHABLE();
+        }
+        else
+        {
+            int remainingComponents = constructor->type.getObjectSize();
+            int parameterIndex = 0;
+
+            while (remainingComponents > 0)
+            {
+                const TType &parameter = constructor->parameters[parameterIndex];
+                bool moreParameters = parameterIndex < (int)constructor->parameters.size() - 1;
+
+                out << "x" + str(parameterIndex);
+
+                if (parameter.isScalar())
+                {
+                    remainingComponents -= 1;
+                }
+                else if (parameter.isVector())
+                {
+                    if (remainingComponents == parameter.getInstanceSize() || moreParameters)
+                    {
+                        remainingComponents -= parameter.getInstanceSize();
+                    }
+                    else if (remainingComponents < parameter.getNominalSize())
+                    {
+                        switch (remainingComponents)
+                        {
+                        case 1: out << ".x";    break;
+                        case 2: out << ".xy";   break;
+                        case 3: out << ".xyz";  break;
+                        case 4: out << ".xyzw"; break;
+                        default: UNREACHABLE();
+                        }
+
+                        remainingComponents = 0;
+                    }
+                    else UNREACHABLE();
+                }
+                else if (parameter.isMatrix() || parameter.getStruct())
+                {
+                    ASSERT(remainingComponents == parameter.getInstanceSize() || moreParameters);
+                    
+                    remainingComponents -= parameter.getInstanceSize();
+                }
+                else UNREACHABLE();
+
+                if (moreParameters)
+                {
+                    parameterIndex++;
+                }
+
+                if (remainingComponents)
+                {
+                    out << ", ";
+                }
+            }
+        }
+
+        out << ");\n"
+               "}\n";
+    }
 }
 
 void OutputHLSL::footer()
@@ -1242,12 +1248,15 @@
 
                 sequence.erase(sequence.begin());
 
-                out << ")\n";
+                out << ")\n"
+                       "{\n";
 
                 mInsideFunction = true;
             }
             else if (visit == PostVisit)
             {
+                out << "}\n";
+
                 mInsideFunction = false;
             }
         }
@@ -1332,21 +1341,66 @@
         }
         break;
       case EOpParameters:       outputTriplet(visit, "(", ", ", ")\n{\n");             break;
-      case EOpConstructFloat:   outputTriplet(visit, "vec1(", NULL, ")");              break;
-      case EOpConstructVec2:    outputTriplet(visit, "vec2(", ", ", ")");              break;
-      case EOpConstructVec3:    outputTriplet(visit, "vec3(", ", ", ")");              break;
-      case EOpConstructVec4:    outputTriplet(visit, "vec4(", ", ", ")");              break;
-      case EOpConstructBool:    UNIMPLEMENTED(); /* FIXME */ out << "Construct bool";  break;
-      case EOpConstructBVec2:   UNIMPLEMENTED(); /* FIXME */ out << "Construct bvec2"; break;
-      case EOpConstructBVec3:   UNIMPLEMENTED(); /* FIXME */ out << "Construct bvec3"; break;
-      case EOpConstructBVec4:   UNIMPLEMENTED(); /* FIXME */ out << "Construct bvec4"; break;
-      case EOpConstructInt:     UNIMPLEMENTED(); /* FIXME */ out << "Construct int";   break;
-      case EOpConstructIVec2:   UNIMPLEMENTED(); /* FIXME */ out << "Construct ivec2"; break;
-      case EOpConstructIVec3:   UNIMPLEMENTED(); /* FIXME */ out << "Construct ivec3"; break;
-      case EOpConstructIVec4:   UNIMPLEMENTED(); /* FIXME */ out << "Construct ivec4"; break;
-      case EOpConstructMat2:    outputTriplet(visit, "float2x2(", ", ", ")");          break;
-      case EOpConstructMat3:    outputTriplet(visit, "float3x3(", ", ", ")");          break;
-      case EOpConstructMat4:    outputTriplet(visit, "float4x4(", ", ", ")");          break;
+      case EOpConstructFloat:
+          addConstructor(node->getType(), "vec1", node->getSequence());
+          outputTriplet(visit, "vec1(", "", ")");
+          break;
+      case EOpConstructVec2:
+          addConstructor(node->getType(), "vec2", node->getSequence());
+          outputTriplet(visit, "vec2(", ", ", ")");
+          break;
+      case EOpConstructVec3:
+          addConstructor(node->getType(), "vec3", node->getSequence());
+          outputTriplet(visit, "vec3(", ", ", ")");
+          break;
+      case EOpConstructVec4:
+          addConstructor(node->getType(), "vec4", node->getSequence());
+          outputTriplet(visit, "vec4(", ", ", ")");
+          break;
+      case EOpConstructBool:
+          addConstructor(node->getType(), "bvec1", node->getSequence());
+          outputTriplet(visit, "bvec1(", "", ")");
+          break;
+      case EOpConstructBVec2:
+          addConstructor(node->getType(), "bvec2", node->getSequence());
+          outputTriplet(visit, "bvec2(", ", ", ")");
+          break;
+      case EOpConstructBVec3:
+          addConstructor(node->getType(), "bvec3", node->getSequence());
+          outputTriplet(visit, "bvec3(", ", ", ")");
+          break;
+      case EOpConstructBVec4:
+          addConstructor(node->getType(), "bvec4", node->getSequence());
+          outputTriplet(visit, "bvec4(", ", ", ")");
+          break;
+      case EOpConstructInt:
+          addConstructor(node->getType(), "ivec1", node->getSequence());
+          outputTriplet(visit, "ivec1(", "", ")");
+          break;
+      case EOpConstructIVec2:
+          addConstructor(node->getType(), "ivec2", node->getSequence());
+          outputTriplet(visit, "ivec2(", ", ", ")");
+          break;
+      case EOpConstructIVec3:
+          addConstructor(node->getType(), "ivec3", node->getSequence());
+          outputTriplet(visit, "ivec3(", ", ", ")");
+          break;
+      case EOpConstructIVec4:
+          addConstructor(node->getType(), "ivec4", node->getSequence());
+          outputTriplet(visit, "ivec4(", ", ", ")");
+          break;
+      case EOpConstructMat2:
+          addConstructor(node->getType(), "mat2", node->getSequence());
+          outputTriplet(visit, "mat2(", ", ", ")");
+          break;
+      case EOpConstructMat3:
+          addConstructor(node->getType(), "mat3", node->getSequence());
+          outputTriplet(visit, "mat3(", ", ", ")");
+          break;
+      case EOpConstructMat4: 
+          addConstructor(node->getType(), "mat4", node->getSequence());
+          outputTriplet(visit, "mat4(", ", ", ")");
+          break;
       case EOpConstructStruct:  outputTriplet(visit, "{", ", ", "}");                  break;
       case EOpLessThan:         outputTriplet(visit, "(", " < ", ")");                 break;
       case EOpGreaterThan:      outputTriplet(visit, "(", " > ", ")");                 break;
@@ -2020,6 +2074,49 @@
     return string;
 }
 
+bool OutputHLSL::CompareConstructor::operator()(const Constructor &x, const Constructor &y) const
+{
+    if (x.type != y.type)
+    {
+        return memcmp(&x.type, &y.type, sizeof(TType)) < 0;
+    }
+
+    if (x.name != y.name)
+    {
+        return x.name < y.name;
+    }
+
+    if (x.parameters.size() != y.parameters.size())
+    {
+        return x.parameters.size() < y.parameters.size();
+    }
+
+    for (unsigned int i = 0; i < x.parameters.size(); i++)
+    {
+        if (x.parameters[i] != y.parameters[i])
+        {
+            return memcmp(&x.parameters[i], &y.parameters[i], sizeof(TType)) < 0;
+        }
+    }
+
+    return false;
+}
+
+void OutputHLSL::addConstructor(const TType &type, const TString &name, const TIntermSequence &parameters)
+{
+    Constructor constructor;
+
+    constructor.type = type;
+    constructor.name = name;
+
+    for (TIntermSequence::const_iterator parameter = parameters.begin(); parameter != parameters.end(); parameter++)
+    {
+        constructor.parameters.push_back((*parameter)->getAsTyped()->getType());
+    }
+
+    mConstructors.insert(constructor);
+}
+
 TString OutputHLSL::decorate(const TString &string)
 {
     if (string.substr(0, 3) != "gl_" && string.substr(0, 3) != "dx_")