Create a diagonal-matrix Constructor class.

This constructor takes a single argument and splats it diagonally across
an otherwise-zero matrix. These are also sometimes referred to as a
uniform-scale matrix.

Change-Id: I1ed8140f55f5cad4029015807b220d6475401daa
Bug: skia:11032
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/390716
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
diff --git a/gn/sksl.gni b/gn/sksl.gni
index 2c5bc51..c50080a 100644
--- a/gn/sksl.gni
+++ b/gn/sksl.gni
@@ -94,6 +94,8 @@
   "$_src/sksl/ir/SkSLBreakStatement.h",
   "$_src/sksl/ir/SkSLConstructor.cpp",
   "$_src/sksl/ir/SkSLConstructor.h",
+  "$_src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp",
+  "$_src/sksl/ir/SkSLConstructorDiagonalMatrix.h",
   "$_src/sksl/ir/SkSLContinueStatement.h",
   "$_src/sksl/ir/SkSLDiscardStatement.h",
   "$_src/sksl/ir/SkSLDoStatement.cpp",
diff --git a/resources/sksl/folding/MatrixFoldingES2.sksl b/resources/sksl/folding/MatrixFoldingES2.sksl
index 3090754..0194440 100644
--- a/resources/sksl/folding/MatrixFoldingES2.sksl
+++ b/resources/sksl/folding/MatrixFoldingES2.sksl
@@ -7,8 +7,12 @@
     ok = ok && !(float2x2(float2(1.0, 0.0), float2(1.0, 1.0)) ==
                  float2x2(float2(1.0, 0.0), float2(0.0, 1.0)));
 
-    ok = ok &&  (float2x2(1) == float2x2(1));
-    ok = ok && !(float2x2(1) == float2x2(0));
+    ok = ok &&  ( float2x2(1)  == float2x2(1));
+    ok = ok && !( float2x2(1)  == float2x2(0));
+    ok = ok &&  ( float2x2(-1) == -float2x2(1));
+    ok = ok &&  ( float2x2(0)  == -float2x2(0));
+    ok = ok &&  (-float2x2(-1) ==  float2x2(1));
+    ok = ok &&  (-float2x2(0)  == -float2x2(-0));
 
     ok = ok &&  (float2x2(1) == float2x2(float2(1.0, 0.0), float2(0.0, 1.0)));
     ok = ok && !(float2x2(2) == float2x2(float2(1.0, 0.0), float2(0.0, 1.0)));
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index e6f505a..47be736 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -41,6 +41,7 @@
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLExternalFunctionCall.h"
 #include "src/sksl/ir/SkSLExternalFunctionReference.h"
 #include "src/sksl/ir/SkSLFieldAccess.h"
@@ -99,8 +100,9 @@
                 } else if (lastArg->type() == *fContext.fTypes.fFloat3x3) {
                     // Determine the type of matrix for this call site
                     if (lastArg->isConstantOrUniform()) {
-                        if (lastArg->kind() == Expression::Kind::kVariableReference ||
-                            lastArg->kind() == Expression::Kind::kConstructor) {
+                        if (lastArg->is<VariableReference>() ||
+                            lastArg->is<Constructor>() ||
+                            lastArg->is<ConstructorDiagonalMatrix>()) {
                             // FIXME if this is a constant, we should parse the float3x3 constructor
                             // and determine if the resulting matrix introduces perspective.
                             fUsage.merge(SampleUsage::UniformMatrix(lastArg->description()));
@@ -759,6 +761,11 @@
             }
             return true;
         }
+        case Expression::Kind::kConstructorDiagonalMatrix: {
+            const ConstructorDiagonalMatrix& leftCtor = left.as<ConstructorDiagonalMatrix>();
+            const ConstructorDiagonalMatrix& rightCtor = right.as<ConstructorDiagonalMatrix>();
+            return IsSameExpressionTree(*leftCtor.argument(), *rightCtor.argument());
+        }
         case Expression::Kind::kFieldAccess:
             return left.as<FieldAccess>().fieldIndex() == right.as<FieldAccess>().fieldIndex() &&
                    IsSameExpressionTree(*left.as<FieldAccess>().base(),
@@ -1012,6 +1019,7 @@
             // ... expressions composed of both of the above
             case Expression::Kind::kBinary:
             case Expression::Kind::kConstructor:
+            case Expression::Kind::kConstructorDiagonalMatrix:
             case Expression::Kind::kFieldAccess:
             case Expression::Kind::kIndex:
             case Expression::Kind::kPrefix:
@@ -1140,6 +1148,9 @@
             }
             return false;
         }
+        case Expression::Kind::kConstructorDiagonalMatrix: {
+            return this->visitExpressionPtr(e.template as<ConstructorDiagonalMatrix>().argument());
+        }
         case Expression::Kind::kExternalFunctionCall: {
             auto& c = e.template as<ExternalFunctionCall>();
             for (auto& arg : c.arguments()) {
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index 88b9c0e..51d6b46 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -16,6 +16,7 @@
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBreakStatement.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDiscardStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
@@ -283,6 +284,13 @@
                 }
                 break;
             }
+            case Expression::Kind::kConstructorDiagonalMatrix: {
+                const ConstructorDiagonalMatrix& c = e->as<ConstructorDiagonalMatrix>();
+                this->writeCommand(Rehydrator::kConstructorDiagonalMatrix_Command);
+                this->write(c.type());
+                this->write(c.argument().get());
+                break;
+            }
             case Expression::Kind::kExternalFunctionCall:
             case Expression::Kind::kExternalFunctionReference:
                 SkDEBUGFAIL("unimplemented--not expected to be used from within an include file");
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index 68102d2..69d29cd 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -199,6 +199,10 @@
         case Expression::Kind::kConstructor:
             this->writeConstructor(expr.as<Constructor>(), parentPrecedence);
             break;
+        case Expression::Kind::kConstructorDiagonalMatrix:
+            this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(),
+                                                 parentPrecedence);
+            break;
         case Expression::Kind::kIntLiteral:
             this->writeIntLiteral(expr.as<IntLiteral>());
             break;
@@ -740,6 +744,14 @@
     this->write(")");
 }
 
+void GLSLCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+                                                       Precedence parentPrecedence) {
+    this->writeType(c.type());
+    this->write("(");
+    this->writeExpression(*c.argument(), Precedence::kSequence);
+    this->write(")");
+}
+
 void GLSLCodeGenerator::writeFragCoord() {
     if (!this->caps().canUseFragCoord()) {
         if (!fSetupFragCoordWorkaround) {
diff --git a/src/sksl/SkSLGLSLCodeGenerator.h b/src/sksl/SkSLGLSLCodeGenerator.h
index 3c81c39..33638e1 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.h
+++ b/src/sksl/SkSLGLSLCodeGenerator.h
@@ -21,6 +21,7 @@
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
 #include "src/sksl/ir/SkSLExtension.h"
 #include "src/sksl/ir/SkSLFieldAccess.h"
@@ -136,6 +137,9 @@
 
     void writeConstructor(const Constructor& c, Precedence parentPrecedence);
 
+    void writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+                                        Precedence parentPrecedence);
+
     virtual void writeFieldAccess(const FieldAccess& f);
 
     virtual void writeSwizzle(const Swizzle& swizzle);
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 41c7972..7915776 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -17,6 +17,7 @@
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLBreakStatement.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDiscardStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
@@ -312,6 +313,12 @@
             SkASSERT(inlinedCtor);
             return inlinedCtor;
         }
+        case Expression::Kind::kConstructorDiagonalMatrix: {
+            const ConstructorDiagonalMatrix& ctor = expression.as<ConstructorDiagonalMatrix>();
+            return ConstructorDiagonalMatrix::Make(*fContext, offset,
+                                                   *ctor.type().clone(symbolTableForExpression),
+                                                   expr(ctor.argument()));
+        }
         case Expression::Kind::kExternalFunctionCall: {
             const ExternalFunctionCall& externalCall = expression.as<ExternalFunctionCall>();
             return std::make_unique<ExternalFunctionCall>(offset, &externalCall.function(),
@@ -913,6 +920,11 @@
                 }
                 break;
             }
+            case Expression::Kind::kConstructorDiagonalMatrix: {
+                ConstructorDiagonalMatrix& ctorExpr = (*expr)->as<ConstructorDiagonalMatrix>();
+                this->visitExpression(&ctorExpr.argument());
+                break;
+            }
             case Expression::Kind::kExternalFunctionCall: {
                 ExternalFunctionCall& funcCallExpr = (*expr)->as<ExternalFunctionCall>();
                 for (std::unique_ptr<Expression>& arg : funcCallExpr.arguments()) {
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 0fddebe..ee87d9b 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -171,6 +171,10 @@
         case Expression::Kind::kConstructor:
             this->writeConstructor(expr.as<Constructor>(), parentPrecedence);
             break;
+        case Expression::Kind::kConstructorDiagonalMatrix:
+            this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(),
+                                                 parentPrecedence);
+            break;
         case Expression::Kind::kIntLiteral:
             this->writeIntLiteral(expr.as<IntLiteral>());
             break;
@@ -1104,6 +1108,14 @@
     this->write(constructorType.isArray() ? "}" : ")");
 }
 
+void MetalCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+                                                        Precedence parentPrecedence) {
+    this->writeType(c.type());
+    this->write("(");
+    this->writeExpression(*c.argument(), Precedence::kSequence);
+    this->write(")");
+}
+
 void MetalCodeGenerator::writeFragCoord() {
     if (fRTHeightName.length()) {
         this->write("float4(_fragCoord.x, ");
@@ -2234,6 +2246,9 @@
             }
             return result;
         }
+        case Expression::Kind::kConstructorDiagonalMatrix: {
+            return this->requirements(e->as<ConstructorDiagonalMatrix>().argument().get());
+        }
         case Expression::Kind::kFieldAccess: {
             const FieldAccess& f = e->as<FieldAccess>();
             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
diff --git a/src/sksl/SkSLMetalCodeGenerator.h b/src/sksl/SkSLMetalCodeGenerator.h
index 6b38c03..9bd500f 100644
--- a/src/sksl/SkSLMetalCodeGenerator.h
+++ b/src/sksl/SkSLMetalCodeGenerator.h
@@ -22,6 +22,7 @@
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
 #include "src/sksl/ir/SkSLExtension.h"
 #include "src/sksl/ir/SkSLFieldAccess.h"
@@ -229,6 +230,9 @@
 
     void writeConstructor(const Constructor& c, Precedence parentPrecedence);
 
+    void writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+                                        Precedence parentPrecedence);
+
     void writeFieldAccess(const FieldAccess& f);
 
     void writeSwizzle(const Swizzle& swizzle);
diff --git a/src/sksl/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
index 187ba20..c67d634 100644
--- a/src/sksl/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
@@ -14,6 +14,7 @@
 #include "src/sksl/SkSLStringStream.h"
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLExpressionStatement.h"
 #include "src/sksl/ir/SkSLFieldAccess.h"
 #include "src/sksl/ir/SkSLForStatement.h"
@@ -74,6 +75,8 @@
     void writeExpression(const Expression& expr, Precedence parentPrecedence);
     void writeFunctionCall(const FunctionCall& c);
     void writeConstructor(const Constructor& c, Precedence parentPrecedence);
+    void writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+                                        Precedence parentPrecedence);
     void writeFieldAccess(const FieldAccess& f);
     void writeSwizzle(const Swizzle& swizzle);
     void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
@@ -411,6 +414,10 @@
         case Expression::Kind::kConstructor:
             this->writeConstructor(expr.as<Constructor>(), parentPrecedence);
             break;
+        case Expression::Kind::kConstructorDiagonalMatrix:
+            this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(),
+                                                 parentPrecedence);
+            break;
         case Expression::Kind::kFieldAccess:
             this->writeFieldAccess(expr.as<FieldAccess>());
             break;
@@ -455,6 +462,14 @@
     this->write(")");
 }
 
+void PipelineStageCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+                                                                Precedence parentPrecedence) {
+    this->writeType(c.type());
+    this->write("(");
+    this->writeExpression(*c.argument(), Precedence::kSequence);
+    this->write(")");
+}
+
 void PipelineStageCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
     this->writeExpression(*expr.base(), Precedence::kPostfix);
     this->write("[");
diff --git a/src/sksl/SkSLRehydrator.cpp b/src/sksl/SkSLRehydrator.cpp
index e289792..a8cc648 100644
--- a/src/sksl/SkSLRehydrator.cpp
+++ b/src/sksl/SkSLRehydrator.cpp
@@ -15,6 +15,8 @@
 #include "include/private/SkSLStatement.h"
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBreakStatement.h"
+#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDiscardStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
@@ -451,6 +453,11 @@
             SkASSERT(ctor);
             return ctor;
         }
