Improvement on loop unrolling with loops indexing sampler arrays
1) Before this workaround is hardwired on mac, now we move it behind a compil
2) Fix the issue where "break" inside the loop isn't handled while unrolled.
BUG=338474
TEST=webgl conformance test sampler-array-using-loop-index.html
Change-Id: I4996a42c2dea39a8a5af772c256f8e3cb383f59a
Reviewed-on: https://chromium-review.googlesource.com/188079
Reviewed-by: Zhenyao Mo <zmo@chromium.org>
Tested-by: Zhenyao Mo <zmo@chromium.org>
Conflicts:
include/GLSLANG/ShaderLang.h
src/compiler/translator/ValidateLimitations.cpp
Change-Id: I546197bd7df1634ebccdd380be14c3250cd56151
Reviewed-on: https://chromium-review.googlesource.com/189061
Reviewed-by: Shannon Woods <shannonwoods@chromium.org>
Tested-by: Zhenyao Mo <zmo@chromium.org>
diff --git a/src/compiler/translator/ValidateLimitations.cpp b/src/compiler/translator/ValidateLimitations.cpp
index efaf177..e96a777 100644
--- a/src/compiler/translator/ValidateLimitations.cpp
+++ b/src/compiler/translator/ValidateLimitations.cpp
@@ -9,25 +9,8 @@
#include "compiler/translator/InitializeParseContext.h"
#include "compiler/translator/ParseContext.h"
-namespace {
-bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
- for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
- if (i->index.id == symbol->getId())
- return true;
- }
- return false;
-}
-
-void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
- for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
- if (i->index.id == symbol->getId()) {
- ASSERT(i->loop != NULL);
- i->loop->setUnrollFlag(true);
- return;
- }
- }
- UNREACHABLE();
-}
+namespace
+{
// Traverses a node to check if it represents a constant index expression.
// Definition:
@@ -38,114 +21,60 @@
// - Constant expressions
// - Loop indices as defined in section 4
// - Expressions composed of both of the above
-class ValidateConstIndexExpr : public TIntermTraverser {
-public:
- ValidateConstIndexExpr(const TLoopStack& stack)
+class ValidateConstIndexExpr : public TIntermTraverser
+{
+ public:
+ ValidateConstIndexExpr(TLoopStack& stack)
: mValid(true), mLoopStack(stack) {}
// Returns true if the parsed node represents a constant index expression.
bool isValid() const { return mValid; }
- virtual void visitSymbol(TIntermSymbol* symbol) {
+ virtual void visitSymbol(TIntermSymbol *symbol)
+ {
// Only constants and loop indices are allowed in a
// constant index expression.
- if (mValid) {
+ if (mValid)
+ {
mValid = (symbol->getQualifier() == EvqConst) ||
- IsLoopIndex(symbol, mLoopStack);
+ (mLoopStack.findLoop(symbol));
}
}
-private:
+ private:
bool mValid;
- const TLoopStack& mLoopStack;
-};
-
-// Traverses a node to check if it uses a loop index.
-// If an int loop index is used in its body as a sampler array index,
-// mark the loop for unroll.
-class ValidateLoopIndexExpr : public TIntermTraverser {
-public:
- ValidateLoopIndexExpr(TLoopStack& stack)
- : mUsesFloatLoopIndex(false),
- mUsesIntLoopIndex(false),
- mLoopStack(stack) {}
-
- bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
- bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
-
- virtual void visitSymbol(TIntermSymbol* symbol) {
- if (IsLoopIndex(symbol, mLoopStack)) {
- switch (symbol->getBasicType()) {
- case EbtFloat:
- mUsesFloatLoopIndex = true;
- break;
- case EbtUInt:
- mUsesIntLoopIndex = true;
- MarkLoopForUnroll(symbol, mLoopStack);
- break;
- case EbtInt:
- mUsesIntLoopIndex = true;
- MarkLoopForUnroll(symbol, mLoopStack);
- break;
- default:
- UNREACHABLE();
- }
- }
- }
-
-private:
- bool mUsesFloatLoopIndex;
- bool mUsesIntLoopIndex;
TLoopStack& mLoopStack;
};
-} // namespace
+
+} // namespace anonymous
ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
- TInfoSinkBase& sink)
+ TInfoSinkBase &sink)
: mShaderType(shaderType),
mSink(sink),
mNumErrors(0)
{
}
-bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
+bool ValidateLimitations::visitBinary(Visit, TIntermBinary *node)
{
// Check if loop index is modified in the loop body.
validateOperation(node, node->getLeft());
// Check indexing.
- switch (node->getOp()) {
+ switch (node->getOp())
+ {
case EOpIndexDirect:
- validateIndexing(node);
- break;
case EOpIndexIndirect:
-#if defined(__APPLE__)
- // Loop unrolling is a work-around for a Mac Cg compiler bug where it
- // crashes when a sampler array's index is also the loop index.
- // Once Apple fixes this bug, we should remove the code in this CL.
- // See http://codereview.appspot.com/4331048/.
- if ((node->getLeft() != NULL) && (node->getRight() != NULL) &&
- (node->getLeft()->getAsSymbolNode())) {
- TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode();
- if (IsSampler(symbol->getBasicType()) && symbol->isArray()) {
- ValidateLoopIndexExpr validate(mLoopStack);
- node->getRight()->traverse(&validate);
- if (validate.usesFloatLoopIndex()) {
- error(node->getLine(),
- "sampler array index is float loop index",
- "for");
- }
- }
- }
-#endif
validateIndexing(node);
break;
- default: break;
+ default:
+ break;
}
return true;
}
-bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
+bool ValidateLimitations::visitUnary(Visit, TIntermUnary *node)
{
// Check if loop index is modified in the loop body.
validateOperation(node, node->getOperand());
@@ -153,7 +82,7 @@
return true;
}
-bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
+bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate *node)
{
switch (node->getOp()) {
case EOpFunctionCall:
@@ -165,22 +94,20 @@
return true;
}
-bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
+bool ValidateLimitations::visitLoop(Visit, TIntermLoop *node)
{
if (!validateLoopType(node))
return false;
- TLoopInfo info;
- memset(&info, 0, sizeof(TLoopInfo));
- info.loop = node;
- if (!validateForLoopHeader(node, &info))
+ if (!validateForLoopHeader(node))
return false;
- TIntermNode* body = node->getBody();
- if (body != NULL) {
- mLoopStack.push_back(info);
+ TIntermNode *body = node->getBody();
+ if (body != NULL)
+ {
+ mLoopStack.push(node);
body->traverse(this);
- mLoopStack.pop_back();
+ mLoopStack.pop();
}
// The loop is fully processed - no need to visit children.
@@ -188,7 +115,7 @@
}
void ValidateLimitations::error(TSourceLoc loc,
- const char *reason, const char* token)
+ const char *reason, const char *token)
{
mSink.prefix(EPrefixError);
mSink.location(loc);
@@ -201,12 +128,13 @@
return !mLoopStack.empty();
}
-bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
+bool ValidateLimitations::isLoopIndex(TIntermSymbol *symbol)
{
- return IsLoopIndex(symbol, mLoopStack);
+ return mLoopStack.findLoop(symbol) != NULL;
}
-bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
+bool ValidateLimitations::validateLoopType(TIntermLoop *node)
+{
TLoopType type = node->getType();
if (type == ELoopFor)
return true;
@@ -218,8 +146,7 @@
return false;
}
-bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
- TLoopInfo* info)
+bool ValidateLimitations::validateForLoopHeader(TIntermLoop *node)
{
ASSERT(node->getType() == ELoopFor);
@@ -227,74 +154,80 @@
// The for statement has the form:
// for ( init-declaration ; condition ; expression ) statement
//
- if (!validateForLoopInit(node, info))
+ int indexSymbolId = validateForLoopInit(node);
+ if (indexSymbolId < 0)
return false;
- if (!validateForLoopCond(node, info))
+ if (!validateForLoopCond(node, indexSymbolId))
return false;
- if (!validateForLoopExpr(node, info))
+ if (!validateForLoopExpr(node, indexSymbolId))
return false;
return true;
}
-bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
- TLoopInfo* info)
+int ValidateLimitations::validateForLoopInit(TIntermLoop *node)
{
- TIntermNode* init = node->getInit();
- if (init == NULL) {
+ TIntermNode *init = node->getInit();
+ if (init == NULL)
+ {
error(node->getLine(), "Missing init declaration", "for");
- return false;
+ return -1;
}
//
// init-declaration has the form:
// type-specifier identifier = constant-expression
//
- TIntermAggregate* decl = init->getAsAggregate();
- if ((decl == NULL) || (decl->getOp() != EOpDeclaration)) {
+ TIntermAggregate *decl = init->getAsAggregate();
+ if ((decl == NULL) || (decl->getOp() != EOpDeclaration))
+ {
error(init->getLine(), "Invalid init declaration", "for");
- return false;
+ return -1;
}
// To keep things simple do not allow declaration list.
- TIntermSequence& declSeq = decl->getSequence();
- if (declSeq.size() != 1) {
+ TIntermSequence &declSeq = decl->getSequence();
+ if (declSeq.size() != 1)
+ {
error(decl->getLine(), "Invalid init declaration", "for");
- return false;
+ return -1;
}
- TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
- if ((declInit == NULL) || (declInit->getOp() != EOpInitialize)) {
+ TIntermBinary *declInit = declSeq[0]->getAsBinaryNode();
+ if ((declInit == NULL) || (declInit->getOp() != EOpInitialize))
+ {
error(decl->getLine(), "Invalid init declaration", "for");
- return false;
+ return -1;
}
- TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
- if (symbol == NULL) {
+ TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
+ if (symbol == NULL)
+ {
error(declInit->getLine(), "Invalid init declaration", "for");
- return false;
+ return -1;
}
// The loop index has type int or float.
TBasicType type = symbol->getBasicType();
if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat)) {
error(symbol->getLine(),
"Invalid type for loop index", getBasicString(type));
- return false;
+ return -1;
}
// The loop index is initialized with constant expression.
- if (!isConstExpr(declInit->getRight())) {
+ if (!isConstExpr(declInit->getRight()))
+ {
error(declInit->getLine(),
"Loop index cannot be initialized with non-constant expression",
symbol->getSymbol().c_str());
- return false;
+ return -1;
}
- info->index.id = symbol->getId();
- return true;
+ return symbol->getId();
}
-bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
- TLoopInfo* info)
+bool ValidateLimitations::validateForLoopCond(TIntermLoop *node,
+ int indexSymbolId)
{
- TIntermNode* cond = node->getCondition();
- if (cond == NULL) {
+ TIntermNode *cond = node->getCondition();
+ if (cond == NULL)
+ {
error(node->getLine(), "Missing condition", "for");
return false;
}
@@ -302,24 +235,28 @@
// condition has the form:
// loop_index relational_operator constant_expression
//
- TIntermBinary* binOp = cond->getAsBinaryNode();
- if (binOp == NULL) {
+ TIntermBinary *binOp = cond->getAsBinaryNode();
+ if (binOp == NULL)
+ {
error(node->getLine(), "Invalid condition", "for");
return false;
}
// Loop index should be to the left of relational operator.
- TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
- if (symbol == NULL) {
+ TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
+ if (symbol == NULL)
+ {
error(binOp->getLine(), "Invalid condition", "for");
return false;
}
- if (symbol->getId() != info->index.id) {
+ if (symbol->getId() != indexSymbolId)
+ {
error(symbol->getLine(),
"Expected loop index", symbol->getSymbol().c_str());
return false;
}
// Relational operator is one of: > >= < <= == or !=.
- switch (binOp->getOp()) {
+ switch (binOp->getOp())
+ {
case EOpEqual:
case EOpNotEqual:
case EOpLessThan:
@@ -334,7 +271,8 @@
break;
}
// Loop index must be compared with a constant.
- if (!isConstExpr(binOp->getRight())) {
+ if (!isConstExpr(binOp->getRight()))
+ {
error(binOp->getLine(),
"Loop index cannot be compared with non-constant expression",
symbol->getSymbol().c_str());
@@ -344,11 +282,12 @@
return true;
}
-bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
- TLoopInfo* info)
+bool ValidateLimitations::validateForLoopExpr(TIntermLoop *node,
+ int indexSymbolId)
{
- TIntermNode* expr = node->getExpression();
- if (expr == NULL) {
+ TIntermNode *expr = node->getExpression();
+ if (expr == NULL)
+ {
error(node->getLine(), "Missing expression", "for");
return false;
}
@@ -362,50 +301,58 @@
// --loop_index
// The last two forms are not specified in the spec, but I am assuming
// its an oversight.
- TIntermUnary* unOp = expr->getAsUnaryNode();
- TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
+ TIntermUnary *unOp = expr->getAsUnaryNode();
+ TIntermBinary *binOp = unOp ? NULL : expr->getAsBinaryNode();
TOperator op = EOpNull;
- TIntermSymbol* symbol = NULL;
- if (unOp != NULL) {
+ TIntermSymbol *symbol = NULL;
+ if (unOp != NULL)
+ {
op = unOp->getOp();
symbol = unOp->getOperand()->getAsSymbolNode();
- } else if (binOp != NULL) {
+ }
+ else if (binOp != NULL)
+ {
op = binOp->getOp();
symbol = binOp->getLeft()->getAsSymbolNode();
}
// The operand must be loop index.
- if (symbol == NULL) {
+ if (symbol == NULL)
+ {
error(expr->getLine(), "Invalid expression", "for");
return false;
}
- if (symbol->getId() != info->index.id) {
+ if (symbol->getId() != indexSymbolId)
+ {
error(symbol->getLine(),
"Expected loop index", symbol->getSymbol().c_str());
return false;
}
// The operator is one of: ++ -- += -=.
- switch (op) {
- case EOpPostIncrement:
- case EOpPostDecrement:
- case EOpPreIncrement:
- case EOpPreDecrement:
- ASSERT((unOp != NULL) && (binOp == NULL));
- break;
- case EOpAddAssign:
- case EOpSubAssign:
- ASSERT((unOp == NULL) && (binOp != NULL));
- break;
- default:
- error(expr->getLine(), "Invalid operator", getOperatorString(op));
- return false;
+ switch (op)
+ {
+ case EOpPostIncrement:
+ case EOpPostDecrement:
+ case EOpPreIncrement:
+ case EOpPreDecrement:
+ ASSERT((unOp != NULL) && (binOp == NULL));
+ break;
+ case EOpAddAssign:
+ case EOpSubAssign:
+ ASSERT((unOp == NULL) && (binOp != NULL));
+ break;
+ default:
+ error(expr->getLine(), "Invalid operator", getOperatorString(op));
+ return false;
}
// Loop index must be incremented/decremented with a constant.
- if (binOp != NULL) {
- if (!isConstExpr(binOp->getRight())) {
+ if (binOp != NULL)
+ {
+ if (!isConstExpr(binOp->getRight()))
+ {
error(binOp->getLine(),
"Loop index cannot be modified by non-constant expression",
symbol->getSymbol().c_str());
@@ -416,7 +363,7 @@
return true;
}
-bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
+bool ValidateLimitations::validateFunctionCall(TIntermAggregate *node)
{
ASSERT(node->getOp() == EOpFunctionCall);
@@ -428,8 +375,9 @@
typedef std::vector<size_t> ParamIndex;
ParamIndex pIndex;
TIntermSequence& params = node->getSequence();
- for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
- TIntermSymbol* symbol = params[i]->getAsSymbolNode();
+ for (TIntermSequence::size_type i = 0; i < params.size(); ++i)
+ {
+ TIntermSymbol *symbol = params[i]->getAsSymbolNode();
if (symbol && isLoopIndex(symbol))
pIndex.push_back(i);
}
@@ -442,12 +390,14 @@
TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->shaderVersion);
ASSERT(symbol && symbol->isFunction());
- TFunction* function = static_cast<TFunction*>(symbol);
+ TFunction *function = static_cast<TFunction *>(symbol);
for (ParamIndex::const_iterator i = pIndex.begin();
- i != pIndex.end(); ++i) {
- const TParameter& param = function->getParam(*i);
+ i != pIndex.end(); ++i)
+ {
+ const TParameter ¶m = function->getParam(*i);
TQualifier qual = param.type->getQualifier();
- if ((qual == EvqOut) || (qual == EvqInOut)) {
+ if ((qual == EvqOut) || (qual == EvqInOut))
+ {
error(params[*i]->getLine(),
"Loop index cannot be used as argument to a function out or inout parameter",
params[*i]->getAsSymbolNode()->getSymbol().c_str());
@@ -458,14 +408,16 @@
return valid;
}
-bool ValidateLimitations::validateOperation(TIntermOperator* node,
- TIntermNode* operand) {
+bool ValidateLimitations::validateOperation(TIntermOperator *node,
+ TIntermNode* operand)
+{
// Check if loop index is modified in the loop body.
if (!withinLoopBody() || !node->isAssignment())
return true;
- const TIntermSymbol* symbol = operand->getAsSymbolNode();
- if (symbol && isLoopIndex(symbol)) {
+ TIntermSymbol *symbol = operand->getAsSymbolNode();
+ if (symbol && isLoopIndex(symbol))
+ {
error(node->getLine(),
"Loop index cannot be statically assigned to within the body of the loop",
symbol->getSymbol().c_str());
@@ -473,13 +425,13 @@
return true;
}
-bool ValidateLimitations::isConstExpr(TIntermNode* node)
+bool ValidateLimitations::isConstExpr(TIntermNode *node)
{
ASSERT(node != NULL);
return node->getAsConstantUnion() != NULL;
}
-bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
+bool ValidateLimitations::isConstIndexExpr(TIntermNode *node)
{
ASSERT(node != NULL);
@@ -488,13 +440,13 @@
return validate.isValid();
}
-bool ValidateLimitations::validateIndexing(TIntermBinary* node)
+bool ValidateLimitations::validateIndexing(TIntermBinary *node)
{
ASSERT((node->getOp() == EOpIndexDirect) ||
(node->getOp() == EOpIndexIndirect));
bool valid = true;
- TIntermTyped* index = node->getRight();
+ TIntermTyped *index = node->getRight();
// The index expression must have integral type.
if (!index->isScalarInt()) {
error(index->getLine(),
@@ -504,10 +456,11 @@
}
// The index expession must be a constant-index-expression unless
// the operand is a uniform in a vertex shader.
- TIntermTyped* operand = node->getLeft();
+ TIntermTyped *operand = node->getLeft();
bool skip = (mShaderType == SH_VERTEX_SHADER) &&
(operand->getQualifier() == EvqUniform);
- if (!skip && !isConstIndexExpr(index)) {
+ if (!skip && !isConstIndexExpr(index))
+ {
error(index->getLine(), "Index expression must be constant", "[]");
valid = false;
}