Added dead variable / code elimination to skslc.

BUG=skia:

Change-Id: Ib037730803a8f222f099de0e001fe06ad452a22c
Reviewed-on: https://skia-review.googlesource.com/7584
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ben Wagner <benjaminwagner@google.com>
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 743745a..92261a4 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -14,9 +14,12 @@
 #include "SkSLParser.h"
 #include "SkSLSPIRVCodeGenerator.h"
 #include "ir/SkSLExpression.h"
+#include "ir/SkSLExpressionStatement.h"
 #include "ir/SkSLIntLiteral.h"
 #include "ir/SkSLModifiersDeclaration.h"
+#include "ir/SkSLNop.h"
 #include "ir/SkSLSymbolTable.h"
+#include "ir/SkSLTernaryExpression.h"
 #include "ir/SkSLUnresolvedFunction.h"
 #include "ir/SkSLVarDeclarations.h"
 #include "SkMutex.h"
@@ -233,22 +236,30 @@
                                        p->fOperand.get(),
                                        (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
                                        definitions);
-
                     }
                     break;
                 }
+                case Expression::kVariableReference_Kind: {
+                    const VariableReference* v = (VariableReference*) expr;
+                    if (v->fRefKind != VariableReference::kRead_RefKind) {
+                        this->addDefinition(
+                                       v,
+                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+                                       definitions);
+                    }
+                }
                 default:
                     break;
             }
             break;
         }
         case BasicBlock::Node::kStatement_Kind: {
-            const Statement* stmt = (Statement*) node.fStatement;
+            const Statement* stmt = (Statement*) node.fStatement->get();
             if (stmt->fKind == Statement::kVarDeclarations_Kind) {
                 VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt;
-                for (VarDeclaration& decl : vd->fDeclaration->fVars) {
-                    if (decl.fValue) {
-                        (*definitions)[decl.fVar] = &decl.fValue;
+                for (const auto& decl : vd->fDeclaration->fVars) {
+                    if (decl->fValue) {
+                        (*definitions)[decl->fVar] = &decl->fValue;
                     }
                 }
             }
@@ -298,11 +309,11 @@
         for (const auto& node : block.fNodes) {
             if (node.fKind == BasicBlock::Node::kStatement_Kind) {
                 ASSERT(node.fStatement);
-                const Statement* s = node.fStatement;
+                const Statement* s = node.fStatement->get();
                 if (s->fKind == Statement::kVarDeclarations_Kind) {
                     const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
-                    for (const VarDeclaration& decl : vd->fDeclaration->fVars) {
-                        result[decl.fVar] = nullptr;
+                    for (const auto& decl : vd->fDeclaration->fVars) {
+                        result[decl->fVar] = nullptr;
                     }
                 }
             }
@@ -311,20 +322,74 @@
     return result;
 }
 
-void Compiler::scanCFG(const FunctionDefinition& f) {
-    CFG cfg = CFGGenerator().getCFG(f);
+/**
+ * Returns true if assigning to this lvalue has no effect.
+ */
+static bool is_dead(const Expression& lvalue) {
+    switch (lvalue.fKind) {
+        case Expression::kVariableReference_Kind:
+            return ((VariableReference&) lvalue).fVariable.dead();
+        case Expression::kSwizzle_Kind:
+            return is_dead(*((Swizzle&) lvalue).fBase);
+        case Expression::kFieldAccess_Kind:
+            return is_dead(*((FieldAccess&) lvalue).fBase);
+        case Expression::kIndex_Kind: {
+            const IndexExpression& idx = (IndexExpression&) lvalue;
+            return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
+        }
+        default:
+            SkDebugf("invalid lvalue: %s\n", lvalue.description().c_str());
+            ASSERT(false);
+            return false;
+    }
+}
 
-    // compute the data flow
-    cfg.fBlocks[cfg.fStart].fBefore = compute_start_state(cfg);
+/**
+ * Returns true if this is an assignment which can be collapsed down to just the right hand side due
+ * to a dead target and lack of side effects on the left hand side.
+ */
+static bool dead_assignment(const BinaryExpression& b) {
+    if (!Token::IsAssignment(b.fOperator)) {
+        return false;
+    }
+    return is_dead(*b.fLeft);
+}
+
+void Compiler::computeDataFlow(CFG* cfg) {
+    cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg);
     std::set<BlockId> workList;
-    for (BlockId i = 0; i < cfg.fBlocks.size(); i++) {
+    for (BlockId i = 0; i < cfg->fBlocks.size(); i++) {
         workList.insert(i);
     }
     while (workList.size()) {
         BlockId next = *workList.begin();
         workList.erase(workList.begin());
-        this->scanCFG(&cfg, next, &workList);
+        this->scanCFG(cfg, next, &workList);
     }
+}
+
+
+/**
+ * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the
+ * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to
+ * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will
+ * need to be regenerated).
+ */
+bool SK_WARN_UNUSED_RESULT try_replace_expression(BasicBlock* b,
+                                                  std::vector<BasicBlock::Node>::iterator* iter,
+                                                  std::unique_ptr<Expression> newExpression) {
+    std::unique_ptr<Expression>* target = (*iter)->fExpression;
+    if (!b->tryRemoveExpression(iter)) {
+        *target = std::move(newExpression);
+        return false;
+    }
+    *target = std::move(newExpression);
+    return b->tryInsertExpression( iter, target);
+}
+
+void Compiler::scanCFG(FunctionDefinition& f) {
+    CFG cfg = CFGGenerator().getCFG(f);
+    this->computeDataFlow(&cfg);
 
     // check for unreachable code
     for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
@@ -333,7 +398,7 @@
             Position p;
             switch (cfg.fBlocks[i].fNodes[0].fKind) {
                 case BasicBlock::Node::kStatement_Kind:
-                    p = cfg.fBlocks[i].fNodes[0].fStatement->fPosition;
+                    p = (*cfg.fBlocks[i].fNodes[0].fStatement)->fPosition;
                     break;
                 case BasicBlock::Node::kExpression_Kind:
                     p = (*cfg.fBlocks[i].fNodes[0].fExpression)->fPosition;
@@ -346,33 +411,170 @@
         return;
     }
 
-    // check for undefined variables, perform constant propagation
-    for (BasicBlock& b : cfg.fBlocks) {
-        DefinitionMap definitions = b.fBefore;
-        for (BasicBlock::Node& n : b.fNodes) {
-            if (n.fKind == BasicBlock::Node::kExpression_Kind) {
-                ASSERT(n.fExpression);
-                Expression* expr = n.fExpression->get();
-                if (n.fConstantPropagation) {
-                    std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator,
-                                                                                    definitions);
-                    if (optimized) {
-                        n.fExpression->reset(optimized.release());
-                        expr = n.fExpression->get();
-                    }
-                }
-                if (expr->fKind == Expression::kVariableReference_Kind) {
-                    const Variable& var = ((VariableReference*) expr)->fVariable;
-                    if (var.fStorage == Variable::kLocal_Storage &&
-                        !definitions[&var]) {
-                        this->error(expr->fPosition,
-                                    "'" + var.fName + "' has not been assigned");
-                    }
-                }
-            }
-            this->addDefinitions(n, &definitions);
+    // check for dead code & undefined variables, perform constant propagation
+    std::unordered_set<const Variable*> undefinedVariables;
+    bool updated;
+    bool needsRescan = false;
+    do {
+        if (needsRescan) {
+            cfg = CFGGenerator().getCFG(f);
+            this->computeDataFlow(&cfg);
+            needsRescan = false;
         }
-    }
+
+        updated = false;
+        for (BasicBlock& b : cfg.fBlocks) {
+            DefinitionMap definitions = b.fBefore;
+
+            for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
+                if (iter->fKind == BasicBlock::Node::kExpression_Kind) {
+                    Expression* expr = iter->fExpression->get();
+                    ASSERT(expr);
+                    if (iter->fConstantPropagation) {
+                        std::unique_ptr<Expression> optimized = expr->constantPropagate(
+                                                                                      *fIRGenerator,
+                                                                                      definitions);
+                        if (optimized) {
+                            if (!try_replace_expression(&b, &iter, std::move(optimized))) {
+                                needsRescan = true;
+                            }
+                            ASSERT(iter->fKind == BasicBlock::Node::kExpression_Kind);
+                            expr = iter->fExpression->get();
+                            updated = true;
+                        }
+                    }
+                    switch (expr->fKind) {
+                        case Expression::kVariableReference_Kind: {
+                            const Variable& var = ((VariableReference*) expr)->fVariable;
+                            if (var.fStorage == Variable::kLocal_Storage &&
+                                !definitions[&var] &&
+                                undefinedVariables.find(&var) == undefinedVariables.end()) {
+                                undefinedVariables.insert(&var);
+                                this->error(expr->fPosition,
+                                            "'" + var.fName + "' has not been assigned");
+                            }
+                            break;
+                        }
+                        case Expression::kTernary_Kind: {
+                            TernaryExpression* t = (TernaryExpression*) expr;
+                            if (t->fTest->fKind == Expression::kBoolLiteral_Kind) {
+                                // ternary has a constant test, replace it with either the true or
+                                // false branch
+                                if (((BoolLiteral&) *t->fTest).fValue) {
+                                    *iter->fExpression = std::move(t->fIfTrue);
+                                } else {
+                                    *iter->fExpression = std::move(t->fIfFalse);
+                                }
+                                updated = true;
+                                needsRescan = true;
+                            }
+                            break;
+                        }
+                        default:
+                            break;
+                    }
+                } else {
+                    ASSERT(iter->fKind == BasicBlock::Node::kStatement_Kind);
+                    Statement* stmt = iter->fStatement->get();
+                    switch (stmt->fKind) {
+                        case Statement::kVarDeclarations_Kind: {
+                            VarDeclarations& vd = *((VarDeclarationsStatement&) *stmt).fDeclaration;
+                            for (auto varIter = vd.fVars.begin(); varIter != vd.fVars.end(); ) {
+                                const auto& varDecl = **varIter;
+                                if (varDecl.fVar->dead() &&
+                                    (!varDecl.fValue ||
+                                     !varDecl.fValue->hasSideEffects())) {
+                                    if (varDecl.fValue) {
+                                        ASSERT(iter->fKind == BasicBlock::Node::kStatement_Kind &&
+                                               iter->fStatement->get() == stmt);
+                                        if (!b.tryRemoveExpressionBefore(
+                                                                        &iter,
+                                                                        varDecl.fValue.get())) {
+                                            needsRescan = true;
+                                        }
+                                    }
+                                    varIter = vd.fVars.erase(varIter);
+                                    updated = true;
+                                } else {
+                                    ++varIter;
+                                }
+                            }
+                            if (vd.fVars.size() == 0) {
+                                iter->fStatement->reset(new Nop());
+                            }
+                            break;
+                        }
+                        case Statement::kIf_Kind: {
+                            IfStatement& i = (IfStatement&) *stmt;
+                            if (i.fIfFalse && i.fIfFalse->isEmpty()) {
+                                // else block doesn't do anything, remove it
+                                i.fIfFalse.reset();
+                                updated = true;
+                                needsRescan = true;
+                            }
+                            if (!i.fIfFalse && i.fIfTrue->isEmpty()) {
+                                // if block doesn't do anything, no else block
+                                if (i.fTest->hasSideEffects()) {
+                                    // test has side effects, keep it
+                                    iter->fStatement->reset(new ExpressionStatement(
+                                                                               std::move(i.fTest)));
+                                } else {
+                                    // no if, no else, no test side effects, kill the whole if
+                                    // statement
+                                    iter->fStatement->reset(new Nop());
+                                }
+                                updated = true;
+                                needsRescan = true;
+                            }
+                            break;
+                        }
+                        case Statement::kExpression_Kind: {
+                            ExpressionStatement& e = (ExpressionStatement&) *stmt;
+                            ASSERT(iter->fStatement->get() == &e);
+                            if (e.fExpression->fKind == Expression::kBinary_Kind) {
+                                BinaryExpression& bin = (BinaryExpression&) *e.fExpression;
+                                if (dead_assignment(bin)) {
+                                    if (!b.tryRemoveExpressionBefore(&iter, &bin)) {
+                                        needsRescan = true;
+                                    }
+                                    if (bin.fRight->hasSideEffects()) {
+                                        // still have to evaluate the right due to side effects,
+                                        // replace the binary expression with just the right side
+                                        e.fExpression = std::move(bin.fRight);
+                                        if (!b.tryInsertExpression(&iter, &e.fExpression)) {
+                                            needsRescan = true;
+                                        }
+                                    } else {
+                                        // no side effects, kill the whole statement
+                                        ASSERT(iter->fKind == BasicBlock::Node::kStatement_Kind &&
+                                                              iter->fStatement->get() == stmt);
+                                        iter->fStatement->reset(new Nop());
+                                    }
+                                    updated = true;
+                                    break;
+                                }
+                            }
+                            if (!e.fExpression->hasSideEffects()) {
+                                // Expression statement with no side effects, kill it
+                                if (!b.tryRemoveExpressionBefore(&iter, e.fExpression.get())) {
+                                    needsRescan = true;
+                                }
+                                ASSERT(iter->fKind == BasicBlock::Node::kStatement_Kind &&
+                                       iter->fStatement->get() == stmt);
+                                iter->fStatement->reset(new Nop());
+                                updated = true;
+                            }
+                            break;
+                        }
+                        default:
+                            break;
+                    }
+                }
+                this->addDefinitions(*iter, &definitions);
+            }
+        }
+    } while (updated);
+    ASSERT(!needsRescan);
 
     // check for missing return
     if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {