Simplify AST transformations that need to find main

Share code for finding the main function from the AST between
InitializeVariables, DeferGlobalInitializers,
EmulateGLFragColorBroadcast and UseInterfaceBlockFields. This makes
InitializeVariables simpler in particular, as it doesn't need an AST
traverser anymore.

BUG=angleproject:2033
TEST=angle_unittests, WebGL conformance tests

Change-Id: I14c994bbde58a904f6684d2f0b72bd8004f70902
Reviewed-on: https://chromium-review.googlesource.com/501166
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/UseInterfaceBlockFields.cpp b/src/compiler/translator/UseInterfaceBlockFields.cpp
index 390e2b0..0819cab 100644
--- a/src/compiler/translator/UseInterfaceBlockFields.cpp
+++ b/src/compiler/translator/UseInterfaceBlockFields.cpp
@@ -10,6 +10,7 @@
 
 #include "compiler/translator/UseInterfaceBlockFields.h"
 
+#include "compiler/translator/FindMain.h"
 #include "compiler/translator/IntermNode.h"
 #include "compiler/translator/SymbolTable.h"
 #include "compiler/translator/util.h"
@@ -20,47 +21,9 @@
 namespace
 {
 
-class UseUniformBlockMembers : public TIntermTraverser
-{
-  public:
-    UseUniformBlockMembers(const InterfaceBlockList &blocks, const TSymbolTable &symbolTable)
-        : TIntermTraverser(true, false, false),
-          mBlocks(blocks),
-          mSymbolTable(symbolTable),
-          mCodeInserted(false)
-    {
-        ASSERT(mSymbolTable.atGlobalLevel());
-    }
-
-  protected:
-    bool visitAggregate(Visit visit, TIntermAggregate *node) override { return !mCodeInserted; }
-    bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
-
-  private:
-    void insertUseCode(TIntermSequence *sequence);
-    void AddFieldUseStatements(const ShaderVariable &var, TIntermSequence *sequence);
-
-    const InterfaceBlockList &mBlocks;
-    const TSymbolTable &mSymbolTable;
-    bool mCodeInserted;
-};
-
-bool UseUniformBlockMembers::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
-{
-    ASSERT(visit == PreVisit);
-    if (node->getFunctionSymbolInfo()->isMain())
-    {
-        TIntermBlock *body = node->getBody();
-        ASSERT(body);
-        insertUseCode(body->getSequence());
-        mCodeInserted = true;
-        return false;
-    }
-    return !mCodeInserted;
-}
-
-void UseUniformBlockMembers::AddFieldUseStatements(const ShaderVariable &var,
-                                                   TIntermSequence *sequence)
+void AddFieldUseStatements(const ShaderVariable &var,
+                           TIntermSequence *sequence,
+                           const TSymbolTable &symbolTable)
 {
     TString name = TString(var.name.c_str());
     if (var.isArray())
@@ -75,7 +38,7 @@
     TType basicType;
     if (var.isStruct())
     {
-        TVariable *structInfo = reinterpret_cast<TVariable *>(mSymbolTable.findGlobal(name));
+        TVariable *structInfo = reinterpret_cast<TVariable *>(symbolTable.findGlobal(name));
         ASSERT(structInfo);
         const TType &structType = structInfo->getType();
         type                    = &structType;
@@ -103,21 +66,23 @@
     }
 }
 
-void UseUniformBlockMembers::insertUseCode(TIntermSequence *sequence)
+void InsertUseCode(TIntermSequence *sequence,
+                   const InterfaceBlockList &blocks,
+                   const TSymbolTable &symbolTable)
 {
-    for (const auto &block : mBlocks)
+    for (const auto &block : blocks)
     {
         if (block.instanceName.empty())
         {
             for (const auto &var : block.fields)
             {
-                AddFieldUseStatements(var, sequence);
+                AddFieldUseStatements(var, sequence, symbolTable);
             }
         }
         else if (block.arraySize > 0)
         {
             TString name      = TString(block.instanceName.c_str());
-            TVariable *ubInfo = reinterpret_cast<TVariable *>(mSymbolTable.findGlobal(name));
+            TVariable *ubInfo = reinterpret_cast<TVariable *>(symbolTable.findGlobal(name));
             ASSERT(ubInfo);
             TIntermSymbol *arraySymbol = new TIntermSymbol(0, name, ubInfo->getType());
             for (unsigned int i = 0; i < block.arraySize; ++i)
@@ -136,7 +101,7 @@
         else
         {
             TString name      = TString(block.instanceName.c_str());
-            TVariable *ubInfo = reinterpret_cast<TVariable *>(mSymbolTable.findGlobal(name));
+            TVariable *ubInfo = reinterpret_cast<TVariable *>(symbolTable.findGlobal(name));
             ASSERT(ubInfo);
             TIntermSymbol *blockSymbol = new TIntermSymbol(0, name, ubInfo->getType());
             for (unsigned int i = 0; i < block.fields.size(); ++i)
@@ -152,12 +117,14 @@
 
 }  // namespace anonymous
 
-void UseInterfaceBlockFields(TIntermNode *root,
+void UseInterfaceBlockFields(TIntermBlock *root,
                              const InterfaceBlockList &blocks,
                              const TSymbolTable &symbolTable)
 {
-    UseUniformBlockMembers useUniformBlock(blocks, symbolTable);
-    root->traverse(&useUniformBlock);
+    TIntermFunctionDefinition *main = FindMain(root);
+    TIntermBlock *mainBody          = main->getBody();
+    ASSERT(mainBody);
+    InsertUseCode(mainBody->getSequence(), blocks, symbolTable);
 }
 
 }  // namespace sh