Update certain instrinsic calls in SkSL SPIR-V gen to not mix vectors and scalars.

Functions like min, max, clamp, and mix cannot intermix vectors and scalars as
operands and return value. This updates SkSL to promote scalars to vectors if
need be before calling these functions.

Bug: skia:7653
Change-Id: I13f98d452978e3f15bafddea638b7bbe313d98ce
Reviewed-on: https://skia-review.googlesource.com/110660
Commit-Queue: Greg Daniel <egdaniel@google.com>
Reviewed-by: Greg Daniel <egdaniel@google.com>
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index d01a82d..fbe8bf0 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -66,12 +66,12 @@
     fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
     fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
     fIntrinsicMap[String("mod")]           = SPECIAL(Mod);
-    fIntrinsicMap[String("min")]           = BY_TYPE_GLSL(FMin, SMin, UMin);
-    fIntrinsicMap[String("max")]           = BY_TYPE_GLSL(FMax, SMax, UMax);
-    fIntrinsicMap[String("clamp")]         = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
+    fIntrinsicMap[String("min")]           = SPECIAL(Min);
+    fIntrinsicMap[String("max")]           = SPECIAL(Max);
+    fIntrinsicMap[String("clamp")]         = SPECIAL(Clamp);
     fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
-    fIntrinsicMap[String("mix")]           = ALL_GLSL(FMix);
+    fIntrinsicMap[String("mix")]           = SPECIAL(Mix);
     fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
     fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
     fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
@@ -721,6 +721,62 @@
     }
 }
 
+std::vector<SpvId> SPIRVCodeGenerator::vectorize(
+                                               const std::vector<std::unique_ptr<Expression>>& args,
+                                               OutputStream& out) {
+    int vectorSize = 0;
+    for (const auto& a : args) {
+        if (a->fType.kind() == Type::kVector_Kind) {
+            if (vectorSize) {
+                ASSERT(a->fType.columns() == vectorSize);
+            }
+            else {
+                vectorSize = a->fType.columns();
+            }
+        }
+    }
+    std::vector<SpvId> result;
+    for (const auto& a : args) {
+        SpvId raw = this->writeExpression(*a, out);
+        if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
+            SpvId vector = this->nextId();
+            this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
+            this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
+            this->writeWord(vector, out);
+            for (int i = 0; i < vectorSize; i++) {
+                this->writeWord(raw, out);
+            }
+            result.push_back(vector);
+        } else {
+            result.push_back(raw);
+        }
+    }
+    return result;
+}
+
+void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
+                                                      SpvId signedInst, SpvId unsignedInst,
+                                                      const std::vector<SpvId>& args,
+                                                      OutputStream& out) {
+    this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
+    this->writeWord(this->getType(type), out);
+    this->writeWord(id, out);
+    this->writeWord(fGLSLExtendedInstructions, out);
+
+    if (is_float(fContext, type)) {
+        this->writeWord(floatInst, out);
+    } else if (is_signed(fContext, type)) {
+        this->writeWord(signedInst, out);
+    } else if (is_unsigned(fContext, type)) {
+        this->writeWord(unsignedInst, out);
+    } else {
+        ASSERT(false);
+    }
+    for (SpvId a : args) {
+        this->writeWord(a, out);
+    }
+}
+
 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
                                                 OutputStream& out) {
     SpvId result = this->nextId();
@@ -838,22 +894,8 @@
             break;
         }
         case kMod_SpecialIntrinsic: {
-            ASSERT(c.fArguments.size() == 2);
-            SpvId arg1 = this->writeExpression(*c.fArguments[0], out);
-            SpvId arg2 = this->writeExpression(*c.fArguments[1], out);
-            if (c.fArguments[0]->fType != c.fArguments[1]->fType) {
-                // we have mod(vector, scalar), but SPIR-V wants mod(vector, vector)
-                ASSERT(c.fArguments[0]->fType.componentType() == c.fArguments[1]->fType);
-                SpvId scalar = arg2;
-                const Type& type = c.fArguments[0]->fType;
-                arg2 = this->nextId();
-                this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), out);
-                this->writeWord(this->getType(type), out);
-                this->writeWord(arg2, out);
-                for (int i = 0; i < type.columns(); i++) {
-                    this->writeWord(scalar, out);
-                }
-            }
+            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
+            ASSERT(args.size() == 2);
             const Type& operandType = c.fArguments[0]->fType;
             SpvOp_ op;
             if (is_float(fContext, operandType)) {
@@ -869,8 +911,37 @@
             this->writeOpCode(op, 5, out);
             this->writeWord(this->getType(operandType), out);
             this->writeWord(result, out);
-            this->writeWord(arg1, out);
-            this->writeWord(arg2, out);
+            this->writeWord(args[0], out);
+            this->writeWord(args[1], out);
+            break;
+        }
+        case kClamp_SpecialIntrinsic: {
+            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
+            ASSERT(args.size() == 3);
+            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
+                                               GLSLstd450UClamp, args, out);
+            break;
+        }
+        case kMax_SpecialIntrinsic: {
+            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
+            ASSERT(args.size() == 2);
+            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
+                                               GLSLstd450UMax, args, out);
+            break;
+        }
+        case kMin_SpecialIntrinsic: {
+            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
+            ASSERT(args.size() == 2);
+            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
+                                               GLSLstd450UMin, args, out);
+            break;
+        }
+        case kMix_SpecialIntrinsic: {
+            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
+            ASSERT(args.size() == 3);
+            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
+                                               SpvOpUndef, args, out);
+            break;
         }
     }
     return result;
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 0d51af5..4bd8d86 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -90,6 +90,10 @@
 
     enum SpecialIntrinsic {
         kAtan_SpecialIntrinsic,
+        kClamp_SpecialIntrinsic,
+        kMax_SpecialIntrinsic,
+        kMin_SpecialIntrinsic,
+        kMix_SpecialIntrinsic,
         kMod_SpecialIntrinsic,
         kSubpassLoad_SpecialIntrinsic,
         kTexelFetch_SpecialIntrinsic,
@@ -149,6 +153,20 @@
 
     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
 
+
+    void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
+                                      SpvId signedInst, SpvId unsignedInst,
+                                      const std::vector<SpvId>& args, OutputStream& out);
+
+    /**
+     * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
+     * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
+     * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
+     * vec2, vec3).
+     */
+    std::vector<SpvId> vectorize(const std::vector<std::unique_ptr<Expression>>& args,
+                                 OutputStream& out);
+
     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
 
     SpvId writeConstantVector(const Constructor& c);