Use ShaderBitSet for active use bits on uniforms

BUG=angleproject:2169

Change-Id: I192c2e3c453540c8a6d7b0d066218ea3c9fbaab2
Reviewed-on: https://chromium-review.googlesource.com/989411
Commit-Queue: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Geoff Lang <geofflang@chromium.org>
diff --git a/src/common/bitset_utils.h b/src/common/bitset_utils.h
index 0ce8062..72e40c3 100644
--- a/src/common/bitset_utils.h
+++ b/src/common/bitset_utils.h
@@ -40,10 +40,10 @@
       private:
         friend class BitSetT;
 
-        Reference(BitSetT *parent, std::size_t bit) : mParent(parent), mBit(bit) {}
+        Reference(BitSetT *parent, ParamT bit) : mParent(parent), mBit(bit) {}
 
         BitSetT *mParent;
-        std::size_t mBit;
+        ParamT mBit;
     };
 
     class Iterator final
diff --git a/src/libANGLE/MemoryProgramCache.cpp b/src/libANGLE/MemoryProgramCache.cpp
index 809edca..e9340c61b 100644
--- a/src/libANGLE/MemoryProgramCache.cpp
+++ b/src/libANGLE/MemoryProgramCache.cpp
@@ -66,10 +66,10 @@
     stream->writeInt(var.binding);
     stream->writeInt(var.dataSize);
 
-    stream->writeInt(var.vertexActive);
-    stream->writeInt(var.fragmentActive);
-    stream->writeInt(var.computeActive);
-    stream->writeInt(var.geometryActive);
+    for (ShaderType shaderType : AllShaderTypes())
+    {
+        stream->writeInt(var.isActive(shaderType));
+    }
 
     stream->writeInt(var.memberIndexes.size());
     for (unsigned int memberCounterIndex : var.memberIndexes)
@@ -82,10 +82,11 @@
 {
     var->binding           = stream->readInt<int>();
     var->dataSize          = stream->readInt<unsigned int>();
-    var->vertexActive      = stream->readBool();
-    var->fragmentActive    = stream->readBool();
-    var->computeActive     = stream->readBool();
-    var->geometryActive    = stream->readBool();
+
+    for (ShaderType shaderType : AllShaderTypes())
+    {
+        var->setActive(shaderType, stream->readBool());
+    }
 
     unsigned int numMembers = stream->readInt<unsigned int>();
     for (unsigned int blockMemberIndex = 0; blockMemberIndex < numMembers; blockMemberIndex++)
@@ -105,9 +106,11 @@
     stream->writeInt(var.blockInfo.isRowMajorMatrix);
     stream->writeInt(var.blockInfo.topLevelArrayStride);
     stream->writeInt(var.topLevelArraySize);
-    stream->writeInt(var.vertexActive);
-    stream->writeInt(var.fragmentActive);
-    stream->writeInt(var.computeActive);
+
+    for (ShaderType shaderType : AllShaderTypes())
+    {
+        stream->writeInt(var.isActive(shaderType));
+    }
 }
 
 void LoadBufferVariable(BinaryInputStream *stream, BufferVariable *var)
@@ -121,9 +124,11 @@
     var->blockInfo.isRowMajorMatrix    = stream->readBool();
     var->blockInfo.topLevelArrayStride = stream->readInt<int>();
     var->topLevelArraySize             = stream->readInt<int>();
-    var->vertexActive                  = stream->readBool();
-    var->fragmentActive                = stream->readBool();
-    var->computeActive                 = stream->readBool();
+
+    for (ShaderType shaderType : AllShaderTypes())
+    {
+        var->setActive(shaderType, stream->readBool());
+    }
 }
 
 void WriteInterfaceBlock(BinaryOutputStream *stream, const InterfaceBlock &block)
@@ -426,7 +431,7 @@
 
     static_assert(static_cast<unsigned long>(ShaderType::EnumCount) <= sizeof(unsigned long) * 8,
                   "Too many shader types");
-    state->mLinkedShaderStages = stream.readInt<gl::ShaderStagesMask>();
+    state->mLinkedShaderStages = stream.readInt<gl::ShaderBitSet>();
 
     state->updateTransformFeedbackStrides();
 
