Refactor CollectVariables

New helper functions are added for collecting built-in variables, and
the traverser is encapsulated inside VariableInfo.cpp. The helper
functions get data for built-in variables from the symbol table, so a
duplicate copy of the data doesn't need to be maintained in
CollectVariables any more.

BUG=angleproject:2068
TEST=angle_unittests

Change-Id: I42595d0da0e5d4fb634a3d92f38db1dd6dd9efab
Reviewed-on: https://chromium-review.googlesource.com/549323
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/VariableInfo.cpp b/src/compiler/translator/VariableInfo.cpp
index d41d451..cb22d40 100644
--- a/src/compiler/translator/VariableInfo.cpp
+++ b/src/compiler/translator/VariableInfo.cpp
@@ -4,11 +4,13 @@
 // found in the LICENSE file.
 //
 
-#include "angle_gl.h"
-#include "compiler/translator/SymbolTable.h"
 #include "compiler/translator/VariableInfo.h"
-#include "compiler/translator/util.h"
+
+#include "angle_gl.h"
 #include "common/utilities.h"
+#include "compiler/translator/IntermNode.h"
+#include "compiler/translator/SymbolTable.h"
+#include "compiler/translator/util.h"
 
 namespace sh
 {
@@ -62,16 +64,85 @@
 
     return nullptr;
 }
-}
 
-CollectVariables::CollectVariables(std::vector<sh::Attribute> *attribs,
-                                   std::vector<sh::OutputVariable> *outputVariables,
-                                   std::vector<sh::Uniform> *uniforms,
-                                   std::vector<sh::Varying> *varyings,
-                                   std::vector<sh::InterfaceBlock> *interfaceBlocks,
-                                   ShHashFunction64 hashFunction,
-                                   const TSymbolTable &symbolTable,
-                                   const TExtensionBehavior &extensionBehavior)
+// Traverses the intermediate tree to collect all attributes, uniforms, varyings, fragment outputs,
+// and interface blocks.
+class CollectVariablesTraverser : public TIntermTraverser
+{
+  public:
+    CollectVariablesTraverser(std::vector<Attribute> *attribs,
+                              std::vector<OutputVariable> *outputVariables,
+                              std::vector<Uniform> *uniforms,
+                              std::vector<Varying> *varyings,
+                              std::vector<InterfaceBlock> *interfaceBlocks,
+                              ShHashFunction64 hashFunction,
+                              const TSymbolTable &symbolTable,
+                              int shaderVersion,
+                              const TExtensionBehavior &extensionBehavior);
+
+    void visitSymbol(TIntermSymbol *symbol) override;
+    bool visitDeclaration(Visit, TIntermDeclaration *node) override;
+    bool visitBinary(Visit visit, TIntermBinary *binaryNode) override;
+
+  private:
+    void setCommonVariableProperties(const TType &type,
+                                     const TString &name,
+                                     ShaderVariable *variableOut) const;
+
+    Attribute recordAttribute(const TIntermSymbol &variable) const;
+    OutputVariable recordOutputVariable(const TIntermSymbol &variable) const;
+    Varying recordVarying(const TIntermSymbol &variable) const;
+    InterfaceBlock recordInterfaceBlock(const TIntermSymbol &variable) const;
+    Uniform recordUniform(const TIntermSymbol &variable) const;
+
+    void setBuiltInInfoFromSymbolTable(const char *name, ShaderVariable *info);
+
+    void recordBuiltInVaryingUsed(const char *name, bool *addedFlag);
+    void recordBuiltInFragmentOutputUsed(const char *name, bool *addedFlag);
+    void recordBuiltInAttributeUsed(const char *name, bool *addedFlag);
+
+    std::vector<Attribute> *mAttribs;
+    std::vector<OutputVariable> *mOutputVariables;
+    std::vector<Uniform> *mUniforms;
+    std::vector<Varying> *mVaryings;
+    std::vector<InterfaceBlock> *mInterfaceBlocks;
+
+    std::map<std::string, InterfaceBlockField *> mInterfaceBlockFields;
+
+    bool mDepthRangeAdded;
+    bool mPointCoordAdded;
+    bool mFrontFacingAdded;
+    bool mFragCoordAdded;
+
+    bool mInstanceIDAdded;
+    bool mVertexIDAdded;
+    bool mPositionAdded;
+    bool mPointSizeAdded;
+    bool mLastFragDataAdded;
+    bool mFragColorAdded;
+    bool mFragDataAdded;
+    bool mFragDepthEXTAdded;
+    bool mFragDepthAdded;
+    bool mSecondaryFragColorEXTAdded;
+    bool mSecondaryFragDataEXTAdded;
+
+    ShHashFunction64 mHashFunction;
+
+    const TSymbolTable &mSymbolTable;
+    int mShaderVersion;
+    const TExtensionBehavior &mExtensionBehavior;
+};
+
+CollectVariablesTraverser::CollectVariablesTraverser(
+    std::vector<sh::Attribute> *attribs,
+    std::vector<sh::OutputVariable> *outputVariables,
+    std::vector<sh::Uniform> *uniforms,
+    std::vector<sh::Varying> *varyings,
+    std::vector<sh::InterfaceBlock> *interfaceBlocks,
+    ShHashFunction64 hashFunction,
+    const TSymbolTable &symbolTable,
+    int shaderVersion,
+    const TExtensionBehavior &extensionBehavior)
     : TIntermTraverser(true, false, false),
       mAttribs(attribs),
       mOutputVariables(outputVariables),
