Use ShaderMap in ProgramD3D - Part II

This patch refactors ProgramD3D by storing all shader information
into ShaderMap to simplify the code structure.

This patch also fixes a bug on getting the number of maximum uniform
blocks.

BUG=angleproject:2169

Change-Id: I5b9fbfd70a18f8731ce19efed0df88037d495389
Reviewed-on: https://chromium-review.googlesource.com/1024749
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/libANGLE/renderer/d3d/ProgramD3D.cpp b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
index 40aa279..e2c2935 100644
--- a/src/libANGLE/renderer/d3d/ProgramD3D.cpp
+++ b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
@@ -164,7 +164,7 @@
     }
 }
 
-bool FindFlatInterpolationVarying(const std::vector<sh::Varying> &varyings)
+bool HasFlatInterpolationVarying(const std::vector<sh::Varying> &varyings)
 {
     // Note: this assumes nested structs can only be packed with one interpolation.
     for (const auto &varying : varyings)
@@ -178,6 +178,46 @@
     return false;
 }
 
+bool FindFlatInterpolationVaryingPerShader(const gl::Context *context, gl::Shader *shader)
+{
+    ASSERT(context && shader);
+    switch (shader->getType())
+    {
+        case gl::ShaderType::Vertex:
+            return HasFlatInterpolationVarying(shader->getOutputVaryings(context));
+        case gl::ShaderType::Fragment:
+            return HasFlatInterpolationVarying(shader->getInputVaryings(context));
+        case gl::ShaderType::Geometry:
+            return HasFlatInterpolationVarying(shader->getInputVaryings(context)) ||
+                   HasFlatInterpolationVarying(shader->getOutputVaryings(context));
+        default:
+            UNREACHABLE();
+            return false;
+    }
+}
+
+bool FindFlatInterpolationVarying(const gl::Context *context,
+                                  const gl::ShaderMap<gl::Shader *> &shaders)
+{
+    ASSERT(context);
+
+    for (gl::ShaderType shaderType : gl::kAllGraphicsShaderTypes)
+    {
+        gl::Shader *shader = shaders[shaderType];
+        if (!shader)
+        {
+            continue;
+        }
+
+        if (FindFlatInterpolationVaryingPerShader(context, shader))
+        {
+            return true;
+        }
+    }
+
+    return false;
+}
+
 // Helper method to de-tranpose a matrix uniform for an API query.
 void GetMatrixUniform(GLint columns, GLint rows, GLfloat *dataOut, const GLfloat *source)
 {
@@ -289,6 +329,44 @@
     return true;
 };
 
+// TODO(jiawei.shao@intel.com): remove this function once we use ShaderMap in gl::Caps.
+GLuint GetMaximumSamplersPerShader(gl::ShaderType shaderType, const gl::Caps &caps)
+{
+    switch (shaderType)
+    {
+        case gl::ShaderType::Fragment:
+            return caps.maxTextureImageUnits;
+        case gl::ShaderType::Vertex:
+            return caps.maxVertexTextureImageUnits;
+        case gl::ShaderType::Compute:
+            return caps.maxComputeTextureImageUnits;
+        case gl::ShaderType::Geometry:
+            return caps.maxGeometryTextureImageUnits;
+        default:
+            UNREACHABLE();
+            return 0u;
+    }
+}
+
+// TODO(jiawei.shao@intel.com): remove this function once we use ShaderMap in gl::Caps.
+GLuint GetMaximumShaderUniformBlocksPerShader(gl::ShaderType shaderType, const gl::Caps &caps)
+{
+    switch (shaderType)
+    {
+        case gl::ShaderType::Vertex:
+            return caps.maxVertexUniformBlocks;
+        case gl::ShaderType::Fragment:
+            return caps.maxFragmentUniformBlocks;
+        case gl::ShaderType::Compute:
+            return caps.maxComputeUniformBlocks;
+        case gl::ShaderType::Geometry:
+            return caps.maxGeometryUniformBlocks;
+        default:
+            UNREACHABLE();
+            return 0u;
+    }
+}
+
 }  // anonymous namespace
 
 // D3DUniform Implementation
@@ -604,9 +682,7 @@
       mComputeExecutable(nullptr),
       mUsesPointSize(false),
       mUsesFlatInterpolation(false),
-      mUsedVertexSamplerRange(0),
-      mUsedPixelSamplerRange(0),
-      mUsedComputeSamplerRange(0),
+      mUsedShaderSamplerRanges({}),
       mDirtySamplerMapping(true),
       mUsedComputeImageRange(0),
       mUsedComputeReadonlyImageRange(0),
@@ -655,31 +731,14 @@
 {
     GLint logicalTextureUnit = -1;
 
-    switch (type)
+    ASSERT(type != gl::ShaderType::InvalidEnum);
+
+    ASSERT(samplerIndex < GetMaximumSamplersPerShader(type, caps));
+
+    const auto &samplers = mShaderSamplers[type];
+    if (samplerIndex < samplers.size() && samplers[samplerIndex].active)
     {
-        case gl::ShaderType::Fragment:
-            ASSERT(samplerIndex < caps.maxTextureImageUnits);
-            if (samplerIndex < mSamplersPS.size() && mSamplersPS[samplerIndex].active)
-            {
-                logicalTextureUnit = mSamplersPS[samplerIndex].logicalTextureUnit;
-            }
-            break;
-        case gl::ShaderType::Vertex:
-            ASSERT(samplerIndex < caps.maxVertexTextureImageUnits);
-            if (samplerIndex < mSamplersVS.size() && mSamplersVS[samplerIndex].active)
-            {
-                logicalTextureUnit = mSamplersVS[samplerIndex].logicalTextureUnit;
-            }
-            break;
-        case gl::ShaderType::Compute:
-            ASSERT(samplerIndex < caps.maxComputeTextureImageUnits);
-            if (samplerIndex < mSamplersCS.size() && mSamplersCS[samplerIndex].active)
-            {
-                logicalTextureUnit = mSamplersCS[samplerIndex].logicalTextureUnit;
-            }
-            break;
-        default:
-            UNREACHABLE();
+        logicalTextureUnit = samplers[samplerIndex].logicalTextureUnit;
     }
 
     if (logicalTextureUnit >= 0 &&
@@ -696,41 +755,19 @@
 gl::TextureType ProgramD3D::getSamplerTextureType(gl::ShaderType type,
                                                   unsigned int samplerIndex) const
 {
-    switch (type)
-    {
-        case gl::ShaderType::Fragment:
-            ASSERT(samplerIndex < mSamplersPS.size());
-            ASSERT(mSamplersPS[samplerIndex].active);
-            return mSamplersPS[samplerIndex].textureType;
-        case gl::ShaderType::Vertex:
-            ASSERT(samplerIndex < mSamplersVS.size());
-            ASSERT(mSamplersVS[samplerIndex].active);
-            return mSamplersVS[samplerIndex].textureType;
-        case gl::ShaderType::Compute:
-            ASSERT(samplerIndex < mSamplersCS.size());
-            ASSERT(mSamplersCS[samplerIndex].active);
-            return mSamplersCS[samplerIndex].textureType;
-        default:
-            UNREACHABLE();
-            return gl::TextureType::InvalidEnum;
-    }
+    ASSERT(type != gl::ShaderType::InvalidEnum);
 
+    const auto &samplers = mShaderSamplers[type];
+    ASSERT(samplerIndex < samplers.size());
+    ASSERT(samplers[samplerIndex].active);
+
+    return samplers[samplerIndex].textureType;
 }
 
 GLuint ProgramD3D::getUsedSamplerRange(gl::ShaderType type) const
 {
-    switch (type)
-    {
-        case gl::ShaderType::Fragment:
-            return mUsedPixelSamplerRange;
-        case gl::ShaderType::Vertex:
-            return mUsedVertexSamplerRange;
-        case gl::ShaderType::Compute:
-            return mUsedComputeSamplerRange;
-        default:
-            UNREACHABLE();
-            return 0u;
-    }
+    ASSERT(type != gl::ShaderType::InvalidEnum);
+    return mUsedShaderSamplerRanges[type];
 }
 
 ProgramD3D::SamplerMapping ProgramD3D::updateSamplerMapping()
@@ -750,50 +787,24 @@
 
         int count = d3dUniform->getArraySizeProduct();
 
-        if (d3dUniform->isReferencedByShader(gl::ShaderType::Fragment))
+        for (gl::ShaderType shaderType : gl::AllShaderTypes())
         {
-            unsigned int firstIndex = d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Fragment];
-
-            for (int i = 0; i < count; i++)
+            if (!d3dUniform->isReferencedByShader(shaderType))
             {
-                unsigned int samplerIndex = firstIndex + i;
-
-                if (samplerIndex < mSamplersPS.size())
-                {
-                    ASSERT(mSamplersPS[samplerIndex].active);
-                    mSamplersPS[samplerIndex].logicalTextureUnit = d3dUniform->mSamplerData[i];
-                }
+                continue;
             }
-        }
 
-        if (d3dUniform->isReferencedByShader(gl::ShaderType::Vertex))
-        {
-            unsigned int firstIndex = d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Vertex];
+            unsigned int firstIndex = d3dUniform->mShaderRegisterIndexes[shaderType];
 
+            std::vector<Sampler> &samplers = mShaderSamplers[shaderType];
             for (int i = 0; i < count; i++)
             {
                 unsigned int samplerIndex = firstIndex + i;
 
-                if (samplerIndex < mSamplersVS.size())
+                if (samplerIndex < samplers.size())
                 {
-                    ASSERT(mSamplersVS[samplerIndex].active);
-                    mSamplersVS[samplerIndex].logicalTextureUnit = d3dUniform->mSamplerData[i];
-                }
-            }
-        }
-
-        if (d3dUniform->isReferencedByShader(gl::ShaderType::Compute))
-        {
-            unsigned int firstIndex = d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Compute];
-
-            for (int i = 0; i < count; i++)
-            {
-                unsigned int samplerIndex = firstIndex + i;
-
-                if (samplerIndex < mSamplersCS.size())
-                {
-                    ASSERT(mSamplersCS[samplerIndex].active);
-                    mSamplersCS[samplerIndex].logicalTextureUnit = d3dUniform->mSamplerData[i];
+                    ASSERT(samplers[samplerIndex].active);
+                    samplers[samplerIndex].logicalTextureUnit = d3dUniform->mSamplerData[i];
                 }
             }
         }
@@ -879,33 +890,19 @@
         stream->readInt(&index);
     }
 
-    const unsigned int psSamplerCount = stream->readInt<unsigned int>();
-    for (unsigned int i = 0; i < psSamplerCount; ++i)
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
     {
-        Sampler sampler;
-        stream->readBool(&sampler.active);
-        stream->readInt(&sampler.logicalTextureUnit);
-        stream->readEnum(&sampler.textureType);
-        mSamplersPS.push_back(sampler);
-    }
-    const unsigned int vsSamplerCount = stream->readInt<unsigned int>();
-    for (unsigned int i = 0; i < vsSamplerCount; ++i)
-    {
-        Sampler sampler;
-        stream->readBool(&sampler.active);
-        stream->readInt(&sampler.logicalTextureUnit);
-        stream->readEnum(&sampler.textureType);
-        mSamplersVS.push_back(sampler);
-    }
+        const unsigned int samplerCount = stream->readInt<unsigned int>();
+        for (unsigned int i = 0; i < samplerCount; ++i)
+        {
+            Sampler sampler;
+            stream->readBool(&sampler.active);
+            stream->readInt(&sampler.logicalTextureUnit);
+            stream->readEnum(&sampler.textureType);
+            mShaderSamplers[shaderType].push_back(sampler);
+        }
 
-    const unsigned int csSamplerCount = stream->readInt<unsigned int>();
-    for (unsigned int i = 0; i < csSamplerCount; ++i)
-    {
-        Sampler sampler;
-        stream->readBool(&sampler.active);
-        stream->readInt(&sampler.logicalTextureUnit);
-        stream->readEnum(&sampler.textureType);
-        mSamplersCS.push_back(sampler);
+        stream->readInt(&mUsedShaderSamplerRanges[shaderType]);
     }
 
     const unsigned int csImageCount = stream->readInt<unsigned int>();
@@ -926,9 +923,6 @@
         mReadonlyImagesCS.push_back(image);
     }
 
-    stream->readInt(&mUsedVertexSamplerRange);
-    stream->readInt(&mUsedPixelSamplerRange);
-    stream->readInt(&mUsedComputeSamplerRange);
     stream->readInt(&mUsedComputeImageRange);
     stream->readInt(&mUsedComputeReadonlyImageRange);
 
@@ -989,12 +983,13 @@
         stream->readInt(&varying->outputSlot);
     }
 
-    stream->readString(&mVertexHLSL);
-    stream->readBytes(reinterpret_cast<unsigned char *>(&mVertexWorkarounds),
-                      sizeof(angle::CompilerWorkaroundsD3D));
-    stream->readString(&mPixelHLSL);
-    stream->readBytes(reinterpret_cast<unsigned char *>(&mPixelWorkarounds),
-                      sizeof(angle::CompilerWorkaroundsD3D));
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
+    {
+        stream->readString(&mShaderHLSL[shaderType]);
+        stream->readBytes(reinterpret_cast<unsigned char *>(&mShaderWorkarounds[shaderType]),
+                          sizeof(angle::CompilerWorkaroundsD3D));
+    }
+
     stream->readBool(&mUsesFragDepth);
     stream->readBool(&mHasANGLEMultiviewEnabled);
     stream->readBool(&mUsesViewID);
@@ -1156,28 +1151,17 @@
         stream->writeInt(d3dSemantic);
     }
 
-    stream->writeInt(mSamplersPS.size());
-    for (unsigned int i = 0; i < mSamplersPS.size(); ++i)
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
     {
-        stream->writeInt(mSamplersPS[i].active);
-        stream->writeInt(mSamplersPS[i].logicalTextureUnit);
-        stream->writeEnum(mSamplersPS[i].textureType);
-    }
+        stream->writeInt(mShaderSamplers[shaderType].size());
+        for (unsigned int i = 0; i < mShaderSamplers[shaderType].size(); ++i)
+        {
+            stream->writeInt(mShaderSamplers[shaderType][i].active);
+            stream->writeInt(mShaderSamplers[shaderType][i].logicalTextureUnit);
+            stream->writeEnum(mShaderSamplers[shaderType][i].textureType);
+        }
 
-    stream->writeInt(mSamplersVS.size());
-    for (unsigned int i = 0; i < mSamplersVS.size(); ++i)
-    {
-        stream->writeInt(mSamplersVS[i].active);
-        stream->writeInt(mSamplersVS[i].logicalTextureUnit);
-        stream->writeEnum(mSamplersVS[i].textureType);
-    }
-
-    stream->writeInt(mSamplersCS.size());
-    for (unsigned int i = 0; i < mSamplersCS.size(); ++i)
-    {
-        stream->writeInt(mSamplersCS[i].active);
-        stream->writeInt(mSamplersCS[i].logicalTextureUnit);
-        stream->writeEnum(mSamplersCS[i].textureType);
+        stream->writeInt(mUsedShaderSamplerRanges[shaderType]);
     }
 
     stream->writeInt(mImagesCS.size());
@@ -1194,9 +1178,6 @@
         stream->writeInt(mReadonlyImagesCS[i].logicalImageUnit);
     }
 
-    stream->writeInt(mUsedVertexSamplerRange);
-    stream->writeInt(mUsedPixelSamplerRange);
-    stream->writeInt(mUsedComputeSamplerRange);
     stream->writeInt(mUsedComputeImageRange);
     stream->writeInt(mUsedComputeReadonlyImageRange);
 
@@ -1231,12 +1212,13 @@
         stream->writeInt(varying.outputSlot);
     }
 
-    stream->writeString(mVertexHLSL);
-    stream->writeBytes(reinterpret_cast<unsigned char *>(&mVertexWorkarounds),
-                       sizeof(angle::CompilerWorkaroundsD3D));
-    stream->writeString(mPixelHLSL);
-    stream->writeBytes(reinterpret_cast<unsigned char *>(&mPixelWorkarounds),
-                       sizeof(angle::CompilerWorkaroundsD3D));
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
+    {
+        stream->writeString(mShaderHLSL[shaderType]);
+        stream->writeBytes(reinterpret_cast<unsigned char *>(&mShaderWorkarounds[shaderType]),
+                           sizeof(angle::CompilerWorkaroundsD3D));
+    }
+
     stream->writeInt(mUsesFragDepth);
     stream->writeInt(mHasANGLEMultiviewEnabled);
     stream->writeInt(mUsesViewID);
@@ -1341,7 +1323,8 @@
     }
 
     std::string finalPixelHLSL = mDynamicHLSL->generatePixelShaderForOutputSignature(
-        mPixelHLSL, mPixelShaderKey, mUsesFragDepth, mPixelShaderOutputLayoutCache);
+        mShaderHLSL[gl::ShaderType::Fragment], mPixelShaderKey, mUsesFragDepth,
+        mPixelShaderOutputLayoutCache);
 
     // Generate new pixel executable
     ShaderExecutableD3D *pixelExecutable = nullptr;
@@ -1351,8 +1334,8 @@
 
     ANGLE_TRY(mRenderer->compileToExecutable(
         *currentInfoLog, finalPixelHLSL, gl::ShaderType::Fragment, mStreamOutVaryings,
-        (mState.getTransformFeedbackBufferMode() == GL_SEPARATE_ATTRIBS), mPixelWorkarounds,
-        &pixelExecutable));
+        (mState.getTransformFeedbackBufferMode() == GL_SEPARATE_ATTRIBS),
+        mShaderWorkarounds[gl::ShaderType::Fragment], &pixelExecutable));
 
     if (pixelExecutable)
     {
@@ -1382,7 +1365,7 @@
 
     // Generate new dynamic layout with attribute conversions
     std::string finalVertexHLSL = mDynamicHLSL->generateVertexShaderForInputLayout(
-        mVertexHLSL, mCachedInputLayout, mState.getAttributes());
+        mShaderHLSL[gl::ShaderType::Vertex], mCachedInputLayout, mState.getAttributes());
 
     // Generate new vertex executable
     ShaderExecutableD3D *vertexExecutable = nullptr;
@@ -1392,8 +1375,8 @@
 
     ANGLE_TRY(mRenderer->compileToExecutable(
         *currentInfoLog, finalVertexHLSL, gl::ShaderType::Vertex, mStreamOutVaryings,
-        (mState.getTransformFeedbackBufferMode() == GL_SEPARATE_ATTRIBS), mVertexWorkarounds,
-        &vertexExecutable));
+        (mState.getTransformFeedbackBufferMode() == GL_SEPARATE_ATTRIBS),
+        mShaderWorkarounds[gl::ShaderType::Vertex], &vertexExecutable));
 
     if (vertexExecutable)
     {
@@ -1684,7 +1667,7 @@
     gl::Shader *computeShader = mState.getAttachedShader(gl::ShaderType::Compute);
     if (computeShader)
     {
-        mSamplersCS.resize(data.getCaps().maxComputeTextureImageUnits);
+        mShaderSamplers[gl::ShaderType::Compute].resize(data.getCaps().maxComputeTextureImageUnits);
         mImagesCS.resize(data.getCaps().maxImageUnits);
         mReadonlyImagesCS.resize(data.getCaps().maxImageUnits);
 
@@ -1705,18 +1688,21 @@
     }
     else
     {
-        gl::Shader *vertexShader   = mState.getAttachedShader(gl::ShaderType::Vertex);
-        gl::Shader *fragmentShader = mState.getAttachedShader(gl::ShaderType::Fragment);
-
         gl::ShaderMap<const ShaderD3D *> shadersD3D = {};
-        shadersD3D[gl::ShaderType::Vertex]          = GetImplAs<ShaderD3D>(vertexShader);
-        shadersD3D[gl::ShaderType::Fragment]        = GetImplAs<ShaderD3D>(fragmentShader);
+        for (gl::ShaderType shaderType : gl::kAllGraphicsShaderTypes)
+        {
+            if (mState.getAttachedShader(shaderType))
+            {
+                shadersD3D[shaderType] = GetImplAs<ShaderD3D>(mState.getAttachedShader(shaderType));
 
-        mSamplersVS.resize(data.getCaps().maxVertexTextureImageUnits);
-        mSamplersPS.resize(data.getCaps().maxTextureImageUnits);
+                mShaderSamplers[shaderType].resize(
+                    GetMaximumSamplersPerShader(shaderType, data.getCaps()));
 
-        shadersD3D[gl::ShaderType::Vertex]->generateWorkarounds(&mVertexWorkarounds);
-        shadersD3D[gl::ShaderType::Fragment]->generateWorkarounds(&mPixelWorkarounds);
+                shadersD3D[shaderType]->generateWorkarounds(&mShaderWorkarounds[shaderType]);
+
+                mShaderUniformsDirty.set(shaderType);
+            }
+        }
 
         if (mRenderer->getNativeLimitations().noFrontFacingSupport)
         {
@@ -1731,7 +1717,7 @@
         BuiltinVaryingsD3D builtins(metadata, resources.varyingPacking);
 
         mDynamicHLSL->generateShaderLinkHLSL(context, mState, metadata, resources.varyingPacking,
-                                             builtins, &mPixelHLSL, &mVertexHLSL);
+                                             builtins, &mShaderHLSL);
 
         mUsesPointSize = shadersD3D[gl::ShaderType::Vertex]->usesPointSize();
         mDynamicHLSL->getPixelShaderOutputKey(data, mState, metadata, &mPixelShaderKey);
@@ -1740,9 +1726,7 @@
         mHasANGLEMultiviewEnabled = metadata.hasANGLEMultiviewEnabled();
 
         // Cache if we use flat shading
-        mUsesFlatInterpolation =
-            (FindFlatInterpolationVarying(fragmentShader->getInputVaryings(context)) ||
-             FindFlatInterpolationVarying(vertexShader->getOutputVaryings(context)));
+        mUsesFlatInterpolation = FindFlatInterpolationVarying(context, mState.getAttachedShaders());
 
         if (mRenderer->getMajorShaderModel() >= 4)
         {
@@ -1753,10 +1737,6 @@
 
         initAttribLocationsToD3DSemantic(context);
 
-        // TODO(jiawei.shao@intel.com): set geometry uniforms dirty if user-defined geometry shader
-        // exists. Tracking bug: http://anglebug.com/1941
-        mShaderUniformsDirty.set(gl::ShaderType::Vertex);
-        mShaderUniformsDirty.set(gl::ShaderType::Fragment);
         defineUniformsAndAssignRegisters(context);
 
         gatherTransformFeedbackVaryings(resources.varyingPacking, builtins[gl::ShaderType::Vertex]);
@@ -1909,7 +1889,7 @@
 
             unsigned int registerIndex = uniformBlock.mShaderRegisterIndexes[shaderType] -
                                          reservedShaderRegisterIndexes[shaderType];
-            ASSERT(registerIndex < caps.maxVertexUniformBlocks);
+            ASSERT(registerIndex < GetMaximumShaderUniformBlocksPerShader(shaderType, caps));
 
             std::vector<int> &shaderUBOcache = mShaderUBOCaches[shaderType];
             if (shaderUBOcache.size() <= registerIndex)
@@ -2485,48 +2465,29 @@
     unsigned int registerOffset = mState.getUniforms()[uniformIndex].flattenedOffsetInParentArrays *
                                   d3dUniform->getArraySizeProduct();
 
-    // TODO(jiawei.shao@intel.com): refactor this code when using ShaderMap on mSamplers(VS|PS|CS)
-    // and mUsed(Vertex|Pixel|Compute)SamplerRange.
-    const gl::Shader *computeShader = mState.getAttachedShader(gl::ShaderType::Compute);
-    if (computeShader)
+    bool hasUniform = false;
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
     {
-        const ShaderD3D *computeShaderD3D =
-            GetImplAs<ShaderD3D>(mState.getAttachedShader(gl::ShaderType::Compute));
-        ASSERT(computeShaderD3D->hasUniform(baseName));
-        d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Compute] =
-            computeShaderD3D->getUniformRegister(baseName) + registerOffset;
-        ASSERT(d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Compute] != GL_INVALID_INDEX);
-        AssignSamplers(d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Compute],
-                       d3dUniform->typeInfo, d3dUniform->getArraySizeProduct(), mSamplersCS,
-                       &mUsedComputeSamplerRange);
-    }
-    else
-    {
-        const ShaderD3D *vertexShaderD3D =
-            GetImplAs<ShaderD3D>(mState.getAttachedShader(gl::ShaderType::Vertex));
-        const ShaderD3D *fragmentShaderD3D =
-            GetImplAs<ShaderD3D>(mState.getAttachedShader(gl::ShaderType::Fragment));
-        ASSERT(vertexShaderD3D->hasUniform(baseName) || fragmentShaderD3D->hasUniform(baseName));
-        if (vertexShaderD3D->hasUniform(baseName))
+        if (!mState.getAttachedShader(shaderType))
         {
-            d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Vertex] =
-                vertexShaderD3D->getUniformRegister(baseName) + registerOffset;
-            ASSERT(d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Vertex] != GL_INVALID_INDEX);
-            AssignSamplers(d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Vertex],
-                           d3dUniform->typeInfo, d3dUniform->getArraySizeProduct(), mSamplersVS,
-                           &mUsedVertexSamplerRange);
+            continue;
         }
-        if (fragmentShaderD3D->hasUniform(baseName))
+
+        const ShaderD3D *shaderD3D = GetImplAs<ShaderD3D>(mState.getAttachedShader(shaderType));
+        if (shaderD3D->hasUniform(baseName))
         {
-            d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Fragment] =
-                fragmentShaderD3D->getUniformRegister(baseName) + registerOffset;
-            ASSERT(d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Fragment] !=
-                   GL_INVALID_INDEX);
-            AssignSamplers(d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Fragment],
-                           d3dUniform->typeInfo, d3dUniform->getArraySizeProduct(), mSamplersPS,
-                           &mUsedPixelSamplerRange);
+            d3dUniform->mShaderRegisterIndexes[shaderType] =
+                shaderD3D->getUniformRegister(baseName) + registerOffset;
+            ASSERT(d3dUniform->mShaderRegisterIndexes[shaderType] != GL_INVALID_VALUE);
+
+            AssignSamplers(d3dUniform->mShaderRegisterIndexes[shaderType], d3dUniform->typeInfo,
+                           d3dUniform->getArraySizeProduct(), mShaderSamplers[shaderType],
+                           &mUsedShaderSamplerRanges[shaderType]);
+            hasUniform = true;
         }
     }
+
+    ASSERT(hasUniform);
 }
 
 // static
@@ -2653,11 +2614,12 @@
 
     mComputeExecutable.reset(nullptr);
 
-    mVertexHLSL.clear();
-    mVertexWorkarounds = angle::CompilerWorkaroundsD3D();
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
+    {
+        mShaderHLSL[shaderType].clear();
+        mShaderWorkarounds[shaderType] = CompilerWorkaroundsD3D();
+    }
 
-    mPixelHLSL.clear();
-    mPixelWorkarounds         = angle::CompilerWorkaroundsD3D();
     mUsesFragDepth            = false;
     mHasANGLEMultiviewEnabled = false;
     mUsesViewID               = false;
@@ -2671,17 +2633,13 @@
     for (gl::ShaderType shaderType : gl::AllShaderTypes())
     {
         mShaderUniformStorages[shaderType].reset();
+        mShaderSamplers[shaderType].clear();
     }
 
-    mSamplersPS.clear();
-    mSamplersVS.clear();
-    mSamplersCS.clear();
     mImagesCS.clear();
     mReadonlyImagesCS.clear();
 
-    mUsedVertexSamplerRange        = 0;
-    mUsedPixelSamplerRange         = 0;
-    mUsedComputeSamplerRange       = 0;
+    mUsedShaderSamplerRanges.fill(0);
     mDirtySamplerMapping           = true;
     mUsedComputeImageRange         = 0;
     mUsedComputeReadonlyImageRange = 0;