Use TLValueTrackingTraverser in ValidateLimitations

Use TLValueTrackingTraverser to determine whether a loop index is used
as an l-value. This replaces custom logic in ValidateLimitations,
greatly simplifying the code. Also pass the symbol table to
ValidateLimitations as a parameter, which removes the need to store a
global pointer to the current ParseContext.

BUG=angleproject:1960
TEST=angle_unittests, WebGL conformance tests

Change-Id: I122c85c78bbea05833d7c787cd184de568c5c45f
Reviewed-on: https://chromium-review.googlesource.com/465606
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/ValidateLimitations.cpp b/src/compiler/translator/ValidateLimitations.cpp
index e647e4c..3a17614 100644
--- a/src/compiler/translator/ValidateLimitations.cpp
+++ b/src/compiler/translator/ValidateLimitations.cpp
@@ -6,10 +6,9 @@
 
 #include "compiler/translator/ValidateLimitations.h"
 
-#include "compiler/translator/Diagnostics.h"
-#include "compiler/translator/InitializeParseContext.h"
-#include "compiler/translator/ParseContext.h"
 #include "angle_gl.h"
+#include "compiler/translator/Diagnostics.h"
+#include "compiler/translator/ParseContext.h"
 
 namespace sh
 {
@@ -65,30 +64,73 @@
     const std::vector<int> mLoopSymbolIds;
 };
 
-}  // namespace anonymous
+// Traverses intermediate tree to ensure that the shader does not exceed the
+// minimum functionality mandated in GLSL 1.0 spec, Appendix A.
+class ValidateLimitationsTraverser : public TLValueTrackingTraverser
+{
+  public:
+    ValidateLimitationsTraverser(sh::GLenum shaderType,
+                                 const TSymbolTable &symbolTable,
+                                 int shaderVersion,
+                                 TDiagnostics *diagnostics);
 
-ValidateLimitations::ValidateLimitations(sh::GLenum shaderType, TDiagnostics *diagnostics)
-    : TIntermTraverser(true, false, false),
+    void visitSymbol(TIntermSymbol *node) override;
+    bool visitBinary(Visit, TIntermBinary *) override;
+    bool visitLoop(Visit, TIntermLoop *) override;
+
+  private:
+    void error(TSourceLoc loc, const char *reason, const char *token);
+
+    bool withinLoopBody() const;
+    bool isLoopIndex(TIntermSymbol *symbol);
+    bool validateLoopType(TIntermLoop *node);
+
+    bool validateForLoopHeader(TIntermLoop *node);
+    // If valid, return the index symbol id; Otherwise, return -1.
+    int validateForLoopInit(TIntermLoop *node);
+    bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
+    bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
+
+    // Returns true if indexing does not exceed the minimum functionality
+    // mandated in GLSL 1.0 spec, Appendix A, Section 5.
+    bool isConstExpr(TIntermNode *node);
+    bool isConstIndexExpr(TIntermNode *node);
+    bool validateIndexing(TIntermBinary *node);
+
+    sh::GLenum mShaderType;
+    TDiagnostics *mDiagnostics;
+    std::vector<int> mLoopSymbolIds;
+};
+
+ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
+                                                           const TSymbolTable &symbolTable,
+                                                           int shaderVersion,
+                                                           TDiagnostics *diagnostics)
+    : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
       mShaderType(shaderType),
-      mDiagnostics(diagnostics),
-      mValidateIndexing(true),
-      mValidateInnerLoops(true)
+      mDiagnostics(diagnostics)
 {
     ASSERT(diagnostics);
 }
 
-bool ValidateLimitations::visitBinary(Visit, TIntermBinary *node)
+void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
 {
-    // Check if loop index is modified in the loop body.
-    validateOperation(node, node->getLeft());
+    if (isLoopIndex(node) && isLValueRequiredHere())
+    {
+        error(node->getLine(),
+              "Loop index cannot be statically assigned to within the body of the loop",
+              node->getSymbol().c_str());
+    }
+}
 
+bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
+{
     // Check indexing.
     switch (node->getOp())
     {
         case EOpIndexDirect:
         case EOpIndexIndirect:
-            if (mValidateIndexing)
-                validateIndexing(node);
+            validateIndexing(node);
             break;
         default:
             break;
@@ -96,33 +138,8 @@
     return true;
 }
 
-bool ValidateLimitations::visitUnary(Visit, TIntermUnary *node)
+bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
 {
-    // Check if loop index is modified in the loop body.
-    validateOperation(node, node->getOperand());
-
-    return true;
-}
-
-bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate *node)
-{
-    switch (node->getOp())
-    {
-        case EOpCallFunctionInAST:
-        case EOpCallBuiltInFunction:
-            validateFunctionCall(node);
-            break;
-        default:
-            break;
-    }
-    return true;
-}
-
-bool ValidateLimitations::visitLoop(Visit, TIntermLoop *node)
-{
-    if (!mValidateInnerLoops)
-        return true;
-
     if (!validateLoopType(node))
         return false;
 
@@ -141,23 +158,23 @@
     return false;
 }
 
-void ValidateLimitations::error(TSourceLoc loc, const char *reason, const char *token)
+void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
 {
     mDiagnostics->error(loc, reason, token);
 }
 
