Refactored swizzle domain testing

This moves the swizzle domain test into the path used by the DSL so that
the error check benefits both sides. To make this possible we need to
be able to distinguish between equivalent swizzle components like x and
r, so they aren't collapsed down to the same component until the very
end.

Change-Id: I48f2582886391eabd7ce6eae949babdeead6051e
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/445280
Reviewed-by: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/include/private/SkSLDefines.h b/include/private/SkSLDefines.h
index 71611a1..cbf5511 100644
--- a/include/private/SkSLDefines.h
+++ b/include/private/SkSLDefines.h
@@ -44,8 +44,10 @@
 namespace SwizzleComponent {
 
 enum Type : int8_t {
-    X = 0, Y = 1, Z = 2, W = 3,
-    R = 0, G = 1, B = 2, A = 3,
+    X  =  0,  Y =  1,  Z =  2,  W =  3,
+    R  =  4,  G =  5,  B =  6,  A =  7,
+    S  =  8,  T =  9,  P = 10,  Q = 11,
+    UL = 12, UT = 13, UR = 14, UB = 15,
     ZERO,
     ONE
 };
diff --git a/src/sksl/SkSLDSLParser.cpp b/src/sksl/SkSLDSLParser.cpp
index cd4c736..fd62d56 100644
--- a/src/sksl/SkSLDSLParser.cpp
+++ b/src/sksl/SkSLDSLParser.cpp
@@ -1504,22 +1504,22 @@
         switch (swizzleMask[i]) {
             case '0': components[i] = SwizzleComponent::ZERO; break;
             case '1': components[i] = SwizzleComponent::ONE;  break;
-            case 'r':
-            case 'x':
-            case 's':
-            case 'L': components[i] = SwizzleComponent::R;    break;
-            case 'g':
-            case 'y':
-            case 't':
-            case 'T': components[i] = SwizzleComponent::G;    break;
-            case 'b':
-            case 'z':
-            case 'p':
-            case 'R': components[i] = SwizzleComponent::B;    break;
-            case 'a':
-            case 'w':
-            case 'q':
-            case 'B': components[i] = SwizzleComponent::A;    break;
+            case 'r': components[i] = SwizzleComponent::R;    break;
+            case 'x': components[i] = SwizzleComponent::X;    break;
+            case 's': components[i] = SwizzleComponent::S;    break;
+            case 'L': components[i] = SwizzleComponent::UL;   break;
+            case 'g': components[i] = SwizzleComponent::G;    break;
+            case 'y': components[i] = SwizzleComponent::Y;    break;
+            case 't': components[i] = SwizzleComponent::T;    break;
+            case 'T': components[i] = SwizzleComponent::UT;   break;
+            case 'b': components[i] = SwizzleComponent::B;    break;
+            case 'z': components[i] = SwizzleComponent::Z;    break;
+            case 'p': components[i] = SwizzleComponent::P;    break;
+            case 'R': components[i] = SwizzleComponent::UR;   break;
+            case 'a': components[i] = SwizzleComponent::A;    break;
+            case 'w': components[i] = SwizzleComponent::W;    break;
+            case 'q': components[i] = SwizzleComponent::Q;    break;
+            case 'B': components[i] = SwizzleComponent::UB;   break;
             default:
                 this->error(offset,
                         String::printf("invalid swizzle component '%c'", swizzleMask[i]).c_str());
diff --git a/src/sksl/ir/SkSLSwizzle.cpp b/src/sksl/ir/SkSLSwizzle.cpp
index 7e56a94..33c6e39 100644
--- a/src/sksl/ir/SkSLSwizzle.cpp
+++ b/src/sksl/ir/SkSLSwizzle.cpp
@@ -14,7 +14,7 @@
 
 namespace SkSL {
 
-static bool validate_swizzle_domain(skstd::string_view fields) {
+static bool validate_swizzle_domain(const ComponentArray& fields) {
     enum SwizzleDomain {
         kCoordinate,
         kColor,
@@ -24,35 +24,35 @@
 
     skstd::optional<SwizzleDomain> domain;
 
-    for (char field : fields) {
+    for (int8_t field : fields) {
         SwizzleDomain fieldDomain;
         switch (field) {
-            case 'x':
-            case 'y':
-            case 'z':
-            case 'w':
+            case SwizzleComponent::X:
+            case SwizzleComponent::Y:
+            case SwizzleComponent::Z:
+            case SwizzleComponent::W:
                 fieldDomain = kCoordinate;
                 break;
-            case 'r':
-            case 'g':
-            case 'b':
-            case 'a':
+            case SwizzleComponent::R:
+            case SwizzleComponent::G:
+            case SwizzleComponent::B:
+            case SwizzleComponent::A:
                 fieldDomain = kColor;
                 break;
-            case 's':
-            case 't':
-            case 'p':
-            case 'q':
+            case SwizzleComponent::S:
+            case SwizzleComponent::T:
+            case SwizzleComponent::P:
+            case SwizzleComponent::Q:
                 fieldDomain = kUV;
                 break;
-            case 'L':
-            case 'T':
-            case 'R':
-            case 'B':
+            case SwizzleComponent::UL:
+            case SwizzleComponent::UT:
+            case SwizzleComponent::UR:
+            case SwizzleComponent::UB:
                 fieldDomain = kRectangle;
                 break;
-            case '0':
-            case '1':
+            case SwizzleComponent::ZERO:
+            case SwizzleComponent::ONE:
                 continue;
             default:
                 return false;
@@ -68,6 +68,38 @@
     return true;
 }
 
+static char mask_char(int8_t component) {
+    switch (component) {
+        case SwizzleComponent::X:    return 'x';
+        case SwizzleComponent::Y:    return 'y';
+        case SwizzleComponent::Z:    return 'z';
+        case SwizzleComponent::W:    return 'w';
+        case SwizzleComponent::R:    return 'r';
+        case SwizzleComponent::G:    return 'g';
+        case SwizzleComponent::B:    return 'b';
+        case SwizzleComponent::A:    return 'a';
+        case SwizzleComponent::S:    return 's';
+        case SwizzleComponent::T:    return 't';
+        case SwizzleComponent::P:    return 'p';
+        case SwizzleComponent::Q:    return 'q';
+        case SwizzleComponent::UL:   return 'L';
+        case SwizzleComponent::UT:   return 'T';
+        case SwizzleComponent::UR:   return 'R';
+        case SwizzleComponent::UB:   return 'B';
+        case SwizzleComponent::ZERO: return '0';
+        case SwizzleComponent::ONE:  return '1';
+        default: SkUNREACHABLE;
+    }
+}
+
+static String mask_string(const ComponentArray& components) {
+    String result;
+    for (int8_t component : components) {
+        result += mask_char(component);
+    }
+    return result;
+}
+
 static std::unique_ptr<Expression> optimize_constructor_swizzle(const Context& context,
                                                                 const AnyConstructor& base,
                                                                 ComponentArray components) {
@@ -199,59 +231,36 @@
 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
                                              std::unique_ptr<Expression> base,
                                              skstd::string_view maskString) {
-    if (!validate_swizzle_domain(maskString)) {
-        context.fErrors->error(base->fOffset, "invalid swizzle mask '" + maskString + "'");
-        return nullptr;
-    }
-
     ComponentArray components;
     for (char field : maskString) {
         switch (field) {
-            case '0':
-                components.push_back(SwizzleComponent::ZERO);
-                break;
-            case '1':
-                components.push_back(SwizzleComponent::ONE);
-                break;
-            case 'x':
-            case 'r':
-            case 's':
-            case 'L':
-                components.push_back(SwizzleComponent::X);
-                break;
-            case 'y':
-            case 'g':
-            case 't':
-            case 'T':
-                components.push_back(SwizzleComponent::Y);
-                break;
-            case 'z':
-            case 'b':
-            case 'p':
-            case 'R':
-                components.push_back(SwizzleComponent::Z);
-                break;
-            case 'w':
-            case 'a':
-            case 'q':
-            case 'B':
-                components.push_back(SwizzleComponent::W);
-                break;
+            case '0': components.push_back(SwizzleComponent::ZERO); break;
+            case '1': components.push_back(SwizzleComponent::ONE);  break;
+            case 'x': components.push_back(SwizzleComponent::X);    break;
+            case 'r': components.push_back(SwizzleComponent::R);    break;
+            case 's': components.push_back(SwizzleComponent::S);    break;
+            case 'L': components.push_back(SwizzleComponent::UL);   break;
+            case 'y': components.push_back(SwizzleComponent::Y);    break;
+            case 'g': components.push_back(SwizzleComponent::G);    break;
+            case 't': components.push_back(SwizzleComponent::T);    break;
+            case 'T': components.push_back(SwizzleComponent::UT);   break;
+            case 'z': components.push_back(SwizzleComponent::Z);    break;
+            case 'b': components.push_back(SwizzleComponent::B);    break;
+            case 'p': components.push_back(SwizzleComponent::P);    break;
+            case 'R': components.push_back(SwizzleComponent::UR);   break;
+            case 'w': components.push_back(SwizzleComponent::W);    break;
+            case 'a': components.push_back(SwizzleComponent::A);    break;
+            case 'q': components.push_back(SwizzleComponent::Q);    break;
+            case 'B': components.push_back(SwizzleComponent::UB);   break;
             default:
-                SkDEBUGFAIL("unexpected swizzle component");
+                context.fErrors->error(base->fOffset,
+                        String::printf("invalid swizzle component '%c'", field));
                 return nullptr;
         }
     }
-    return Convert(context, std::move(base), std::move(components), maskString);
+    return Convert(context, std::move(base), std::move(components));
 }
 
-std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
-                                             std::unique_ptr<Expression> base,
-                                             ComponentArray inComponents) {
-    return Convert(context, std::move(base), std::move(inComponents), "");
-}
-
-
 // Swizzles are complicated due to constant components. The most difficult case is a mask like
 // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that evaluates
 // 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0') together and use a
@@ -259,9 +268,12 @@
 // 'float4(base.xw, 1, 0).xzyw'.
 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
                                              std::unique_ptr<Expression> base,
-                                             ComponentArray inComponents,
-                                             skstd::string_view maskString) {
-    SkASSERT(maskString.empty() || (int) maskString.length() == inComponents.count());
+                                             ComponentArray inComponents) {
+    if (!validate_swizzle_domain(inComponents)) {
+        context.fErrors->error(base->fOffset,
+                "invalid swizzle mask '" + mask_string(inComponents) + "'");
+        return nullptr;
+    }
 
     const int offset = base->fOffset;
     const Type& baseType = base->type();
@@ -273,11 +285,8 @@
     }
 
     if (inComponents.count() > 4) {
-        String error = "too many components in swizzle mask";
-        if (!maskString.empty()) {
-            error += " '" + maskString + "'";
-        }
-        context.fErrors->error(offset, error.c_str());
+        context.fErrors->error(offset,
+                "too many components in swizzle mask '" + mask_string(inComponents) + "'");
         return nullptr;
     }
 
@@ -290,10 +299,16 @@
                 // Skip over constant fields for now.
                 break;
             case SwizzleComponent::X:
+            case SwizzleComponent::R:
+            case SwizzleComponent::S:
+            case SwizzleComponent::UL:
                 foundXYZW = true;
                 maskComponents.push_back(SwizzleComponent::X);
                 break;
             case SwizzleComponent::Y:
+            case SwizzleComponent::G:
+            case SwizzleComponent::T:
+            case SwizzleComponent::UT:
                 foundXYZW = true;
                 if (baseType.columns() >= 2) {
                     maskComponents.push_back(SwizzleComponent::Y);
@@ -301,6 +316,9 @@
                 }
                 [[fallthrough]];
             case SwizzleComponent::Z:
+            case SwizzleComponent::B:
+            case SwizzleComponent::P:
+            case SwizzleComponent::UR:
                 foundXYZW = true;
                 if (baseType.columns() >= 3) {
                     maskComponents.push_back(SwizzleComponent::Z);
@@ -308,6 +326,9 @@
                 }
                 [[fallthrough]];
             case SwizzleComponent::W:
+            case SwizzleComponent::A:
+            case SwizzleComponent::Q:
+            case SwizzleComponent::UB:
                 foundXYZW = true;
                 if (baseType.columns() >= 4) {
                     maskComponents.push_back(SwizzleComponent::W);
@@ -317,9 +338,8 @@
             default:
                 // The swizzle component references a field that doesn't exist in the base type.
                 context.fErrors->error(offset,
-                        maskString.empty() ? "invalid swizzle component"
-                                           : String::printf("invalid swizzle component '%c'",
-                                                            maskString[i]));
+                       String::printf("invalid swizzle component '%c'",
+                            mask_char(inComponents[i])));
                 return nullptr;
         }
     }
diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h
index a82c11e..107c420 100644
--- a/src/sksl/ir/SkSLSwizzle.h
+++ b/src/sksl/ir/SkSLSwizzle.h
@@ -79,11 +79,6 @@
     }
 
 private:
-    static std::unique_ptr<Expression> Convert(const Context& context,
-                                               std::unique_ptr<Expression> base,
-                                               ComponentArray inComponents,
-                                               skstd::string_view maskString);
-
     Swizzle(const Type* type, std::unique_ptr<Expression> base, const ComponentArray& components)
         : INHERITED(base->fOffset, kExpressionKind, type)
         , fBase(std::move(base))
diff --git a/tests/sksl/errors/BadFieldAccess.glsl b/tests/sksl/errors/BadFieldAccess.glsl
index f93c670..fcfdd1b 100644
--- a/tests/sksl/errors/BadFieldAccess.glsl
+++ b/tests/sksl/errors/BadFieldAccess.glsl
@@ -3,6 +3,6 @@
 error: 3: type 'S' does not have a field named 'missing'
 error: 4: not a function
 error: 5: type mismatch: '=' cannot operate on 'float', 'bool3'
-error: 6: invalid swizzle mask 'missing'
+error: 6: invalid swizzle component 'm'
 error: 7: expected array, but found 'float'
 5 errors