SkSL now supports uniform array types

Change-Id: I809e9c424ee92b05f0a87d75d1384c92849e1474
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/308498
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/gpu/effects/generated/GrColorMatrixFragmentProcessor.cpp b/src/gpu/effects/generated/GrColorMatrixFragmentProcessor.cpp
index 99cc75b..8dd0b60 100644
--- a/src/gpu/effects/generated/GrColorMatrixFragmentProcessor.cpp
+++ b/src/gpu/effects/generated/GrColorMatrixFragmentProcessor.cpp
@@ -77,6 +77,7 @@
             const SkM44& mValue = _outer.m;
             if (mPrev != (mValue)) {
                 mPrev = mValue;
+                static_assert(1 == 1);
                 pdman.setSkM44(mVar, mValue);
             }
             const SkV4& vValue = _outer.v;
diff --git a/src/sksl/SkSLCPPCodeGenerator.cpp b/src/sksl/SkSLCPPCodeGenerator.cpp
index c641fbd..21ac37a 100644
--- a/src/sksl/SkSLCPPCodeGenerator.cpp
+++ b/src/sksl/SkSLCPPCodeGenerator.cpp
@@ -168,6 +168,18 @@
                                             const Layout& layout,
                                             const String& cppCode,
                                             std::vector<String>* formatArgs) {
+    if (type.kind() == Type::kArray_Kind) {
+        String result("[");
+        const char* separator = "";
+        for (int i = 0; i < type.columns(); i++) {
+            result += separator + this->formatRuntimeValue(type.componentType(), layout,
+                                                           "(" + cppCode + ")[" + to_string(i) +
+                                                           "]", formatArgs);
+            separator = ",";
+        }
+        result += "]";
+        return result;
+    }
     if (type.isFloat()) {
         formatArgs->push_back(cppCode);
         return "%f";
@@ -678,11 +690,21 @@
     if (var.fModifiers.fLayout.fWhen.fLength) {
         this->writef("        if (%s) {\n    ", String(var.fModifiers.fLayout.fWhen).c_str());
     }
-    const char* type = glsltype_string(fContext, var.fType);
     String name(var.fName);
-    this->writef("        %sVar = args.fUniformHandler->addUniform(&_outer, kFragment_GrShaderFlag,"
-                 " %s, \"%s\");\n", HCodeGenerator::FieldName(name.c_str()).c_str(), type,
-                 name.c_str());
+    if (var.fType.kind() != Type::kArray_Kind) {
+        this->writef("        %sVar = args.fUniformHandler->addUniform(&_outer, "
+                     "kFragment_GrShaderFlag, %s, \"%s\");\n",
+                     HCodeGenerator::FieldName(name.c_str()).c_str(),
+                     glsltype_string(fContext, var.fType),
+                     name.c_str());
+    } else {
+        this->writef("        %sVar = args.fUniformHandler->addUniformArray(&_outer, "
+                     "kFragment_GrShaderFlag, %s, \"%s\", %d);\n",
+                     HCodeGenerator::FieldName(name.c_str()).c_str(),
+                     glsltype_string(fContext, var.fType.componentType()),
+                     name.c_str(),
+                     var.fType.columns());
+    }
     if (var.fModifiers.fLayout.fWhen.fLength) {
         this->write("        }\n");
     }
diff --git a/src/sksl/SkSLCPPUniformCTypes.cpp b/src/sksl/SkSLCPPUniformCTypes.cpp
index b36ae1e..4103137 100644
--- a/src/sksl/SkSLCPPUniformCTypes.cpp
+++ b/src/sksl/SkSLCPPUniformCTypes.cpp
@@ -9,6 +9,7 @@
 #include "src/sksl/SkSLHCodeGenerator.h"
 #include "src/sksl/SkSLStringStream.h"
 
+#include <map>
 #include <vector>
 
 #if defined(SKSL_STANDALONE) || defined(GR_TEST_UTILS)
@@ -109,38 +110,75 @@
 
 String UniformCTypeMapper::setUniform(const String& pdman, const String& uniform,
                                       const String& var) const {
-    std::vector<String> tokens = { "pdman", "uniform", "var" };
-    std::vector<const String*> values = { &pdman, &uniform, &var };
-    return eval_template(fUniformTemplate, tokens, values);
+    std::vector<String> tokens = { "pdman", "uniform", "var", "count" };
+    String count;
+    String finalVar;
+    const String* activeTemplate;
+    if (fArrayCount != -1) {
+        count = to_string(fArrayCount);
+        finalVar = var + "[0]";
+        activeTemplate = &fUniformArrayTemplate;
+    } else {
+        count = "1";
+        finalVar = std::move(var);
+        activeTemplate = &fUniformSingleTemplate;
+    }
+    std::vector<const String*> values = { &pdman, &uniform, &finalVar, &count };
+    return eval_template(*activeTemplate, tokens, values);
 }
 
 UniformCTypeMapper::UniformCTypeMapper(
-        Layout::CType ctype, const std::vector<String>& skslTypes, const String& setUniformFormat,
+        Layout::CType ctype, const std::vector<String>& skslTypes,
+        const String& setUniformSingleFormat, const String& setUniformArrayFormat,
         bool enableTracking, const String& defaultValue, const String& dirtyExpressionFormat,
         const String& saveStateFormat)
     : fCType(ctype)
     , fSKSLTypes(skslTypes)
-    , fUniformTemplate(setUniformFormat)
-    , fInlineValue(determine_inline_from_template(setUniformFormat))
+    , fUniformSingleTemplate(setUniformSingleFormat)
+    , fUniformArrayTemplate(setUniformArrayFormat)
+    , fInlineValue(determine_inline_from_template(setUniformSingleFormat) &&
+                   determine_inline_from_template(setUniformArrayFormat))
     , fSupportsTracking(enableTracking)
     , fDefaultValue(defaultValue)
     , fDirtyExpressionTemplate(dirtyExpressionFormat)
-    , fSaveStateTemplate(saveStateFormat) { }
+    , fSaveStateTemplate(saveStateFormat) {}
 
-// NOTE: These would be macros, but C++ initialization lists for the sksl type names do not play
-// well with macro parsing.
+const UniformCTypeMapper* UniformCTypeMapper::arrayMapper(int count) const {
+    // We leak an object here, but since this code only ever runs as part of the build process and
+    // is rarely executed, it doesn't really matter.
+#if !defined(SKSL_STANDALONE) && !defined(GR_TEST_UTILS)
+    #error This code leaks memory and should not be present in a release build.
+#endif
+    UniformCTypeMapper* result = new UniformCTypeMapper(*this);
+    result->fArrayCount = count;
+    return result;
+}
 
-static UniformCTypeMapper REGISTER(Layout::CType ctype, const std::vector<String>& skslTypes,
-                                   const char* uniformFormat, const char* defaultValue,
-                                   const char* dirtyExpression) {
-    return UniformCTypeMapper(ctype, skslTypes, uniformFormat, defaultValue, dirtyExpression,
+
+static UniformCTypeMapper register_array(Layout::CType ctype, const std::vector<String>& skslTypes,
+                                   const char* singleSet, const char* arraySet,
+                                   const char* defaultValue, const char* dirtyExpression) {
+    return UniformCTypeMapper(ctype, skslTypes, singleSet, arraySet, defaultValue, dirtyExpression,
                               "${oldVar} = ${newVar}");
 }
 
-static UniformCTypeMapper REGISTER(Layout::CType ctype, const std::vector<String>& skslTypes,
+static UniformCTypeMapper register_array(Layout::CType ctype, const std::vector<String>& skslTypes,
+                                         const char* singleSet, const char* arraySet,
+                                         const char* defaultValue) {
+    return register_array(ctype, skslTypes, singleSet, arraySet, defaultValue,
+                              "${oldVar} != ${newVar}");
+}
+
+static UniformCTypeMapper register_type(Layout::CType ctype, const std::vector<String>& skslTypes,
+                                   const char* uniformFormat, const char* defaultValue,
+                                   const char* dirtyExpression) {
+    return register_array(ctype, skslTypes, uniformFormat, uniformFormat, defaultValue,
+                          dirtyExpression);
+}
+
+static UniformCTypeMapper register_type(Layout::CType ctype, const std::vector<String>& skslTypes,
                                    const char* uniformFormat, const char* defaultValue) {
-    return REGISTER(ctype, skslTypes, uniformFormat, defaultValue,
-                    "${oldVar} != ${newVar}");
+    return register_array(ctype, skslTypes, uniformFormat, uniformFormat, defaultValue);
 }
 
 //////////////////////////////
@@ -149,49 +187,53 @@
 
 static const std::vector<UniformCTypeMapper>& get_mappers() {
     static const std::vector<UniformCTypeMapper> registeredMappers = {
-    REGISTER(Layout::CType::kSkRect, { "half4", "float4", "double4" },
-        "${pdman}.set4fv(${uniform}, 1, reinterpret_cast<const float*>(&${var}))", // to gpu
+    register_type(Layout::CType::kSkRect, { "half4", "float4", "double4" },
+        "${pdman}.set4fv(${uniform}, ${count}, reinterpret_cast<const float*>(&${var}))", // to gpu
         "SkRect::MakeEmpty()",                                                     // default value
         "${oldVar}.isEmpty() || ${oldVar} != ${newVar}"),                          // dirty check
 
-    REGISTER(Layout::CType::kSkIRect, { "int4", "short4", "byte4" },
-        "${pdman}.set4iv(${uniform}, 1, reinterpret_cast<const int*>(&${var}))",   // to gpu
+    register_type(Layout::CType::kSkIRect, { "int4", "short4", "byte4" },
+        "${pdman}.set4iv(${uniform}, ${count}, reinterpret_cast<const int*>(&${var}))", // to gpu
         "SkIRect::MakeEmpty()",                                                    // default value
         "${oldVar}.isEmpty() || ${oldVar} != ${newVar}"),                          // dirty check
 
-    REGISTER(Layout::CType::kSkPMColor4f, { "half4", "float4", "double4" },
-        "${pdman}.set4fv(${uniform}, 1, ${var}.vec())",                            // to gpu
+    register_type(Layout::CType::kSkPMColor4f, { "half4", "float4", "double4" },
+        "${pdman}.set4fv(${uniform}, ${count}, ${var}.vec())",                     // to gpu
         "{SK_FloatNaN, SK_FloatNaN, SK_FloatNaN, SK_FloatNaN}"),                   // default value
 
-    REGISTER(Layout::CType::kSkV4, { "half4", "float4", "double4" },
-        "${pdman}.set4fv(${uniform}, 1, ${var}.ptr())",                            // to gpu
+    register_type(Layout::CType::kSkV4, { "half4", "float4", "double4" },
+        "${pdman}.set4fv(${uniform}, ${count}, ${var}.ptr())",                     // to gpu
         "SkV4{SK_FloatNaN, SK_FloatNaN, SK_FloatNaN, SK_FloatNaN}",                // default value
         "${oldVar} != (${newVar})"),                                               // dirty check
 
-    REGISTER(Layout::CType::kSkPoint, { "half2", "float2", "double2" } ,
-        "${pdman}.set2f(${uniform}, ${var}.fX, ${var}.fY)",                        // to gpu
+    register_array(Layout::CType::kSkPoint, { "half2", "float2", "double2" } ,
+        "${pdman}.set2f(${uniform}, ${var}.fX, ${var}.fY)",                        // single
+        "${pdman}.set2fv(${uniform}, ${count}, &${var}.fX)",                       // array
         "SkPoint::Make(SK_FloatNaN, SK_FloatNaN)"),                                // default value
 
-    REGISTER(Layout::CType::kSkIPoint, { "int2", "short2", "byte2" },
-        "${pdman}.set2i(${uniform}, ${var}.fX, ${var}.fY)",                        // to gpu
+    register_array(Layout::CType::kSkIPoint, { "int2", "short2", "byte2" },
+        "${pdman}.set2i(${uniform}, ${var}.fX, ${var}.fY)",                        // single
+        "${pdman}.set2iv(${uniform}, ${count}, ${var}.fX, ${var}.fY)",             // array
         "SkIPoint::Make(SK_NaN32, SK_NaN32)"),                                     // default value
 
-    REGISTER(Layout::CType::kSkMatrix, { "half3x3", "float3x3", "double3x3" },
-        "${pdman}.setSkMatrix(${uniform}, ${var})",                                // to gpu
+    register_type(Layout::CType::kSkMatrix, { "half3x3", "float3x3", "double3x3" },
+        "static_assert(${count} == 1); ${pdman}.setSkMatrix(${uniform}, ${var})",  // to gpu
         "SkMatrix::Scale(SK_FloatNaN, SK_FloatNaN)",                               // default value
         "!${oldVar}.cheapEqualTo(${newVar})"),                                     // dirty check
 
-    REGISTER(Layout::CType::kSkM44,  { "half4x4", "float4x4", "double4x4" },
-        "${pdman}.setSkM44(${uniform}, ${var})",                                   // to gpu
+    register_type(Layout::CType::kSkM44,  { "half4x4", "float4x4", "double4x4" },
+        "static_assert(${count} == 1); ${pdman}.setSkM44(${uniform}, ${var})",     // to gpu
         "SkM44(SkM44::kNaN_Constructor)",                                          // default value
         "${oldVar} != (${newVar})"),                                               // dirty check
 
-    REGISTER(Layout::CType::kFloat,  { "half", "float", "double" },
-        "${pdman}.set1f(${uniform}, ${var})",                                      // to gpu
+    register_array(Layout::CType::kFloat,  { "half", "float", "double" },
+        "${pdman}.set1f(${uniform}, ${var})",                                      // single
+        "${pdman}.set1fv(${uniform}, ${count}, &${var})",                          // array
         "SK_FloatNaN"),                                                            // default value
 
-    REGISTER(Layout::CType::kInt32, { "int", "short", "byte" },
-        "${pdman}.set1i(${uniform}, ${var})",                                      // to gpu
+    register_array(Layout::CType::kInt32, { "int", "short", "byte" },
+        "${pdman}.set1i(${uniform}, ${var})",                                      // single
+        "${pdman}.set1iv(${uniform}, ${count}, &${var})",                          // array
         "SK_NaN32"),                                                               // default value
     };
 
@@ -204,6 +246,10 @@
 // ctype and supports the sksl type of the variable.
 const UniformCTypeMapper* UniformCTypeMapper::Get(const Context& context, const Type& type,
                                                   const Layout& layout) {
+    if (type.kind() == Type::kArray_Kind) {
+        const UniformCTypeMapper* base = Get(context, type.componentType(), layout);
+        return base ? base->arrayMapper(type.columns()) : nullptr;
+    }
     const std::vector<UniformCTypeMapper>& registeredMappers = get_mappers();
 
     Layout::CType ctype = layout.fCType;
diff --git a/src/sksl/SkSLCPPUniformCTypes.h b/src/sksl/SkSLCPPUniformCTypes.h
index 076dda3..c6d3686 100644
--- a/src/sksl/SkSLCPPUniformCTypes.h
+++ b/src/sksl/SkSLCPPUniformCTypes.h
@@ -34,17 +34,11 @@
 // semicolons or newlines, which will be handled by the code generation itself.
 class UniformCTypeMapper {
 public:
-    // Create a templated mapper that does not support state tracking
     UniformCTypeMapper(Layout::CType ctype, const std::vector<String>& skslTypes,
-            const char* setUniformFormat)
-        : UniformCTypeMapper(ctype, skslTypes, setUniformFormat, false, "", "", "") { }
-
-    // Create a templated mapper that provides extra patterns for the state
-    // tracking expressions.
-    UniformCTypeMapper(Layout::CType ctype, const std::vector<String>& skslTypes,
-            const String& setUniformFormat, const String& defaultValue,
-            const String& dirtyExpressionFormat, const String& saveStateFormat)
-        : UniformCTypeMapper(ctype, skslTypes, setUniformFormat,
+            const String& setUniformSingleFormat, const String& setUniformArrayFormat,
+            const String& defaultValue = "", const String& dirtyExpressionFormat = "",
+            const String& saveStateFormat = "")
+        : UniformCTypeMapper(ctype, skslTypes, setUniformSingleFormat, setUniformArrayFormat,
                 true, defaultValue, dirtyExpressionFormat, saveStateFormat) { }
 
     // Returns nullptr if the type and layout are not supported; the returned pointer's ownership
@@ -116,12 +110,17 @@
 
 private:
     UniformCTypeMapper(Layout::CType ctype, const std::vector<String>& skslTypes,
-            const String& setUniformFormat, bool enableTracking, const String& defaultValue,
-            const String& dirtyExpressionFormat, const String& saveStateFormat);
+            const String& setUniformSingleFormat, const String& setUniformArrayFormat,
+            bool enableTracking, const String& defaultValue, const String& dirtyExpressionFormat,
+            const String& saveStateFormat);
+
+    const UniformCTypeMapper* arrayMapper(int arrayCount) const;
 
     Layout::CType fCType;
+    int fArrayCount = -1;
     std::vector<String> fSKSLTypes;
-    String fUniformTemplate;
+    String fUniformSingleTemplate;
+    String fUniformArrayTemplate;
     bool fInlineValue; // Cached value calculated from fUniformTemplate
 
     bool fSupportsTracking;
diff --git a/src/sksl/SkSLHCodeGenerator.cpp b/src/sksl/SkSLHCodeGenerator.cpp
index 012b302..40b6c69 100644
--- a/src/sksl/SkSLHCodeGenerator.cpp
+++ b/src/sksl/SkSLHCodeGenerator.cpp
@@ -33,6 +33,10 @@
 
 String HCodeGenerator::ParameterType(const Context& context, const Type& type,
                                      const Layout& layout) {
+    if (type.kind() == Type::kArray_Kind) {
+        return String::printf("std::array<%s>", ParameterType(context, type.componentType(),
+                                                              layout).c_str());
+    }
     Layout::CType ctype = ParameterCType(context, type, layout);
     if (ctype != Layout::CType::kDefault) {
         return Layout::CTypeToStr(ctype);
@@ -42,6 +46,7 @@
 
 Layout::CType HCodeGenerator::ParameterCType(const Context& context, const Type& type,
                                      const Layout& layout) {
+    SkASSERT(type.kind() != Type::kArray_Kind);
     if (layout.fCType != Layout::CType::kDefault) {
         return layout.fCType;
     }
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index ee95407..f0f21da 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -400,9 +400,6 @@
         const Type* type = baseType;
         std::vector<std::unique_ptr<Expression>> sizes;
         auto iter = varDecl.begin();
-        if (varData.fSizeCount > 0 && (modifiers.fFlags & Modifiers::kIn_Flag)) {
-            fErrors.error(varDecl.fOffset, "'in' variables may not have array type");
-        }
         for (size_t i = 0; i < varData.fSizeCount; ++i, ++iter) {
             const ASTNode& rawSize = *iter;
             if (rawSize) {