Implemented struct equality
TRAC #11727
Signed-off-by: Shannon Woods
Signed-off-by: Daniel Koch

Author:    Nicolas Capens

git-svn-id: https://angleproject.googlecode.com/svn/trunk@127 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/OutputHLSL.cpp b/src/compiler/OutputHLSL.cpp
index d102e20..27a1db6 100644
--- a/src/compiler/OutputHLSL.cpp
+++ b/src/compiler/OutputHLSL.cpp
@@ -13,6 +13,18 @@
 {
 OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), mContext(context)
 {
+    mUsesEqualMat2 = false;
+    mUsesEqualMat3 = false;
+    mUsesEqualMat4 = false;
+    mUsesEqualVec2 = false;
+    mUsesEqualVec3 = false;
+    mUsesEqualVec4 = false;
+    mUsesEqualIVec2 = false;
+    mUsesEqualIVec3 = false;
+    mUsesEqualIVec4 = false;
+    mUsesEqualBVec2 = false;
+    mUsesEqualBVec3 = false;
+    mUsesEqualBVec4 = false;
 }
 
 void OutputHLSL::output()
@@ -401,28 +413,109 @@
            "        return -N;\n"
            "    }\n"
            "}\n"
-           "\n"
-           "bool __equal(float2x2 m, float2x2 n)\n"
-           "{\n"
-           "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] &&\n"
-           "           m[1][0] == n[1][0] && m[1][1] == n[1][1];\n"
-           "}\n"
-           "\n"
-           "bool __equal(float3x3 m, float3x3 n)\n"
-           "{\n"
-           "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] &&\n"
-           "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] &&\n"
-           "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2];\n"
-           "}\n"
-           "\n"
-           "bool __equal(float4x4 m, float4x4 n)\n"
-           "{\n"
-           "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] && m[0][3] == n[0][3] &&\n"
-           "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] && m[1][3] == n[1][3] &&\n"
-           "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2] && m[2][3] == n[2][3] &&\n"
-           "           m[3][0] == n[3][0] && m[3][1] == n[3][1] && m[3][2] == n[3][2] && m[3][3] == n[3][3];\n"
-           "}\n"
            "\n";
