Add expression complexity and call stack depth limits.

git-svn-id: https://angleproject.googlecode.com/svn/trunk@2242 736b8ea6-26fd-11df-bfd4-992fa37f6226

TRAC #23333
Authored-by: gman@chromium.org
Signed-off-by: Shannon Woods
Signed-off-by Nicolas Capens
Merged-by: Jamie Madill
Conflicts:

	src/common/version.h
diff --git a/src/compiler/Compiler.cpp b/src/compiler/Compiler.cpp
index d39d921..c7c792c 100644
--- a/src/compiler/Compiler.cpp
+++ b/src/compiler/Compiler.cpp
@@ -5,7 +5,7 @@
 //
 
 #include "compiler/BuiltInFunctionEmulator.h"
-#include "compiler/DetectRecursion.h"
+#include "compiler/DetectCallDepth.h"
 #include "compiler/ForLoopUnroll.h"
 #include "compiler/Initialize.h"
 #include "compiler/InitializeParseContext.h"
@@ -59,6 +59,9 @@
 TCompiler::TCompiler(ShShaderType type, ShShaderSpec spec)
     : shaderType(type),
       shaderSpec(spec),
+      maxUniformVectors(0),
+      maxExpressionComplexity(0),
+      maxCallStackDepth(0),
       fragmentPrecisionHigh(false),
       clampingStrategy(SH_CLAMP_WITH_CLAMP_INTRINSIC),
       builtInFunctionEmulator(type)
@@ -78,6 +81,8 @@
     maxUniformVectors = (shaderType == SH_VERTEX_SHADER) ?
         resources.MaxVertexUniformVectors :
         resources.MaxFragmentUniformVectors;
+    maxExpressionComplexity = resources.MaxExpressionComplexity;
+    maxCallStackDepth = resources.MaxCallStackDepth;
     TScopedPoolAllocator scopedAlloc(&allocator, false);
 
     // Generate built-in symbol table.
@@ -144,7 +149,7 @@
         success = intermediate.postProcess(root);
 
         if (success)
-            success = detectRecursion(root);
+            success = detectCallDepth(root, infoSink, (compileOptions & SH_LIMIT_CALL_STACK_DEPTH) != 0);
 
         if (success && shaderVersion == 300 && shaderType == SH_FRAGMENT_SHADER)
             success = validateOutputs(root);
@@ -170,6 +175,10 @@
         if (success && (compileOptions & SH_CLAMP_INDIRECT_ARRAY_BOUNDS))
             arrayBoundsClamper.MarkIndirectArrayBoundsForClamping(root);
 
+        // Disallow expressions deemed too complex.
+        if (success && (compileOptions & SH_LIMIT_EXPRESSION_COMPLEXITY))
+            success = limitExpressionComplexity(root);
+
         // Call mapLongVariableNames() before collectAttribsUniforms() so in
         // collectAttribsUniforms() we already have the mapped symbol names and
         // we could composite mapped and original variable names.
@@ -260,21 +269,25 @@
     nameMap.clear();
 }
 
-bool TCompiler::detectRecursion(TIntermNode* root)
+bool TCompiler::detectCallDepth(TIntermNode* root, TInfoSink& infoSink, bool limitCallStackDepth)
 {
-    DetectRecursion detect;
+    DetectCallDepth detect(infoSink, limitCallStackDepth, maxCallStackDepth);
     root->traverse(&detect);
-    switch (detect.detectRecursion()) {
-        case DetectRecursion::kErrorNone:
+    switch (detect.detectCallDepth()) {
+        case DetectCallDepth::kErrorNone:
             return true;
-        case DetectRecursion::kErrorMissingMain:
+        case DetectCallDepth::kErrorMissingMain:
             infoSink.info.prefix(EPrefixError);
             infoSink.info << "Missing main()";
             return false;
-        case DetectRecursion::kErrorRecursion:
+        case DetectCallDepth::kErrorRecursion:
             infoSink.info.prefix(EPrefixError);
             infoSink.info << "Function recursion detected";
             return false;
+        case DetectCallDepth::kErrorMaxDepthExceeded:
+            infoSink.info.prefix(EPrefixError);
+            infoSink.info << "Function call stack too deep";
+            return false;
         default:
             UNREACHABLE();
             return false;
@@ -326,6 +339,28 @@
     }
 }
 
+bool TCompiler::limitExpressionComplexity(TIntermNode* root)
+{
+    TIntermTraverser traverser;
+    root->traverse(&traverser);
+    TDependencyGraph graph(root);
+
+    for (TFunctionCallVector::const_iterator iter = graph.beginUserDefinedFunctionCalls();
+         iter != graph.endUserDefinedFunctionCalls();
+         ++iter)
+    {
+        TGraphFunctionCall* samplerSymbol = *iter;
+        TDependencyGraphTraverser graphTraverser;
+        samplerSymbol->traverse(&graphTraverser);
+    }
+
+    if (traverser.getMaxDepth() > maxExpressionComplexity) {
+        infoSink.info << "Expression too complex.";
+        return false;
+    }
+    return true;
+}
+
 bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph)
 {
     RestrictFragmentShaderTiming restrictor(infoSink.info);
diff --git a/src/compiler/DetectCallDepth.cpp b/src/compiler/DetectCallDepth.cpp
new file mode 100644
index 0000000..60df52c
--- /dev/null
+++ b/src/compiler/DetectCallDepth.cpp
@@ -0,0 +1,185 @@
+//
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/DetectCallDepth.h"
+#include "compiler/InfoSink.h"
+
+DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
+    : name(fname),
+      visit(PreVisit)
+{
+}
+
+const TString& DetectCallDepth::FunctionNode::getName() const
+{
+    return name;
+}
+
+void DetectCallDepth::FunctionNode::addCallee(
+    DetectCallDepth::FunctionNode* callee)
+{
+    for (size_t i = 0; i < callees.size(); ++i) {
+        if (callees[i] == callee)
+            return;
+    }
+    callees.push_back(callee);
+}
+
+int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
+{
+    ASSERT(visit == PreVisit);
+    ASSERT(detectCallDepth);
+
+    int maxDepth = depth;
+    visit = InVisit;
+    for (size_t i = 0; i < callees.size(); ++i) {
+        switch (callees[i]->visit) {
+            case InVisit:
+                // cycle detected, i.e., recursion detected.
+                return kInfiniteCallDepth;
+            case PostVisit:
+                break;
+            case PreVisit: {
+                // Check before we recurse so we don't go too depth
+                if (detectCallDepth->checkExceedsMaxDepth(depth))
+                    return depth;
+                int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
+                // Check after we recurse so we can exit immediately and provide info.
+                if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
+                    detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
+                    return callDepth;
+                }
+                maxDepth = std::max(callDepth, maxDepth);
+                break;
+            }
+            default:
+                UNREACHABLE();
+                break;
+        }
+    }
+    visit = PostVisit;
+    return maxDepth;
+}
+
+void DetectCallDepth::FunctionNode::reset()
+{
+    visit = PreVisit;
+}
+
+DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
+    : TIntermTraverser(true, false, true, false),
+      currentFunction(NULL),
+      infoSink(infoSink),
+      maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
+{
+}
+
+DetectCallDepth::~DetectCallDepth()
+{
+    for (size_t i = 0; i < functions.size(); ++i)
+        delete functions[i];
+}
+
+bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
+{
+    switch (node->getOp())
+    {
+        case EOpPrototype:
+            // Function declaration.
+            // Don't add FunctionNode here because node->getName() is the
+            // unmangled function name.
+            break;
+        case EOpFunction: {
+            // Function definition.
+            if (visit == PreVisit) {
+                currentFunction = findFunctionByName(node->getName());
+                if (currentFunction == NULL) {
+                    currentFunction = new FunctionNode(node->getName());
+                    functions.push_back(currentFunction);
+                }
+            } else if (visit == PostVisit) {
+                currentFunction = NULL;
+            }
+            break;
+        }
+        case EOpFunctionCall: {
+            // Function call.
+            if (visit == PreVisit) {
+                FunctionNode* func = findFunctionByName(node->getName());
+                if (func == NULL) {
+                    func = new FunctionNode(node->getName());
+                    functions.push_back(func);
+                }
+                if (currentFunction)
+                    currentFunction->addCallee(func);
+            }
+            break;
+        }
+        default:
+            break;
+    }
+    return true;
+}
+
+bool DetectCallDepth::checkExceedsMaxDepth(int depth)
+{
+    return depth >= maxDepth;
+}
+
+void DetectCallDepth::resetFunctionNodes()
+{
+    for (size_t i = 0; i < functions.size(); ++i) {
+        functions[i]->reset();
+    }
+}
+
+DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
+{
+    currentFunction = NULL;
+    resetFunctionNodes();
+
+    int maxCallDepth = func->detectCallDepth(this, 1);
+
+    if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
+        return kErrorRecursion;
+
+    if (maxCallDepth >= maxDepth)
+        return kErrorMaxDepthExceeded;
+
+    return kErrorNone;
+}
+
+DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
+{
+    if (maxDepth != FunctionNode::kInfiniteCallDepth) {
+        // Check all functions because the driver may fail on them
+        // TODO: Before detectingRecursion, strip unused functions.
+        for (size_t i = 0; i < functions.size(); ++i) {
+            ErrorCode error = detectCallDepthForFunction(functions[i]);
+            if (error != kErrorNone)
+                return error;
+        }
+    } else {
+        FunctionNode* main = findFunctionByName("main(");
+        if (main == NULL)
+            return kErrorMissingMain;
+
+        return detectCallDepthForFunction(main);
+    }
+
+    return kErrorNone;
+}
+
+DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
+    const TString& name)
+{
+    for (size_t i = 0; i < functions.size(); ++i) {
+        if (functions[i]->getName() == name)
+            return functions[i];
+    }
+    return NULL;
+}
+
diff --git a/src/compiler/DetectCallDepth.h b/src/compiler/DetectCallDepth.h
new file mode 100644
index 0000000..89e85f8
--- /dev/null
+++ b/src/compiler/DetectCallDepth.h
@@ -0,0 +1,80 @@
+//
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#ifndef COMPILER_DETECT_RECURSION_H_
+#define COMPILER_DETECT_RECURSION_H_
+
+#include "GLSLANG/ShaderLang.h"
+
+#include <limits.h>
+#include "compiler/intermediate.h"
+#include "compiler/VariableInfo.h"
+
+class TInfoSink;
+
+// Traverses intermediate tree to detect function recursion.
+class DetectCallDepth : public TIntermTraverser {
+public:
+    enum ErrorCode {
+        kErrorMissingMain,
+        kErrorRecursion,
+        kErrorMaxDepthExceeded,
+        kErrorNone
+    };
+
+    DetectCallDepth(TInfoSink& infoSync, bool limitCallStackDepth, int maxCallStackDepth);
+    ~DetectCallDepth();
+
+    virtual bool visitAggregate(Visit, TIntermAggregate*);
+
+    bool checkExceedsMaxDepth(int depth);
+
+    ErrorCode detectCallDepth();
+
+private:
+    class FunctionNode {
+    public:
+        static const int kInfiniteCallDepth = INT_MAX;
+
+        FunctionNode(const TString& fname);
+
+        const TString& getName() const;
+
+        // If a function is already in the callee list, this becomes a no-op.
+        void addCallee(FunctionNode* callee);
+
+        // Returns kInifinityCallDepth if recursive function calls are detected.
+        int detectCallDepth(DetectCallDepth* detectCallDepth, int depth);
+
+        // Reset state.
+        void reset();
+
+    private:
+        // mangled function name is unique.
+        TString name;
+
+        // functions that are directly called by this function.
+        TVector<FunctionNode*> callees;
+
+        Visit visit;
+    };
+
+    ErrorCode detectCallDepthForFunction(FunctionNode* func);
+    FunctionNode* findFunctionByName(const TString& name);
+    void resetFunctionNodes();
+
+    TInfoSink& getInfoSink() { return infoSink; }
+
+    TVector<FunctionNode*> functions;
+    FunctionNode* currentFunction;
+    TInfoSink& infoSink;
+    int maxDepth;
+
+    DetectCallDepth(const DetectCallDepth&);
+    void operator=(const DetectCallDepth&);
+};
+
+#endif  // COMPILER_DETECT_RECURSION_H_
diff --git a/src/compiler/DetectRecursion.cpp b/src/compiler/DetectRecursion.cpp
deleted file mode 100644
index c09780d..0000000
--- a/src/compiler/DetectRecursion.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//
-// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-//
-
-#include "compiler/DetectRecursion.h"
-
-DetectRecursion::FunctionNode::FunctionNode(const TString& fname)
-    : name(fname),
-      visit(PreVisit)
-{
-}
-
-const TString& DetectRecursion::FunctionNode::getName() const
-{
-    return name;
-}
-
-void DetectRecursion::FunctionNode::addCallee(
-    DetectRecursion::FunctionNode* callee)
-{
-    for (size_t i = 0; i < callees.size(); ++i) {
-        if (callees[i] == callee)
-            return;
-    }
-    callees.push_back(callee);
-}
-
-bool DetectRecursion::FunctionNode::detectRecursion()
-{
-    ASSERT(visit == PreVisit);
-    visit = InVisit;
-    for (size_t i = 0; i < callees.size(); ++i) {
-        switch (callees[i]->visit) {
-            case InVisit:
-                // cycle detected, i.e., recursion detected.
-                return true;
-            case PostVisit:
-                break;
-            case PreVisit: {
-                bool recursion = callees[i]->detectRecursion();
-                if (recursion)
-                    return true;
-                break;
-            }
-            default:
-                UNREACHABLE();
-                break;
-        }
-    }
-    visit = PostVisit;
-    return false;
-}
-
-DetectRecursion::DetectRecursion()
-    : currentFunction(NULL)
-{
-}
-
-DetectRecursion::~DetectRecursion()
-{
-    for (size_t i = 0; i < functions.size(); ++i)
-        delete functions[i];
-}
-
-bool DetectRecursion::visitAggregate(Visit visit, TIntermAggregate* node)
-{
-    switch (node->getOp())
-    {
-        case EOpPrototype:
-            // Function declaration.
-            // Don't add FunctionNode here because node->getName() is the
-            // unmangled function name.
-            break;
-        case EOpFunction: {
-            // Function definition.
-            if (visit == PreVisit) {
-                currentFunction = findFunctionByName(node->getName());
-                if (currentFunction == NULL) {
-                    currentFunction = new FunctionNode(node->getName());
-                    functions.push_back(currentFunction);
-                }
-            }
-            break;
-        }
-        case EOpFunctionCall: {
-            // Function call.
-            if (visit == PreVisit) {
-                ASSERT(currentFunction != NULL);
-                FunctionNode* func = findFunctionByName(node->getName());
-                if (func == NULL) {
-                    func = new FunctionNode(node->getName());
-                    functions.push_back(func);
-                }
-                currentFunction->addCallee(func);
-            }
-            break;
-        }
-        default:
-            break;
-    }
-    return true;
-}
-
-DetectRecursion::ErrorCode DetectRecursion::detectRecursion()
-{
-    FunctionNode* main = findFunctionByName("main(");
-    if (main == NULL)
-        return kErrorMissingMain;
-    if (main->detectRecursion())
-        return kErrorRecursion;
-    return kErrorNone;
-}
-
-DetectRecursion::FunctionNode* DetectRecursion::findFunctionByName(
-    const TString& name)
-{
-    for (size_t i = 0; i < functions.size(); ++i) {
-        if (functions[i]->getName() == name)
-            return functions[i];
-    }
-    return NULL;
-}
-
diff --git a/src/compiler/DetectRecursion.h b/src/compiler/DetectRecursion.h
deleted file mode 100644
index bbac79d..0000000
--- a/src/compiler/DetectRecursion.h
+++ /dev/null
@@ -1,60 +0,0 @@
-//
-// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-//
-
-#ifndef COMPILER_DETECT_RECURSION_H_
-#define COMPILER_DETECT_RECURSION_H_
-
-#include "GLSLANG/ShaderLang.h"
-
-#include "compiler/intermediate.h"
-#include "compiler/VariableInfo.h"
-
-// Traverses intermediate tree to detect function recursion.
-class DetectRecursion : public TIntermTraverser {
-public:
-    enum ErrorCode {
-        kErrorMissingMain,
-        kErrorRecursion,
-        kErrorNone
-    };
-
-    DetectRecursion();
-    ~DetectRecursion();
-
-    virtual bool visitAggregate(Visit, TIntermAggregate*);
-
-    ErrorCode detectRecursion();
-
-private:
-    class FunctionNode {
-    public:
-        FunctionNode(const TString& fname);
-
-        const TString& getName() const;
-
-        // If a function is already in the callee list, this becomes a no-op.
-        void addCallee(FunctionNode* callee);
-
-        // Return true if recursive function calls are detected.
-        bool detectRecursion();
-
-    private:
-        // mangled function name is unique.
-        TString name;
-
-        // functions that are directly called by this function.
-        TVector<FunctionNode*> callees;
-
-        Visit visit;
-    };
-
-    FunctionNode* findFunctionByName(const TString& name);
-
-    TVector<FunctionNode*> functions;
-    FunctionNode* currentFunction;
-};
-
-#endif  // COMPILER_DETECT_RECURSION_H_
diff --git a/src/compiler/ShHandle.h b/src/compiler/ShHandle.h
index 5f8d9d0..b523a77 100644
--- a/src/compiler/ShHandle.h
+++ b/src/compiler/ShHandle.h
@@ -84,8 +84,8 @@
     bool InitBuiltInSymbolTable(const ShBuiltInResources& resources);
     // Clears the results from the previous compilation.
     void clearResults();
