Represent splat constructors with a dedicated ConstructorSplat class.

Change-Id: Ic9c3d688b571591d057ab6a4e998f1f9712a1b58
Bug: skia:11032
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/392117
Commit-Queue: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index d624cdf..bb8c0e0 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -718,10 +718,10 @@
             IsTrivialExpression(*expr.as<Swizzle>().base())) ||
            (expr.is<FieldAccess>() &&
             IsTrivialExpression(*expr.as<FieldAccess>().base())) ||
-           (expr.is<Constructor>() &&
-            expr.as<Constructor>().arguments().size() == 1 &&
-            IsTrivialExpression(*expr.as<Constructor>().arguments().front())) ||
-           (expr.is<Constructor>() &&
+           (expr.isAnyConstructor() &&
+            expr.asAnyConstructor().argumentSpan().size() == 1 &&
+            IsTrivialExpression(*expr.asAnyConstructor().argumentSpan().front())) ||
+           (expr.isAnyConstructor() &&
             expr.isConstantOrUniform()) ||
            (expr.is<IndexExpression>() &&
             expr.as<IndexExpression>().index()->is<IntLiteral>() &&
@@ -749,7 +749,8 @@
 
         case Expression::Kind::kConstructor:
         case Expression::Kind::kConstructorArray:
-        case Expression::Kind::kConstructorDiagonalMatrix: {
+        case Expression::Kind::kConstructorDiagonalMatrix:
+        case Expression::Kind::kConstructorSplat: {
             const AnyConstructor& leftCtor = left.asAnyConstructor();
             const AnyConstructor& rightCtor = right.asAnyConstructor();
             const auto leftSpan = leftCtor.argumentSpan();
@@ -1019,6 +1020,7 @@
             case Expression::Kind::kConstructor:
             case Expression::Kind::kConstructorArray:
             case Expression::Kind::kConstructorDiagonalMatrix:
+            case Expression::Kind::kConstructorSplat:
             case Expression::Kind::kFieldAccess:
             case Expression::Kind::kIndex:
             case Expression::Kind::kPrefix:
@@ -1142,7 +1144,8 @@
         }
         case Expression::Kind::kConstructor:
         case Expression::Kind::kConstructorArray:
-        case Expression::Kind::kConstructorDiagonalMatrix: {
+        case Expression::Kind::kConstructorDiagonalMatrix:
+        case Expression::Kind::kConstructorSplat: {
             auto& c = e.asAnyConstructor();
             for (auto& arg : c.argumentSpan()) {
                 if (this->visitExpressionPtr(arg)) { return true; }
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
index 7105065..440b4cc 100644
--- a/src/sksl/SkSLConstantFolder.cpp
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -14,6 +14,7 @@
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLExpression.h"
 #include "src/sksl/ir/SkSLFloatLiteral.h"
 #include "src/sksl/ir/SkSLIntLiteral.h"
@@ -126,14 +127,12 @@
     return ctor;
 }
 
-static Constructor splat_scalar(const Expression& scalar, const Type& type) {
+static ConstructorSplat splat_scalar(const Expression& scalar, const Type& type) {
     SkASSERT(type.isVector());
     SkASSERT(type.componentType() == scalar.type());
 
-    // Use a Constructor to splat the scalar expression across a vector.
-    ExpressionArray arg;
-    arg.push_back(scalar.clone());
-    return Constructor{scalar.fOffset, type, std::move(arg)};
+    // Use a constructor to splat the scalar expression across a vector.
+    return ConstructorSplat{scalar.fOffset, type, scalar.clone()};
 }
 
 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
@@ -161,8 +160,8 @@
 }
 
 static bool contains_constant_zero(const Expression& expr) {
-    if (expr.is<Constructor>()) {
-        for (const auto& arg : expr.as<Constructor>().arguments()) {
+    if (expr.isAnyConstructor()) {
+        for (const auto& arg : expr.asAnyConstructor().argumentSpan()) {
             if (contains_constant_zero(*arg)) {
                 return true;
             }
@@ -176,8 +175,8 @@
     // This check only supports scalars and vectors (and in particular, not matrices).
     SkASSERT(expr.type().isScalar() || expr.type().isVector());
 
-    if (expr.is<Constructor>()) {
-        for (const auto& arg : expr.as<Constructor>().arguments()) {
+    if (expr.isAnyConstructor()) {
+        for (const auto& arg : expr.asAnyConstructor().argumentSpan()) {
             if (!is_constant_value(*arg, value)) {
                 return false;
             }
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index e62e27b..0a79483 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -18,6 +18,7 @@
 #include "src/sksl/ir/SkSLConstructor.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDiscardStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
@@ -301,6 +302,12 @@
                 this->writeExpressionSpan(e->as<ConstructorDiagonalMatrix>().argumentSpan());
                 break;
 
+            case Expression::Kind::kConstructorSplat:
+                this->writeCommand(Rehydrator::kConstructorSplat_Command);
+                this->write(e->type());
+                this->writeExpressionSpan(e->as<ConstructorSplat>().argumentSpan());
+                break;
+
             case Expression::Kind::kExternalFunctionCall:
             case Expression::Kind::kExternalFunctionReference:
                 SkDEBUGFAIL("unimplemented--not expected to be used from within an include file");
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index ab8a897..b475575 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -203,6 +203,7 @@
             break;
         case Expression::Kind::kConstructorArray:
         case Expression::Kind::kConstructorDiagonalMatrix:
+        case Expression::Kind::kConstructorSplat:
             this->writeAnyConstructor(expr.asAnyConstructor(), parentPrecedence);
             break;
         case Expression::Kind::kIntLiteral:
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 422013b..2d61939 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -19,6 +19,7 @@
 #include "src/sksl/ir/SkSLConstructor.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDiscardStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
@@ -326,6 +327,12 @@
                                                    *ctor.type().clone(symbolTableForExpression),
                                                    expr(ctor.argument()));
         }
+        case Expression::Kind::kConstructorSplat: {
+            const ConstructorSplat& ctor = expression.as<ConstructorSplat>();
+            return ConstructorSplat::Make(*fContext, offset,
+                                          *ctor.type().clone(symbolTableForExpression),
+                                          expr(ctor.argument()));
+        }
         case Expression::Kind::kExternalFunctionCall: {
             const ExternalFunctionCall& externalCall = expression.as<ExternalFunctionCall>();
             return std::make_unique<ExternalFunctionCall>(offset, &externalCall.function(),
@@ -922,7 +929,8 @@
             }
             case Expression::Kind::kConstructor:
             case Expression::Kind::kConstructorArray:
-            case Expression::Kind::kConstructorDiagonalMatrix: {
+            case Expression::Kind::kConstructorDiagonalMatrix:
+            case Expression::Kind::kConstructorSplat: {
                 AnyConstructor& constructorExpr = (*expr)->asAnyConstructor();
                 for (std::unique_ptr<Expression>& arg : constructorExpr.argumentSpan()) {
                     this->visitExpression(&arg);
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 60148cd..54bcae3 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -12,6 +12,7 @@
 #include "src/sksl/SkSLMemoryLayout.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLExpressionStatement.h"
 #include "src/sksl/ir/SkSLExtension.h"
 #include "src/sksl/ir/SkSLIndexExpression.h"
@@ -180,6 +181,9 @@
             this->writeSingleArgumentConstructor(expr.as<ConstructorDiagonalMatrix>(),
                                                  parentPrecedence);
             break;
+        case Expression::Kind::kConstructorSplat:
+            this->writeSingleArgumentConstructor(expr.as<ConstructorSplat>(), parentPrecedence);
+            break;
         case Expression::Kind::kIntLiteral:
             this->writeIntLiteral(expr.as<IntLiteral>());
             break;
@@ -2260,7 +2264,8 @@
         }
         case Expression::Kind::kConstructor:
         case Expression::Kind::kConstructorArray:
-        case Expression::Kind::kConstructorDiagonalMatrix: {
+        case Expression::Kind::kConstructorDiagonalMatrix:
+        case Expression::Kind::kConstructorSplat: {
             const AnyConstructor& c = e->asAnyConstructor();
             Requirements result = kNo_Requirements;
             for (const auto& arg : c.argumentSpan()) {
diff --git a/src/sksl/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
index 93a7bde..2a50909 100644
--- a/src/sksl/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
@@ -412,6 +412,7 @@
         case Expression::Kind::kConstructor:
         case Expression::Kind::kConstructorArray:
         case Expression::Kind::kConstructorDiagonalMatrix:
+        case Expression::Kind::kConstructorSplat:
             this->writeAnyConstructor(expr.asAnyConstructor(), parentPrecedence);
             break;
         case Expression::Kind::kFieldAccess:
diff --git a/src/sksl/SkSLRehydrator.cpp b/src/sksl/SkSLRehydrator.cpp
index 57eb9ea..e3eb784 100644
--- a/src/sksl/SkSLRehydrator.cpp
+++ b/src/sksl/SkSLRehydrator.cpp
@@ -18,6 +18,7 @@
 #include "src/sksl/ir/SkSLConstructor.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDiscardStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
@@ -470,6 +471,12 @@
             return ConstructorDiagonalMatrix::Make(fContext, /*offset=*/-1, *type,
                                                    std::move(args[0]));
         }
+        case Rehydrator::kConstructorSplat_Command: {
+            const Type* type = this->type();
+            ExpressionArray args = this->expressionArray();
+            SkASSERT(args.size() == 1);
+            return ConstructorSplat::Make(fContext, /*offset=*/-1, *type, std::move(args[0]));
+        }
         case Rehydrator::kFieldAccess_Command: {
             std::unique_ptr<Expression> base = this->expression();
             int index = this->readU8();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 7bdafc0..128d3d5 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -715,6 +715,8 @@
             return this->writeArrayConstructor(expr.as<ConstructorArray>(), out);
         case Expression::Kind::kConstructorDiagonalMatrix:
             return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
+        case Expression::Kind::kConstructorSplat:
+            return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
         case Expression::Kind::kIntLiteral:
             return this->writeIntLiteral(expr.as<IntLiteral>());
         case Expression::Kind::kFieldAccess:
@@ -1172,7 +1174,7 @@
     return result;
 }
 
-SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
+SpvId SPIRVCodeGenerator::writeConstantVector(const AnyConstructor& c) {
     const Type& type = c.type();
     SkASSERT(type.isVector() && c.isCompileTimeConstant());
 
@@ -1625,26 +1627,39 @@
             arguments.push_back(this->writeExpression(*c.arguments()[i], out));
         }
     }
+    SkASSERT((int)arguments.size() == type.columns());
+
     SpvId result = this->nextId(&type);
-    if (arguments.size() == 1 && c.arguments()[0]->type().isScalar()) {
-        this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), out);
-        this->writeWord(this->getType(type), out);
-        this->writeWord(result, out);
-        for (int i = 0; i < type.columns(); i++) {
-            this->writeWord(arguments[0], out);
-        }
-    } else {
-        SkASSERT(arguments.size() > 1);
-        this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
-        this->writeWord(this->getType(type), out);
-        this->writeWord(result, out);
-        for (SpvId id : arguments) {
-            this->writeWord(id, out);
-        }
+    this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
+    this->writeWord(this->getType(type), out);
+    this->writeWord(result, out);
+    for (SpvId id : arguments) {
+        this->writeWord(id, out);
     }
     return result;
 }
 
+SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
+    // Use writeConstantVector to deduplicate constant splats.
+    if (c.isCompileTimeConstant()) {
+        return this->writeConstantVector(c);
+    }
+
+    // Write the splat argument.
+    SpvId argument = this->writeExpression(*c.argument(), out);
+
+    // Generate a OpCompositeConstruct which repeats the argument N times.
+    SpvId result = this->nextId(&c.type());
+    this->writeOpCode(SpvOpCompositeConstruct, 3 + c.type().columns(), out);
+    this->writeWord(this->getType(c.type()), out);
+    this->writeWord(result, out);
+    for (int i = 0; i < c.type().columns(); i++) {
+        this->writeWord(argument, out);
+    }
+    return result;
+}
+
+
 SpvId SPIRVCodeGenerator::writeArrayConstructor(const ConstructorArray& c, OutputStream& out) {
     const Type& type = c.type();
     SkASSERT(type.isArray());
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 0789b2d..9552957 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -24,6 +24,7 @@
 #include "src/sksl/ir/SkSLConstructor.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
 #include "src/sksl/ir/SkSLFieldAccess.h"
 #include "src/sksl/ir/SkSLFloatLiteral.h"
@@ -241,7 +242,7 @@
 
     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
 
-    SpvId writeConstantVector(const Constructor& c);
+    SpvId writeConstantVector(const AnyConstructor& c);
 
     SpvId writeFloatConstructor(const Constructor& c, OutputStream& out);
 
@@ -291,6 +292,8 @@
 
     SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
 
+    SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
+
     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
 
     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
diff --git a/src/sksl/SkSLVMGenerator.cpp b/src/sksl/SkSLVMGenerator.cpp
index 8348f58..3bd5331 100644
--- a/src/sksl/SkSLVMGenerator.cpp
+++ b/src/sksl/SkSLVMGenerator.cpp
@@ -20,6 +20,7 @@
 #include "src/sksl/ir/SkSLConstructor.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
 #include "src/sksl/ir/SkSLExpressionStatement.h"
@@ -247,6 +248,7 @@
     Value writeConstructor(const Constructor& c);
     Value writeMultiArgumentConstructor(const MultiArgumentConstructor& c);
     Value writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
+    Value writeConstructorSplat(const ConstructorSplat& c);
     Value writeFunctionCall(const FunctionCall& c);
     Value writeExternalFunctionCall(const ExternalFunctionCall& c);
     Value writeFieldAccess(const FieldAccess& expr);
@@ -768,19 +770,24 @@
         return dst;
     }
 
-    // We can splat scalars to all components of a vector
-    if (dstType.isVector() && srcType.isScalar()) {
-        Value dst(dstType.columns());
-        for (int i = 0; i < dstType.columns(); ++i) {
-            dst[i] = src[0];
-        }
-        return dst;
-    }
-
     SkDEBUGFAIL("Invalid constructor");
     return {};
 }
 
+Value SkVMGenerator::writeConstructorSplat(const ConstructorSplat& c) {
+    SkASSERT(c.type().isVector());
+    SkASSERT(c.argument()->type().isScalar());
+    int columns = c.type().columns();
+
+    // Splat the argument across all components of a vector.
+    Value src = this->writeExpression(*c.argument());
+    Value dst(columns);
+    for (int i = 0; i < columns; ++i) {
+        dst[i] = src[0];
+    }
+    return dst;
+}
+
 Value SkVMGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c) {
     const Type& dstType = c.type();
     SkASSERT(dstType.isMatrix());
@@ -1457,6 +1464,8 @@
             return this->writeMultiArgumentConstructor(e.as<ConstructorArray>());
         case Expression::Kind::kConstructorDiagonalMatrix:
             return this->writeConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
+        case Expression::Kind::kConstructorSplat:
+            return this->writeConstructorSplat(e.as<ConstructorSplat>());
         case Expression::Kind::kFieldAccess:
             return this->writeFieldAccess(e.as<FieldAccess>());
         case Expression::Kind::kIndex:
diff --git a/src/sksl/generated/sksl_gpu.dehydrated.sksl b/src/sksl/generated/sksl_gpu.dehydrated.sksl
index a2b6c7f..f6ff911 100644
--- a/src/sksl/generated/sksl_gpu.dehydrated.sksl
+++ b/src/sksl/generated/sksl_gpu.dehydrated.sksl
@@ -3899,7 +3899,7 @@
 2,
 45,0,0,0,0,1,
 37,
-6,
+9,
 43,15,2,1,
 22,
 43,176,0,0,0,0,0,1,0,
@@ -3948,11 +3948,11 @@
 47,
 1,
 52,167,3,0,66,
-6,
+9,
 43,15,2,1,
 22,
 43,176,0,0,0,0,0,
-6,
+9,
 43,15,2,1,
 22,
 43,176,0,0,0,0,0,
@@ -5115,7 +5115,7 @@
 2,
 45,0,0,0,0,1,
 37,
-6,
+9,
 43,172,1,1,
 22,
 43,176,0,0,0,0,0,1,1,1,211,3,
@@ -5823,7 +5823,7 @@
 52,18,4,0,
 53,
 37,
-6,
+9,
 43,15,2,1,
 22,
 43,176,0,0,0,0,0,1,29,154,3,157,3,160,3,163,3,166,3,169,3,172,3,175,3,178,3,181,3,184,3,187,3,190,3,193,3,196,3,202,3,205,3,208,3,221,3,227,3,230,3,236,3,239,3,242,3,245,3,6,4,9,4,12,4,15,4,
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index a704651..d16fcaf 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -10,6 +10,7 @@
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLFloatLiteral.h"
 #include "src/sksl/ir/SkSLIntLiteral.h"
 #include "src/sksl/ir/SkSLPrefixExpression.h"
@@ -87,14 +88,10 @@
                                                                     std::move(args));
         SkASSERT(typecast);
 
-        if (type.isMatrix()) {
-            // Matrix-from-scalar creates a diagonal matrix.
-            return ConstructorDiagonalMatrix::Make(context, offset, type, std::move(typecast));
-        }
-
-        ExpressionArray typecastArgs;
-        typecastArgs.push_back(std::move(typecast));
-        return std::make_unique<Constructor>(offset, type, std::move(typecastArgs));
+        // Matrix-from-scalar creates a diagonal matrix; vector-from-scalar creates a splat.
+        return type.isMatrix()
+                       ? ConstructorDiagonalMatrix::Make(context, offset, type, std::move(typecast))
+                       : ConstructorSplat::Make(context, offset, type, std::move(typecast));
     }
 
     int expected = type.rows() * type.columns();
@@ -229,6 +226,9 @@
     if (other.is<ConstructorDiagonalMatrix>()) {
         return other.compareConstant(*this);
     }
+    if (other.is<ConstructorSplat>()) {
+        return other.compareConstant(*this);
+    }
     if (!other.is<Constructor>()) {
         return ComparisonResult::kUnknown;
     }
@@ -373,8 +373,7 @@
             return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
         }
         if (argType.isMatrix()) {
-            SkASSERT(this->arguments()[0]->is<Constructor>() ||
-                     this->arguments()[0]->is<ConstructorDiagonalMatrix>());
+            SkASSERT(this->arguments()[0]->isAnyConstructor());
             // single matrix argument. make sure we're within the argument's bounds.
             if (col < argType.columns() && row < argType.rows()) {
                 // within bounds, defer to argument
diff --git a/src/sksl/ir/SkSLConstructorSplat.cpp b/src/sksl/ir/SkSLConstructorSplat.cpp
new file mode 100644
index 0000000..f0306bf
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorSplat.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/ir/SkSLConstructorSplat.h"
+
+namespace SkSL {
+
+std::unique_ptr<Expression> ConstructorSplat::Make(const Context& context,
+                                                   int offset,
+                                                   const Type& type,
+                                                   std::unique_ptr<Expression> arg) {
+    SkASSERT(type.isVector());
+    SkASSERT(arg->type() == type.componentType());
+    return std::make_unique<ConstructorSplat>(offset, type, std::move(arg));
+}
+
+Expression::ComparisonResult ConstructorSplat::compareConstant(const Expression& other) const {
+    SkASSERT(this->type() == other.type());
+    if (!other.isAnyConstructor()) {
+        return ComparisonResult::kUnknown;
+    }
+
+    return this->compareConstantConstructor(other.asAnyConstructor());
+}
+
+Expression::ComparisonResult ConstructorSplat::compareConstantConstructor(
+        const AnyConstructor& other) const {
+    ComparisonResult check = ComparisonResult::kEqual;
+    for (const std::unique_ptr<Expression>& expr : other.argumentSpan()) {
+        // We need to recurse to handle nested constructors like `half4(1) == half4(half2(1), 1, 1)`
+        check = expr->isAnyConstructor()
+                        ? this->compareConstantConstructor(expr->asAnyConstructor())
+                        : argument()->compareConstant(*expr);
+        if (check != ComparisonResult::kEqual) {
+            break;
+        }
+    }
+
+    return check;
+}
+
+}  // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorSplat.h b/src/sksl/ir/SkSLConstructorSplat.h
new file mode 100644
index 0000000..a9da4fc
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorSplat.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_CONSTRUCTOR_SPLAT
+#define SKSL_CONSTRUCTOR_SPLAT
+
+#include "src/sksl/SkSLContext.h"
+#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLExpression.h"
+
+#include <memory>
+
+namespace SkSL {
+
+/**
+ * Represents the construction of a vector splat, such as `half3(n)`.
+ *
+ * These always contain exactly 1 scalar.
+ */
+class ConstructorSplat final : public SingleArgumentConstructor {
+public:
+    static constexpr Kind kExpressionKind = Kind::kConstructorSplat;
+
+    ConstructorSplat(int offset, const Type& type, std::unique_ptr<Expression> arg)
+        : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
+
+    static std::unique_ptr<Expression> Make(const Context& context,
+                                            int offset,
+                                            const Type& type,
+                                            std::unique_ptr<Expression> arg);
+
+    std::unique_ptr<Expression> clone() const override {
+        return std::make_unique<ConstructorSplat>(fOffset, this->type(), argument()->clone());
+    }
+
+    ComparisonResult compareConstant(const Expression& other) const override;
+
+    SKSL_FLOAT getFVecComponent(int) const override {
+        return this->argument()->getConstantFloat();
+    }
+
+    SKSL_INT getIVecComponent(int) const override {
+        return this->argument()->getConstantInt();
+    }
+
+    bool getBVecComponent(int) const override {
+        return this->argument()->getConstantBool();
+    }
+
+private:
+    Expression::ComparisonResult compareConstantConstructor(const AnyConstructor& other) const;
+
+    using INHERITED = SingleArgumentConstructor;
+};
+
+}  // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 300224a..71eeacf 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -33,6 +33,7 @@
         kConstructor,
         kConstructorArray,
         kConstructorDiagonalMatrix,
+        kConstructorSplat,
         kDefined,
         kExternalFunctionCall,
         kExternalFunctionReference,
@@ -83,9 +84,8 @@
     }
 
     bool isAnyConstructor() const {
-        static_assert((int)Kind::kConstructorDiagonalMatrix + 1 == (int)Kind::kDefined);
-        return this->kind() >= Kind::kConstructor &&
-               this->kind() <= Kind::kConstructorDiagonalMatrix;
+        static_assert((int)Kind::kConstructorSplat + 1 == (int)Kind::kDefined);
+        return this->kind() >= Kind::kConstructor && this->kind() <= Kind::kConstructorSplat;
     }
 
     /**
diff --git a/src/sksl/ir/SkSLPrefixExpression.cpp b/src/sksl/ir/SkSLPrefixExpression.cpp
index 0fe9502..0853e4d 100644
--- a/src/sksl/ir/SkSLPrefixExpression.cpp
+++ b/src/sksl/ir/SkSLPrefixExpression.cpp
@@ -12,6 +12,7 @@
 #include "src/sksl/ir/SkSLConstructor.h"
 #include "src/sksl/ir/SkSLConstructorArray.h"
 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLFloatLiteral.h"
 #include "src/sksl/ir/SkSLIntLiteral.h"
 
@@ -65,6 +66,15 @@
             }
             break;
 
+        case Expression::Kind::kConstructorSplat:
+            // Convert `-vector(literal)` into `vector(-literal)`.
+            if (context.fConfig->fSettings.fOptimize && value->isCompileTimeConstant()) {
+                ConstructorSplat& ctor = operand->as<ConstructorSplat>();
+                return ConstructorSplat::Make(context, ctor.fOffset, ctor.type(),
+                                              negate_operand(context, std::move(ctor.argument())));
+            }
+            break;
+
         case Expression::Kind::kConstructor:
             // To be consistent with prior behavior, the conversion of a negated constructor into a
             // constructor of negative values is only performed when optimization is on.
diff --git a/src/sksl/ir/SkSLSwizzle.cpp b/src/sksl/ir/SkSLSwizzle.cpp
index 655a4c3..9b8bd8d 100644
--- a/src/sksl/ir/SkSLSwizzle.cpp
+++ b/src/sksl/ir/SkSLSwizzle.cpp
@@ -6,6 +6,7 @@
  */
 
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
 #include "src/sksl/ir/SkSLSwizzle.h"
 
 namespace SkSL {
@@ -189,6 +190,21 @@
             return Swizzle::Make(context, std::move(base.base()), combined);
         }
 
+        // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
+        // optimized to just `scalar`. The swizzle components don't actually matter, as every field
+        // in a splat constructor holds the same value.
+        if (expr->is<ConstructorSplat>()) {
+            ConstructorSplat& splat = expr->as<ConstructorSplat>();
+            ExpressionArray ctorArgs;
+            ctorArgs.push_back(std::move(splat.argument()));
+            auto ctor = Constructor::Convert(
+                    context, splat.fOffset,
+                    splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
+                    std::move(ctorArgs));
+            SkASSERT(ctor);
+            return ctor;
+        }
+
         // Optimize swizzles of constructors.
         if (expr->is<Constructor>()) {
             Constructor& base = expr->as<Constructor>();
@@ -196,17 +212,6 @@
             const Type& componentType = exprType.componentType();
             int swizzleSize = components.size();
 
-            // `half4(scalar).zyy` can be optimized to `half3(scalar)`. The swizzle components don't
-            // actually matter since all fields are the same.
-            if (base.arguments().size() == 1 && base.arguments().front()->type().isScalar()) {
-                auto ctor = Constructor::Convert(
-                        context, base.fOffset,
-                        componentType.toCompound(context, swizzleSize, /*rows=*/1),
-                        std::move(base.arguments()));
-                SkASSERT(ctor);
-                return ctor;
-            }
-
             // Swizzles can duplicate some elements and discard others, e.g.
             // `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
             // - Expressions with side effects need to occur exactly once, even if they