added support for SkSL unpremul function
Change-Id: I970f1ad0dd0859448c874498fe02342f8abc3aa3
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/242897
Reviewed-by: Brian Salomon <bsalomon@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index b7ffd1c..34728d5 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -141,6 +141,7 @@
SpvOpUndef);
fIntrinsicMap[String("EmitVertex")] = ALL_SPIRV(EmitVertex);
fIntrinsicMap[String("EndPrimitive")] = ALL_SPIRV(EndPrimitive);
+ fIntrinsicMap[String("unpremul")] = SPECIAL(Unpremul);
// interpolateAt* not yet supported...
}
@@ -1002,6 +1003,26 @@
GLSLstd450UClamp, spvArgs, out);
break;
}
+ case kUnpremul_SpecialIntrinsic: {
+ SpvId color = this->writeExpression(*c.fArguments[0], out);
+ SpvId a = this->writeSwizzle(*fContext.fHalf_Type, c.fArguments[0]->fType, color, { 3 },
+ out);
+ FloatLiteral min(fContext, -1, SKSL_UNPREMUL_MIN);
+ SpvId minId = this->writeFloatLiteral(min);
+ SpvId nonZeroAlpha = this->nextId();
+ this->writeGLSLExtendedInstruction(*fContext.fHalf_Type, nonZeroAlpha, GLSLstd450FMax,
+ SpvOpUndef, SpvOpUndef, { a, minId }, out);
+ SpvId rgb = this->writeSwizzle(*fContext.fHalf3_Type, *fContext.fHalf4_Type, color,
+ { 0, 1, 2 }, out);
+ SpvId scaled = this->writeBinaryExpression(*fContext.fHalf3_Type, rgb, Token::SLASH,
+ *fContext.fFloat_Type, nonZeroAlpha,
+ *fContext.fHalf3_Type, out);
+ this->writeOpCode(SpvOpCompositeConstruct, 5, out);
+ this->writeWord(this->getType(c.fType), out);
+ this->writeWord(result, out);
+ this->writeWord(scaled, out);
+ this->writeWord(nonZeroAlpha, out);
+ }
}
return result;
}
@@ -1912,19 +1933,24 @@
}
SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
- SpvId base = this->writeExpression(*swizzle.fBase, out);
+ return this->writeSwizzle(swizzle.fType, swizzle.fBase->fType,
+ this->writeExpression(*swizzle.fBase, out), swizzle.fComponents, out);
+}
+
+SpvId SPIRVCodeGenerator::writeSwizzle(const Type& type, const Type& baseType, SpvId base,
+ const std::vector<int> components, OutputStream& out) {
SpvId result = this->nextId();
- size_t count = swizzle.fComponents.size();
+ size_t count = components.size();
if (count == 1) {
- this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
- swizzle.fComponents[0], out);
+ this->writeInstruction(SpvOpCompositeExtract, this->getType(type), result, base,
+ components[0], out);
} else {
this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
- this->writeWord(this->getType(swizzle.fType), out);
+ this->writeWord(this->getType(type), out);
this->writeWord(result, out);
this->writeWord(base, out);
SpvId other = base;
- for (int c : swizzle.fComponents) {
+ for (int c : components) {
if (c < 0) {
if (!fConstantZeroOneVector) {
FloatLiteral zero(fContext, -1, 0);
@@ -1944,11 +1970,11 @@
}
}
this->writeWord(other, out);
- for (int component : swizzle.fComponents) {
+ for (int component : components) {
if (component == SKSL_SWIZZLE_0) {
- this->writeWord(swizzle.fBase->fType.columns(), out);
+ this->writeWord(baseType.columns(), out);
} else if (component == SKSL_SWIZZLE_1) {
- this->writeWord(swizzle.fBase->fType.columns() + 1, out);
+ this->writeWord(baseType.columns() + 1, out);
} else {
this->writeWord(component, out);
}
@@ -2068,6 +2094,11 @@
SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
const Type& rightType, SpvId rhs,
const Type& resultType, OutputStream& out) {
+ // it's important to handle comma early, so we don't end up vectorizing the operands
+ if (op == Token::COMMA) {
+ return rhs;
+ }
+
Type tmp("<invalid>");
// overall type we are operating on: float2, int, uint4...
const Type* operandType;
@@ -2260,8 +2291,6 @@
case Token::BITWISEXOR:
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
- case Token::COMMA:
- return rhs;
default:
SkASSERT(false);
return -1;