Add an option to unfold short circuiting in AST.
We replace "a || b" with "a ? true : b",
"a && b" with "a ? b : false".
This is to work around short circuiting bug in Mac drivers.
ANGLEBUG=482
TEST=webgl conformance tests
R=alokp@chromium.org, kbr@chromium.org
Review URL: https://codereview.appspot.com/14529048
Conflicts:
src/build_angle.gypi
src/compiler/translator/Compiler.cpp
Change-Id: Ic2384a97d58f54294efcb3a012deb2007a9fc658
Reviewed-on: https://chromium-review.googlesource.com/178996
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Commit-Queue: Jamie Madill <jmadill@chromium.org>
Tested-by: Shannon Woods <shannonwoods@chromium.org>
diff --git a/src/compiler/translator/Compiler.cpp b/src/compiler/translator/Compiler.cpp
index 80c97ea..909f35b 100644
--- a/src/compiler/translator/Compiler.cpp
+++ b/src/compiler/translator/Compiler.cpp
@@ -14,6 +14,7 @@
#include "compiler/translator/ParseContext.h"
#include "compiler/translator/RenameFunction.h"
#include "compiler/translator/ShHandle.h"
+#include "compiler/translator/UnfoldShortCircuitAST.h"
#include "compiler/translator/ValidateLimitations.h"
#include "compiler/translator/ValidateOutputs.h"
#include "compiler/translator/VariablePacker.h"
@@ -202,6 +203,12 @@
root->traverse(&initGLPosition);
}
+ if (success && (compileOptions & SH_UNFOLD_SHORT_CIRCUIT)) {
+ UnfoldShortCircuitAST unfoldShortCircuit;
+ root->traverse(&unfoldShortCircuit);
+ unfoldShortCircuit.updateTree();
+ }
+
if (success && (compileOptions & SH_VARIABLES)) {
collectVariables(root);
if (compileOptions & SH_ENFORCE_PACKING_RESTRICTIONS) {
diff --git a/src/compiler/translator/IntermTraverse.cpp b/src/compiler/translator/IntermTraverse.cpp
index 0e345d2..a579488 100644
--- a/src/compiler/translator/IntermTraverse.cpp
+++ b/src/compiler/translator/IntermTraverse.cpp
@@ -51,7 +51,7 @@
//
if (visit)
{
- it->incrementDepth();
+ it->incrementDepth(this);
if (it->rightToLeft)
{
@@ -98,7 +98,7 @@
visit = it->visitUnary(PreVisit, this);
if (visit) {
- it->incrementDepth();
+ it->incrementDepth(this);
operand->traverse(it);
it->decrementDepth();
}
@@ -119,7 +119,7 @@
if (visit)
{
- it->incrementDepth();
+ it->incrementDepth(this);
if (it->rightToLeft)
{
@@ -166,7 +166,7 @@
visit = it->visitSelection(PreVisit, this);
if (visit) {
- it->incrementDepth();
+ it->incrementDepth(this);
if (it->rightToLeft) {
if (falseBlock)
falseBlock->traverse(it);
@@ -199,7 +199,7 @@
if (visit)
{
- it->incrementDepth();
+ it->incrementDepth(this);
if (it->rightToLeft)
{
@@ -248,7 +248,7 @@
visit = it->visitBranch(PreVisit, this);
if (visit && expression) {
- it->incrementDepth();
+ it->incrementDepth(this);
expression->traverse(it);
it->decrementDepth();
}
diff --git a/src/compiler/translator/Intermediate.cpp b/src/compiler/translator/Intermediate.cpp
index 3d07459..bec0e29 100644
--- a/src/compiler/translator/Intermediate.cpp
+++ b/src/compiler/translator/Intermediate.cpp
@@ -20,11 +20,13 @@
bool CompareStructure(const TType& leftNodeType, ConstantUnion* rightUnionArray, ConstantUnion* leftUnionArray);
-static TPrecision GetHigherPrecision( TPrecision left, TPrecision right ){
+static TPrecision GetHigherPrecision(TPrecision left, TPrecision right)
+{
return left > right ? left : right;
}
-const char* getOperatorString(TOperator op) {
+const char* getOperatorString(TOperator op)
+{
switch (op) {
case EOpInitialize: return "=";
case EOpAssign: return "=";
@@ -769,6 +771,63 @@
//
////////////////////////////////////////////////////////////////
+#define REPLACE_IF_IS(node, type, original, replacement) \
+ if (node == original) { \
+ node = static_cast<type *>(replacement); \
+ return true; \
+ }
+
+bool TIntermLoop::replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement)
+{
+ REPLACE_IF_IS(init, TIntermNode, original, replacement);
+ REPLACE_IF_IS(cond, TIntermTyped, original, replacement);
+ REPLACE_IF_IS(expr, TIntermTyped, original, replacement);
+ REPLACE_IF_IS(body, TIntermNode, original, replacement);
+ return false;
+}
+
+bool TIntermBranch::replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement)
+{
+ REPLACE_IF_IS(expression, TIntermTyped, original, replacement);
+ return false;
+}
+
+bool TIntermBinary::replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement)
+{
+ REPLACE_IF_IS(left, TIntermTyped, original, replacement);
+ REPLACE_IF_IS(right, TIntermTyped, original, replacement);
+ return false;
+}
+
+bool TIntermUnary::replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement)
+{
+ REPLACE_IF_IS(operand, TIntermTyped, original, replacement);
+ return false;
+}
+
+bool TIntermAggregate::replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement)
+{
+ for (size_t ii = 0; ii < sequence.size(); ++ii)
+ {
+ REPLACE_IF_IS(sequence[ii], TIntermNode, original, replacement);
+ }
+ return false;
+}
+
+bool TIntermSelection::replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement)
+{
+ REPLACE_IF_IS(condition, TIntermTyped, original, replacement);
+ REPLACE_IF_IS(trueBlock, TIntermNode, original, replacement);
+ REPLACE_IF_IS(falseBlock, TIntermNode, original, replacement);
+ return false;
+}
+
//
// Say whether or not an operation node changes the value of a variable.
//
@@ -825,6 +884,7 @@
return false;
}
}
+
//
// Make sure the type of a unary operator is appropriate for its
// combination of operation and operand type.
diff --git a/src/compiler/translator/OutputGLSLBase.cpp b/src/compiler/translator/OutputGLSLBase.cpp
index 8a2b77a..e51bd12 100644
--- a/src/compiler/translator/OutputGLSLBase.cpp
+++ b/src/compiler/translator/OutputGLSLBase.cpp
@@ -437,7 +437,7 @@
node->getCondition()->traverse(this);
out << ")\n";
- incrementDepth();
+ incrementDepth(node);
visitCodeBlock(node->getTrueBlock());
if (node->getFalseBlock())
@@ -462,7 +462,7 @@
// Scope the sequences except when at the global scope.
if (depth > 0) out << "{\n";
- incrementDepth();
+ incrementDepth(node);
const TIntermSequence& sequence = node->getSequence();
for (TIntermSequence::const_iterator iter = sequence.begin();
iter != sequence.end(); ++iter)
@@ -500,7 +500,7 @@
writeVariableType(node->getType());
out << " " << hashFunctionName(node->getName());
- incrementDepth();
+ incrementDepth(node);
// Function definition node contains one or two children nodes
// representing function parameters and function body. The latter
// is not present in case of empty function bodies.
@@ -640,7 +640,7 @@
{
TInfoSinkBase& out = objSink();
- incrementDepth();
+ incrementDepth(node);
// Loop header.
TLoopType loopType = node->getType();
if (loopType == ELoopFor) // for loop
diff --git a/src/compiler/translator/UnfoldShortCircuitAST.cpp b/src/compiler/translator/UnfoldShortCircuitAST.cpp
new file mode 100644
index 0000000..29c4397
--- /dev/null
+++ b/src/compiler/translator/UnfoldShortCircuitAST.cpp
@@ -0,0 +1,81 @@
+//
+// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/translator/UnfoldShortCircuitAST.h"
+
+namespace
+{
+
+// "x || y" is equivalent to "x ? true : y".
+TIntermSelection *UnfoldOR(TIntermTyped *x, TIntermTyped *y)
+{
+ const TType boolType(EbtBool, EbpUndefined);
+ ConstantUnion *u = new ConstantUnion;
+ u->setBConst(true);
+ TIntermConstantUnion *trueNode = new TIntermConstantUnion(
+ u, TType(EbtBool, EbpUndefined, EvqConst, 1));
+ return new TIntermSelection(x, trueNode, y, boolType);
+}
+
+// "x && y" is equivalent to "x ? y : false".
+TIntermSelection *UnfoldAND(TIntermTyped *x, TIntermTyped *y)
+{
+ const TType boolType(EbtBool, EbpUndefined);
+ ConstantUnion *u = new ConstantUnion;
+ u->setBConst(false);
+ TIntermConstantUnion *falseNode = new TIntermConstantUnion(
+ u, TType(EbtBool, EbpUndefined, EvqConst, 1));
+ return new TIntermSelection(x, y, falseNode, boolType);
+}
+
+} // namespace anonymous
+
+bool UnfoldShortCircuitAST::visitBinary(Visit visit, TIntermBinary *node)
+{
+ TIntermSelection *replacement = NULL;
+
+ switch (node->getOp())
+ {
+ case EOpLogicalOr:
+ replacement = UnfoldOR(node->getLeft(), node->getRight());
+ break;
+ case EOpLogicalAnd:
+ replacement = UnfoldAND(node->getLeft(), node->getRight());
+ break;
+ default:
+ break;
+ }
+ if (replacement)
+ {
+ replacements.push_back(
+ NodeUpdateEntry(getParentNode(), node, replacement));
+ }
+ return true;
+}
+
+void UnfoldShortCircuitAST::updateTree()
+{
+ for (size_t ii = 0; ii < replacements.size(); ++ii)
+ {
+ const NodeUpdateEntry& entry = replacements[ii];
+ ASSERT(entry.parent);
+ bool replaced = entry.parent->replaceChildNode(
+ entry.original, entry.replacement);
+ ASSERT(replaced);
+
+ // In AST traversing, a parent is visited before its children.
+ // After we replace a node, if an immediate child is to
+ // be replaced, we need to make sure we don't update the replaced
+ // node; instead, we update the replacement node.
+ for (size_t jj = ii + 1; jj < replacements.size(); ++jj)
+ {
+ NodeUpdateEntry& entry2 = replacements[jj];
+ if (entry2.parent == entry.original)
+ entry2.parent = entry.replacement;
+ }
+ }
+}
+
diff --git a/src/compiler/translator/UnfoldShortCircuitAST.h b/src/compiler/translator/UnfoldShortCircuitAST.h
new file mode 100644
index 0000000..24c14a6
--- /dev/null
+++ b/src/compiler/translator/UnfoldShortCircuitAST.h
@@ -0,0 +1,51 @@
+//
+// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// UnfoldShortCircuitAST is an AST traverser to replace short-circuiting
+// operations with ternary operations.
+//
+
+#ifndef COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
+#define COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
+
+#include "common/angleutils.h"
+#include "compiler/translator/intermediate.h"
+
+// This traverser identifies all the short circuit binary nodes that need to
+// be replaced, and creates the corresponding replacement nodes. However,
+// the actual replacements happen after the traverse through updateTree().
+
+class UnfoldShortCircuitAST : public TIntermTraverser
+{
+ public:
+ UnfoldShortCircuitAST() { }
+
+ virtual bool visitBinary(Visit visit, TIntermBinary *);
+
+ void updateTree();
+
+ private:
+ struct NodeUpdateEntry
+ {
+ NodeUpdateEntry(TIntermNode *_parent,
+ TIntermNode *_original,
+ TIntermNode *_replacement)
+ : parent(_parent),
+ original(_original),
+ replacement(_replacement) {}
+
+ TIntermNode *parent;
+ TIntermNode *original;
+ TIntermNode *replacement;
+ };
+
+ // During traversing, save all the replacements that need to happen;
+ // then replace them by calling updateNodes().
+ std::vector<NodeUpdateEntry> replacements;
+
+ DISALLOW_COPY_AND_ASSIGN(UnfoldShortCircuitAST);
+};
+
+#endif // COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
diff --git a/src/compiler/translator/intermediate.h b/src/compiler/translator/intermediate.h
index 2dd8fad..4ec4109 100644
--- a/src/compiler/translator/intermediate.h
+++ b/src/compiler/translator/intermediate.h
@@ -238,6 +238,11 @@
virtual TIntermSymbol* getAsSymbolNode() { return 0; }
virtual TIntermLoop* getAsLoopNode() { return 0; }
+ // Replace a child node. Return true if |original| is a child
+ // node and it is replaced; otherwise, return false.
+ virtual bool replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement) = 0;
+
protected:
TSourceLoc line;
};
@@ -311,6 +316,8 @@
virtual TIntermLoop* getAsLoopNode() { return this; }
virtual void traverse(TIntermTraverser*);
+ virtual bool replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement);
TLoopType getType() const { return type; }
TIntermNode* getInit() { return init; }
@@ -341,6 +348,8 @@
expression(e) { }
virtual void traverse(TIntermTraverser*);
+ virtual bool replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement);
TOperator getFlowOp() { return flowOp; }
TIntermTyped* getExpression() { return expression; }
@@ -373,6 +382,7 @@
virtual void traverse(TIntermTraverser*);
virtual TIntermSymbol* getAsSymbolNode() { return this; }
+ virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
protected:
int id;
@@ -395,6 +405,7 @@
virtual TIntermConstantUnion* getAsConstantUnion() { return this; }
virtual void traverse(TIntermTraverser*);
+ virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
TIntermTyped* fold(TOperator, TIntermTyped*, TInfoSink&);
@@ -430,6 +441,8 @@
virtual TIntermBinary* getAsBinaryNode() { return this; }
virtual void traverse(TIntermTraverser*);
+ virtual bool replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement);
virtual bool hasSideEffects() const { return (isAssignment() || left->hasSideEffects() || right->hasSideEffects()); }
@@ -460,6 +473,8 @@
virtual void traverse(TIntermTraverser*);
virtual TIntermUnary* getAsUnaryNode() { return this; }
+ virtual bool replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement);
virtual bool hasSideEffects() const { return (isAssignment() || operand->hasSideEffects()); }
@@ -492,6 +507,8 @@
virtual TIntermAggregate* getAsAggregate() { return this; }
virtual void traverse(TIntermTraverser*);
+ virtual bool replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement);
// Conservatively assume function calls and other aggregate operators have side-effects
virtual bool hasSideEffects() const { return true; }
@@ -538,6 +555,8 @@
TIntermTyped(type), condition(cond), trueBlock(trueB), falseBlock(falseB) {}
virtual void traverse(TIntermTraverser*);
+ virtual bool replaceChildNode(
+ TIntermNode *original, TIntermNode *replacement);
// Conservatively assume selections have side-effects
virtual bool hasSideEffects() const { return true; }
@@ -580,7 +599,7 @@
rightToLeft(rightToLeft),
depth(0),
maxDepth(0) {}
- virtual ~TIntermTraverser() {};
+ virtual ~TIntermTraverser() {}
virtual void visitSymbol(TIntermSymbol*) {}
virtual void visitConstantUnion(TIntermConstantUnion*) {}
@@ -592,8 +611,24 @@
virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;}
int getMaxDepth() const {return maxDepth;}
- void incrementDepth() {depth++; maxDepth = std::max(maxDepth, depth); }
- void decrementDepth() {depth--;}
+
+ void incrementDepth(TIntermNode *current)
+ {
+ depth++;
+ maxDepth = std::max(maxDepth, depth);
+ path.push_back(current);
+ }
+
+ void decrementDepth()
+ {
+ depth--;
+ path.pop_back();
+ }
+
+ TIntermNode *getParentNode()
+ {
+ return path.size() == 0 ? NULL : path.back();
+ }
// Return the original name if hash function pointer is NULL;
// otherwise return the hashed name.
@@ -607,6 +642,9 @@
protected:
int depth;
int maxDepth;
+
+ // All the nodes from root to the current node's parent during traversing.
+ TVector<TIntermNode *> path;
};
#endif // __INTERMEDIATE_H