-    // Return true if function recursion is detected.
-    bool detectRecursion(TIntermNode* root);
+    // Return true if function recursion is detected or call depth exceeded.
+    bool detectCallDepth(TIntermNode* root, TInfoSink& infoSink, bool limitCallStackDepth);
     // Returns true if a program has no conflicting or missing fragment outputs
     bool validateOutputs(TIntermNode* root);
     // Rewrites a shader's intermediate tree according to the CSS Shaders spec.
@@ -109,6 +109,8 @@
     // Returns true if the shader does not use sampler dependent values to affect control 
     // flow or in operations whose time can depend on the input values.
     bool enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph);
+    // Return true if the maximum expression complexity is below the limit.
+    bool limitExpressionComplexity(TIntermNode* root);
     // Get built-in extensions with default behavior.
     const TExtensionBehavior& getExtensionBehavior() const;
     // Get the resources set by InitBuiltInSymbolTable
@@ -123,6 +125,8 @@
     ShShaderSpec shaderSpec;
 
     int maxUniformVectors;
+    int maxExpressionComplexity;
+    int maxCallStackDepth;
 
     ShBuiltInResources compileResources;
 
diff --git a/src/compiler/intermediate.h b/src/compiler/intermediate.h
index a8a0e31..d7f4ace 100644
--- a/src/compiler/intermediate.h
+++ b/src/compiler/intermediate.h
@@ -18,6 +18,7 @@
 
 #include "GLSLANG/ShaderLang.h"
 
