sksl support for geometry shaders

BUG=skia:

Change-Id: I8541b98aadcf4c2484fef73e2f49be3ee38bc1e2
Reviewed-on: https://skia-review.googlesource.com/8409
Reviewed-by: Ben Wagner <benjaminwagner@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 743745a..004eab7 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -37,6 +37,10 @@
 #include "sksl_frag.include"
 ;
 
+static const char* SKSL_GEOM_INCLUDE =
+#include "sksl_geom.include"
+;
+
 namespace SkSL {
 
 Compiler::Compiler()
@@ -459,6 +463,9 @@
         case Program::kFragment_Kind:
             this->internalConvertProgram(SkString(SKSL_FRAG_INCLUDE), &ignored, &elements);
             break;
+        case Program::kGeometry_Kind:
+            this->internalConvertProgram(SkString(SKSL_GEOM_INCLUDE), &ignored, &elements);
+            break;
     }
     fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
     Modifiers::Flag defaultPrecision;
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index f7dcf2b..497d06d 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -18,9 +18,11 @@
 #include "SkSLIRGenerator.h"
 
 #define SK_FRAGCOLOR_BUILTIN    10001
+#define SK_IN_BUILTIN           10002
 #define SK_FRAGCOORD_BUILTIN       15
 #define SK_VERTEXID_BUILTIN         5
 #define SK_CLIPDISTANCE_BUILTIN     3
+#define SK_INVOCATIONID_BUILTIN     8
 
 namespace SkSL {
 
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index 18cfba7..3f5f9d1 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -332,6 +332,12 @@
         case SK_CLIPDISTANCE_BUILTIN:
             this->write("gl_ClipDistance");
             break;
+        case SK_IN_BUILTIN:
+            this->write("gl_in");
+            break;
+        case SK_INVOCATIONID_BUILTIN:
+            this->write("gl_InvocationID");
+            break;
         default:
             this->write(ref.fVariable.fName);
     }
@@ -584,7 +590,7 @@
 }
 
 void GLSLCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
-    if (intf.fTypeName == "gl_PerVertex") {
+    if (intf.fTypeName == "sk_PerVertex") {
         return;
     }
     this->writeModifiers(intf.fVariable.fModifiers, true);
diff --git a/src/sksl/SkSLMain.cpp b/src/sksl/SkSLMain.cpp
index f493b05..46e9c18 100644
--- a/src/sksl/SkSLMain.cpp
+++ b/src/sksl/SkSLMain.cpp
@@ -24,8 +24,10 @@
         kind = SkSL::Program::kVertex_Kind;
     } else if (len > 5 && !strcmp(argv[1] + strlen(argv[1]) - 5, ".frag")) {
         kind = SkSL::Program::kFragment_Kind;
+    } else if (len > 5 && !strcmp(argv[1] + strlen(argv[1]) - 5, ".geom")) {
+        kind = SkSL::Program::kGeometry_Kind;
     } else {
-        printf("input filename must end in '.vert' or '.frag'\n");
+        printf("input filename must end in '.vert', '.frag', or '.geom'\n");
         exit(1);
     }
 
diff --git a/src/sksl/SkSLParser.cpp b/src/sksl/SkSLParser.cpp
index 5bffe1e..cc47577 100644
--- a/src/sksl/SkSLParser.cpp
+++ b/src/sksl/SkSLParser.cpp
@@ -123,6 +123,15 @@
     fLayoutKeys[SkString("override_coverage")]           = kOverrideCoverage_LayoutKey;
     fLayoutKeys[SkString("blend_support_all_equations")] = kBlendSupportAllEquations_LayoutKey;
     fLayoutKeys[SkString("push_constant")]               = kPushConstant_LayoutKey;