@@ -95,16 +166,70 @@
       mSecondaryFragDataEXTAdded(false),
       mHashFunction(hashFunction),
       mSymbolTable(symbolTable),
+      mShaderVersion(shaderVersion),
       mExtensionBehavior(extensionBehavior)
 {
 }
 
+void CollectVariablesTraverser::setBuiltInInfoFromSymbolTable(const char *name,
+                                                              ShaderVariable *info)
+{
+    TVariable *symbolTableVar =
+        reinterpret_cast<TVariable *>(mSymbolTable.findBuiltIn(name, mShaderVersion));
+    ASSERT(symbolTableVar);
+    const TType &type = symbolTableVar->getType();
+
+    info->name       = name;
+    info->mappedName = name;
+    info->type       = GLVariableType(type);
+    info->arraySize  = type.isArray() ? type.getArraySize() : 0;
+    info->precision  = GLVariablePrecision(type);
+}
+
+void CollectVariablesTraverser::recordBuiltInVaryingUsed(const char *name, bool *addedFlag)
+{
+    if (!(*addedFlag))
+    {
+        Varying info;
+        setBuiltInInfoFromSymbolTable(name, &info);
+        info.staticUse   = true;
+        info.isInvariant = mSymbolTable.isVaryingInvariant(name);
+        mVaryings->push_back(info);
+        (*addedFlag) = true;
+    }
+}
+
+void CollectVariablesTraverser::recordBuiltInFragmentOutputUsed(const char *name, bool *addedFlag)
+{
+    if (!(*addedFlag))
+    {
+        OutputVariable info;
+        setBuiltInInfoFromSymbolTable(name, &info);
+        info.staticUse = true;
+        mOutputVariables->push_back(info);
+        (*addedFlag) = true;
+    }
+}
+
+void CollectVariablesTraverser::recordBuiltInAttributeUsed(const char *name, bool *addedFlag)
+{
+    if (!(*addedFlag))
+    {
+        Attribute info;
+        setBuiltInInfoFromSymbolTable(name, &info);
+        info.staticUse = true;
+        info.location  = -1;
+        mAttribs->push_back(info);
+        (*addedFlag) = true;
+    }
+}
+
 // We want to check whether a uniform/varying is statically used
 // because we only count the used ones in packing computing.
 // Also, gl_FragCoord, gl_PointCoord, and gl_FrontFacing count
 // toward varying counting if they are statically used in a fragment
 // shader.
-void CollectVariables::visitSymbol(TIntermSymbol *symbol)
+void CollectVariablesTraverser::visitSymbol(TIntermSymbol *symbol)
 {
     ASSERT(symbol != nullptr);
     ShaderVariable *var       = nullptr;
@@ -202,241 +327,59 @@
             }
             break;
             case EvqFragCoord:
-                if (!mFragCoordAdded)
-                {
-                    Varying info;
-                    const char kName[] = "gl_FragCoord";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC4;
-                    info.arraySize     = 0;
-                    info.precision     = GL_MEDIUM_FLOAT;  // Defined by spec.
-                    info.staticUse     = true;
-                    info.isInvariant   = mSymbolTable.isVaryingInvariant(kName);
-                    mVaryings->push_back(info);
-                    mFragCoordAdded = true;
-                }
+                recordBuiltInVaryingUsed("gl_FragCoord", &mFragCoordAdded);
                 return;
             case EvqFrontFacing:
-                if (!mFrontFacingAdded)
-                {
-                    Varying info;
-                    const char kName[] = "gl_FrontFacing";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_BOOL;
-                    info.arraySize     = 0;
-                    info.precision     = GL_NONE;
-                    info.staticUse     = true;
-                    info.isInvariant   = mSymbolTable.isVaryingInvariant(kName);
-                    mVaryings->push_back(info);
-                    mFrontFacingAdded = true;
-                }
+                recordBuiltInVaryingUsed("gl_FrontFacing", &mFrontFacingAdded);
                 return;
             case EvqPointCoord:
-                if (!mPointCoordAdded)
-                {
-                    Varying info;
-                    const char kName[] = "gl_PointCoord";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC2;
-                    info.arraySize     = 0;
-                    info.precision     = GL_MEDIUM_FLOAT;  // Defined by spec.
-                    info.staticUse     = true;
-                    info.isInvariant   = mSymbolTable.isVaryingInvariant(kName);
-                    mVaryings->push_back(info);
-                    mPointCoordAdded = true;
-                }
+                recordBuiltInVaryingUsed("gl_PointCoord", &mPointCoordAdded);
                 return;
             case EvqInstanceID:
-                if (!mInstanceIDAdded)
-                {
-                    Attribute info;
-                    const char kName[] = "gl_InstanceID";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_INT;
-                    info.arraySize     = 0;
-                    info.precision     = GL_HIGH_INT;  // Defined by spec.
-                    info.staticUse     = true;
-                    info.location      = -1;
-                    mAttribs->push_back(info);
-                    mInstanceIDAdded = true;
-                }
+                recordBuiltInAttributeUsed("gl_InstanceID", &mInstanceIDAdded);
                 return;
             case EvqVertexID:
-                if (!mVertexIDAdded)
-                {
-                    Attribute info;
-                    const char kName[] = "gl_VertexID";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_INT;
-                    info.arraySize     = 0;
-                    info.precision     = GL_HIGH_INT;  // Defined by spec.
-                    info.staticUse     = true;
-                    info.location      = -1;
-                    mAttribs->push_back(info);
-                    mVertexIDAdded = true;
-                }
+                recordBuiltInAttributeUsed("gl_VertexID", &mVertexIDAdded);
                 return;
             case EvqPosition:
-                if (!mPositionAdded)
-                {
-                    Varying info;
-                    const char kName[] = "gl_Position";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC4;
-                    info.arraySize     = 0;
-                    info.precision     = GL_HIGH_FLOAT;  // Defined by spec.
-                    info.staticUse     = true;
-                    info.isInvariant   = mSymbolTable.isVaryingInvariant(kName);
-                    mVaryings->push_back(info);
-                    mPositionAdded = true;
-                }
+                recordBuiltInVaryingUsed("gl_Position", &mPositionAdded);
                 return;
             case EvqPointSize:
-                if (!mPointSizeAdded)
-                {
-                    Varying info;
-                    const char kName[] = "gl_PointSize";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT;
-                    info.arraySize     = 0;
-                    info.precision     = GL_MEDIUM_FLOAT;  // Defined by spec.
-                    info.staticUse     = true;
-                    info.isInvariant   = mSymbolTable.isVaryingInvariant(kName);
-                    mVaryings->push_back(info);
-                    mPointSizeAdded = true;
-                }
+                recordBuiltInVaryingUsed("gl_PointSize", &mPointSizeAdded);
                 return;
             case EvqLastFragData:
-                if (!mLastFragDataAdded)
-                {
-                    Varying info;
-                    const char kName[] = "gl_LastFragData";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC4;
-                    info.arraySize     = static_cast<const TVariable *>(
-                                         mSymbolTable.findBuiltIn("gl_MaxDrawBuffers", 100))
-                                         ->getConstPointer()
-                                         ->getIConst();
-                    info.precision   = GL_MEDIUM_FLOAT;  // Defined by spec.
-                    info.staticUse   = true;
-                    info.isInvariant = mSymbolTable.isVaryingInvariant(kName);
-                    mVaryings->push_back(info);
-                    mLastFragDataAdded = true;
-                }
+                recordBuiltInVaryingUsed("gl_LastFragData", &mLastFragDataAdded);
                 return;
             case EvqFragColor:
-                if (!mFragColorAdded)
-                {
-                    OutputVariable info;
-                    const char kName[] = "gl_FragColor";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC4;
-                    info.arraySize     = 0;
-                    info.precision     = GL_MEDIUM_FLOAT;  // Defined by spec.
-                    info.staticUse     = true;
-                    mOutputVariables->push_back(info);
-                    mFragColorAdded = true;
-                }
+                recordBuiltInFragmentOutputUsed("gl_FragColor", &mFragColorAdded);
                 return;
             case EvqFragData:
                 if (!mFragDataAdded)
                 {
                     OutputVariable info;
-                    const char kName[] = "gl_FragData";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC4;
-                    if (::IsExtensionEnabled(mExtensionBehavior, "GL_EXT_draw_buffers"))
-                    {
-                        info.arraySize = static_cast<const TVariable *>(
-                                             mSymbolTable.findBuiltIn("gl_MaxDrawBuffers", 100))
-                                             ->getConstPointer()
-                                             ->getIConst();
-                    }
-                    else
+                    setBuiltInInfoFromSymbolTable("gl_FragData", &info);
+                    if (!::IsExtensionEnabled(mExtensionBehavior, "GL_EXT_draw_buffers"))
                     {
                         info.arraySize = 1;
                     }
-                    info.precision = GL_MEDIUM_FLOAT;  // Defined by spec.
                     info.staticUse = true;
                     mOutputVariables->push_back(info);
                     mFragDataAdded = true;
                 }
                 return;
             case EvqFragDepthEXT:
-                if (!mFragDepthEXTAdded)
-                {
-                    OutputVariable info;
-                    const char kName[] = "gl_FragDepthEXT";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT;
-                    info.arraySize     = 0;
-                    info.precision =
-                        GLVariablePrecision(static_cast<const TVariable *>(
-                                                mSymbolTable.findBuiltIn("gl_FragDepthEXT", 100))
-                                                ->getType());
-                    info.staticUse = true;
-                    mOutputVariables->push_back(info);
-                    mFragDepthEXTAdded = true;
-                }
+                recordBuiltInFragmentOutputUsed("gl_FragDepthEXT", &mFragDepthEXTAdded);
                 return;
             case EvqFragDepth:
-                if (!mFragDepthAdded)
-                {
-                    OutputVariable info;
-                    const char kName[] = "gl_FragDepth";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT;
-                    info.arraySize     = 0;
-                    info.precision     = GL_HIGH_FLOAT;
-                    info.staticUse     = true;
-                    mOutputVariables->push_back(info);
-                    mFragDepthAdded = true;
-                }
+                recordBuiltInFragmentOutputUsed("gl_FragDepth", &mFragDepthAdded);
                 return;
             case EvqSecondaryFragColorEXT:
-                if (!mSecondaryFragColorEXTAdded)
-                {
-                    OutputVariable info;
-                    const char kName[] = "gl_SecondaryFragColorEXT";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC4;
-                    info.arraySize     = 0;
-                    info.precision     = GL_MEDIUM_FLOAT;  // Defined by spec.
-                    info.staticUse     = true;
-                    mOutputVariables->push_back(info);
-                    mSecondaryFragColorEXTAdded = true;
-                }
+                recordBuiltInFragmentOutputUsed("gl_SecondaryFragColorEXT",
+                                                &mSecondaryFragColorEXTAdded);
                 return;
             case EvqSecondaryFragDataEXT:
-                if (!mSecondaryFragDataEXTAdded)
-                {
-                    OutputVariable info;
-                    const char kName[] = "gl_SecondaryFragDataEXT";
-                    info.name          = kName;
-                    info.mappedName    = kName;
-                    info.type          = GL_FLOAT_VEC4;
-
-                    const TVariable *maxDualSourceDrawBuffersVar = static_cast<const TVariable *>(
-                        mSymbolTable.findBuiltIn("gl_MaxDualSourceDrawBuffersEXT", 100));
-                    info.arraySize = maxDualSourceDrawBuffersVar->getConstPointer()->getIConst();
-                    info.precision = GL_MEDIUM_FLOAT;  // Defined by spec.
-                    info.staticUse = true;
-                    mOutputVariables->push_back(info);
-                    mSecondaryFragDataEXTAdded = true;
-                }
+                recordBuiltInFragmentOutputUsed("gl_SecondaryFragDataEXT",
+                                                &mSecondaryFragDataEXTAdded);
                 return;
             default:
                 break;
