moved SkSL Ternary data into IRNode
Change-Id: I70e63aaa73082024c8f0887a941d54cfd12aa2b6
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/323883
Reviewed-by: John Stiles <johnstiles@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index 0a7714e..7187f15 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -408,8 +408,8 @@
case Expression::Kind::kTernary: {
auto& t = e.template as<TernaryExpression>();
- return this->visitExpression(*t.fTest) || this->visitExpression(*t.fIfTrue) ||
- this->visitExpression(*t.fIfFalse);
+ return this->visitExpression(*t.test()) || this->visitExpression(*t.ifTrue()) ||
+ this->visitExpression(*t.ifFalse());
}
default:
SkUNREACHABLE;
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index fa55aa9..9e2b778 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -1427,14 +1427,14 @@
void ByteCodeGenerator::writeTernaryExpression(const TernaryExpression& t) {
int count = SlotCount(t.type());
- SkASSERT(count == SlotCount(t.fIfTrue->type()));
- SkASSERT(count == SlotCount(t.fIfFalse->type()));
+ SkASSERT(count == SlotCount(t.ifTrue()->type()));
+ SkASSERT(count == SlotCount(t.ifFalse()->type()));
- this->writeExpression(*t.fTest);
+ this->writeExpression(*t.test());
this->write(ByteCodeInstruction::kMaskPush);
- this->writeExpression(*t.fIfTrue);
+ this->writeExpression(*t.ifTrue());
this->write(ByteCodeInstruction::kMaskNegate);
- this->writeExpression(*t.fIfFalse);
+ this->writeExpression(*t.ifFalse());
this->write(ByteCodeInstruction::kMaskBlend, count);
}
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp
index 084073c..cd6a17d 100644
--- a/src/sksl/SkSLCFGGenerator.cpp
+++ b/src/sksl/SkSLCFGGenerator.cpp
@@ -141,13 +141,13 @@
}
case Expression::Kind::kTernary: {
TernaryExpression& ternary = lvalue->as<TernaryExpression>();
- if (!this->tryRemoveExpressionBefore(iter, ternary.fTest.get())) {
+ if (!this->tryRemoveExpressionBefore(iter, ternary.test().get())) {
return false;
}
- if (!this->tryRemoveLValueBefore(iter, ternary.fIfTrue.get())) {
+ if (!this->tryRemoveLValueBefore(iter, ternary.ifTrue().get())) {
return false;
}
- return this->tryRemoveLValueBefore(iter, ternary.fIfFalse.get());
+ return this->tryRemoveLValueBefore(iter, ternary.ifFalse().get());
}
default:
#ifdef SK_DEBUG
@@ -413,15 +413,15 @@
break;
case Expression::Kind::kTernary: {
TernaryExpression& t = e->get()->as<TernaryExpression>();
- this->addExpression(cfg, &t.fTest, constantPropagate);
+ this->addExpression(cfg, &t.test(), constantPropagate);
cfg.currentBlock().fNodes.push_back(BasicBlock::MakeExpression(e, constantPropagate));
BlockId start = cfg.fCurrent;
cfg.newBlock();
- this->addExpression(cfg, &t.fIfTrue, constantPropagate);
+ this->addExpression(cfg, &t.ifTrue(), constantPropagate);
BlockId next = cfg.newBlock();
cfg.fCurrent = start;
cfg.newBlock();
- this->addExpression(cfg, &t.fIfFalse, constantPropagate);
+ this->addExpression(cfg, &t.ifFalse(), constantPropagate);
cfg.addExit(cfg.fCurrent, next);
cfg.fCurrent = next;
break;
@@ -454,12 +454,12 @@
break;
case Expression::Kind::kTernary: {
TernaryExpression& ternary = e->get()->as<TernaryExpression>();
- this->addExpression(cfg, &ternary.fTest, /*constantPropagate=*/true);
+ this->addExpression(cfg, &ternary.test(), /*constantPropagate=*/true);
// Technically we will of course only evaluate one or the other, but if the test turns
// out to be constant, the ternary will get collapsed down to just one branch anyway. So
// it should be ok to pretend that we always evaluate both branches here.
- this->addLValue(cfg, &ternary.fIfTrue);
- this->addLValue(cfg, &ternary.fIfFalse);
+ this->addLValue(cfg, &ternary.ifTrue());
+ this->addLValue(cfg, &ternary.ifFalse());
break;
}
default:
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index adbf374..85dbbc7 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -466,10 +466,10 @@
// To simplify analysis, we just pretend that we write to both sides of the ternary.
// This allows for false positives (meaning we fail to detect that a variable might not
// have been assigned), but is preferable to false negatives.
- this->addDefinition(lvalue->as<TernaryExpression>().fIfTrue.get(),
+ this->addDefinition(lvalue->as<TernaryExpression>().ifTrue().get(),
(std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions);
- this->addDefinition(lvalue->as<TernaryExpression>().fIfFalse.get(),
+ this->addDefinition(lvalue->as<TernaryExpression>().ifFalse().get(),
(std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions);
break;
@@ -633,7 +633,7 @@
}
case Expression::Kind::kTernary: {
const TernaryExpression& t = lvalue.as<TernaryExpression>();
- return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse);
+ return !t.test()->hasSideEffects() && is_dead(*t.ifTrue()) && is_dead(*t.ifFalse());
}
case Expression::Kind::kExternalValue:
return false;
@@ -926,13 +926,13 @@
}
case Expression::Kind::kTernary: {
TernaryExpression* t = &expr->as<TernaryExpression>();
- if (t->fTest->kind() == Expression::Kind::kBoolLiteral) {
+ if (t->test()->is<BoolLiteral>()) {
// ternary has a constant test, replace it with either the true or
// false branch
- if (t->fTest->as<BoolLiteral>().value()) {
- (*iter)->setExpression(std::move(t->fIfTrue));
+ if (t->test()->as<BoolLiteral>().value()) {
+ (*iter)->setExpression(std::move(t->ifTrue()));
} else {
- (*iter)->setExpression(std::move(t->fIfFalse));
+ (*iter)->setExpression(std::move(t->ifFalse()));
}
*outUpdated = true;
*outNeedsRescan = true;
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index 679e5ed..2b81eaa 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -370,9 +370,9 @@
case Expression::Kind::kTernary: {
const TernaryExpression& t = e->as<TernaryExpression>();
this->writeU8(Rehydrator::kTernary_Command);
- this->write(t.fTest.get());
- this->write(t.fIfTrue.get());
- this->write(t.fIfFalse.get());
+ this->write(t.test().get());
+ this->write(t.ifTrue().get());
+ this->write(t.ifFalse().get());
break;
}
case Expression::Kind::kVariableReference: {
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index e259727..8eef942 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -976,11 +976,11 @@
if (kTernary_Precedence >= parentPrecedence) {
this->write("(");
}
- this->writeExpression(*t.fTest, kTernary_Precedence);
+ this->writeExpression(*t.test(), kTernary_Precedence);
this->write(" ? ");
- this->writeExpression(*t.fIfTrue, kTernary_Precedence);
+ this->writeExpression(*t.ifTrue(), kTernary_Precedence);
this->write(" : ");
- this->writeExpression(*t.fIfFalse, kTernary_Precedence);
+ this->writeExpression(*t.ifFalse(), kTernary_Precedence);
if (kTernary_Precedence >= parentPrecedence) {
this->write(")");
}
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index c705dde..41f35f0 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -408,8 +408,8 @@
}
case Expression::Kind::kTernary: {
const TernaryExpression& t = expression.as<TernaryExpression>();
- return std::make_unique<TernaryExpression>(offset, expr(t.fTest),
- expr(t.fIfTrue), expr(t.fIfFalse));
+ return std::make_unique<TernaryExpression>(offset, expr(t.test()),
+ expr(t.ifTrue()), expr(t.ifFalse()));
}
case Expression::Kind::kTypeReference:
return expression.clone();
@@ -1043,7 +1043,7 @@
case Expression::Kind::kTernary: {
TernaryExpression& ternaryExpr = (*expr)->as<TernaryExpression>();
// The test expression is a candidate for inlining.
- this->visitExpression(&ternaryExpr.fTest);
+ this->visitExpression(&ternaryExpr.test());
// The true- and false-expressions cannot be inlined, because we are only allowed to
// evaluate one side.
break;
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index c5df516..980d499 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -876,11 +876,11 @@
if (kTernary_Precedence >= parentPrecedence) {
this->write("(");
}
- this->writeExpression(*t.fTest, kTernary_Precedence);
+ this->writeExpression(*t.test(), kTernary_Precedence);
this->write(" ? ");
- this->writeExpression(*t.fIfTrue, kTernary_Precedence);
+ this->writeExpression(*t.ifTrue(), kTernary_Precedence);
this->write(" : ");
- this->writeExpression(*t.fIfFalse, kTernary_Precedence);
+ this->writeExpression(*t.ifFalse(), kTernary_Precedence);
if (kTernary_Precedence >= parentPrecedence) {
this->write(")");
}
@@ -1711,8 +1711,8 @@
return this->requirements(e->as<PostfixExpression>().fOperand.get());
case Expression::Kind::kTernary: {
const TernaryExpression& t = e->as<TernaryExpression>();
- return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) |
- this->requirements(t.fIfFalse.get());
+ return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
+ this->requirements(t.ifFalse().get());
}
case Expression::Kind::kVariableReference: {
const VariableReference& v = e->as<VariableReference>();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 9f45c5b..63a3543 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -1788,18 +1788,18 @@
}
case Expression::Kind::kTernary: {
TernaryExpression& t = (TernaryExpression&) expr;
- SpvId test = this->writeExpression(*t.fTest, out);
+ SpvId test = this->writeExpression(*t.test(), out);
SpvId end = this->nextId();
SpvId ifTrueLabel = this->nextId();
SpvId ifFalseLabel = this->nextId();
this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
this->writeLabel(ifTrueLabel, out);
- SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
+ SpvId ifTrue = this->getLValue(*t.ifTrue(), out)->getPointer();
SkASSERT(ifTrue);
this->writeInstruction(SpvOpBranch, end, out);
ifTrueLabel = fCurrentBlock;
- SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
+ SpvId ifFalse = this->getLValue(*t.ifFalse(), out)->getPointer();
SkASSERT(ifFalse);
ifFalseLabel = fCurrentBlock;
this->writeInstruction(SpvOpBranch, end, out);
@@ -2389,14 +2389,14 @@
SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
const Type& type = t.type();
- SpvId test = this->writeExpression(*t.fTest, out);
- if (t.fIfTrue->type().columns() == 1 &&
- t.fIfTrue->isCompileTimeConstant() &&
- t.fIfFalse->isCompileTimeConstant()) {
+ SpvId test = this->writeExpression(*t.test(), out);
+ if (t.ifTrue()->type().columns() == 1 &&
+ t.ifTrue()->isCompileTimeConstant() &&
+ t.ifFalse()->isCompileTimeConstant()) {
// both true and false are constants, can just use OpSelect
SpvId result = this->nextId();
- SpvId trueId = this->writeExpression(*t.fIfTrue, out);
- SpvId falseId = this->writeExpression(*t.fIfFalse, out);
+ SpvId trueId = this->writeExpression(*t.ifTrue(), out);
+ SpvId falseId = this->writeExpression(*t.ifFalse(), out);
this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
out);
return result;
@@ -2412,10 +2412,10 @@
this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
this->writeLabel(trueLabel, out);
- this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
+ this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifTrue(), out), out);
this->writeInstruction(SpvOpBranch, end, out);
this->writeLabel(falseLabel, out);
- this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
+ this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifFalse(), out), out);
this->writeInstruction(SpvOpBranch, end, out);
this->writeLabel(end, out);
SpvId result = this->nextId();
diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h
index 5ef6240..dcfce09 100644
--- a/src/sksl/ir/SkSLBinaryExpression.h
+++ b/src/sksl/ir/SkSLBinaryExpression.h
@@ -31,7 +31,7 @@
return check_ref(*expr.as<Swizzle>().fBase);
case Expression::Kind::kTernary: {
const TernaryExpression& t = expr.as<TernaryExpression>();
- return check_ref(*t.fIfTrue) && check_ref(*t.fIfFalse);
+ return check_ref(*t.ifTrue()) && check_ref(*t.ifFalse());
}
case Expression::Kind::kVariableReference: {
const VariableReference& ref = expr.as<VariableReference>();
diff --git a/src/sksl/ir/SkSLTernaryExpression.h b/src/sksl/ir/SkSLTernaryExpression.h
index f590bf8..c5aef10 100644
--- a/src/sksl/ir/SkSLTernaryExpression.h
+++ b/src/sksl/ir/SkSLTernaryExpression.h
@@ -16,43 +16,66 @@
/**
* A ternary expression (test ? ifTrue : ifFalse).
*/
-struct TernaryExpression : public Expression {
+class TernaryExpression : public Expression {
+public:
static constexpr Kind kExpressionKind = Kind::kTernary;
TernaryExpression(int offset, std::unique_ptr<Expression> test,
std::unique_ptr<Expression> ifTrue, std::unique_ptr<Expression> ifFalse)
- : INHERITED(offset, kExpressionKind, &ifTrue->type())
- , fTest(std::move(test))
- , fIfTrue(std::move(ifTrue))
- , fIfFalse(std::move(ifFalse)) {
- SkASSERT(fIfTrue->type() == fIfFalse->type());
+ : INHERITED(offset, kExpressionKind, &ifTrue->type()) {
+ SkASSERT(ifTrue->type() == ifFalse->type());
+ fExpressionChildren.reserve(3);
+ fExpressionChildren.push_back(std::move(test));
+ fExpressionChildren.push_back(std::move(ifTrue));
+ fExpressionChildren.push_back(std::move(ifFalse));
+ }
+
+ std::unique_ptr<Expression>& test() {
+ return fExpressionChildren[0];
+ }
+
+ const std::unique_ptr<Expression>& test() const {
+ return fExpressionChildren[0];
+ }
+
+ std::unique_ptr<Expression>& ifTrue() {
+ return fExpressionChildren[1];
+ }
+
+ const std::unique_ptr<Expression>& ifTrue() const {
+ return fExpressionChildren[1];
+ }
+
+ std::unique_ptr<Expression>& ifFalse() {
+ return fExpressionChildren[2];
+ }
+
+ const std::unique_ptr<Expression>& ifFalse() const {
+ return fExpressionChildren[2];
}
bool hasProperty(Property property) const override {
- return fTest->hasProperty(property) || fIfTrue->hasProperty(property) ||
- fIfFalse->hasProperty(property);
+ return this->test()->hasProperty(property) || this->ifTrue()->hasProperty(property) ||
+ this->ifFalse()->hasProperty(property);
}
bool isConstantOrUniform() const override {
- return fTest->isConstantOrUniform() && fIfTrue->isConstantOrUniform() &&
- fIfFalse->isConstantOrUniform();
+ return this->test()->isConstantOrUniform() && this->ifTrue()->isConstantOrUniform() &&
+ this->ifFalse()->isConstantOrUniform();
}
std::unique_ptr<Expression> clone() const override {
- return std::unique_ptr<Expression>(new TernaryExpression(fOffset, fTest->clone(),
- fIfTrue->clone(),
- fIfFalse->clone()));
+ return std::unique_ptr<Expression>(new TernaryExpression(fOffset, this->test()->clone(),
+ this->ifTrue()->clone(),
+ this->ifFalse()->clone()));
}
String description() const override {
- return "(" + fTest->description() + " ? " + fIfTrue->description() + " : " +
- fIfFalse->description() + ")";
+ return "(" + this->test()->description() + " ? " + this->ifTrue()->description() + " : " +
+ this->ifFalse()->description() + ")";
}
- std::unique_ptr<Expression> fTest;
- std::unique_ptr<Expression> fIfTrue;
- std::unique_ptr<Expression> fIfFalse;
-
+private:
using INHERITED = Expression;
};