+    fLayoutKeys[SkString("points")]                      = kPoints_LayoutKey;
+    fLayoutKeys[SkString("lines")]                       = kLines_LayoutKey;
+    fLayoutKeys[SkString("line_strip")]                  = kLineStrip_LayoutKey;
+    fLayoutKeys[SkString("lines_adjacency")]             = kLinesAdjacency_LayoutKey;
+    fLayoutKeys[SkString("triangles")]                   = kTriangles_LayoutKey;
+    fLayoutKeys[SkString("triangle_strip")]              = kTriangleStrip_LayoutKey;
+    fLayoutKeys[SkString("triangles_adjacency")]         = kTrianglesAdjacency_LayoutKey;
+    fLayoutKeys[SkString("max_vertices")]                = kMaxVertices_LayoutKey;
+    fLayoutKeys[SkString("invocations")]                 = kInvocations_LayoutKey;
 }
 
 Parser::~Parser() {
@@ -570,12 +579,15 @@
     bool blendSupportAllEquations = false;
     Layout::Format format = Layout::Format::kUnspecified;
     bool pushConstant = false;
+    Layout::Primitive primitive = Layout::kUnspecified_Primitive;
+    int maxVertices = -1;
+    int invocations = -1;
     if (this->peek().fKind == Token::LAYOUT) {
         this->nextToken();
         if (!this->expect(Token::LPAREN, "'('")) {
             return Layout(location, offset, binding, index, set, builtin, inputAttachmentIndex,
                           originUpperLeft, overrideCoverage, blendSupportAllEquations, format,
-                          pushConstant);
+                          pushConstant, primitive, maxVertices, invocations);
         }
         for (;;) {
             Token t = this->nextToken();
@@ -615,6 +627,33 @@
                     case kPushConstant_LayoutKey:
                         pushConstant = true;
                         break;
+                    case kPoints_LayoutKey:
+                        primitive = Layout::kPoints_Primitive;
+                        break;
+                    case kLines_LayoutKey:
+                        primitive = Layout::kLines_Primitive;
+                        break;
+                    case kLineStrip_LayoutKey:
+                        primitive = Layout::kLineStrip_Primitive;
+                        break;
+                    case kLinesAdjacency_LayoutKey:
+                        primitive = Layout::kLinesAdjacency_Primitive;
+                        break;
+                    case kTriangles_LayoutKey:
+                        primitive = Layout::kTriangles_Primitive;
+                        break;
+                    case kTriangleStrip_LayoutKey:
+                        primitive = Layout::kTriangleStrip_Primitive;
+                        break;
+                    case kTrianglesAdjacency_LayoutKey:
+                        primitive = Layout::kTrianglesAdjacency_Primitive;
+                        break;
+                    case kMaxVertices_LayoutKey:
+                        maxVertices = this->layoutInt();
+                        break;
+                    case kInvocations_LayoutKey:
+                        invocations = this->layoutInt();
+                        break;
                 }
             } else if (Layout::ReadFormat(t.fText, &format)) {
                // AST::ReadFormat stored the result in 'format'.
@@ -633,7 +672,7 @@
     }
     return Layout(location, offset, binding, index, set, builtin, inputAttachmentIndex,
                   originUpperLeft, overrideCoverage, blendSupportAllEquations, format,
-                  pushConstant);
+                  pushConstant, primitive, maxVertices, invocations);
 }
 
 /* layout? (UNIFORM | CONST | IN | OUT | INOUT | LOWP | MEDIUMP | HIGHP | FLAT | NOPERSPECTIVE |
diff --git a/src/sksl/SkSLParser.h b/src/sksl/SkSLParser.h
index f277745..78d9933 100644
--- a/src/sksl/SkSLParser.h
+++ b/src/sksl/SkSLParser.h
@@ -218,7 +218,16 @@
         kOriginUpperLeft_LayoutKey,
         kOverrideCoverage_LayoutKey,
         kBlendSupportAllEquations_LayoutKey,
-        kPushConstant_LayoutKey
+        kPushConstant_LayoutKey,
+        kPoints_LayoutKey,
+        kLines_LayoutKey,
+        kLineStrip_LayoutKey,
+        kLinesAdjacency_LayoutKey,
+        kTriangles_LayoutKey,
+        kTriangleStrip_LayoutKey,
+        kTrianglesAdjacency_LayoutKey,
+        kMaxVertices_LayoutKey,
+        kInvocations_LayoutKey
     };
     std::unordered_map<SkString, LayoutKey> fLayoutKeys;
 
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 0b7da57..feb66a6 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -1873,7 +1873,7 @@
             SkString name("sksl_synthetic_uniforms");
             Type intfStruct(Position(), name, fields);
             Layout layout(-1, -1, 1, -1, -1, -1, -1, false, false, false, Layout::Format::kUnspecified,
-                          false);
+                          false, Layout::kUnspecified_Primitive, -1, -1);
             Variable intfVar(Position(), Modifiers(layout, Modifiers::kUniform_Flag), name,
                              intfStruct, Variable::kGlobal_Storage);
             InterfaceBlock intf(Position(), intfVar, name, SkString(""),
@@ -2859,6 +2859,9 @@
         case Program::kFragment_Kind:
             this->writeWord(SpvExecutionModelFragment, out);
             break;
+        case Program::kGeometry_Kind:
+            this->writeWord(SpvExecutionModelGeometry, out);
+            break;
     }
     this->writeWord(fFunctionMap[main], out);
     this->writeString(main->fName.c_str(), out);
diff --git a/src/sksl/ir/SkSLLayout.h b/src/sksl/ir/SkSLLayout.h
index 75650e2..5e7ec44 100644
--- a/src/sksl/ir/SkSLLayout.h
+++ b/src/sksl/ir/SkSLLayout.h
@@ -19,6 +19,17 @@
  * layout (location = 0) int x;
  */
 struct Layout {
+    enum Primitive {
+        kUnspecified_Primitive = -1,
+        kPoints_Primitive,
+        kLines_Primitive,
+        kLineStrip_Primitive,
+        kLinesAdjacency_Primitive,
+        kTriangles_Primitive,
+        kTriangleStrip_Primitive,
+        kTrianglesAdjacency_Primitive
+    };
+
     // These are used by images in GLSL. We only support a subset of what GL supports.
     enum class Format {
         kUnspecified = -1,
@@ -79,7 +90,8 @@
 
     Layout(int location, int offset, int binding, int index, int set, int builtin,
            int inputAttachmentIndex, bool originUpperLeft, bool overrideCoverage,
-           bool blendSupportAllEquations, Format format, bool pushconstant)
+           bool blendSupportAllEquations, Format format, bool pushconstant, Primitive primitive,
+           int maxVertices, int invocations)
     : fLocation(location)
     , fOffset(offset)
     , fBinding(binding)
@@ -91,7 +103,10 @@
     , fOverrideCoverage(overrideCoverage)
     , fBlendSupportAllEquations(blendSupportAllEquations)
     , fFormat(format)
-    , fPushConstant(pushconstant) {}
+    , fPushConstant(pushconstant)
+    , fPrimitive(primitive)
+    , fMaxVertices(maxVertices)
+    , fInvocations(invocations) {}
 
     Layout()
     : fLocation(-1)
@@ -105,7 +120,10 @@
     , fOverrideCoverage(false)
     , fBlendSupportAllEquations(false)
     , fFormat(Format::kUnspecified)
-    , fPushConstant(false) {}
+    , fPushConstant(false)
+    , fPrimitive(kUnspecified_Primitive)
+    , fMaxVertices(-1)
+    , fInvocations(-1) {}
 
     SkString description() const {
         SkString result;
@@ -158,6 +176,46 @@
             result += separator + "push_constant";
             separator = ", ";
         }
+        switch (fPrimitive) {
+            case kPoints_Primitive:
+                result += separator + "points";
+                separator = ", ";
+                break;
+            case kLines_Primitive:
+                result += separator + "lines";
+                separator = ", ";
+                break;
+            case kLineStrip_Primitive:
+                result += separator + "line_strip";
+                separator = ", ";
+                break;
+            case kLinesAdjacency_Primitive:
+                result += separator + "lines_adjacency";
+                separator = ", ";
+                break;
+            case kTriangles_Primitive:
+                result += separator + "triangles";
+                separator = ", ";
+                break;
+            case kTriangleStrip_Primitive:
+                result += separator + "triangle_strip";
+                separator = ", ";
+                break;
+            case kTrianglesAdjacency_Primitive:
+                result += separator + "triangles_adjacency";
+                separator = ", ";
+                break;
+            case kUnspecified_Primitive:
+                break;
+        }
+        if (fMaxVertices >= 0) {
+            result += separator + "max_vertices = " + to_string(fMaxVertices);
+            separator = ", ";
+        }
+        if (fInvocations >= 0) {
+            result += separator + "invocations = " + to_string(fInvocations);
+            separator = ", ";
+        }
         if (result.size() > 0) {
             result = "layout (" + result + ")";
         }
@@ -175,7 +233,10 @@
                fOriginUpperLeft          == other.fOriginUpperLeft &&
                fOverrideCoverage         == other.fOverrideCoverage &&
                fBlendSupportAllEquations == other.fBlendSupportAllEquations &&
-               fFormat                   == other.fFormat;
+               fFormat                   == other.fFormat &&
+               fPrimitive                == other.fPrimitive &&
+               fMaxVertices              == other.fMaxVertices &&
+               fInvocations              == other.fInvocations;
     }
 
     bool operator!=(const Layout& other) const {
@@ -198,6 +259,9 @@
     bool fBlendSupportAllEquations;
     Format fFormat;
     bool fPushConstant;
+    Primitive fPrimitive;
+    int fMaxVertices;
+    int fInvocations;
 };
 
 } // namespace
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index e4a975b..2ca9372 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -52,7 +52,8 @@
 
     enum Kind {
         kFragment_Kind,
-        kVertex_Kind
+        kVertex_Kind,
+        kGeometry_Kind
     };
 
     Program(Kind kind,
diff --git a/src/sksl/sksl.include b/src/sksl/sksl.include
index fc5e40f..11e3710 100644
--- a/src/sksl/sksl.include
+++ b/src/sksl/sksl.include
@@ -516,10 +516,6 @@
 $genType fwidth($genType p);
 $genType fwidthCoarse($genType p);
 $genType fwidthFine($genType p);
-void EmitStreamVertex(int stream);
-void EndStreamPrimitive(int stream);
-void EmitVertex();
-void EndPrimitive();
 void barrier();
 void memoryBarrier();
 void memoryBarrierAtomicCounter();
diff --git a/src/sksl/sksl_geom.include b/src/sksl/sksl_geom.include
new file mode 100644
index 0000000..18e779f
--- /dev/null
+++ b/src/sksl/sksl_geom.include
@@ -0,0 +1,24 @@
+STRINGIFY(
+
+// defines built-in interfaces supported by SkiaSL geometry shaders
+
+layout(builtin=10002) in sk_PerVertex {
+  layout(builtin=0) vec4 gl_Position;
+  layout(builtin=1) float gl_PointSize;
+  layout(builtin=3) float sk_ClipDistance[];
+} sk_in[];
+
+out sk_PerVertex {
+    layout(builtin=0) vec4 gl_Position;
+    layout(builtin=1) float gl_PointSize;
+    layout(builtin=3) float sk_ClipDistance[];
+};
+
+layout(builtin=8) int sk_InvocationID;
+
+void EmitStreamVertex(int stream);
+void EndStreamPrimitive(int stream);
+void EmitVertex();
+void EndPrimitive();
+
+)
diff --git a/src/sksl/sksl_vert.include b/src/sksl/sksl_vert.include
index b5ccfcb..e7e9d59 100644
--- a/src/sksl/sksl_vert.include
+++ b/src/sksl/sksl_vert.include
@@ -2,8 +2,8 @@
 
 // defines built-in interfaces supported by SkiaSL vertex shaders
 
-out gl_PerVertex {
-  	layout(builtin=0) vec4 gl_Position;
+out sk_PerVertex {
+    layout(builtin=0) vec4 gl_Position;
     layout(builtin=1) float gl_PointSize;
     layout(builtin=3) float sk_ClipDistance[1];
 };
diff --git a/tests/SkSLGLSLTest.cpp b/tests/SkSLGLSLTest.cpp
index a57fd9d..a0fdb98 100644
--- a/tests/SkSLGLSLTest.cpp
+++ b/tests/SkSLGLSLTest.cpp
@@ -750,4 +750,31 @@
          "}\n");
 }
 