diff --git a/src/libANGLE/PackedGLEnums.h b/src/libANGLE/PackedGLEnums.h
index 4061440..3559845 100644
--- a/src/libANGLE/PackedGLEnums.h
+++ b/src/libANGLE/PackedGLEnums.h
@@ -172,6 +172,8 @@
     angle::EnumIterator<ShaderType> end() const { return kAfterShaderTypeMax; }
 };
 
+using ShaderBitSet = angle::PackedEnumBitSet<ShaderType>;
+
 TextureType SamplerTypeToTextureType(GLenum samplerType);
 
 }  // namespace gl
diff --git a/src/libANGLE/Program.h b/src/libANGLE/Program.h
index b7cba58..8f59c09 100644
--- a/src/libANGLE/Program.h
+++ b/src/libANGLE/Program.h
@@ -277,8 +277,6 @@
     std::vector<GLuint> boundImageUnits;
 };
 
-using ShaderStagesMask = angle::PackedEnumBitSet<ShaderType>;
-
 class ProgramState final : angle::NonCopyable
 {
   public:
@@ -352,7 +350,7 @@
     int getNumViews() const { return mNumViews; }
     bool usesMultiview() const { return mNumViews != -1; }
 
-    const ShaderStagesMask &getLinkedShaderStages() const { return mLinkedShaderStages; }
+    const ShaderBitSet &getLinkedShaderStages() const { return mLinkedShaderStages; }
 
   private:
     friend class MemoryProgramCache;
@@ -423,7 +421,7 @@
 
     bool mBinaryRetrieveableHint;
     bool mSeparable;
-    ShaderStagesMask mLinkedShaderStages;
+    ShaderBitSet mLinkedShaderStages;
 
     // ANGLE_multiview.
     int mNumViews;
diff --git a/src/libANGLE/Uniform.cpp b/src/libANGLE/Uniform.cpp
index 731ea4d..31cbc09 100644
--- a/src/libANGLE/Uniform.cpp
+++ b/src/libANGLE/Uniform.cpp
@@ -14,7 +14,6 @@
 {
 
 ActiveVariable::ActiveVariable()
-    : vertexActive(false), fragmentActive(false), computeActive(false), geometryActive(false)
 {
 }
 
@@ -27,50 +26,24 @@
 
 void ActiveVariable::setActive(ShaderType shaderType, bool used)
 {
-    switch (shaderType)
-    {
-        case ShaderType::Vertex:
-            vertexActive = used;
-            break;
+    ASSERT(shaderType != ShaderType::InvalidEnum);
+    mActiveUseBits.set(shaderType, used);
+}
 
-        case ShaderType::Fragment:
-            fragmentActive = used;
-            break;
-
-        case ShaderType::Compute:
-            computeActive = used;
-            break;
-
-        case ShaderType::Geometry:
-            geometryActive = used;
-            break;
-
-        default:
-            UNREACHABLE();
-    }
+bool ActiveVariable::isActive(ShaderType shaderType) const
+{
+    ASSERT(shaderType != ShaderType::InvalidEnum);
+    return mActiveUseBits[shaderType];
 }
 
 void ActiveVariable::unionReferencesWith(const ActiveVariable &other)
 {
-    vertexActive |= other.vertexActive;
-    fragmentActive |= other.fragmentActive;
-    computeActive |= other.computeActive;
-    geometryActive |= other.geometryActive;
+    mActiveUseBits |= other.mActiveUseBits;
 }
 
 ShaderType ActiveVariable::getFirstShaderTypeWhereActive() const
 {
-    if (vertexActive)
-        return ShaderType::Vertex;
-    if (fragmentActive)
-        return ShaderType::Fragment;
-    if (computeActive)
-        return ShaderType::Compute;
-    if (geometryActive)
-        return ShaderType::Geometry;
-
-    UNREACHABLE();
-    return ShaderType::InvalidEnum;
+    return static_cast<ShaderType>(gl::ScanForward(mActiveUseBits.bits()));
 }
 
 LinkedUniform::LinkedUniform()
diff --git a/src/libANGLE/Uniform.h b/src/libANGLE/Uniform.h
index 854df5b..a3ca589 100644
--- a/src/libANGLE/Uniform.h
+++ b/src/libANGLE/Uniform.h
@@ -31,11 +31,10 @@
     ShaderType getFirstShaderTypeWhereActive() const;
     void setActive(ShaderType shaderType, bool used);
     void unionReferencesWith(const ActiveVariable &other);
+    bool isActive(ShaderType shaderType) const;
 
-    bool vertexActive;
-    bool fragmentActive;
-    bool computeActive;
-    bool geometryActive;
+  private:
+    ShaderBitSet mActiveUseBits;
 };
 
 // Helper struct representing a single shader uniform
diff --git a/src/libANGLE/queryutils.cpp b/src/libANGLE/queryutils.cpp
index 845eccf..04792c6 100644
--- a/src/libANGLE/queryutils.cpp
+++ b/src/libANGLE/queryutils.cpp
@@ -727,13 +727,13 @@
             }
             break;
         case GL_REFERENCED_BY_VERTEX_SHADER:
-            params[(*outputPosition)++] = static_cast<GLint>(buffer.vertexActive);
+            params[(*outputPosition)++] = static_cast<GLint>(buffer.isActive(ShaderType::Vertex));
             break;
         case GL_REFERENCED_BY_FRAGMENT_SHADER:
-            params[(*outputPosition)++] = static_cast<GLint>(buffer.fragmentActive);
+            params[(*outputPosition)++] = static_cast<GLint>(buffer.isActive(ShaderType::Fragment));
             break;
         case GL_REFERENCED_BY_COMPUTE_SHADER:
-            params[(*outputPosition)++] = static_cast<GLint>(buffer.computeActive);
+            params[(*outputPosition)++] = static_cast<GLint>(buffer.isActive(ShaderType::Compute));
             break;
         default:
             UNREACHABLE();
@@ -1425,13 +1425,13 @@
             return static_cast<GLint>(uniform.blockInfo.isRowMajorMatrix);
 
         case GL_REFERENCED_BY_VERTEX_SHADER:
-            return uniform.vertexActive;
+            return uniform.isActive(ShaderType::Vertex);
 
         case GL_REFERENCED_BY_FRAGMENT_SHADER:
-            return uniform.fragmentActive;
+            return uniform.isActive(ShaderType::Fragment);
 
         case GL_REFERENCED_BY_COMPUTE_SHADER:
-            return uniform.computeActive;
+            return uniform.isActive(ShaderType::Compute);
 
         case GL_ATOMIC_COUNTER_BUFFER_INDEX:
             return (uniform.isAtomicCounter() ? uniform.bufferIndex : -1);
@@ -1468,13 +1468,13 @@
             return static_cast<GLint>(bufferVariable.blockInfo.isRowMajorMatrix);
 
         case GL_REFERENCED_BY_VERTEX_SHADER:
-            return bufferVariable.vertexActive;
+            return bufferVariable.isActive(ShaderType::Vertex);
 
         case GL_REFERENCED_BY_FRAGMENT_SHADER:
-            return bufferVariable.fragmentActive;
+            return bufferVariable.isActive(ShaderType::Fragment);
 
         case GL_REFERENCED_BY_COMPUTE_SHADER:
-            return bufferVariable.computeActive;
+            return bufferVariable.isActive(ShaderType::Compute);
 
         case GL_TOP_LEVEL_ARRAY_SIZE:
             return bufferVariable.topLevelArraySize;
diff --git a/src/libANGLE/renderer/d3d/ProgramD3D.cpp b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
index 74fecd0..afa61de 100644
--- a/src/libANGLE/renderer/d3d/ProgramD3D.cpp
+++ b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
@@ -622,10 +622,7 @@
       mDirtySamplerMapping(true),
       mUsedComputeImageRange(0),
       mUsedComputeReadonlyImageRange(0),
-      mSerial(issueSerial()),
-      mVertexUniformsDirty(true),
-      mFragmentUniformsDirty(true),
-      mComputeUniformsDirty(true)
+      mSerial(issueSerial())
 {
     mDynamicHLSL = new DynamicHLSL(renderer);
 }
@@ -1148,6 +1145,8 @@
 
     initializeUniformStorage();
 
+    dirtyAllUniforms();
+
     return true;
 }
 
@@ -1697,6 +1696,7 @@
         mImagesCS.resize(data.getCaps().maxImageUnits);
         mReadonlyImagesCS.resize(data.getCaps().maxImageUnits);
 
+        mShaderUniformsDirty.set(gl::ShaderType::Compute);
         defineUniformsAndAssignRegisters(context);
 
         gl::LinkResult result = compileComputeExecutable(context, infoLog);
@@ -1760,6 +1760,10 @@
 
         initAttribLocationsToD3DSemantic(context);
 
+        // TODO(jiawei.shao@intel.com): set geometry uniforms dirty if user-defined geometry shader
+        // exists. Tracking bug: http://anglebug.com/1941
+        mShaderUniformsDirty.set(gl::ShaderType::Vertex);
+        mShaderUniformsDirty.set(gl::ShaderType::Fragment);
         defineUniformsAndAssignRegisters(context);
 
         gatherTransformFeedbackVaryings(resources.varyingPacking, builtins[gl::ShaderType::Vertex]);
@@ -1811,14 +1815,14 @@
 
         D3DUniformBlock d3dUniformBlock;
 
-        if (uniformBlock.vertexActive)
+        if (uniformBlock.isActive(gl::ShaderType::Vertex))
         {
             ASSERT(vertexShaderD3D != nullptr);
             unsigned int baseRegister = vertexShaderD3D->getUniformBlockRegister(uniformBlock.name);
             d3dUniformBlock.vsRegisterIndex = baseRegister + uniformBlockElement;
         }
 
-        if (uniformBlock.fragmentActive)
+        if (uniformBlock.isActive(gl::ShaderType::Fragment))
         {
             ASSERT(fragmentShaderD3D != nullptr);
             unsigned int baseRegister =
@@ -1826,7 +1830,7 @@
             d3dUniformBlock.psRegisterIndex = baseRegister + uniformBlockElement;
         }
 
-        if (uniformBlock.computeActive)
+        if (uniformBlock.isActive(gl::ShaderType::Compute))
         {
             ASSERT(computeShaderD3D != nullptr);
             unsigned int baseRegister =
@@ -1968,16 +1972,12 @@
 
 void ProgramD3D::dirtyAllUniforms()
 {
-    mVertexUniformsDirty   = true;
-    mFragmentUniformsDirty = true;
-    mComputeUniformsDirty  = true;
+    mShaderUniformsDirty = mState.getLinkedShaderStages();
 }
 
 void ProgramD3D::markUniformsClean()
 {
-    mVertexUniformsDirty   = false;
-    mFragmentUniformsDirty = false;
-    mComputeUniformsDirty  = false;
+    mShaderUniformsDirty.reset();
 }
 
 void ProgramD3D::setUniform1fv(GLint location, GLsizei count, const GLfloat *v)
@@ -2442,19 +2442,19 @@
     if (targetUniform->vsData)
     {
         setUniformImpl(locationInfo, count, v, targetUniform->vsData, uniformType);
-        mVertexUniformsDirty = true;
+        mShaderUniformsDirty.set(gl::ShaderType::Vertex);
     }
 
     if (targetUniform->psData)
     {
         setUniformImpl(locationInfo, count, v, targetUniform->psData, uniformType);
-        mFragmentUniformsDirty = true;
+        mShaderUniformsDirty.set(gl::ShaderType::Fragment);
     }
 
     if (targetUniform->csData)
     {
         setUniformImpl(locationInfo, count, v, targetUniform->csData, uniformType);
-        mComputeUniformsDirty = true;
+        mShaderUniformsDirty.set(gl::ShaderType::Compute);
     }
 }
 
@@ -2511,7 +2511,7 @@
         if (setUniformMatrixfvImpl<cols, rows>(location, countIn, transpose, value,
                                                targetUniform->vsData, targetUniformType))
         {
-            mVertexUniformsDirty = true;
+            mShaderUniformsDirty.set(gl::ShaderType::Vertex);
         }
     }
 
@@ -2520,7 +2520,7 @@
         if (setUniformMatrixfvImpl<cols, rows>(location, countIn, transpose, value,
                                                targetUniform->psData, targetUniformType))
         {
-            mFragmentUniformsDirty = true;
+            mShaderUniformsDirty.set(gl::ShaderType::Fragment);
         }
     }
 
@@ -2529,7 +2529,7 @@
         if (setUniformMatrixfvImpl<cols, rows>(location, countIn, transpose, value,
                                                targetUniform->csData, targetUniformType))
         {
-            mComputeUniformsDirty = true;
+            mShaderUniformsDirty.set(gl::ShaderType::Compute);
         }
     }
 }
@@ -2757,7 +2757,7 @@
 
     mGeometryShaderPreamble.clear();
 
-    dirtyAllUniforms();
+    markUniformsClean();
 
     mCachedPixelExecutableIndex.reset();
     mCachedVertexExecutableIndex.reset();
@@ -2951,6 +2951,11 @@
     return mCachedPixelExecutableIndex.valid();
 }
 
+bool ProgramD3D::anyShaderUniformsDirty() const
+{
+    return mShaderUniformsDirty.any();
+}
+
 template <typename DestT>
 void ProgramD3D::getUniformInternal(GLint location, DestT *dataOut) const
 {
diff --git a/src/libANGLE/renderer/d3d/ProgramD3D.h b/src/libANGLE/renderer/d3d/ProgramD3D.h
index 9605450..4082a67 100644
--- a/src/libANGLE/renderer/d3d/ProgramD3D.h
+++ b/src/libANGLE/renderer/d3d/ProgramD3D.h
@@ -308,9 +308,11 @@
     bool hasGeometryExecutableForPrimitiveType(GLenum drawMode);
     bool hasPixelExecutableForCachedOutputLayout();
 
-    bool areVertexUniformsDirty() const { return mVertexUniformsDirty; }
-    bool areFragmentUniformsDirty() const { return mFragmentUniformsDirty; }
-    bool areComputeUniformsDirty() const { return mComputeUniformsDirty; }
+    bool anyShaderUniformsDirty() const;
+    bool areShaderUniformsDirty(gl::ShaderType shaderType) const
+    {
+        return mShaderUniformsDirty[shaderType];
+    }
     const std::vector<D3DUniform *> &getD3DUniforms() const { return mD3DUniforms; }
     void markUniformsClean();
 
@@ -550,9 +552,7 @@
     std::map<std::string, int> mImageBindingMap;
     std::vector<D3DUniformBlock> mD3DUniformBlocks;
 
-    bool mVertexUniformsDirty;
-    bool mFragmentUniformsDirty;
-    bool mComputeUniformsDirty;
+    gl::ShaderBitSet mShaderUniformsDirty;
 
     static unsigned int issueSerial();
     static unsigned int mCurrentSerial;
diff --git a/src/libANGLE/renderer/d3d/d3d11/StateManager11.cpp b/src/libANGLE/renderer/d3d/d3d11/StateManager11.cpp
index f6777d2..5864731 100644
--- a/src/libANGLE/renderer/d3d/d3d11/StateManager11.cpp
+++ b/src/libANGLE/renderer/d3d/d3d11/StateManager11.cpp
@@ -261,13 +261,9 @@
 
 // ShaderConstants11 implementation
 ShaderConstants11::ShaderConstants11()
-    : mVertexDirty(true),
-      mPixelDirty(true),
-      mComputeDirty(true),
-      mNumActiveVSSamplers(0),
-      mNumActivePSSamplers(0),
-      mNumActiveCSSamplers(0)
+    : mNumActiveVSSamplers(0), mNumActivePSSamplers(0), mNumActiveCSSamplers(0)
 {
+    mShaderConstantsDirty.set();
 }
 
 ShaderConstants11::~ShaderConstants11()
@@ -299,9 +295,7 @@
 
 void ShaderConstants11::markDirty()
 {
-    mVertexDirty         = true;
-    mPixelDirty          = true;
-    mComputeDirty        = true;
+    mShaderConstantsDirty.set();
     mNumActiveVSSamplers = 0;
     mNumActivePSSamplers = 0;
     mNumActiveCSSamplers = 0;
@@ -398,15 +392,15 @@
     mCompute.numWorkGroups[0] = numGroupsX;
     mCompute.numWorkGroups[1] = numGroupsY;
     mCompute.numWorkGroups[2] = numGroupsZ;
-    mComputeDirty             = true;
+    mShaderConstantsDirty.set(gl::ShaderType::Compute);
 }
 
 void ShaderConstants11::setMultiviewWriteToViewportIndex(GLfloat index)
 {
     mVertex.multiviewWriteToViewportIndex = index;
-    mVertexDirty                          = true;
     mPixel.multiviewWriteToViewportIndex  = index;
-    mPixelDirty                           = true;
+    mShaderConstantsDirty.set(gl::ShaderType::Vertex);
+    mShaderConstantsDirty.set(gl::ShaderType::Fragment);
 }
 
 void ShaderConstants11::onViewportChange(const gl::Rectangle &glViewport,
@@ -414,8 +408,8 @@
                                          bool is9_3,
                                          bool presentPathFast)
 {
-    mVertexDirty = true;
-    mPixelDirty  = true;
+    mShaderConstantsDirty.set(gl::ShaderType::Vertex);
+    mShaderConstantsDirty.set(gl::ShaderType::Fragment);
 
     // On Feature Level 9_*, we must emulate large and/or negative viewports in the shaders
     // using viewAdjust (like the D3D9 renderer).
@@ -513,27 +507,30 @@
     switch (shaderType)
     {
         case gl::ShaderType::Vertex:
-            dirty                   = mVertexDirty || (mNumActiveVSSamplers < numSamplers);
+            dirty = mShaderConstantsDirty[gl::ShaderType::Vertex] ||
+                    (mNumActiveVSSamplers < numSamplers);
             dataSize                = sizeof(Vertex);
             data                    = reinterpret_cast<const uint8_t *>(&mVertex);
             samplerData             = reinterpret_cast<const uint8_t *>(mSamplerMetadataVS.data());
-            mVertexDirty            = false;
+            mShaderConstantsDirty.set(gl::ShaderType::Vertex, false);
             mNumActiveVSSamplers    = numSamplers;
             break;
         case gl::ShaderType::Fragment:
-            dirty                   = mPixelDirty || (mNumActivePSSamplers < numSamplers);
+            dirty = mShaderConstantsDirty[gl::ShaderType::Fragment] ||
+                    (mNumActivePSSamplers < numSamplers);
             dataSize                = sizeof(Pixel);
             data                    = reinterpret_cast<const uint8_t *>(&mPixel);
             samplerData             = reinterpret_cast<const uint8_t *>(mSamplerMetadataPS.data());
-            mPixelDirty             = false;
+            mShaderConstantsDirty.set(gl::ShaderType::Fragment, false);
             mNumActivePSSamplers    = numSamplers;
             break;
         case gl::ShaderType::Compute:
-            dirty                   = mComputeDirty || (mNumActiveCSSamplers < numSamplers);
+            dirty = mShaderConstantsDirty[gl::ShaderType::Compute] ||
+                    (mNumActiveCSSamplers < numSamplers);
             dataSize                = sizeof(Compute);
             data                    = reinterpret_cast<const uint8_t *>(&mCompute);
             samplerData             = reinterpret_cast<const uint8_t *>(mSamplerMetadataCS.data());
-            mComputeDirty           = false;
+            mShaderConstantsDirty.set(gl::ShaderType::Compute, false);
             mNumActiveCSSamplers    = numSamplers;
             break;
         default:
@@ -2008,7 +2005,7 @@
     }
 
     // TODO(jmadill): Use dirty bits.
-    if (programD3D->areVertexUniformsDirty() || programD3D->areFragmentUniformsDirty())
+    if (programD3D->anyShaderUniformsDirty())
     {
         mInternalDirtyBits.set(DIRTY_BIT_PROGRAM_UNIFORMS);
     }
@@ -3091,12 +3088,14 @@
     const d3d11::Buffer *pixelConstantBuffer = nullptr;
     ANGLE_TRY(fragmentUniformStorage->getConstantBuffer(mRenderer, &pixelConstantBuffer));
 
-    if (vertexUniformStorage->size() > 0 && programD3D->areVertexUniformsDirty())
+    if (vertexUniformStorage->size() > 0 &&
+        programD3D->areShaderUniformsDirty(gl::ShaderType::Vertex))
     {
         UpdateUniformBuffer(deviceContext, vertexUniformStorage, vertexConstantBuffer);
     }
 
-    if (fragmentUniformStorage->size() > 0 && programD3D->areFragmentUniformsDirty())
+    if (fragmentUniformStorage->size() > 0 &&
+        programD3D->areShaderUniformsDirty(gl::ShaderType::Fragment))
     {
         UpdateUniformBuffer(deviceContext, fragmentUniformStorage, pixelConstantBuffer);
     }
@@ -3187,7 +3186,8 @@
 
     ID3D11DeviceContext *deviceContext = mRenderer->getDeviceContext();
 
-    if (computeUniformStorage->size() > 0 && programD3D->areComputeUniformsDirty())
+    if (computeUniformStorage->size() > 0 &&
+        programD3D->areShaderUniformsDirty(gl::ShaderType::Compute))
     {
         UpdateUniformBuffer(deviceContext, computeUniformStorage, constantBuffer);
         programD3D->markUniformsClean();
diff --git a/src/libANGLE/renderer/d3d/d3d11/StateManager11.h b/src/libANGLE/renderer/d3d/d3d11/StateManager11.h
index 2400a3c..a7fbfab 100644
--- a/src/libANGLE/renderer/d3d/d3d11/StateManager11.h
+++ b/src/libANGLE/renderer/d3d/d3d11/StateManager11.h
@@ -127,11 +127,9 @@
     bool updateSamplerMetadata(SamplerMetadata *data, const gl::Texture &texture);
 
     Vertex mVertex;
-    bool mVertexDirty;
     Pixel mPixel;
-    bool mPixelDirty;
     Compute mCompute;
-    bool mComputeDirty;
+    gl::ShaderBitSet mShaderConstantsDirty;
 
     std::vector<SamplerMetadata> mSamplerMetadataVS;
     int mNumActiveVSSamplers;
diff --git a/src/libANGLE/renderer/d3d/d3d9/Renderer9.cpp b/src/libANGLE/renderer/d3d/d3d9/Renderer9.cpp
index 810e448..01341fe 100644
--- a/src/libANGLE/renderer/d3d/d3d9/Renderer9.cpp
+++ b/src/libANGLE/renderer/d3d/d3d9/Renderer9.cpp
@@ -1806,8 +1806,8 @@
 
 gl::Error Renderer9::applyUniforms(ProgramD3D *programD3D)
 {
-    // Skip updates if we're not dirty. Note that D3D9 cannot have compute.
-    if (!programD3D->areVertexUniformsDirty() && !programD3D->areFragmentUniformsDirty())
+    // Skip updates if we're not dirty. Note that D3D9 cannot have compute or geometry.
+    if (!programD3D->anyShaderUniformsDirty())
     {
         return gl::NoError();
     }
diff --git a/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp b/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp
index cec43b0..ff3d986 100644
--- a/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp
+++ b/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp
@@ -138,13 +138,14 @@
 
         std::string setBindingString = "set = 1, binding = " + Str(textureCount);
 
-        ASSERT(samplerUniform.vertexActive || samplerUniform.fragmentActive);
-        if (samplerUniform.vertexActive)
+        ASSERT(samplerUniform.isActive(gl::ShaderType::Vertex) ||
+               samplerUniform.isActive(gl::ShaderType::Fragment));
+        if (samplerUniform.isActive(gl::ShaderType::Vertex))
         {
             InsertLayoutSpecifierString(&vertexSource, samplerUniform.name, setBindingString);
         }
 
-        if (samplerUniform.fragmentActive)
+        if (samplerUniform.isActive(gl::ShaderType::Fragment))
         {
             InsertLayoutSpecifierString(&fragmentSource, samplerUniform.name, setBindingString);
         }