SPV: Implement the extension SPV_KHR_shader_ballot
diff --git a/SPIRV/CMakeLists.txt b/SPIRV/CMakeLists.txt
index 48a6c46..ad72276 100755
--- a/SPIRV/CMakeLists.txt
+++ b/SPIRV/CMakeLists.txt
@@ -13,6 +13,7 @@
 set(HEADERS
     spirv.hpp
     GLSL.std.450.h
+    GLSL.ext.KHR.h
     GlslangToSpv.h
     Logger.h
     SpvBuilder.h
diff --git a/SPIRV/GLSL.ext.KHR.h b/SPIRV/GLSL.ext.KHR.h
new file mode 100644
index 0000000..7ce795f
--- /dev/null
+++ b/SPIRV/GLSL.ext.KHR.h
@@ -0,0 +1,51 @@
+/*
+** Copyright (c) 2014-2016 The Khronos Group Inc.
+**
+** Permission is hereby granted, free of charge, to any person obtaining a copy
+** of this software and/or associated documentation files (the "Materials"),
+** to deal in the Materials without restriction, including without limitation
+** the rights to use, copy, modify, merge, publish, distribute, sublicense,
+** and/or sell copies of the Materials, and to permit persons to whom the
+** Materials are furnished to do so, subject to the following conditions:
+**
+** The above copyright notice and this permission notice shall be included in
+** all copies or substantial portions of the Materials.
+**
+** MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS
+** STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND
+** HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/
+**
+** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+** OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+** THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+** FROM,OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS
+** IN THE MATERIALS.
+*/
+
+#ifndef GLSLextKHR_H
+#define GLSLextKHR_H
+
+enum BuiltIn;
+enum Op;
+enum Capability;
+
+static const int GLSLextKHRVersion = 100;
+static const int GLSLextKHRRevision = 1;
+
+// SPV_KHR_shader_ballot
+static const char* const E_SPV_KHR_shader_ballot        = "SPV_KHR_shader_ballot";
+
+static const BuiltIn BuiltInSubgroupEqMaskKHR           = static_cast<BuiltIn>(4416);
+static const BuiltIn BuiltInSubgroupGeMaskKHR           = static_cast<BuiltIn>(4417);
+static const BuiltIn BuiltInSubgroupGtMaskKHR           = static_cast<BuiltIn>(4418);
+static const BuiltIn BuiltInSubgroupLeMaskKHR           = static_cast<BuiltIn>(4419);
+static const BuiltIn BuiltInSubgroupLtMaskKHR           = static_cast<BuiltIn>(4420);
+
+static const Op OpSubgroupBallotKHR                     = static_cast<Op>(4421);
+static const Op OpSubgroupFirstInvocationKHR            = static_cast<Op>(4422);
+
+static const Capability CapabilitySubgroupBallotKHR     = static_cast<Capability>(4423);
+
+#endif  // #ifndef GLSLextKHR_H
diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp
index 7286588..8372dbb 100755
--- a/SPIRV/GlslangToSpv.cpp
+++ b/SPIRV/GlslangToSpv.cpp
@@ -42,9 +42,10 @@
 #include "GlslangToSpv.h"
 #include "SpvBuilder.h"
 namespace spv {
-   #include "GLSL.std.450.h"
+    #include "GLSL.std.450.h"
+    #include "GLSL.ext.KHR.h"
 #ifdef AMD_EXTENSIONS
-   #include "GLSL.ext.AMD.h"
+    #include "GLSL.ext.AMD.h"
 #endif
 }
 
@@ -154,7 +155,7 @@
     spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id destTypeId, spv::Id operand, glslang::TBasicType typeProxy);
     spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
     spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
-    spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy);
+    spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
 #ifdef AMD_EXTENSIONS
     spv::Id CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, spv::Id operand);
 #endif
@@ -521,16 +522,40 @@
     case glslang::EbvLocalInvocationId:    return spv::BuiltInLocalInvocationId;
     case glslang::EbvLocalInvocationIndex: return spv::BuiltInLocalInvocationIndex;
     case glslang::EbvGlobalInvocationId:   return spv::BuiltInGlobalInvocationId;
+
     case glslang::EbvSubGroupSize:
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+        return spv::BuiltInSubgroupSize;
+
     case glslang::EbvSubGroupInvocation:
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+        return spv::BuiltInSubgroupLocalInvocationId;
+
     case glslang::EbvSubGroupEqMask:
+        builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+        return spv::BuiltInSubgroupEqMaskKHR;
+
     case glslang::EbvSubGroupGeMask:
+        builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+        return spv::BuiltInSubgroupGeMaskKHR;
+
     case glslang::EbvSubGroupGtMask:
+        builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+        return spv::BuiltInSubgroupGtMaskKHR;
+
     case glslang::EbvSubGroupLeMask:
+        builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+        return spv::BuiltInSubgroupLeMaskKHR;
+
     case glslang::EbvSubGroupLtMask:
-        // TODO: Add SPIR-V builtin ID.
-        logger->missingFunctionality("shader ballot");
-        return spv::BuiltInMax;
+        builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+        return spv::BuiltInSubgroupLtMaskKHR;
+
 #ifdef AMD_EXTENSIONS
     case glslang::EbvBaryCoordNoPersp:          return spv::BuiltInBaryCoordNoPerspAMD;
     case glslang::EbvBaryCoordNoPerspCentroid:  return spv::BuiltInBaryCoordNoPerspCentroidAMD;
@@ -3610,10 +3635,6 @@
 
     case glslang::EOpBallot:
     case glslang::EOpReadFirstInvocation:
-        logger->missingFunctionality("shader ballot");
-        libCall = spv::GLSLstd450Bad;
-        break;
-
     case glslang::EOpAnyInvocation:
     case glslang::EOpAllInvocations:
     case glslang::EOpAllInvocationsEqual:
@@ -3625,7 +3646,11 @@
     case glslang::EOpMaxInvocationsNonUniform:
     case glslang::EOpAddInvocationsNonUniform:
 #endif
-        return createInvocationsOperation(op, typeId, operand, typeProxy);
+    {
+        std::vector<spv::Id> operands;
+        operands.push_back(operand);
+        return createInvocationsOperation(op, typeId, operands, typeProxy);
+    }
 
 #ifdef AMD_EXTENSIONS
     case glslang::EOpMbcnt:
@@ -3959,113 +3984,149 @@
 }
 
 // Create group invocation operations.
-spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy)
+spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy)
 {
     bool isUnsigned = typeProxy == glslang::EbtUint || typeProxy == glslang::EbtUint64;
     bool isFloat = typeProxy == glslang::EbtFloat || typeProxy == glslang::EbtDouble;
 
-    builder.addCapability(spv::CapabilityGroups);
+    spv::Op opCode = spv::OpNop;
 
-    std::vector<spv::Id> operands;
-    operands.push_back(builder.makeUintConstant(spv::ScopeSubgroup));
+    std::vector<spv::Id> spvGroupOperands;
+    if (op == glslang::EOpBallot || op == glslang::EOpReadFirstInvocation) {
+        builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+        builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+    } else {
+        builder.addCapability(spv::CapabilityGroups);
+
+        spvGroupOperands.push_back(builder.makeUintConstant(spv::ScopeSubgroup));
 #ifdef AMD_EXTENSIONS
-    if (op == glslang::EOpMinInvocations || op == glslang::EOpMaxInvocations || op == glslang::EOpAddInvocations ||
-        op == glslang::EOpMinInvocationsNonUniform || op == glslang::EOpMaxInvocationsNonUniform || op == glslang::EOpAddInvocationsNonUniform)
-        operands.push_back(spv::GroupOperationReduce);
+        if (op == glslang::EOpMinInvocations || op == glslang::EOpMaxInvocations || op == glslang::EOpAddInvocations ||
+            op == glslang::EOpMinInvocationsNonUniform || op == glslang::EOpMaxInvocationsNonUniform || op == glslang::EOpAddInvocationsNonUniform)
+            spvGroupOperands.push_back(spv::GroupOperationReduce);
 #endif
-    operands.push_back(operand);
+    }
+
+    for (auto opIt = operands.begin(); opIt != operands.end(); ++opIt)
+        spvGroupOperands.push_back(*opIt);
 
     switch (op) {
     case glslang::EOpAnyInvocation:
+        opCode = spv::OpGroupAny;
+        break;
     case glslang::EOpAllInvocations:
-        return builder.createOp(op == glslang::EOpAnyInvocation ? spv::OpGroupAny : spv::OpGroupAll, typeId, operands);
-
+        opCode = spv::OpGroupAll;
+        break;
     case glslang::EOpAllInvocationsEqual:
     {
-        spv::Id groupAll = builder.createOp(spv::OpGroupAll, typeId, operands);
-        spv::Id groupAny = builder.createOp(spv::OpGroupAny, typeId, operands);
+        spv::Id groupAll = builder.createOp(spv::OpGroupAll, typeId, spvGroupOperands);
+        spv::Id groupAny = builder.createOp(spv::OpGroupAny, typeId, spvGroupOperands);
 
         return builder.createBinOp(spv::OpLogicalOr, typeId, groupAll,
                                    builder.createUnaryOp(spv::OpLogicalNot, typeId, groupAny));
     }
+
+    case glslang::EOpReadInvocation:
+        opCode = spv::OpGroupBroadcast;
+        break;
+    case glslang::EOpReadFirstInvocation:
+        opCode = spv::OpSubgroupFirstInvocationKHR;
+        break;
+    case glslang::EOpBallot:
+    {
+        // NOTE: According to the spec, the result type of "OpSubgroupBallotKHR" must be a 4 component vector of 32
+        // bit integer types. The GLSL built-in function "ballotARB()" assumes the maximum number of invocations in
+        // a subgroup is 64. Thus, we have to convert uvec4.xy to uint64_t as follow:
+        //
+        //     result = Bitcast(SubgroupBallotKHR(Predicate).xy)
+        //
+        spv::Id uintType  = builder.makeUintType(32);
+        spv::Id uvec4Type = builder.makeVectorType(uintType, 4);
+        spv::Id result = builder.createOp(spv::OpSubgroupBallotKHR, uvec4Type, spvGroupOperands);
+
+        std::vector<spv::Id> components;
+        components.push_back(builder.createCompositeExtract(result, uintType, 0));
+        components.push_back(builder.createCompositeExtract(result, uintType, 1));
+
+        spv::Id uvec2Type = builder.makeVectorType(uintType, 2);
+        return builder.createUnaryOp(spv::OpBitcast, typeId,
+                                     builder.createCompositeConstruct(uvec2Type, components));
+    }
+
 #ifdef AMD_EXTENSIONS
     case glslang::EOpMinInvocations:
     case glslang::EOpMaxInvocations:
     case glslang::EOpAddInvocations:
-    {
-        spv::Op spvOp = spv::OpNop;
         if (op == glslang::EOpMinInvocations) {
             if (isFloat)
-                spvOp = spv::OpGroupFMin;
+                opCode = spv::OpGroupFMin;
             else {
                 if (isUnsigned)
-                    spvOp = spv::OpGroupUMin;
+                    opCode = spv::OpGroupUMin;
                 else
-                    spvOp = spv::OpGroupSMin;
+                    opCode = spv::OpGroupSMin;
             }
         } else if (op == glslang::EOpMaxInvocations) {
             if (isFloat)
-                spvOp = spv::OpGroupFMax;
+                opCode = spv::OpGroupFMax;
             else {
                 if (isUnsigned)
-                    spvOp = spv::OpGroupUMax;
+                    opCode = spv::OpGroupUMax;
                 else
-                    spvOp = spv::OpGroupSMax;
+                    opCode = spv::OpGroupSMax;
             }
         } else {
             if (isFloat)
-                spvOp = spv::OpGroupFAdd;
+                opCode = spv::OpGroupFAdd;
             else
-                spvOp = spv::OpGroupIAdd;
+                opCode = spv::OpGroupIAdd;
         }
 
         if (builder.isVectorType(typeId))
-            return CreateInvocationsVectorOperation(spvOp, typeId, operand);
-        else
-            return builder.createOp(spvOp, typeId, operands);
-    }
+            return CreateInvocationsVectorOperation(opCode, typeId, operands[0]);
+
+        break;
     case glslang::EOpMinInvocationsNonUniform:
     case glslang::EOpMaxInvocationsNonUniform:
     case glslang::EOpAddInvocationsNonUniform:
-    {
-        spv::Op spvOp = spv::OpNop;
         if (op == glslang::EOpMinInvocationsNonUniform) {
             if (isFloat)
-                spvOp = spv::OpGroupFMinNonUniformAMD;
+                opCode = spv::OpGroupFMinNonUniformAMD;
             else {
                 if (isUnsigned)
-                    spvOp = spv::OpGroupUMinNonUniformAMD;
+                    opCode = spv::OpGroupUMinNonUniformAMD;
                 else
-                    spvOp = spv::OpGroupSMinNonUniformAMD;
+                    opCode = spv::OpGroupSMinNonUniformAMD;
             }
         }
         else if (op == glslang::EOpMaxInvocationsNonUniform) {
             if (isFloat)
-                spvOp = spv::OpGroupFMaxNonUniformAMD;
+                opCode = spv::OpGroupFMaxNonUniformAMD;
             else {
                 if (isUnsigned)
-                    spvOp = spv::OpGroupUMaxNonUniformAMD;
+                    opCode = spv::OpGroupUMaxNonUniformAMD;
                 else
-                    spvOp = spv::OpGroupSMaxNonUniformAMD;
+                    opCode = spv::OpGroupSMaxNonUniformAMD;
             }
         }
         else {
             if (isFloat)
-                spvOp = spv::OpGroupFAddNonUniformAMD;
+                opCode = spv::OpGroupFAddNonUniformAMD;
             else
-                spvOp = spv::OpGroupIAddNonUniformAMD;
+                opCode = spv::OpGroupIAddNonUniformAMD;
         }
 
         if (builder.isVectorType(typeId))
-            return CreateInvocationsVectorOperation(spvOp, typeId, operand);
-        else
-            return builder.createOp(spvOp, typeId, operands);
-    }
+            return CreateInvocationsVectorOperation(opCode, typeId, operands[0]);
+
+        break;
 #endif
     default:
         logger->missingFunctionality("invocation operation");
         return spv::NoResult;
     }
+
+    assert(opCode != spv::OpNop);
+    return builder.createOp(opCode, typeId, spvGroupOperands);
 }
 
 #ifdef AMD_EXTENSIONS
