Improve Metal support for out parameters.

We now insert helper functions which defer the assignment of out-
parameters back into their original variables to the end of the
function call. This allows us to match the semantics listed the GLSL
spec in section 6.1.1:

"All arguments are evaluated at call time, exactly once, in order, from
left to right. [...] Evaluation of an out parameter results in an
l-value that is used to copy out a value when the function returns.
Evaluation of an inout parameter results in both a value and an l-value;
the value is copied to the formal parameter at call time and the lvalue
is used to copy out a value when the function returns."

This technique also allows us to support swizzled out-parameters in
Metal, by reading the swizzle into a temp variable, calling the original
function, and then re-assigning the result back into the original
swizzle expression.

At present, we don't deduplicate these helper functions, so in theory
there could be a fair amount of redundant code generated if a function
with out parameters is called many times in a row. The cost of properly
deduplicating them is probably larger than the benefit in the 99% case.

Change-Id: Iefc922ac9e2b24ef2ff1e9dacb17a735a75ec8ea
Bug: skia:10855, skia:11052
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/341162
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 7b17b22..cfc94ff 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -165,8 +165,6 @@
             this->write(to_string(type.columns()));
         }
         this->write("]");
-
-        this->writeArrayDimensions(type.componentType());
     }
 }
 
@@ -264,10 +262,123 @@
     }
 }
 
-String MetalCodeGenerator::getOutParamHelper(const FunctionDeclaration& function,
-                                             const ExpressionArray& arguments) {
-    // TODO: actually synthesize helper method.
-    return String::printf("/*needs swizzle fix*/ %s", String(function.name()).c_str());
+String MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
+                                             const ExpressionArray& arguments,
+                                             const SkTArray<VariableReference*>& outVars) {
+    AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
+    const FunctionDeclaration& function = call.function();
+
+    String name = "_skOutParamHelper" + to_string(fSwizzleHelperCount++) + "_" + function.name();
+    const char* separator = "";
+
+    // Emit a prototype for the function we'll be calling through to in our helper.
+    if (!function.isBuiltin()) {
+        this->writeFunctionDeclaration(function);
+        this->writeLine(";");
+    }
+
+    // Synthesize a helper function that takes the same inputs as `function`, except in places where
+    // `outVars` is non-null; in those places, we take the type of the VariableReference.
+    //
+    // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
+    this->writeBaseType(call.type());
+    this->write(" ");
+    this->write(name);
+    this->write("(");
+    this->writeFunctionRequirementParams(function, separator);
+
+    SkASSERT(outVars.size() == arguments.size());
+    SkASSERT(outVars.size() == function.parameters().size());
+
+    for (int index = 0; index < arguments.count(); ++index) {
+        this->write(separator);
+        separator = ", ";
+
+        const Variable* param = function.parameters()[index];
+        this->writeModifiers(param->modifiers(), /*globalContext=*/false);
+
+        const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
+        this->writeBaseType(*type);
+
+        if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
+            this->write("&");
+        }
+        if (outVars[index]) {
+            this->write(" ");
+            fIgnoreVariableReferenceModifiers = true;
+            this->writeVariableReference(*outVars[index]);
+            fIgnoreVariableReferenceModifiers = false;
+        } else {
+            this->write(" _var");
+            this->write(to_string(index));
+        }
+        this->writeArrayDimensions(*type);
+    }
+    this->writeLine(") {");
+
+    ++fIndentation;
+    for (int index = 0; index < outVars.count(); ++index) {
+        if (!outVars[index]) {
+            continue;
+        }
+        // float3 _var2[ = outParam.zyx];
+        this->writeBaseType(arguments[index]->type());
+        this->write(" _var");
+        this->write(to_string(index));
+
+        const Variable* param = function.parameters()[index];
+        if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
+            this->write(" = ");
+            fIgnoreVariableReferenceModifiers = true;
+            this->writeExpression(*arguments[index], kAssignment_Precedence);
+            fIgnoreVariableReferenceModifiers = false;
+        }
+
+        this->writeLine(";");
+    }
+
+    // [int _skResult = ] myFunction(inputs, outputs, globals, _var0, _var1, _var2, _var3);
+    bool hasResult = (call.type().name() != "void");
+    if (hasResult) {
+        this->writeBaseType(call.type());
+        this->write(" _skResult = ");
+    }
+
+    this->writeName(function.name());
+    this->write("(");
+    separator = "";
+    this->writeFunctionRequirementArgs(function, separator);
+
+    for (int index = 0; index < arguments.count(); ++index) {
+        this->write(separator);
+        separator = ", ";
+
+        this->write("_var");
+        this->write(to_string(index));
+    }
+    this->writeLine(");");
+
+    for (int index = 0; index < outVars.count(); ++index) {
+        if (!outVars[index]) {
+            continue;
+        }
+        // outParam.zyx = _var2;
+        fIgnoreVariableReferenceModifiers = true;
+        this->writeExpression(*arguments[index], kAssignment_Precedence);
+        fIgnoreVariableReferenceModifiers = false;
+        this->write(" = _var");
+        this->write(to_string(index));
+        this->writeLine(";");
+    }
+
+    if (hasResult) {
+        this->writeLine("return _skResult;");
+    }
+
+    --fIndentation;
+    this->writeLine("}");
+
+    return name;
 }
 
 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
