Implement findMSB intrinsic in Metal.
findMSB has one special trick that Metal doesn't naturally have an
equivalent for, specifically in its treatment of negative numbers.
findMSB searches negative numbers for a zero bit, not a one bit!
We emulate this behavior in Metal using select(n, ~n, n<0).
Change-Id: I861c6b8fb3dc5427643cd8c68a39a53f1959bff3
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/343996
Commit-Queue: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index af94a88..8cafe20 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -50,6 +50,7 @@
fIntrinsicMap[String("dot")] = SPECIAL(Dot);
fIntrinsicMap[String("faceforward")] = SPECIAL(Faceforward);
fIntrinsicMap[String("findLSB")] = SPECIAL(FindLSB);
+ fIntrinsicMap[String("findMSB")] = SPECIAL(FindMSB);
fIntrinsicMap[String("length")] = SPECIAL(Length);
fIntrinsicMap[String("mod")] = SPECIAL(Mod);
fIntrinsicMap[String("normalize")] = SPECIAL(Normalize);
@@ -700,6 +701,54 @@
this->write("(0)))");
break;
}
+ case kFindMSB_SpecialIntrinsic: {
+ // Create a temp variable to store the expression, to avoid double-evaluating it.
+ String skTemp1 = this->getTempVariable(arguments[0]->type());
+ String exprType = this->typeName(arguments[0]->type());
+
+ // GLSL findMSB is actually quite different from Metal's clz:
+ // - For signed negative numbers, it returns the first zero bit, not the first one bit!
+ // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
+
+ // (_skTemp1 = (.....),
+ this->write("(");
+ this->write(skTemp1);
+ this->write(" = (");
+ this->writeExpression(*arguments[0], kSequence_Precedence);
+ this->write("), ");
+
+ // Signed input types might be negative; we need another helper variable to negate the
+ // input (since we can only find one bits, not zero bits).
+ String skTemp2;
+ if (arguments[0]->type().isSigned()) {
+ // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
+ skTemp2 = this->getTempVariable(arguments[0]->type());
+ this->write(skTemp2);
+ this->write(" = (select(");
+ this->write(skTemp1);
+ this->write(", ~");
+ this->write(skTemp1);
+ this->write(", ");
+ this->write(skTemp1);
+ this->write(" < 0)), ");
+ } else {
+ skTemp2 = skTemp1;
+ }
+
+ // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
+ this->write("select(");
+ this->write(this->typeName(c.type()));
+ this->write("(clz(");
+ this->write(skTemp2);
+ this->write(")), ");
+ this->write(this->typeName(c.type()));
+ this->write("(-1), ");
+ this->write(skTemp2);
+ this->write(" == ");
+ this->write(exprType);
+ this->write("(0)))");
+ break;
+ }
default:
ABORT("unsupported special intrinsic kind");
}
diff --git a/src/sksl/SkSLMetalCodeGenerator.h b/src/sksl/SkSLMetalCodeGenerator.h
index cfe955a..7fbe031 100644
--- a/src/sksl/SkSLMetalCodeGenerator.h
+++ b/src/sksl/SkSLMetalCodeGenerator.h
@@ -110,6 +110,7 @@
kDot_SpecialIntrinsic,
kFaceforward_SpecialIntrinsic,
kFindLSB_SpecialIntrinsic,
+ kFindMSB_SpecialIntrinsic,
kLength_SpecialIntrinsic,
kMod_SpecialIntrinsic,
kNormalize_SpecialIntrinsic,
diff --git a/tests/sksl/intrinsics/golden/FindMSB.metal b/tests/sksl/intrinsics/golden/FindMSB.metal
index a6d94e0..3a9e1e6 100644
--- a/tests/sksl/intrinsics/golden/FindMSB.metal
+++ b/tests/sksl/intrinsics/golden/FindMSB.metal
@@ -13,7 +13,10 @@
fragment Outputs fragmentMain(Inputs _in [[stage_in]], bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]) {
Outputs _outputStruct;
thread Outputs* _out = &_outputStruct;
- _out->sk_FragColor.x = float(findMSB(_in.a));
- _out->sk_FragColor.y = float(findMSB(_in.b));
+ int _skTemp0;
+ int _skTemp1;
+ uint _skTemp2;
+ _out->sk_FragColor.x = float((_skTemp0 = (_in.a), _skTemp1 = (select(_skTemp0, ~_skTemp0, _skTemp0 < 0)), select(int(clz(_skTemp1)), int(-1), _skTemp1 == int(0))));
+ _out->sk_FragColor.y = float((_skTemp2 = (_in.b), select(int(clz(_skTemp2)), int(-1), _skTemp2 == uint(0))));
return *_out;
}