@@ -4256,9 +4317,7 @@
         break;
 
     case glslang::EOpReadInvocation:
-        logger->missingFunctionality("shader ballot");
-        libCall = spv::GLSLstd450Bad;
-        break;
+        return createInvocationsOperation(op, typeId, operands, typeProxy);
 
 #ifdef AMD_EXTENSIONS
     case glslang::EOpSwizzleInvocations:
@@ -4825,7 +4884,7 @@
     if (extBuiltinMap.find(name) != extBuiltinMap.end())
         return extBuiltinMap[name];
     else {
-        builder.addExtensions(name);
+        builder.addExtension(name);
         spv::Id extBuiltins = builder.import(name);
         extBuiltinMap[name] = extBuiltins;
         return extBuiltins;
diff --git a/SPIRV/SpvBuilder.cpp b/SPIRV/SpvBuilder.cpp
index a881d1b..7aaa51f 100644
--- a/SPIRV/SpvBuilder.cpp
+++ b/SPIRV/SpvBuilder.cpp
@@ -2318,9 +2318,9 @@
         capInst.dump(out);
     }
 
-    for (int e = 0; e < (int)extensions.size(); ++e) {
+    for (auto it = extensions.cbegin(); it != extensions.cend(); ++it) {
         Instruction extInst(0, 0, OpExtension);
-        extInst.addStringOperand(extensions[e]);
+        extInst.addStringOperand(*it);
         extInst.dump(out);
     }
 
diff --git a/SPIRV/SpvBuilder.h b/SPIRV/SpvBuilder.h
index 38dc1fa..6e709ea 100755
--- a/SPIRV/SpvBuilder.h
+++ b/SPIRV/SpvBuilder.h
@@ -71,7 +71,7 @@
         sourceVersion = version;
     }
     void addSourceExtension(const char* ext) { sourceExtensions.push_back(ext); }