@@ -298,58 +409,50 @@
         name = "dfdy";
     }
 
-    // GLSL supports passing swizzled variables to out params; Metal doesn't. To emulate that
-    // support, we synthesize a helper function which performs the swizzle into a temporary
-    // variable, calls the original function, then writes the temp var back into the out param.
+    // GLSL supports passing swizzled variables to out params; Metal doesn't. Walk the list of
+    // parameters and see if any are out parameters; if so, check if the passed-in expression is a
+    // swizzle. Take a note of all the swizzled variables that we find.
     const std::vector<const Variable*>& parameters = function.parameters();
     SkASSERT(arguments.size() == parameters.size());
-    for (size_t index = 0; index < arguments.size(); ++index) {
+
+    bool foundOutParam = false;
+    SkSTArray<16, VariableReference*> outVars;
+    outVars.push_back_n(arguments.size(), (VariableReference*)nullptr);
+
+    for (int index = 0; index < arguments.count(); ++index) {
         // If this is an out parameter...
         if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
-            // Inspect the expression to see if it contains a swizzle.
+            // Find the expression's inner variable being written to.
             Analysis::AssignmentInfo info;
-            bool outParamIsAssignable = Analysis::IsAssignable(*arguments[index], &info, nullptr);
-            SkASSERT(outParamIsAssignable);  // assignability was verified at IRGeneration time
-            if (outParamIsAssignable && info.fIsSwizzled) {
-                // Found a swizzle; we need to use a helper function here.
-                name = this->getOutParamHelper(function, arguments);
-                break;
-            }
+            // Assignability was verified at IRGeneration time, so this should always succeed.
+            SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
+            outVars[index] = info.fAssignedVar;
+            foundOutParam = true;
         }
     }
 
+    if (foundOutParam) {
+        // Out parameters need to be written back to at the end of the function. To do this, we
+        // synthesize a helper function which evaluates the out-param expression into a temporary
+        // variable, calls the original function, then writes the temp var back into the out param
+        // using the original out-param expression. (This lets us support things like swizzles and
+        // array indices.)
+        name = getOutParamHelper(c, arguments, outVars);
+    }
+
     this->write(name);
     this->write("(");
     const char* separator = "";
-    if (this->requirements(function) & kInputs_Requirement) {
-        this->write("_in");
-        separator = ", ";
-    }
-    if (this->requirements(function) & kOutputs_Requirement) {
-        this->write(separator);
-        this->write("_out");
-        separator = ", ";
-    }
-    if (this->requirements(function) & kUniforms_Requirement) {
-        this->write(separator);
-        this->write("_uniforms");
-        separator = ", ";
-    }
-    if (this->requirements(function) & kGlobals_Requirement) {
-        this->write(separator);
-        this->write("_globals");
-        separator = ", ";
-    }
-    if (this->requirements(function) & kFragCoord_Requirement) {
-        this->write(separator);
-        this->write("_fragCoord");
-        separator = ", ";
-    }
-    for (size_t i = 0; i < arguments.size(); ++i) {
-        const Expression& arg = *arguments[i];
+    this->writeFunctionRequirementArgs(function, separator);
+    for (int i = 0; i < arguments.count(); ++i) {
         this->write(separator);
         separator = ", ";
-        this->writeExpression(arg, kSequence_Precedence);
+
+        if (outVars[i]) {
+            this->writeExpression(*outVars[i], kSequence_Precedence);
+        } else {
+            this->writeExpression(*arguments[i], kSequence_Precedence);
+        }
     }
     this->write(")");
 }
@@ -807,6 +910,14 @@
 }
 
 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
+    // When assembling out-param helper functions, we copy variables into local clones with matching
+    // names. We never want to prepend "_in." or "_globals->" when writing these variables since
+    // we're actually targeting the clones.
+    if (fIgnoreVariableReferenceModifiers) {
+        this->writeName(ref.variable()->name());
+        return;
+    }
+
     switch (ref.variable()->modifiers().fLayout.fBuiltin) {
         case SK_FRAGCOLOR_BUILTIN:
             this->write("_out->sk_FragColor");
@@ -1047,6 +1158,66 @@
     ABORT("internal error; setting was not folded to a constant during compilation\n");
 }
 
+void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
+                                                      const char*& separator) {
+    Requirements requirements = this->requirements(f);
+    if (requirements & kInputs_Requirement) {
+        this->write(separator);
+        this->write("_in");
+        separator = ", ";
+    }
+    if (requirements & kOutputs_Requirement) {
+        this->write(separator);
+        this->write("_out");
+        separator = ", ";
+    }
+    if (requirements & kUniforms_Requirement) {
+        this->write(separator);
+        this->write("_uniforms");
+        separator = ", ";
+    }
+    if (requirements & kGlobals_Requirement) {
+        this->write(separator);
+        this->write("_globals");
+        separator = ", ";
+    }
+    if (requirements & kFragCoord_Requirement) {
+        this->write(separator);
+        this->write("_fragCoord");
+        separator = ", ";
+    }
+}
+
+void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
+                                                        const char*& separator) {
+    Requirements requirements = this->requirements(f);
+    if (requirements & kInputs_Requirement) {
+        this->write(separator);
+        this->write("Inputs _in");
+        separator = ", ";
+    }
+    if (requirements & kOutputs_Requirement) {
+        this->write(separator);
+        this->write("thread Outputs* _out");
+        separator = ", ";
+    }
+    if (requirements & kUniforms_Requirement) {
+        this->write(separator);
+        this->write("Uniforms _uniforms");
+        separator = ", ";
+    }
+    if (requirements & kGlobals_Requirement) {
+        this->write(separator);
+        this->write("thread Globals* _globals");
+        separator = ", ";
+    }
+    if (requirements & kFragCoord_Requirement) {
+        this->write(separator);
+        this->write("float4 _fragCoord");
+        separator = ", ";
+    }
+}
+
 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
     fRTHeightName = fProgram.fInputs.fRTHeight ? "_globals->_anonInterface0->u_skRTHeight" : "";
     const char* separator = "";
@@ -1130,36 +1301,12 @@
         this->write(" ");
         this->writeName(f.name());
         this->write("(");
-        Requirements requirements = this->requirements(f);
-        if (requirements & kInputs_Requirement) {
-            this->write("Inputs _in");
-            separator = ", ";
-        }
-        if (requirements & kOutputs_Requirement) {
-            this->write(separator);
-            this->write("thread Outputs* _out");
-            separator = ", ";
-        }
-        if (requirements & kUniforms_Requirement) {
-            this->write(separator);
-            this->write("Uniforms _uniforms");
-            separator = ", ";
-        }
-        if (requirements & kGlobals_Requirement) {
-            this->write(separator);
-            this->write("thread Globals* _globals");
-            separator = ", ";
-        }
-        if (requirements & kFragCoord_Requirement) {
-            this->write(separator);
-            this->write("float4 _fragCoord");
-            separator = ", ";
-        }
+        this->writeFunctionRequirementParams(f, separator);
     }
     for (const auto& param : f.parameters()) {
         this->write(separator);
         separator = ", ";
-        this->writeModifiers(param->modifiers(), false);
+        this->writeModifiers(param->modifiers(), /*globalContext=*/false);
         const Type* type = &param->type();
         this->writeBaseType(*type);
         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
@@ -1243,7 +1390,7 @@
 }
 
 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