-bool ValidateLimitations::withinLoopBody() const
+bool ValidateLimitationsTraverser::withinLoopBody() const
 {
     return !mLoopSymbolIds.empty();
 }
 
-bool ValidateLimitations::isLoopIndex(TIntermSymbol *symbol)
+bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
 {
     return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->getId()) !=
            mLoopSymbolIds.end();
 }
 
-bool ValidateLimitations::validateLoopType(TIntermLoop *node)
+bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
 {
     TLoopType type = node->getType();
     if (type == ELoopFor)
@@ -168,7 +185,7 @@
     return false;
 }
 
-bool ValidateLimitations::validateForLoopHeader(TIntermLoop *node)
+bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
 {
     ASSERT(node->getType() == ELoopFor);
 
@@ -187,7 +204,7 @@
     return true;
 }
 
-int ValidateLimitations::validateForLoopInit(TIntermLoop *node)
+int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
 {
     TIntermNode *init = node->getInit();
     if (init == NULL)
@@ -243,7 +260,7 @@
     return symbol->getId();
 }
 
-bool ValidateLimitations::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
+bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
 {
     TIntermNode *cond = node->getCondition();
     if (cond == NULL)
@@ -299,7 +316,7 @@
     return true;
 }
 
-bool ValidateLimitations::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
+bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
 {
     TIntermNode *expr = node->getExpression();
     if (expr == NULL)
@@ -377,77 +394,13 @@
     return true;
 }
 
-bool ValidateLimitations::validateFunctionCall(TIntermAggregate *node)
-{
-    ASSERT(node->getOp() == EOpCallFunctionInAST || node->getOp() == EOpCallBuiltInFunction);
-
-    // If not within loop body, there is nothing to check.
-    if (!withinLoopBody())
-        return true;
-
-    // List of param indices for which loop indices are used as argument.
-    typedef std::vector<size_t> ParamIndex;
-    ParamIndex pIndex;
-    TIntermSequence *params = node->getSequence();
-    for (TIntermSequence::size_type i = 0; i < params->size(); ++i)
-    {
-        TIntermSymbol *symbol = (*params)[i]->getAsSymbolNode();
-        if (symbol && isLoopIndex(symbol))
-            pIndex.push_back(i);
-    }
-    // If none of the loop indices are used as arguments,
-    // there is nothing to check.
-    if (pIndex.empty())
-        return true;
-
-    bool valid                = true;
-    TSymbolTable &symbolTable = GetGlobalParseContext()->symbolTable;
-    // TODO(oetuaho@nvidia.com): It would be neater to leverage TIntermLValueTrackingTraverser to
-    // keep track of out parameters, rather than doing a symbol table lookup here.
-    TString mangledName = TFunction::GetMangledNameFromCall(
-        node->getFunctionSymbolInfo()->getName(), *node->getSequence());
-    TSymbol *symbol = symbolTable.find(mangledName, GetGlobalParseContext()->getShaderVersion());
-    ASSERT(symbol && symbol->isFunction());
-    TFunction *function = static_cast<TFunction *>(symbol);
-    for (ParamIndex::const_iterator i = pIndex.begin(); i != pIndex.end(); ++i)
-    {
-        const TConstParameter &param = function->getParam(*i);
-        TQualifier qual              = param.type->getQualifier();
-        if ((qual == EvqOut) || (qual == EvqInOut))
-        {
-            error((*params)[*i]->getLine(),
-                  "Loop index cannot be used as argument to a function out or inout parameter",
-                  (*params)[*i]->getAsSymbolNode()->getSymbol().c_str());
-            valid = false;
-        }
-    }
-
-    return valid;
-}
-
-bool ValidateLimitations::validateOperation(TIntermOperator *node, TIntermNode *operand)
-{
-    // Check if loop index is modified in the loop body.
-    if (!withinLoopBody() || !node->isAssignment())
-        return true;
-
-    TIntermSymbol *symbol = operand->getAsSymbolNode();
-    if (symbol && isLoopIndex(symbol))
-    {
-        error(node->getLine(),
-              "Loop index cannot be statically assigned to within the body of the loop",
-              symbol->getSymbol().c_str());
-    }
-    return true;
-}
-
-bool ValidateLimitations::isConstExpr(TIntermNode *node)
+bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
 {
     ASSERT(node != nullptr);
     return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
 }
 
-bool ValidateLimitations::isConstIndexExpr(TIntermNode *node)
+bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
 {
     ASSERT(node != NULL);
 
@@ -456,7 +409,7 @@
     return validate.isValid();
 }
 
-bool ValidateLimitations::validateIndexing(TIntermBinary *node)
+bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
 {
     ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
 
@@ -474,4 +427,17 @@
     return valid;
 }
 
+}  // namespace anonymous
+
+bool ValidateLimitations(TIntermNode *root,
+                         GLenum shaderType,
+                         const TSymbolTable &symbolTable,
+                         int shaderVersion,
+                         TDiagnostics *diagnostics)
+{
+    ValidateLimitationsTraverser validate(shaderType, symbolTable, shaderVersion, diagnostics);
+    root->traverse(&validate);
+    return diagnostics->numErrors() == 0;
+}
+
 }  // namespace sh