-    void addExtensions(const char* ext) { extensions.push_back(ext); }
+    void addExtension(const char* ext) { extensions.insert(ext); }
     Id import(const char*);
     void setMemoryModel(spv::AddressingModel addr, spv::MemoryModel mem)
     {
@@ -552,7 +552,7 @@
 
     SourceLanguage source;
     int sourceVersion;
-    std::vector<const char*> extensions;
+    std::set<const char*> extensions;
     std::vector<const char*> sourceExtensions;
     AddressingModel addressModel;
     MemoryModel memoryModel;
diff --git a/SPIRV/doc.cpp b/SPIRV/doc.cpp
index a25f7c0..d2161dd 100755
--- a/SPIRV/doc.cpp
+++ b/SPIRV/doc.cpp
@@ -45,14 +45,15 @@
 #include <cstring>
 #include <algorithm>
 
-#ifdef AMD_EXTENSIONS
 namespace spv {
     extern "C" {
         // Include C-based headers that don't have a namespace
+        #include "GLSL.ext.KHR.h"
+#ifdef AMD_EXTENSIONS
         #include "GLSL.ext.AMD.h"
+#endif
     }
 }
-#endif
 
 namespace spv {
 
@@ -312,6 +313,12 @@
     case BuiltInCeiling:
     default: return "Bad";
 
+    case 4416: return "SubgroupEqMaskKHR";
+    case 4417: return "SubgroupGeMaskKHR";
+    case 4418: return "SubgroupGtMaskKHR";
+    case 4419: return "SubgroupLeMaskKHR";
+    case 4420: return "SubgroupLtMaskKHR";
+
 #ifdef AMD_EXTENSIONS
     case 4992: return "BaryCoordNoPerspAMD";
     case 4993: return "BaryCoordNoPerspCentroidAMD";
@@ -799,6 +806,8 @@
 
     case CapabilityCeiling:
     default: return "Bad";
+
+    case 4423: return "SubgroupBallotKHR";
     }
 }
 
@@ -1131,6 +1140,9 @@
     default:
         return "Bad";
 
+    case 4421: return "OpSubgroupBallotKHR";
+    case 4422: return "OpSubgroupFirstInvocationKHR";
+
 #ifdef AMD_EXTENSIONS
     case 5000: return "OpGroupIAddNonUniformAMD";
     case 5001: return "OpGroupFAddNonUniformAMD";
@@ -1146,11 +1158,7 @@
 
 // The set of objects that hold all the instruction/operand
 // parameterization information.
-#ifdef AMD_EXTENSIONS
 InstructionParameters InstructionDesc[OpCodeMask + 1];
-#else
-InstructionParameters InstructionDesc[OpcodeCeiling];
-#endif
 OperandParameters ExecutionModeOperands[ExecutionModeCeiling];
 OperandParameters DecorationOperands[DecorationCeiling];
 
@@ -2742,6 +2750,10 @@
     InstructionDesc[OpEnqueueMarker].operands.push(OperandId, "'Wait Events'");
     InstructionDesc[OpEnqueueMarker].operands.push(OperandId, "'Ret Event'");
 
+    InstructionDesc[OpSubgroupBallotKHR].operands.push(OperandId, "'Predicate'");
+
+    InstructionDesc[OpSubgroupFirstInvocationKHR].operands.push(OperandId, "'Value'");
+
 #ifdef AMD_EXTENSIONS
     InstructionDesc[OpGroupIAddNonUniformAMD].capabilities.push_back(CapabilityGroups);
     InstructionDesc[OpGroupIAddNonUniformAMD].operands.push(OperandScope, "'Execution'");