Support swizzled out-params in SPIR-V intrinsic calls.

The code in writeFunctionCall which supported swizzled out-params has
been factored out into a pair of helper functions. The code for emitting
intrinsics now relies on these helpers when emitting out-params.

Change-Id: I4436185ae107d70b529e7e1ea0dd89f844c6a673
Bug: skia:11052
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/446719
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp b/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp
index 43f8661..0175625 100644
--- a/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp
@@ -195,12 +195,12 @@
     return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
 }
 
-static bool is_out(const Variable& var) {
-    return (var.modifiers().fFlags & Modifiers::kOut_Flag) != 0;
+static bool is_out(const Modifiers& m) {
+    return (m.fFlags & Modifiers::kOut_Flag) != 0;
 }
 
-static bool is_in(const Variable& var) {
-    switch (var.modifiers().fFlags & (Modifiers::kOut_Flag | Modifiers::kIn_Flag)) {
+static bool is_in(const Modifiers& m) {
+    switch (m.fFlags & (Modifiers::kOut_Flag | Modifiers::kIn_Flag)) {
         case Modifiers::kOut_Flag:                       // out
             return false;
 
@@ -703,12 +703,8 @@
             // crashes. Making it take a float* and storing the argument in a temporary variable,
             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
             // the spec makes this make sense.
-//            if (is_out(function->fParameters[i])) {
-                parameterTypes.push_back(this->getPointerType(parameters[i]->type(),
-                                                              SpvStorageClassFunction));
-//            } else {
-//                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
-//            }
+            parameterTypes.push_back(this->getPointerType(parameters[i]->type(),
+                                                          SpvStorageClassFunction));
         }
         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
         this->writeWord(result, fConstantBuffer);
@@ -822,10 +818,15 @@
         case kGLSL_STD_450_IntrinsicOpcodeKind: {
             SpvId result = this->nextId(&c.type());
             std::vector<SpvId> argumentIds;
+            std::vector<TempVar> tempVars;
+            argumentIds.reserve(arguments.size());
             for (size_t i = 0; i < arguments.size(); i++) {
-                if (function.parameters()[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
-                    // TODO(skia:11052): swizzled lvalues won't work with getPointer()
-                    argumentIds.push_back(this->getLValue(*arguments[i], out)->getPointer());
+                if (is_out(function.parameters()[i]->modifiers())) {
+                    argumentIds.push_back(
+                            this->writeFunctionCallArgument(*arguments[i],
+                                                            function.parameters()[i]->modifiers(),
+                                                            &tempVars,
+                                                            out));
                 } else {
                     argumentIds.push_back(this->writeExpression(*arguments[i], out));
                 }
@@ -838,6 +839,7 @@
             for (SpvId id : argumentIds) {
                 this->writeWord(id, out);
             }
+            this->copyBackTempVars(tempVars, out);
             return result;
         }
         case kSPIRV_IntrinsicOpcodeKind: {
@@ -847,10 +849,15 @@
             }
             SpvId result = this->nextId(&c.type());
             std::vector<SpvId> argumentIds;
+            std::vector<TempVar> tempVars;
+            argumentIds.reserve(arguments.size());
             for (size_t i = 0; i < arguments.size(); i++) {
-                if (function.parameters()[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
-                    // TODO(skia:11052): swizzled lvalues won't work with getPointer()
-                    argumentIds.push_back(this->getLValue(*arguments[i], out)->getPointer());
+                if (is_out(function.parameters()[i]->modifiers())) {
+                    argumentIds.push_back(
+                            this->writeFunctionCallArgument(*arguments[i],
+                                                            function.parameters()[i]->modifiers(),
+                                                            &tempVars,
+                                                            out));
                 } else {
                     argumentIds.push_back(this->writeExpression(*arguments[i], out));
                 }
@@ -865,6 +872,7 @@
             for (SpvId id : argumentIds) {
                 this->writeWord(id, out);
             }
+            this->copyBackTempVars(tempVars, out);
             return result;
         }
         case kSpecial_IntrinsicOpcodeKind:
@@ -948,6 +956,7 @@
     switch (kind) {
         case kAtan_SpecialIntrinsic: {
             std::vector<SpvId> argumentIds;
+            argumentIds.reserve(arguments.size());
             for (const std::unique_ptr<Expression>& arg : arguments) {
                 argumentIds.push_back(this->writeExpression(*arg, out));
             }
@@ -1170,12 +1179,52 @@
     return result;
 }
 
-namespace {
-struct TempVar {
-    SpvId spvId;
-    const Type* type;
-    std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
-};
+SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const Expression& arg,
+                                                    const Modifiers& paramModifiers,
+                                                    std::vector<TempVar>* tempVars,
+                                                    OutputStream& out) {
+    // ID of temporary variable that we will use to hold this argument, or 0 if it is being
+    // passed directly
+    SpvId tmpVar;
+    // if we need a temporary var to store this argument, this is the value to store in the var
+    SpvId tmpValueId = -1;
+
+    if (is_out(paramModifiers)) {
+        std::unique_ptr<LValue> lv = this->getLValue(arg, out);
+        SpvId ptr = lv->getPointer();
+        if (ptr != (SpvId) -1 && lv->isMemoryObjectPointer()) {
+            return ptr;
+        }
+
+        // lvalue cannot simply be read and written via a pointer (e.g. it's a swizzle). We need to
+        // to use a temp variable.
+        if (is_in(paramModifiers)) {
+            tmpValueId = lv->load(out);
+        }
+        tmpVar = this->nextId(&arg.type());
+        tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
+    } else {
+        // See getFunctionType for an explanation of why we're always using pointer parameters.
+        tmpValueId = this->writeExpression(arg, out);
+        tmpVar = this->nextId(nullptr);
+    }
+    this->writeInstruction(SpvOpVariable,
+                           this->getPointerType(arg.type(), SpvStorageClassFunction),
+                           tmpVar,
+                           SpvStorageClassFunction,
+                           fVariableBuffer);
+    if (tmpValueId != (SpvId)-1) {
+        this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
+    }
+    return tmpVar;
+}
+
+void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
+    for (const TempVar& tempVar : tempVars) {
+        SpvId load = this->nextId(tempVar.type);
+        this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
+        tempVar.lvalue->store(load, out);
+    }
 }
 
 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
@@ -1193,42 +1242,12 @@
     // Temp variables are used to write back out-parameters after the function call is complete.
     std::vector<TempVar> tempVars;
     std::vector<SpvId> argumentIds;
+    argumentIds.reserve(arguments.size());
     for (size_t i = 0; i < arguments.size(); i++) {
-        // id of temporary variable that we will use to hold this argument, or 0 if it is being
-        // passed directly
-        SpvId tmpVar;
-        // if we need a temporary var to store this argument, this is the value to store in the var
-        SpvId tmpValueId = -1;
-        if (is_out(*function.parameters()[i])) {
-            std::unique_ptr<LValue> lv = this->getLValue(*arguments[i], out);
-            SpvId ptr = lv->getPointer();
-            if (ptr != (SpvId) -1 && lv->isMemoryObjectPointer()) {
-                argumentIds.push_back(ptr);
-                continue;
-            } else {
-                // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
-                // copy it into a temp, call the function, read the value out of the temp, and then
-                // update the lvalue.
-                if (is_in(*function.parameters()[i])) {
-                    tmpValueId = lv->load(out);
-                }
-                tmpVar = this->nextId(&arguments[i]->type());
-                tempVars.push_back(TempVar{tmpVar, &arguments[i]->type(), std::move(lv)});
-            }
-        } else {
-            // See getFunctionType for an explanation of why we're always using pointer parameters.
-            tmpValueId = this->writeExpression(*arguments[i], out);
-            tmpVar = this->nextId(nullptr);
-        }
-        this->writeInstruction(SpvOpVariable,
-                               this->getPointerType(arguments[i]->type(), SpvStorageClassFunction),
-                               tmpVar,
-                               SpvStorageClassFunction,
-                               fVariableBuffer);
-        if (tmpValueId != (SpvId)-1) {
-            this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
-        }
-        argumentIds.push_back(tmpVar);
+        argumentIds.push_back(this->writeFunctionCallArgument(*arguments[i],
+                                                              function.parameters()[i]->modifiers(),
+                                                              &tempVars,
+                                                              out));
     }
     SpvId result = this->nextId(nullptr);
     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) arguments.size(), out);
@@ -1239,11 +1258,7 @@
         this->writeWord(id, out);
     }
     // Now that the call is complete, we copy temp out-variables back to their real lvalues.
-    for (const TempVar& tempVar : tempVars) {
-        SpvId load = this->nextId(tempVar.type);
-        this->writeInstruction(SpvOpLoad, getType(*tempVar.type), load, tempVar.spvId, out);
-        tempVar.lvalue->store(load, out);
-    }
+    this->copyBackTempVars(tempVars, out);
     return result;
 }
 
diff --git a/src/sksl/codegen/SkSLSPIRVCodeGenerator.h b/src/sksl/codegen/SkSLSPIRVCodeGenerator.h
index 55bdc74..5c08c82 100644
--- a/src/sksl/codegen/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/codegen/SkSLSPIRVCodeGenerator.h
@@ -172,6 +172,12 @@
         kRelaxed,
     };
 
+    struct TempVar {
+        SpvId spvId;
+        const Type* type;
+        std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
+    };
+
     void setupIntrinsics();
 
     /**
@@ -229,6 +235,13 @@
 
     SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
 
+    SpvId writeFunctionCallArgument(const Expression& arg,
+                                    const Modifiers& paramModifiers,
+                                    std::vector<TempVar>* tempVars,
+                                    OutputStream& out);
+
+    void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
+
     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);