Add an option to unfold short circuiting in AST.

We replace "a || b" with "a ? true : b",
"a && b" with "a ? b : false".

This is to work around short circuiting bug in Mac drivers.

ANGLEBUG=482
TEST=webgl conformance tests
R=alokp@chromium.org, kbr@chromium.org

Review URL: https://codereview.appspot.com/14529048

Conflicts:

	src/build_angle.gypi
	src/compiler/translator/Compiler.cpp

Change-Id: Ic2384a97d58f54294efcb3a012deb2007a9fc658
Reviewed-on: https://chromium-review.googlesource.com/178996
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Commit-Queue: Jamie Madill <jmadill@chromium.org>
Tested-by: Shannon Woods <shannonwoods@chromium.org>
diff --git a/src/compiler/translator/Compiler.cpp b/src/compiler/translator/Compiler.cpp
index 80c97ea..909f35b 100644
--- a/src/compiler/translator/Compiler.cpp
+++ b/src/compiler/translator/Compiler.cpp
@@ -14,6 +14,7 @@
 #include "compiler/translator/ParseContext.h"
 #include "compiler/translator/RenameFunction.h"
 #include "compiler/translator/ShHandle.h"
+#include "compiler/translator/UnfoldShortCircuitAST.h"
 #include "compiler/translator/ValidateLimitations.h"
 #include "compiler/translator/ValidateOutputs.h"
 #include "compiler/translator/VariablePacker.h"
@@ -202,6 +203,12 @@
             root->traverse(&initGLPosition);
         }
 
