Emit top-level StructDefinition for every struct

Previously, structs that were defined as part of a variable declaration
would end up declared similarly in the generated code. Now, global
variable declarations that include a struct definition generate two
separate program elements.

Bug: skia:11228
Change-Id: Id7ddde6931fe07a250c2c9c46153879005535fb3
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/361359
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index a075539..87f5ffb 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -156,14 +156,8 @@
     }
 }
 
-bool GLSLCodeGenerator::writeStructDefinition(const Type& type) {
-    for (const Type* search : fWrittenStructs) {
-        if (*search == type) {
-            // already written
-            return false;
-        }
-    }
-    fWrittenStructs.push_back(&type);
+void GLSLCodeGenerator::writeStructDefinition(const StructDefinition& s) {
+    const Type& type = s.type();
     this->write("struct ");
     this->write(type.name());
     this->writeLine(" {");
@@ -178,18 +172,11 @@
         this->writeLine(";");
     }
     fIndentation--;
-    this->write("}");
-    return true;
+    this->writeLine("};");
 }
 
 void GLSLCodeGenerator::writeType(const Type& type) {
-    if (type.isStruct()) {
-        if (!this->writeStructDefinition(type)) {
-            this->write(type.name());
-        }
-    } else {
-        this->write(this->getTypeName(type));
-    }
+    this->write(this->getTypeName(type));
 }
 
 void GLSLCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
@@ -1542,9 +1529,7 @@
         case ProgramElement::Kind::kEnum:
             break;
         case ProgramElement::Kind::kStructDefinition:
-            if (this->writeStructDefinition(e.as<StructDefinition>().type())) {
-                this->writeLine(";");
-            }
+            this->writeStructDefinition(e.as<StructDefinition>());
             break;
         default:
             SkDEBUGFAILF("unsupported program element %s\n", e.description().c_str());
diff --git a/src/sksl/SkSLGLSLCodeGenerator.h b/src/sksl/SkSLGLSLCodeGenerator.h
index 1e32c80..917c04f 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.h
+++ b/src/sksl/SkSLGLSLCodeGenerator.h
@@ -98,7 +98,7 @@
 
     virtual String getTypeName(const Type& type);
 
-    bool writeStructDefinition(const Type& type);
+    void writeStructDefinition(const StructDefinition& s);
 
     void writeType(const Type& type);
 
@@ -204,10 +204,6 @@
     int fVarCount = 0;
     int fIndentation = 0;
     bool fAtLineStart = false;
-    // Keeps track of which struct types we have written. Given that we are unlikely to ever write
-    // more than one or two structs per shader, a simple linear search will be faster than anything
-    // fancier.
-    std::vector<const Type*> fWrittenStructs;
     std::set<String> fWrittenIntrinsics;
     // true if we have run into usages of dFdx / dFdy
     bool fFoundDerivatives = false;
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index a0fdc7a..336136c 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1305,6 +1305,8 @@
                                     "expected a struct here, found '" + type->name() + "'");
         return nullptr;
     }
+    SkDEBUGCODE(auto [iter, wasInserted] =) fDefinedStructs.insert(type);
+    SkASSERT(wasInserted);
     return std::make_unique<StructDefinition>(node.fOffset, *type);
 }
 
@@ -1415,6 +1417,14 @@
 void IRGenerator::convertGlobalVarDeclarations(const ASTNode& decl) {
     StatementArray decls = this->convertVarDeclarations(decl, Variable::Storage::kGlobal);
     for (std::unique_ptr<Statement>& stmt : decls) {
+        const Type* type = &stmt->as<VarDeclaration>().baseType();
+        if (type->isStruct()) {
+            auto [iter, wasInserted] = fDefinedStructs.insert(type);
+            if (wasInserted) {
+                fProgramElements->push_back(
+                        std::make_unique<StructDefinition>(decl.fOffset, *type));
+            }
+        }
         fProgramElements->push_back(std::make_unique<GlobalVarDeclaration>(decl.fOffset,
                                                                            std::move(stmt)));
     }
@@ -3006,6 +3016,7 @@
     fInvocations = -1;
     fRTAdjust = nullptr;
     fRTAdjustInterfaceBlock = nullptr;
+    fDefinedStructs.clear();
 
     AutoSymbolTable table(this);
 
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index 06b373a..a9a40b4 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -289,6 +289,7 @@
     int fLoopLevel = 0;
     int fSwitchLevel = 0;
     int fInvocations;
+    std::unordered_set<const Type*> fDefinedStructs;
     std::vector<std::unique_ptr<ProgramElement>>* fProgramElements = nullptr;
     std::vector<const ProgramElement*>*           fSharedElements = nullptr;
     const Variable* fRTAdjust = nullptr;
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 4a9fd40..9efdccb 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -132,20 +132,13 @@
     }
 }
 
