Make UnfoldShortCircuit to change AST instead of writing output
This is needed to make way for further AST transformations to handle array
expressions that need to work correctly together with unfolding short-
circuiting operators. This also improves the maintainability of HLSL output
by isolating the unfolding into a separate compilation step.
The new version of UnfoldShortCircuit traverser will traverse the tree until
an expression that needs to be unfolded is encountered. It then unfolds it and
gets reset. The traverser will be run repeatedly until no more operations to
unfold are found. This helps with keeping the traverser's design relatively
simple.
All declarations are separated to single declarations before short-circuit
unfolding is run. Previously OutputHLSL already output every declaration
separately.
BUG=angleproject:960
TEST=WebGL conformance tests, angle_unittests, angle_end2end_tests
Change-Id: Id769be396adbd4c0223e418980dc464dd855f019
Reviewed-on: https://chromium-review.googlesource.com/270460
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Jamie Madill <jmadill@chromium.org>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index b339798..2c29144 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -209,6 +209,21 @@
return false;
}
+bool TIntermAggregate::insertChildNodes(TIntermSequence::size_type position, TIntermSequence insertions)
+{
+ TIntermSequence::size_type itPosition = 0;
+ for (auto it = mSequence.begin(); it < mSequence.end(); ++it)
+ {
+ if (itPosition == position)
+ {
+ mSequence.insert(it, insertions.begin(), insertions.end());
+ return true;
+ }
+ ++itPosition;
+ }
+ return false;
+}
+
void TIntermAggregate::setPrecisionFromChildren()
{
if (getBasicType() == EbtBool)
@@ -1488,6 +1503,14 @@
void TIntermTraverser::updateTree()
{
+ for (size_t ii = 0; ii < mInsertions.size(); ++ii)
+ {
+ const NodeInsertMultipleEntry &insertion = mInsertions[ii];
+ ASSERT(insertion.parent);
+ bool inserted = insertion.parent->insertChildNodes(insertion.position, insertion.insertions);
+ ASSERT(inserted);
+ UNUSED_ASSERTION_VARIABLE(inserted);
+ }
for (size_t ii = 0; ii < mReplacements.size(); ++ii)
{
const NodeUpdateEntry &replacement = mReplacements[ii];
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 5b8145f..e650485 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -434,6 +434,7 @@
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement);
bool replaceChildNodeWithMultiple(TIntermNode *original, TIntermSequence replacements);
+ bool insertChildNodes(TIntermSequence::size_type position, TIntermSequence insertions);
// Conservatively assume function calls and other aggregate operators have side-effects
virtual bool hasSideEffects() const { return true; }
@@ -680,11 +681,27 @@
TIntermSequence replacements;
};
+ // To insert multiple nodes on the parent aggregate node
+ struct NodeInsertMultipleEntry
+ {
+ NodeInsertMultipleEntry(TIntermAggregate *_parent, TIntermSequence::size_type _position, TIntermSequence _insertions)
+ : parent(_parent),
+ position(_position),
+ insertions(_insertions)
+ {
+ }
+
+ TIntermAggregate *parent;
+ TIntermSequence::size_type position;
+ TIntermSequence insertions;
+ };
+
// During traversing, save all the changes that need to happen into
// mReplacements/mMultiReplacements, then do them by calling updateTree().
// Multi replacements are processed after single replacements.
std::vector<NodeUpdateEntry> mReplacements;
std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements;
+ std::vector<NodeInsertMultipleEntry> mInsertions;
};
//
diff --git a/src/compiler/translator/OutputHLSL.cpp b/src/compiler/translator/OutputHLSL.cpp
index cb7931a..903f6e3 100644
--- a/src/compiler/translator/OutputHLSL.cpp
+++ b/src/compiler/translator/OutputHLSL.cpp
@@ -23,7 +23,6 @@
#include "compiler/translator/SearchSymbol.h"
#include "compiler/translator/StructureHLSL.h"
#include "compiler/translator/TranslatorHLSL.h"
-#include "compiler/translator/UnfoldShortCircuit.h"
#include "compiler/translator/UniformHLSL.h"
#include "compiler/translator/UtilsHLSL.h"
#include "compiler/translator/blocklayout.h"
@@ -111,7 +110,6 @@
mCompileOptions(compileOptions),
mCurrentFunctionMetadata(nullptr)
{
- mUnfoldShortCircuit = new UnfoldShortCircuit(this);
mInsideFunction = false;
mUsesFragColor = false;
@@ -153,7 +151,6 @@
OutputHLSL::~OutputHLSL()
{
- SafeDelete(mUnfoldShortCircuit);
SafeDelete(mStructureHLSL);
SafeDelete(mUniformHLSL);
for (auto &eqFunction : mStructEqualityFunctions)
@@ -1672,31 +1669,19 @@
case EOpMatrixTimesVector: outputTriplet(visit, "mul(transpose(", "), ", ")"); break;
case EOpMatrixTimesMatrix: outputTriplet(visit, "transpose(mul(transpose(", "), transpose(", ")))"); break;
case EOpLogicalOr:
- if (node->getRight()->hasSideEffects())
- {
- out << "s" << mUnfoldShortCircuit->getNextTemporaryIndex();
- return false;
- }
- else
- {
- outputTriplet(visit, "(", " || ", ")");
- return true;
- }
+ // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have been unfolded.
+ ASSERT(!node->getRight()->hasSideEffects());
+ outputTriplet(visit, "(", " || ", ")");
+ return true;
case EOpLogicalXor:
mUsesXor = true;
outputTriplet(visit, "xor(", ", ", ")");
break;
case EOpLogicalAnd:
- if (node->getRight()->hasSideEffects())
- {
- out << "s" << mUnfoldShortCircuit->getNextTemporaryIndex();
- return false;
- }
- else
- {
- outputTriplet(visit, "(", " && ", ")");
- return true;
- }
+ // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have been unfolded.
+ ASSERT(!node->getRight()->hasSideEffects());
+ outputTriplet(visit, "(", " && ", ")");
+ return true;
default: UNREACHABLE();
}
@@ -1854,12 +1839,15 @@
{
outputLineDirective((*sit)->getLine().first_line);
- traverseStatements(*sit);
+ (*sit)->traverse(this);
// Don't output ; after case labels, they're terminated by :
// This is needed especially since outputting a ; after a case statement would turn empty
// case statements into non-empty case statements, disallowing fall-through from them.
- if ((*sit)->getAsCaseNode() == nullptr)
+ // Also no need to output ; after selection (if) statements. This is done just for code clarity.
+ TIntermSelection *asSelection = (*sit)->getAsSelectionNode();
+ ASSERT(asSelection == nullptr || !asSelection->usesTernaryOperator());
+ if ((*sit)->getAsCaseNode() == nullptr && asSelection == nullptr)
out << ";\n";
}
@@ -1876,6 +1864,7 @@
{
TIntermSequence *sequence = node->getSequence();
TIntermTyped *variable = (*sequence)[0]->getAsTyped();
+ ASSERT(sequence->size() == 1);
if (variable && (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal))
{
@@ -1883,37 +1872,24 @@
if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "") // Variable declaration
{
- for (const auto &seqElement : *sequence)
+ if (!mInsideFunction)
{
- if (isSingleStatement(seqElement))
- {
- mUnfoldShortCircuit->traverse(seqElement);
- }
+ out << "static ";
+ }
- if (!mInsideFunction)
- {
- out << "static ";
- }
+ out << TypeString(variable->getType()) + " ";
- out << TypeString(variable->getType()) + " ";
+ TIntermSymbol *symbol = variable->getAsSymbolNode();
- TIntermSymbol *symbol = seqElement->getAsSymbolNode();
-
- if (symbol)
- {
- symbol->traverse(this);
- out << ArrayString(symbol->getType());
- out << " = " + initializer(symbol->getType());
- }
- else
- {
- seqElement->traverse(this);
- }
-
- if (seqElement != sequence->back())
- {
- out << ";\n";
- }
+ if (symbol)
+ {
+ symbol->traverse(this);
+ out << ArrayString(symbol->getType());
+ out << " = " + initializer(symbol->getType());
+ }
+ else
+ {
+ variable->traverse(this);
}
}
else if (variable->getAsSymbolNode() && variable->getAsSymbolNode()->getSymbol() == "") // Type (struct) declaration
@@ -2300,66 +2276,63 @@
{
TInfoSinkBase &out = getInfoSink();
- if (node->usesTernaryOperator())
+ ASSERT(!node->usesTernaryOperator());
+
+ // D3D errors when there is a gradient operation in a loop in an unflattened if.
+ // We check for null mCurrentFunctionMetadata to prevent crashing in the case that the translator has generated if
+ // statements in the global scope when unfolding global initializers. This is a bug that should be addressed by
+ // moving the unfolded global initializers into a function.
+ if (mShaderType == GL_FRAGMENT_SHADER
+ && mCurrentFunctionMetadata != nullptr
+ && mCurrentFunctionMetadata->hasDiscontinuousLoop(node)
+ && mCurrentFunctionMetadata->hasGradientInCallGraph(node))
{
- out << "s" << mUnfoldShortCircuit->getNextTemporaryIndex();
+ out << "FLATTEN ";
}
- else // if/else statement
+
+ out << "if (";
+
+ node->getCondition()->traverse(this);
+
+ out << ")\n";
+
+ outputLineDirective(node->getLine().first_line);
+ out << "{\n";
+
+ bool discard = false;
+
+ if (node->getTrueBlock())
{
- mUnfoldShortCircuit->traverse(node->getCondition());
+ node->getTrueBlock()->traverse(this);
- // D3D errors when there is a gradient operation in a loop in an unflattened if.
- if (mShaderType == GL_FRAGMENT_SHADER
- && mCurrentFunctionMetadata->hasDiscontinuousLoop(node)
- && mCurrentFunctionMetadata->hasGradientInCallGraph(node))
- {
- out << "FLATTEN ";
- }
+ // Detect true discard
+ discard = (discard || FindDiscard::search(node->getTrueBlock()));
+ }
- out << "if (";
+ outputLineDirective(node->getLine().first_line);
+ out << ";\n}\n";
- node->getCondition()->traverse(this);
+ if (node->getFalseBlock())
+ {
+ out << "else\n";
- out << ")\n";
-
- outputLineDirective(node->getLine().first_line);
+ outputLineDirective(node->getFalseBlock()->getLine().first_line);
out << "{\n";
- bool discard = false;
+ outputLineDirective(node->getFalseBlock()->getLine().first_line);
+ node->getFalseBlock()->traverse(this);
- if (node->getTrueBlock())
- {
- traverseStatements(node->getTrueBlock());
-
- // Detect true discard
- discard = (discard || FindDiscard::search(node->getTrueBlock()));
- }
-
- outputLineDirective(node->getLine().first_line);
+ outputLineDirective(node->getFalseBlock()->getLine().first_line);
out << ";\n}\n";
- if (node->getFalseBlock())
- {
- out << "else\n";
+ // Detect false discard
+ discard = (discard || FindDiscard::search(node->getFalseBlock()));
+ }
- outputLineDirective(node->getFalseBlock()->getLine().first_line);
- out << "{\n";
-
- outputLineDirective(node->getFalseBlock()->getLine().first_line);
- traverseStatements(node->getFalseBlock());
-
- outputLineDirective(node->getFalseBlock()->getLine().first_line);
- out << ";\n}\n";
-
- // Detect false discard
- discard = (discard || FindDiscard::search(node->getFalseBlock()));
- }
-
- // ANGLE issue 486: Detect problematic conditional discard
- if (discard && FindSideEffectRewriting::search(node))
- {
- mUsesDiscardRewriting = true;
- }
+ // ANGLE issue 486: Detect problematic conditional discard
+ if (discard && FindSideEffectRewriting::search(node))
+ {
+ mUsesDiscardRewriting = true;
}
return false;
@@ -2461,7 +2434,7 @@
if (node->getBody())
{
- traverseStatements(node->getBody());
+ node->getBody()->traverse(this);
}
outputLineDirective(node->getLine().first_line);
@@ -2541,16 +2514,6 @@
return true;
}
-void OutputHLSL::traverseStatements(TIntermNode *node)
-{
- if (isSingleStatement(node))
- {
- mUnfoldShortCircuit->traverse(node);
- }
-
- node->traverse(this);
-}
-
bool OutputHLSL::isSingleStatement(TIntermNode *node)
{
TIntermAggregate *aggregate = node->getAsAggregate();
diff --git a/src/compiler/translator/OutputHLSL.h b/src/compiler/translator/OutputHLSL.h
index 59f305c..ba1d2ad 100644
--- a/src/compiler/translator/OutputHLSL.h
+++ b/src/compiler/translator/OutputHLSL.h
@@ -63,7 +63,6 @@
bool visitLoop(Visit visit, TIntermLoop*);
bool visitBranch(Visit visit, TIntermBranch*);
- void traverseStatements(TIntermNode *node);
bool isSingleStatement(TIntermNode *node);
bool handleExcessiveLoop(TIntermLoop *node);
@@ -103,7 +102,6 @@
const ShShaderOutput mOutputType;
int mCompileOptions;
- UnfoldShortCircuit *mUnfoldShortCircuit;
bool mInsideFunction;
// Output streams
diff --git a/src/compiler/translator/SeparateDeclarations.cpp b/src/compiler/translator/SeparateDeclarations.cpp
index 7ea1ce0..d33747f 100644
--- a/src/compiler/translator/SeparateDeclarations.cpp
+++ b/src/compiler/translator/SeparateDeclarations.cpp
@@ -3,9 +3,10 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
-// The SeparateArrayDeclarations function processes declarations that contain array declarators. Each declarator in
-// such declarations gets its own declaration.
-// This is useful as an intermediate step when initialization needs to be separated from declaration.
+// The SeparateDeclarations function processes declarations, so that in the end each declaration
+// contains only one declarator.
+// This is useful as an intermediate step when initialization needs to be separated from declaration,
+// or when things need to be unfolded out of the initializer.
// Example:
// int a[1] = int[1](1), b[1] = int[1](2);
// gets transformed when run through this class into the AST equivalent of:
@@ -19,43 +20,33 @@
namespace
{
-class SeparateDeclarations : private TIntermTraverser
+class SeparateDeclarationsTraverser : private TIntermTraverser
{
public:
static void apply(TIntermNode *root);
private:
- SeparateDeclarations();
+ SeparateDeclarationsTraverser();
bool visitAggregate(Visit, TIntermAggregate *node) override;
};
-void SeparateDeclarations::apply(TIntermNode *root)
+void SeparateDeclarationsTraverser::apply(TIntermNode *root)
{
- SeparateDeclarations separateDecl;
+ SeparateDeclarationsTraverser separateDecl;
root->traverse(&separateDecl);
separateDecl.updateTree();
}
-SeparateDeclarations::SeparateDeclarations()
+SeparateDeclarationsTraverser::SeparateDeclarationsTraverser()
: TIntermTraverser(true, false, false)
{
}
-bool SeparateDeclarations::visitAggregate(Visit, TIntermAggregate *node)
+bool SeparateDeclarationsTraverser::visitAggregate(Visit, TIntermAggregate *node)
{
if (node->getOp() == EOpDeclaration)
{
TIntermSequence *sequence = node->getSequence();
- bool sequenceContainsArrays = false;
- for (size_t ii = 0; ii < sequence->size(); ++ii)
- {
- TIntermTyped *typed = sequence->at(ii)->getAsTyped();
- if (typed != nullptr && typed->isArray())
- {
- sequenceContainsArrays = true;
- break;
- }
- }
- if (sequence->size() > 1 && sequenceContainsArrays)
+ if (sequence->size() > 1)
{
TIntermAggregate *parentAgg = getParentNode()->getAsAggregate();
ASSERT(parentAgg != nullptr);
@@ -80,7 +71,7 @@
} // namespace
-void SeparateArrayDeclarations(TIntermNode *root)
+void SeparateDeclarations(TIntermNode *root)
{
- SeparateDeclarations::apply(root);
+ SeparateDeclarationsTraverser::apply(root);
}
diff --git a/src/compiler/translator/SeparateDeclarations.h b/src/compiler/translator/SeparateDeclarations.h
index 2aa5294..77913ab 100644
--- a/src/compiler/translator/SeparateDeclarations.h
+++ b/src/compiler/translator/SeparateDeclarations.h
@@ -3,9 +3,10 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
-// The SeparateArrayDeclarations function processes declarations that contain array declarators. Each declarator in
-// such declarations gets its own declaration.
-// This is useful as an intermediate step when initialization needs to be separated from declaration.
+// The SeparateDeclarations function processes declarations, so that in the end each declaration
+// contains only one declarator.
+// This is useful as an intermediate step when initialization needs to be separated from declaration,
+// or when things need to be unfolded out of the initializer.
// Example:
// int a[1] = int[1](1), b[1] = int[1](2);
// gets transformed when run through this class into the AST equivalent of:
@@ -17,6 +18,6 @@
class TIntermNode;
-void SeparateArrayDeclarations(TIntermNode *root);
+void SeparateDeclarations(TIntermNode *root);
#endif // COMPILER_TRANSLATOR_SEPARATEDECLARATIONS_H_
diff --git a/src/compiler/translator/TranslatorHLSL.cpp b/src/compiler/translator/TranslatorHLSL.cpp
index a59b65e..2c4224c 100644
--- a/src/compiler/translator/TranslatorHLSL.cpp
+++ b/src/compiler/translator/TranslatorHLSL.cpp
@@ -11,6 +11,7 @@
#include "compiler/translator/SeparateArrayInitialization.h"
#include "compiler/translator/SeparateDeclarations.h"
#include "compiler/translator/SimplifyArrayAssignment.h"
+#include "compiler/translator/UnfoldShortCircuitToIf.h"
TranslatorHLSL::TranslatorHLSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
: TCompiler(type, spec, output)
@@ -22,7 +23,10 @@
const ShBuiltInResources &resources = getResources();
int numRenderTargets = resources.EXT_draw_buffers ? resources.MaxDrawBuffers : 1;
- SeparateArrayDeclarations(root);
+ SeparateDeclarations(root);
+
+ // Note that SeparateDeclarations needs to be run before UnfoldShortCircuitToIf.
+ UnfoldShortCircuitToIf(root);
// Note that SeparateDeclarations needs to be run before SeparateArrayInitialization.
SeparateArrayInitialization(root);
diff --git a/src/compiler/translator/UnfoldShortCircuit.cpp b/src/compiler/translator/UnfoldShortCircuit.cpp
deleted file mode 100644
index f79f9dd..0000000
--- a/src/compiler/translator/UnfoldShortCircuit.cpp
+++ /dev/null
@@ -1,185 +0,0 @@
-//
-// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-//
-// UnfoldShortCircuit is an AST traverser to output short-circuiting operators as if-else statements.
-// The results are assigned to s# temporaries, which are used by the main translator instead of
-// the original expression.
-//
-
-#include "compiler/translator/UnfoldShortCircuit.h"
-
-#include "compiler/translator/InfoSink.h"
-#include "compiler/translator/OutputHLSL.h"
-#include "compiler/translator/UtilsHLSL.h"
-
-namespace sh
-{
-UnfoldShortCircuit::UnfoldShortCircuit(OutputHLSL *outputHLSL) : mOutputHLSL(outputHLSL)
-{
- mTemporaryIndex = 0;
-}
-
-void UnfoldShortCircuit::traverse(TIntermNode *node)
-{
- int rewindIndex = mTemporaryIndex;
- node->traverse(this);
- mTemporaryIndex = rewindIndex;
-}
-
-bool UnfoldShortCircuit::visitBinary(Visit visit, TIntermBinary *node)
-{
- TInfoSinkBase &out = mOutputHLSL->getInfoSink();
-
- // If our right node doesn't have side effects, we know we don't need to unfold this
- // expression: there will be no short-circuiting side effects to avoid
- // (note: unfolding doesn't depend on the left node -- it will always be evaluated)
- if (!node->getRight()->hasSideEffects())
- {
- return true;
- }
-
- switch (node->getOp())
- {
- case EOpLogicalOr:
- // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; else s = y;",
- // and then further simplifies down to "bool s = x; if(!s) s = y;".
- {
- int i = mTemporaryIndex;
-
- out << "bool s" << i << ";\n";
-
- out << "{\n";
-
- mTemporaryIndex = i + 1;
- node->getLeft()->traverse(this);
- out << "s" << i << " = ";
- mTemporaryIndex = i + 1;
- node->getLeft()->traverse(mOutputHLSL);
- out << ";\n";
- out << "if (!s" << i << ")\n"
- "{\n";
- mTemporaryIndex = i + 1;
- node->getRight()->traverse(this);
- out << " s" << i << " = ";
- mTemporaryIndex = i + 1;
- node->getRight()->traverse(mOutputHLSL);
- out << ";\n"
- "}\n";
-
- out << "}\n";
-
- mTemporaryIndex = i + 1;
- }
- return false;
- case EOpLogicalAnd:
- // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; else s = false;",
- // and then further simplifies down to "bool s = x; if(s) s = y;".
- {
- int i = mTemporaryIndex;
-
- out << "bool s" << i << ";\n";
-
- out << "{\n";
-
- mTemporaryIndex = i + 1;
- node->getLeft()->traverse(this);
- out << "s" << i << " = ";
- mTemporaryIndex = i + 1;
- node->getLeft()->traverse(mOutputHLSL);
- out << ";\n";
- out << "if (s" << i << ")\n"
- "{\n";
- mTemporaryIndex = i + 1;
- node->getRight()->traverse(this);
- out << " s" << i << " = ";
- mTemporaryIndex = i + 1;
- node->getRight()->traverse(mOutputHLSL);
- out << ";\n"
- "}\n";
-
- out << "}\n";
-
- mTemporaryIndex = i + 1;
- }
- return false;
- default:
- return true;
- }
-}
-
-bool UnfoldShortCircuit::visitSelection(Visit visit, TIntermSelection *node)
-{
- TInfoSinkBase &out = mOutputHLSL->getInfoSink();
-
- // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
- if (node->usesTernaryOperator())
- {
- int i = mTemporaryIndex;
-
- out << TypeString(node->getType()) << " s" << i << ";\n";
-
- out << "{\n";
-
- mTemporaryIndex = i + 1;
- node->getCondition()->traverse(this);
- out << "if (";
- mTemporaryIndex = i + 1;
- node->getCondition()->traverse(mOutputHLSL);
- out << ")\n"
- "{\n";
- mTemporaryIndex = i + 1;
- node->getTrueBlock()->traverse(this);
- out << " s" << i << " = ";
- mTemporaryIndex = i + 1;
- node->getTrueBlock()->traverse(mOutputHLSL);
- out << ";\n"
- "}\n"
- "else\n"
- "{\n";
- mTemporaryIndex = i + 1;
- node->getFalseBlock()->traverse(this);
- out << " s" << i << " = ";
- mTemporaryIndex = i + 1;
- node->getFalseBlock()->traverse(mOutputHLSL);
- out << ";\n"
- "}\n";
-
- out << "}\n";
-
- mTemporaryIndex = i + 1;
- }
-
- return false;
-}
-
-bool UnfoldShortCircuit::visitLoop(Visit visit, TIntermLoop *node)
-{
- int rewindIndex = mTemporaryIndex;
-
- if (node->getInit())
- {
- node->getInit()->traverse(this);
- }
-
- if (node->getCondition())
- {
- node->getCondition()->traverse(this);
- }
-
- if (node->getExpression())
- {
- node->getExpression()->traverse(this);
- }
-
- mTemporaryIndex = rewindIndex;
-
- return false;
-}
-
-int UnfoldShortCircuit::getNextTemporaryIndex()
-{
- return mTemporaryIndex++;
-}
-}
diff --git a/src/compiler/translator/UnfoldShortCircuit.h b/src/compiler/translator/UnfoldShortCircuit.h
deleted file mode 100644
index eaceb0a..0000000
--- a/src/compiler/translator/UnfoldShortCircuit.h
+++ /dev/null
@@ -1,38 +0,0 @@
-//
-// Copyright (c) 2002-2012 The ANGLE Project Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-//
-// UnfoldShortCircuit is an AST traverser to output short-circuiting operators as if-else statements
-//
-
-#ifndef COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_
-#define COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_
-
-#include "compiler/translator/IntermNode.h"
-#include "compiler/translator/ParseContext.h"
-
-namespace sh
-{
-class OutputHLSL;
-
-class UnfoldShortCircuit : public TIntermTraverser
-{
- public:
- UnfoldShortCircuit(OutputHLSL *outputHLSL);
-
- void traverse(TIntermNode *node);
- bool visitBinary(Visit visit, TIntermBinary*);
- bool visitSelection(Visit visit, TIntermSelection *node);
- bool visitLoop(Visit visit, TIntermLoop *node);
-
- int getNextTemporaryIndex();
-
- protected:
- OutputHLSL *const mOutputHLSL;
-
- int mTemporaryIndex;
-};
-}
-
-#endif // COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_
diff --git a/src/compiler/translator/UnfoldShortCircuitToIf.cpp b/src/compiler/translator/UnfoldShortCircuitToIf.cpp
new file mode 100644
index 0000000..c7a74a5
--- /dev/null
+++ b/src/compiler/translator/UnfoldShortCircuitToIf.cpp
@@ -0,0 +1,253 @@
+//
+// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// UnfoldShortCircuitToIf is an AST traverser to convert short-circuiting operators to if-else statements.
+// The results are assigned to s# temporaries, which are used by the main translator instead of
+// the original expression.
+//
+
+#include "compiler/translator/UnfoldShortCircuitToIf.h"
+
+#include "compiler/translator/InfoSink.h"
+#include "compiler/translator/IntermNode.h"
+
+namespace
+{
+
+// Traverser that unfolds one short-circuiting operation at a time.
+class UnfoldShortCircuitTraverser : public TIntermTraverser
+{
+ public:
+ UnfoldShortCircuitTraverser();
+
+ void traverse(TIntermNode *node);
+ bool visitBinary(Visit visit, TIntermBinary *node) override;
+ bool visitAggregate(Visit visit, TIntermAggregate *node) override;
+ bool visitSelection(Visit visit, TIntermSelection *node) override;
+
+ void nextIteration();
+ bool foundShortCircuit() const { return mFoundShortCircuit; }
+
+ protected:
+ int mTemporaryIndex;
+
+ // Marked to true once an operation that needs to be unfolded has been found.
+ // After that, no more unfolding is performed on that traversal.
+ bool mFoundShortCircuit;
+
+ struct ParentBlock
+ {
+ ParentBlock(TIntermAggregate *_node, TIntermSequence::size_type _pos)
+ : node(_node),
+ pos(_pos)
+ {
+ }
+
+ TIntermAggregate *node;
+ TIntermSequence::size_type pos;
+ };
+ std::vector<ParentBlock> mParentBlockStack;
+
+ TIntermSymbol *createTempSymbol(const TType &type);
+ TIntermAggregate *createTempInitDeclaration(const TType &type, TIntermTyped *initializer);
+ TIntermBinary *createTempAssignment(const TType &type, TIntermTyped *rightNode);
+};
+
+UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser()
+ : TIntermTraverser(true, true, true),
+ mTemporaryIndex(0),
+ mFoundShortCircuit(false)
+{
+}
+
+TIntermSymbol *UnfoldShortCircuitTraverser::createTempSymbol(const TType &type)
+{
+ // Each traversal uses at most one temporary variable, so the index stays the same within a single traversal.
+ TInfoSinkBase symbolNameOut;
+ symbolNameOut << "s" << mTemporaryIndex;
+ TString symbolName = symbolNameOut.c_str();
+
+ TIntermSymbol *node = new TIntermSymbol(0, symbolName, type);
+ node->setInternal(true);
+ return node;
+}
+
+TIntermAggregate *UnfoldShortCircuitTraverser::createTempInitDeclaration(const TType &type, TIntermTyped *initializer)
+{
+ ASSERT(initializer != nullptr);
+ TIntermSymbol *tempSymbol = createTempSymbol(type);
+ TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
+ TIntermBinary *tempInit = new TIntermBinary(EOpInitialize);
+ tempInit->setLeft(tempSymbol);
+ tempInit->setRight(initializer);
+ tempInit->setType(type);
+ tempDeclaration->getSequence()->push_back(tempInit);
+ return tempDeclaration;
+}
+
+TIntermBinary *UnfoldShortCircuitTraverser::createTempAssignment(const TType &type, TIntermTyped *rightNode)
+{
+ ASSERT(rightNode != nullptr);
+ TIntermSymbol *tempSymbol = createTempSymbol(type);
+ TIntermBinary *assignment = new TIntermBinary(EOpAssign);
+ assignment->setLeft(tempSymbol);
+ assignment->setRight(rightNode);
+ assignment->setType(type);
+ return assignment;
+}
+
+bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
+{
+ if (mFoundShortCircuit)
+ return false;
+ // If our right node doesn't have side effects, we know we don't need to unfold this
+ // expression: there will be no short-circuiting side effects to avoid
+ // (note: unfolding doesn't depend on the left node -- it will always be evaluated)
+ if (!node->getRight()->hasSideEffects())
+ {
+ return true;
+ }
+
+ switch (node->getOp())
+ {
+ case EOpLogicalOr:
+ mFoundShortCircuit = true;
+
+ // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; else s = y;",
+ // and then further simplifies down to "bool s = x; if(!s) s = y;".
+ {
+ TIntermSequence insertions;
+ TType boolType(EbtBool, EbpUndefined, EvqTemporary);
+
+ insertions.push_back(createTempInitDeclaration(boolType, node->getLeft()));
+
+ TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence);
+ assignRightBlock->getSequence()->push_back(createTempAssignment(boolType, node->getRight()));
+
+ TIntermUnary *notTempSymbol = new TIntermUnary(EOpLogicalNot, boolType);
+ notTempSymbol->setOperand(createTempSymbol(boolType));
+ TIntermSelection *ifNode = new TIntermSelection(notTempSymbol, assignRightBlock, nullptr);
+ insertions.push_back(ifNode);
+
+ NodeInsertMultipleEntry insert(mParentBlockStack.back().node, mParentBlockStack.back().pos, insertions);
+ mInsertions.push_back(insert);
+
+ NodeUpdateEntry replaceVariable(getParentNode(), node, createTempSymbol(boolType), false);
+ mReplacements.push_back(replaceVariable);
+ }
+ return false;
+ case EOpLogicalAnd:
+ mFoundShortCircuit = true;
+
+ // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; else s = false;",
+ // and then further simplifies down to "bool s = x; if(s) s = y;".
+ {
+ TIntermSequence insertions;
+ TType boolType(EbtBool, EbpUndefined, EvqTemporary);
+
+ insertions.push_back(createTempInitDeclaration(boolType, node->getLeft()));
+
+ TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence);
+ assignRightBlock->getSequence()->push_back(createTempAssignment(boolType, node->getRight()));
+
+ TIntermSelection *ifNode = new TIntermSelection(createTempSymbol(boolType), assignRightBlock, nullptr);
+ insertions.push_back(ifNode);
+
+ NodeInsertMultipleEntry insert(mParentBlockStack.back().node, mParentBlockStack.back().pos, insertions);
+ mInsertions.push_back(insert);
+
+ NodeUpdateEntry replaceVariable(getParentNode(), node, createTempSymbol(boolType), false);
+ mReplacements.push_back(replaceVariable);
+ }
+ return false;
+ default:
+ return true;
+ }
+}
+
+bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *node)
+{
+ if (mFoundShortCircuit)
+ return false;
+
+ // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
+ if (visit == PreVisit && node->usesTernaryOperator())
+ {
+ mFoundShortCircuit = true;
+ TIntermSequence insertions;
+
+ TIntermSymbol *tempSymbol = createTempSymbol(node->getType());
+ TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
+ tempDeclaration->getSequence()->push_back(tempSymbol);
+ insertions.push_back(tempDeclaration);
+
+ TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence);
+ TIntermBinary *trueAssignment = createTempAssignment(node->getType(), node->getTrueBlock()->getAsTyped());
+ trueBlock->getSequence()->push_back(trueAssignment);
+
+ TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence);
+ TIntermBinary *falseAssignment = createTempAssignment(node->getType(), node->getFalseBlock()->getAsTyped());
+ falseBlock->getSequence()->push_back(falseAssignment);
+
+ TIntermSelection *ifNode = new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock);
+ insertions.push_back(ifNode);
+
+ NodeInsertMultipleEntry insert(mParentBlockStack.back().node, mParentBlockStack.back().pos, insertions);
+ mInsertions.push_back(insert);
+
+ TIntermSymbol *ternaryResult = createTempSymbol(node->getType());
+ NodeUpdateEntry replaceVariable(getParentNode(), node, ternaryResult, false);
+ mReplacements.push_back(replaceVariable);
+ return false;
+ }
+
+ return true;
+}
+
+bool UnfoldShortCircuitTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
+{
+ if (node->getOp() == EOpSequence)
+ {
+ if (visit == PreVisit)
+ {
+ mParentBlockStack.push_back(ParentBlock(node, 0));
+ }
+ else if (visit == InVisit)
+ {
+ ++mParentBlockStack.back().pos;
+ }
+ else
+ {
+ ASSERT(visit == PostVisit);
+ mParentBlockStack.pop_back();
+ }
+ }
+ return true;
+}
+
+void UnfoldShortCircuitTraverser::nextIteration()
+{
+ mFoundShortCircuit = false;
+ mTemporaryIndex++;
+ mReplacements.clear();
+ mMultiReplacements.clear();
+ mInsertions.clear();
+}
+
+} // namespace
+
+void UnfoldShortCircuitToIf(TIntermNode *root)
+{
+ UnfoldShortCircuitTraverser traverser;
+ // Unfold one operator at a time, and reset the traverser between iterations.
+ do
+ {
+ traverser.nextIteration();
+ root->traverse(&traverser);
+ if (traverser.foundShortCircuit())
+ traverser.updateTree();
+ }
+ while (traverser.foundShortCircuit());
+}
diff --git a/src/compiler/translator/UnfoldShortCircuitToIf.h b/src/compiler/translator/UnfoldShortCircuitToIf.h
new file mode 100644
index 0000000..1d67b00
--- /dev/null
+++ b/src/compiler/translator/UnfoldShortCircuitToIf.h
@@ -0,0 +1,18 @@
+//
+// Copyright (c) 2002-2012 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// UnfoldShortCircuitToIf is an AST traverser to convert short-circuiting operators to if-else statements.
+// The results are assigned to s# temporaries, which are used by the main translator instead of
+// the original expression.
+//
+
+#ifndef COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_
+#define COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_
+
+class TIntermNode;
+
+void UnfoldShortCircuitToIf(TIntermNode *root);
+
+#endif // COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_