+        case Rehydrator::kConstructorDiagonalMatrix_Command: {
+            const Type* type = this->type();
+            return ConstructorDiagonalMatrix::Make(fContext, /*offset=*/-1, *type,
+                                                   this->expression());
+        }
         case Rehydrator::kFieldAccess_Command: {
             std::unique_ptr<Expression> base = this->expression();
             int index = this->readU8();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index daab463..9bedcee 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -711,6 +711,8 @@
             return this->writeBoolLiteral(expr.as<BoolLiteral>());
         case Expression::Kind::kConstructor:
             return this->writeConstructor(expr.as<Constructor>(), out);
+        case Expression::Kind::kConstructorDiagonalMatrix:
+            return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
         case Expression::Kind::kIntLiteral:
             return this->writeIntLiteral(expr.as<IntLiteral>());
         case Expression::Kind::kFieldAccess:
@@ -1688,6 +1690,21 @@
     }
 }
 
+SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+                                                         OutputStream& out) {
+    const Type& type = c.type();
+    SkASSERT(type.isMatrix());
+    SkASSERT(c.argument()->type().isScalar());
+
+    // Write out the scalar argument.
+    SpvId argument = this->writeExpression(*c.argument(), out);
+
+    // Build the diagonal matrix.
+    SpvId result = this->nextId(&type);
+    this->writeUniformScaleMatrix(result, argument, type, out);
+    return result;
+}
+
 static SpvStorageClass_ get_storage_class(const Variable& var,
                                           SpvStorageClass_ fallbackStorageClass) {
     const Modifiers& modifiers = var.modifiers();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 7b032c4..88fe274 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -22,6 +22,7 @@
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
 #include "src/sksl/ir/SkSLFieldAccess.h"
 #include "src/sksl/ir/SkSLFloatLiteral.h"
@@ -287,6 +288,8 @@
 
     SpvId writeConstructor(const Constructor& c, OutputStream& out);
 
+    SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
+
     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
 
     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
diff --git a/src/sksl/SkSLVMGenerator.cpp b/src/sksl/SkSLVMGenerator.cpp
index 250e123..3130ec7 100644
--- a/src/sksl/SkSLVMGenerator.cpp
+++ b/src/sksl/SkSLVMGenerator.cpp
@@ -18,6 +18,7 @@
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLBreakStatement.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLContinueStatement.h"
 #include "src/sksl/ir/SkSLDoStatement.h"
 #include "src/sksl/ir/SkSLExpressionStatement.h"
@@ -243,6 +244,7 @@
     Value writeExpression(const Expression& expr);
     Value writeBinaryExpression(const BinaryExpression& b);
     Value writeConstructor(const Constructor& c);
+    Value writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
     Value writeFunctionCall(const FunctionCall& c);
     Value writeExternalFunctionCall(const ExternalFunctionCall& c);
     Value writeFieldAccess(const FieldAccess& expr);
@@ -770,6 +772,26 @@
     return {};
 }
 
