Add as<ProgramElementSubclass> to downcast ProgramElements more safely.

The as<T>() function asserts that the ProgramElement is of the correct
kind before performing the downcast, and is also generally easier to
read as function calls flow naturally from left-to-right, and C-style
casts don't.

This CL updates several downcasts throughout SkSL to the as<T>
syntax, but is not intended to exhaustively replace them all (although
that would be ideal). In places where we SkASSERTed the element's
fKind immediately before a cast, the assert has been removed because it
would be redundant with the behavior of as<T>().

Change-Id: I89a487aeaf56e56c720479fee0c2633377a202f1
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/312020
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 9fdee06..fa5d174 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -952,12 +952,12 @@
         }
         for (const auto& e : fProgram) {
             if (ProgramElement::kVar_Kind == e.fKind) {
-                VarDeclarations& decls = (VarDeclarations&) e;
+                const VarDeclarations& decls = e.as<VarDeclarations>();
                 if (!decls.fVars.size()) {
                     continue;
                 }
                 for (const auto& stmt: decls.fVars) {
-                    VarDeclaration& var = (VarDeclaration&) *stmt;
+                    VarDeclaration& var = stmt->as<VarDeclaration>();
                     if (var.fVar->fType.kind() == Type::kSampler_Kind) {
                         if (var.fVar->fModifiers.fLayout.fBinding < 0) {
                             fErrors.error(decls.fOffset,
@@ -1392,11 +1392,11 @@
 void MetalCodeGenerator::writeUniformStruct() {
     for (const auto& e : fProgram) {
         if (ProgramElement::kVar_Kind == e.fKind) {
-            VarDeclarations& decls = (VarDeclarations&) e;
+            const VarDeclarations& decls = e.as<VarDeclarations>();
             if (!decls.fVars.size()) {
                 continue;
             }
-            const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
+            const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
             if (first.fModifiers.fFlags & Modifiers::kUniform_Flag &&
                 first.fType.kind() != Type::kSampler_Kind) {
                 if (-1 == fUniformBuffer) {
@@ -1415,7 +1415,7 @@
                 this->writeType(first.fType);
                 this->write(" ");
                 for (const auto& stmt : decls.fVars) {
-                    VarDeclaration& var = (VarDeclaration&) *stmt;
+                    const VarDeclaration& var = stmt->as<VarDeclaration>();
                     this->writeName(var.fVar->fName);
                 }
                 this->write(";\n");
@@ -1431,18 +1431,18 @@
     this->write("struct Inputs {\n");
     for (const auto& e : fProgram) {
         if (ProgramElement::kVar_Kind == e.fKind) {
-            VarDeclarations& decls = (VarDeclarations&) e;
+            const VarDeclarations& decls = e.as<VarDeclarations>();
             if (!decls.fVars.size()) {
                 continue;
             }
-            const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
+            const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
             if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
                 -1 == first.fModifiers.fLayout.fBuiltin) {
                 this->write("    ");
                 this->writeType(first.fType);
                 this->write(" ");
                 for (const auto& stmt : decls.fVars) {
-                    VarDeclaration& var = (VarDeclaration&) *stmt;
+                    const VarDeclaration& var = stmt->as<VarDeclaration>();
                     this->writeName(var.fVar->fName);
                     if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
                         if (fProgram.fKind == Program::kVertex_Kind) {
@@ -1470,18 +1470,18 @@
     }
     for (const auto& e : fProgram) {
         if (ProgramElement::kVar_Kind == e.fKind) {
-            VarDeclarations& decls = (VarDeclarations&) e;
+            const VarDeclarations& decls = e.as<VarDeclarations>();
             if (!decls.fVars.size()) {
                 continue;
             }
-            const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
+            const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
             if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
                 -1 == first.fModifiers.fLayout.fBuiltin) {
                 this->write("    ");
                 this->writeType(first.fType);
                 this->write(" ");
                 for (const auto& stmt : decls.fVars) {
-                    VarDeclaration& var = (VarDeclaration&) *stmt;
+                    const VarDeclaration& var = stmt->as<VarDeclaration>();
                     this->writeName(var.fVar->fName);
                     if (fProgram.fKind == Program::kVertex_Kind) {
                         this->write("  [[user(locn" +
@@ -1510,7 +1510,7 @@
     bool wroteInterfaceBlock = false;
     for (const auto& e : fProgram) {
         if (ProgramElement::kInterfaceBlock_Kind == e.fKind) {
-            this->writeInterfaceBlock((InterfaceBlock&) e);
+            this->writeInterfaceBlock(e.as<InterfaceBlock>());
             wroteInterfaceBlock = true;
         }
     }
@@ -1662,9 +1662,9 @@
         case ProgramElement::kExtension_Kind:
             break;
         case ProgramElement::kVar_Kind: {
-            VarDeclarations& decl = (VarDeclarations&) e;
+            const VarDeclarations& decl = e.as<VarDeclarations>();
             if (decl.fVars.size() > 0) {
-                int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
+                int builtin = decl.fVars[0]->as<VarDeclaration>().fVar->fModifiers.fLayout.fBuiltin;
                 if (-1 == builtin) {
                     // normal var
                     this->writeVarDeclarations(decl, true);
@@ -1679,10 +1679,10 @@
             // handled in writeInterfaceBlocks, do nothing
             break;
         case ProgramElement::kFunction_Kind:
-            this->writeFunction((FunctionDefinition&) e);
+            this->writeFunction(e.as<FunctionDefinition>());
             break;
         case ProgramElement::kModifiers_Kind:
-            this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
+            this->writeModifiers(e.as<ModifiersDeclaration>().fModifiers, true);
             this->writeLine(";");
             break;
         default:
@@ -1699,7 +1699,7 @@
     }
     switch (e->fKind) {
         case Expression::kFunctionCall_Kind: {
-            const FunctionCall& f = (const FunctionCall&) *e;
+            const FunctionCall& f = e->as<FunctionCall>();
             Requirements result = this->requirements(f.fFunction);
             for (const auto& arg : f.fArguments) {
                 result |= this->requirements(arg.get());
@@ -1707,7 +1707,7 @@
             return result;
         }
         case Expression::kConstructor_Kind: {
-            const Constructor& c = (const Constructor&) *e;
+            const Constructor& c = e->as<Constructor>();
             Requirements result = kNo_Requirements;
             for (const auto& arg : c.fArguments) {
                 result |= this->requirements(arg.get());
@@ -1715,33 +1715,33 @@
             return result;
         }
         case Expression::kFieldAccess_Kind: {
-            const FieldAccess& f = (const FieldAccess&) *e;
+            const FieldAccess& f = e->as<FieldAccess>();
             if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) {
                 return kGlobals_Requirement;
             }
             return this->requirements(f.fBase.get());
         }
         case Expression::kSwizzle_Kind:
-            return this->requirements(((const Swizzle&) *e).fBase.get());
+            return this->requirements(e->as<Swizzle>().fBase.get());
         case Expression::kBinary_Kind: {
-            const BinaryExpression& b = (const BinaryExpression&) *e;
+            const BinaryExpression& b = e->as<BinaryExpression>();
             return this->requirements(b.fLeft.get()) | this->requirements(b.fRight.get());
         }
         case Expression::kIndex_Kind: {
-            const IndexExpression& idx = (const IndexExpression&) *e;
+            const IndexExpression& idx = e->as<IndexExpression>();
             return this->requirements(idx.fBase.get()) | this->requirements(idx.fIndex.get());
         }
         case Expression::kPrefix_Kind:
-            return this->requirements(((const PrefixExpression&) *e).fOperand.get());
+            return this->requirements(e->as<PrefixExpression>().fOperand.get());
         case Expression::kPostfix_Kind:
-            return this->requirements(((const PostfixExpression&) *e).fOperand.get());
+            return this->requirements(e->as<PostfixExpression>().fOperand.get());
         case Expression::kTernary_Kind: {
-            const TernaryExpression& t = (const TernaryExpression&) *e;
+            const TernaryExpression& t = e->as<TernaryExpression>();
             return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) |
                    this->requirements(t.fIfFalse.get());
         }
         case Expression::kVariableReference_Kind: {
-            const VariableReference& v = (const VariableReference&) *e;
+            const VariableReference& v = e->as<VariableReference>();
             Requirements result = kNo_Requirements;
             if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
                 result = kGlobals_Requirement | kFragCoord_Requirement;
@@ -1771,54 +1771,54 @@
     switch (s->fKind) {
         case Statement::kBlock_Kind: {
             Requirements result = kNo_Requirements;
-            for (const auto& child : ((const Block*) s)->fStatements) {
+            for (const auto& child : s->as<Block>().fStatements) {
                 result |= this->requirements(child.get());
             }
             return result;
         }
         case Statement::kVarDeclaration_Kind: {
-            const VarDeclaration& var = (const VarDeclaration&) *s;
+            const VarDeclaration& var = s->as<VarDeclaration>();
             return this->requirements(var.fValue.get());
         }
         case Statement::kVarDeclarations_Kind: {
             Requirements result = kNo_Requirements;
-            const VarDeclarations& decls = *((const VarDeclarationsStatement&) *s).fDeclaration;
+            const VarDeclarations& decls = *s->as<VarDeclarationsStatement>().fDeclaration;
             for (const auto& stmt : decls.fVars) {
                 result |= this->requirements(stmt.get());
             }
             return result;
         }
         case Statement::kExpression_Kind:
-            return this->requirements(((const ExpressionStatement&) *s).fExpression.get());
+            return this->requirements(s->as<ExpressionStatement>().fExpression.get());
         case Statement::kReturn_Kind: {
-            const ReturnStatement& r = (const ReturnStatement&) *s;
+            const ReturnStatement& r = s->as<ReturnStatement>();
             return this->requirements(r.fExpression.get());
         }
         case Statement::kIf_Kind: {
-            const IfStatement& i = (const IfStatement&) *s;
+            const IfStatement& i = s->as<IfStatement>();
             return this->requirements(i.fTest.get()) |
                    this->requirements(i.fIfTrue.get()) |
                    this->requirements(i.fIfFalse.get());
         }
         case Statement::kFor_Kind: {
-            const ForStatement& f = (const ForStatement&) *s;
+            const ForStatement& f = s->as<ForStatement>();
             return this->requirements(f.fInitializer.get()) |
                    this->requirements(f.fTest.get()) |
                    this->requirements(f.fNext.get()) |
                    this->requirements(f.fStatement.get());
         }
         case Statement::kWhile_Kind: {
-            const WhileStatement& w = (const WhileStatement&) *s;
+            const WhileStatement& w = s->as<WhileStatement>();
             return this->requirements(w.fTest.get()) |
                    this->requirements(w.fStatement.get());
         }
         case Statement::kDo_Kind: {
-            const DoStatement& d = (const DoStatement&) *s;
+            const DoStatement& d = s->as<DoStatement>();
             return this->requirements(d.fTest.get()) |
                    this->requirements(d.fStatement.get());
         }
         case Statement::kSwitch_Kind: {
-            const SwitchStatement& sw = (const SwitchStatement&) *s;
+            const SwitchStatement& sw = s->as<SwitchStatement>();
             Requirements result = this->requirements(sw.fValue.get());
             for (const auto& c : sw.fCases) {
                 for (const auto& st : c->fStatements) {
@@ -1841,7 +1841,7 @@
         fRequirements[&f] = kNo_Requirements;
         for (const auto& e : fProgram) {
             if (ProgramElement::kFunction_Kind == e.fKind) {
-                const FunctionDefinition& def = (const FunctionDefinition&) e;
+                const FunctionDefinition& def = e.as<FunctionDefinition>();
                 if (&def.fDeclaration == &f) {
                     Requirements reqs = this->requirements(def.fBody.get());
                     fRequirements[&f] = reqs;