@@ -448,9 +391,9 @@
     }
 }
 
-void CollectVariables::setCommonVariableProperties(const TType &type,
-                                                   const TString &name,
-                                                   ShaderVariable *variableOut) const
+void CollectVariablesTraverser::setCommonVariableProperties(const TType &type,
+                                                            const TString &name,
+                                                            ShaderVariable *variableOut) const
 {
     ASSERT(variableOut);
 
@@ -483,7 +426,7 @@
     variableOut->arraySize  = type.getArraySize();
 }
 
-Attribute CollectVariables::recordAttribute(const TIntermSymbol &variable) const
+Attribute CollectVariablesTraverser::recordAttribute(const TIntermSymbol &variable) const
 {
     const TType &type = variable.getType();
     ASSERT(!type.getStruct());
@@ -495,7 +438,7 @@
     return attribute;
 }
 
-OutputVariable CollectVariables::recordOutputVariable(const TIntermSymbol &variable) const
+OutputVariable CollectVariablesTraverser::recordOutputVariable(const TIntermSymbol &variable) const
 {
     const TType &type = variable.getType();
     ASSERT(!type.getStruct());
@@ -507,7 +450,7 @@
     return outputVariable;
 }
 
-Varying CollectVariables::recordVarying(const TIntermSymbol &variable) const
+Varying CollectVariablesTraverser::recordVarying(const TIntermSymbol &variable) const
 {
     const TType &type = variable.getType();
 
@@ -536,7 +479,7 @@
     return varying;
 }
 
-InterfaceBlock CollectVariables::recordInterfaceBlock(const TIntermSymbol &variable) const
+InterfaceBlock CollectVariablesTraverser::recordInterfaceBlock(const TIntermSymbol &variable) const
 {
     const TInterfaceBlock *blockType = variable.getType().getInterfaceBlock();
     ASSERT(blockType);
@@ -566,7 +509,7 @@
     return interfaceBlock;
 }
 
-Uniform CollectVariables::recordUniform(const TIntermSymbol &variable) const
+Uniform CollectVariablesTraverser::recordUniform(const TIntermSymbol &variable) const
 {
     Uniform uniform;
     setCommonVariableProperties(variable.getType(), variable.getSymbol(), &uniform);
@@ -576,7 +519,7 @@
     return uniform;
 }
 
-bool CollectVariables::visitDeclaration(Visit, TIntermDeclaration *node)
+bool CollectVariablesTraverser::visitDeclaration(Visit, TIntermDeclaration *node)
 {
     const TIntermSequence &sequence = *(node->getSequence());
     ASSERT(!sequence.empty());
@@ -630,7 +573,7 @@
     return false;
 }
 
-bool CollectVariables::visitBinary(Visit, TIntermBinary *binaryNode)
+bool CollectVariablesTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
 {
     if (binaryNode->getOp() == EOpIndexDirectInterfaceBlock)
     {
@@ -655,6 +598,25 @@
     return true;
 }
 
+}  // anonymous namespace
+
+void CollectVariables(TIntermBlock *root,
+                      std::vector<Attribute> *attributes,
+                      std::vector<OutputVariable> *outputVariables,
+                      std::vector<Uniform> *uniforms,
+                      std::vector<Varying> *varyings,
+                      std::vector<InterfaceBlock> *interfaceBlocks,
+                      ShHashFunction64 hashFunction,
+                      const TSymbolTable &symbolTable,
+                      int shaderVersion,
+                      const TExtensionBehavior &extensionBehavior)
+{
+    CollectVariablesTraverser collect(attributes, outputVariables, uniforms, varyings,
+                                      interfaceBlocks, hashFunction, symbolTable, shaderVersion,
+                                      extensionBehavior);
+    root->traverse(&collect);
+}
+
 void ExpandVariable(const ShaderVariable &variable,
                     const std::string &name,
                     const std::string &mappedName,
@@ -703,10 +665,10 @@
 
 void ExpandUniforms(const std::vector<Uniform> &compact, std::vector<ShaderVariable> *expanded)
 {
-    for (size_t variableIndex = 0; variableIndex < compact.size(); variableIndex++)
+    for (const Uniform &variable : compact)
     {
-        const ShaderVariable &variable = compact[variableIndex];
         ExpandVariable(variable, variable.name, variable.mappedName, variable.staticUse, expanded);
     }
 }
-}
+
+}  // namespace sh