+Value SkVMGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c) {
+    const Type& dstType = c.type();
+    SkASSERT(dstType.isMatrix());
+    SkASSERT(c.argument()->type() == dstType.componentType());
+
+    Value src = this->writeExpression(*c.argument());
+    Value dst(dstType.rows() * dstType.columns());
+    size_t dstIndex = 0;
+
+    // Matrix-from-scalar builds a diagonal scale matrix
+    for (int c = 0; c < dstType.columns(); ++c) {
+        for (int r = 0; r < dstType.rows(); ++r) {
+            dst[dstIndex++] = (c == r ? f32(src) : fBuilder->splat(0.0f));
+        }
+    }
+
+    SkASSERT(dstIndex == dst.slots());
+    return dst;
+}
+
 size_t SkVMGenerator::fieldSlotOffset(const FieldAccess& expr) {
     size_t offset = 0;
     for (int i = 0; i < expr.fieldIndex(); ++i) {
@@ -1422,6 +1444,8 @@
             return fBuilder->splat(e.as<BoolLiteral>().value() ? ~0 : 0);
         case Expression::Kind::kConstructor:
             return this->writeConstructor(e.as<Constructor>());
+        case Expression::Kind::kConstructorDiagonalMatrix:
+            return this->writeConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
         case Expression::Kind::kFieldAccess:
             return this->writeFieldAccess(e.as<FieldAccess>());
         case Expression::Kind::kIndex:
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index 55e0218..e56a7cb 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -8,6 +8,7 @@
 #include "src/sksl/ir/SkSLConstructor.h"
 
 #include "src/sksl/ir/SkSLBoolLiteral.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLFloatLiteral.h"
 #include "src/sksl/ir/SkSLIntLiteral.h"
 #include "src/sksl/ir/SkSLPrefixExpression.h"
@@ -80,11 +81,19 @@
         // A constructor containing a single scalar is a splat (for vectors) or diagonal matrix (for
         // matrices). In either event, it's legal regardless of the scalar's type. Synthesize an
         // explicit conversion to the proper type (this is a no-op if it's unnecessary).
-        ExpressionArray castArgs;
-        castArgs.push_back(Constructor::Convert(context, offset, type.componentType(),
-                                                std::move(args)));
-        SkASSERT(castArgs.front());
-        return std::make_unique<Constructor>(offset, type, std::move(castArgs));
+        std::unique_ptr<Expression> typecast = Constructor::Convert(context, offset,
+                                                                    type.componentType(),
+                                                                    std::move(args));
+        SkASSERT(typecast);
+
+        if (type.isMatrix()) {
+            // Matrix-from-scalar creates a diagonal matrix.
+            return ConstructorDiagonalMatrix::Make(context, offset, type, std::move(typecast));
+        }
+
+        ExpressionArray typecastArgs;
+        typecastArgs.push_back(std::move(typecast));
+        return std::make_unique<Constructor>(offset, type, std::move(typecastArgs));
     }
 
     int expected = type.rows() * type.columns();
@@ -249,6 +258,9 @@
 }
 
 Expression::ComparisonResult Constructor::compareConstant(const Expression& other) const {
+    if (other.is<ConstructorDiagonalMatrix>()) {
+        return other.compareConstant(*this);
+    }
     if (!other.is<Constructor>()) {
         return ComparisonResult::kUnknown;
     }
@@ -404,11 +416,12 @@
             return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
         }
         if (argType.isMatrix()) {
-            SkASSERT(this->arguments()[0]->is<Constructor>());
+            SkASSERT(this->arguments()[0]->is<Constructor>() ||
+                     this->arguments()[0]->is<ConstructorDiagonalMatrix>());
             // single matrix argument. make sure we're within the argument's bounds.
             if (col < argType.columns() && row < argType.rows()) {
                 // within bounds, defer to argument
-                return this->arguments()[0]->as<Constructor>().getMatComponent(col, row);
+                return this->arguments()[0]->getMatComponent(col, row);
             }
             // out of bounds
             return 0.0;
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
new file mode 100644
index 0000000..4645fc4
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+
+#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLType.h"
+
+namespace SkSL {
+
+std::unique_ptr<Expression> ConstructorDiagonalMatrix::Make(const Context& context,
+                                                            int offset,
+                                                            const Type& type,
+                                                            std::unique_ptr<Expression> arg) {
+    SkASSERT(type.isMatrix());
+    SkASSERT(arg->type() == type.componentType());
+    return std::make_unique<ConstructorDiagonalMatrix>(offset, type, std::move(arg));
+}
+
+Expression::ComparisonResult ConstructorDiagonalMatrix::compareConstant(
+        const Expression& other) const {
+    SkASSERT(this->type() == other.type());
+
+    // We know that these are the only SkSL constant expressions which can hold a matrix.
+    if (!other.is<Constructor>() && !other.is<ConstructorDiagonalMatrix>()) {
+        return ComparisonResult::kNotEqual;
+    }
+
+    // The other constructor might not be DiagonalMatrix-based, so we check each cell individually.
+    for (int col = 0; col < this->type().columns(); col++) {
+        for (int row = 0; row < this->type().rows(); row++) {
+            if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
+                return ComparisonResult::kNotEqual;
+            }
+        }
+    }
+
+    return ComparisonResult::kEqual;
+}
+
+SKSL_FLOAT ConstructorDiagonalMatrix::getMatComponent(int col, int row) const {
+    SkASSERT(this->isCompileTimeConstant());
+    SkASSERT(col >= 0);
+    SkASSERT(row >= 0);
+    SkASSERT(col < this->type().columns());
+    SkASSERT(row < this->type().rows());
+
+    // Our matrix is of the form:
+    //  |x 0 0|
+    //  |0 x 0|
+    //  |0 0 x|
+    return (col == row) ? this->argument()->getConstantFloat() : 0.0;
+}
+
+}  // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
new file mode 100644
index 0000000..ae76848
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_CONSTRUCTOR_DIAGONAL_MATRIX
+#define SKSL_CONSTRUCTOR_DIAGONAL_MATRIX
+
+#include "include/private/SkSLDefines.h"
+#include "src/sksl/SkSLContext.h"
+#include "src/sksl/ir/SkSLExpression.h"
+
+#include <memory>
+
+namespace SkSL {
+
+/**
+ * Represents the construction of a diagonal matrix, such as `half3x3(n)`.
+ *
+ * These always contain exactly 1 scalar.
+ */
+class ConstructorDiagonalMatrix final : public Expression {
+public:
+    static constexpr Kind kExpressionKind = Kind::kConstructorDiagonalMatrix;
+
+    ConstructorDiagonalMatrix(int offset, const Type& type, std::unique_ptr<Expression> arg)
+        : INHERITED(offset, kExpressionKind, &type)
+        , fArgument(std::move(arg)) {}
+
+    static std::unique_ptr<Expression> Make(const Context& context,
+                                            int offset,
+                                            const Type& type,
+                                            std::unique_ptr<Expression> arg);
+
+    std::unique_ptr<Expression>& argument() {
+        return fArgument;
+    }
+
+    const std::unique_ptr<Expression>& argument() const {
+        return fArgument;
+    }
+
+    bool hasProperty(Property property) const override {
+        return argument()->hasProperty(property);
+    }
+
+    std::unique_ptr<Expression> clone() const override {
+        return std::make_unique<ConstructorDiagonalMatrix>(fOffset, this->type(),
+                                                           argument()->clone());
+    }
+
+    String description() const override {
+        return this->type().description() + "(" + argument()->description() + ")";
+    }
+
+    const Type& componentType() const {
+        return this->argument()->type();
+    }
+
+    bool isCompileTimeConstant() const override {
+        return argument()->isCompileTimeConstant();
+    }
+
+    bool isConstantOrUniform() const override {
+        return argument()->isConstantOrUniform();
+    }
+
+    ComparisonResult compareConstant(const Expression& other) const override;
+
+    SKSL_FLOAT getMatComponent(int col, int row) const override;
+
+private:
+    std::unique_ptr<Expression> fArgument;
+
+    using INHERITED = Expression;
+};
+
+}  // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index cf9cdb0..af9c774 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -30,6 +30,7 @@
         kBoolLiteral,
         kCodeString,
         kConstructor,