-bool MetalCodeGenerator::writeStructDefinition(const Type& type) {
-    for (const Type* search : fWrittenStructs) {
-        if (*search == type) {
-            // already written
-            return false;
-        }
-    }
-    fWrittenStructs.push_back(&type);
+void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
+    const Type& type = s.type();
     this->writeLine("struct " + type.name() + " {");
     fIndentation++;
     this->writeFields(type.fields(), type.fOffset);
     fIndentation--;
-    this->write("}");
-    return true;
+    this->writeLine("};");
 }
 
 // Flags an error if an array type is found. Meant to be used in places where an array type might
@@ -160,11 +153,6 @@
 // Call `writeArrayDimensions` to write the type's accompanying array sizes.
 void MetalCodeGenerator::writeBaseType(const Type& type) {
     switch (type.typeKind()) {
-        case Type::TypeKind::kStruct:
-            if (!this->writeStructDefinition(type)) {
-                this->write(type.name());
-            }
-            break;
         case Type::TypeKind::kArray:
             this->writeBaseType(type.componentType());
             break;
@@ -1646,7 +1634,6 @@
     if (structType->isArray()) {
         structType = &structType->componentType();
     }
-    fWrittenStructs.push_back(structType);
     fIndentation++;
     this->writeFields(structType->fields(), structType->fOffset, &intf);
     if (fProgram.fInputs.fRTHeight) {
@@ -2058,19 +2045,7 @@
 void MetalCodeGenerator::writeStructDefinitions() {
     for (const ProgramElement* e : fProgram.elements()) {
         if (e->is<StructDefinition>()) {
-            if (this->writeStructDefinition(e->as<StructDefinition>().type())) {
-                this->writeLine(";");
-            }
-        } else if (e->is<GlobalVarDeclaration>()) {
-            // If a global var declaration introduces a struct type, we need to write that type
-            // here, since globals are all embedded in a sub-struct.
-            const Type* type = &e->as<GlobalVarDeclaration>().declaration()
-                                 ->as<VarDeclaration>().baseType();
-            if (type->isStruct()) {
-                if (this->writeStructDefinition(*type)) {
-                    this->writeLine(";");
-                }
-            }
+            this->writeStructDefinition(e->as<StructDefinition>());
         }
     }
 }
diff --git a/src/sksl/SkSLMetalCodeGenerator.h b/src/sksl/SkSLMetalCodeGenerator.h
index eee2d4c..51ef73d 100644
--- a/src/sksl/SkSLMetalCodeGenerator.h
+++ b/src/sksl/SkSLMetalCodeGenerator.h
@@ -170,7 +170,7 @@
 
     String typeName(const Type& type);
 
-    bool writeStructDefinition(const Type& type);
+    void writeStructDefinition(const StructDefinition& s);
 
     void disallowArrayTypes(const Type& type, int offset);
 
@@ -311,10 +311,6 @@
     int fVarCount = 0;
     int fIndentation = 0;
     bool fAtLineStart = false;
-    // Keeps track of which struct types we have written. Given that we are unlikely to ever write
-    // more than one or two structs per shader, a simple linear search will be faster than anything
-    // fancier.
-    std::vector<const Type*> fWrittenStructs;
     std::set<String> fWrittenIntrinsics;
     // true if we have run into usages of dFdx / dFdy
     bool fFoundDerivatives = false;