Improve Metal support for out parameters.
We now insert helper functions which defer the assignment of out-
parameters back into their original variables to the end of the
function call. This allows us to match the semantics listed the GLSL
spec in section 6.1.1:
"All arguments are evaluated at call time, exactly once, in order, from
left to right. [...] Evaluation of an out parameter results in an
l-value that is used to copy out a value when the function returns.
Evaluation of an inout parameter results in both a value and an l-value;
the value is copied to the formal parameter at call time and the lvalue
is used to copy out a value when the function returns."
This technique also allows us to support swizzled out-parameters in
Metal, by reading the swizzle into a temp variable, calling the original
function, and then re-assigning the result back into the original
swizzle expression.
At present, we don't deduplicate these helper functions, so in theory
there could be a fair amount of redundant code generated if a function
with out parameters is called many times in a row. The cost of properly
deduplicating them is probably larger than the benefit in the 99% case.
Change-Id: Iefc922ac9e2b24ef2ff1e9dacb17a735a75ec8ea
Bug: skia:10855, skia:11052
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/341162
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 7b17b22..cfc94ff 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -165,8 +165,6 @@
this->write(to_string(type.columns()));
}
this->write("]");
-
- this->writeArrayDimensions(type.componentType());
}
}
@@ -264,10 +262,123 @@
}
}
-String MetalCodeGenerator::getOutParamHelper(const FunctionDeclaration& function,
- const ExpressionArray& arguments) {
- // TODO: actually synthesize helper method.
- return String::printf("/*needs swizzle fix*/ %s", String(function.name()).c_str());
+String MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
+ const ExpressionArray& arguments,
+ const SkTArray<VariableReference*>& outVars) {
+ AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
+ const FunctionDeclaration& function = call.function();
+
+ String name = "_skOutParamHelper" + to_string(fSwizzleHelperCount++) + "_" + function.name();
+ const char* separator = "";
+
+ // Emit a prototype for the function we'll be calling through to in our helper.
+ if (!function.isBuiltin()) {
+ this->writeFunctionDeclaration(function);
+ this->writeLine(";");
+ }
+
+ // Synthesize a helper function that takes the same inputs as `function`, except in places where
+ // `outVars` is non-null; in those places, we take the type of the VariableReference.
+ //
+ // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
+ this->writeBaseType(call.type());
+ this->write(" ");
+ this->write(name);
+ this->write("(");
+ this->writeFunctionRequirementParams(function, separator);
+
+ SkASSERT(outVars.size() == arguments.size());
+ SkASSERT(outVars.size() == function.parameters().size());
+
+ for (int index = 0; index < arguments.count(); ++index) {
+ this->write(separator);
+ separator = ", ";
+
+ const Variable* param = function.parameters()[index];
+ this->writeModifiers(param->modifiers(), /*globalContext=*/false);
+
+ const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
+ this->writeBaseType(*type);
+
+ if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
+ this->write("&");
+ }
+ if (outVars[index]) {
+ this->write(" ");
+ fIgnoreVariableReferenceModifiers = true;
+ this->writeVariableReference(*outVars[index]);
+ fIgnoreVariableReferenceModifiers = false;
+ } else {
+ this->write(" _var");
+ this->write(to_string(index));
+ }
+ this->writeArrayDimensions(*type);
+ }
+ this->writeLine(") {");
+
+ ++fIndentation;
+ for (int index = 0; index < outVars.count(); ++index) {
+ if (!outVars[index]) {
+ continue;
+ }
+ // float3 _var2[ = outParam.zyx];
+ this->writeBaseType(arguments[index]->type());
+ this->write(" _var");
+ this->write(to_string(index));
+
+ const Variable* param = function.parameters()[index];
+ if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
+ this->write(" = ");
+ fIgnoreVariableReferenceModifiers = true;
+ this->writeExpression(*arguments[index], kAssignment_Precedence);
+ fIgnoreVariableReferenceModifiers = false;
+ }
+
+ this->writeLine(";");
+ }
+
+ // [int _skResult = ] myFunction(inputs, outputs, globals, _var0, _var1, _var2, _var3);
+ bool hasResult = (call.type().name() != "void");
+ if (hasResult) {
+ this->writeBaseType(call.type());
+ this->write(" _skResult = ");
+ }
+
+ this->writeName(function.name());
+ this->write("(");
+ separator = "";
+ this->writeFunctionRequirementArgs(function, separator);
+
+ for (int index = 0; index < arguments.count(); ++index) {
+ this->write(separator);
+ separator = ", ";
+
+ this->write("_var");
+ this->write(to_string(index));
+ }
+ this->writeLine(");");
+
+ for (int index = 0; index < outVars.count(); ++index) {
+ if (!outVars[index]) {
+ continue;
+ }
+ // outParam.zyx = _var2;
+ fIgnoreVariableReferenceModifiers = true;
+ this->writeExpression(*arguments[index], kAssignment_Precedence);
+ fIgnoreVariableReferenceModifiers = false;
+ this->write(" = _var");
+ this->write(to_string(index));
+ this->writeLine(";");
+ }
+
+ if (hasResult) {
+ this->writeLine("return _skResult;");
+ }
+
+ --fIndentation;
+ this->writeLine("}");
+
+ return name;
}
void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
@@ -298,58 +409,50 @@
name = "dfdy";
}
- // GLSL supports passing swizzled variables to out params; Metal doesn't. To emulate that
- // support, we synthesize a helper function which performs the swizzle into a temporary
- // variable, calls the original function, then writes the temp var back into the out param.
+ // GLSL supports passing swizzled variables to out params; Metal doesn't. Walk the list of
+ // parameters and see if any are out parameters; if so, check if the passed-in expression is a
+ // swizzle. Take a note of all the swizzled variables that we find.
const std::vector<const Variable*>& parameters = function.parameters();
SkASSERT(arguments.size() == parameters.size());
- for (size_t index = 0; index < arguments.size(); ++index) {
+
+ bool foundOutParam = false;
+ SkSTArray<16, VariableReference*> outVars;
+ outVars.push_back_n(arguments.size(), (VariableReference*)nullptr);
+
+ for (int index = 0; index < arguments.count(); ++index) {
// If this is an out parameter...
if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
- // Inspect the expression to see if it contains a swizzle.
+ // Find the expression's inner variable being written to.
Analysis::AssignmentInfo info;
- bool outParamIsAssignable = Analysis::IsAssignable(*arguments[index], &info, nullptr);
- SkASSERT(outParamIsAssignable); // assignability was verified at IRGeneration time
- if (outParamIsAssignable && info.fIsSwizzled) {
- // Found a swizzle; we need to use a helper function here.
- name = this->getOutParamHelper(function, arguments);
- break;
- }
+ // Assignability was verified at IRGeneration time, so this should always succeed.
+ SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
+ outVars[index] = info.fAssignedVar;
+ foundOutParam = true;
}
}
+ if (foundOutParam) {
+ // Out parameters need to be written back to at the end of the function. To do this, we
+ // synthesize a helper function which evaluates the out-param expression into a temporary
+ // variable, calls the original function, then writes the temp var back into the out param
+ // using the original out-param expression. (This lets us support things like swizzles and
+ // array indices.)
+ name = getOutParamHelper(c, arguments, outVars);
+ }
+
this->write(name);
this->write("(");
const char* separator = "";
- if (this->requirements(function) & kInputs_Requirement) {
- this->write("_in");
- separator = ", ";
- }
- if (this->requirements(function) & kOutputs_Requirement) {
- this->write(separator);
- this->write("_out");
- separator = ", ";
- }
- if (this->requirements(function) & kUniforms_Requirement) {
- this->write(separator);
- this->write("_uniforms");
- separator = ", ";
- }
- if (this->requirements(function) & kGlobals_Requirement) {
- this->write(separator);
- this->write("_globals");
- separator = ", ";
- }
- if (this->requirements(function) & kFragCoord_Requirement) {
- this->write(separator);
- this->write("_fragCoord");
- separator = ", ";
- }
- for (size_t i = 0; i < arguments.size(); ++i) {
- const Expression& arg = *arguments[i];
+ this->writeFunctionRequirementArgs(function, separator);
+ for (int i = 0; i < arguments.count(); ++i) {
this->write(separator);
separator = ", ";
- this->writeExpression(arg, kSequence_Precedence);
+
+ if (outVars[i]) {
+ this->writeExpression(*outVars[i], kSequence_Precedence);
+ } else {
+ this->writeExpression(*arguments[i], kSequence_Precedence);
+ }
}
this->write(")");
}
@@ -807,6 +910,14 @@
}
void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
+ // When assembling out-param helper functions, we copy variables into local clones with matching
+ // names. We never want to prepend "_in." or "_globals->" when writing these variables since
+ // we're actually targeting the clones.
+ if (fIgnoreVariableReferenceModifiers) {
+ this->writeName(ref.variable()->name());
+ return;
+ }
+
switch (ref.variable()->modifiers().fLayout.fBuiltin) {
case SK_FRAGCOLOR_BUILTIN:
this->write("_out->sk_FragColor");
@@ -1047,6 +1158,66 @@
ABORT("internal error; setting was not folded to a constant during compilation\n");
}
+void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
+ const char*& separator) {
+ Requirements requirements = this->requirements(f);
+ if (requirements & kInputs_Requirement) {
+ this->write(separator);
+ this->write("_in");
+ separator = ", ";
+ }
+ if (requirements & kOutputs_Requirement) {
+ this->write(separator);
+ this->write("_out");
+ separator = ", ";
+ }
+ if (requirements & kUniforms_Requirement) {
+ this->write(separator);
+ this->write("_uniforms");
+ separator = ", ";
+ }
+ if (requirements & kGlobals_Requirement) {
+ this->write(separator);
+ this->write("_globals");
+ separator = ", ";
+ }
+ if (requirements & kFragCoord_Requirement) {
+ this->write(separator);
+ this->write("_fragCoord");
+ separator = ", ";
+ }
+}
+
+void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
+ const char*& separator) {
+ Requirements requirements = this->requirements(f);
+ if (requirements & kInputs_Requirement) {
+ this->write(separator);
+ this->write("Inputs _in");
+ separator = ", ";
+ }
+ if (requirements & kOutputs_Requirement) {
+ this->write(separator);
+ this->write("thread Outputs* _out");
+ separator = ", ";
+ }
+ if (requirements & kUniforms_Requirement) {
+ this->write(separator);
+ this->write("Uniforms _uniforms");
+ separator = ", ";
+ }
+ if (requirements & kGlobals_Requirement) {
+ this->write(separator);
+ this->write("thread Globals* _globals");
+ separator = ", ";
+ }
+ if (requirements & kFragCoord_Requirement) {
+ this->write(separator);
+ this->write("float4 _fragCoord");
+ separator = ", ";
+ }
+}
+
bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
fRTHeightName = fProgram.fInputs.fRTHeight ? "_globals->_anonInterface0->u_skRTHeight" : "";
const char* separator = "";
@@ -1130,36 +1301,12 @@
this->write(" ");
this->writeName(f.name());
this->write("(");
- Requirements requirements = this->requirements(f);
- if (requirements & kInputs_Requirement) {
- this->write("Inputs _in");
- separator = ", ";
- }
- if (requirements & kOutputs_Requirement) {
- this->write(separator);
- this->write("thread Outputs* _out");
- separator = ", ";
- }
- if (requirements & kUniforms_Requirement) {
- this->write(separator);
- this->write("Uniforms _uniforms");
- separator = ", ";
- }
- if (requirements & kGlobals_Requirement) {
- this->write(separator);
- this->write("thread Globals* _globals");
- separator = ", ";
- }
- if (requirements & kFragCoord_Requirement) {
- this->write(separator);
- this->write("float4 _fragCoord");
- separator = ", ";
- }
+ this->writeFunctionRequirementParams(f, separator);
}
for (const auto& param : f.parameters()) {
this->write(separator);
separator = ", ";
- this->writeModifiers(param->modifiers(), false);
+ this->writeModifiers(param->modifiers(), /*globalContext=*/false);
const Type* type = ¶m->type();
this->writeBaseType(*type);
if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
@@ -1243,7 +1390,7 @@
}
void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
- bool globalContext) {
+ bool globalContext) {
if (modifiers.fFlags & Modifiers::kOut_Flag) {
this->write("thread ");
}
@@ -1256,7 +1403,7 @@
if ("sk_PerVertex" == intf.typeName()) {
return;
}
- this->writeModifiers(intf.variable().modifiers(), true);
+ this->writeModifiers(intf.variable().modifiers(), /*globalContext=*/true);
this->write("struct ");
this->writeLine(intf.typeName() + " {");
const Type* structType = &intf.variable().type();
@@ -1327,7 +1474,7 @@
return;
}
currentOffset += fieldSize;
- this->writeModifiers(field.fModifiers, false);
+ this->writeModifiers(field.fModifiers, /*globalContext=*/false);
this->writeBaseType(*fieldType);
this->write(" ");
this->writeName(field.fName);
@@ -1849,7 +1996,8 @@
this->writeFunctionPrototype(e.as<FunctionPrototype>());
break;
case ProgramElement::Kind::kModifiers:
- this->writeModifiers(e.as<ModifiersDeclaration>().modifiers(), true);
+ this->writeModifiers(e.as<ModifiersDeclaration>().modifiers(),
+ /*globalContext=*/true);
this->writeLine(";");
break;
case ProgramElement::Kind::kEnum:
diff --git a/src/sksl/SkSLMetalCodeGenerator.h b/src/sksl/SkSLMetalCodeGenerator.h
index e6b1dae..1f49b63 100644
--- a/src/sksl/SkSLMetalCodeGenerator.h
+++ b/src/sksl/SkSLMetalCodeGenerator.h
@@ -178,6 +178,11 @@
void writeFunctionStart(const FunctionDeclaration& f);
+ void writeFunctionRequirementParams(const FunctionDeclaration& f,
+ const char*& separator);
+
+ void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);
+
bool writeFunctionDeclaration(const FunctionDeclaration& f);
void writeFunction(const FunctionDefinition& f);
@@ -204,7 +209,9 @@
void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);
- String getOutParamHelper(const FunctionDeclaration& function, const ExpressionArray& arguments);
+ String getOutParamHelper(const FunctionCall& c,
+ const ExpressionArray& arguments,
+ const SkTArray<VariableReference*>& outVars);
String getInverseHack(const Expression& mat);
@@ -284,7 +291,6 @@
int fPaddingCount = 0;
const char* fLineEnding;
const Context& fContext;
- StringStream fHeader;
String fFunctionHeader;
StringStream fExtraFunctions;
Program::Kind fProgramKind;
@@ -305,6 +311,8 @@
int fUniformBuffer = -1;
String fRTHeightName;
const FunctionDeclaration* fCurrentFunction = nullptr;
+ int fSwizzleHelperCount = 0;
+ bool fIgnoreVariableReferenceModifiers = false;
using INHERITED = CodeGenerator;
};
diff --git a/src/sksl/SkSLString.cpp b/src/sksl/SkSLString.cpp
index 67b4d2d..5f070b3 100644
--- a/src/sksl/SkSLString.cpp
+++ b/src/sksl/SkSLString.cpp
@@ -210,15 +210,11 @@
}
String to_string(int64_t value) {
- std::stringstream buffer;
- buffer << value;
- return String(buffer.str().c_str());
+ return SkSL::String::printf("%lld", value);
}
String to_string(uint64_t value) {
- std::stringstream buffer;
- buffer << value;
- return String(buffer.str().c_str());
+ return SkSL::String::printf("%llu", value);
}
String to_string(double value) {