Add expression complexity and call stack depth limits.
git-svn-id: https://angleproject.googlecode.com/svn/trunk@2242 736b8ea6-26fd-11df-bfd4-992fa37f6226
TRAC #23333
Authored-by: gman@chromium.org
Signed-off-by: Shannon Woods
Signed-off-by Nicolas Capens
Merged-by: Jamie Madill
Conflicts:
src/common/version.h
diff --git a/src/compiler/Compiler.cpp b/src/compiler/Compiler.cpp
index d39d921..c7c792c 100644
--- a/src/compiler/Compiler.cpp
+++ b/src/compiler/Compiler.cpp
@@ -5,7 +5,7 @@
//
#include "compiler/BuiltInFunctionEmulator.h"
-#include "compiler/DetectRecursion.h"
+#include "compiler/DetectCallDepth.h"
#include "compiler/ForLoopUnroll.h"
#include "compiler/Initialize.h"
#include "compiler/InitializeParseContext.h"
@@ -59,6 +59,9 @@
TCompiler::TCompiler(ShShaderType type, ShShaderSpec spec)
: shaderType(type),
shaderSpec(spec),
+ maxUniformVectors(0),
+ maxExpressionComplexity(0),
+ maxCallStackDepth(0),
fragmentPrecisionHigh(false),
clampingStrategy(SH_CLAMP_WITH_CLAMP_INTRINSIC),
builtInFunctionEmulator(type)
@@ -78,6 +81,8 @@
maxUniformVectors = (shaderType == SH_VERTEX_SHADER) ?
resources.MaxVertexUniformVectors :
resources.MaxFragmentUniformVectors;
+ maxExpressionComplexity = resources.MaxExpressionComplexity;
+ maxCallStackDepth = resources.MaxCallStackDepth;
TScopedPoolAllocator scopedAlloc(&allocator, false);
// Generate built-in symbol table.
@@ -144,7 +149,7 @@
success = intermediate.postProcess(root);
if (success)
- success = detectRecursion(root);
+ success = detectCallDepth(root, infoSink, (compileOptions & SH_LIMIT_CALL_STACK_DEPTH) != 0);
if (success && shaderVersion == 300 && shaderType == SH_FRAGMENT_SHADER)
success = validateOutputs(root);
@@ -170,6 +175,10 @@
if (success && (compileOptions & SH_CLAMP_INDIRECT_ARRAY_BOUNDS))
arrayBoundsClamper.MarkIndirectArrayBoundsForClamping(root);
+ // Disallow expressions deemed too complex.
+ if (success && (compileOptions & SH_LIMIT_EXPRESSION_COMPLEXITY))
+ success = limitExpressionComplexity(root);
+
// Call mapLongVariableNames() before collectAttribsUniforms() so in
// collectAttribsUniforms() we already have the mapped symbol names and
// we could composite mapped and original variable names.
@@ -260,21 +269,25 @@
nameMap.clear();
}
-bool TCompiler::detectRecursion(TIntermNode* root)
+bool TCompiler::detectCallDepth(TIntermNode* root, TInfoSink& infoSink, bool limitCallStackDepth)
{
- DetectRecursion detect;
+ DetectCallDepth detect(infoSink, limitCallStackDepth, maxCallStackDepth);
root->traverse(&detect);
- switch (detect.detectRecursion()) {
- case DetectRecursion::kErrorNone:
+ switch (detect.detectCallDepth()) {
+ case DetectCallDepth::kErrorNone:
return true;
- case DetectRecursion::kErrorMissingMain:
+ case DetectCallDepth::kErrorMissingMain:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Missing main()";
return false;
- case DetectRecursion::kErrorRecursion:
+ case DetectCallDepth::kErrorRecursion:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Function recursion detected";
return false;
+ case DetectCallDepth::kErrorMaxDepthExceeded:
+ infoSink.info.prefix(EPrefixError);
+ infoSink.info << "Function call stack too deep";
+ return false;
default:
UNREACHABLE();
return false;
@@ -326,6 +339,28 @@
}
}
+bool TCompiler::limitExpressionComplexity(TIntermNode* root)
+{
+ TIntermTraverser traverser;
+ root->traverse(&traverser);
+ TDependencyGraph graph(root);
+
+ for (TFunctionCallVector::const_iterator iter = graph.beginUserDefinedFunctionCalls();
+ iter != graph.endUserDefinedFunctionCalls();
+ ++iter)
+ {
+ TGraphFunctionCall* samplerSymbol = *iter;
+ TDependencyGraphTraverser graphTraverser;
+ samplerSymbol->traverse(&graphTraverser);
+ }
+
+ if (traverser.getMaxDepth() > maxExpressionComplexity) {
+ infoSink.info << "Expression too complex.";
+ return false;
+ }
+ return true;
+}
+
bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph)
{
RestrictFragmentShaderTiming restrictor(infoSink.info);
diff --git a/src/compiler/DetectCallDepth.cpp b/src/compiler/DetectCallDepth.cpp
new file mode 100644
index 0000000..60df52c
--- /dev/null
+++ b/src/compiler/DetectCallDepth.cpp
@@ -0,0 +1,185 @@
+//
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/DetectCallDepth.h"
+#include "compiler/InfoSink.h"
+
+DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
+ : name(fname),
+ visit(PreVisit)
+{
+}
+
+const TString& DetectCallDepth::FunctionNode::getName() const
+{
+ return name;
+}
+
+void DetectCallDepth::FunctionNode::addCallee(
+ DetectCallDepth::FunctionNode* callee)
+{
+ for (size_t i = 0; i < callees.size(); ++i) {
+ if (callees[i] == callee)
+ return;
+ }
+ callees.push_back(callee);
+}
+
+int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
+{
+ ASSERT(visit == PreVisit);
+ ASSERT(detectCallDepth);
+
+ int maxDepth = depth;
+ visit = InVisit;
+ for (size_t i = 0; i < callees.size(); ++i) {
+ switch (callees[i]->visit) {
+ case InVisit:
+ // cycle detected, i.e., recursion detected.
+ return kInfiniteCallDepth;
+ case PostVisit:
+ break;
+ case PreVisit: {
+ // Check before we recurse so we don't go too depth
+ if (detectCallDepth->checkExceedsMaxDepth(depth))
+ return depth;
+ int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
+ // Check after we recurse so we can exit immediately and provide info.
+ if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
+ detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
+ return callDepth;
+ }
+ maxDepth = std::max(callDepth, maxDepth);
+ break;
+ }
+ default:
+ UNREACHABLE();
+ break;
+ }
+ }
+ visit = PostVisit;
+ return maxDepth;
+}
+
+void DetectCallDepth::FunctionNode::reset()
+{
+ visit = PreVisit;
+}
+
+DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
+ : TIntermTraverser(true, false, true, false),
+ currentFunction(NULL),
+ infoSink(infoSink),
+ maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
+{
+}
+
+DetectCallDepth::~DetectCallDepth()
+{
+ for (size_t i = 0; i < functions.size(); ++i)
+ delete functions[i];
+}
+
+bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
+{
+ switch (node->getOp())
+ {
+ case EOpPrototype:
+ // Function declaration.
+ // Don't add FunctionNode here because node->getName() is the
+ // unmangled function name.
+ break;
+ case EOpFunction: {
+ // Function definition.
+ if (visit == PreVisit) {
+ currentFunction = findFunctionByName(node->getName());
+ if (currentFunction == NULL) {
+ currentFunction = new FunctionNode(node->getName());
+ functions.push_back(currentFunction);
+ }
+ } else if (visit == PostVisit) {
+ currentFunction = NULL;
+ }
+ break;
+ }
+ case EOpFunctionCall: {
+ // Function call.
+ if (visit == PreVisit) {
+ FunctionNode* func = findFunctionByName(node->getName());
+ if (func == NULL) {
+ func = new FunctionNode(node->getName());
+ functions.push_back(func);
+ }
+ if (currentFunction)
+ currentFunction->addCallee(func);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return true;
+}
+
+bool DetectCallDepth::checkExceedsMaxDepth(int depth)
+{
+ return depth >= maxDepth;
+}
+
+void DetectCallDepth::resetFunctionNodes()
+{
+ for (size_t i = 0; i < functions.size(); ++i) {
+ functions[i]->reset();
+ }
+}
+
+DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
+{
+ currentFunction = NULL;
+ resetFunctionNodes();
+
+ int maxCallDepth = func->detectCallDepth(this, 1);
+
+ if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
+ return kErrorRecursion;
+
+ if (maxCallDepth >= maxDepth)
+ return kErrorMaxDepthExceeded;
+
+ return kErrorNone;
+}
+
+DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
+{
+ if (maxDepth != FunctionNode::kInfiniteCallDepth) {
+ // Check all functions because the driver may fail on them
+ // TODO: Before detectingRecursion, strip unused functions.
+ for (size_t i = 0; i < functions.size(); ++i) {
+ ErrorCode error = detectCallDepthForFunction(functions[i]);
+ if (error != kErrorNone)
+ return error;
+ }
+ } else {
+ FunctionNode* main = findFunctionByName("main(");
+ if (main == NULL)
+ return kErrorMissingMain;
+
+ return detectCallDepthForFunction(main);
+ }
+
+ return kErrorNone;
+}
+
+DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
+ const TString& name)
+{
+ for (size_t i = 0; i < functions.size(); ++i) {
+ if (functions[i]->getName() == name)
+ return functions[i];
+ }
+ return NULL;
+}
+
diff --git a/src/compiler/DetectCallDepth.h b/src/compiler/DetectCallDepth.h
new file mode 100644
index 0000000..89e85f8
--- /dev/null
+++ b/src/compiler/DetectCallDepth.h
@@ -0,0 +1,80 @@
+//
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#ifndef COMPILER_DETECT_RECURSION_H_
+#define COMPILER_DETECT_RECURSION_H_
+
+#include "GLSLANG/ShaderLang.h"
+
+#include <limits.h>
+#include "compiler/intermediate.h"
+#include "compiler/VariableInfo.h"
+
+class TInfoSink;
+
+// Traverses intermediate tree to detect function recursion.
+class DetectCallDepth : public TIntermTraverser {
+public:
+ enum ErrorCode {
+ kErrorMissingMain,
+ kErrorRecursion,
+ kErrorMaxDepthExceeded,
+ kErrorNone
+ };
+
+ DetectCallDepth(TInfoSink& infoSync, bool limitCallStackDepth, int maxCallStackDepth);
+ ~DetectCallDepth();
+
+ virtual bool visitAggregate(Visit, TIntermAggregate*);
+
+ bool checkExceedsMaxDepth(int depth);
+
+ ErrorCode detectCallDepth();
+
+private:
+ class FunctionNode {
+ public:
+ static const int kInfiniteCallDepth = INT_MAX;
+
+ FunctionNode(const TString& fname);
+
+ const TString& getName() const;
+
+ // If a function is already in the callee list, this becomes a no-op.
+ void addCallee(FunctionNode* callee);
+
+ // Returns kInifinityCallDepth if recursive function calls are detected.
+ int detectCallDepth(DetectCallDepth* detectCallDepth, int depth);
+
+ // Reset state.
+ void reset();
+
+ private:
+ // mangled function name is unique.
+ TString name;
+
+ // functions that are directly called by this function.
+ TVector<FunctionNode*> callees;
+
+ Visit visit;
+ };
+
+ ErrorCode detectCallDepthForFunction(FunctionNode* func);
+ FunctionNode* findFunctionByName(const TString& name);
+ void resetFunctionNodes();
+
+ TInfoSink& getInfoSink() { return infoSink; }
+
+ TVector<FunctionNode*> functions;
+ FunctionNode* currentFunction;
+ TInfoSink& infoSink;
+ int maxDepth;
+
+ DetectCallDepth(const DetectCallDepth&);
+ void operator=(const DetectCallDepth&);
+};
+
+#endif // COMPILER_DETECT_RECURSION_H_
diff --git a/src/compiler/DetectRecursion.cpp b/src/compiler/DetectRecursion.cpp
deleted file mode 100644
index c09780d..0000000
--- a/src/compiler/DetectRecursion.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//
-// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-//
-
-#include "compiler/DetectRecursion.h"
-
-DetectRecursion::FunctionNode::FunctionNode(const TString& fname)
- : name(fname),
- visit(PreVisit)
-{
-}
-
-const TString& DetectRecursion::FunctionNode::getName() const
-{
- return name;
-}
-
-void DetectRecursion::FunctionNode::addCallee(
- DetectRecursion::FunctionNode* callee)
-{
- for (size_t i = 0; i < callees.size(); ++i) {
- if (callees[i] == callee)
- return;
- }
- callees.push_back(callee);
-}
-
-bool DetectRecursion::FunctionNode::detectRecursion()
-{
- ASSERT(visit == PreVisit);
- visit = InVisit;
- for (size_t i = 0; i < callees.size(); ++i) {
- switch (callees[i]->visit) {
- case InVisit:
- // cycle detected, i.e., recursion detected.
- return true;
- case PostVisit:
- break;
- case PreVisit: {
- bool recursion = callees[i]->detectRecursion();
- if (recursion)
- return true;
- break;
- }
- default:
- UNREACHABLE();
- break;
- }
- }
- visit = PostVisit;
- return false;
-}
-
-DetectRecursion::DetectRecursion()
- : currentFunction(NULL)
-{
-}
-
-DetectRecursion::~DetectRecursion()
-{
- for (size_t i = 0; i < functions.size(); ++i)
- delete functions[i];
-}
-
-bool DetectRecursion::visitAggregate(Visit visit, TIntermAggregate* node)
-{
- switch (node->getOp())
- {
- case EOpPrototype:
- // Function declaration.
- // Don't add FunctionNode here because node->getName() is the
- // unmangled function name.
- break;
- case EOpFunction: {
- // Function definition.
- if (visit == PreVisit) {
- currentFunction = findFunctionByName(node->getName());
- if (currentFunction == NULL) {
- currentFunction = new FunctionNode(node->getName());
- functions.push_back(currentFunction);
- }
- }
- break;
- }
- case EOpFunctionCall: {
- // Function call.
- if (visit == PreVisit) {
- ASSERT(currentFunction != NULL);
- FunctionNode* func = findFunctionByName(node->getName());
- if (func == NULL) {
- func = new FunctionNode(node->getName());
- functions.push_back(func);
- }
- currentFunction->addCallee(func);
- }
- break;
- }
- default:
- break;
- }
- return true;
-}
-
-DetectRecursion::ErrorCode DetectRecursion::detectRecursion()
-{
- FunctionNode* main = findFunctionByName("main(");
- if (main == NULL)
- return kErrorMissingMain;
- if (main->detectRecursion())
- return kErrorRecursion;
- return kErrorNone;
-}
-
-DetectRecursion::FunctionNode* DetectRecursion::findFunctionByName(
- const TString& name)
-{
- for (size_t i = 0; i < functions.size(); ++i) {
- if (functions[i]->getName() == name)
- return functions[i];
- }
- return NULL;
-}
-
diff --git a/src/compiler/DetectRecursion.h b/src/compiler/DetectRecursion.h
deleted file mode 100644
index bbac79d..0000000
--- a/src/compiler/DetectRecursion.h
+++ /dev/null
@@ -1,60 +0,0 @@
-//
-// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-//
-
-#ifndef COMPILER_DETECT_RECURSION_H_
-#define COMPILER_DETECT_RECURSION_H_
-
-#include "GLSLANG/ShaderLang.h"
-
-#include "compiler/intermediate.h"
-#include "compiler/VariableInfo.h"
-
-// Traverses intermediate tree to detect function recursion.
-class DetectRecursion : public TIntermTraverser {
-public:
- enum ErrorCode {
- kErrorMissingMain,
- kErrorRecursion,
- kErrorNone
- };
-
- DetectRecursion();
- ~DetectRecursion();
-
- virtual bool visitAggregate(Visit, TIntermAggregate*);
-
- ErrorCode detectRecursion();
-
-private:
- class FunctionNode {
- public:
- FunctionNode(const TString& fname);
-
- const TString& getName() const;
-
- // If a function is already in the callee list, this becomes a no-op.
- void addCallee(FunctionNode* callee);
-
- // Return true if recursive function calls are detected.
- bool detectRecursion();
-
- private:
- // mangled function name is unique.
- TString name;
-
- // functions that are directly called by this function.
- TVector<FunctionNode*> callees;
-
- Visit visit;
- };
-
- FunctionNode* findFunctionByName(const TString& name);
-
- TVector<FunctionNode*> functions;
- FunctionNode* currentFunction;
-};
-
-#endif // COMPILER_DETECT_RECURSION_H_
diff --git a/src/compiler/ShHandle.h b/src/compiler/ShHandle.h
index 5f8d9d0..b523a77 100644
--- a/src/compiler/ShHandle.h
+++ b/src/compiler/ShHandle.h
@@ -84,8 +84,8 @@
bool InitBuiltInSymbolTable(const ShBuiltInResources& resources);
// Clears the results from the previous compilation.
void clearResults();
- // Return true if function recursion is detected.
- bool detectRecursion(TIntermNode* root);
+ // Return true if function recursion is detected or call depth exceeded.
+ bool detectCallDepth(TIntermNode* root, TInfoSink& infoSink, bool limitCallStackDepth);
// Returns true if a program has no conflicting or missing fragment outputs
bool validateOutputs(TIntermNode* root);
// Rewrites a shader's intermediate tree according to the CSS Shaders spec.
@@ -109,6 +109,8 @@
// Returns true if the shader does not use sampler dependent values to affect control
// flow or in operations whose time can depend on the input values.
bool enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph);
+ // Return true if the maximum expression complexity is below the limit.
+ bool limitExpressionComplexity(TIntermNode* root);
// Get built-in extensions with default behavior.
const TExtensionBehavior& getExtensionBehavior() const;
// Get the resources set by InitBuiltInSymbolTable
@@ -123,6 +125,8 @@
ShShaderSpec shaderSpec;
int maxUniformVectors;
+ int maxExpressionComplexity;
+ int maxCallStackDepth;
ShBuiltInResources compileResources;
diff --git a/src/compiler/intermediate.h b/src/compiler/intermediate.h
index a8a0e31..d7f4ace 100644
--- a/src/compiler/intermediate.h
+++ b/src/compiler/intermediate.h
@@ -18,6 +18,7 @@
#include "GLSLANG/ShaderLang.h"
+#include <algorithm>
#include "compiler/Common.h"
#include "compiler/Types.h"
#include "compiler/ConstantUnion.h"
@@ -562,7 +563,8 @@
inVisit(inVisit),
postVisit(postVisit),
rightToLeft(rightToLeft),
- depth(0) {}
+ depth(0),
+ maxDepth(0) {}
virtual ~TIntermTraverser() {};
virtual void visitSymbol(TIntermSymbol*) {}
@@ -574,7 +576,8 @@
virtual bool visitLoop(Visit visit, TIntermLoop*) {return true;}
virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;}
- void incrementDepth() {depth++;}
+ int getMaxDepth() const {return maxDepth;}
+ void incrementDepth() {depth++; maxDepth = std::max(maxDepth, depth); }
void decrementDepth() {depth--;}
// Return the original name if hash function pointer is NULL;
@@ -588,6 +591,7 @@
protected:
int depth;
+ int maxDepth;
};
#endif // __INTERMEDIATE_H
diff --git a/src/compiler/translator_common.vcxproj b/src/compiler/translator_common.vcxproj
index 2a318e2..ffa9cfd 100644
--- a/src/compiler/translator_common.vcxproj
+++ b/src/compiler/translator_common.vcxproj
@@ -143,7 +143,7 @@
<ClCompile Include="BuiltInFunctionEmulator.cpp" />
<ClCompile Include="Compiler.cpp" />
<ClCompile Include="debug.cpp" />
- <ClCompile Include="DetectRecursion.cpp" />
+ <ClCompile Include="DetectCallDepth.cpp" />
<ClCompile Include="Diagnostics.cpp" />
<ClCompile Include="DirectiveHandler.cpp" />
<ClCompile Include="ForLoopUnroll.cpp" />
@@ -235,7 +235,7 @@
<ClInclude Include="Common.h" />
<ClInclude Include="ConstantUnion.h" />
<ClInclude Include="debug.h" />
- <ClInclude Include="DetectRecursion.h" />
+ <ClInclude Include="DetectCallDepth.h" />
<ClInclude Include="Diagnostics.h" />
<ClInclude Include="DirectiveHandler.h" />
<ClInclude Include="ForLoopUnroll.h" />
diff --git a/src/compiler/translator_common.vcxproj.filters b/src/compiler/translator_common.vcxproj.filters
index cab4fa4..0e05db6 100644
--- a/src/compiler/translator_common.vcxproj.filters
+++ b/src/compiler/translator_common.vcxproj.filters
@@ -38,7 +38,7 @@
<ClCompile Include="debug.cpp">
<Filter>Source Files</Filter>
</ClCompile>
- <ClCompile Include="DetectRecursion.cpp">
+ <ClCompile Include="DetectCallDepth.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Diagnostics.cpp">
@@ -163,7 +163,7 @@
<ClInclude Include="debug.h">
<Filter>Header Files</Filter>
</ClInclude>
- <ClInclude Include="DetectRecursion.h">
+ <ClInclude Include="DetectCallDepth.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="Diagnostics.h">