+#include <algorithm>
 #include "compiler/Common.h"
 #include "compiler/Types.h"
 #include "compiler/ConstantUnion.h"
@@ -562,7 +563,8 @@
             inVisit(inVisit),
             postVisit(postVisit),
             rightToLeft(rightToLeft),
-            depth(0) {}
+            depth(0),
+            maxDepth(0) {}
     virtual ~TIntermTraverser() {};
 
     virtual void visitSymbol(TIntermSymbol*) {}
@@ -574,7 +576,8 @@
     virtual bool visitLoop(Visit visit, TIntermLoop*) {return true;}
     virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;}
 
-    void incrementDepth() {depth++;}
+    int getMaxDepth() const {return maxDepth;}
+    void incrementDepth() {depth++; maxDepth = std::max(maxDepth, depth); }
     void decrementDepth() {depth--;}
 
     // Return the original name if hash function pointer is NULL;
@@ -588,6 +591,7 @@
 
 protected:
     int depth;
+    int maxDepth;
 };
 
 #endif // __INTERMEDIATE_H
diff --git a/src/compiler/translator_common.vcxproj b/src/compiler/translator_common.vcxproj
index 2a318e2..ffa9cfd 100644
--- a/src/compiler/translator_common.vcxproj
+++ b/src/compiler/translator_common.vcxproj
@@ -143,7 +143,7 @@
     <ClCompile Include="BuiltInFunctionEmulator.cpp" />

     <ClCompile Include="Compiler.cpp" />

     <ClCompile Include="debug.cpp" />