-                                       bool globalContext) {
+                                        bool globalContext) {
     if (modifiers.fFlags & Modifiers::kOut_Flag) {
         this->write("thread ");
     }
@@ -1256,7 +1403,7 @@
     if ("sk_PerVertex" == intf.typeName()) {
         return;
     }
-    this->writeModifiers(intf.variable().modifiers(), true);
+    this->writeModifiers(intf.variable().modifiers(), /*globalContext=*/true);
     this->write("struct ");
     this->writeLine(intf.typeName() + " {");
     const Type* structType = &intf.variable().type();
@@ -1327,7 +1474,7 @@
             return;
         }
         currentOffset += fieldSize;
-        this->writeModifiers(field.fModifiers, false);
+        this->writeModifiers(field.fModifiers, /*globalContext=*/false);
         this->writeBaseType(*fieldType);
         this->write(" ");
         this->writeName(field.fName);
@@ -1849,7 +1996,8 @@
             this->writeFunctionPrototype(e.as<FunctionPrototype>());
             break;
         case ProgramElement::Kind::kModifiers:
-            this->writeModifiers(e.as<ModifiersDeclaration>().modifiers(), true);
+            this->writeModifiers(e.as<ModifiersDeclaration>().modifiers(),
+                                 /*globalContext=*/true);
             this->writeLine(";");
             break;
         case ProgramElement::Kind::kEnum:
diff --git a/src/sksl/SkSLMetalCodeGenerator.h b/src/sksl/SkSLMetalCodeGenerator.h
index e6b1dae..1f49b63 100644
--- a/src/sksl/SkSLMetalCodeGenerator.h
+++ b/src/sksl/SkSLMetalCodeGenerator.h
@@ -178,6 +178,11 @@
 
     void writeFunctionStart(const FunctionDeclaration& f);
 
+    void writeFunctionRequirementParams(const FunctionDeclaration& f,
+                                        const char*& separator);
+
+    void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);
+
     bool writeFunctionDeclaration(const FunctionDeclaration& f);
 
     void writeFunction(const FunctionDefinition& f);
@@ -204,7 +209,9 @@
 
     void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);
 
-    String getOutParamHelper(const FunctionDeclaration& function, const ExpressionArray& arguments);
+    String getOutParamHelper(const FunctionCall& c,
+                             const ExpressionArray& arguments,
+                             const SkTArray<VariableReference*>& outVars);
 
     String getInverseHack(const Expression& mat);
 
@@ -284,7 +291,6 @@
     int fPaddingCount = 0;
     const char* fLineEnding;
     const Context& fContext;
-    StringStream fHeader;
     String fFunctionHeader;
     StringStream fExtraFunctions;
     Program::Kind fProgramKind;
@@ -305,6 +311,8 @@
     int fUniformBuffer = -1;
     String fRTHeightName;
     const FunctionDeclaration* fCurrentFunction = nullptr;
+    int fSwizzleHelperCount = 0;
+    bool fIgnoreVariableReferenceModifiers = false;
 
     using INHERITED = CodeGenerator;
 };
diff --git a/src/sksl/SkSLString.cpp b/src/sksl/SkSLString.cpp
index 67b4d2d..5f070b3 100644
--- a/src/sksl/SkSLString.cpp
+++ b/src/sksl/SkSLString.cpp
@@ -210,15 +210,11 @@
 }
 
 String to_string(int64_t value) {
-    std::stringstream buffer;
-    buffer << value;
-    return String(buffer.str().c_str());
+    return SkSL::String::printf("%lld", value);
 }
 
 String to_string(uint64_t value) {
-    std::stringstream buffer;
-    buffer << value;
-    return String(buffer.str().c_str());
+    return SkSL::String::printf("%llu", value);
 }
 
 String to_string(double value) {