+
+    if (mUsesEqualMat2)
+    {
+        out << "bool __equal(float2x2 m, float2x2 n)\n"
+               "{\n"
+               "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] &&\n"
+               "           m[1][0] == n[1][0] && m[1][1] == n[1][1];\n"
+               "}\n";
+    }
+
+    if (mUsesEqualMat3)
+    {
+        out << "bool __equal(float3x3 m, float3x3 n)\n"
+               "{\n"
+               "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] &&\n"
+               "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] &&\n"
+               "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2];\n"
+               "}\n";
+    }
+
+    if (mUsesEqualMat4)
+    {
+        out << "bool __equal(float4x4 m, float4x4 n)\n"
+               "{\n"
+               "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] && m[0][3] == n[0][3] &&\n"
+               "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] && m[1][3] == n[1][3] &&\n"
+               "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2] && m[2][3] == n[2][3] &&\n"
+               "           m[3][0] == n[3][0] && m[3][1] == n[3][1] && m[3][2] == n[3][2] && m[3][3] == n[3][3];\n"
+               "}\n";
+    }
+
+    if (mUsesEqualVec2)
+    {
+        out << "bool __equal(float2 v, float2 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualVec3)
+    {
+        out << "bool __equal(float3 v, float3 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualVec4)
+    {
+        out << "bool __equal(float4 v, float4 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualIVec2)
+    {
+        out << "bool __equal(int2 v, int2 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualIVec3)
+    {
+        out << "bool __equal(int3 v, int3 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualIVec4)
+    {
+        out << "bool __equal(int4 v, int4 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualBVec2)
+    {
+        out << "bool __equal(bool2 v, bool2 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualBVec3)
+    {
+        out << "bool __equal(bool3 v, bool3 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
+               "}\n";
+    }
+
+    if (mUsesEqualBVec4)
+    {
+        out << "bool __equal(bool4 v, bool4 u)\n"
+               "{\n"
+               "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
+               "}\n";
+    }
 }
 
 void OutputHLSL::footer()
@@ -642,23 +735,106 @@
       case EOpMul:               outputTriplet(visit, "(", " * ", ")"); break;
       case EOpDiv:               outputTriplet(visit, "(", " / ", ")"); break;
       case EOpEqual:
-        if (!node->getLeft()->isMatrix())
-        {
-            outputTriplet(visit, "(", " == ", ")");
-        }
-        else
-        {
-            outputTriplet(visit, "__equal(", ", ", ")");
-        }
-        break;
       case EOpNotEqual:
-        if (!node->getLeft()->isMatrix())
+        if (node->getLeft()->isScalar())
         {
-            outputTriplet(visit, "(", " != ", ")");
+            if (node->getOp() == EOpEqual)
+            {
+                outputTriplet(visit, "(", " == ", ")");
+            }
+            else
+            {
+                outputTriplet(visit, "(", " != ", ")");
+            }
+        }
+        else if (node->getLeft()->getBasicType() == EbtStruct)
+        {
+            if (node->getOp() == EOpEqual)
+            {
+                out << "(";
+            }
+            else
+            {
+                out << "!(";
+            }
+
+            const TTypeList *fields = node->getLeft()->getType().getStruct();
+
+            for (size_t i = 0; i < fields->size(); i++)
+            {
+                const TType *fieldType = (*fields)[i].type;
+
+                node->getLeft()->traverse(this);
+                out << "." + fieldType->getFieldName() + " == ";
+                node->getRight()->traverse(this);
+                out << "." + fieldType->getFieldName();
+
+                if (i < fields->size() - 1)
+                {
+                    out << " && ";
+                }
+            }
+
+            out << ")";
+
+            return false;
         }
         else
         {
-            outputTriplet(visit, "!__equal(", ", ", ")");
+            if (node->getLeft()->isMatrix())
+            {
+                switch (node->getLeft()->getSize())
+                {
+                  case 2 * 2: mUsesEqualMat2 = true; break;
+                  case 3 * 3: mUsesEqualMat3 = true; break;
+                  case 4 * 4: mUsesEqualMat4 = true; break;
+                  default: UNREACHABLE();
+                }
+            }
+            else if (node->getLeft()->isVector())
+            {
+                switch (node->getLeft()->getBasicType())
+                {
+                  case EbtFloat:
+                    switch (node->getLeft()->getSize())
+                    {
+                      case 2: mUsesEqualVec2 = true; break;
+                      case 3: mUsesEqualVec3 = true; break;
+                      case 4: mUsesEqualVec4 = true; break;
+                      default: UNREACHABLE();
+                    }
+                    break;
+                  case EbtInt:
+                    switch (node->getLeft()->getSize())
+                    {
+                      case 2: mUsesEqualIVec2 = true; break;
+                      case 3: mUsesEqualIVec3 = true; break;
+                      case 4: mUsesEqualIVec4 = true; break;
+                      default: UNREACHABLE();
+                    }
+                    break;
+                  case EbtBool:
+                    switch (node->getLeft()->getSize())
+                    {
+                      case 2: mUsesEqualBVec2 = true; break;
+                      case 3: mUsesEqualBVec3 = true; break;
+                      case 4: mUsesEqualBVec4 = true; break;
+                      default: UNREACHABLE();
+                    }
+                    break;
+                  default: UNREACHABLE();
+                }
+            }
+            else UNREACHABLE();
+
+            if (node->getOp() == EOpEqual)
+            {
+                outputTriplet(visit, "__equal(", ", ", ")");
+            }
+            else
+            {
+                outputTriplet(visit, "!__equal(", ", ", ")");
+            }
         }
         break;
       case EOpLessThan:          outputTriplet(visit, "(", " < ", ")");   break;
@@ -803,12 +979,6 @@
     EShLanguage language = mContext.language;
     TInfoSinkBase &out = mBody;
 
-    if (node->getOp() == EOpNull)
-    {
-        out.message(EPrefixError, "node is still EOpNull!");
-        return true;
-    }
-
     switch (node->getOp())
     {
       case EOpSequence: outputTriplet(visit, NULL, ";\n", ";\n"); break;
@@ -823,6 +993,11 @@
             {
                 if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "")   // Variable declaration
                 {
+                    if (variable->getQualifier() == EvqGlobal)
+                    {
+                        out << "static ";
+                    }
+
                     out << typeString(variable->getType()) + " ";
 
                     for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
@@ -1144,10 +1319,7 @@
                       default: UNREACHABLE();
                     }
                 }
-                else
-                {
-                    UNIMPLEMENTED();
-                }
+                else UNREACHABLE();
                 break;
               case EbtFloat:
                 if (!matrix)
@@ -1184,10 +1356,7 @@
                       default: UNREACHABLE();
                     }
                 }
-                else
-                {
-                    UNIMPLEMENTED();
-                }
+                else UNREACHABLE();
                 break;
               default:
                 UNIMPLEMENTED();   // FIXME
diff --git a/src/compiler/OutputHLSL.h b/src/compiler/OutputHLSL.h
index f4876f6..bb6203f 100644
--- a/src/compiler/OutputHLSL.h
+++ b/src/compiler/OutputHLSL.h
@@ -48,6 +48,20 @@
     TInfoSinkBase mHeader;
     TInfoSinkBase mBody;
     TInfoSinkBase mFooter;
+
+    // Parameters determining what goes in the header output
+    bool mUsesEqualMat2;
+    bool mUsesEqualMat3;
+    bool mUsesEqualMat4;
+    bool mUsesEqualVec2;
+    bool mUsesEqualVec3;
+    bool mUsesEqualVec4;
+    bool mUsesEqualIVec2;
+    bool mUsesEqualIVec3;
+    bool mUsesEqualIVec4;
+    bool mUsesEqualBVec2;
+    bool mUsesEqualBVec3;
+    bool mUsesEqualBVec4;
 };
 }
 
diff --git a/src/compiler/Types.h b/src/compiler/Types.h
index 7adab52..e06d83f 100644
--- a/src/compiler/Types.h
+++ b/src/compiler/Types.h
@@ -214,6 +214,7 @@
 	void setArrayInformationType(TType* t) { arrayInformationType = t; }
 	TType* getArrayInformationType() const { return arrayInformationType; }
 	virtual bool isVector() const { return size > 1 && !matrix; }
+    virtual bool isScalar() const { return size == 1 && !matrix && !structure; }
 	static const char* getBasicString(TBasicType t) {
 		switch (t) {
 		case EbtVoid:              return "void";              break;
diff --git a/src/compiler/intermediate.h b/src/compiler/intermediate.h
index 47cfc10..525f92a 100644
--- a/src/compiler/intermediate.h
+++ b/src/compiler/intermediate.h
@@ -243,6 +243,7 @@
     virtual bool isMatrix() const { return type.isMatrix(); }
     virtual bool isArray()  const { return type.isArray(); }
     virtual bool isVector() const { return type.isVector(); }
+    virtual bool isScalar() const { return type.isScalar(); }
     const char* getBasicString()      const { return type.getBasicString(); }
     const char* getQualifierString()  const { return type.getQualifierString(); }
     TString getCompleteString() const { return type.getCompleteString(); }