Add as<ProgramElementSubclass> to downcast ProgramElements more safely.
The as<T>() function asserts that the ProgramElement is of the correct
kind before performing the downcast, and is also generally easier to
read as function calls flow naturally from left-to-right, and C-style
casts don't.
This CL updates several downcasts throughout SkSL to the as<T>
syntax, but is not intended to exhaustively replace them all (although
that would be ideal). In places where we SkASSERTed the element's
fKind immediately before a cast, the assert has been removed because it
would be redundant with the behavior of as<T>().
Change-Id: I89a487aeaf56e56c720479fee0c2633377a202f1
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/312020
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 9fdee06..fa5d174 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -952,12 +952,12 @@
}
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
- VarDeclarations& decls = (VarDeclarations&) e;
+ const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
for (const auto& stmt: decls.fVars) {
- VarDeclaration& var = (VarDeclaration&) *stmt;
+ VarDeclaration& var = stmt->as<VarDeclaration>();
if (var.fVar->fType.kind() == Type::kSampler_Kind) {
if (var.fVar->fModifiers.fLayout.fBinding < 0) {
fErrors.error(decls.fOffset,
@@ -1392,11 +1392,11 @@
void MetalCodeGenerator::writeUniformStruct() {
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
- VarDeclarations& decls = (VarDeclarations&) e;
+ const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
- const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
+ const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
if (first.fModifiers.fFlags & Modifiers::kUniform_Flag &&
first.fType.kind() != Type::kSampler_Kind) {
if (-1 == fUniformBuffer) {
@@ -1415,7 +1415,7 @@
this->writeType(first.fType);
this->write(" ");
for (const auto& stmt : decls.fVars) {
- VarDeclaration& var = (VarDeclaration&) *stmt;
+ const VarDeclaration& var = stmt->as<VarDeclaration>();
this->writeName(var.fVar->fName);
}
this->write(";\n");
@@ -1431,18 +1431,18 @@
this->write("struct Inputs {\n");
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
- VarDeclarations& decls = (VarDeclarations&) e;
+ const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
- const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
+ const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
-1 == first.fModifiers.fLayout.fBuiltin) {
this->write(" ");
this->writeType(first.fType);
this->write(" ");
for (const auto& stmt : decls.fVars) {
- VarDeclaration& var = (VarDeclaration&) *stmt;
+ const VarDeclaration& var = stmt->as<VarDeclaration>();
this->writeName(var.fVar->fName);
if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
if (fProgram.fKind == Program::kVertex_Kind) {
@@ -1470,18 +1470,18 @@
}
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
- VarDeclarations& decls = (VarDeclarations&) e;
+ const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
- const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
+ const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
-1 == first.fModifiers.fLayout.fBuiltin) {
this->write(" ");
this->writeType(first.fType);
this->write(" ");
for (const auto& stmt : decls.fVars) {
- VarDeclaration& var = (VarDeclaration&) *stmt;
+ const VarDeclaration& var = stmt->as<VarDeclaration>();
this->writeName(var.fVar->fName);
if (fProgram.fKind == Program::kVertex_Kind) {
this->write(" [[user(locn" +
@@ -1510,7 +1510,7 @@
bool wroteInterfaceBlock = false;
for (const auto& e : fProgram) {
if (ProgramElement::kInterfaceBlock_Kind == e.fKind) {
- this->writeInterfaceBlock((InterfaceBlock&) e);
+ this->writeInterfaceBlock(e.as<InterfaceBlock>());
wroteInterfaceBlock = true;
}
}
@@ -1662,9 +1662,9 @@
case ProgramElement::kExtension_Kind:
break;
case ProgramElement::kVar_Kind: {
- VarDeclarations& decl = (VarDeclarations&) e;
+ const VarDeclarations& decl = e.as<VarDeclarations>();
if (decl.fVars.size() > 0) {
- int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
+ int builtin = decl.fVars[0]->as<VarDeclaration>().fVar->fModifiers.fLayout.fBuiltin;
if (-1 == builtin) {
// normal var
this->writeVarDeclarations(decl, true);
@@ -1679,10 +1679,10 @@
// handled in writeInterfaceBlocks, do nothing
break;
case ProgramElement::kFunction_Kind:
- this->writeFunction((FunctionDefinition&) e);
+ this->writeFunction(e.as<FunctionDefinition>());
break;
case ProgramElement::kModifiers_Kind:
- this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
+ this->writeModifiers(e.as<ModifiersDeclaration>().fModifiers, true);
this->writeLine(";");
break;
default:
@@ -1699,7 +1699,7 @@
}
switch (e->fKind) {
case Expression::kFunctionCall_Kind: {
- const FunctionCall& f = (const FunctionCall&) *e;
+ const FunctionCall& f = e->as<FunctionCall>();
Requirements result = this->requirements(f.fFunction);
for (const auto& arg : f.fArguments) {
result |= this->requirements(arg.get());
@@ -1707,7 +1707,7 @@
return result;
}
case Expression::kConstructor_Kind: {
- const Constructor& c = (const Constructor&) *e;
+ const Constructor& c = e->as<Constructor>();
Requirements result = kNo_Requirements;
for (const auto& arg : c.fArguments) {
result |= this->requirements(arg.get());
@@ -1715,33 +1715,33 @@
return result;
}
case Expression::kFieldAccess_Kind: {
- const FieldAccess& f = (const FieldAccess&) *e;
+ const FieldAccess& f = e->as<FieldAccess>();
if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) {
return kGlobals_Requirement;
}
return this->requirements(f.fBase.get());
}
case Expression::kSwizzle_Kind:
- return this->requirements(((const Swizzle&) *e).fBase.get());
+ return this->requirements(e->as<Swizzle>().fBase.get());
case Expression::kBinary_Kind: {
- const BinaryExpression& b = (const BinaryExpression&) *e;
+ const BinaryExpression& b = e->as<BinaryExpression>();
return this->requirements(b.fLeft.get()) | this->requirements(b.fRight.get());
}
case Expression::kIndex_Kind: {
- const IndexExpression& idx = (const IndexExpression&) *e;
+ const IndexExpression& idx = e->as<IndexExpression>();
return this->requirements(idx.fBase.get()) | this->requirements(idx.fIndex.get());
}
case Expression::kPrefix_Kind:
- return this->requirements(((const PrefixExpression&) *e).fOperand.get());
+ return this->requirements(e->as<PrefixExpression>().fOperand.get());
case Expression::kPostfix_Kind:
- return this->requirements(((const PostfixExpression&) *e).fOperand.get());
+ return this->requirements(e->as<PostfixExpression>().fOperand.get());
case Expression::kTernary_Kind: {
- const TernaryExpression& t = (const TernaryExpression&) *e;
+ const TernaryExpression& t = e->as<TernaryExpression>();
return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) |
this->requirements(t.fIfFalse.get());
}
case Expression::kVariableReference_Kind: {
- const VariableReference& v = (const VariableReference&) *e;
+ const VariableReference& v = e->as<VariableReference>();
Requirements result = kNo_Requirements;
if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
result = kGlobals_Requirement | kFragCoord_Requirement;
@@ -1771,54 +1771,54 @@
switch (s->fKind) {
case Statement::kBlock_Kind: {
Requirements result = kNo_Requirements;
- for (const auto& child : ((const Block*) s)->fStatements) {
+ for (const auto& child : s->as<Block>().fStatements) {
result |= this->requirements(child.get());
}
return result;
}
case Statement::kVarDeclaration_Kind: {
- const VarDeclaration& var = (const VarDeclaration&) *s;
+ const VarDeclaration& var = s->as<VarDeclaration>();
return this->requirements(var.fValue.get());
}
case Statement::kVarDeclarations_Kind: {
Requirements result = kNo_Requirements;
- const VarDeclarations& decls = *((const VarDeclarationsStatement&) *s).fDeclaration;
+ const VarDeclarations& decls = *s->as<VarDeclarationsStatement>().fDeclaration;
for (const auto& stmt : decls.fVars) {
result |= this->requirements(stmt.get());
}
return result;
}
case Statement::kExpression_Kind:
- return this->requirements(((const ExpressionStatement&) *s).fExpression.get());
+ return this->requirements(s->as<ExpressionStatement>().fExpression.get());
case Statement::kReturn_Kind: {
- const ReturnStatement& r = (const ReturnStatement&) *s;
+ const ReturnStatement& r = s->as<ReturnStatement>();
return this->requirements(r.fExpression.get());
}
case Statement::kIf_Kind: {
- const IfStatement& i = (const IfStatement&) *s;
+ const IfStatement& i = s->as<IfStatement>();
return this->requirements(i.fTest.get()) |
this->requirements(i.fIfTrue.get()) |
this->requirements(i.fIfFalse.get());
}
case Statement::kFor_Kind: {
- const ForStatement& f = (const ForStatement&) *s;
+ const ForStatement& f = s->as<ForStatement>();
return this->requirements(f.fInitializer.get()) |
this->requirements(f.fTest.get()) |
this->requirements(f.fNext.get()) |
this->requirements(f.fStatement.get());
}
case Statement::kWhile_Kind: {
- const WhileStatement& w = (const WhileStatement&) *s;
+ const WhileStatement& w = s->as<WhileStatement>();
return this->requirements(w.fTest.get()) |
this->requirements(w.fStatement.get());
}
case Statement::kDo_Kind: {
- const DoStatement& d = (const DoStatement&) *s;
+ const DoStatement& d = s->as<DoStatement>();
return this->requirements(d.fTest.get()) |
this->requirements(d.fStatement.get());
}
case Statement::kSwitch_Kind: {
- const SwitchStatement& sw = (const SwitchStatement&) *s;
+ const SwitchStatement& sw = s->as<SwitchStatement>();
Requirements result = this->requirements(sw.fValue.get());
for (const auto& c : sw.fCases) {
for (const auto& st : c->fStatements) {
@@ -1841,7 +1841,7 @@
fRequirements[&f] = kNo_Requirements;
for (const auto& e : fProgram) {
if (ProgramElement::kFunction_Kind == e.fKind) {
- const FunctionDefinition& def = (const FunctionDefinition&) e;
+ const FunctionDefinition& def = e.as<FunctionDefinition>();
if (&def.fDeclaration == &f) {
Requirements reqs = this->requirements(def.fBody.get());
fRequirements[&f] = reqs;