+        kConstructorDiagonalMatrix,
         kDefined,
         kExternalFunctionCall,
         kExternalFunctionReference,
diff --git a/src/sksl/ir/SkSLPrefixExpression.cpp b/src/sksl/ir/SkSLPrefixExpression.cpp
index a06fccf..becba3c 100644
--- a/src/sksl/ir/SkSLPrefixExpression.cpp
+++ b/src/sksl/ir/SkSLPrefixExpression.cpp
@@ -10,6 +10,7 @@
 #include "src/sksl/SkSLConstantFolder.h"
 #include "src/sksl/ir/SkSLBoolLiteral.h"
 #include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
 #include "src/sksl/ir/SkSLFloatLiteral.h"
 #include "src/sksl/ir/SkSLIntLiteral.h"
 
@@ -41,6 +42,16 @@
             }
             break;
 
+        case Expression::Kind::kConstructorDiagonalMatrix:
+            // Convert `-matrix(literal)` into `matrix(-literal)`.
+            if (context.fConfig->fSettings.fOptimize && value->isCompileTimeConstant()) {
+                ConstructorDiagonalMatrix& ctor = operand->as<ConstructorDiagonalMatrix>();
+                return ConstructorDiagonalMatrix::Make(
+                        context, ctor.fOffset, ctor.type(),
+                        negate_operand(context, std::move(ctor.argument())));
+            }
+            break;
+
         case Expression::Kind::kConstructor:
             // To be consistent with prior behavior, the conversion of a negated constructor into a
             // constructor of negative values is only performed when optimization is on.
@@ -166,11 +177,10 @@
             break;
 
         default:
-            SK_ABORT("unsupported prefix operator\n");
+            SK_ABORT("unsupported prefix operator");
     }
 
     return PrefixExpression::Make(context, op, std::move(base));
-
 }
 
 std::unique_ptr<Expression> PrefixExpression::Make(const Context& context, Operator op,