+DEF_TEST(SkSLGeometry, r) {
+    test(r,
+         "layout(points) in;"
+         "layout(invocations = 2) in;"
+         "layout(line_strip, max_vertices = 2) out;"
+         "void main() {"
+         "gl_Position = sk_in[0].gl_Position + vec4(-0.5, 0, 0, sk_InvocationID);"
+         "EmitVertex();"
+         "gl_Position = sk_in[0].gl_Position + vec4(0.5, 0, 0, sk_InvocationID);"
+         "EmitVertex();"
+         "EndPrimitive();"
+         "}",
+         *SkSL::ShaderCapsFactory::Default(),
+         "#version 400\n"
+         "layout (points) in ;\n"
+         "layout (invocations = 2) in ;\n"
+         "layout (line_strip, max_vertices = 2) out ;\n"
+         "void main() {\n"
+         "    gl_Position = gl_in[0].gl_Position + vec4(-0.5, 0.0, 0.0, float(gl_InvocationID));\n"
+         "    EmitVertex();\n"
+         "    gl_Position = gl_in[0].gl_Position + vec4(0.5, 0.0, 0.0, float(gl_InvocationID));\n"
+         "    EmitVertex();\n"
+         "    EndPrimitive();\n"
+         "}\n",
+         SkSL::Program::kGeometry_Kind);
+}
+
 #endif