added support for SkSL unpremul function

Change-Id: I970f1ad0dd0859448c874498fe02342f8abc3aa3
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/242897
Reviewed-by: Brian Salomon <bsalomon@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index b7ffd1c..34728d5 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -141,6 +141,7 @@
                                                                 SpvOpUndef);
     fIntrinsicMap[String("EmitVertex")]       = ALL_SPIRV(EmitVertex);
     fIntrinsicMap[String("EndPrimitive")]     = ALL_SPIRV(EndPrimitive);
+    fIntrinsicMap[String("unpremul")]         = SPECIAL(Unpremul);
 // interpolateAt* not yet supported...
 }
 
@@ -1002,6 +1003,26 @@
                                                GLSLstd450UClamp, spvArgs, out);
             break;
         }
+        case kUnpremul_SpecialIntrinsic: {
+            SpvId color = this->writeExpression(*c.fArguments[0], out);
+            SpvId a = this->writeSwizzle(*fContext.fHalf_Type, c.fArguments[0]->fType, color, { 3 },
+                                         out);
+            FloatLiteral min(fContext, -1, SKSL_UNPREMUL_MIN);
+            SpvId minId = this->writeFloatLiteral(min);
+            SpvId nonZeroAlpha = this->nextId();
+            this->writeGLSLExtendedInstruction(*fContext.fHalf_Type, nonZeroAlpha, GLSLstd450FMax,
+                                               SpvOpUndef, SpvOpUndef, { a, minId }, out);
+            SpvId rgb = this->writeSwizzle(*fContext.fHalf3_Type, *fContext.fHalf4_Type, color,
+                                           { 0, 1, 2 }, out);
+            SpvId scaled = this->writeBinaryExpression(*fContext.fHalf3_Type, rgb, Token::SLASH,
+                                                       *fContext.fFloat_Type, nonZeroAlpha,
+                                                       *fContext.fHalf3_Type, out);
+            this->writeOpCode(SpvOpCompositeConstruct, 5, out);
+            this->writeWord(this->getType(c.fType), out);
+            this->writeWord(result, out);
+            this->writeWord(scaled, out);
+            this->writeWord(nonZeroAlpha, out);
+        }
     }
     return result;
 }
@@ -1912,19 +1933,24 @@
 }
 
 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
-    SpvId base = this->writeExpression(*swizzle.fBase, out);
+    return this->writeSwizzle(swizzle.fType, swizzle.fBase->fType,
+                              this->writeExpression(*swizzle.fBase, out), swizzle.fComponents, out);
+}
+
+SpvId SPIRVCodeGenerator::writeSwizzle(const Type& type, const Type& baseType, SpvId base,
+                                       const std::vector<int> components, OutputStream& out) {
     SpvId result = this->nextId();
-    size_t count = swizzle.fComponents.size();
+    size_t count = components.size();
     if (count == 1) {
-        this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
-                               swizzle.fComponents[0], out);
+        this->writeInstruction(SpvOpCompositeExtract, this->getType(type), result, base,
+                               components[0], out);
     } else {
         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
-        this->writeWord(this->getType(swizzle.fType), out);
+        this->writeWord(this->getType(type), out);
         this->writeWord(result, out);
         this->writeWord(base, out);
         SpvId other = base;
-        for (int c : swizzle.fComponents) {
+        for (int c : components) {
             if (c < 0) {
                 if (!fConstantZeroOneVector) {
                     FloatLiteral zero(fContext, -1, 0);
@@ -1944,11 +1970,11 @@
             }
         }
         this->writeWord(other, out);
-        for (int component : swizzle.fComponents) {
+        for (int component : components) {
             if (component == SKSL_SWIZZLE_0) {
-                this->writeWord(swizzle.fBase->fType.columns(), out);
+                this->writeWord(baseType.columns(), out);
             } else if (component == SKSL_SWIZZLE_1) {
-                this->writeWord(swizzle.fBase->fType.columns() + 1, out);
+                this->writeWord(baseType.columns() + 1, out);
             } else {
                 this->writeWord(component, out);
             }
@@ -2068,6 +2094,11 @@
 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
                                                 const Type& rightType, SpvId rhs,
                                                 const Type& resultType, OutputStream& out) {
+    // it's important to handle comma early, so we don't end up vectorizing the operands
+    if (op == Token::COMMA) {
+        return rhs;
+    }
+
     Type tmp("<invalid>");
     // overall type we are operating on: float2, int, uint4...
     const Type* operandType;
@@ -2260,8 +2291,6 @@
         case Token::BITWISEXOR:
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
-        case Token::COMMA:
-            return rhs;
         default:
             SkASSERT(false);
             return -1;