blob: 41897a0488e3fb64e792231554830aa620156989 [file] [log] [blame]
//
// Copyright 2020 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include <unordered_map>
#include "common/system_utils.h"
#include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
#include "compiler/translator/TranslatorMetalDirect/SeparateCompoundExpressions.h"
#include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
#include "compiler/translator/tree_util/IntermRebuild.h"
using namespace sh;
////////////////////////////////////////////////////////////////////////////////
namespace
{
bool IsIndex(TOperator op)
{
switch (op)
{
case TOperator::EOpIndexDirect:
case TOperator::EOpIndexDirectInterfaceBlock:
case TOperator::EOpIndexDirectStruct:
case TOperator::EOpIndexIndirect:
return true;
default:
return false;
}
}
bool IsIndex(TIntermTyped &expr)
{
if (auto *binary = expr.getAsBinaryNode())
{
return IsIndex(binary->getOp());
}
return expr.getAsSwizzleNode();
}
bool ViewBinaryChain(TOperator op, TIntermTyped &node, std::vector<TIntermTyped *> &out)
{
TIntermBinary *binary = node.getAsBinaryNode();
if (!binary || binary->getOp() != op)
{
return false;
}
TIntermTyped *left = binary->getLeft();
TIntermTyped *right = binary->getRight();
if (!ViewBinaryChain(op, *left, out))
{
out.push_back(left);
}
if (!ViewBinaryChain(op, *right, out))
{
out.push_back(right);
}
return true;
}
std::vector<TIntermTyped *> ViewBinaryChain(TIntermBinary &node)
{
std::vector<TIntermTyped *> chain;
ViewBinaryChain(node.getOp(), node, chain);
ASSERT(chain.size() >= 2);
return chain;
}
class PrePass : public TIntermRebuild
{
public:
PrePass(TCompiler &compiler) : TIntermRebuild(compiler, true, true) {}
private:
// Change chains of
// x OP y OP z
// to
// x OP (y OP z)
// regardless of original parenthesization.
TIntermTyped &reassociateRight(TIntermBinary &node)
{
const TOperator op = node.getOp();
std::vector<TIntermTyped *> chain = ViewBinaryChain(node);
TIntermTyped *result = chain.back();
chain.pop_back();
ASSERT(result);
const auto begin = chain.rbegin();
const auto end = chain.rend();
for (auto iter = begin; iter != end; ++iter)
{
TIntermTyped *part = *iter;
ASSERT(part);
TIntermNode *temp = rebuild(*part).single();
ASSERT(temp);
part = temp->getAsTyped();
ASSERT(part);
result = new TIntermBinary(op, part, result);
}
return *result;
}
private:
PreResult visitBinaryPre(TIntermBinary &node) override
{
const TOperator op = node.getOp();
if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
{
return {reassociateRight(node), VisitBits::Neither};
}
return node;
}
};
class Separator : public TIntermRebuild
{
IdGen &mIdGen;
std::vector<std::vector<TIntermNode *>> mStmtsStack;
std::vector<std::unordered_map<const TVariable *, TIntermDeclaration *>> mBindingMapStack;
std::unordered_map<TIntermTyped *, TIntermTyped *> mExprMap;
std::unordered_set<TIntermDeclaration *> mMaskedDecls;
public:
Separator(TCompiler &compiler, SymbolEnv &symbolEnv, IdGen &idGen)
: TIntermRebuild(compiler, true, true), mIdGen(idGen)
{}
~Separator() override
{
ASSERT(mStmtsStack.empty());
ASSERT(mExprMap.empty());
ASSERT(mBindingMapStack.empty());
}
private:
std::vector<TIntermNode *> &getCurrStmts()
{
ASSERT(!mStmtsStack.empty());
return mStmtsStack.back();
}
std::unordered_map<const TVariable *, TIntermDeclaration *> &getCurrBindingMap()
{
ASSERT(!mBindingMapStack.empty());
return mBindingMapStack.back();
}
void pushStmt(TIntermNode &node) { getCurrStmts().push_back(&node); }
bool isTerminalExpr(TIntermNode &node)
{
NodeType nodeType = getNodeType(node);
switch (nodeType)
{
case NodeType::Symbol:
case NodeType::ConstantUnion:
return true;
default:
return false;
}
}
TIntermTyped *pullMappedExpr(TIntermTyped *node, bool allowBacktrack)
{
TIntermTyped *expr;
{
auto iter = mExprMap.find(node);
if (iter == mExprMap.end())
{
return node;
}
ASSERT(node);
expr = iter->second;
ASSERT(expr);
mExprMap.erase(iter);
}
if (allowBacktrack)
{
auto &bindingMap = getCurrBindingMap();
while (TIntermSymbol *symbol = expr->getAsSymbolNode())
{
const TVariable &var = symbol->variable();
auto iter = bindingMap.find(&var);
if (iter == bindingMap.end())
{
return expr;
}
ASSERT(var.symbolType() == SymbolType::AngleInternal);
TIntermDeclaration *decl = iter->second;
ASSERT(decl);
expr = ViewDeclaration(*decl).initExpr;
ASSERT(expr);
bindingMap.erase(iter);
mMaskedDecls.insert(decl);
}
}
return expr;
}
bool isStandaloneExpr(TIntermTyped &expr)
{
if (getParentNode()->getAsBlock())
{
return true;
}
ASSERT(expr.getType().getBasicType() != TBasicType::EbtVoid);
return false;
}
void pushBinding(TIntermTyped &oldExpr, TIntermTyped &newExpr)
{
if (isStandaloneExpr(newExpr))
{
pushStmt(newExpr);
return;
}
if (IsIndex(newExpr))
{
mExprMap[&oldExpr] = &newExpr;
return;
}
auto &bindingMap = getCurrBindingMap();
const Name name = mIdGen.createNewName();
auto *var =
new TVariable(&mSymbolTable, name.rawName(), &newExpr.getType(), name.symbolType());
auto *decl = new TIntermDeclaration(var, &newExpr);
pushStmt(*decl);
mExprMap[&oldExpr] = new TIntermSymbol(var);
bindingMap[var] = decl;
}
void pushStacks()
{
mStmtsStack.emplace_back();
mBindingMapStack.emplace_back();
}
void popStacks()
{
ASSERT(!mBindingMapStack.empty());
ASSERT(!mStmtsStack.empty());
ASSERT(mStmtsStack.back().empty());
mBindingMapStack.pop_back();
mStmtsStack.pop_back();
}
void pushStmtsIntoBlock(TIntermBlock &block, std::vector<TIntermNode *> &stmts)
{
TIntermSequence &seq = *block.getSequence();
for (TIntermNode *stmt : stmts)
{
if (TIntermDeclaration *decl = stmt->getAsDeclarationNode())
{
auto iter = mMaskedDecls.find(decl);
if (iter != mMaskedDecls.end())
{
mMaskedDecls.erase(iter);
continue;
}
}
seq.push_back(stmt);
}
}
TIntermBlock &buildBlockWithTailAssign(const TVariable &var, TIntermTyped &newExpr)
{
std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
popStacks();
auto &block = *new TIntermBlock();
auto &seq = *block.getSequence();
seq.reserve(1 + stmts.size());
pushStmtsIntoBlock(block, stmts);
seq.push_back(new TIntermBinary(TOperator::EOpAssign, new TIntermSymbol(&var), &newExpr));
return block;
}
private:
PreResult visitBlockPre(TIntermBlock &node) override
{
pushStacks();
return node;
}
PostResult visitBlockPost(TIntermBlock &node) override
{
std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
popStacks();
TIntermSequence &seq = *node.getSequence();
seq.clear();
seq.reserve(stmts.size());
pushStmtsIntoBlock(node, stmts);
TIntermNode *parent = getParentNode();
if (parent && parent->getAsBlock())
{
pushStmt(node);
}
return node;
}
PreResult visitDeclarationPre(TIntermDeclaration &node) override
{
Declaration decl = ViewDeclaration(node);
if (!decl.initExpr || isTerminalExpr(*decl.initExpr))
{
pushStmt(node);
return {node, VisitBits::Neither};
}
return node;
}
PostResult visitDeclarationPost(TIntermDeclaration &node) override
{
Declaration decl = ViewDeclaration(node);
ASSERT(decl.symbol.variable().symbolType() != SymbolType::Empty);
ASSERT(!decl.symbol.variable().getType().isStructSpecifier());
TIntermTyped *newInitExpr = pullMappedExpr(decl.initExpr, true);
if (decl.initExpr == newInitExpr)
{
pushStmt(node);
}
else
{
auto &newNode = *new TIntermDeclaration();
newNode.appendDeclarator(
new TIntermBinary(TOperator::EOpInitialize, &decl.symbol, newInitExpr));
pushStmt(newNode);
}
return node;
}
PostResult visitUnaryPost(TIntermUnary &node) override
{
TIntermTyped *expr = node.getOperand();
TIntermTyped *newExpr = pullMappedExpr(expr, false);
if (expr == newExpr)
{
pushBinding(node, node);
}
else
{
pushBinding(node, *new TIntermUnary(node.getOp(), newExpr, node.getFunction()));
}
return node;
}
PreResult visitBinaryPre(TIntermBinary &node) override
{
const TOperator op = node.getOp();
if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
{
TIntermTyped *left = node.getLeft();
TIntermTyped *right = node.getRight();
PostResult leftResult = rebuild(*left);
ASSERT(leftResult.single());
pushStacks();
PostResult rightResult = rebuild(*right);
ASSERT(rightResult.single());
return {node, VisitBits::Post};
}
return node;
}
PostResult visitBinaryPost(TIntermBinary &node) override
{
const TOperator op = node.getOp();
if (op == TOperator::EOpInitialize && getParentNode()->getAsDeclarationNode())
{
// Special case is handled by visitDeclarationPost
return node;
}
TIntermTyped *left = node.getLeft();
TIntermTyped *right = node.getRight();
if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
{
const Name name = mIdGen.createNewName();
auto *var = new TVariable(&mSymbolTable, name.rawName(), new TType(TBasicType::EbtBool),
name.symbolType());
TIntermTyped *newRight = pullMappedExpr(right, true);
TIntermBlock *rightBlock = &buildBlockWithTailAssign(*var, *newRight);
TIntermTyped *newLeft = pullMappedExpr(left, true);
TIntermTyped *cond = new TIntermSymbol(var);
if (op == TOperator::EOpLogicalOr)
{
cond = new TIntermUnary(TOperator::EOpLogicalNot, cond, nullptr);
}
pushStmt(*new TIntermDeclaration(var, newLeft));
pushStmt(*new TIntermIfElse(cond, rightBlock, nullptr));
if (!isStandaloneExpr(node))
{
mExprMap[&node] = new TIntermSymbol(var);
}
return node;
}
const bool isAssign = IsAssignment(op);
TIntermTyped *newLeft = pullMappedExpr(left, false);
TIntermTyped *newRight = pullMappedExpr(right, isAssign);
if (op == TOperator::EOpComma)
{
pushBinding(node, *newRight);
return node;
}
else
{
TIntermBinary *newNode;
if (left == newLeft && right == newRight)
{
newNode = &node;
}
else
{
newNode = new TIntermBinary(op, newLeft, newRight);
}
pushBinding(node, *newNode);
return node;
}
}
PreResult visitTernaryPre(TIntermTernary &node) override
{
PostResult condResult = rebuild(*node.getCondition());
ASSERT(condResult.single());
pushStacks();
PostResult thenResult = rebuild(*node.getTrueExpression());
ASSERT(thenResult.single());
pushStacks();
PostResult elseResult = rebuild(*node.getFalseExpression());
ASSERT(elseResult.single());
return {node, VisitBits::Post};
}
PostResult visitTernaryPost(TIntermTernary &node) override
{
TIntermTyped *cond = node.getCondition();
TIntermTyped *then = node.getTrueExpression();
TIntermTyped *else_ = node.getFalseExpression();
const Name name = mIdGen.createNewName();
auto *var =
new TVariable(&mSymbolTable, name.rawName(), &node.getType(), name.symbolType());
TIntermTyped *newElse = pullMappedExpr(else_, false);
TIntermBlock *elseBlock = &buildBlockWithTailAssign(*var, *newElse);
TIntermTyped *newThen = pullMappedExpr(then, true);
TIntermBlock *thenBlock = &buildBlockWithTailAssign(*var, *newThen);
TIntermTyped *newCond = pullMappedExpr(cond, true);
pushStmt(*new TIntermDeclaration{var});
pushStmt(*new TIntermIfElse(newCond, thenBlock, elseBlock));
if (!isStandaloneExpr(node))
{
mExprMap[&node] = new TIntermSymbol(var);
}
return node;
}
PostResult visitSwizzlePost(TIntermSwizzle &node) override
{
TIntermTyped *expr = node.getOperand();
TIntermTyped *newExpr = pullMappedExpr(expr, false);
if (expr == newExpr)
{
pushBinding(node, node);
}
else
{
pushBinding(node, *new TIntermSwizzle(newExpr, node.getSwizzleOffsets()));
}
return node;
}
PostResult visitAggregatePost(TIntermAggregate &node) override
{
TIntermSequence &args = *node.getSequence();
for (TIntermNode *&arg : args)
{
TIntermTyped *targ = arg->getAsTyped();
ASSERT(targ);
arg = pullMappedExpr(targ, false);
}
pushBinding(node, node);
return node;
}
PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node) override
{
pushStmt(node);
return node;
}
PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node) override
{
if (!getParentFunction())
{
pushStmt(node);
}
return node;
}
PreResult visitCasePre(TIntermCase &node) override
{
if (TIntermTyped *cond = node.getCondition())
{
ASSERT(isTerminalExpr(*cond));
}
pushStmt(node);
return {node, VisitBits::Neither};
}
PostResult visitSwitchPost(TIntermSwitch &node) override
{
TIntermTyped *init = node.getInit();
TIntermTyped *newInit = pullMappedExpr(init, false);
if (init == newInit)
{
pushStmt(node);
}
else
{
pushStmt(*new TIntermSwitch(newInit, node.getStatementList()));
}
return node;
}
PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
{
pushStmt(node);
return node;
}
PostResult visitIfElsePost(TIntermIfElse &node) override
{
TIntermTyped *cond = node.getCondition();
TIntermTyped *newCond = pullMappedExpr(cond, false);
if (cond == newCond)
{
pushStmt(node);
}
else
{
pushStmt(*new TIntermIfElse(newCond, node.getTrueBlock(), node.getFalseBlock()));
}
return node;
}
PostResult visitBranchPost(TIntermBranch &node) override
{
TIntermTyped *expr = node.getExpression();
TIntermTyped *newExpr = pullMappedExpr(expr, false);
if (expr == newExpr)
{
pushStmt(node);
}
else
{
pushStmt(*new TIntermBranch(node.getFlowOp(), newExpr));
}
return node;
}
PreResult visitLoopPre(TIntermLoop &node) override
{
if (!rebuildInPlace(*node.getBody()))
{
UNREACHABLE();
}
pushStmt(node);
return {node, VisitBits::Neither};
}
PostResult visitConstantUnionPost(TIntermConstantUnion &node) override
{
const TType &type = node.getType();
if (!type.isScalar())
{
pushBinding(node, node);
}
return node;
}
PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node) override
{
ASSERT(false); // These should be scrubbed from AST before rewriter is called.
pushStmt(node);
return node;
}
};
} // anonymous namespace
////////////////////////////////////////////////////////////////////////////////
bool sh::SeparateCompoundExpressions(TCompiler &compiler,
SymbolEnv &symbolEnv,
IdGen &idGen,
TIntermBlock &root)
{
if (angle::GetBoolEnvironmentVar("GMT_DISABLE_SEPARATE_COMPOUND_EXPRESSIONS"))
{
return true;
}
if (!SimplifyLoopConditions(&compiler, &root, &compiler.getSymbolTable()))
{
return false;
}
if (!PrePass(compiler).rebuildRoot(root))
{
return false;
}
if (!Separator(compiler, symbolEnv, idGen).rebuildRoot(root))
{
return false;
}
return true;
}