Improved skslc optimizer, particularly around vectors.

BUG=skia:

Change-Id: Idb364d9198f2ff84aad1eb68e236fb45ec1c86b7
Reviewed-on: https://skia-review.googlesource.com/8000
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ben Wagner <benjaminwagner@google.com>
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index ea87e99..e4ab700 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -14,9 +14,12 @@
 #include "SkSLParser.h"
 #include "SkSLSPIRVCodeGenerator.h"
 #include "ir/SkSLExpression.h"
+#include "ir/SkSLExpressionStatement.h"
 #include "ir/SkSLIntLiteral.h"
 #include "ir/SkSLModifiersDeclaration.h"
+#include "ir/SkSLNop.h"
 #include "ir/SkSLSymbolTable.h"
+#include "ir/SkSLTernaryExpression.h"
 #include "ir/SkSLUnresolvedFunction.h"
 #include "ir/SkSLVarDeclarations.h"
 
@@ -207,8 +210,8 @@
                               DefinitionMap* definitions) {
     switch (node.fKind) {
         case BasicBlock::Node::kExpression_Kind: {
-            ASSERT(node.fExpression);
-            const Expression* expr = (Expression*) node.fExpression->get();
+            ASSERT(node.expression());
+            const Expression* expr = (Expression*) node.expression()->get();
             switch (expr->fKind) {
                 case Expression::kBinary_Kind: {
                     BinaryExpression* b = (BinaryExpression*) expr;
@@ -240,22 +243,30 @@
                                        p->fOperand.get(),
                                        (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
                                        definitions);
-
                     }
                     break;
                 }
+                case Expression::kVariableReference_Kind: {
+                    const VariableReference* v = (VariableReference*) expr;
+                    if (v->fRefKind != VariableReference::kRead_RefKind) {
+                        this->addDefinition(
+                                       v,
+                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+                                       definitions);
+                    }
+                }
                 default:
                     break;
             }
             break;
         }
         case BasicBlock::Node::kStatement_Kind: {
-            const Statement* stmt = (Statement*) node.fStatement;
+            const Statement* stmt = (Statement*) node.statement()->get();
             if (stmt->fKind == Statement::kVarDeclarations_Kind) {
                 VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt;
-                for (VarDeclaration& decl : vd->fDeclaration->fVars) {
-                    if (decl.fValue) {
-                        (*definitions)[decl.fVar] = &decl.fValue;
+                for (const auto& decl : vd->fDeclaration->fVars) {
+                    if (decl->fValue) {
+                        (*definitions)[decl->fVar] = &decl->fValue;
                     }
                 }
             }
@@ -308,12 +319,12 @@
     for (const auto& block : cfg.fBlocks) {
         for (const auto& node : block.fNodes) {
             if (node.fKind == BasicBlock::Node::kStatement_Kind) {
-                ASSERT(node.fStatement);
-                const Statement* s = node.fStatement;
+                ASSERT(node.statement());
+                const Statement* s = node.statement()->get();
                 if (s->fKind == Statement::kVarDeclarations_Kind) {
                     const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
-                    for (const VarDeclaration& decl : vd->fDeclaration->fVars) {
-                        result[decl.fVar] = nullptr;
+                    for (const auto& decl : vd->fDeclaration->fVars) {
+                        result[decl->fVar] = nullptr;
                     }
                 }
             }
@@ -322,20 +333,290 @@
     return result;
 }
 
-void Compiler::scanCFG(const FunctionDefinition& f) {
-    CFG cfg = CFGGenerator().getCFG(f);
+/**
+ * Returns true if assigning to this lvalue has no effect.
+ */
+static bool is_dead(const Expression& lvalue) {
+    switch (lvalue.fKind) {
+        case Expression::kVariableReference_Kind:
+            return ((VariableReference&) lvalue).fVariable.dead();
+        case Expression::kSwizzle_Kind:
+            return is_dead(*((Swizzle&) lvalue).fBase);
+        case Expression::kFieldAccess_Kind:
+            return is_dead(*((FieldAccess&) lvalue).fBase);
+        case Expression::kIndex_Kind: {
+            const IndexExpression& idx = (IndexExpression&) lvalue;
+            return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
+        }
+        default:
+            ABORT("invalid lvalue: %s\n", lvalue.description().c_str());
+    }
+}
 
-    // compute the data flow
-    cfg.fBlocks[cfg.fStart].fBefore = compute_start_state(cfg);
+/**
+ * Returns true if this is an assignment which can be collapsed down to just the right hand side due
+ * to a dead target and lack of side effects on the left hand side.
+ */
+static bool dead_assignment(const BinaryExpression& b) {
+    if (!Token::IsAssignment(b.fOperator)) {
+        return false;
+    }
+    return is_dead(*b.fLeft);
+}
+
+void Compiler::computeDataFlow(CFG* cfg) {
+    cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg);
     std::set<BlockId> workList;
-    for (BlockId i = 0; i < cfg.fBlocks.size(); i++) {
+    for (BlockId i = 0; i < cfg->fBlocks.size(); i++) {
         workList.insert(i);
     }
     while (workList.size()) {
         BlockId next = *workList.begin();
         workList.erase(workList.begin());
-        this->scanCFG(&cfg, next, &workList);
+        this->scanCFG(cfg, next, &workList);
     }
+}
+
+/**
+ * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the
+ * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to
+ * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will
+ * need to be regenerated).
+ */
+bool try_replace_expression(BasicBlock* b,
+                            std::vector<BasicBlock::Node>::iterator* iter,
+                            std::unique_ptr<Expression>* newExpression) {
+    std::unique_ptr<Expression>* target = (*iter)->expression();
+    if (!b->tryRemoveExpression(iter)) {
+        *target = std::move(*newExpression);
+        return false;
+    }
+    *target = std::move(*newExpression);
+    return b->tryInsertExpression(iter, target);
+}
+
+/**
+ * Returns true if the expression is a constant numeric literal with the specified value.
+ */
+bool is_constant(Expression& expr, double value) {
+    switch (expr.fKind) {
+        case Expression::kIntLiteral_Kind:
+            return ((IntLiteral&) expr).fValue == value;
+        case Expression::kFloatLiteral_Kind:
+            return ((FloatLiteral&) expr).fValue == value;
+        default:
+            return false;
+    }
+}
+
+/**
+ * Collapses the binary expression pointed to by iter down to just the right side (in both the IR
+ * and CFG structures).
+ */
+void delete_left(BasicBlock* b,
+                 std::vector<BasicBlock::Node>::iterator* iter,
+                 bool* outUpdated,
+                 bool* outNeedsRescan) {
+    *outUpdated = true;
+    if (!try_replace_expression(b, iter, &((BinaryExpression&) **(*iter)->expression()).fRight)) {
+        *outNeedsRescan = true;
+    }
+}
+
+/**
+ * Collapses the binary expression pointed to by iter down to just the left side (in both the IR and
+ * CFG structures).
+ */
+void delete_right(BasicBlock* b,
+                  std::vector<BasicBlock::Node>::iterator* iter,
+                  bool* outUpdated,
+                  bool* outNeedsRescan) {
+    *outUpdated = true;
+    if (!try_replace_expression(b, iter, &((BinaryExpression&) **(*iter)->expression()).fLeft)) {
+        *outNeedsRescan = true;
+    }
+}
+
+void Compiler::simplifyExpression(DefinitionMap& definitions,
+                                  BasicBlock& b,
+                                  std::vector<BasicBlock::Node>::iterator* iter,
+                                  std::unordered_set<const Variable*>* undefinedVariables,
+                                  bool* outUpdated,
+                                  bool* outNeedsRescan) {
+    Expression* expr = (*iter)->expression()->get();
+    ASSERT(expr);
+    if ((*iter)->fConstantPropagation) {
+        std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, definitions);
+        if (optimized) {
+            if (!try_replace_expression(&b, iter, &optimized)) {
+                *outNeedsRescan = true;
+            }
+            ASSERT((*iter)->fKind == BasicBlock::Node::kExpression_Kind);
+            expr = (*iter)->expression()->get();
+            *outUpdated = true;
+        }
+    }
+    switch (expr->fKind) {
+        case Expression::kVariableReference_Kind: {
+            const Variable& var = ((VariableReference*) expr)->fVariable;
+            if (var.fStorage == Variable::kLocal_Storage && !definitions[&var] &&
+                (*undefinedVariables).find(&var) == (*undefinedVariables).end()) {
+                (*undefinedVariables).insert(&var);
+                this->error(expr->fPosition,
+                            "'" + var.fName + "' has not been assigned");
+            }
+            break;
+        }
+        case Expression::kTernary_Kind: {
+            TernaryExpression* t = (TernaryExpression*) expr;
+            if (t->fTest->fKind == Expression::kBoolLiteral_Kind) {
+                // ternary has a constant test, replace it with either the true or
+                // false branch
+                if (((BoolLiteral&) *t->fTest).fValue) {
+                    (*iter)->setExpression(std::move(t->fIfTrue));
+                } else {
+                    (*iter)->setExpression(std::move(t->fIfFalse));
+                }
+                *outUpdated = true;
+                *outNeedsRescan = true;
+            }
+            break;
+        }
+        case Expression::kBinary_Kind: {
+            // collapse useless expressions like x * 1 or x + 0
+            BinaryExpression* bin = (BinaryExpression*) expr;
+            switch (bin->fOperator) {
+                case Token::STAR:
+                    if (is_constant(*bin->fLeft, 1)) {
+                        delete_left(&b, iter, outUpdated, outNeedsRescan);
+                    }
+                    else if (is_constant(*bin->fRight, 1)) {
+                        delete_right(&b, iter, outUpdated, outNeedsRescan);
+                    }
+                    break;
+                case Token::PLUS: // fall through
+                case Token::MINUS:
+                    if (is_constant(*bin->fLeft, 0)) {
+                        delete_left(&b, iter, outUpdated, outNeedsRescan);
+                    }
+                    else if (is_constant(*bin->fRight, 0)) {
+                        delete_right(&b, iter, outUpdated, outNeedsRescan);
+                    }
+                    break;
+                case Token::SLASH:
+                    if (is_constant(*bin->fRight, 1)) {
+                        delete_right(&b, iter, outUpdated, outNeedsRescan);
+                    }
+                    break;
+                default:
+                    break;
+            }
+        }
+        default:
+            break;
+    }
+}
+
+void Compiler::simplifyStatement(DefinitionMap& definitions,
+                                  BasicBlock& b,
+                                  std::vector<BasicBlock::Node>::iterator* iter,
+                                  std::unordered_set<const Variable*>* undefinedVariables,
+                                  bool* outUpdated,
+                                  bool* outNeedsRescan) {
+    Statement* stmt = (*iter)->statement()->get();
+    switch (stmt->fKind) {
+        case Statement::kVarDeclarations_Kind: {
+            VarDeclarations& vd = *((VarDeclarationsStatement&) *stmt).fDeclaration;
+            for (auto varIter = vd.fVars.begin(); varIter != vd.fVars.end(); ) {
+                const auto& varDecl = **varIter;
+                if (varDecl.fVar->dead() &&
+                    (!varDecl.fValue ||
+                     !varDecl.fValue->hasSideEffects())) {
+                    if (varDecl.fValue) {
+                        ASSERT((*iter)->statement()->get() == stmt);
+                        if (!b.tryRemoveExpressionBefore(iter, varDecl.fValue.get())) {
+                            *outNeedsRescan = true;
+                        }
+                    }
+                    varIter = vd.fVars.erase(varIter);
+                    *outUpdated = true;
+                } else {
+                    ++varIter;
+                }
+            }
+            if (vd.fVars.size() == 0) {
+                (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
+            }
+            break;
+        }
+        case Statement::kIf_Kind: {
+            IfStatement& i = (IfStatement&) *stmt;
+            if (i.fIfFalse && i.fIfFalse->isEmpty()) {
+                // else block doesn't do anything, remove it
+                i.fIfFalse.reset();
+                *outUpdated = true;
+                *outNeedsRescan = true;
+            }
+            if (!i.fIfFalse && i.fIfTrue->isEmpty()) {
+                // if block doesn't do anything, no else block
+                if (i.fTest->hasSideEffects()) {
+                    // test has side effects, keep it
+                    (*iter)->setStatement(std::unique_ptr<Statement>(
+                                                      new ExpressionStatement(std::move(i.fTest))));
+                } else {
+                    // no if, no else, no test side effects, kill the whole if
+                    // statement
+                    (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
+                }
+                *outUpdated = true;
+                *outNeedsRescan = true;
+            }
+            break;
+        }
+        case Statement::kExpression_Kind: {
+            ExpressionStatement& e = (ExpressionStatement&) *stmt;
+            ASSERT((*iter)->statement()->get() == &e);
+            if (e.fExpression->fKind == Expression::kBinary_Kind) {
+                BinaryExpression& bin = (BinaryExpression&) *e.fExpression;
+                if (dead_assignment(bin)) {
+                    if (!b.tryRemoveExpressionBefore(iter, &bin)) {
+                        *outNeedsRescan = true;
+                    }
+                    if (bin.fRight->hasSideEffects()) {
+                        // still have to evaluate the right due to side effects,
+                        // replace the binary expression with just the right side
+                        e.fExpression = std::move(bin.fRight);
+                        if (!b.tryInsertExpression(iter, &e.fExpression)) {
+                            *outNeedsRescan = true;
+                        }
+                    } else {
+                        // no side effects, kill the whole statement
+                        ASSERT((*iter)->statement()->get() == stmt);
+                        (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
+                    }
+                    *outUpdated = true;
+                    break;
+                }
+            }
+            if (!e.fExpression->hasSideEffects()) {
+                // Expression statement with no side effects, kill it
+                if (!b.tryRemoveExpressionBefore(iter, e.fExpression.get())) {
+                    *outNeedsRescan = true;
+                }
+                ASSERT((*iter)->statement()->get() == stmt);
+                (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
+                *outUpdated = true;
+            }
+            break;
+        }
+        default:
+            break;
+    }
+}
+
+void Compiler::scanCFG(FunctionDefinition& f) {
+    CFG cfg = CFGGenerator().getCFG(f);
+    this->computeDataFlow(&cfg);
 
     // check for unreachable code
     for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
@@ -344,10 +625,10 @@
             Position p;
             switch (cfg.fBlocks[i].fNodes[0].fKind) {
                 case BasicBlock::Node::kStatement_Kind:
-                    p = cfg.fBlocks[i].fNodes[0].fStatement->fPosition;
+                    p = (*cfg.fBlocks[i].fNodes[0].statement())->fPosition;
                     break;
                 case BasicBlock::Node::kExpression_Kind:
-                    p = (*cfg.fBlocks[i].fNodes[0].fExpression)->fPosition;
+                    p = (*cfg.fBlocks[i].fNodes[0].expression())->fPosition;
                     break;
             }
             this->error(p, String("unreachable"));
@@ -357,33 +638,34 @@
         return;
     }
 
-    // check for undefined variables, perform constant propagation
-    for (BasicBlock& b : cfg.fBlocks) {
-        DefinitionMap definitions = b.fBefore;
-        for (BasicBlock::Node& n : b.fNodes) {
-            if (n.fKind == BasicBlock::Node::kExpression_Kind) {
-                ASSERT(n.fExpression);
-                Expression* expr = n.fExpression->get();
-                if (n.fConstantPropagation) {
-                    std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator,
-                                                                                    definitions);
-                    if (optimized) {
-                        n.fExpression->reset(optimized.release());
-                        expr = n.fExpression->get();
-                    }
-                }
-                if (expr->fKind == Expression::kVariableReference_Kind) {
-                    const Variable& var = ((VariableReference*) expr)->fVariable;
-                    if (var.fStorage == Variable::kLocal_Storage &&
-                        !definitions[&var]) {
-                        this->error(expr->fPosition,
-                                    "'" + var.fName + "' has not been assigned");
-                    }
-                }
-            }
-            this->addDefinitions(n, &definitions);
+    // check for dead code & undefined variables, perform constant propagation
+    std::unordered_set<const Variable*> undefinedVariables;
+    bool updated;
+    bool needsRescan = false;
+    do {
+        if (needsRescan) {
+            cfg = CFGGenerator().getCFG(f);
+            this->computeDataFlow(&cfg);
+            needsRescan = false;
         }
-    }
+
+        updated = false;
+        for (BasicBlock& b : cfg.fBlocks) {
+            DefinitionMap definitions = b.fBefore;
+
+            for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
+                if (iter->fKind == BasicBlock::Node::kExpression_Kind) {
+                    this->simplifyExpression(definitions, b, &iter, &undefinedVariables, &updated,
+                                             &needsRescan);
+                } else {
+                    this->simplifyStatement(definitions, b, &iter, &undefinedVariables, &updated,
+                                             &needsRescan);
+                }
+                this->addDefinitions(*iter, &definitions);
+            }
+        }
+    } while (updated);
+    ASSERT(!needsRescan);
 
     // check for missing return
     if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {