Implement gl_PointSize point sprite emulation using D3D11 geometry shaders.

TRAC #22412

Signed-off-by: Nicolas Capens
Signed-off-by: Shannon Woods
Author: Jamie Madill

git-svn-id: https://angleproject.googlecode.com/svn/branches/dx11proto@1786 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/libGLESv2/Context.cpp b/src/libGLESv2/Context.cpp
index e45921d..441f2aa 100644
--- a/src/libGLESv2/Context.cpp
+++ b/src/libGLESv2/Context.cpp
@@ -1712,7 +1712,17 @@
 // Applies the fixed-function state (culling, depth test, alpha blending, stenciling, etc) to the Direct3D 9 device
 void Context::applyState(GLenum drawMode)
 {
-    mRenderer->setRasterizerState(mState.rasterizer);
+    // disable face culling for point sprite emulation (done as geometry shader quads)
+    if (getCurrentProgramBinary()->usesPointSpriteEmulation())
+    {
+        RasterizerState rasterizerStateCopy = mState.rasterizer;
+        rasterizerStateCopy.cullFace = false;
+        mRenderer->setRasterizerState(rasterizerStateCopy);
+    }
+    else
+    {
+        mRenderer->setRasterizerState(mState.rasterizer);
+    }
 
     unsigned int mask = 0;
     if (mState.sampleCoverage)
diff --git a/src/libGLESv2/ProgramBinary.cpp b/src/libGLESv2/ProgramBinary.cpp
index 6eaac91..48fef54 100644
--- a/src/libGLESv2/ProgramBinary.cpp
+++ b/src/libGLESv2/ProgramBinary.cpp
@@ -41,6 +41,7 @@
 {
     mPixelExecutable = NULL;
     mVertexExecutable = NULL;
+    mGeometryExecutable = NULL;
 
     mValidated = false;
 
@@ -67,7 +68,13 @@
 ProgramBinary::~ProgramBinary()
 {
     delete mPixelExecutable;
+    mPixelExecutable = NULL;
+
     delete mVertexExecutable;
+    mVertexExecutable = NULL;
+
+    delete mGeometryExecutable;
+    mGeometryExecutable = NULL;
 
     while (!mUniforms.empty())
     {
@@ -96,6 +103,11 @@
     return mVertexExecutable;
 }
 
+rx::ShaderExecutable *ProgramBinary::getGeometryExecutable()
+{
+    return mGeometryExecutable;
+}
+
 GLuint ProgramBinary::getAttributeLocation(const char *name)
 {
     if (name)
@@ -139,6 +151,16 @@
     return mUsesPointSize;
 }
 
+bool ProgramBinary::usesPointSpriteEmulation() const
+{
+    return mUsesPointSize && mRenderer->getMajorShaderModel() >= 4;
+}
+
+bool ProgramBinary::usesGeometryShader() const
+{
+    return usesPointSpriteEmulation();
+}
+
 // Returns the index of the texture image unit (0-19) corresponding to a Direct3D 9 sampler
 // index (0-15 for the pixel shader and 0-3 for the vertex shader).
 GLint ProgramBinary::getSamplerMapping(SamplerType type, unsigned int samplerIndex)
@@ -1420,7 +1442,7 @@
         }
     }
 
-    if (fragmentShader->mUsesPointCoord && shaderModel == 3)
+    if (fragmentShader->mUsesPointCoord && shaderModel >= 3)
     {
         pixelHLSL += "    float2 gl_PointCoord : " + pointCoordSemantic + ";\n";
     }
@@ -1477,7 +1499,7 @@
                      "    gl_FragCoord.w = rhw;\n";
     }
 
-    if (fragmentShader->mUsesPointCoord && shaderModel == 3)
+    if (fragmentShader->mUsesPointCoord && shaderModel >= 3)
     {
         pixelHLSL += "    gl_PointCoord.x = input.gl_PointCoord.x;\n";
         pixelHLSL += "    gl_PointCoord.y = 1.0 - input.gl_PointCoord.y;\n";
@@ -1643,6 +1665,9 @@
     unsigned int vertexShaderSize;
     stream.read(&vertexShaderSize);
 
+    unsigned int geometryShaderSize;
+    stream.read(&geometryShaderSize);
+
     const char *ptr = (const char*) binary + stream.offset();
 
     const GUID *binaryIdentifier = (const GUID *) ptr;
@@ -1661,6 +1686,9 @@
     const char *vertexShaderFunction = ptr;
     ptr += vertexShaderSize;
 
+    const char *geometryShaderFunction = geometryShaderSize > 0 ? ptr : NULL;
+    ptr += geometryShaderSize;
+
     mPixelExecutable = mRenderer->loadExecutable(reinterpret_cast<const DWORD*>(pixelShaderFunction),
                                                  pixelShaderSize, rx::SHADER_PIXEL);
     if (!mPixelExecutable)
@@ -1679,6 +1707,25 @@
         return false;
     }
 
+    if (geometryShaderFunction != NULL && geometryShaderSize > 0)
+    {
+        mGeometryExecutable = mRenderer->loadExecutable(reinterpret_cast<const DWORD*>(geometryShaderFunction),
+                                                        geometryShaderSize, rx::SHADER_GEOMETRY);
+        if (!mGeometryExecutable)
+        {
+            infoLog.append("Could not create geometry shader.");
+            delete mPixelExecutable;
+            mPixelExecutable = NULL;
+            delete mVertexExecutable;
+            mVertexExecutable = NULL;
+            return false;
+        }
+    }
+    else
+    {
+        mGeometryExecutable = NULL;
+    }
+
     return true;
 }
 
@@ -1740,12 +1787,15 @@
     UINT vertexShaderSize = mVertexExecutable->getLength();
     stream.write(vertexShaderSize);
 
+    UINT geometryShaderSize = (mGeometryExecutable != NULL) ? mGeometryExecutable->getLength() : 0;
+    stream.write(geometryShaderSize);
+
     GUID identifier = mRenderer->getAdapterIdentifier();
 
     GLsizei streamLength = stream.length();
     const void *streamData = stream.data();
 
-    GLsizei totalLength = streamLength + sizeof(GUID) + pixelShaderSize + vertexShaderSize;
+    GLsizei totalLength = streamLength + sizeof(GUID) + pixelShaderSize + vertexShaderSize + geometryShaderSize;
     if (totalLength > bufSize)
     {
         if (length)
@@ -1772,6 +1822,12 @@
         memcpy(ptr, mVertexExecutable->getFunction(), vertexShaderSize);
         ptr += vertexShaderSize;
 
+        if (mGeometryExecutable != NULL && geometryShaderSize > 0)
+        {
+            memcpy(ptr, mGeometryExecutable->getFunction(), geometryShaderSize);
+            ptr += geometryShaderSize;
+        }
+
         ASSERT(ptr - totalLength == binary);
     }
 
@@ -1829,7 +1885,13 @@
     mVertexExecutable = mRenderer->compileToExecutable(infoLog, vertexHLSL.c_str(), rx::SHADER_VERTEX);
     mPixelExecutable = mRenderer->compileToExecutable(infoLog, pixelHLSL.c_str(), rx::SHADER_PIXEL);
 
-    if (!mVertexExecutable || !mPixelExecutable)
+    if (usesGeometryShader())
+    {
+        std::string geometryHLSL = generateGeometryShaderHLSL(registers, packing, fragmentShader, vertexShader);
+        mGeometryExecutable = mRenderer->compileToExecutable(infoLog, geometryHLSL.c_str(), rx::SHADER_GEOMETRY);
+    }
+
+    if (!mVertexExecutable || !mPixelExecutable || (usesGeometryShader() && !mGeometryExecutable))
     {
         infoLog.append("Failed to create D3D shaders.");
         success = false;
@@ -1838,6 +1900,8 @@
         mVertexExecutable = NULL;
         delete mPixelExecutable;
         mPixelExecutable = NULL;
+        delete mGeometryExecutable;
+        mGeometryExecutable = NULL;
     }
 
     if (!linkAttributes(infoLog, attributeBindings, fragmentShader, vertexShader))
@@ -2038,6 +2102,142 @@
     return true;
 }
 
+std::string ProgramBinary::generateGeometryShaderHLSL(int registers, const Varying *packing[][4], FragmentShader *fragmentShader, VertexShader *vertexShader) const
+{
+    // for now we only handle point sprite emulation
+    ASSERT(usesPointSpriteEmulation());
+    return generatePointSpriteHLSL(registers, packing, fragmentShader, vertexShader);
+}
+
+std::string ProgramBinary::generatePointSpriteHLSL(int registers, const Varying *packing[][4], FragmentShader *fragmentShader, VertexShader *vertexShader) const
+{
+    ASSERT(registers >= 0);
+    ASSERT(vertexShader->mUsesPointSize);
+    ASSERT(mRenderer->getMajorShaderModel() >= 4);
+
+    std::string geomHLSL;
+
+    std::string varyingSemantic = "TEXCOORD";
+
+    std::string fragCoordSemantic;
+    std::string pointCoordSemantic;
+
+    int reservedRegisterIndex = registers;
+
+    if (fragmentShader->mUsesFragCoord)
+    {
+        fragCoordSemantic = varyingSemantic + str(reservedRegisterIndex++);
+    }
+
+    if (fragmentShader->mUsesPointCoord)
+    {
+        pointCoordSemantic = varyingSemantic + str(reservedRegisterIndex++);
+    }
+
+    geomHLSL += "uniform float4 dx_viewportCoords : register(c1);\n"
+                "\n";
+
+    geomHLSL += "struct GS_INPUT\n"
+                "{\n";
+
+    for (int r = 0; r < registers; r++)
+    {
+        int registerSize = packing[r][3] ? 4 : (packing[r][2] ? 3 : (packing[r][1] ? 2 : 1));
+
+        geomHLSL += "    float" + str(registerSize) + " v" + str(r) + " : " + varyingSemantic + str(r) + ";\n";
+    }
+
+    if (fragmentShader->mUsesFragCoord)
+    {
+        geomHLSL += "    float4 gl_FragCoord : " + fragCoordSemantic + ";\n";
+    }
+
+    geomHLSL += "    float gl_PointSize : PSIZE;\n"
+                "    float4 gl_Position : SV_Position;\n"
+                  "};\n"
+                  "\n"
+                  "struct GS_OUTPUT\n"
+                  "{\n";
+
+    for (int r = 0; r < registers; r++)
+    {
+        int registerSize = packing[r][3] ? 4 : (packing[r][2] ? 3 : (packing[r][1] ? 2 : 1));
+
+        geomHLSL += "    float" + str(registerSize) + " v" + str(r) + " : " + varyingSemantic + str(r) + ";\n";
+    }
+
+    if (fragmentShader->mUsesFragCoord)
+    {
+        geomHLSL += "    float4 gl_FragCoord : " + fragCoordSemantic + ";\n";
+    }
+
+    if (fragmentShader->mUsesPointCoord)
+    {
+        geomHLSL += "    float2 gl_PointCoord : " + pointCoordSemantic + ";\n";
+    }
+
+    geomHLSL +=   "    float4 gl_Position : SV_Position;\n"
+                  "};\n"
+                  "\n"
+                  "static float2 pointSpriteCorners[] = \n"
+                  "{\n"
+                  "    float2( 0.5f, -0.5f),\n"
+                  "    float2( 0.5f,  0.5f),\n"
+                  "    float2(-0.5f, -0.5f),\n"
+                  "    float2(-0.5f,  0.5f)\n"
+                  "};\n"
+                  "\n"
+                  "static float2 pointSpriteTexcoords[] = \n"
+                  "{\n"
+                  "    float2(1.0f, 1.0f),\n"
+                  "    float2(1.0f, 0.0f),\n"
+                  "    float2(0.0f, 1.0f),\n"
+                  "    float2(0.0f, 0.0f)\n"
+                  "};\n"
+                  "\n"
+                  "static float minPointSize = " + str(ALIASED_POINT_SIZE_RANGE_MIN) + ".0f;\n"
+                  "static float maxPointSize = " + str(mRenderer->getMaxPointSize()) + ".0f;\n"
+                  "\n"
+                  "[maxvertexcount(4)]\n"
+                  "void main(point GS_INPUT input[1], inout TriangleStream<GS_OUTPUT> outStream)\n"
+                  "{\n"
+                  "    GS_OUTPUT output = (GS_OUTPUT)0;\n";
+
+    for (int r = 0; r < registers; r++)
+    {
+        geomHLSL += "    output.v" + str(r) + " = input[0].v" + str(r) + ";\n";
+    }
+
+    if (fragmentShader->mUsesFragCoord)
+    {
+        geomHLSL += "    output.gl_FragCoord = input[0].gl_FragCoord;\n";
+    }
+
+    geomHLSL += "    \n"
+                "    float gl_PointSize = clamp(input[0].gl_PointSize, minPointSize, maxPointSize);\n"
+                "    float4 gl_Position = input[0].gl_Position;\n"
+                "    float2 viewportScale = float2(1.0f / dx_viewportCoords.x, 1.0f / dx_viewportCoords.y);\n";
+
+    for (int corner = 0; corner < 4; corner++)
+    {
+        geomHLSL += "    \n"
+                    "    output.gl_Position = gl_Position + float4(pointSpriteCorners[" + str(corner) + "] * viewportScale * gl_PointSize, 0.0f, 0.0f);\n";
+
+        if (fragmentShader->mUsesPointCoord)
+        {
+            geomHLSL += "    output.gl_PointCoord = pointSpriteTexcoords[" + str(corner) + "];\n";
+        }
+
+        geomHLSL += "    outStream.Append(output);\n";
+    }
+
+    geomHLSL += "    \n"
+                "    outStream.RestartStrip();\n"
+                "}\n";
+
+    return geomHLSL;
+}
+
 // This method needs to match OutputHLSL::decorate
 std::string ProgramBinary::decorateAttribute(const std::string &name)
 {
diff --git a/src/libGLESv2/ProgramBinary.h b/src/libGLESv2/ProgramBinary.h
index e0a3f15..d2ee73a 100644
--- a/src/libGLESv2/ProgramBinary.h
+++ b/src/libGLESv2/ProgramBinary.h
@@ -53,6 +53,7 @@
 
     rx::ShaderExecutable *getPixelExecutable();
     rx::ShaderExecutable *getVertexExecutable();
+    rx::ShaderExecutable *getGeometryExecutable();
 
     GLuint getAttributeLocation(const char *name);
     int getSemanticIndex(int attributeIndex);
@@ -61,6 +62,8 @@
     TextureType getSamplerTextureType(SamplerType type, unsigned int samplerIndex);
     GLint getUsedSamplerRange(SamplerType type);
     bool usesPointSize() const;
+    bool usesPointSpriteEmulation() const;
+    bool usesGeometryShader() const;
 
     GLint getUniformLocation(std::string name);
     bool setUniform1fv(GLint location, GLsizei count, const GLfloat *v);
@@ -117,10 +120,14 @@
     bool linkUniforms(InfoLog &infoLog, const sh::ActiveUniforms &vertexUniforms, const sh::ActiveUniforms &fragmentUniforms);
     bool defineUniform(GLenum shader, const sh::Uniform &constant, InfoLog &infoLog);
     
+    std::string generateGeometryShaderHLSL(int registers, const Varying *packing[][4], FragmentShader *fragmentShader, VertexShader *vertexShader) const;
+    std::string generatePointSpriteHLSL(int registers, const Varying *packing[][4], FragmentShader *fragmentShader, VertexShader *vertexShader) const;
+
     rx::Renderer *const mRenderer;
 
     rx::ShaderExecutable *mPixelExecutable;
     rx::ShaderExecutable *mVertexExecutable;
+    rx::ShaderExecutable *mGeometryExecutable;
 
     Attribute mLinkedAttribute[MAX_VERTEX_ATTRIBS];
     int mSemanticIndex[MAX_VERTEX_ATTRIBS];
diff --git a/src/libGLESv2/renderer/Renderer.h b/src/libGLESv2/renderer/Renderer.h
index 242308e..55b5997 100644
--- a/src/libGLESv2/renderer/Renderer.h
+++ b/src/libGLESv2/renderer/Renderer.h
@@ -83,7 +83,8 @@
 enum ShaderType
 {
     SHADER_VERTEX,
-    SHADER_PIXEL
+    SHADER_PIXEL,
+    SHADER_GEOMETRY
 };
 
 class Renderer
diff --git a/src/libGLESv2/renderer/Renderer11.cpp b/src/libGLESv2/renderer/Renderer11.cpp
index 1d4ae88..76ddc9e 100644
--- a/src/libGLESv2/renderer/Renderer11.cpp
+++ b/src/libGLESv2/renderer/Renderer11.cpp
@@ -1044,6 +1044,7 @@
     {
         ShaderExecutable *vertexExe = programBinary->getVertexExecutable();
         ShaderExecutable *pixelExe = programBinary->getPixelExecutable();
+        ShaderExecutable *geometryExe = programBinary->getGeometryExecutable();
 
         ID3D11VertexShader *vertexShader = NULL;
         if (vertexExe) vertexShader = ShaderExecutable11::makeShaderExecutable11(vertexExe)->getVertexShader();
@@ -1051,8 +1052,21 @@
         ID3D11PixelShader *pixelShader = NULL;
         if (pixelExe) pixelShader = ShaderExecutable11::makeShaderExecutable11(pixelExe)->getPixelShader();
 
+        ID3D11GeometryShader *geometryShader = NULL;
+        if (geometryExe) geometryShader = ShaderExecutable11::makeShaderExecutable11(geometryExe)->getGeometryShader();
+
         mDeviceContext->PSSetShader(pixelShader, NULL, 0);
         mDeviceContext->VSSetShader(vertexShader, NULL, 0);
+
+        if (geometryShader)
+        {
+            mDeviceContext->GSSetShader(geometryShader, NULL, 0);
+        }
+        else
+        {
+            mDeviceContext->GSSetShader(NULL, NULL, 0);
+        }
+
         programBinary->dirtyAllUniforms();
 
         mAppliedProgramBinarySerial = programBinarySerial;
@@ -1221,7 +1235,7 @@
     
     mDeviceContext->VSSetConstantBuffers(0, 1, &vertexConstantBuffer);
     mDeviceContext->PSSetConstantBuffers(0, 1, &pixelConstantBuffer);
-    
+
     delete[] mapVS;
     delete[] mapPS;
 
@@ -1269,6 +1283,9 @@
         mDeviceContext->UpdateSubresource(mDriverConstantBufferPS, 0, NULL, &mPixelConstants, 16, 0);
         memcpy(&mAppliedPixelConstants, &mPixelConstants, sizeof(dx_PixelConstants));
     }
+
+    // needed for the point sprite geometry shader
+    mDeviceContext->GSSetConstantBuffers(0, 1, &mDriverConstantBufferPS);
 }
 
 void Renderer11::clear(const gl::ClearParameters &clearParams, gl::Framebuffer *frameBuffer)
@@ -1964,9 +1981,9 @@
 
 float Renderer11::getMaxPointSize() const
 {
-    // TODO
-    // UNIMPLEMENTED();
-    return 1.0f;
+    // choose a reasonable maximum. we enforce this in the shader.
+    // (nb: on a Radeon 2600xt, DX9 reports a 256 max point size)
+    return 1024.0f;
 }
 
 int Renderer11::getMaxTextureWidth() const
@@ -2406,6 +2423,18 @@
             }
         }
         break;
+      case rx::SHADER_GEOMETRY:
+        {
+            ID3D11GeometryShader *gshader = NULL;
+            HRESULT result = mDevice->CreateGeometryShader(function, length, NULL, &gshader);
+            ASSERT(SUCCEEDED(result));
+
+            if (gshader)
+            {
+                executable = new ShaderExecutable11(function, length, gshader);
+            }
+        }
+        break;
       default:
         UNREACHABLE();
         break;
@@ -2426,6 +2455,9 @@
       case rx::SHADER_PIXEL:
         profile = "ps_4_0";
         break;
+      case rx::SHADER_GEOMETRY:
+        profile = "gs_4_0";
+        break;
       default:
         UNREACHABLE();
         return NULL;
diff --git a/src/libGLESv2/renderer/ShaderExecutable11.cpp b/src/libGLESv2/renderer/ShaderExecutable11.cpp
index d0903be..4a944fe 100644
--- a/src/libGLESv2/renderer/ShaderExecutable11.cpp
+++ b/src/libGLESv2/renderer/ShaderExecutable11.cpp
@@ -19,6 +19,7 @@
 {
     mPixelExecutable = executable;
     mVertexExecutable = NULL;
+    mGeometryExecutable = NULL;
 
     mConstantBuffer = NULL;
 }
@@ -28,6 +29,17 @@
 {
     mVertexExecutable = executable;
     mPixelExecutable = NULL;
+    mGeometryExecutable = NULL;
+
+    mConstantBuffer = NULL;
+}
+
+ShaderExecutable11::ShaderExecutable11(const void *function, size_t length, ID3D11GeometryShader *executable)
+    : ShaderExecutable(function, length)
+{
+    mGeometryExecutable = executable;
+    mVertexExecutable = NULL;
+    mPixelExecutable = NULL;
 
     mConstantBuffer = NULL;
 }
@@ -42,6 +54,10 @@
     {
         mPixelExecutable->Release();
     }
+    if (mGeometryExecutable)
+    {
+        mGeometryExecutable->Release();
+    }
     
     if (mConstantBuffer)
     {
@@ -65,6 +81,11 @@
     return mPixelExecutable;
 }
 
+ID3D11GeometryShader *ShaderExecutable11::getGeometryShader() const
+{
+    return mGeometryExecutable;
+}
+
 ID3D11Buffer *ShaderExecutable11::getConstantBuffer(ID3D11Device *device, unsigned int registerCount)
 {
     if (!mConstantBuffer && registerCount > 0)
diff --git a/src/libGLESv2/renderer/ShaderExecutable11.h b/src/libGLESv2/renderer/ShaderExecutable11.h
index c95ecfc..0801e07 100644
--- a/src/libGLESv2/renderer/ShaderExecutable11.h
+++ b/src/libGLESv2/renderer/ShaderExecutable11.h
@@ -22,6 +22,7 @@
   public:
     ShaderExecutable11(const void *function, size_t length, ID3D11PixelShader *executable);
     ShaderExecutable11(const void *function, size_t length, ID3D11VertexShader *executable);
+    ShaderExecutable11(const void *function, size_t length, ID3D11GeometryShader *executable);
 
     virtual ~ShaderExecutable11();
 
@@ -29,6 +30,7 @@
 
     ID3D11PixelShader *getPixelShader() const;
     ID3D11VertexShader *getVertexShader() const;
+    ID3D11GeometryShader *getGeometryShader() const;
 
     ID3D11Buffer *getConstantBuffer(ID3D11Device *device, unsigned int registerCount);
 
@@ -37,6 +39,7 @@
 
     ID3D11PixelShader *mPixelExecutable;
     ID3D11VertexShader *mVertexExecutable;
+    ID3D11GeometryShader *mGeometryExecutable;
 
     ID3D11Buffer *mConstantBuffer;
 };
diff --git a/src/libGLESv2/renderer/SwapChain11.cpp b/src/libGLESv2/renderer/SwapChain11.cpp
index f3faae4..2d6f742 100644
--- a/src/libGLESv2/renderer/SwapChain11.cpp
+++ b/src/libGLESv2/renderer/SwapChain11.cpp
@@ -507,6 +507,7 @@
     deviceContext->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP);
     deviceContext->VSSetShader(mPassThroughVS, NULL, 0);
     deviceContext->PSSetShader(mPassThroughPS, NULL, 0);
+    deviceContext->GSSetShader(NULL, NULL, 0);
 
     // Apply render targets
     deviceContext->OMSetRenderTargets(1, &mBackBufferRTView, NULL);