re-land of skslc type constructor cleanups

BUG=skia:

Change-Id: I953be07e2389dd4a9e7dcce0ddfd7505b309bda1
Reviewed-on: https://skia-review.googlesource.com/8265
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/gpu/glsl/GrGLSLGeometryProcessor.cpp b/src/gpu/glsl/GrGLSLGeometryProcessor.cpp
index 0d5ed93..87df896 100644
--- a/src/gpu/glsl/GrGLSLGeometryProcessor.cpp
+++ b/src/gpu/glsl/GrGLSLGeometryProcessor.cpp
@@ -108,7 +108,7 @@
                                                         &viewMatrixName);
         if (!mat.hasPerspective()) {
             gpArgs->fPositionVar.set(kVec2f_GrSLType, "pos2");
-            vertBuilder->codeAppendf("vec2 %s = vec2(%s * vec3(%s, 1));",
+            vertBuilder->codeAppendf("vec2 %s = (%s * vec3(%s, 1)).xy;",
                                      gpArgs->fPositionVar.c_str(), viewMatrixName, posName);
         } else {
             gpArgs->fPositionVar.set(kVec3f_GrSLType, "pos3");
diff --git a/src/gpu/instanced/InstanceProcessor.cpp b/src/gpu/instanced/InstanceProcessor.cpp
index 8626eb9..2ac5b8e 100644
--- a/src/gpu/instanced/InstanceProcessor.cpp
+++ b/src/gpu/instanced/InstanceProcessor.cpp
@@ -115,12 +115,26 @@
 
     void fetchNextParam(GrSLType type = kVec4f_GrSLType) const {
         SkASSERT(fParamsBuffer.isValid());
-        if (type != kVec4f_GrSLType) {
-            fVertexBuilder->codeAppendf("%s(", GrGLSLTypeString(type));
+        switch (type) {
+            case kVec2f_GrSLType: // fall through
+            case kVec3f_GrSLType: // fall through
+            case kVec4f_GrSLType:
+                break;
+            default:
+                fVertexBuilder->codeAppendf("%s(", GrGLSLTypeString(type));
         }
         fVertexBuilder->appendTexelFetch(fParamsBuffer, "paramsIdx++");
-        if (type != kVec4f_GrSLType) {
-            fVertexBuilder->codeAppend(")");
+        switch (type) {
+            case kVec2f_GrSLType:
+                fVertexBuilder->codeAppend(".xy");
+                break;
+            case kVec3f_GrSLType:
+                fVertexBuilder->codeAppend(".xyz");
+                break;
+            case kVec4f_GrSLType:
+                break;
+            default:
+                fVertexBuilder->codeAppend(")");
         }
     }
 
diff --git a/src/sksl/README b/src/sksl/README
index d78953f..98103fa 100644
--- a/src/sksl/README
+++ b/src/sksl/README
@@ -41,6 +41,9 @@
   have to be expressed "vec2(x, y) * 4.0". There is no performance penalty for 
   this, as the number is converted to a float at compile time)
 * type suffixes on numbers (1.0f, 0xFFu) are both unnecessary and unsupported
+* creating a smaller vector from a larger vector (e.g. vec2(vec3(1))) is
+  intentionally disallowed, as it is just a wordier way of performing a swizzle.
+  Use swizzles instead.
 * Use texture() instead of textureProj(), e.g. texture(sampler2D, vec3) is
   equivalent to GLSL's textureProj(sampler2D, vec3)
 * some built-in functions and one or two rarely-used language features are not
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 55d9d2c..687ccb9 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1104,15 +1104,15 @@
     return this->call(position, *ref->fFunctions[0], std::move(arguments));
 }
 
-std::unique_ptr<Expression> IRGenerator::convertConstructor(
+std::unique_ptr<Expression> IRGenerator::convertNumberConstructor(
                                                     Position position,
                                                     const Type& type,
                                                     std::vector<std::unique_ptr<Expression>> args) {
-    // FIXME: add support for structs and arrays
-    Type::Kind kind = type.kind();
-    if (!type.isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind &&
-        kind != Type::kArray_Kind) {
-        fErrors.error(position, "cannot construct '" + type.description() + "'");
+    ASSERT(type.isNumber());
+    if (args.size() != 1) {
+        fErrors.error(position, "invalid arguments to '" + type.description() +
+                                "' constructor, (expected exactly 1 argument, but found " +
+                                to_string((uint64_t) args.size()) + ")");
         return nullptr;
     }
     if (type == *fContext.fFloat_Type && args.size() == 1 &&
@@ -1120,51 +1120,59 @@
         int64_t value = ((IntLiteral&) *args[0]).fValue;
         return std::unique_ptr<Expression>(new FloatLiteral(fContext, position, (double) value));
     }
-    if (args.size() == 1 && args[0]->fType == type) {
-        // argument is already the right type, just return it
-        return std::move(args[0]);
+    if (args[0]->fKind == Expression::kIntLiteral_Kind && (type == *fContext.fInt_Type ||
+        type == *fContext.fUInt_Type)) {
+        return std::unique_ptr<Expression>(new IntLiteral(fContext,
+                                                          position,
+                                                          ((IntLiteral&) *args[0]).fValue,
+                                                          &type));
     }
-    if (type.isNumber()) {
-        if (args.size() != 1) {
-            fErrors.error(position, "invalid arguments to '" + type.description() +
-                                    "' constructor, (expected exactly 1 argument, but found " +
-                                    to_string((uint64_t) args.size()) + ")");
-            return nullptr;
-        }
-        if (args[0]->fType == *fContext.fBool_Type) {
-            std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, position, 0));
-            std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, position, 1));
-            return std::unique_ptr<Expression>(
-                                         new TernaryExpression(position, std::move(args[0]),
-                                                               this->coerce(std::move(one), type),
-                                                               this->coerce(std::move(zero),
-                                                                            type)));
-        } else if (!args[0]->fType.isNumber()) {
-            fErrors.error(position, "invalid argument to '" + type.description() +
-                                    "' constructor (expected a number or bool, but found '" +
-                                    args[0]->fType.description() + "')");
-        }
-        if (args[0]->fKind == Expression::kIntLiteral_Kind && (type == *fContext.fInt_Type ||
-            type == *fContext.fUInt_Type)) {
-            return std::unique_ptr<Expression>(new IntLiteral(fContext,
-                                                              position,
-                                                              ((IntLiteral&) *args[0]).fValue,
-                                                              &type));
-        }
-    } else if (kind == Type::kArray_Kind) {
-        const Type& base = type.componentType();
+    if (args[0]->fType == *fContext.fBool_Type) {
+        std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, position, 0));
+        std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, position, 1));
+        return std::unique_ptr<Expression>(
+                                     new TernaryExpression(position, std::move(args[0]),
+                                                           this->coerce(std::move(one), type),
+                                                           this->coerce(std::move(zero),
+                                                                        type)));
+    }
+    if (!args[0]->fType.isNumber()) {
+        fErrors.error(position, "invalid argument to '" + type.description() +
+                                "' constructor (expected a number or bool, but found '" +
+                                args[0]->fType.description() + "')");
+        return nullptr;
+    }
+    return std::unique_ptr<Expression>(new Constructor(position, std::move(type), std::move(args)));
+}
+
+int component_count(const Type& type) {
+    switch (type.kind()) {
+        case Type::kVector_Kind:
+            return type.columns();
+        case Type::kMatrix_Kind:
+            return type.columns() * type.rows();
+        default:
+            return 1;
+    }
+}
+
+std::unique_ptr<Expression> IRGenerator::convertCompoundConstructor(
+                                                    Position position,
+                                                    const Type& type,
+                                                    std::vector<std::unique_ptr<Expression>> args) {
+    ASSERT(type.kind() == Type::kVector_Kind || type.kind() == Type::kMatrix_Kind);
+    if (type.kind() == Type::kMatrix_Kind && args.size() == 1 &&
+        args[0]->fType.kind() == Type::kMatrix_Kind) {
+        // matrix from matrix is always legal
+        return std::unique_ptr<Expression>(new Constructor(position, std::move(type),
+                                                           std::move(args)));
+    }
+    int actual = 0;
+    int expected = type.rows() * type.columns();
+    if (args.size() != 1 || expected != component_count(args[0]->fType) ||
+        type.componentType().isNumber() != args[0]->fType.componentType().isNumber()) {
         for (size_t i = 0; i < args.size(); i++) {
-            args[i] = this->coerce(std::move(args[i]), base);
-            if (!args[i]) {
-                return nullptr;
-            }
-        }
-    } else {
-        ASSERT(kind == Type::kVector_Kind || kind == Type::kMatrix_Kind);
-        int actual = 0;
-        for (size_t i = 0; i < args.size(); i++) {
-            if (args[i]->fType.kind() == Type::kVector_Kind ||
-                args[i]->fType.kind() == Type::kMatrix_Kind) {
+            if (args[i]->fType.kind() == Type::kVector_Kind) {
                 if (type.componentType().isNumber() !=
                     args[i]->fType.componentType().isNumber()) {
                     fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid "
@@ -1172,7 +1180,7 @@
                                             "' constructor");
                     return nullptr;
                 }
-                actual += args[i]->fType.rows() * args[i]->fType.columns();
+                actual += args[i]->fType.columns();
             } else if (args[i]->fType.kind() == Type::kScalar_Kind) {
                 actual += 1;
                 if (type.kind() != Type::kScalar_Kind) {
@@ -1187,20 +1195,46 @@
                 return nullptr;
             }
         }
-        int min = type.rows() * type.columns();
-        int max = type.columns() > 1 ? INT_MAX : min;
-        if ((actual < min || actual > max) &&
-            !((kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) && (actual == 1))) {
+        if (actual != 1 && actual != expected) {
             fErrors.error(position, "invalid arguments to '" + type.description() +
-                                    "' constructor (expected " + to_string(min) + " scalar" +
-                                    (min == 1 ? "" : "s") + ", but found " + to_string(actual) +
-                                    ")");
+                                    "' constructor (expected " + to_string(expected) +
+                                    " scalars, but found " + to_string(actual) + ")");
             return nullptr;
         }
     }
     return std::unique_ptr<Expression>(new Constructor(position, std::move(type), std::move(args)));
 }
 
+std::unique_ptr<Expression> IRGenerator::convertConstructor(
+                                                    Position position,
+                                                    const Type& type,
+                                                    std::vector<std::unique_ptr<Expression>> args) {
+    // FIXME: add support for structs
+    Type::Kind kind = type.kind();
+    if (args.size() == 1 && args[0]->fType == type) {
+        // argument is already the right type, just return it
+        return std::move(args[0]);
+    }
+    if (type.isNumber()) {
+        return this->convertNumberConstructor(position, type, std::move(args));
+    } else if (kind == Type::kArray_Kind) {
+        const Type& base = type.componentType();
+        for (size_t i = 0; i < args.size(); i++) {
+            args[i] = this->coerce(std::move(args[i]), base);
+            if (!args[i]) {
+                return nullptr;
+            }
+        }
+        return std::unique_ptr<Expression>(new Constructor(position, std::move(type),
+                                                           std::move(args)));
+    } else if (kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) {
+        return this->convertCompoundConstructor(position, type, std::move(args));
+    } else {
+        fErrors.error(position, "cannot construct '" + type.description() + "'");
+        return nullptr;
+    }
+}
+
 std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
                                                             const ASTPrefixExpression& expression) {
     std::unique_ptr<Expression> base = this->convertExpression(*expression.fOperand);
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index 2ffcb0d..1336b68 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -126,6 +126,14 @@
     std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
     std::unique_ptr<Block> convertBlock(const ASTBlock& block);
     std::unique_ptr<Statement> convertBreak(const ASTBreakStatement& b);
+    std::unique_ptr<Expression> convertNumberConstructor(
+                                                   Position position,
+                                                   const Type& type,
+                                                   std::vector<std::unique_ptr<Expression>> params);
+    std::unique_ptr<Expression> convertCompoundConstructor(
+                                                   Position position,
+                                                   const Type& type,
+                                                   std::vector<std::unique_ptr<Expression>> params);
     std::unique_ptr<Expression> convertConstructor(Position position,
                                                    const Type& type,
                                                    std::vector<std::unique_ptr<Expression>> params);
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 9c8a5d0..1158f94 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -1500,6 +1500,37 @@
     return result;
 }
 
