Make HLSL shaders use only one main function

Instead of having separate main() and gl_main() functions in HLSL
shaders, add initializing outputs and inputs directly to the main
function that's in the AST.

This works around some HLSL bugs and should not introduce name
conflicts inside main() since all the user-defined variables are
prefixed.

BUG=angleproject:2325
TEST=angle_end2end_tests

Change-Id: I5b000c96aac8f321cefe50b6a893008498eac0d5
Reviewed-on: https://chromium-review.googlesource.com/1146647
Reviewed-by: Geoff Lang <geofflang@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/OutputHLSL.cpp b/src/compiler/translator/OutputHLSL.cpp
index 5e8e2cb..5e85824 100644
--- a/src/compiler/translator/OutputHLSL.cpp
+++ b/src/compiler/translator/OutputHLSL.cpp
@@ -182,6 +182,7 @@
                        int numRenderTargets,
                        const std::vector<Uniform> &uniforms,
                        ShCompileOptions compileOptions,
+                       sh::WorkGroupSize workGroupSize,
                        TSymbolTable *symbolTable,
                        PerformanceDiagnostics *perfDiagnostics)
     : TIntermTraverser(true, true, true, symbolTable),
@@ -191,12 +192,13 @@
       mSourcePath(sourcePath),
       mOutputType(outputType),
       mCompileOptions(compileOptions),
+      mInsideFunction(false),
+      mInsideMain(false),
       mNumRenderTargets(numRenderTargets),
       mCurrentFunctionMetadata(nullptr),
+      mWorkGroupSize(workGroupSize),
       mPerfDiagnostics(perfDiagnostics)
 {
-    mInsideFunction = false;
-
     mUsesFragColor   = false;
     mUsesFragData    = false;
     mUsesDepthRange  = false;
@@ -1743,10 +1745,16 @@
 {
     TInfoSinkBase &out = getInfoSink();
 
+    bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
+
     if (mInsideFunction)
     {
         outputLineDirective(out, node->getLine().first_line);
         out << "{\n";
+        if (isMainBlock)
+        {
+            out << "@@ MAIN PROLOGUE @@\n";
+        }
     }
 
     for (TIntermNode *statement : *node->getSequence())
@@ -1781,6 +1789,19 @@
     if (mInsideFunction)
     {
         outputLineDirective(out, node->getLine().last_line);
+        if (isMainBlock && shaderNeedsGenerateOutput())
+        {
+            // We could have an empty main, a main function without a branch at the end, or a main
+            // function with a discard statement at the end. In these cases we need to add a return
+            // statement.
+            bool needReturnStatement =
+                node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
+                node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
+            if (needReturnStatement)
+            {
+                out << "return " << generateOutputCall() << ";\n";
+            }
+        }
         out << "}\n";
     }
 
@@ -1797,40 +1818,64 @@
     ASSERT(index != CallDAG::InvalidIndex);
     mCurrentFunctionMetadata = &mASTMetadataList[index];
 
-    out << TypeString(node->getFunctionPrototype()->getType()) << " ";
-
     const TFunction *func = node->getFunction();
 
     if (func->isMain())
     {
-        out << "gl_main(";
+        // The stub strings below are replaced when shader is dynamically defined by its layout:
+        switch (mShaderType)
+        {
+            case GL_VERTEX_SHADER:
+                out << "@@ VERTEX ATTRIBUTES @@\n\n"
+                    << "@@ VERTEX OUTPUT @@\n\n"
+                    << "VS_OUTPUT main(VS_INPUT input)";
+                break;
+            case GL_FRAGMENT_SHADER:
+                out << "@@ PIXEL OUTPUT @@\n\n"
+                    << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
+                break;
+            case GL_COMPUTE_SHADER:
+                out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
+                    << mWorkGroupSize[2] << ")]\n";
+                out << "void main(CS_INPUT input)";
+                break;
+            default:
+                UNREACHABLE();
+                break;
+        }
     }
     else
     {
+        out << TypeString(node->getFunctionPrototype()->getType()) << " ";
         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
             << (mOutputLod0Function ? "Lod0(" : "(");
-    }
 
-    size_t paramCount = func->getParamCount();
-    for (unsigned int i = 0; i < paramCount; i++)
-    {
-        const TVariable *param = func->getParam(i);
-        ensureStructDefined(param->getType());
-
-        writeParameter(param, out);
-
-        if (i < paramCount - 1)
+        size_t paramCount = func->getParamCount();
+        for (unsigned int i = 0; i < paramCount; i++)
         {
-            out << ", ";
-        }
-    }
+            const TVariable *param = func->getParam(i);
+            ensureStructDefined(param->getType());
 
-    out << ")\n";
+            writeParameter(param, out);
+
+            if (i < paramCount - 1)
+            {
+                out << ", ";
+            }
+        }
+
+        out << ")\n";
+    }
 
     mInsideFunction = true;