+        if (success && (compileOptions & SH_UNFOLD_SHORT_CIRCUIT)) {
+            UnfoldShortCircuitAST unfoldShortCircuit;
+            root->traverse(&unfoldShortCircuit);
+            unfoldShortCircuit.updateTree();
+        }
+
         if (success && (compileOptions & SH_VARIABLES)) {
             collectVariables(root);
             if (compileOptions & SH_ENFORCE_PACKING_RESTRICTIONS) {
diff --git a/src/compiler/translator/IntermTraverse.cpp b/src/compiler/translator/IntermTraverse.cpp
index 0e345d2..a579488 100644
--- a/src/compiler/translator/IntermTraverse.cpp
+++ b/src/compiler/translator/IntermTraverse.cpp
@@ -51,7 +51,7 @@
     //
     if (visit)
     {
-        it->incrementDepth();
+        it->incrementDepth(this);
 
         if (it->rightToLeft)
         {
@@ -98,7 +98,7 @@
         visit = it->visitUnary(PreVisit, this);
 
     if (visit) {
-        it->incrementDepth();
+        it->incrementDepth(this);
         operand->traverse(it);
         it->decrementDepth();
     }
@@ -119,7 +119,7 @@
 
     if (visit)
     {
-        it->incrementDepth();
+        it->incrementDepth(this);
 
         if (it->rightToLeft)
         {
@@ -166,7 +166,7 @@
         visit = it->visitSelection(PreVisit, this);
 
     if (visit) {
-        it->incrementDepth();
+        it->incrementDepth(this);
         if (it->rightToLeft) {
             if (falseBlock)
                 falseBlock->traverse(it);
@@ -199,7 +199,7 @@
 
     if (visit)
     {
-        it->incrementDepth();
+        it->incrementDepth(this);
 
         if (it->rightToLeft)
         {
@@ -248,7 +248,7 @@
         visit = it->visitBranch(PreVisit, this);
 
     if (visit && expression) {
-        it->incrementDepth();
+        it->incrementDepth(this);
         expression->traverse(it);
         it->decrementDepth();
     }
diff --git a/src/compiler/translator/Intermediate.cpp b/src/compiler/translator/Intermediate.cpp
index 3d07459..bec0e29 100644
--- a/src/compiler/translator/Intermediate.cpp
+++ b/src/compiler/translator/Intermediate.cpp
@@ -20,11 +20,13 @@
 
 bool CompareStructure(const TType& leftNodeType, ConstantUnion* rightUnionArray, ConstantUnion* leftUnionArray);
 
-static TPrecision GetHigherPrecision( TPrecision left, TPrecision right ){
+static TPrecision GetHigherPrecision(TPrecision left, TPrecision right)
+{
     return left > right ? left : right;
 }
 
-const char* getOperatorString(TOperator op) {
+const char* getOperatorString(TOperator op)
+{
     switch (op) {
       case EOpInitialize: return "=";
       case EOpAssign: return "=";
@@ -769,6 +771,63 @@
 //
 ////////////////////////////////////////////////////////////////
 
+#define REPLACE_IF_IS(node, type, original, replacement) \
+    if (node == original) { \
+        node = static_cast<type *>(replacement); \
+        return true; \
+    }
+
+bool TIntermLoop::replaceChildNode(
+    TIntermNode *original, TIntermNode *replacement)
+{
+    REPLACE_IF_IS(init, TIntermNode, original, replacement);
+    REPLACE_IF_IS(cond, TIntermTyped, original, replacement);
+    REPLACE_IF_IS(expr, TIntermTyped, original, replacement);
+    REPLACE_IF_IS(body, TIntermNode, original, replacement);
+    return false;
+}
+
+bool TIntermBranch::replaceChildNode(
+    TIntermNode *original, TIntermNode *replacement)
+{
+    REPLACE_IF_IS(expression, TIntermTyped, original, replacement);
+    return false;
+}
+
+bool TIntermBinary::replaceChildNode(
+    TIntermNode *original, TIntermNode *replacement)
+{
+    REPLACE_IF_IS(left, TIntermTyped, original, replacement);
+    REPLACE_IF_IS(right, TIntermTyped, original, replacement);
+    return false;
+}
+
+bool TIntermUnary::replaceChildNode(
+    TIntermNode *original, TIntermNode *replacement)
+{
+    REPLACE_IF_IS(operand, TIntermTyped, original, replacement);
+    return false;
+}
+
+bool TIntermAggregate::replaceChildNode(
+    TIntermNode *original, TIntermNode *replacement)
+{
+    for (size_t ii = 0; ii < sequence.size(); ++ii)
+    {
+        REPLACE_IF_IS(sequence[ii], TIntermNode, original, replacement);
+    }
+    return false;
+}
+
+bool TIntermSelection::replaceChildNode(
+    TIntermNode *original, TIntermNode *replacement)
+{
+    REPLACE_IF_IS(condition, TIntermTyped, original, replacement);
+    REPLACE_IF_IS(trueBlock, TIntermNode, original, replacement);
+    REPLACE_IF_IS(falseBlock, TIntermNode, original, replacement);
+    return false;
+}
+
 //
 // Say whether or not an operation node changes the value of a variable.
 //
@@ -825,6 +884,7 @@
             return false;
     }
 }
+
 //
 // Make sure the type of a unary operator is appropriate for its
 // combination of operation and operand type.
diff --git a/src/compiler/translator/OutputGLSLBase.cpp b/src/compiler/translator/OutputGLSLBase.cpp
index 8a2b77a..e51bd12 100644
--- a/src/compiler/translator/OutputGLSLBase.cpp
+++ b/src/compiler/translator/OutputGLSLBase.cpp
@@ -437,7 +437,7 @@
         node->getCondition()->traverse(this);
         out << ")\n";
 
-        incrementDepth();
+        incrementDepth(node);
         visitCodeBlock(node->getTrueBlock());
 
         if (node->getFalseBlock())
@@ -462,7 +462,7 @@
             // Scope the sequences except when at the global scope.
             if (depth > 0) out << "{\n";
 
-            incrementDepth();
+            incrementDepth(node);
             const TIntermSequence& sequence = node->getSequence();
             for (TIntermSequence::const_iterator iter = sequence.begin();
                  iter != sequence.end(); ++iter)
@@ -500,7 +500,7 @@
             writeVariableType(node->getType());
             out << " " << hashFunctionName(node->getName());
 
-            incrementDepth();
+            incrementDepth(node);
             // Function definition node contains one or two children nodes
             // representing function parameters and function body. The latter
             // is not present in case of empty function bodies.
@@ -640,7 +640,7 @@
 {
     TInfoSinkBase& out = objSink();
 
-    incrementDepth();
+    incrementDepth(node);
     // Loop header.
     TLoopType loopType = node->getType();
     if (loopType == ELoopFor)  // for loop
diff --git a/src/compiler/translator/UnfoldShortCircuitAST.cpp b/src/compiler/translator/UnfoldShortCircuitAST.cpp
new file mode 100644
index 0000000..29c4397
--- /dev/null
+++ b/src/compiler/translator/UnfoldShortCircuitAST.cpp
@@ -0,0 +1,81 @@
+//
+// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/translator/UnfoldShortCircuitAST.h"
+
+namespace
+{
+
+// "x || y" is equivalent to "x ? true : y".
+TIntermSelection *UnfoldOR(TIntermTyped *x, TIntermTyped *y)
+{
+    const TType boolType(EbtBool, EbpUndefined);
+    ConstantUnion *u = new ConstantUnion;
+    u->setBConst(true);
+    TIntermConstantUnion *trueNode = new TIntermConstantUnion(
+        u, TType(EbtBool, EbpUndefined, EvqConst, 1));
+    return new TIntermSelection(x, trueNode, y, boolType);
+}
+
+// "x && y" is equivalent to "x ? y : false".
+TIntermSelection *UnfoldAND(TIntermTyped *x, TIntermTyped *y)
+{
+    const TType boolType(EbtBool, EbpUndefined);
+    ConstantUnion *u = new ConstantUnion;
+    u->setBConst(false);
+    TIntermConstantUnion *falseNode = new TIntermConstantUnion(
+        u, TType(EbtBool, EbpUndefined, EvqConst, 1));
+    return new TIntermSelection(x, y, falseNode, boolType);
+}
+
+}  // namespace anonymous
+
+bool UnfoldShortCircuitAST::visitBinary(Visit visit, TIntermBinary *node)
+{
+    TIntermSelection *replacement = NULL;
+
+    switch (node->getOp())
+    {
+      case EOpLogicalOr:
+        replacement = UnfoldOR(node->getLeft(), node->getRight());
+        break;
+      case EOpLogicalAnd:
+        replacement = UnfoldAND(node->getLeft(), node->getRight());
+        break;
+      default:
+        break;
+    }
+    if (replacement)
+    {
+        replacements.push_back(
+            NodeUpdateEntry(getParentNode(), node, replacement));
+    }
+    return true;
+}
+
+void UnfoldShortCircuitAST::updateTree()
+{
+    for (size_t ii = 0; ii < replacements.size(); ++ii)
+    {
+        const NodeUpdateEntry& entry = replacements[ii];
+        ASSERT(entry.parent);
+        bool replaced = entry.parent->replaceChildNode(
+            entry.original, entry.replacement);
+        ASSERT(replaced);
+
+        // In AST traversing, a parent is visited before its children.
+        // After we replace a node, if an immediate child is to
+        // be replaced, we need to make sure we don't update the replaced
+	// node; instead, we update the replacement node.
+        for (size_t jj = ii + 1; jj < replacements.size(); ++jj)
+        {
+            NodeUpdateEntry& entry2 = replacements[jj];
+            if (entry2.parent == entry.original)
+                entry2.parent = entry.replacement;
+        }
+    }
+}
+
diff --git a/src/compiler/translator/UnfoldShortCircuitAST.h b/src/compiler/translator/UnfoldShortCircuitAST.h
new file mode 100644
index 0000000..24c14a6
--- /dev/null
+++ b/src/compiler/translator/UnfoldShortCircuitAST.h
@@ -0,0 +1,51 @@
+//
+// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// UnfoldShortCircuitAST is an AST traverser to replace short-circuiting
+// operations with ternary operations.
+//
+
+#ifndef COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
+#define COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
+
+#include "common/angleutils.h"
+#include "compiler/translator/intermediate.h"
+
+// This traverser identifies all the short circuit binary  nodes that need to
+// be replaced, and creates the corresponding replacement nodes. However,
+// the actual replacements happen after the traverse through updateTree().
+
+class UnfoldShortCircuitAST : public TIntermTraverser
+{
+  public:
+    UnfoldShortCircuitAST() { }
+
+    virtual bool visitBinary(Visit visit, TIntermBinary *);
+
+    void updateTree();
+
+  private:
+    struct NodeUpdateEntry
+    {
+        NodeUpdateEntry(TIntermNode *_parent,
+                        TIntermNode *_original,
+                        TIntermNode *_replacement)
+            : parent(_parent),
+              original(_original),
+              replacement(_replacement) {}
+
+        TIntermNode *parent;
+        TIntermNode *original;
+        TIntermNode *replacement;
+    };
+
+    // During traversing, save all the replacements that need to happen;
+    // then replace them by calling updateNodes().
+    std::vector<NodeUpdateEntry> replacements;
+
+    DISALLOW_COPY_AND_ASSIGN(UnfoldShortCircuitAST);
+};
+
+#endif  // COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
diff --git a/src/compiler/translator/intermediate.h b/src/compiler/translator/intermediate.h
index 2dd8fad..4ec4109 100644
--- a/src/compiler/translator/intermediate.h
+++ b/src/compiler/translator/intermediate.h
@@ -238,6 +238,11 @@
     virtual TIntermSymbol* getAsSymbolNode() { return 0; }
     virtual TIntermLoop* getAsLoopNode() { return 0; }
 
+    // Replace a child node. Return true if |original| is a child
+    // node and it is replaced; otherwise, return false.
+    virtual bool replaceChildNode(
+        TIntermNode *original, TIntermNode *replacement) = 0;
+
 protected:
     TSourceLoc line;
 };
@@ -311,6 +316,8 @@
 
     virtual TIntermLoop* getAsLoopNode() { return this; }
     virtual void traverse(TIntermTraverser*);
+    virtual bool replaceChildNode(
+        TIntermNode *original, TIntermNode *replacement);
 
     TLoopType getType() const { return type; }
     TIntermNode* getInit() { return init; }
@@ -341,6 +348,8 @@
             expression(e) { }
 
     virtual void traverse(TIntermTraverser*);
+    virtual bool replaceChildNode(
+        TIntermNode *original, TIntermNode *replacement);
 
     TOperator getFlowOp() { return flowOp; }
     TIntermTyped* getExpression() { return expression; }
@@ -373,6 +382,7 @@
 
     virtual void traverse(TIntermTraverser*);
     virtual TIntermSymbol* getAsSymbolNode() { return this; }
+    virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
 
 protected:
     int id;
@@ -395,6 +405,7 @@
 
     virtual TIntermConstantUnion* getAsConstantUnion()  { return this; }
     virtual void traverse(TIntermTraverser*);
+    virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
 
     TIntermTyped* fold(TOperator, TIntermTyped*, TInfoSink&);
 
@@ -430,6 +441,8 @@
 
     virtual TIntermBinary* getAsBinaryNode() { return this; }
     virtual void traverse(TIntermTraverser*);
+    virtual bool replaceChildNode(
+        TIntermNode *original, TIntermNode *replacement);
 
     virtual bool hasSideEffects() const { return (isAssignment() || left->hasSideEffects() || right->hasSideEffects()); }
 
@@ -460,6 +473,8 @@
 
     virtual void traverse(TIntermTraverser*);
     virtual TIntermUnary* getAsUnaryNode() { return this; }
+    virtual bool replaceChildNode(
+        TIntermNode *original, TIntermNode *replacement);
 
     virtual bool hasSideEffects() const { return (isAssignment() || operand->hasSideEffects()); }
 
@@ -492,6 +507,8 @@
 
     virtual TIntermAggregate* getAsAggregate() { return this; }
     virtual void traverse(TIntermTraverser*);
+    virtual bool replaceChildNode(
+        TIntermNode *original, TIntermNode *replacement);
 
     // Conservatively assume function calls and other aggregate operators have side-effects
     virtual bool hasSideEffects() const { return true; }
@@ -538,6 +555,8 @@
             TIntermTyped(type), condition(cond), trueBlock(trueB), falseBlock(falseB) {}
 
     virtual void traverse(TIntermTraverser*);
+    virtual bool replaceChildNode(
+        TIntermNode *original, TIntermNode *replacement);
 
     // Conservatively assume selections have side-effects
     virtual bool hasSideEffects() const { return true; }
@@ -580,7 +599,7 @@
             rightToLeft(rightToLeft),
             depth(0),
             maxDepth(0) {}
-    virtual ~TIntermTraverser() {};
+    virtual ~TIntermTraverser() {}
 
     virtual void visitSymbol(TIntermSymbol*) {}
     virtual void visitConstantUnion(TIntermConstantUnion*) {}
@@ -592,8 +611,24 @@
     virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;}
 
     int getMaxDepth() const {return maxDepth;}
-    void incrementDepth() {depth++; maxDepth = std::max(maxDepth, depth); }
-    void decrementDepth() {depth--;}
+
+    void incrementDepth(TIntermNode *current)
+    {
+        depth++;
+        maxDepth = std::max(maxDepth, depth);
+        path.push_back(current);
+    }
+
+    void decrementDepth()
+    {
+        depth--;
+        path.pop_back();
+    }
+
+    TIntermNode *getParentNode()
+    {
+        return path.size() == 0 ? NULL : path.back();
+    }
 
     // Return the original name if hash function pointer is NULL;
     // otherwise return the hashed name.
@@ -607,6 +642,9 @@
 protected:
     int depth;
     int maxDepth;
+
+    // All the nodes from root to the current node's parent during traversing.
+    TVector<TIntermNode *> path;
 };
 
 #endif // __INTERMEDIATE_H