+void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
+                                                 SkWStream& out) {
+    FloatLiteral zero(fContext, Position(), 0);
+    SpvId zeroId = this->writeFloatLiteral(zero);
+    std::vector<SpvId> columnIds;
+    for (int column = 0; column < type.columns(); column++) {
+        this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
+                          out);
+        this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
+                        out);
+        SpvId columnId = this->nextId();
+        this->writeWord(columnId, out);
+        columnIds.push_back(columnId);
+        for (int row = 0; row < type.columns(); row++) {
+            this->writeWord(row == column ? diagonal : zeroId, out);
+        }
+    }
+    this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
+                      out);
+    this->writeWord(this->getType(type), out);
+    this->writeWord(id, out);
+    for (SpvId id : columnIds) {
+        this->writeWord(id, out);
+    }
+}
+
+void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
+                                         const Type& dstType, SkWStream& out) {
+    ABORT("unimplemented");
+}
+
 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, SkWStream& out) {
     ASSERT(c.fType.kind() == Type::kMatrix_Kind);
     // go ahead and write the arguments so we don't try to write new instructions in the middle of
@@ -1511,33 +1542,10 @@
     SpvId result = this->nextId();
     int rows = c.fType.rows();
     int columns = c.fType.columns();
-    // FIXME this won't work to create a matrix from another matrix
-    if (arguments.size() == 1) {
-        // with a single argument, a matrix will have all of its diagonal entries equal to the
-        // argument and its other values equal to zero
-        // FIXME this won't work for int matrices
-        FloatLiteral zero(fContext, Position(), 0);
-        SpvId zeroId = this->writeFloatLiteral(zero);
-        std::vector<SpvId> columnIds;
-        for (int column = 0; column < columns; column++) {
-            this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(),
-                              out);
-            this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 1)),
-                            out);
-            SpvId columnId = this->nextId();
-            this->writeWord(columnId, out);
-            columnIds.push_back(columnId);
-            for (int row = 0; row < c.fType.columns(); row++) {
-                this->writeWord(row == column ? arguments[0] : zeroId, out);
-            }
-        }
-        this->writeOpCode(SpvOpCompositeConstruct, 3 + columns,
-                          out);
-        this->writeWord(this->getType(c.fType), out);
-        this->writeWord(result, out);
-        for (SpvId id : columnIds) {
-            this->writeWord(id, out);
-        }
+    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
+        this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
+    } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
+        this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
     } else {
         std::vector<SpvId> columnIds;
         int currentCount = 0;
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 562bf27..fad7e31 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -147,7 +147,21 @@
     SpvId writeFloatConstructor(const Constructor& c, SkWStream& out);
 
     SpvId writeIntConstructor(const Constructor& c, SkWStream& out);
-    
+
+    /**
+     * Writes a matrix with the diagonal entries all equal to the provided expression, and all other
+     * entries equal to zero.
+     */
+    void writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, SkWStream& out);
+
+    /**
+     * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
+     * source matrix are filled with zero; entries which do not exist in the destination matrix are
+     * ignored.
+     */
+    void writeMatrixCopy(SpvId id, SpvId src, const Type& srcType, const Type& dstType,
+                         SkWStream& out);
+
     SpvId writeMatrixConstructor(const Constructor& c, SkWStream& out);
 
     SpvId writeVectorConstructor(const Constructor& c, SkWStream& out);
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 691bea1..3360ace 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -17,6 +17,12 @@
 
 /**
  * Represents the construction of a compound type, such as "vec2(x, y)".
+ *
+ * Vector constructors will always consist of either exactly 1 scalar, or a collection of vectors
+ * and scalars totalling exactly the right number of scalar components.
+ *
+ * Matrix constructors will always consist of either exactly 1 scalar, exactly 1 matrix, or a
+ * collection of vectors and scalars totalling exactly the right number of scalar components.
  */
 struct Constructor : public Expression {
     Constructor(Position position, const Type& type,
diff --git a/tests/SkSLErrorTest.cpp b/tests/SkSLErrorTest.cpp
index a33e6f1..9c34ab4 100644
--- a/tests/SkSLErrorTest.cpp
+++ b/tests/SkSLErrorTest.cpp
@@ -134,7 +134,10 @@
                  "void main() { vec3 x = vec3(1.0, 2.0); }",
                  "error: 1: invalid arguments to 'vec3' constructor (expected 3 scalars, but "
                  "found 2)\n1 error\n");
-    test_success(r, "void main() { vec3 x = vec3(1.0, 2.0, 3.0, 4.0); }");
+    test_failure(r,
+                 "void main() { vec3 x = vec3(1.0, 2.0, 3.0, 4.0); }",
+                 "error: 1: invalid arguments to 'vec3' constructor (expected 3 scalars, but found "
+                 "4)\n1 error\n");
 }
 
 DEF_TEST(SkSLSwizzleScalar, r) {
diff --git a/tests/SkSLGLSLTest.cpp b/tests/SkSLGLSLTest.cpp
index a0cc7d5..533c203 100644
--- a/tests/SkSLGLSLTest.cpp
+++ b/tests/SkSLGLSLTest.cpp
@@ -373,24 +373,20 @@
          "vec2 v1 = vec2(1);"
          "vec2 v2 = vec2(1, 2);"
          "vec2 v3 = vec2(vec2(1));"
-         "vec2 v4 = vec2(vec3(1));"
-         "vec3 v5 = vec3(vec2(1), 1.0);"
-         "vec3 v6 = vec3(vec4(1, 2, 3, 4));"
-         "ivec2 v7 = ivec2(1);"
-         "ivec2 v8 = ivec2(vec2(1, 2));"
-         "vec2 v9 = vec2(ivec2(1, 2));",
+         "vec3 v4 = vec3(vec2(1), 1.0);"
+         "ivec2 v5 = ivec2(1);"
+         "ivec2 v6 = ivec2(vec2(1, 2));"
+         "vec2 v7 = vec2(ivec2(1, 2));",
          *SkSL::ShaderCapsFactory::Default(),
          "#version 400\n"
          "out vec4 sk_FragColor;\n"
          "vec2 v1 = vec2(1.0);\n"
          "vec2 v2 = vec2(1.0, 2.0);\n"
          "vec2 v3 = vec2(1.0);\n"
-         "vec2 v4 = vec2(vec3(1.0));\n"
-         "vec3 v5 = vec3(vec2(1.0), 1.0);\n"
-         "vec3 v6 = vec3(vec4(1.0, 2.0, 3.0, 4.0));\n"
-         "ivec2 v7 = ivec2(1);\n"
-         "ivec2 v8 = ivec2(vec2(1.0, 2.0));\n"
-         "vec2 v9 = vec2(ivec2(1, 2));\n");
+         "vec3 v4 = vec3(vec2(1.0), 1.0);\n"
+         "ivec2 v5 = ivec2(1);\n"
+         "ivec2 v6 = ivec2(vec2(1.0, 2.0));\n"
+         "vec2 v7 = vec2(ivec2(1, 2));\n");
 }
 
 DEF_TEST(SkSLArrayConstructors, r) {