+    if (func->isMain())
+    {
+        mInsideMain = true;
+    }
     // The function body node will output braces.
     node->getBody()->traverse(this);
     mInsideFunction = false;
+    mInsideMain     = false;
 
     mCurrentFunctionMetadata = nullptr;
 
@@ -2455,11 +2500,19 @@
             case EOpReturn:
                 if (node->getExpression())
                 {
+                    ASSERT(!mInsideMain);
                     out << "return ";
                 }
                 else
                 {
-                    out << "return";
+                    if (mInsideMain && shaderNeedsGenerateOutput())
+                    {
+                        out << "return " << generateOutputCall();
+                    }
+                    else
+                    {
+                        out << "return";
+                    }
                 }
                 break;
             default:
@@ -3154,4 +3207,21 @@
     }
 }
 
+bool OutputHLSL::shaderNeedsGenerateOutput() const
+{
+    return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
+}
+
+const char *OutputHLSL::generateOutputCall() const
+{
+    if (mShaderType == GL_VERTEX_SHADER)
+    {
+        return "generateOutput(input)";
+    }
+    else
+    {
+        return "generateOutput()";
+    }
+}
+
 }  // namespace sh
diff --git a/src/compiler/translator/OutputHLSL.h b/src/compiler/translator/OutputHLSL.h
index 6685f63..d2cfb87 100644
--- a/src/compiler/translator/OutputHLSL.h
+++ b/src/compiler/translator/OutputHLSL.h
@@ -53,6 +53,7 @@
                int numRenderTargets,
                const std::vector<Uniform> &uniforms,
                ShCompileOptions compileOptions,
+               sh::WorkGroupSize workGroupSize,
                TSymbolTable *symbolTable,
                PerformanceDiagnostics *perfDiagnostics);
 
@@ -146,6 +147,9 @@
     // Ensures if the type is a struct, the struct is defined
     void ensureStructDefined(const TType &type);
 
+    bool shaderNeedsGenerateOutput() const;
+    const char *generateOutputCall() const;
+
     sh::GLenum mShaderType;
     int mShaderVersion;
     const TExtensionBehavior &mExtensionBehavior;
@@ -154,6 +158,7 @@
     ShCompileOptions mCompileOptions;
 
     bool mInsideFunction;
+    bool mInsideMain;
 
     // Output streams
     TInfoSinkBase mHeader;
@@ -250,6 +255,8 @@
     // arrays can't be return values in HLSL.
     std::vector<ArrayHelperFunction> mArrayConstructIntoFunctions;
 
+    sh::WorkGroupSize mWorkGroupSize;
+
     PerformanceDiagnostics *mPerfDiagnostics;
 
   private:
diff --git a/src/compiler/translator/TranslatorHLSL.cpp b/src/compiler/translator/TranslatorHLSL.cpp
index db2af8f..4990622 100644
--- a/src/compiler/translator/TranslatorHLSL.cpp
+++ b/src/compiler/translator/TranslatorHLSL.cpp
@@ -128,7 +128,8 @@
 
     sh::OutputHLSL outputHLSL(getShaderType(), getShaderVersion(), getExtensionBehavior(),
                               getSourcePath(), getOutputType(), numRenderTargets, getUniforms(),
-                              compileOptions, &getSymbolTable(), perfDiagnostics);
+                              compileOptions, getComputeShaderLocalSize(), &getSymbolTable(),
+                              perfDiagnostics);
 
     outputHLSL.output(root, getInfoSink().obj);