Use BuiltInFunctionEmulatorHLSL for all emulated functions

Implementation of missing built-in functions is a separate concern from
outputting the intermediate tree itself as HLSL, so it makes sense to
have all of the built-in emulation in a class that is separate from
OutputHLSL. Being able to reuse the same logic for different emulated
functions also makes the code more compact.

Change-Id: Id503dc3a5c5e743ec65722add56d6ba216a03a7f
Reviewed-on: https://chromium-review.googlesource.com/239872
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Nicolas Capens <capn@chromium.org>
Reviewed-by: Olli Etuaho <oetuaho@nvidia.com>
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/BuiltInFunctionEmulatorHLSL.cpp b/src/compiler/translator/BuiltInFunctionEmulatorHLSL.cpp
index 4857513..4de954c 100644
--- a/src/compiler/translator/BuiltInFunctionEmulatorHLSL.cpp
+++ b/src/compiler/translator/BuiltInFunctionEmulatorHLSL.cpp
@@ -16,6 +16,133 @@
     TType float3(EbtFloat, 3);
     TType float4(EbtFloat, 4);
 
+    AddEmulatedFunction(EOpMod, float1, float1,
+        "float webgl_mod_emu(float x, float y)\n"
+        "{\n"
+        "    return x - y * floor(x / y);\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpMod, float2, float2,
+        "float2 webgl_mod_emu(float2 x, float2 y)\n"
+        "{\n"
+        "    return x - y * floor(x / y);\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpMod, float2, float1,
+        "float2 webgl_mod_emu(float2 x, float y)\n"
+        "{\n"
+        "    return x - y * floor(x / y);\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpMod, float3, float3,
+        "float3 webgl_mod_emu(float3 x, float3 y)\n"
+        "{\n"
+        "    return x - y * floor(x / y);\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpMod, float3, float1,
+        "float3 webgl_mod_emu(float3 x, float y)\n"
+        "{\n"
+        "    return x - y * floor(x / y);\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpMod, float4, float4,
+        "float4 webgl_mod_emu(float4 x, float4 y)\n"
+        "{\n"
+        "    return x - y * floor(x / y);\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpMod, float4, float1,
+        "float4 webgl_mod_emu(float4 x, float y)\n"
+        "{\n"
+        "    return x - y * floor(x / y);\n"
+        "}\n"
+        "\n");
+
+    AddEmulatedFunction(EOpFaceForward, float1, float1, float1,
+        "float webgl_faceforward_emu(float N, float I, float Nref)\n"
+        "{\n"
+        "    if(dot(Nref, I) >= 0)\n"
+        "    {\n"
+        "        return -N;\n"
+        "    }\n"
+        "    else\n"
+        "    {\n"
+        "        return N;\n"
+        "    }\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpFaceForward, float2, float2, float2,
+        "float2 webgl_faceforward_emu(float2 N, float2 I, float2 Nref)\n"
+        "{\n"
+        "    if(dot(Nref, I) >= 0)\n"
+        "    {\n"
+        "        return -N;\n"
+        "    }\n"
+        "    else\n"
+        "    {\n"
+        "        return N;\n"
+        "    }\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpFaceForward, float3, float3, float3,
+        "float3 webgl_faceforward_emu(float3 N, float3 I, float3 Nref)\n"
+        "{\n"
+        "    if(dot(Nref, I) >= 0)\n"
+        "    {\n"
+        "        return -N;\n"
+        "    }\n"
+        "    else\n"
+        "    {\n"
+        "        return N;\n"
+        "    }\n"
+        "}\n"
+        "\n");
+    AddEmulatedFunction(EOpFaceForward, float4, float4, float4,
+        "float4 webgl_faceforward_emu(float4 N, float4 I, float4 Nref)\n"
+        "{\n"
+        "    if(dot(Nref, I) >= 0)\n"
+        "    {\n"
+        "        return -N;\n"
+        "    }\n"
+        "    else\n"
+        "    {\n"
+        "        return N;\n"
+        "    }\n"
+        "}\n"
+        "\n");
+
+    AddEmulatedFunction(EOpAtan, float1, float1,
+        "float webgl_atan_emu(float y, float x)\n"
+        "{\n"
+        "    if(x == 0 && y == 0) x = 1;\n"   // Avoid producing a NaN
+        "    return atan2(y, x);\n"
+        "}\n");
+    AddEmulatedFunction(EOpAtan, float2, float2,
+        "float2 webgl_atan_emu(float2 y, float2 x)\n"
+        "{\n"
+        "    if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
+        "    if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
+        "    return float2(atan2(y[0], x[0]), atan2(y[1], x[1]));\n"
+        "}\n");
+    AddEmulatedFunction(EOpAtan, float3, float3,
+        "float3 webgl_atan_emu(float3 y, float3 x)\n"
+        "{\n"
+        "    if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
+        "    if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
+        "    if(x[2] == 0 && y[2] == 0) x[2] = 1;\n"
+        "    return float3(atan2(y[0], x[0]), atan2(y[1], x[1]), atan2(y[2], x[2]));\n"
+        "}\n");
+    AddEmulatedFunction(EOpAtan, float4, float4,
+        "float4 webgl_atan_emu(float4 y, float4 x)\n"
+        "{\n"
+        "    if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
+        "    if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
+        "    if(x[2] == 0 && y[2] == 0) x[2] = 1;\n"
+        "    if(x[3] == 0 && y[3] == 0) x[3] = 1;\n"
+        "    return float4(atan2(y[0], x[0]), atan2(y[1], x[1]), atan2(y[2], x[2]), atan2(y[3], x[3]));\n"
+        "}\n");
+
     AddEmulatedFunction(EOpAsinh, float1,
         "float webgl_asinh_emu(in float x) {\n"
         "    return log(x + sqrt(pow(x, 2.0) + 1.0));\n"
diff --git a/src/compiler/translator/OutputHLSL.cpp b/src/compiler/translator/OutputHLSL.cpp
index 88d1e2d..fe79221 100644
--- a/src/compiler/translator/OutputHLSL.cpp
+++ b/src/compiler/translator/OutputHLSL.cpp
@@ -112,21 +112,6 @@
     mUsesPointSize = false;
     mUsesFragDepth = false;
     mUsesXor = false;
-    mUsesMod1 = false;
-    mUsesMod2v = false;
-    mUsesMod2f = false;
-    mUsesMod3v = false;
-    mUsesMod3f = false;
-    mUsesMod4v = false;
-    mUsesMod4f = false;
-    mUsesFaceforward1 = false;
-    mUsesFaceforward2 = false;
-    mUsesFaceforward3 = false;
-    mUsesFaceforward4 = false;
-    mUsesAtan2_1 = false;
-    mUsesAtan2_2 = false;
-    mUsesAtan2_3 = false;
-    mUsesAtan2_4 = false;
     mUsesDiscardRewriting = false;
     mUsesNestedBreak = false;
 
@@ -188,14 +173,13 @@
     BuiltInFunctionEmulatorHLSL builtInFunctionEmulator;
     builtInFunctionEmulator.MarkBuiltInFunctionsForEmulation(mContext.treeRoot);
     mContext.treeRoot->traverse(this);   // Output the body first to determine what has to go in the header
-    header();
-    TInfoSinkBase& sink = mContext.infoSink().obj;
-    // Write emulated built-in functions if needed.
-    builtInFunctionEmulator.OutputEmulatedFunctionDefinition(sink, false);
-    builtInFunctionEmulator.Cleanup();
+    header(&builtInFunctionEmulator);
 
+    TInfoSinkBase& sink = mContext.infoSink().obj;
     sink << mHeader.c_str();
     sink << mBody.c_str();
+
+    builtInFunctionEmulator.Cleanup();
 }
 
 void OutputHLSL::makeFlaggedStructMaps(const std::vector<TIntermTyped *> &flaggedStructs)
@@ -284,7 +268,7 @@
     return init;
 }
 
-void OutputHLSL::header()
+void OutputHLSL::header(const BuiltInFunctionEmulatorHLSL *builtInFunctionEmulator)
 {
     TInfoSinkBase &out = mHeader;
 
@@ -1219,174 +1203,7 @@
                "\n";
     }
 
-    if (mUsesMod1)
-    {
-        out << "float mod(float x, float y)\n"
-               "{\n"
-               "    return x - y * floor(x / y);\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesMod2v)
-    {
-        out << "float2 mod(float2 x, float2 y)\n"
-               "{\n"
-               "    return x - y * floor(x / y);\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesMod2f)
-    {
-        out << "float2 mod(float2 x, float y)\n"
-               "{\n"
-               "    return x - y * floor(x / y);\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesMod3v)
-    {
-        out << "float3 mod(float3 x, float3 y)\n"
-               "{\n"
-               "    return x - y * floor(x / y);\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesMod3f)
-    {
-        out << "float3 mod(float3 x, float y)\n"
-               "{\n"
-               "    return x - y * floor(x / y);\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesMod4v)
-    {
-        out << "float4 mod(float4 x, float4 y)\n"
-               "{\n"
-               "    return x - y * floor(x / y);\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesMod4f)
-    {
-        out << "float4 mod(float4 x, float y)\n"
-               "{\n"
-               "    return x - y * floor(x / y);\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesFaceforward1)
-    {
-        out << "float faceforward(float N, float I, float Nref)\n"
-               "{\n"
-               "    if(dot(Nref, I) >= 0)\n"
-               "    {\n"
-               "        return -N;\n"
-               "    }\n"
-               "    else\n"
-               "    {\n"
-               "        return N;\n"
-               "    }\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesFaceforward2)
-    {
-        out << "float2 faceforward(float2 N, float2 I, float2 Nref)\n"
-               "{\n"
-               "    if(dot(Nref, I) >= 0)\n"
-               "    {\n"
-               "        return -N;\n"
-               "    }\n"
-               "    else\n"
-               "    {\n"
-               "        return N;\n"
-               "    }\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesFaceforward3)
-    {
-        out << "float3 faceforward(float3 N, float3 I, float3 Nref)\n"
-               "{\n"
-               "    if(dot(Nref, I) >= 0)\n"
-               "    {\n"
-               "        return -N;\n"
-               "    }\n"
-               "    else\n"
-               "    {\n"
-               "        return N;\n"
-               "    }\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesFaceforward4)
-    {
-        out << "float4 faceforward(float4 N, float4 I, float4 Nref)\n"
-               "{\n"
-               "    if(dot(Nref, I) >= 0)\n"
-               "    {\n"
-               "        return -N;\n"
-               "    }\n"
-               "    else\n"
-               "    {\n"
-               "        return N;\n"
-               "    }\n"
-               "}\n"
-               "\n";
-    }
-
-    if (mUsesAtan2_1)
-    {
-        out << "float atanyx(float y, float x)\n"
-               "{\n"
-               "    if(x == 0 && y == 0) x = 1;\n"   // Avoid producing a NaN
-               "    return atan2(y, x);\n"
-               "}\n";
-    }
-
-    if (mUsesAtan2_2)
-    {
-        out << "float2 atanyx(float2 y, float2 x)\n"
-               "{\n"
-               "    if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
-               "    if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
-               "    return float2(atan2(y[0], x[0]), atan2(y[1], x[1]));\n"
-               "}\n";
-    }
-
-    if (mUsesAtan2_3)
-    {
-        out << "float3 atanyx(float3 y, float3 x)\n"
-               "{\n"
-               "    if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
-               "    if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
-               "    if(x[2] == 0 && y[2] == 0) x[2] = 1;\n"
-               "    return float3(atan2(y[0], x[0]), atan2(y[1], x[1]), atan2(y[2], x[2]));\n"
-               "}\n";
-    }
-
-    if (mUsesAtan2_4)
-    {
-        out << "float4 atanyx(float4 y, float4 x)\n"
-               "{\n"
-               "    if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
-               "    if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
-               "    if(x[2] == 0 && y[2] == 0) x[2] = 1;\n"
-               "    if(x[3] == 0 && y[3] == 0) x[3] = 1;\n"
-               "    return float4(atan2(y[0], x[0]), atan2(y[1], x[1]), atan2(y[2], x[2]), atan2(y[3], x[3]));\n"
-               "}\n";
-    }
+    builtInFunctionEmulator->OutputEmulatedFunctionDefinition(out, false);
 }
 
 void OutputHLSL::visitSymbol(TIntermSymbol *node)
@@ -2262,37 +2079,14 @@
       case EOpVectorEqual:      outputTriplet(visit, "(", " == ", ")");                break;
       case EOpVectorNotEqual:   outputTriplet(visit, "(", " != ", ")");                break;
       case EOpMod:
-        {
-            // We need to look at the number of components in both arguments
-            const int modValue = (*node->getSequence())[0]->getAsTyped()->getNominalSize() * 10 +
-                (*node->getSequence())[1]->getAsTyped()->getNominalSize();
-            switch (modValue)
-            {
-              case 11: mUsesMod1 = true; break;
-              case 22: mUsesMod2v = true; break;
-              case 21: mUsesMod2f = true; break;
-              case 33: mUsesMod3v = true; break;
-              case 31: mUsesMod3f = true; break;
-              case 44: mUsesMod4v = true; break;
-              case 41: mUsesMod4f = true; break;
-              default: UNREACHABLE();
-            }
-
-            outputTriplet(visit, "mod(", ", ", ")");
-        }
+        ASSERT(node->getUseEmulatedFunction());
+        writeEmulatedFunctionTriplet(visit, "mod(");
         break;
       case EOpPow:              outputTriplet(visit, "pow(", ", ", ")");               break;
       case EOpAtan:
         ASSERT(node->getSequence()->size() == 2);   // atan(x) is a unary operator
-        switch ((*node->getSequence())[0]->getAsTyped()->getNominalSize())
-        {
-          case 1: mUsesAtan2_1 = true; break;
-          case 2: mUsesAtan2_2 = true; break;
-          case 3: mUsesAtan2_3 = true; break;
-          case 4: mUsesAtan2_4 = true; break;
-          default: UNREACHABLE();
-        }
-        outputTriplet(visit, "atanyx(", ", ", ")");
+        ASSERT(node->getUseEmulatedFunction());
+        writeEmulatedFunctionTriplet(visit, "atan(");
         break;
       case EOpMin:           outputTriplet(visit, "min(", ", ", ")");           break;
       case EOpMax:           outputTriplet(visit, "max(", ", ", ")");           break;
@@ -2304,18 +2098,8 @@
       case EOpDot:           outputTriplet(visit, "dot(", ", ", ")");           break;
       case EOpCross:         outputTriplet(visit, "cross(", ", ", ")");         break;
       case EOpFaceForward:
-        {
-            switch ((*node->getSequence())[0]->getAsTyped()->getNominalSize())   // Number of components in the first argument
-            {
-            case 1: mUsesFaceforward1 = true; break;
-            case 2: mUsesFaceforward2 = true; break;
-            case 3: mUsesFaceforward3 = true; break;
-            case 4: mUsesFaceforward4 = true; break;
-            default: UNREACHABLE();
-            }
-
-            outputTriplet(visit, "faceforward(", ", ", ")");
-        }
+        ASSERT(node->getUseEmulatedFunction());
+        writeEmulatedFunctionTriplet(visit, "faceforward(");
         break;
       case EOpReflect:       outputTriplet(visit, "reflect(", ", ", ")");       break;
       case EOpRefract:       outputTriplet(visit, "refract(", ", ", ")");       break;
diff --git a/src/compiler/translator/OutputHLSL.h b/src/compiler/translator/OutputHLSL.h
index f409015..b2b3b80 100644
--- a/src/compiler/translator/OutputHLSL.h
+++ b/src/compiler/translator/OutputHLSL.h
@@ -15,6 +15,8 @@
 #include "compiler/translator/IntermNode.h"
 #include "compiler/translator/ParseContext.h"
 
+class BuiltInFunctionEmulatorHLSL;
+
 namespace sh
 {
 class UnfoldShortCircuit;
@@ -39,7 +41,7 @@
     static TString initializer(const TType &type);
 
   protected:
-    void header();
+    void header(const BuiltInFunctionEmulatorHLSL *builtInFunctionEmulator);
 
     // Visit AST nodes and output their code to the body stream
     void visitSymbol(TIntermSymbol*);
@@ -122,21 +124,6 @@
     bool mUsesPointSize;
     bool mUsesFragDepth;
     bool mUsesXor;
-    bool mUsesMod1;
-    bool mUsesMod2v;
-    bool mUsesMod2f;
-    bool mUsesMod3v;
-    bool mUsesMod3f;
-    bool mUsesMod4v;
-    bool mUsesMod4f;
-    bool mUsesFaceforward1;
-    bool mUsesFaceforward2;
-    bool mUsesFaceforward3;
-    bool mUsesFaceforward4;
-    bool mUsesAtan2_1;
-    bool mUsesAtan2_2;
-    bool mUsesAtan2_3;
-    bool mUsesAtan2_4;
     bool mUsesDiscardRewriting;
     bool mUsesNestedBreak;