-    <ClCompile Include="DetectRecursion.cpp" />

+    <ClCompile Include="DetectCallDepth.cpp" />

     <ClCompile Include="Diagnostics.cpp" />

     <ClCompile Include="DirectiveHandler.cpp" />

     <ClCompile Include="ForLoopUnroll.cpp" />

@@ -235,7 +235,7 @@
     <ClInclude Include="Common.h" />

     <ClInclude Include="ConstantUnion.h" />

     <ClInclude Include="debug.h" />

-    <ClInclude Include="DetectRecursion.h" />

+    <ClInclude Include="DetectCallDepth.h" />

     <ClInclude Include="Diagnostics.h" />

     <ClInclude Include="DirectiveHandler.h" />

     <ClInclude Include="ForLoopUnroll.h" />

diff --git a/src/compiler/translator_common.vcxproj.filters b/src/compiler/translator_common.vcxproj.filters
index cab4fa4..0e05db6 100644
--- a/src/compiler/translator_common.vcxproj.filters
+++ b/src/compiler/translator_common.vcxproj.filters
@@ -38,7 +38,7 @@
     <ClCompile Include="debug.cpp">

       <Filter>Source Files</Filter>

     </ClCompile>

-    <ClCompile Include="DetectRecursion.cpp">

+    <ClCompile Include="DetectCallDepth.cpp">

       <Filter>Source Files</Filter>

     </ClCompile>

     <ClCompile Include="Diagnostics.cpp">

@@ -163,7 +163,7 @@
     <ClInclude Include="debug.h">

       <Filter>Header Files</Filter>

     </ClInclude>

-    <ClInclude Include="DetectRecursion.h">

+    <ClInclude Include="DetectCallDepth.h">

       <Filter>Header Files</Filter>

     </ClInclude>

     <ClInclude Include="Diagnostics.h">