sksl support for static ifs & switches

Bug: skia:
Change-Id: Ic9e01a3a18efddb19bab26e92bfb473cad294fc1
Reviewed-on: https://skia-review.googlesource.com/16144
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ben Wagner <benjaminwagner@google.com>
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index c6140d3..a283e30 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -764,12 +764,70 @@
     }
 }
 
+
+// returns true if this statement could potentially execute a break at the current level (we ignore
+// nested loops and switches, since any breaks inside of them will merely break the loop / switch)
+static bool contains_break(Statement& s) {
+    switch (s.fKind) {
+        case Statement::kBlock_Kind:
+            for (const auto& sub : ((Block&) s).fStatements) {
+                if (contains_break(*sub)) {
+                    return true;
+                }
+            }
+            return false;
+        case Statement::kBreak_Kind:
+            return true;
+        case Statement::kIf_Kind: {
+            const IfStatement& i = (IfStatement&) s;
+            return contains_break(*i.fIfTrue) || (i.fIfFalse && contains_break(*i.fIfFalse));
+        }
+        default:
+            return false;
+    }
+}
+
+// Returns a block containing all of the statements that will be run if the given case matches
+// (which, owing to the statements being owned by unique_ptrs, means the switch itself will be
+// broken by this call and must then be discarded).
+// Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as
+// when break statements appear inside conditionals.
+static std::unique_ptr<Statement> block_for_case(SwitchStatement* s, SwitchCase* c) {
+    bool capturing = false;
+    std::vector<std::unique_ptr<Statement>*> statementPtrs;
+    for (const auto& current : s->fCases) {
+        if (current.get() == c) {
+            capturing = true;
+        }
+        if (capturing) {
+            for (auto& stmt : current->fStatements) {
+                if (stmt->fKind == Statement::kBreak_Kind) {
+                    capturing = false;
+                    break;
+                }
+                if (contains_break(*stmt)) {
+                    return nullptr;
+                }
+                statementPtrs.push_back(&stmt);
+            }
+            if (!capturing) {
+                break;
+            }
+        }
+    }
+    std::vector<std::unique_ptr<Statement>> statements;
+    for (const auto& s : statementPtrs) {
+        statements.push_back(std::move(*s));
+    }
+    return std::unique_ptr<Statement>(new Block(Position(), std::move(statements)));
+}
+
 void Compiler::simplifyStatement(DefinitionMap& definitions,
-                                  BasicBlock& b,
-                                  std::vector<BasicBlock::Node>::iterator* iter,
-                                  std::unordered_set<const Variable*>* undefinedVariables,
-                                  bool* outUpdated,
-                                  bool* outNeedsRescan) {
+                                 BasicBlock& b,
+                                 std::vector<BasicBlock::Node>::iterator* iter,
+                                 std::unordered_set<const Variable*>* undefinedVariables,
+                                 bool* outUpdated,
+                                 bool* outNeedsRescan) {
     Statement* stmt = (*iter)->statement()->get();
     switch (stmt->fKind) {
         case Statement::kVarDeclarations_Kind: {
@@ -798,6 +856,22 @@
         }
         case Statement::kIf_Kind: {
             IfStatement& i = (IfStatement&) *stmt;
+            if (i.fTest->fKind == Expression::kBoolLiteral_Kind) {
+                // constant if, collapse down to a single branch
+                if (((BoolLiteral&) *i.fTest).fValue) {
+                    ASSERT(i.fIfTrue);
+                    (*iter)->setStatement(std::move(i.fIfTrue));
+                } else {
+                    if (i.fIfFalse) {
+                        (*iter)->setStatement(std::move(i.fIfFalse));
+                    } else {
+                        (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
+                    }
+                }
+                *outUpdated = true;
+                *outNeedsRescan = true;
+                break;
+            }
             if (i.fIfFalse && i.fIfFalse->isEmpty()) {
                 // else block doesn't do anything, remove it
                 i.fIfFalse.reset();
@@ -820,6 +894,57 @@
             }
             break;
         }
+        case Statement::kSwitch_Kind: {
+            SwitchStatement& s = (SwitchStatement&) *stmt;
+            if (s.fValue->isConstant()) {
+                // switch is constant, replace it with the case that matches
+                bool found = false;
+                SwitchCase* defaultCase = nullptr;
+                for (const auto& c : s.fCases) {
+                    if (!c->fValue) {
+                        defaultCase = c.get();
+                        continue;
+                    }
+                    ASSERT(c->fValue->fKind == s.fValue->fKind);
+                    found = c->fValue->compareConstant(fContext, *s.fValue);
+                    if (found) {
+                        std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
+                        if (newBlock) {
+                            (*iter)->setStatement(std::move(newBlock));
+                            break;
+                        } else {
+                            if (s.fIsStatic) {
+                                this->error(s.fPosition,
+                                            "static switch contains non-static conditional break");
+                                s.fIsStatic = false;
+                            }
+                            return; // can't simplify
+                        }
+                    }
+                }
+                if (!found) {
+                    // no matching case. use default if it exists, or kill the whole thing
+                    if (defaultCase) {
+                        std::unique_ptr<Statement> newBlock = block_for_case(&s, defaultCase);
+                        if (newBlock) {
+                            (*iter)->setStatement(std::move(newBlock));
+                        } else {
+                            if (s.fIsStatic) {
+                                this->error(s.fPosition,
+                                            "static switch contains non-static conditional break");
+                                s.fIsStatic = false;
+                            }
+                            return; // can't simplify
+                        }
+                    } else {
+                        (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
+                    }
+                }
+                *outUpdated = true;
+                *outNeedsRescan = true;
+            }
+            break;
+        }
         case Statement::kExpression_Kind: {
             ExpressionStatement& e = (ExpressionStatement&) *stmt;
             ASSERT((*iter)->statement()->get() == &e);
@@ -892,6 +1017,31 @@
     } while (updated);
     ASSERT(!needsRescan);
 
+    // verify static ifs & switches
+    for (BasicBlock& b : cfg.fBlocks) {
+        DefinitionMap definitions = b.fBefore;
+
+        for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
+            if (iter->fKind == BasicBlock::Node::kStatement_Kind) {
+                const Statement& s = **iter->statement();
+                switch (s.fKind) {
+                    case Statement::kIf_Kind:
+                        if (((const IfStatement&) s).fIsStatic) {
+                            this->error(s.fPosition, "static if has non-static test");
+                        }
+                        break;
+                    case Statement::kSwitch_Kind:
+                        if (((const SwitchStatement&) s).fIsStatic) {
+                            this->error(s.fPosition, "static switch has non-static test");
+                        }
+                        break;
+                    default:
+                        break;
+                }
+            }
+        }
+    }
+
     // check for missing return
     if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {
         if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {