blob: cd397892101ea9dc1528b327f02e3bd0465edcd1 [file] [log] [blame]
// Copyright 2021 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// OutputSPIRV: Generate SPIR-V from the AST.
#include "compiler/translator/OutputSPIRV.h"
#include "angle_gl.h"
#include "common/debug.h"
#include "common/mathutil.h"
#include "common/spirv/spirv_instruction_builder_autogen.h"
#include "compiler/translator/BuildSPIRV.h"
#include "compiler/translator/Compiler.h"
#include "compiler/translator/tree_util/IntermTraverse.h"
#include <cfloat>
// Extended instructions
namespace spv
#include <spirv/unified1/GLSL.std.450.h>
// SPIR-V tools include for disassembly
#include <spirv-tools/libspirv.hpp>
// Enable this for debug logging of pre-transform SPIR-V:
namespace sh
// A struct to hold either SPIR-V ids or literal constants. If id is not valid, a literal is
// assumed.
struct SpirvIdOrLiteral
SpirvIdOrLiteral() = default;
SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
spirv::IdRef id;
spirv::LiteralInteger literal;
// A data structure to facilitate generating array indexing, block field selection, swizzle and
// such. Used in conjunction with NodeData which includes the access chain's baseId and idList.
// - rvalue[literal].field[literal] generates OpCompositeExtract
// - rvalue.x generates OpCompositeExtract
// - generates OpVectorShuffle
// -[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
// OpVectorExtractDynamic as well)
// - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
// OpAccessChain and OpLoad
// - lvalue[i].field[j].x generates OpAccessChain and OpStore
// - generates an OpLoad followed by OpVectorShuffle and OpStore
// -[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
// OpVectorExtractDynamic as well)
// storageClass == Max implies an rvalue.
struct AccessChain
// The storage class for lvalues. If Max, it's an rvalue.
spv::StorageClass storageClass = spv::StorageClassMax;
// If the access chain ends in swizzle, the swizzle components are specified here. Swizzles
// select multiple components so need special treatment when used as lvalue.
std::vector<uint32_t> swizzles;
// If a vector component is selected dynamically (i.e. indexed with a non-literal index),
// dynamicComponent will contain the id of the index.
spirv::IdRef dynamicComponent;
// Type of base expression, before swizzle is applied, after swizzle is applied and after
// dynamic component is applied.
spirv::IdRef baseTypeId;
spirv::IdRef preSwizzleTypeId;
spirv::IdRef postSwizzleTypeId;
spirv::IdRef postDynamicComponentTypeId;
// If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
// id.
spirv::IdRef accessChainId;
// Whether all indices are literal. Avoids looping through indices to determine this
// information.
bool areAllIndicesLiteral = true;
// The number of components in the vector, if vector and swizzle is used. This is cached to
// avoid a type look up when handling swizzles.
uint8_t swizzledVectorComponentCount = 0;
// SPIR-V type specialization due to the base type. Used to correctly select the SPIR-V type
// id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
// This always corresponds to the specialization specific to the end result of the access chain,
// not the base or any intermediary types. For example, a struct nested in a column-major
// interface block, with a parent block qualified as row-major would specify row-major here.
SpirvTypeSpec typeSpec;
// As each node is traversed, it produces data. When visiting back the parent, this data is used to
// complete the data of the parent. For example, the children of a function call (i.e. the
// arguments) each produce a SPIR-V id corresponding to the result of their expression. The
// function call node itself in PostVisit uses those ids to generate the function call instruction.
struct NodeData
// An id whose meaning depends on the node. It could be a temporary id holding the result of an
// expression, a reference to a variable etc.
spirv::IdRef baseId;
// List of relevant SPIR-V ids accumulated while traversing the children. Meaning depends on
// the node, for example a list of parameters to be passed to a function, a set of ids used to
// construct an access chain etc.
std::vector<SpirvIdOrLiteral> idList;
// For constructing access chains.
AccessChain accessChain;
struct FunctionIds
// Id of the function type, return type and parameter types.
spirv::IdRef functionTypeId;
spirv::IdRef returnTypeId;
spirv::IdRefList parameterTypeIds;
// Id of the function itself.
spirv::IdRef functionId;
struct BuiltInResultStruct
// Some builtins require a struct result. The struct always has two fields of a scalar or
// vector type.
TBasicType lsbType;
TBasicType msbType;
uint32_t lsbPrimarySize;
uint32_t msbPrimarySize;
struct BuiltInResultStructHash
size_t operator()(const BuiltInResultStruct &key) const
static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
const uint8_t properties[4] = {
return angle::ComputeGenericHash(properties, sizeof(properties));
bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
return a.lsbType == b.lsbType && a.msbType == b.msbType &&
a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
bool IsAccessChainRValue(const AccessChain &accessChain)
return accessChain.storageClass == spv::StorageClassMax;
bool IsAccessChainUnindexedLValue(const NodeData &data)
return !IsAccessChainRValue(data.accessChain) && data.idList.empty() &&
data.accessChain.swizzles.empty() && !data.accessChain.dynamicComponent.valid();
// A traverser that generates SPIR-V as it walks the AST.
class OutputSPIRVTraverser : public TIntermTraverser
OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions, bool forceHighp);
~OutputSPIRVTraverser() override;
spirv::Blob getSpirv();
void visitSymbol(TIntermSymbol *node) override;
void visitConstantUnion(TIntermConstantUnion *node) override;
bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitUnary(Visit visit, TIntermUnary *node) override;
bool visitTernary(Visit visit, TIntermTernary *node) override;
bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override;
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBlock(Visit visit, TIntermBlock *node) override;
bool visitGlobalQualifierDeclaration(Visit visit,
TIntermGlobalQualifierDeclaration *node) override;
bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
bool visitLoop(Visit visit, TIntermLoop *node) override;
bool visitBranch(Visit visit, TIntermBranch *node) override;
void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
const TType &type,
spv::StorageClass *storageClass);
// Access chain handling.
// Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
// determine the typeId passed to |accessChainPush*|).
void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
void accessChainPushLiteral(NodeData *data,
spirv::LiteralInteger index,
spirv::IdRef typeId) const;
void accessChainPushSwizzle(NodeData *data,
const TVector<int> &swizzle,
spirv::IdRef typeId,
uint8_t componentCount) const;
void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
spirv::IdRef accessChainCollapse(NodeData *data);
spirv::IdRef accessChainLoad(NodeData *data,
const TType &valueType,
spirv::IdRef *resultTypeIdOut);
void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
// Access chain helpers.
void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
spirv::IdRef getAccessChainTypeId(NodeData *data);
// Node data handling.
void nodeDataInitLValue(NodeData *data,
spirv::IdRef baseId,
spirv::IdRef typeId,
spv::StorageClass storageClass,
const SpirvTypeSpec &typeSpec) const;
void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
void declareSpecConst(TIntermDeclaration *decl);
spirv::IdRef createConstant(const TType &type,
TBasicType expectedBasicType,
const TConstantUnion *constUnion,
bool isConstantNullValue);
spirv::IdRef createComplexConstant(const TType &type,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorVectorFromScalar(const TType &type,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorVectorFromScalarsAndVectors(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
// Load N values where N is the number of node's children. In some cases, the last M values are
// lvalues which should be skipped.
spirv::IdRefList loadAllParams(TIntermOperator *node, size_t skipCount);
void extractComponents(TIntermAggregate *node,
size_t componentCount,
const spirv::IdRefList &parameters,
spirv::IdRefList *extractedComponentsOut);
void startShortCircuit(TIntermBinary *node);
spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
spirv::IdRef createIncrementDecrement(TIntermOperator *node, spirv::IdRef resultTypeId);
spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
// Cast between types. There are two kinds of casts:
// - A constructor can cast between basic types, for example vec4(someInt).
// - Assignments, constructors, function calls etc may copy an array or struct between different
// block storages, invariance etc (which due to their decorations generate different SPIR-V
// types). For example:
// layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
spirv::IdRef castBasicType(spirv::IdRef value,
const TType &valueType,
TBasicType expectedBasicType,
spirv::IdRef *resultTypeIdOut);
spirv::IdRef cast(spirv::IdRef value,
const TType &valueType,
const SpirvTypeSpec &valueTypeSpec,
const SpirvTypeSpec &expectedTypeSpec,
spirv::IdRef *resultTypeIdOut);
// Helper to reduce vector == and != with OpAll and OpAny respectively. If multiple ids are
// given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
// constructed and OpAll and OpAny used.
spirv::IdRef reduceBoolVector(TOperator op,
const spirv::IdRefList &valueIds,
spirv::IdRef typeId,
const SpirvDecorations &decorations);
// Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
void createCompareImpl(TOperator op,
const TType &operandType,
spirv::IdRef resultTypeId,
spirv::IdRef leftId,
spirv::IdRef rightId,
const SpirvDecorations &operandDecorations,
const SpirvDecorations &resultDecorations,
spirv::LiteralIntegerList *currentAccessChain,
spirv::IdRefList *intermediateResultsOut);
// For some builtins, SPIR-V outputs two values in a struct. This function defines such a
// struct if not already defined.
spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
// Once the builtin instruction is generated, the two return values are extracted from the
// struct. These are written to the return value (if any) and the out parameters.
void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
size_t lvalueCount,
spirv::IdRef structValue,
spirv::IdRef returnValue,
spirv::IdRef returnValueType);
void storeBuiltInStructOutputInParamHelper(NodeData *data,
TIntermTyped *param,
spirv::IdRef structValue,
uint32_t fieldIndex);
TCompiler *mCompiler;
ShCompileOptions mCompileOptions;
SPIRVBuilder mBuilder;
// Traversal state. Nodes generally push() once to this stack on PreVisit. On InVisit and
// PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
// in back() (data corresponding to the node itself). On PostVisit, code is generated.
std::vector<NodeData> mNodeData;
// A map of TSymbol to its SPIR-V id. This could be a:
// - TVariable, or
// - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
// don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
// A map of TFunction to its various SPIR-V ids.
angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
// A map of internally defined structs used to capture result of some SPIR-V instructions.
angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
// Whether the current symbol being visited is being declared.
bool mIsSymbolBeingDeclared = false;
spv::StorageClass GetStorageClass(const TType &type)
// Opaque uniforms (samplers and images) have the UniformConstant storage class
if (type.isSampler() || type.isImage())
return spv::StorageClassUniformConstant;
const TQualifier qualifier = type.getQualifier();
// Input varying and IO blocks have the Input storage class
if (IsShaderIn(qualifier))
return spv::StorageClassInput;
// Output varying and IO blocks have the Input storage class
if (IsShaderOut(qualifier))
return spv::StorageClassOutput;
// Uniform and storage buffers have the Uniform storage class. Default uniforms are gathered in
// a uniform block as well.
if (type.getInterfaceBlock() != nullptr || qualifier == EvqUniform)
// I/O blocks must have already been classified as input or output above.
return spv::StorageClassUniform;
switch (qualifier)
case EvqShared:
// Compute shader shared memory has the Workgroup storage class
return spv::StorageClassWorkgroup;
case EvqGlobal:
// Global variables have the Private class.
return spv::StorageClassPrivate;
case EvqTemporary:
case EvqIn:
case EvqOut:
case EvqInOut:
// Function-local variables have the Function class
return spv::StorageClassFunction;
case EvqVertexID:
case EvqInstanceID:
case EvqFragCoord:
case EvqFrontFacing:
case EvqPointCoord:
case EvqHelperInvocation:
case EvqNumWorkGroups:
case EvqWorkGroupID:
case EvqLocalInvocationID:
case EvqGlobalInvocationID:
case EvqLocalInvocationIndex:
return spv::StorageClassInput;
case EvqFragDepth:
return spv::StorageClassOutput;
// TODO:
return spv::StorageClassPrivate;
OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
ShCompileOptions compileOptions,
bool forceHighp)
: TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
const TType &type,
spv::StorageClass *storageClass)
*storageClass = GetStorageClass(type);
auto iter = mSymbolIdMap.find(symbol);
if (iter != mSymbolIdMap.end())
return iter->second;
// This must be an implicitly defined variable, define it now.
const char *name = nullptr;
spv::BuiltIn builtInDecoration = spv::BuiltInMax;
switch (type.getQualifier())
case EvqVertexID:
name = "gl_VertexIndex";
builtInDecoration = spv::BuiltInVertexIndex;
case EvqInstanceID:
name = "gl_InstanceIndex";
builtInDecoration = spv::BuiltInInstanceIndex;
// Fragment shader built-ins
case EvqFragCoord:
name = "gl_FragCoord";
builtInDecoration = spv::BuiltInFragCoord;
case EvqFrontFacing:
name = "gl_FrontFacing";
builtInDecoration = spv::BuiltInFrontFacing;
case EvqPointCoord:
name = "gl_PointCoord";
builtInDecoration = spv::BuiltInPointCoord;
case EvqFragDepth:
name = "gl_FragDepth";
builtInDecoration = spv::BuiltInFragDepth;
case EvqHelperInvocation:
name = "gl_HelperInvocation";
builtInDecoration = spv::BuiltInHelperInvocation;
// Compute shader built-ins
case EvqNumWorkGroups:
name = "gl_NumWorkGroups";
builtInDecoration = spv::BuiltInNumWorkgroups;
case EvqWorkGroupID:
name = "gl_WorkGroupID";
builtInDecoration = spv::BuiltInWorkgroupId;
case EvqLocalInvocationID:
name = "gl_LocalInvocationID";
builtInDecoration = spv::BuiltInLocalInvocationId;
case EvqGlobalInvocationID:
name = "gl_GlobalInvocationID";
builtInDecoration = spv::BuiltInGlobalInvocationId;
case EvqLocalInvocationIndex:
name = "gl_LocalInvocationIndex";
builtInDecoration = spv::BuiltInLocalInvocationIndex;
// TODO: more built-ins.
const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
const spirv::IdRef varId = mBuilder.declareVariable(
typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name);
spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
mSymbolIdMap.insert({symbol, varId});
return varId;
void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
spirv::IdRef baseId,
spirv::IdRef typeId,
spv::StorageClass storageClass,
const SpirvTypeSpec &typeSpec) const
*data = {};
// Initialize the access chain as an lvalue. Useful when an access chain is resolved, but needs
// to be replaced by a reference to a temporary variable holding the result.
data->baseId = baseId;
data->accessChain.baseTypeId = typeId;
data->accessChain.preSwizzleTypeId = typeId;
data->accessChain.storageClass = storageClass;
data->accessChain.typeSpec = typeSpec;
void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
spirv::IdRef baseId,
spirv::IdRef typeId) const
*data = {};
// Initialize the access chain as an rvalue. Useful when an access chain is resolved, and needs
// to be replaced by a reference to it.
data->baseId = baseId;
data->accessChain.baseTypeId = typeId;
data->accessChain.preSwizzleTypeId = typeId;
void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
AccessChain &accessChain = data->accessChain;
// Adjust |typeSpec| based on the type (which implies what the index does; select an array
// element, a block field etc). Index is only meaningful for selecting block fields.
if (parentType.isArray())
(parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
const TFieldListCollection *block = parentType.getInterfaceBlock();
if (!parentType.isInterfaceBlock())
block = parentType.getStruct();
const TType &fieldType = *block->fields()[index]->type();
else if (parentType.isMatrix())
void OutputSPIRVTraverser::accessChainPush(NodeData *data,
spirv::IdRef index,
spirv::IdRef typeId) const
// Simply add the index to the chain of indices.
data->accessChain.areAllIndicesLiteral = false;
data->accessChain.preSwizzleTypeId = typeId;
void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
spirv::LiteralInteger index,
spirv::IdRef typeId) const
// Add the literal integer in the chain of indices. Since this is an id list, fake it as an id.
data->accessChain.preSwizzleTypeId = typeId;
void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
const TVector<int> &swizzle,
spirv::IdRef typeId,
uint8_t componentCount) const
AccessChain &accessChain = data->accessChain;
// Record the swizzle as multi-component swizzles require special handling. When loading
// through the access chain, the swizzle is applied after loading the vector first (see
// |accessChainLoad()|). When storing through the access chain, the whole vector is loaded,
// swizzled components overwritten and the whoel vector written back (see |accessChainStore()|).
if (swizzle.size() == 1)
// If this swizzle is selecting a single component, fold it into the access chain.
accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
// Otherwise keep them separate.
accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
accessChain.postSwizzleTypeId = typeId;
accessChain.swizzledVectorComponentCount = componentCount;
void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
spirv::IdRef index,
spirv::IdRef typeId)
AccessChain &accessChain = data->accessChain;
// Record the index used to dynamically select a component of a vector.
if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
// If the access chain is an rvalue with all-literal indices, keep this index separate so
// that OpCompositeExtract can be used for the access chain up to this index.
accessChain.dynamicComponent = index;
accessChain.postDynamicComponentTypeId = typeId;
if (!accessChain.swizzles.empty())
// Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
// single dynamic component selection.
ASSERT(accessChain.swizzles.size() > 1);
// Create a vector constant from the swizzles.
spirv::IdRefList swizzleIds;
for (uint32_t component : accessChain.swizzles)
const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
const spirv::IdRef swizzlesId = mBuilder.getNewId({});
spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
swizzlesId, swizzleIds);
// Index that vector constant with the dynamic index. For example, vec.ywxz[i] becomes the
// constant {1, 3, 0, 2} indexed with i, and that index used on vec.
const spirv::IdRef newIndex = mBuilder.getNewId({});
spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
newIndex, swizzlesId, index);
index = newIndex;
// Fold it into the access chain.
accessChainPush(data, index, typeId);
spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
AccessChain &accessChain = data->accessChain;
ASSERT(accessChain.storageClass != spv::StorageClassMax);
if (accessChain.accessChainId.valid())
return accessChain.accessChainId;
// If there are no indices, the baseId is where access is done to/from.
if (data->idList.empty())
accessChain.accessChainId = data->baseId;
return accessChain.accessChainId;
// Otherwise create an OpAccessChain instruction. Swizzle handling is special as it selects
// multiple components, and is done differently for load and store.
spirv::IdRefList indexIds;
makeAccessChainIdList(data, &indexIds);
const spirv::IdRef typePointerId =
mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
accessChain.accessChainId = mBuilder.getNewId({});
spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
accessChain.accessChainId, data->baseId, indexIds);
return accessChain.accessChainId;
spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
const TType &valueType,
spirv::IdRef *resultTypeIdOut)
const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
// Loading through the access chain can generate different instructions based on whether it's an
// rvalue, the indices are literal, there's a swizzle etc.
// - If rvalue:
// * With indices:
// + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
// + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
// the rvalue into it, then use OpAccessChain and OpLoad to load from it.
// * Without indices: Take the base id.
// - If lvalue:
// * With indices: Use OpAccessChain and OpLoad
// * Without indices: Use OpLoad
// - With swizzle: Use OpVectorShuffle on the result of the previous step
// - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
AccessChain &accessChain = data->accessChain;
if (resultTypeIdOut)
*resultTypeIdOut = getAccessChainTypeId(data);
spirv::IdRef loadResult = data->baseId;
if (IsAccessChainRValue(accessChain))
if (data->idList.size() > 0)
if (accessChain.areAllIndicesLiteral)
// Use OpCompositeExtract on an rvalue with all literal indices.
spirv::LiteralIntegerList indexList;
makeAccessChainLiteralList(data, &indexList);
const spirv::IdRef result = mBuilder.getNewId(decorations);
accessChain.preSwizzleTypeId, result, loadResult,
loadResult = result;
// Create a temp variable to hold the rvalue so an access chain can be made on it.
const spirv::IdRef tempVar =
mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
decorations, nullptr, "indexable");
// Write the rvalue into the temp variable
spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
// Make the temp variable the source of the access chain.
data->baseId = tempVar;
data->accessChain.storageClass = spv::StorageClassFunction;
// Load from the temp variable.
const spirv::IdRef accessChainId = accessChainCollapse(data);
loadResult = mBuilder.getNewId(decorations);
accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
// Load from the access chain.
const spirv::IdRef accessChainId = accessChainCollapse(data);
loadResult = mBuilder.getNewId(decorations);
spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
loadResult, accessChainId, nullptr);
if (!accessChain.swizzles.empty())
// Single-component swizzles are already folded into the index list.
ASSERT(accessChain.swizzles.size() > 1);
// Take the loaded value and use OpVectorShuffle to create the swizzle.
spirv::LiteralIntegerList swizzleList;
for (uint32_t component : accessChain.swizzles)
const spirv::IdRef result = mBuilder.getNewId(decorations);
accessChain.postSwizzleTypeId, result, loadResult, loadResult,
loadResult = result;
if (accessChain.dynamicComponent.valid())
// Dynamic component in combination with swizzle is already folded.
// Use OpVectorExtractDynamic to select the component.
const spirv::IdRef result = mBuilder.getNewId(decorations);
accessChain.postDynamicComponentTypeId, result, loadResult,
loadResult = result;
// Upon loading values, cast them to the default SPIR-V variant.
const spirv::IdRef castResult =
cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
return castResult;
void OutputSPIRVTraverser::accessChainStore(NodeData *data,
spirv::IdRef value,
const TType &valueType)
// Storing through the access chain can generate different instructions based on whether the
// there's a swizzle.
// - Without swizzle: Use OpAccessChain and OpStore
// - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
// replace the components being overwritten. Finally, use OpStore to write the result back.
AccessChain &accessChain = data->accessChain;
// Single-component swizzles are already folded into the indices.
ASSERT(accessChain.swizzles.size() != 1);
// Since store can only happen through lvalues, it's impossible to have a dynamic component as
// that always gets folded into the indices except for rvalues.
const spirv::IdRef accessChainId = accessChainCollapse(data);
// Store through the access chain. The values are always cast to the default SPIR-V type
// variant when loaded from memory and operated on as such. When storing, we need to cast the
// result to the variant specified by the access chain.
value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
if (!accessChain.swizzles.empty())
// Load the vector before the swizzle.
const spirv::IdRef loadResult = mBuilder.getNewId({});
spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
loadResult, accessChainId, nullptr);
// Overwrite the components being written. This is done by first creating an identity
// swizzle, then replacing the components being written with a swizzle from the value. For
// example, take the following:
// vec4 v;
// v.zx = u;
// The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
// each (in this example, swizzles [0, 3] select from v and [4, 7] select from u). This
// algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
// 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
// {4+1, 1, 4+0, 3}.
spirv::LiteralIntegerList swizzleList;
for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
uint32_t srcComponent = 0;
for (uint32_t dstComponent : accessChain.swizzles)
swizzleList[dstComponent] =
spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
// Use the generated swizzle to select components from the loaded vector and the value to be
// written. Use the final result as the value to be written to the vector.
const spirv::IdRef result = mBuilder.getNewId({});
accessChain.preSwizzleTypeId, result, loadResult, value,
value = result;
spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
for (size_t index = 0; index < data->idList.size(); ++index)
spirv::IdRef indexId = data->idList[index].id;
if (!indexId.valid())
// The index is a literal integer, so replace it with an OpConstant id.
indexId = mBuilder.getUintConstant(data->idList[index].literal);
void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
spirv::LiteralIntegerList *literalsOut)
for (size_t index = 0; index < data->idList.size(); ++index)
spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
// Load and store through the access chain may be done in multiple steps. These steps produce
// the following types:
// - preSwizzleTypeId
// - postSwizzleTypeId
// - postDynamicComponentTypeId
// The last of these types is the final type of the expression this access chain corresponds to.
const AccessChain &accessChain = data->accessChain;
if (accessChain.postDynamicComponentTypeId.valid())
return accessChain.postDynamicComponentTypeId;
if (accessChain.postSwizzleTypeId.valid())
return accessChain.postSwizzleTypeId;
return accessChain.preSwizzleTypeId;
void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
const TIntermSequence &sequence = *decl->getSequence();
ASSERT(sequence.size() == 1);
TIntermBinary *assign = sequence.front()->getAsBinaryNode();
ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
ASSERT(initializer != nullptr);
const TType &type = symbol->getType();
const TVariable *variable = &symbol->variable();
// All spec consts in ANGLE are initialized to 0.
const spirv::IdRef specConstId =
mBuilder.declareSpecConst(type.getBasicType(), type.getLayoutQualifier().location,
// Remember the id of the variable for future look up.
ASSERT(mSymbolIdMap.count(variable) == 0);
mSymbolIdMap[variable] = specConstId;
spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
TBasicType expectedBasicType,
const TConstantUnion *constUnion,
bool isConstantNullValue)
const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
spirv::IdRefList componentIds;
// If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants. This
// is not done for basic scalar types as some instructions require an OpConstant and validation
// doesn't accept OpConstantNull (likely a spec bug).
const size_t size = type.getObjectSize();
const TBasicType basicType = type.getBasicType();
const bool isBasicScalar = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
basicType == EbtUInt || basicType == EbtBool);
const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
if (useOpConstantNull)
return mBuilder.getNullConstant(typeId);
if (type.getBasicType() == EbtStruct)
// If it's a struct constant, get the constant id for each field.
for (const TField *field : type.getStruct()->fields())
const TType *fieldType = field->type();
createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
constUnion += fieldType->getObjectSize();
// Otherwise get the constant id for each component.
ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
expectedBasicType == EbtUInt || expectedBasicType == EbtBool);
for (size_t component = 0; component < size; ++component, ++constUnion)
spirv::IdRef componentId;
// If the constant has a different type than expected, cast it right away.
TConstantUnion castConstant;
bool valid = castConstant.cast(expectedBasicType, *constUnion);
switch (castConstant.getType())
case EbtFloat:
componentId = mBuilder.getFloatConstant(castConstant.getFConst());
case EbtInt:
componentId = mBuilder.getIntConstant(castConstant.getIConst());
case EbtUInt:
componentId = mBuilder.getUintConstant(castConstant.getUConst());
case EbtBool:
componentId = mBuilder.getBoolConstant(castConstant.getBConst());
// If this is a composite, create a composite constant from the components.
if (type.getBasicType() == EbtStruct || componentIds.size() > 1)
return createComplexConstant(type, typeId, componentIds);
// Otherwise return the sole component.
ASSERT(componentIds.size() == 1);
return componentIds[0];
spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
if (type.isMatrix() && !type.isArray())
// Matrices are constructed from their columns.
spirv::IdRefList columnIds;
const spirv::IdRef columnTypeId =
mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
spirv::IdRefList columnParameters(columnParametersStart,
columnParametersStart + type.getRows());
columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
return mBuilder.getCompositeConstant(typeId, columnIds);
return mBuilder.getCompositeConstant(typeId, parameters);
spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
const TType &type = node->getType();
const TIntermSequence &arguments = *node->getSequence();
const TType &arg0Type = arguments[0]->getAsTyped()->getType();
// In some cases, constructors with constant value are not folded. If the constructor is a null
// value, use OpConstantNull to avoid creating a bunch of instructions. Otherwise, the constant
// is created below.
if (node->isConstantNullValue())
return mBuilder.getNullConstant(typeId);
// Take each constructor argument that is visited and evaluate it as rvalue
spirv::IdRefList parameters = loadAllParams(node, 0);
// Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
// (in each case, if the parameter doesn't match the type being constructed, it must be cast):
// - float(f): This should translate to just f
// - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
// - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
// results of v1.zy and v2.x. However, for simplicity it's easier to generate that
// instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
// used as parameter).
// - vecN(m): This takes N components from m in column-major order (for example, vec4
// constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
// This translates to OpCompositeConstruct with the id of the individual components extracted
// from m.
// - matNxM(f): This creates a diagonal matrix. It generates N OpCompositeConstruct
// instructions for each column (which are vecM), followed by an OpCompositeConstruct that
// constructs the final result.
// - matNxM(m):
// * With m larger than NxM, this extracts a submatrix out of m. It generates
// OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
// rows of m are more than M. OpCompositeConstruct is used to construct the final result.
// * If m is not larger than NxM, an identity matrix is created and superimposed with m.
// OpCompositeExtract is used to extract each component of m (that is necessary), and
// together with the zero or one constants necessary used to create the columns (with
// OpCompositeConstruct). OpCompositeConstruct is used to construct the final result.
// - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
// are extracted from the parameters, which are divided up and used to construct each column,
// which is finally constructed into the final result.
// Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
// each parameter which must enumerate every individual element / field.
// In some cases, constructors with constant value are not folded. That is handled here.
if (node->hasConstantValue())
return createComplexConstant(node->getType(), typeId, parameters);
if (type.isArray() || type.getStruct() != nullptr)
return createArrayOrStructConstructor(node, typeId, parameters);
// The following are simple casts:
// - basic(s) (where basic is int, uint, float or bool, and s is scalar).
// - gvecN(vN) (where the argument is a single vector with the same number of components).
// - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note that
// matrices are always float, so there's no actual cast and this would be a no-op.
const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
arg0Type.isVector() &&
type.getNominalSize() == arg0Type.getNominalSize();
const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
type.getRows() == arg0Type.getRows();
if (type.isScalar() || isSingleVectorCast || isSingleMatrixCast)
return castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
if (type.isVector())
if (arguments.size() == 1 && arg0Type.isScalar())
parameters[0] = castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
return createConstructorVectorFromScalar(node->getType(), typeId, parameters);
if (arguments.size() == 1 && arg0Type.isMatrix())
return createConstructorVectorFromMatrix(node, typeId, parameters);
return createConstructorVectorFromScalarsAndVectors(node, typeId, parameters);
if (arg0Type.isScalar())
parameters[0] = castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
return createConstructorMatrixFromScalar(node, typeId, parameters);
if (arg0Type.isMatrix())
return createConstructorMatrixFromMatrix(node, typeId, parameters);
return createConstructorMatrixFromVectors(node, typeId, parameters);
spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
return result;
spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
const TType &type,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
// vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
ASSERT(parameters.size() == 1);
spirv::IdRefList replicatedParameter(type.getNominalSize(), parameters[0]);
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
return result;
spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
// vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
spirv::IdRefList extractedComponents;
extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
// Construct the vector with the basic type of the argument, and cast it at end if needed.
ASSERT(parameters.size() == 1);
const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
const TBasicType expectedBasicType = node->getType().getBasicType();
spirv::IdRef argumentTypeId = typeId;
TType arg0TypeAsVector(arg0Type);
arg0TypeAsVector.setPrimarySize(static_cast<unsigned char>(node->getType().getNominalSize()));
if (arg0Type.getBasicType() != expectedBasicType)
argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
if (arg0Type.getBasicType() != expectedBasicType)
result = castBasicType(result, arg0TypeAsVector, expectedBasicType, nullptr);
return result;
spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalarsAndVectors(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
// vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
spirv::IdRefList extractedComponents;
extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
return result;
spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
// matNxM(f) translates to
// %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
// %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
// %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
// ...
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
const TType &type = node->getType();
const spirv::IdRef scalarId = parameters[0];
spirv::IdRef zeroId;
SpirvDecorations decorations = mBuilder.getDecorations(type);
switch (type.getBasicType())
case EbtFloat:
zeroId = mBuilder.getFloatConstant(0);
case EbtInt:
zeroId = mBuilder.getIntConstant(0);
case EbtUInt:
zeroId = mBuilder.getUintConstant(0);
case EbtBool:
zeroId = mBuilder.getBoolConstant(0);
spirv::IdRefList componentIds(type.getRows(), zeroId);
spirv::IdRefList columnIds;
const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
// Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
if (columnIndex < type.getRows())
componentIds[columnIndex] = scalarId;
if (columnIndex > 0 && columnIndex <= type.getRows())
componentIds[columnIndex - 1] = zeroId;
// Create the column.
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
columnIds.back(), componentIds);
// Create the matrix out of the columns.
const spirv::IdRef result = mBuilder.getNewId(decorations);
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
return result;
spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
// matNxM(v1.zy, v2.x, ...) translates to:
// %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
// ...
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
const TType &type = node->getType();
SpirvDecorations decorations = mBuilder.getDecorations(type);
spirv::IdRefList extractedComponents;
extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
spirv::IdRefList columnIds;
const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
// Chunk up the extracted components by column and construct intermediary vectors.
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
// Create the column.
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
columnIds.back(), componentIds);
const spirv::IdRef result = mBuilder.getNewId(decorations);
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
return result;
spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
// matNxM(m) translates to:
// - If m is SxR where S>=N and R>=M:
// %c0 = OpCompositeExtract %vecR %m 0
// %c1 = OpCompositeExtract %vecR %m 1
// ...
// // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
// ...
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
// - Otherwise, an identity matrix is created and super imposed by m:
// %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
// %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
// %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
// %c3 = OpCompositeConstruct %vecM %0 %0 %0 %1
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
const TType &type = node->getType();
const TType &parameterType = (*node->getSequence())[0]->getAsTyped()->getType();
SpirvDecorations decorations = mBuilder.getDecorations(type);
ASSERT(parameters.size() == 1);
spirv::IdRefList columnIds;
const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
// If the parameter is a larger matrix than the constructor type, extract the columns
// directly and potentially swizzle them.
SpirvType paramColumnType = mBuilder.getSpirvType(parameterType, {});
paramColumnType.secondarySize = 1;
const spirv::IdRef paramColumnTypeId =
mBuilder.getSpirvTypeData(paramColumnType, nullptr).id;
const bool needsSwizzle = parameterType.getRows() > type.getRows();
spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
// Extract the column.
const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
parameterColumnId, parameters[0],
// If the column has too many components, select the appropriate number of components.
spirv::IdRef constructorColumnId = parameterColumnId;
if (needsSwizzle)
constructorColumnId = mBuilder.getNewId(decorations);
spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
constructorColumnId, parameterColumnId, parameterColumnId,
// Otherwise create an identity matrix and fill in the components that can be taken from the
// given parameter.
SpirvType paramComponentType = mBuilder.getSpirvType(parameterType, {});
paramComponentType.primarySize = 1;
paramComponentType.secondarySize = 1;
const spirv::IdRef paramComponentTypeId =
mBuilder.getSpirvTypeData(paramComponentType, nullptr).id;
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
spirv::IdRefList componentIds;
for (int componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
// Take the component from the constructor parameter if possible.
spirv::IdRef componentId;
if (componentIndex < parameterType.getRows())
componentId = mBuilder.getNewId(decorations);
paramComponentTypeId, componentId, parameters[0],
const bool isOnDiagonal = columnIndex == componentIndex;
switch (type.getBasicType())
case EbtFloat:
componentId = mBuilder.getFloatConstant(isOnDiagonal ? 0.0f : 1.0f);
case EbtInt:
componentId = mBuilder.getIntConstant(isOnDiagonal ? 0 : 1);
case EbtUInt:
componentId = mBuilder.getUintConstant(isOnDiagonal ? 0 : 1);
case EbtBool:
componentId = mBuilder.getBoolConstant(isOnDiagonal);
// Create the column vector.
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
columnIds.back(), componentIds);
const spirv::IdRef result = mBuilder.getNewId(decorations);
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
return result;
spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node, size_t skipCount)
const size_t parameterCount = node->getChildCount();
spirv::IdRefList parameters;
for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
// Take each parameter that is visited and evaluate it as rvalue
NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
const spirv::IdRef paramValue = accessChainLoad(
&param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
return parameters;
void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
size_t componentCount,
const spirv::IdRefList &parameters,
spirv::IdRefList *extractedComponentsOut)
// A helper function that takes the list of parameters passed to a constructor (which may have
// more components than necessary) and extracts the first componentCount components.
const TIntermSequence &arguments = *node->getSequence();
const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
const TBasicType expectedBasicType = node->getType().getBasicType();
ASSERT(arguments.size() == parameters.size());
for (size_t argumentIndex = 0;
argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
TIntermNode *argument = arguments[argumentIndex];
const TType &argumentType = argument->getAsTyped()->getType();
const spirv::IdRef parameterId = parameters[argumentIndex];
if (argumentType.isScalar())
// For scalar parameters, there's nothing to do other than a potential cast.
const spirv::IdRef castParameterId =
? parameterId
: castBasicType(parameterId, argumentType, expectedBasicType, nullptr);
if (argumentType.isVector())
SpirvType componentType = mBuilder.getSpirvType(argumentType, {});
componentType.type = expectedBasicType;
componentType.primarySize = 1;
const spirv::IdRef componentTypeId =
mBuilder.getSpirvTypeData(componentType, nullptr).id;
// Cast the whole vector parameter in one go.
const spirv::IdRef castParameterId =
? parameterId
: castBasicType(parameterId, argumentType, expectedBasicType, nullptr);
// For vector parameters, take components out of the vector one by one.
for (int componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
extractedComponentsOut->size() < componentCount;
const spirv::IdRef componentId = mBuilder.getNewId(decorations);
componentTypeId, componentId, castParameterId,
SpirvType componentType = mBuilder.getSpirvType(argumentType, {});
componentType.primarySize = 1;
componentType.secondarySize = 1;
const spirv::IdRef componentTypeId = mBuilder.getSpirvTypeData(componentType, nullptr).id;
// For matrix parameters, take components out of the matrix one by one in column-major
// order. No cast is done here; it would only be required for vector constructors with
// matrix parameters, in which case the resulting vector is cast in the end.
for (int columnIndex = 0; columnIndex < argumentType.getCols() &&
extractedComponentsOut->size() < componentCount;
for (int componentIndex = 0; componentIndex < argumentType.getRows() &&
extractedComponentsOut->size() < componentCount;
const spirv::IdRef componentId = mBuilder.getNewId(decorations);
mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
{spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
// Emulate && and || as such:
// || => if (!left) result = right
// && => if ( left) result = right
// When this function is called, |left| has already been visited, so it creates the appropriate
// |if| construct in preparation for visiting |right|.
// Load |left| and replace the access chain with an rvalue that's the result.
spirv::IdRef typeId;
const spirv::IdRef left =
accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
nodeDataInitRValue(&mNodeData.back(), left, typeId);
// Keep the id of the block |left| was evaluated in.
// Two blocks necessary, one for the |if| block, and one for the merge block.
mBuilder.startConditional(2, false, false);
// Generate the branch instructions.
const SpirvConditional *conditional = mBuilder.getCurrentConditional();
const spirv::IdRef mergeBlock = conditional->blockIds.back();
const spirv::IdRef ifBlock = conditional->blockIds.front();
const spirv::IdRef trueBlock = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
// Note that no logical not is necessary. For ||, the branch will target the merge block in the
// true case.
mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
// Load the right hand side.
const spirv::IdRef right =
accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
// Get the id of the block |right| is evaluated in.
const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
// And the cached id of the block |left| is evaluated in.
ASSERT(mNodeData.back().idList.size() == 1);
const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
// Move on to the merge block.
// Pop from the conditional stack.
// Get the previously loaded result of the left hand side.
*typeId = mNodeData.back().accessChain.baseTypeId;
const spirv::IdRef left = mNodeData.back().baseId;
// Create an OpPhi instruction that selects either the |left| or |right| based on which block
// was traversed.
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
{spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
return result;
spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
spirv::IdRef resultTypeId)
const TFunction *function = node->getFunction();
ASSERT(mFunctionIdMap.count(function) > 0);
const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
// Get the list of parameters passed to the function. The function parameters can only be
// memory variables, or if the function argument is |const|, an rvalue.
// For in variables:
// - If the parameter is const, pass it directly as rvalue, otherwise
// - If the parameter is an unindexed lvalue, pass it directly, otherwise
// - Write it to a temp variable first and pass that.
// For out variables:
// - If the parameter is an unindexed lvalue, pass it directly, otherwise
// - Pass a temporary variable. After the function call, copy that variable to the parameter.
// For inout variables:
// - If the parameter is an unindexed lvalue, pass it directly, otherwise
// - Write the parameter to a temp variable and pass that. After the function call, copy that
// variable back to the parameter.
// - For opaque uniforms, pass it directly as lvalue,
const size_t parameterCount = node->getChildCount();
spirv::IdRefList parameters;
spirv::IdRefList tempVarIds(parameterCount);
spirv::IdRefList tempVarTypeIds(parameterCount);
for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
const TType &paramType = function->getParam(paramIndex)->getType();
const TQualifier &paramQualifier = paramType.getQualifier();
NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
spirv::IdRef paramValue;
if (paramQualifier == EvqConst)
// |const| parameters are passed as rvalue.
paramValue = accessChainLoad(&param, paramType, nullptr);
else if (IsOpaqueType(paramType.getBasicType()))
// Opaque uniforms are passed by pointer.
paramValue = accessChainCollapse(&param);
else if (IsAccessChainUnindexedLValue(param) &&
(param.accessChain.storageClass == spv::StorageClassFunction &&
(mCompileOptions & SH_GENERATE_SPIRV_WORKAROUNDS) == 0))
// Unindexed lvalues are passed directly.
// This optimization is not applied on buggy drivers.
paramValue = param.baseId;
ASSERT(paramQualifier == EvqIn || paramQualifier == EvqOut ||
paramQualifier == EvqInOut);
// Need to create a temp variable and pass that.
tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
tempVarIds[paramIndex] =
mBuilder.declareVariable(tempVarTypeIds[paramIndex], spv::StorageClassFunction,
mBuilder.getDecorations(paramType), nullptr, "param");
// If it's an in or inout parameter, the temp variable needs to be initialized with the
// value of the parameter first.
if (paramQualifier == EvqIn || paramQualifier == EvqInOut)
paramValue = accessChainLoad(&param, paramType, nullptr);
spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
paramValue, nullptr);
paramValue = tempVarIds[paramIndex];
// Make the actual function call.
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
functionId, parameters);
// Copy from the out and inout temp variables back to the original parameters.
for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
if (!tempVarIds[paramIndex].valid())
const TType &paramType = function->getParam(paramIndex)->getType();
const TQualifier &paramQualifier = paramType.getQualifier();
NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
if (paramQualifier == EvqIn)
// Copy from the temp variable to the parameter.
NodeData tempVarData;
nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
spv::StorageClassFunction, {});
const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, paramType, nullptr);
accessChainStore(&param, tempVarValue, function->getParam(paramIndex)->getType());
return result;
bool IsShortCircuitNeeded(TIntermOperator *node)
TOperator op = node->getOp();
// Short circuit is only necessary for && and ||.
if (op != EOpLogicalAnd && op != EOpLogicalOr)
return false;
ASSERT(node->getChildCount() == 2);
// If the right hand side does not have side effects, short-circuiting is unnecessary.
// TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
// complexity of the right hand side expression. We could potentially only allow
// OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
// expressions be placed inside an if block.
return node->getChildNode(1)->getAsTyped()->hasSideEffects();
using WriteUnaryOp = void (*)(spirv::Blob *blob,
spirv::IdResultType idResultType,
spirv::IdResult idResult,
spirv::IdRef operand);
using WriteBinaryOp = void (*)(spirv::Blob *blob,
spirv::IdResultType idResultType,
spirv::IdResult idResult,
spirv::IdRef operand1,
spirv::IdRef operand2);
using WriteTernaryOp = void (*)(spirv::Blob *blob,
spirv::IdResultType idResultType,
spirv::IdResult idResult,
spirv::IdRef operand1,
spirv::IdRef operand2,
spirv::IdRef operand3);
using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
spirv::IdResultType idResultType,
spirv::IdResult idResult,
spirv::IdRef operand1,
spirv::IdRef operand2,
spirv::IdRef operand3,
spirv::IdRef operand4);
using WriteAtomicOp = void (*)(spirv::Blob *blob,
spirv::IdResultType idResultType,
spirv::IdResult idResult,
spirv::IdRef pointer,
spirv::IdScope scope,
spirv::IdMemorySemantics semantics,
spirv::IdRef value);
spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
// Handle special groups.
const TOperator op = node->getOp();
if (op == EOpPostIncrement || op == EOpPreIncrement || op == EOpPostDecrement ||
op == EOpPreDecrement)
return createIncrementDecrement(node, resultTypeId);
if (op == EOpEqual || op == EOpNotEqual)
return createCompare(node, resultTypeId);
if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
return createAtomicBuiltIn(node, resultTypeId);
if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
return createImageTextureBuiltIn(node, resultTypeId);
if (BuiltInGroup::IsInterpolationFS(op))
return createInterpolate(node, resultTypeId);
const size_t childCount = node->getChildCount();
TIntermTyped *firstChild = node->getChildNode(0)->getAsTyped();
const TType &firstOperandType = firstChild->getType();
const TBasicType basicType = firstOperandType.getBasicType();
const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
const bool isUnsigned = basicType == EbtUInt;
const bool isBool = basicType == EbtBool;
// Whether the operation needs to be applied column by column.
TIntermBinary *asBinary = node->getAsBinaryNode();
bool operateOnColumns = asBinary && (asBinary->getLeft()->getType().isMatrix() ||
// Whether the operands need to be swapped in the (binary) instruction
bool binarySwapOperands = false;
// Whether the scalar operand needs to be extended to match the other operand which is a vector
// (in a binary operation).
bool binaryExtendScalarToVector = true;
// Some built-ins have out parameters at the end of the list of parameters.
size_t lvalueCount = 0;
WriteUnaryOp writeUnaryOp = nullptr;
WriteBinaryOp writeBinaryOp = nullptr;
WriteTernaryOp writeTernaryOp = nullptr;
WriteQuaternaryOp writeQuaternaryOp = nullptr;
// Some operators are implemented with an extended instruction.
spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
switch (op)
case EOpNegative:
if (isFloat)
writeUnaryOp = spirv::WriteFNegate;
writeUnaryOp = spirv::WriteSNegate;
case EOpPositive:
// This is a noop.
return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
case EOpLogicalNot:
case EOpNotComponentWise:
writeUnaryOp = spirv::WriteLogicalNot;
case EOpBitwiseNot:
writeUnaryOp = spirv::WriteNot;
case EOpAdd:
case EOpAddAssign:
if (isFloat)
writeBinaryOp = spirv::WriteFAdd;
writeBinaryOp = spirv::WriteIAdd;
case EOpSub:
case EOpSubAssign:
if (isFloat)
writeBinaryOp = spirv::WriteFSub;
writeBinaryOp = spirv::WriteISub;
case EOpMul:
case EOpMulAssign:
case EOpMatrixCompMult:
if (isFloat)
writeBinaryOp = spirv::WriteFMul;
writeBinaryOp = spirv::WriteIMul;
case EOpDiv:
case EOpDivAssign:
if (isFloat)
writeBinaryOp = spirv::WriteFDiv;
else if (isUnsigned)
writeBinaryOp = spirv::WriteUDiv;
writeBinaryOp = spirv::WriteSDiv;
case EOpIMod:
case EOpIModAssign:
if (isFloat)
writeBinaryOp = spirv::WriteFMod;
else if (isUnsigned)
writeBinaryOp = spirv::WriteUMod;
writeBinaryOp = spirv::WriteSMod;
case EOpEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdEqual;
else if (isBool)
writeBinaryOp = spirv::WriteLogicalEqual;
writeBinaryOp = spirv::WriteIEqual;
case EOpNotEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFUnordNotEqual;
else if (isBool)
writeBinaryOp = spirv::WriteLogicalNotEqual;
writeBinaryOp = spirv::WriteINotEqual;
case EOpLessThan:
case EOpLessThanComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdLessThan;
else if (isUnsigned)
writeBinaryOp = spirv::WriteULessThan;
writeBinaryOp = spirv::WriteSLessThan;
case EOpGreaterThan:
case EOpGreaterThanComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdGreaterThan;
else if (isUnsigned)
writeBinaryOp = spirv::WriteUGreaterThan;
writeBinaryOp = spirv::WriteSGreaterThan;
case EOpLessThanEqual:
case EOpLessThanEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdLessThanEqual;
else if (isUnsigned)
writeBinaryOp = spirv::WriteULessThanEqual;
writeBinaryOp = spirv::WriteSLessThanEqual;
case EOpGreaterThanEqual:
case EOpGreaterThanEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
else if (isUnsigned)
writeBinaryOp = spirv::WriteUGreaterThanEqual;
writeBinaryOp = spirv::WriteSGreaterThanEqual;
case EOpVectorTimesScalar:
case EOpVectorTimesScalarAssign:
if (isFloat)
writeBinaryOp = spirv::WriteVectorTimesScalar;
binarySwapOperands = node->getChildNode(1)->getAsTyped()->getType().isVector();
binaryExtendScalarToVector = false;
writeBinaryOp = spirv::WriteIMul;
case EOpVectorTimesMatrix:
case EOpVectorTimesMatrixAssign:
writeBinaryOp = spirv::WriteVectorTimesMatrix;
operateOnColumns = false;
case EOpMatrixTimesVector:
writeBinaryOp = spirv::WriteMatrixTimesVector;
operateOnColumns = false;
case EOpMatrixTimesScalar:
case EOpMatrixTimesScalarAssign:
writeBinaryOp = spirv::WriteMatrixTimesScalar;
binarySwapOperands = asBinary->getRight()->getType().isMatrix();
operateOnColumns = false;
case EOpMatrixTimesMatrix:
case EOpMatrixTimesMatrixAssign:
writeBinaryOp = spirv::WriteMatrixTimesMatrix;
operateOnColumns = false;
case EOpLogicalOr:
binaryExtendScalarToVector = false;
writeBinaryOp = spirv::WriteLogicalOr;
case EOpLogicalXor:
binaryExtendScalarToVector = false;
writeBinaryOp = spirv::WriteLogicalNotEqual;
case EOpLogicalAnd:
binaryExtendScalarToVector = false;
writeBinaryOp = spirv::WriteLogicalAnd;
case EOpBitShiftLeft:
case EOpBitShiftLeftAssign:
writeBinaryOp = spirv::WriteShiftLeftLogical;
case EOpBitShiftRight:
case EOpBitShiftRightAssign:
if (isUnsigned)
writeBinaryOp = spirv::WriteShiftRightLogical;
writeBinaryOp = spirv::WriteShiftRightArithmetic;
case EOpBitwiseAnd:
case EOpBitwiseAndAssign:
writeBinaryOp = spirv::WriteBitwiseAnd;
case EOpBitwiseXor:
case EOpBitwiseXorAssign:
writeBinaryOp = spirv::WriteBitwiseXor;
case EOpBitwiseOr:
case EOpBitwiseOrAssign:
writeBinaryOp = spirv::WriteBitwiseOr;
case EOpRadians:
extendedInst = spv::GLSLstd450Radians;
case EOpDegrees:
extendedInst = spv::GLSLstd450Degrees;
case EOpSin:
extendedInst = spv::GLSLstd450Sin;
case EOpCos:
extendedInst = spv::GLSLstd450Cos;
case EOpTan:
extendedInst = spv::GLSLstd450Tan;
case EOpAsin:
extendedInst = spv::GLSLstd450Asin;
case EOpAcos:
extendedInst = spv::GLSLstd450Acos;
case EOpAtan:
extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
case EOpSinh:
extendedInst = spv::GLSLstd450Sinh;
case EOpCosh:
extendedInst = spv::GLSLstd450Cosh;
case EOpTanh:
extendedInst = spv::GLSLstd450Tanh;
case EOpAsinh:
extendedInst = spv::GLSLstd450Asinh;
case EOpAcosh:
extendedInst = spv::GLSLstd450Acosh;
case EOpAtanh:
extendedInst = spv::GLSLstd450Atanh;
case EOpPow:
extendedInst = spv::GLSLstd450Pow;
case EOpExp:
extendedInst = spv::GLSLstd450Exp;
case EOpLog:
extendedInst = spv::GLSLstd450Log;
case EOpExp2:
extendedInst = spv::GLSLstd450Exp2;
case EOpLog2:
extendedInst = spv::GLSLstd450Log2;
case EOpSqrt:
extendedInst = spv::GLSLstd450Sqrt;
case EOpInversesqrt:
extendedInst = spv::GLSLstd450InverseSqrt;
case EOpAbs:
if (isFloat)
extendedInst = spv::GLSLstd450FAbs;
extendedInst = spv::GLSLstd450SAbs;
case EOpSign:
if (isFloat)
extendedInst = spv::GLSLstd450FSign;
extendedInst = spv::GLSLstd450SSign;
case EOpFloor:
extendedInst = spv::GLSLstd450Floor;
case EOpTrunc:
extendedInst = spv::GLSLstd450Trunc;
case EOpRound:
extendedInst = spv::GLSLstd450Round;
case EOpRoundEven:
extendedInst = spv::GLSLstd450RoundEven;
case EOpCeil:
extendedInst = spv::GLSLstd450Ceil;
case EOpFract:
extendedInst = spv::GLSLstd450Fract;
case EOpMod:
if (isFloat)
writeBinaryOp = spirv::WriteFMod;
else if (isUnsigned)
writeBinaryOp = spirv::WriteUMod;
writeBinaryOp = spirv::WriteSMod;
case EOpMin:
if (isFloat)
extendedInst = spv::GLSLstd450FMin;
else if (isUnsigned)
extendedInst = spv::GLSLstd450UMin;
extendedInst = spv::GLSLstd450SMin;
case EOpMax:
if (isFloat)
extendedInst = spv::GLSLstd450FMax;
else if (isUnsigned)
extendedInst = spv::GLSLstd450UMax;
extendedInst = spv::GLSLstd450SMax;
case EOpClamp:
if (isFloat)
extendedInst = spv::GLSLstd450FClamp;
else if (isUnsigned)
extendedInst = spv::GLSLstd450UClamp;
extendedInst = spv::GLSLstd450SClamp;
case EOpMix:
if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
writeTernaryOp = spirv::WriteSelect;
extendedInst = spv::GLSLstd450FMix;
case EOpStep:
extendedInst = spv::GLSLstd450Step;
case EOpSmoothstep:
extendedInst = spv::GLSLstd450SmoothStep;
case EOpModf:
extendedInst = spv::GLSLstd450ModfStruct;
lvalueCount = 1;
case EOpIsnan:
writeUnaryOp = spirv::WriteIsNan;
case EOpIsinf:
writeUnaryOp = spirv::WriteIsInf;
case EOpFloatBitsToInt:
case EOpFloatBitsToUint:
case EOpIntBitsToFloat:
case EOpUintBitsToFloat:
writeUnaryOp = spirv::WriteBitcast;
case EOpFma:
extendedInst = spv::GLSLstd450Fma;
case EOpFrexp:
extendedInst = spv::GLSLstd450FrexpStruct;
lvalueCount = 1;
case EOpLdexp:
extendedInst = spv::GLSLstd450Ldexp;
case EOpPackSnorm2x16:
extendedInst = spv::GLSLstd450PackSnorm2x16;
case EOpPackUnorm2x16:
extendedInst = spv::GLSLstd450PackUnorm2x16;
case EOpPackHalf2x16:
extendedInst = spv::GLSLstd450PackHalf2x16;
case EOpUnpackSnorm2x16:
extendedInst = spv::GLSLstd450UnpackSnorm2x16;
case EOpUnpackUnorm2x16:
extendedInst = spv::GLSLstd450UnpackUnorm2x16;
case EOpUnpackHalf2x16:
extendedInst = spv::GLSLstd450UnpackHalf2x16;
case EOpPackUnorm4x8:
extendedInst = spv::GLSLstd450PackUnorm4x8;
case EOpPackSnorm4x8:
extendedInst = spv::GLSLstd450PackSnorm4x8;
case EOpUnpackUnorm4x8:
extendedInst = spv::GLSLstd450UnpackUnorm4x8;
case EOpUnpackSnorm4x8:
extendedInst = spv::GLSLstd450UnpackSnorm4x8;
case EOpPackDouble2x32:
case EOpUnpackDouble2x32:
// TODO: support desktop GLSL.
case EOpLength:
extendedInst = spv::GLSLstd450Length;
case EOpDistance:
extendedInst = spv::GLSLstd450Distance;
case EOpDot:
// Use normal multiplication for scalars.
if (firstOperandType.isScalar())
if (isFloat)
writeBinaryOp = spirv::WriteFMul;
writeBinaryOp = spirv::WriteIMul;
writeBinaryOp = spirv::WriteDot;
case EOpCross:
extendedInst = spv::GLSLstd450Cross;
case EOpNormalize:
extendedInst = spv::GLSLstd450Normalize;
case EOpFaceforward:
extendedInst = spv::GLSLstd450FaceForward;
case EOpReflect:
extendedInst = spv::GLSLstd450Reflect;
case EOpRefract:
extendedInst = spv::GLSLstd450Refract;
case EOpFtransform:
// TODO: support desktop GLSL.
case EOpOuterProduct:
writeBinaryOp = spirv::WriteOuterProduct;
case EOpTranspose:
writeUnaryOp = spirv::WriteTranspose;
case EOpDeterminant:
extendedInst = spv::GLSLstd450Determinant;
case EOpInverse:
extendedInst = spv::GLSLstd450MatrixInverse;
case EOpAny:
writeUnaryOp = spirv::WriteAny;
case EOpAll:
writeUnaryOp = spirv::WriteAll;
case EOpBitfieldExtract:
if (isUnsigned)
writeTernaryOp = spirv::WriteBitFieldUExtract;
writeTernaryOp = spirv::WriteBitFieldSExtract;
case EOpBitfieldInsert:
writeQuaternaryOp = spirv::WriteBitFieldInsert;
case EOpBitfieldReverse:
writeUnaryOp = spirv::WriteBitReverse;
case EOpBitCount:
writeUnaryOp = spirv::WriteBitCount;
case EOpFindLSB:
extendedInst = spv::GLSLstd450FindILsb;
case EOpFindMSB:
if (isUnsigned)
extendedInst = spv::GLSLstd450FindUMsb;
extendedInst = spv::GLSLstd450FindSMsb;
case EOpUaddCarry:
writeBinaryOp = spirv::WriteIAddCarry;
lvalueCount = 1;
case EOpUsubBorrow:
writeBinaryOp = spirv::WriteISubBorrow;
lvalueCount = 1;
case EOpUmulExtended:
writeBinaryOp = spirv::WriteUMulExtended;
lvalueCount = 2;
case EOpImulExtended:
writeBinaryOp = spirv::WriteSMulExtended;
lvalueCount = 2;
case EOpRgb_2_yuv:
case EOpYuv_2_rgb:
// TODO: There doesn't seem to be an equivalent in SPIR-V, and should likley be emulated
// as an AST transformation. Not supported by the Vulkan at the moment.
case EOpDFdx:
writeUnaryOp = spirv::WriteDPdx;
case EOpDFdy:
writeUnaryOp = spirv::WriteDPdy;
case EOpFwidth:
writeUnaryOp = spirv::WriteFwidth;
case EOpDFdxFine:
writeUnaryOp = spirv::WriteDPdxFine;
case EOpDFdyFine:
writeUnaryOp = spirv::WriteDPdyFine;
case EOpDFdxCoarse:
writeUnaryOp = spirv::WriteDPdxCoarse;
case EOpDFdyCoarse:
writeUnaryOp = spirv::WriteDPdyCoarse;
case EOpFwidthFine:
writeUnaryOp = spirv::WriteFwidthFine;
case EOpFwidthCoarse:
writeUnaryOp = spirv::WriteFwidthCoarse;
case EOpNoise1:
case EOpNoise2:
case EOpNoise3:
case EOpNoise4:
// TODO: support desktop GLSL.
case EOpSubpassLoad:
// TODO: support framebuffer fetch.
case EOpAnyInvocation:
case EOpAllInvocations:
case EOpAllInvocationsEqual:
// TODO: support desktop GLSL.
// Load the parameters.
spirv::IdRefList parameters = loadAllParams(node, lvalueCount);
const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
spirv::IdRef result;
if (node->getType().getBasicType() != EbtVoid)
result = mBuilder.getNewId(decorations);
// In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
// result is expected to be a struct instead.
spirv::IdRef builtInResultTypeId = resultTypeId;
spirv::IdRef builtInResult;
if (lvalueCount > 0)
builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
builtInResult = mBuilder.getNewId({});
builtInResult = result;
if (operateOnColumns)
// If negating a matrix, multiplying or comparing them, do that column by column.
spirv::IdRefList columnIds;
const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
const spirv::IdRef columnTypeId =
mBuilder.getBasicTypeId(firstOperandType.getBasicType(), firstOperandType.getRows());
if (binarySwapOperands)
std::swap(parameters[0], parameters[1]);
// Extract and apply the operator to each column.
for (int columnIndex = 0; columnIndex < firstOperandType.getCols(); ++columnIndex)
const spirv::IdRef columnIdA = mBuilder.getNewId(operandDecorations);
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
columnIdA, parameters[0],
if (writeUnaryOp)
writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
columnIds.back(), columnIdA);
const spirv::IdRef columnIdB = mBuilder.getNewId(operandDecorations);
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
columnIdB, parameters[1],
writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
columnIds.back(), columnIdA, columnIdB);
// Construct the result.
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
builtInResult, columnIds);
else if (writeUnaryOp)
ASSERT(parameters.size() == 1);
writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
else if (writeBinaryOp)
ASSERT(parameters.size() == 2);
// For vector<op>scalar operations that require it, turn the scalar into a vector of the
// same size.
if (binaryExtendScalarToVector)
const TType &leftType = node->getChildNode(0)->getAsTyped()->getType();
const TType &rightType = node->getChildNode(1)->getAsTyped()->getType();
if (leftType.isScalar() && rightType.isVector())
parameters[0] = createConstructorVectorFromScalar(rightType, builtInResultTypeId,
else if (rightType.isScalar() && leftType.isVector())
parameters[1] = createConstructorVectorFromScalar(leftType, builtInResultTypeId,
if (binarySwapOperands)
std::swap(parameters[0], parameters[1]);
// Write the operation that combines the left and right values.
writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
parameters[0], parameters[1]);
else if (writeTernaryOp)
ASSERT(parameters.size() == 3);
// mix(a, b, bool) is the same as bool ? b : a;
if (op == EOpMix)
std::swap(parameters[0], parameters[2]);
writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
parameters[0], parameters[1], parameters[2]);
else if (writeQuaternaryOp)
ASSERT(parameters.size() == 4);
writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
builtInResult, parameters[0], parameters[1], parameters[2],
// It's an extended instruction.
ASSERT(extendedInst != spv::GLSLstd450Bad);
spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
builtInResult, mBuilder.getExtInstImportIdStd(),
spirv::LiteralExtInstInteger(extendedInst), parameters);
// If it's an assignment, store the calculated value.
if (IsAssignment(node->getOp()))
ASSERT(mNodeData.size() >= 2);
ASSERT(parameters.size() == 2);
accessChainStore(&mNodeData[mNodeData.size() - 2], builtInResult, firstOperandType);
// If the operation returns a struct, load the lsb and msb and store them in result/out
// parameters.
if (lvalueCount > 0)
storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
return result;
spirv::IdRef OutputSPIRVTraverser::createIncrementDecrement(TIntermOperator *node,
spirv::IdRef resultTypeId)
TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
const TType &operandType = operand->getType();
const TBasicType basicType = operandType.getBasicType();
const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
// ++ and -- are implemented with binary SPIR-V ops.
WriteBinaryOp writeBinaryOp = nullptr;
switch (node->getOp())
case EOpPostIncrement:
case EOpPreIncrement:
if (isFloat)
writeBinaryOp = spirv::WriteFAdd;
writeBinaryOp = spirv::WriteIAdd;
case EOpPostDecrement:
case EOpPreDecrement:
if (isFloat)
writeBinaryOp = spirv::WriteFSub;
writeBinaryOp = spirv::WriteISub;
// Load the operand.
spirv::IdRef value = accessChainLoad(&mNodeData.back(), operandType, nullptr);
spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(operandType));
const spirv::IdRef one = isFloat ? mBuilder.getFloatConstant(1) : mBuilder.getIntConstant(1);
writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, value, one);
// The result is always written back.
accessChainStore(&mNodeData.back(), result, operandType);
// Initialize the access chain with either the result or the value based on whether pre or
// post increment/decrement was used. The result is always an rvalue.
if (node->getOp() == EOpPostIncrement || node->getOp() == EOpPostDecrement)
result = value;
return result;
spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
const TOperator op = node->getOp();
TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
const TType &operandType = operand->getType();
const SpirvDecorations resultDecorations = mBuilder.getDecorations(node->getType());
const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
// Load the left and right values.
spirv::IdRefList parameters = loadAllParams(node, 0);
ASSERT(parameters.size() == 2);
// In GLSL, operators == and != can operate on the following:
// - scalars: There's a SPIR-V instruction for this,
// - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
// with OpAll/OpAny for == and != respectively,
// - matrices: Comparison must be done column by column and the result reduced,
// - arrays: Comparison must be done on every array element and the result reduced,
// - structs: Comparison must be done on each field and the result reduced.
// For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
// more complex type, which is recursively traversed. The results are accumulated in a list
// that is then reduced 4 by 4 elements until a single boolean is produced.
spirv::LiteralIntegerList currentAccessChain;
spirv::IdRefList intermediateResults;
createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
operandDecorations, resultDecorations, &currentAccessChain,
// Make sure the function correctly pushes and pops access chain indices.
// Reduce the intermediate results.
// The following code implements this algorithm, assuming N bools are to be reduced:
// Reduced To Reduce
// {b1} {b2, b3, ..., bN} Initial state
// Loop
// {b1, b2, b3, b4} {b5, b6, ..., bN} Take up to 3 new bools
// {r1} {b5, b6, ..., bN} Reduce it
// Repeat
// In the end, a single value is left.
size_t reducedCount = 0;
spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
while (reducedCount < intermediateResults.size())
// Take up to 3 new bools.
size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
for (size_t i = 0; i < toTakeCount; ++i)
// Reduce them to one bool.
const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
// Replace the list of bools to reduce with the reduced one.
ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
return toReduce[0];
spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
spirv::IdRef resultTypeId)
const TType &operandType = node->getChildNode(0)->getAsTyped()->getType();
const TBasicType operandBasicType = operandType.getBasicType();
const bool isImage = IsImage(operandBasicType);
// Most atomic instructions are in the form of:
// %result = OpAtomicX %pointer Scope MemorySemantics %value
// OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
// order from GLSL):
// %result = OpAtomicCompareExchange %pointer
// Scope MemorySemantics MemorySemantics
// %value %comparator
// In all cases, the first parameter is the pointer, and the rest are rvalues.
// For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
// operation is being performed.
const size_t parameterCount = node->getChildCount();
size_t imagePointerParameterCount = 0;
spirv::IdRef pointerId;
spirv::IdRefList imagePointerParameters;
spirv::IdRefList parameters;
if (isImage)
// One parameter for coordinates.
if (IsImageMS(operandBasicType))
// One parameter for samples.
ASSERT(parameterCount >= 2 + imagePointerParameterCount);
pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
const spirv::IdRef parameter = accessChainLoad(
&param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
// imageAtomic* built-ins have a few additional parameters right after the image. These are
// kept separately for use with OpImageTexelPointer.
if (paramIndex <= imagePointerParameterCount)
// The scope of the operation is always Device as we don't enable the Vulkan memory model
// extension.
const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
// The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
const spirv::IdMemorySemantics semanticsId =
WriteAtomicOp writeAtomicOp = nullptr;
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
// Determine whether the operation is on ints or uints.
const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
// For images, convert the pointer to the image to a pointer to a texel in the image.
if (isImage)
const spirv::IdRef texelTypePointerId =
mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
const spirv::IdRef texelPointerId = mBuilder.getNewId({});
const spirv::IdRef coordinate = imagePointerParameters[0];
spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
: mBuilder.getUintConstant(0);
spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
texelPointerId, pointerId, coordinate, sample);
pointerId = texelPointerId;
switch (node->getOp())
case EOpAtomicAdd:
case EOpImageAtomicAdd:
writeAtomicOp = spirv::WriteAtomicIAdd;
case EOpAtomicMin:
case EOpImageAtomicMin:
writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
case EOpAtomicMax:
case EOpImageAtomicMax:
writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
case EOpAtomicAnd:
case EOpImageAtomicAnd:
writeAtomicOp = spirv::WriteAtomicAnd;
case EOpAtomicOr:
case EOpImageAtomicOr:
writeAtomicOp = spirv::WriteAtomicOr;
case EOpAtomicXor:
case EOpImageAtomicXor:
writeAtomicOp = spirv::WriteAtomicXor;
case EOpAtomicExchange:
case EOpImageAtomicExchange:
writeAtomicOp = spirv::WriteAtomicExchange;
case EOpAtomicCompSwap:
case EOpImageAtomicCompSwap:
// Generate this special instruction right here and early out. Note again that the
// value and compare parameters of OpAtomicCompareExchange are in the opposite order
// from GLSL.
ASSERT(parameters.size() == 2);
spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
result, pointerId, scopeId, semanticsId, semanticsId,
parameters[1], parameters[0]);
return result;
// Write the instruction.
ASSERT(parameters.size() == 1);
writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
semanticsId, parameters[0]);
return result;
spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
spirv::IdRef resultTypeId)
const TOperator op = node->getOp();
const TFunction *function = node->getAsAggregate()->getFunction();
const TType &samplerType = function->getParam(0)->getType();
const TBasicType samplerBasicType = samplerType.getBasicType();
// Load the parameters.
spirv::IdRefList parameters = loadAllParams(node, 0);
// GLSL texture* and image* built-ins map to the following SPIR-V instructions. Some of these
// instructions take a "sampled image" while the others take the image itself. In these
// functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
// parameters while the rest are bundled in a list of image operands.
// Image operations that query:
// - OpImageQuerySizeLod
// - OpImageQuerySize
// - OpImageQueryLod <-- sampled image
// - OpImageQueryLevels
// - OpImageQuerySamples
// Image operations that read/write:
// - OpImageSampleImplicitLod <-- sampled image
// - OpImageSampleExplicitLod <-- sampled image
// - OpImageSampleDrefImplicitLod <-- sampled image
// - OpImageSampleDrefExplicitLod <-- sampled image
// - OpImageSampleProjImplicitLod <-- sampled image
// - OpImageSampleProjExplicitLod <-- sampled image
// - OpImageSampleProjDrefImplicitLod <-- sampled image
// - OpImageSampleProjDrefExplicitLod <-- sampled image
// - OpImageFetch
// - OpImageGather <-- sampled image
// - OpImageDrefGather <-- sampled image
// - OpImageRead
// - OpImageWrite
// The additional image parameters are:
// - Bias: Only used with ImplicitLod.
// - Lod: Only used with ExplicitLod.
// - Grad: 2x operands; dx and dy. Only used with ExplicitLod.
// - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
// - Offset: Non-constant offset added to coordinates of OpImage*Gather.
// - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
// - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
// Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
// the SPIR-V image out of a SPIR-V sampled image.
// The first parameter, which is either a sampled image or an image. Some GLSL built-ins
// receive a sampled image but their SPIR-V equivalent expects an image. OpImage is used in
// that case.
spirv::IdRef image = parameters[0];
bool extractImageFromSampledImage = false;
// The argument index for different possible parameters. 0 indicates that the argument is
// unused. Coordinates are usually at index 1, so it's pre-initialized.
size_t coordinatesIndex = 1;
size_t biasIndex = 0;
size_t lodIndex = 0;
size_t compareIndex = 0;
size_t dPdxIndex = 0;
size_t dPdyIndex = 0;
size_t offsetIndex = 0;
size_t offsetsIndex = 0;
size_t gatherComponentIndex = 0;
size_t sampleIndex = 0;
size_t dataIndex = 0;
// Whether this is a Dref variant of a sample call.
bool isDref = IsShadowSampler(samplerBasicType);
// Whether this is a Proj variant of a sample call.
bool isProj = false;
// The SPIR-V op used to implement the built-in. For OpImageSample* instructions,
// OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
// and |isProj|.
spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
// Organize the parameters and decide the SPIR-V Op to use.
switch (op)
case EOpTexture2D:
case EOpTextureCube:
case EOpTexture1D:
case EOpTexture3D:
case EOpShadow1D:
case EOpShadow2D:
case EOpShadow2DEXT:
case EOpTexture2DRect:
case EOpTextureVideoWEBGL:
case EOpTexture:
case EOpTexture2DBias:
case EOpTextureCubeBias:
case EOpTexture3DBias:
case EOpTexture1DBias:
case EOpShadow1DBias:
case EOpShadow2DBias:
case EOpTextureBias:
// For shadow cube arrays, the compare value is specified through an additional
// parameter, while for the rest is taken out of the coordinates.
if (function->getParamCount() == 3)
if (samplerBasicType == EbtSamplerCubeArrayShadow)
compareIndex = 2;
biasIndex = 2;
case EOpTexture2DProj:
case EOpTexture1DProj:
case EOpTexture3DProj:
case EOpShadow1DProj:
case EOpShadow2DProj:
case EOpShadow2DProjEXT:
case EOpTexture2DRectProj:
case EOpTextureProj:
case EOpTexture2DProjBias:
case EOpTexture3DProjBias:
case EOpTexture1DProjBias:
case EOpShadow1DProjBias:
case EOpShadow2DProjBias:
case EOpTextureProjBias:
isProj = true;
if (function->getParamCount() == 3)
biasIndex = 2;
case EOpTexture2DLod:
case EOpTextureCubeLod:
case EOpTexture1DLod:
case EOpShadow1DLod:
case EOpShadow2DLod:
case EOpTexture3DLod:
case EOpTexture2DLodVS:
case EOpTextureCubeLodVS:
case EOpTexture2DLodEXTFS:
case EOpTextureCubeLodEXTFS:
case EOpTextureLod:
ASSERT(function->getParamCount() == 3);
lodIndex = 2;
case EOpTexture2DProjLod:
case EOpTexture1DProjLod:
case EOpShadow1DProjLod:
case EOpShadow2DProjLod:
case EOpTexture3DProjLod:
case EOpTexture2DProjLodVS:
case EOpTexture2DProjLodEXTFS:
case EOpTextureProjLod:
ASSERT(function->getParamCount() == 3);
isProj = true;
lodIndex = 2;
case EOpTexelFetch:
case EOpTexelFetchOffset:
// texelFetch has the following forms:
// - texelFetch(sampler, P);
// - texelFetch(sampler, P, lod);
// - texelFetch(samplerMS, P, sample);
// texelFetchOffset has an additional offset parameter at the end.
// In SPIR-V, OpImageFetch is used which operates on the image itself.
spirvOp = spv::OpImageFetch;
extractImageFromSampledImage = true;
if (IsSamplerMS(samplerBasicType))
ASSERT(function->getParamCount() == 3);
sampleIndex = 2;
else if (function->getParamCount() >= 3)
lodIndex = 2;
if (op == EOpTexelFetchOffset)
offsetIndex = function->getParamCount() - 1;
case EOpTexture2DGradEXT:
case EOpTextureCubeGradEXT:
case EOpTextureGrad:
ASSERT(function->getParamCount() == 4);
dPdxIndex = 2;
dPdyIndex = 3;
case EOpTexture2DProjGradEXT:
case EOpTextureProjGrad:
ASSERT(function->getParamCount() == 4);
isProj = true;
dPdxIndex = 2;
dPdyIndex = 3;
case EOpTextureOffset:
case EOpTextureOffsetBias:
ASSERT(function->getParamCount() >= 3);
offsetIndex = 2;
if (function->getParamCount() == 4)
biasIndex = 3;
case EOpTextureProjOffset:
case EOpTextureProjOffsetBias:
ASSERT(function->getParamCount() >= 3);
isProj = true;
offsetIndex = 2;
if (function->getParamCount() == 4)
biasIndex = 3;
case EOpTextureLodOffset:
ASSERT(function->getParamCount() == 4);
lodIndex = 2;
offsetIndex = 3;
case EOpTextureProjLodOffset:
ASSERT(function->getParamCount() == 4);
isProj = true;
lodIndex = 2;
offsetIndex = 3;
case EOpTextureGradOffset:
ASSERT(function->getParamCount() == 5);
dPdxIndex = 2;
dPdyIndex = 3;
offsetIndex = 4;
case EOpTextureProjGradOffset:
ASSERT(function->getParamCount() == 5);
isProj = true;
dPdxIndex = 2;
dPdyIndex = 3;
offsetIndex = 4;
case EOpTextureGather:
// For shadow textures, refZ (same as Dref) is specified as the last argument.
// Otherwise a component may be specified which defaults to 0 if not specified.
spirvOp = spv::OpImageGather;
if (isDref)
ASSERT(function->getParamCount() == 3);
compareIndex = 2;
else if (function->getParamCount() == 3)
gatherComponentIndex = 2;
case EOpTextureGatherOffset:
case EOpTextureGatherOffsetComp:
case EOpTextureGatherOffsets:
case EOpTextureGatherOffsetsComp:
// textureGatherOffset and textureGatherOffsets have the following forms:
// - texelGatherOffset*(sampler, P, offset*);
// - texelGatherOffset*(sampler, P, offset*, component);
// - texelGatherOffset*(sampler, P, refZ, offset*);
spirvOp = spv::OpImageGather;
if (isDref)
ASSERT(function->getParamCount() == 4);
compareIndex = 2;
else if (function->getParamCount() == 4)
gatherComponentIndex = 3;
ASSERT(function->getParamCount() >= 3);
if (BuiltInGroup::IsTextureGatherOffset(op))
offsetIndex = isDref ? 3 : 2;
offsetsIndex = isDref ? 3 : 2;
case EOpImageStore:
// imageStore has the following forms:
// - imageStore(image, P, data);
// - imageStore(imageMS, P, sample, data);
spirvOp = spv::OpImageWrite;
if (IsSamplerMS(samplerBasicType))
ASSERT(function->getParamCount() == 4);
sampleIndex = 2;
dataIndex = 3;
ASSERT(function->getParamCount() == 3);
dataIndex = 2;
case EOpImageLoad:
// imageStore has the following forms:
// - imageLoad(image, P);
// - imageLoad(imageMS, P, sample);
spirvOp = spv::OpImageRead;
if (IsSamplerMS(samplerBasicType))
ASSERT(function->getParamCount() == 3);
sampleIndex = 2;
ASSERT(function->getParamCount() == 2);
// Queries:
case EOpTextureSize:
case EOpImageSize:
// textureSize has the following forms:
// - textureSize(sampler);
// - textureSize(sampler, lod);
// while imageSize has only one form:
// - imageSize(image);
extractImageFromSampledImage = true;
if (function->getParamCount() == 2)
spirvOp = spv::OpImageQuerySizeLod;
lodIndex = 1;
spirvOp = spv::OpImageQuerySize;
// No coordinates parameter.
coordinatesIndex = 0;
case EOpTextureSamples:
case EOpImageSamples:
extractImageFromSampledImage = true;
spirvOp = spv::OpImageQuerySamples;
// No coordinates parameter.
coordinatesIndex = 0;
case EOpTextureQueryLevels:
extractImageFromSampledImage = true;
spirvOp = spv::OpImageQueryLevels;
// No coordinates parameter.
coordinatesIndex = 0;
case EOpTextureQueryLod:
spirvOp = spv::OpImageQueryLod;
// If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
// one as they are not allowed in SPIR-V outside fragment shaders.
bool makeLodExplicit =
mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 &&
(spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
// Apply any necessary fix up.
if (extractImageFromSampledImage && IsSampler(samplerBasicType))
// Get the (non-sampled) image type.
SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
imageType.isSamplerBaseImage = true;
const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
// Use OpImage to get the image out of the sampled image.
const spirv::IdRef extractedImage = mBuilder.getNewId({});
spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
extractedImage, image);
image = extractedImage;
// Gather operands as necessary.
// - Coordinates
int coordinatesChannelCount = 0;
spirv::IdRef coordinatesId;
const TType *coordinatesType = nullptr;
if (coordinatesIndex > 0)
coordinatesId = parameters[coordinatesIndex];
coordinatesType = &function->getParam(coordinatesIndex)->getType();
coordinatesChannelCount = coordinatesType->getNominalSize();
// - Dref; either specified as a compare/refz argument (cube array, gather), or:
// * coordinates.z for proj variants
// * coordinates.<last> for others
spirv::IdRef drefId;
if (compareIndex > 0)
drefId = parameters[compareIndex];
else if (isDref)
// Get the component index
ASSERT(coordinatesChannelCount > 0);
int drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
// Get the component type
SpirvType drefSpirvType = mBuilder.getSpirvType(*coordinatesType, {});
drefSpirvType.primarySize = 1;
const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
// Extract the dref component out of coordinates.
drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
coordinatesId, {spirv::LiteralInteger(drefComponent)});
// - Gather component
spirv::IdRef gatherComponentId;
if (gatherComponentIndex > 0)
gatherComponentId = parameters[gatherComponentIndex];
else if (spirvOp == spv::OpImageGather)
// If comp is not specified, component 0 is taken as default.
gatherComponentId = mBuilder.getIntConstant(0);
// - Image write data
spirv::IdRef dataId;
if (dataIndex > 0)
dataId = parameters[dataIndex];
// - Other operands
spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
spirv::IdRefList imageOperandsList;
if (biasIndex > 0)
operandsMask = operandsMask | spv::ImageOperandsBiasMask;
if (lodIndex > 0)
operandsMask = operandsMask | spv::ImageOperandsLodMask;
else if (makeLodExplicit)
// If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
// lod 0.
operandsMask = operandsMask | spv::ImageOperandsLodMask;
imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
: mBuilder.getFloatConstant(0));
if (dPdxIndex > 0)
ASSERT(dPdyIndex > 0);
operandsMask = operandsMask | spv::ImageOperandsGradMask;
if (offsetIndex > 0)
// Non-const offsets require the ImageGatherExtended feature.
if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
ASSERT(spirvOp == spv::OpImageGather);
operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
if (offsetsIndex > 0)
operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
if (sampleIndex > 0)
operandsMask = operandsMask | spv::ImageOperandsSampleMask;
const spv::ImageOperandsMask *imageOperands =
imageOperandsList.empty() ? nullptr : &operandsMask;
// GLSL and SPIR-V are different in the way the projective component is specified:
// In GLSL:
// > The texture coordinates consumed from P, not including the last component of P, are divided
// > by the last component of P.
// In SPIR-V, there's a similar language (division by last element), but with the following
// added:
// > ... all unused components will appear after all used components.
// So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
// where P.z is ignored. In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
if (isProj)
int requiredChannelCount = coordinatesChannelCount;
// texture*Proj* operate on the following parameters:
// - sampler1D, vec2 P
// - sampler1D, vec4 P
// - sampler2D, vec3 P
// - sampler2D, vec4 P
// - sampler2DRect, vec3 P
// - sampler2DRect, vec4 P
// - sampler3D, vec4 P
// - sampler1DShadow, vec4 P
// - sampler2DShadow, vec4 P
// - sampler2DRectShadow, vec4 P
// Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
// proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
if (IsSampler2D(samplerBasicType))
requiredChannelCount = 3;
else if (IsSampler1D(samplerBasicType))
requiredChannelCount = 2;
if (requiredChannelCount != coordinatesChannelCount)
ASSERT(coordinatesChannelCount == 4);
// Get the component type
SpirvType spirvType = mBuilder.getSpirvType(*coordinatesType, {});
const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
spirvType.primarySize = 1;
const spirv::IdRef channelTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
// Extract the last component out of coordinates.
const spirv::IdRef projChannelId =
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
projChannelId, coordinatesId,
{spirv::LiteralInteger(coordinatesChannelCount - 1)});
// Insert it after the channels that are consumed. The extra channels are ignored per
// the SPIR-V spec.
const spirv::IdRef newCoordinatesId =
spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
newCoordinatesId, coordinatesId, projChannelId,
{spirv::LiteralInteger(requiredChannelCount - 1)});
coordinatesId = newCoordinatesId;
// Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
if (spirvOp == spv::OpImageSampleImplicitLod)
const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
if (isDref)
if (isProj)
spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
: spv::OpImageSampleProjDrefImplicitLod;
spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
: spv::OpImageSampleDrefImplicitLod;
if (isProj)
spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
: spv::OpImageSampleProjImplicitLod;
spirvOp =
isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
if (spirvOp == spv::OpImageGather && isDref)
spirvOp = spv::OpImageDrefGather;
spirv::IdRef result;
if (spirvOp != spv::OpImageWrite)
result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
switch (spirvOp)
case spv::OpImageQuerySizeLod:
ASSERT(imageOperandsList.size() == 1);
spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
result, image, imageOperandsList[0]);
case spv::OpImageQuerySize:
spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
result, image);
case spv::OpImageQueryLod:
spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
image, coordinatesId);
case spv::OpImageQueryLevels:
spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
result, image);
case spv::OpImageQuerySamples:
spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
result, image);
case spv::OpImageSampleImplicitLod:
resultTypeId, result, image, coordinatesId,
imageOperands, imageOperandsList);
case spv::OpImageSampleExplicitLod:
resultTypeId, result, image, coordinatesId,
*imageOperands, imageOperandsList);
case spv::OpImageSampleDrefImplicitLod:
resultTypeId, result, image, coordinatesId,
drefId, imageOperands, imageOperandsList);
case spv::OpImageSampleDrefExplicitLod:
resultTypeId, result, image, coordinatesId,
drefId, *imageOperands, imageOperandsList);
case spv::OpImageSampleProjImplicitLod:
resultTypeId, result, image, coordinatesId,
imageOperands, imageOperandsList);
case spv::OpImageSampleProjExplicitLod:
resultTypeId, result, image, coordinatesId,
*imageOperands, imageOperandsList);
case spv::OpImageSampleProjDrefImplicitLod:
resultTypeId, result, image, coordinatesId,
drefId, imageOperands, imageOperandsList);
case spv::OpImageSampleProjDrefExplicitLod:
resultTypeId, result, image, coordinatesId,
drefId, *imageOperands, imageOperandsList);
case spv::OpImageFetch:
spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
image, coordinatesId, imageOperands, imageOperandsList);
case spv::OpImageGather:
spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
image, coordinatesId, gatherComponentId, imageOperands,
case spv::OpImageDrefGather:
spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
result, image, coordinatesId, drefId, imageOperands,
case spv::OpImageRead:
spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
image, coordinatesId, imageOperands, imageOperandsList);
case spv::OpImageWrite:
spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
dataId, imageOperands, imageOperandsList);
// In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
// OpImageSample*Dref* instructions produce a scalar. EXT_shadow_samplers in ESSL introduces
// similar functions but which return a scalar.
// TODO: For desktop GLSL, the result must be turned into a vec4.
return result;
spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
spirv::IdRef resultTypeId)
spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
switch (node->getOp())
case EOpInterpolateAtCentroid:
extendedInst = spv::GLSLstd450InterpolateAtCentroid;
case EOpInterpolateAtSample:
extendedInst = spv::GLSLstd450InterpolateAtSample;
case EOpInterpolateAtOffset:
extendedInst = spv::GLSLstd450InterpolateAtOffset;
size_t childCount = node->getChildCount();
spirv::IdRefList parameters;
// interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
// passed to the instruction. Except interpolateAtCentroid, another parameter follows.
parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
if (childCount > 1)
&mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
spirv::LiteralExtInstInteger(extendedInst), parameters);
return result;
spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
const TType &valueType,
TBasicType expectedBasicType,
spirv::IdRef *resultTypeIdOut)
if (valueType.getBasicType() == expectedBasicType)
return value;
SpirvType valueSpirvType = mBuilder.getSpirvType(valueType, {});
valueSpirvType.type = expectedBasicType;
valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(valueType));
// Write the instruction that casts between types. Different instructions are used based on the
// types being converted.
// - int/uint <-> float: OpConvert*To*
// - int <-> uint: OpBitcast
// - bool --> int/uint/float: OpSelect with 0 and 1
// - int/uint --> bool: OPINotEqual 0
// - float --> bool: OpFUnordNotEqual 0
WriteUnaryOp writeUnaryOp = nullptr;
WriteBinaryOp writeBinaryOp = nullptr;
WriteTernaryOp writeTernaryOp = nullptr;
spirv::IdRef zero;
spirv::IdRef one;
switch (valueType.getBasicType())
case EbtFloat:
switch (expectedBasicType)
case EbtInt:
writeUnaryOp = spirv::WriteConvertFToS;
case EbtUInt:
writeUnaryOp = spirv::WriteConvertFToU;
case EbtBool:
zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
writeBinaryOp = spirv::WriteFUnordNotEqual;
case EbtInt:
case EbtUInt:
switch (expectedBasicType)
case EbtFloat:
writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
: spirv::WriteConvertUToF;
case EbtInt:
case EbtUInt:
writeUnaryOp = spirv::WriteBitcast;
case EbtBool:
zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
writeBinaryOp = spirv::WriteINotEqual;
case EbtBool:
writeTernaryOp = spirv::WriteSelect;
switch (expectedBasicType)
case EbtFloat:
zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
one = mBuilder.getVecConstant(1, valueType.getNominalSize());
case EbtInt:
zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
one = mBuilder.getIvecConstant(1, valueType.getNominalSize());
case EbtUInt:
zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
one = mBuilder.getUvecConstant(1, valueType.getNominalSize());
// TODO: support desktop GLSL.
if (writeUnaryOp)
writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
else if (writeBinaryOp)
writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
if (resultTypeIdOut)
*resultTypeIdOut = castTypeId;
return castValue;
spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
const TType &valueType,
const SpirvTypeSpec &valueTypeSpec,
const SpirvTypeSpec &expectedTypeSpec,
spirv::IdRef *resultTypeIdOut)
// If there's no difference in type specialization, there's nothing to cast.
if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock)
return value;
// At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
// specialized by |valueTypeSpec|. However, it's being assigned (for example through operator=,
// used in a constructor or passed as a function argument) where the same GLSL type is expected
// but with different SPIR-V type specialization (|expectedTypeSpec|). SPIR-V 1.4 has
// OpCopyLogical that does exactly that, but we generate SPIR-V 1.0 at the moment.
// The following code recursively copies the array elements or struct fields and then constructs
// the final result with the expected SPIR-V type.
// Interface blocks cannot be copied or passed as parameters in GLSL.
spirv::IdRefList constituents;
if (valueType.isArray())
// Find the SPIR-V type specialization for the element type.
SpirvTypeSpec valueElementTypeSpec = valueTypeSpec;
SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
const bool isElementBlock = valueType.getStruct() != nullptr;
const bool isElementArray = valueType.isArrayOfArrays();
valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
// Get the element type id.
TType elementType(valueType);
const spirv::IdRef elementTypeId =
mBuilder.getTypeData(elementType, valueElementTypeSpec).id;
const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
// Extract each element of the array and cast it to the expected type.
for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
elementId, value, {spirv::LiteralInteger(elementIndex)});
constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
expectedElementTypeSpec, nullptr));
else if (valueType.getStruct() != nullptr)
uint32_t fieldIndex = 0;
// Extract each field of the struct and cast it to the expected type.
for (const TField *field : valueType.getStruct()->fields())
const TType &fieldType = *field->type();
// Find the SPIR-V type specialization for the field type.
SpirvTypeSpec valueFieldTypeSpec = valueTypeSpec;
SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
// Get the field type id.
const spirv::IdRef fieldTypeId = mBuilder.getTypeData(fieldType, valueFieldTypeSpec).id;
// Extract the field.
const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
// Bool types in interface blocks are emulated with uint. bool<->uint cast is done here.
ASSERT(valueType.getBasicType() == EbtBool);
ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
// If value is loaded as uint, it needs to change to bool. If it's bool, it needs to change
// to uint before storage.
if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
TType emulatedValueType(valueType);
return castBasicType(value, emulatedValueType, EbtBool, resultTypeIdOut);
return castBasicType(value, valueType, EbtUInt, resultTypeIdOut);
// Construct the value with the expected type from its cast constituents.
const spirv::IdRef expectedTypeId = mBuilder.getTypeData(valueType, expectedTypeSpec).id;
const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
expectedId, constituents);
if (resultTypeIdOut)
*resultTypeIdOut = expectedTypeId;
return expectedId;
spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
const spirv::IdRefList &valueIds,
spirv::IdRef typeId,
const SpirvDecorations &decorations)
if (valueIds.size() == 2)
// If two values are given, and/or them directly.
WriteBinaryOp writeBinaryOp =
op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
const spirv::IdRef result = mBuilder.getNewId(decorations);
writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
return result;
WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
spirv::IdRef valueId = valueIds[0];
if (valueIds.size() > 2)
// If multiple values are given, construct a bool vector out of them first.
const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
valueId = {mBuilder.getNewId(decorations)};
spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
const spirv::IdRef result = mBuilder.getNewId(decorations);
writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
return result;
void OutputSPIRVTraverser::createCompareImpl(TOperator op,
const TType &operandType,
spirv::IdRef resultTypeId,
spirv::IdRef leftId,
spirv::IdRef rightId,
const SpirvDecorations &operandDecorations,
const SpirvDecorations &resultDecorations,
spirv::LiteralIntegerList *currentAccessChain,
spirv::IdRefList *intermediateResultsOut)
const TBasicType basicType = operandType.getBasicType();
const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
const bool isBool = basicType == EbtBool;
WriteBinaryOp writeBinaryOp = nullptr;
// For arrays, compare them element by element.
if (operandType.isArray())
TType elementType(operandType);
for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
// Select the current element.
currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
// Compare and accumulate the results.
createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
resultDecorations, currentAccessChain, intermediateResultsOut);
// For structs, compare them field by field.
if (operandType.getStruct() != nullptr)
uint32_t fieldIndex = 0;
for (const TField *field : operandType.getStruct()->fields())
// Select the current field.
currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
// Compare and accumulate the results.
createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
resultDecorations, currentAccessChain, intermediateResultsOut);
// For matrices, compare them column by column.
if (operandType.isMatrix())
TType columnType(operandType);
for (int columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
// Select the current column.
currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
// Compare and accumulate the results.
createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
resultDecorations, currentAccessChain, intermediateResultsOut);
// For scalars and vectors generate a single instruction for comparison.
if (op == EOpEqual)
if (isFloat)
writeBinaryOp = spirv::WriteFOrdEqual;
else if (isBool)
writeBinaryOp = spirv::WriteLogicalEqual;
writeBinaryOp = spirv::WriteIEqual;
ASSERT(op == EOpNotEqual);
if (isFloat)
writeBinaryOp = spirv::WriteFUnordNotEqual;
else if (isBool)
writeBinaryOp = spirv::WriteLogicalNotEqual;
writeBinaryOp = spirv::WriteINotEqual;
// Extract the scalar and vector from composite types, if any.
spirv::IdRef leftComponentId = leftId;
spirv::IdRef rightComponentId = rightId;
if (!currentAccessChain->empty())
leftComponentId = mBuilder.getNewId(operandDecorations);
rightComponentId = mBuilder.getNewId(operandDecorations);
const spirv::IdRef componentTypeId =
mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
leftComponentId, leftId, *currentAccessChain);
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
rightComponentId, rightId, *currentAccessChain);
const bool reduceResult = !operandType.isScalar();
spirv::IdRef result = mBuilder.getNewId({});
spirv::IdRef opResultTypeId = resultTypeId;
if (reduceResult)
opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
// Write the comparison operation itself.
writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
// If it's a vector, reduce the result.
if (reduceResult)
result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
size_t lvalueCount)
// The built-ins with lvalues are in one of the following forms:
// - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
// - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
// In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
const size_t childCount = node->getChildCount();
ASSERT(childCount >= 2);
TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
ASSERT(lsbType.isScalar() || lsbType.isVector());
ASSERT(msbType.isScalar() || msbType.isVector());
const BuiltInResultStruct key = {
auto iter = mBuiltInResultStructMap.find(key);
if (iter == mBuiltInResultStructMap.end())
// Create a TStructure and TType for the required structure.
TType *lsbTypeCopy = new TType(lsbType.getBasicType(),
static_cast<unsigned char>(lsbType.getNominalSize()), 1);
TType *msbTypeCopy = new TType(msbType.getBasicType(),
static_cast<unsigned char>(msbType.getNominalSize()), 1);
TFieldList *fields = new TFieldList;
new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
TStructure *structure =
new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
fields, SymbolType::AngleInternal);
TType structType(structure, true);
// Get an id for the type and store in the hash map.
const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
iter = mBuiltInResultStructMap.insert({key, structTypeId}).first;
return iter->second;
// Once the builtin instruction is generated, the two return values are extracted from the
// struct. These are written to the return value (if any) and the out parameters.
void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
TIntermOperator *node,
size_t lvalueCount,
spirv::IdRef structValue,
spirv::IdRef returnValue,
spirv::IdRef returnValueType)
const size_t childCount = node->getChildCount();
ASSERT(childCount >= 2);
TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
if (lvalueCount == 1)
// The built-in is the form:
// lsb = builtin(..., out msb): These are identified by lvalueCount == 1
// Field 0 is lsb, which is extracted as the builtin's return value.
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
returnValue, structValue, {spirv::LiteralInteger(0)});
// Field 1 is msb, which is extracted and stored through the out parameter.
storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
structValue, 1);
// The built-in is the form:
// builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
ASSERT(lvalueCount == 2);
// Field 0 is lsb, which is extracted and stored through the second out parameter.
storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
structValue, 0);
// Field 1 is msb, which is extracted and stored through the first out parameter.
storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
structValue, 1);
void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
TIntermTyped *param,
spirv::IdRef structValue,
uint32_t fieldIndex)
spirv::IdRef fieldTypeId = mBuilder.getTypeData(param->getType(), {}).id;
spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
structValue, {spirv::LiteralInteger(fieldIndex)});
accessChainStore(data, fieldValueId, param->getType());
void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
// Constants are expected to be folded.
// No-op visits to symbols that are being declared. They are handled in visitDeclaration.
if (mIsSymbolBeingDeclared)
// Make sure this does not affect other symbols, for example in the initializer expression.
mIsSymbolBeingDeclared = false;
// The symbol is either:
// - A specialization constant
// - A variable (local, varying etc)
// - An interface block
// - A field of an unnamed interface block
// Specialization constants in SPIR-V are treated largely like constants, in which case make
// this behave like visitConstantUnion().
const TType &type = node->getType();
const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
const TSymbol *symbol = interfaceBlock;
if (interfaceBlock == nullptr)
symbol = &node->variable();
// Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
// They are needed to determine the derived type in an access chain, but are not promoted in
// intermediate nodes' TTypes.
SpirvTypeSpec typeSpec;
typeSpec.inferDefaults(type, mCompiler);
const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
// If the symbol is a const variable, such as a const function parameter or specialization
// constant, create an rvalue.
if (type.getQualifier() == EvqConst || type.getQualifier() == EvqSpecConst)
ASSERT(interfaceBlock == nullptr);
ASSERT(mSymbolIdMap.count(symbol) > 0);
nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
// Otherwise create an lvalue.
spv::StorageClass storageClass;
const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
// If a field of a nameless interface block, create an access chain.
if (type.getInterfaceBlock() && !type.isInterfaceBlock())
uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
const TType &type = node->getType();
// Find out the expected type for this constant, so it can be cast right away and not need an
// instruction to do that.
TIntermNode *parent = getParentNode();
const size_t childIndex = getParentChildIndex(PreVisit);
TBasicType expectedBasicType = type.getBasicType();
if (parent->getAsAggregate())
TIntermAggregate *parentAggregate = parent->getAsAggregate();
// There are three possibilities:
// - It's a struct constructor: The basic type must match that of the corresponding field of
// the struct.
// - It's a non struct constructor: The basic type must match that of the type being
// constructed.
// - It's a function call: The basic type must match that of the corresponding argument.
if (parentAggregate->isConstructor())
const TStructure *structure = parentAggregate->getType().getStruct();
if (structure != nullptr)
expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
expectedBasicType = parentAggregate->getType().getBasicType();
expectedBasicType =
// TODO: other node types such as binary, ternary etc.
const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
nodeDataInitRValue(&mNodeData.back(), constId, typeId);
bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
// Constants are expected to be folded.
if (visit == PreVisit)
// Don't add an entry to the stack. The child will create one, which we won't pop.
return true;
ASSERT(visit == PostVisit);
ASSERT(mNodeData.size() >= 1);
const TType &vectorType = node->getOperand()->getType();
const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
const TVector<int> &swizzle = node->getSwizzleOffsets();
// As an optimization, do nothing if the swizzle is selecting all the components of the vector
// in order.
bool isIdentity = swizzle.size() == vectorComponentCount;
for (size_t index = 0; index < swizzle.size(); ++index)
isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
if (isIdentity)
return true;
accessChainOnPush(&mNodeData.back(), vectorType, 0);
const spirv::IdRef typeId =
mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
return true;
bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
// Constants are expected to be folded.
if (visit == PreVisit)
// Don't add an entry to the stack. The left child will create one, which we won't pop.
return true;
// If this is a variable initialization node, defer any code generation to visitDeclaration.
if (node->getOp() == EOpInitialize)
ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
return true;
if (IsShortCircuitNeeded(node))
// For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
// |if| construct. At this point, the left-hand side is already evaluated, so we need to
// create an appropriate conditional on in-visit and visit the right-hand-side inside the
// conditional block. On post-visit, OpPhi is used to calculate the result.
if (visit == InVisit)
return true;
spirv::IdRef typeId;
const spirv::IdRef result = endShortCircuit(node, &typeId);
// Replace the access chain with an rvalue that's the result.
nodeDataInitRValue(&mNodeData.back(), result, typeId);
return true;
if (visit == InVisit)
// Left child visited. Take the entry it created as the current node's.
ASSERT(mNodeData.size() >= 1);
// As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
// add it to the access chain as literal.
switch (node->getOp())
case EOpIndexDirect:
case EOpIndexDirectStruct:
case EOpIndexDirectInterfaceBlock:
const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
const spirv::IdRef typeId =
mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
// Don't visit the right child, it's already processed.
return false;
return true;
// There are at least two entries, one for the left node and one for the right one.
ASSERT(mNodeData.size() >= 2);
SpirvTypeSpec resultTypeSpec;
if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
if (node->getOp() == EOpIndexIndirect)
accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
// For EOpIndex* operations, push the right value as an index to the left value's access chain.
// For the other operations, evaluate the expression.
switch (node->getOp())
case EOpIndexDirect:
case EOpIndexDirectStruct:
case EOpIndexDirectInterfaceBlock:
case EOpIndexIndirect:
// Load the index.
const spirv::IdRef rightValue =
accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
case EOpAssign:
// Load the right hand side of assignment.
const spirv::IdRef rightValue =
accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
// Store into the access chain. Since the result of the (a = b) expression is b, change
// the access chain to an unindexed rvalue which is |rightValue|.
accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
case EOpComma:
// When the expression a,b is visited, all side effects of a and b are already
// processed. What's left is to to replace the expression with the result of b. This
// is simply done by dropping the left node and placing the right node as the result.
mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
const spirv::IdRef result = visitOperator(node, resultTypeId);
nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
// TODO: Handle NoContraction decoration.
return true;
bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
// Constants are expected to be folded.
if (visit == PreVisit)
// Don't add an entry to the stack. The child will create one, which we won't pop.
return true;
// It's a unary operation, so there can't be an InVisit.
ASSERT(visit != InVisit);
// There is at least on entry for the child.
ASSERT(mNodeData.size() >= 1);
// Special case EOpArrayLength. .length() on sized arrays is already constant folded, so this
// operation only applies to ssbo.last_member.length(). OpArrayLength takes the ssbo block
// *type* and the field index of last_member, so those need to be extracted from the access
// chain. Additionally, OpArrayLength produces an unsigned int while GLSL produces an int, so a
// final cast is necessary.
if (node->getOp() == EOpArrayLength)
// The access chain must only include the base ssbo + one literal field index.
ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
const spirv::IdRef baseId = mNodeData.back().baseId;
const spirv::LiteralInteger fieldIndex = mNodeData.back().idList.back().literal;
// Get the int and uint type ids.
const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
// Generate the instruction.
const spirv::IdRef resultId = mBuilder.getNewId({});
spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
baseId, fieldIndex);
// Cast to int.
const spirv::IdRef castResultId = mBuilder.getNewId({});
spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId,
// Replace the access chain with an rvalue that's the result.
nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
return true;
const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
const spirv::IdRef result = visitOperator(node, resultTypeId);
// Keep the result as rvalue.
nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
return true;
bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
if (visit == PreVisit)
// Don't add an entry to the stack. The condition will create one, which we won't pop.
return true;
size_t lastChildIndex = getLastTraversedChildIndex(visit);
// If the condition was just visited, evaluate it and decide if OpSelect could be used or an
// if-else must be emitted. OpSelect is only used if the type is scalar or vector (required by
// OpSelect) and if neither side has a side effect.
const TType &type = node->getType();
const bool canUseOpSelect = (type.isScalar() || type.isVector()) &&
!node->getTrueExpression()->hasSideEffects() &&
if (lastChildIndex == 0)
spirv::IdRef typeId;
spirv::IdRef conditionValue =
accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), &typeId);
// If OpSelect can be used, keep the condition for later usage.
if (canUseOpSelect)
// SPIR-V 1.0 requires that the condition value have as many components as the result.
// So when selecting between vectors, we must replicate the condition scalar.
if (type.isVector())
typeId = mBuilder.getBasicTypeId(node->getCondition()->getType().getBasicType(),
conditionValue =
createConstructorVectorFromScalar(type, typeId, {{conditionValue}});
nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
return true;
// Otherwise generate an if-else construct.
// Three blocks necessary; the true, false and merge.
mBuilder.startConditional(3, false, false);
// Generate the branch instructions.
const SpirvConditional *conditional = mBuilder.getCurrentConditional();
const spirv::IdRef trueBlockId = conditional->blockIds[0];
const spirv::IdRef falseBlockId = conditional->blockIds[1];
const spirv::IdRef mergeBlockId = conditional->blockIds.back();
mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
return true;
// Load the result of the true or false part, and keep it for the end. It's either used in
// OpSelect or OpPhi.
spirv::IdRef typeId;
const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
if (!canUseOpSelect)
// Move on to the next block.
// When done, generate either OpSelect or OpPhi.
if (visit == PostVisit)
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
ASSERT(mNodeData.back().idList.size() == 2);
const spirv::IdRef trueValue = mNodeData.back().idList[0].id;
const spirv::IdRef falseValue = mNodeData.back().idList[1].id;
if (canUseOpSelect)
const spirv::IdRef conditionValue = mNodeData.back().baseId;
spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
conditionValue, trueValue, falseValue);
const SpirvConditional *conditional = mBuilder.getCurrentConditional();
const spirv::IdRef trueBlockId = conditional->blockIds[0];
const spirv::IdRef falseBlockId = conditional->blockIds[1];
spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
{spirv::PairIdRefIdRef{trueValue, trueBlockId},
spirv::PairIdRefIdRef{falseValue, falseBlockId}});
// Replace the access chain with an rvalue that's the result.
nodeDataInitRValue(&mNodeData.back(), result, typeId);
return true;
bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
if (visit == PreVisit)
// Don't add an entry to the stack. The condition will create one, which we won't pop.
return true;
const size_t lastChildIndex = getLastTraversedChildIndex(visit);
// If the condition was just visited, evaluate it and create the branch instructions.
if (lastChildIndex == 0)
const spirv::IdRef conditionValue =
accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
// Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
// else block (if any), and one for the merge block. getChildCount() works here as it
// produces an identical count.
mBuilder.startConditional(node->getChildCount(), false, false);
// Generate the branch instructions.
const SpirvConditional *conditional = mBuilder.getCurrentConditional();
const spirv::IdRef mergeBlock = conditional->blockIds.back();
spirv::IdRef trueBlock = mergeBlock;
spirv::IdRef falseBlock = mergeBlock;
size_t nextBlockIndex = 0;
if (node->getTrueBlock())
trueBlock = conditional->blockIds[nextBlockIndex++];
if (node->getFalseBlock())
falseBlock = conditional->blockIds[nextBlockIndex++];
mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
return true;
// Otherwise move on to the next block, inserting a branch to the merge block at the end of each
// block.
// Pop from the conditional stack when done.
if (visit == PostVisit)
return true;
bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
// Take the following switch:
// switch (c)
// {
// case A:
// ABlock;
// break;
// case B:
// default:
// BBlock;
// break;
// case C:
// CBlock;
// // fallthrough
// case D:
// DBlock;
// }
// In SPIR-V, this is implemented similarly to the following pseudo-code:
// switch c:
// A -> jump %A
// B -> jump %B
// C -> jump %C
// D -> jump %D
// default -> jump %B
// %A:
// ABlock
// jump %merge
// %B:
// BBlock
// jump %merge
// %C:
// CBlock
// jump %D
// %D:
// DBlock
// jump %merge
// The OpSwitch instruction contains the jump labels for the default and other cases. Each
// block either terminates with a jump to the merge block or the next block as fallthrough.
// // pre-switch block
// OpSelectionMerge %merge None
// OpSwitch %cond %C A %A B %B C %C D %D
// %A = OpLabel
// ABlock
// OpBranch %merge
// %B = OpLabel
// BBlock
// OpBranch %merge
// %C = OpLabel
// CBlock
// OpBranch %D
// %D = OpLabel
// DBlock
// OpBranch %merge
if (visit == PreVisit)
// Don't add an entry to the stack. The condition will create one, which we won't pop.
return true;
// If the condition was just visited, evaluate it and create the switch instruction.
if (visit == InVisit)
ASSERT(getLastTraversedChildIndex(visit) == 0);
const spirv::IdRef conditionValue =
accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
// First, need to find out how many blocks are there in the switch.
const TIntermSequence &statements = *node->getStatementList()->getSequence();
bool lastWasCase = true;
size_t blockIndex = 0;
size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
TVector<uint32_t> caseValues;
TVector<size_t> caseBlockIndices;
for (TIntermNode *statement : statements)
TIntermCase *caseLabel = statement->getAsCaseNode();
const bool isCaseLabel = caseLabel != nullptr;
if (isCaseLabel)
// For every case label, remember its block index. This is used later to generate
// the OpSwitch instruction.
if (caseLabel->hasCondition())
// All switch conditions are literals.
TIntermConstantUnion *condition =
ASSERT(condition != nullptr);
TConstantUnion caseValue;
caseValue.cast(EbtUInt, *condition->getConstantValue());
// Remember the block index of the default case.
defaultBlockIndex = blockIndex;
lastWasCase = true;
else if (lastWasCase)
// Every time a non-case node is visited and the previous statement was a case node,
// it's a new block.
lastWasCase = false;
// Block count is the number of blocks based on cases + 1 for the merge block.
const size_t blockCount = blockIndex + 1;
mBuilder.startConditional(blockCount, false, true);
// Generate the switch instructions.
const SpirvConditional *conditional = mBuilder.getCurrentConditional();
// Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction. If
// the switch ends in a number of cases with no statements following them, they will
// naturally jump to the merge block!
spirv::PairLiteralIntegerIdRefList switchTargets;
for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
uint32_t value = caseValues[caseIndex];
size_t caseBlockIndex = caseBlockIndices[caseIndex];
{spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
const spirv::IdRef mergeBlock = conditional->blockIds.back();
const spirv::IdRef defaultBlock = defaultBlockIndex < caseValues.size()
? conditional->blockIds[defaultBlockIndex]
: mergeBlock;
mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
return true;
// Terminate the last block if not already and end the conditional.
return true;
bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
ASSERT(visit == PreVisit);
TIntermBlock *parent = getParentNode()->getAsBlock();
const size_t childIndex = getParentChildIndex(PreVisit);
const TIntermSequence &parentStatements = *parent->getSequence();
// Check the previous statement. If it was not a |case|, then a new block is being started so
// handle fallthrough:
// ...
// statement;
// case X: <--- end the previous block here
// case Y:
if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
// Don't traverse the condition, as it was processed in visitSwitch.
return false;
bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
// If global block, nothing to do.
if (getCurrentTraversalDepth() == 0)
return true;
// Any construct that needs code blocks must have already handled creating the necessary blocks
// and setting the right one "current". If there's a block opened in GLSL for scoping reasons,
// it's ignored here as there are no scopes within a function in SPIR-V.
if (visit == PreVisit)
return node->getChildCount() > 0;
// Any node that needed to generate code has already done so, just clean up its data. If
// the child node has no effect, it's automatically discarded (such as variable.field[n].x,
// side effects of n already having generated code).
// Blocks inside blocks like:
// {
// statement;
// {
// statement2;
// }
// }
// don't generate nodes.
const size_t childIndex = getLastTraversedChildIndex(visit);
const TIntermSequence &statements = *node->getSequence();
if (statements[childIndex]->getAsBlock() == nullptr)
return true;
bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
if (visit == PreVisit)
return true;
// After the prototype is visited, generate the initial code for the function.
if (visit == InVisit)
const TFunction *function = node->getFunction();
ASSERT(mFunctionIdMap.count(function) > 0);
const FunctionIds &ids = mFunctionIdMap[function];
// Declare the function.
spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
spv::FunctionControlMaskNone, ids.functionTypeId);
for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
const TVariable *paramVariable = function->getParam(paramIndex);
const spirv::IdRef paramId =
ids.parameterTypeIds[paramIndex], paramId);
// Remember the id of the variable for future look up.
ASSERT(mSymbolIdMap.count(paramVariable) == 0);
mSymbolIdMap[paramVariable] = paramId;
spirv::WriteName(mBuilder.getSpirvDebug(), paramId,
mBuilder.startNewFunction(ids.functionId, function);
return true;
// If no explicit return was specified, add one automatically here.
if (!mBuilder.isCurrentFunctionBlockTerminated())
if (node->getFunction()->getReturnType().getBasicType() == EbtVoid)
// GLSL allows functions expecting a return value to miss a return. In that case,
// return a null constant.
const TFunction *function = node->getFunction();
const TType &returnType = function->getReturnType();
spirv::IdRef nullConstant;
if (returnType.isScalar() && !returnType.isArray())
switch (function->getReturnType().getBasicType())
case EbtFloat:
nullConstant = mBuilder.getFloatConstant(0);
case EbtUInt:
nullConstant = mBuilder.getUintConstant(0);
case EbtInt:
nullConstant = mBuilder.getIntConstant(0);
if (!nullConstant.valid())
nullConstant = mBuilder.getNullConstant(mFunctionIdMap[function].returnTypeId);
spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
// End the function
return true;
bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
TIntermGlobalQualifierDeclaration *node)
if (node->isPrecise())
// TODO: handle precise.
return false;
// Global qualifier declarations apply to variables that are already declared. Invariant simply
// adds a decoration to the variable declaration, which can be done right away. Note that
// invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
// which are applied to the members directly by DeclarePerVertexBlocks.
const TVariable *variable = &node->getSymbol()->variable();
ASSERT(mSymbolIdMap.count(variable) > 0);
const spirv::IdRef variableId = mSymbolIdMap[variable];
spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
return false;
void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
const TFunction *function = node->getFunction();
// If the function was previously forward declared, skip this.
if (mFunctionIdMap.count(function) > 0)
FunctionIds ids;
// Declare the function type
ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
spirv::IdRefList paramTypeIds;
for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
const TType &paramType = function->getParam(paramIndex)->getType();
spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
// const function parameters are intermediate values, while the rest are "variables"
// with the Function storage class.
if (paramType.getQualifier() != EvqConst)
const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
? spv::StorageClassUniformConstant
: spv::StorageClassFunction;
paramId = mBuilder.getTypePointerId(paramId, storageClass);
ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
// Allocate an id for the function up-front.
// Apply decorations to the return value of the function by applying them to the OpFunction
// instruction.
ids.functionId = mBuilder.getNewId(mBuilder.getDecorations(function->getReturnType()));
// Remember the ID of main() for the sake of OpEntryPoint.
if (function->isMain())
// Remember the id of the function for future look up.
mFunctionIdMap[function] = ids;
bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
// Constants are expected to be folded. However, large constructors (such as arrays) are not
// folded and are handled here.
ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
if (visit == PreVisit)
return true;
// Keep the parameters on the stack. If a function call contains out or inout parameters, we
// need to know the access chains for the eventual write back to them.
if (visit == InVisit)
return true;
// Expect to have accumulated as many parameters as the node requires.
ASSERT(mNodeData.size() > node->getChildCount());
const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
spirv::IdRef result;
switch (node->getOp())
case EOpConstruct:
// Construct a value out of the accumulated parameters.
result = createConstructor(node, resultTypeId);
case EOpCallFunctionInAST:
// Create a call to the function.
result = createFunctionCall(node, resultTypeId);
// For barrier functions the scope is device, or with the Vulkan memory model, the queue
// family. We don't use the Vulkan memory model.
case EOpBarrier:
mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
case EOpBarrierTCS:
// Note: The memory scope and semantics are different with the Vulkan memory model,
// which is not supported.
case EOpMemoryBarrier:
case EOpGroupMemoryBarrier:
const spv::Scope scope =
node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
spv::MemorySemanticsWorkgroupMemoryMask |
spv::MemorySemanticsImageMemoryMask |
case EOpMemoryBarrierBuffer:
mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
case EOpMemoryBarrierImage:
mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
case EOpMemoryBarrierShared:
mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
case EOpMemoryBarrierAtomicCounter:
// Atomic counters are emulated.
case EOpEmitVertex:
case EOpEndPrimitive:
case EOpEmitStreamVertex:
case EOpEndStreamPrimitive:
// TODO: support geometry shaders.
result = visitOperator(node, resultTypeId);
// Pop the parameters.
mNodeData.resize(mNodeData.size() - node->getChildCount());
// Keep the result as rvalue.
nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
return false;
bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
const TIntermSequence &sequence = *node->getSequence();
// Enforced by ValidateASTOptions::validateMultiDeclarations.
ASSERT(sequence.size() == 1);
// Declare specialization constants especially; they don't require processing the left and right
// nodes, and they are like constant declarations with special instructions and decorations.
if (sequence.front()->getAsTyped()->getType().getQualifier() == EvqSpecConst)
return false;
if (!mInGlobalScope && visit == PreVisit)
mIsSymbolBeingDeclared = visit == PreVisit;
if (visit != PostVisit)
return true;
TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
spirv::IdRef initializerId;
bool initializeWithDeclaration = false;
// Handle declarations with initializer.
if (symbol == nullptr)
TIntermBinary *assign = sequence.front()->getAsBinaryNode();
ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
symbol = assign->getLeft()->getAsSymbolNode();
ASSERT(symbol != nullptr);
// In SPIR-V, it's only possible to initialize a variable together with its declaration if
// the initializer is a constant or a global variable. We ignore the global variable case
// to avoid tracking whether the variable has been modified since the beginning of the
// function. Since variable declarations are always placed at the beginning of the function
// in SPIR-V, it would be wrong for example to initialize |var| below with the global
// variable at declaration time:
// vec4 global = A;
// void f()
// {
// global = B;
// {
// vec4 var = global;
// }
// }
// So the initializer is only used when declarating a variable when it's a constant
// expression. Note that if the variable being declared is itself global (and the
// initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
// makes sure their initialization is deferred to the beginning of main.
// Additionally, if the variable is being defined inside a loop, the initializer is not used
// as that would prevent it from being reintialized in the next iteration of the loop.
TIntermTyped *initializer = assign->getRight();
initializeWithDeclaration =
!mBuilder.isInLoop() &&
(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
if (initializeWithDeclaration)
// If a constant, take the Id directly.
initializerId = mNodeData.back().baseId;
// Otherwise generate code to load from right hand side expression.
initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
// Clean up the initializer data.
const TType &type = symbol->getType();
const TVariable *variable = &symbol->variable();
// If this is just a struct declaration (and not a variable declaration), don't declare the
// struct up-front and let it be lazily defined. If the struct is only used inside an interface
// block for example, this avoids it being doubly defined (once with the unspecified block
// storage and once with interface block's).
if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
return false;
const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
spv::StorageClass storageClass = GetStorageClass(type);
SpirvDecorations decorations = mBuilder.getDecorations(type);
if (mBuilder.isInvariantOutput(type))
// Apply the Invariant decoration to output variables if specified or if globally enabled.
const spirv::IdRef variableId = mBuilder.declareVariable(
typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
if (!initializeWithDeclaration && initializerId.valid())
// If not initializing at the same time as the declaration, issue a store instruction.
spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
// Add decorations, which apply to the element type of arrays, if array.
spirv::IdRef nonArrayTypeId = typeId;
if (type.isArray() && (isShaderInOut || isInterfaceBlock))
SpirvType elementType = mBuilder.getSpirvType(type, {});
elementType.arraySizes = {};
nonArrayTypeId = mBuilder.getSpirvTypeData(elementType, nullptr).id;
if (isShaderInOut)
// Add in and out variables to the list of interface variables.
if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
// For gl_PerVertex in particular, write the necessary BuiltIn decorations
if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
// I/O blocks are decorated with Block
spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
spv::DecorationBlock, {});
else if (isInterfaceBlock)
// For uniform and buffer variables, add Block and BufferBlock decorations respectively.
const spv::Decoration decoration =
type.getQualifier() == EvqUniform ? spv::DecorationBlock : spv::DecorationBufferBlock;
spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
// Write DescriptorSet, Binding, Location etc decorations if necessary.
mBuilder.writeInterfaceVariableDecorations(type, variableId);
// Remember the id of the variable for future look up. For interface blocks, also remember the
// id of the interface block.
ASSERT(mSymbolIdMap.count(variable) == 0);
mSymbolIdMap[variable] = variableId;
if (type.isInterfaceBlock())
ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
mSymbolIdMap[type.getInterfaceBlock()] = variableId;
return false;
void GetLoopBlocks(const SpirvConditional *conditional,
TLoopType loopType,
bool hasCondition,
spirv::IdRef *headerBlock,
spirv::IdRef *condBlock,
spirv::IdRef *bodyBlock,
spirv::IdRef *continueBlock,
spirv::IdRef *mergeBlock)
// The order of the blocks is for |for| and |while|:
// %header %cond [optional] %body %continue %merge
// and for |do-while|:
// %header %body %cond %merge
// Note that the |break| target is always the last block and the |continue| target is the one
// before last.
// If %continue is not present, all jumps are made to %cond (which is necessarily present).
// If %cond is not present, all jumps are made to %body instead.
size_t nextBlock = 0;
*headerBlock = conditional->blockIds[nextBlock++];
// %cond, if any is after header except for |do-while|.
if (loopType != ELoopDoWhile && hasCondition)
*condBlock = conditional->blockIds[nextBlock++];
*bodyBlock = conditional->blockIds[nextBlock++];
// After the block is either %cond or %continue based on |do-while| or not.
if (loopType != ELoopDoWhile)
*continueBlock = conditional->blockIds[nextBlock++];
*condBlock = conditional->blockIds[nextBlock++];
*mergeBlock = conditional->blockIds[nextBlock++];
ASSERT(nextBlock == conditional->blockIds.size());
if (!continueBlock->valid())
*continueBlock = *condBlock;
if (!condBlock->valid())
*condBlock = *bodyBlock;
bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
// There are three kinds of loops, and they translate as such:
// for (init; cond; expr) body;
// // pre-loop block
// init
// OpBranch %header
// %header = OpLabel
// OpLoopMerge %merge %continue None
// OpBranch %cond
// // Note: if cond doesn't exist, this section is not generated. The above
// // OpBranch would jump directly to %body.
// %cond = OpLabel
// %v = cond
// OpBranchConditional %v %body %merge None
// %body = OpLabel
// body
// OpBranch %continue
// %continue = OpLabel
// expr
// OpBranch %header
// // post-loop block
// %merge = OpLabel
// while (cond) body;
// // pre-for block
// OpBranch %header
// %header = OpLabel
// OpLoopMerge %merge %continue None
// OpBranch %cond
// %cond = OpLabel
// %v = cond
// OpBranchConditional %v %body %merge None
// %body = OpLabel
// body
// OpBranch %continue
// %continue = OpLabel
// OpBranch %header
// // post-loop block
// %merge = OpLabel
// do body; while (cond);
// // pre-for block
// OpBranch %header
// %header = OpLabel
// OpLoopMerge %merge %cond None
// OpBranch %body
// %body = OpLabel
// body
// OpBranch %cond
// %cond = OpLabel
// %v = cond
// OpBranchConditional %v %header %merge None
// // post-loop block
// %merge = OpLabel
// The order of the blocks is not necessarily the same as traversed, so it's much simpler if
// this function enforces traversal in the right order.
ASSERT(visit == PreVisit);
const TLoopType loopType = node->getType();
// The init statement of a for loop is placed in the previous block, so continue generating code
// as-is until that statement is done.
if (node->getInit())
ASSERT(loopType == ELoopFor);
const bool hasCondition = node->getCondition() != nullptr;
// Once the init node is visited, if any, we need to set up the loop.
// For |for| and |while|, we need %header, %body, %continue and %merge. For |do-while|, we
// need %header, %body and %merge. If condition is present, an additional %cond block is
// needed in each case.
const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
mBuilder.startConditional(blockCount, true, true);
// Generate the %header block.
const SpirvConditional *conditional = mBuilder.getCurrentConditional();
spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
&continueBlock, &mergeBlock);
mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
// %cond, if any is after header except for |do-while|.
if (loopType != ELoopDoWhile && hasCondition)
// Generate the branch at the end of the %cond block.
const spirv::IdRef conditionValue =
accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
// Next comes %body.
// Generate the branch at the end of the %body block.
switch (loopType)
case ELoopFor:
// For |for| loops, the expression is placed after the body and acts as the continue
// block.
if (node->getExpression())
// Generate the branch at the end of the %continue block.
case ELoopWhile:
// |for| loops have the expression in the continue block and |do-while| loops have their
// condition block act as the loop's continue block. |while| loops need a branch-only
// continue loop, which is generated here.
case ELoopDoWhile:
// For |do-while|, %cond comes last.
// Generate the branch at the end of the %cond block.
const spirv::IdRef conditionValue =
accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
// Pop from the conditional stack when done.
// Don't traverse the children, that's done already.
return false;
bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
if (visit == PreVisit)
return true;
// There is only ever one child at most.
ASSERT(visit != InVisit);
switch (node->getFlowOp())
case EOpKill:
case EOpBreak:
case EOpContinue:
case EOpReturn:
// Evaluate the expression if any, and return.
if (node->getExpression() != nullptr)
ASSERT(mNodeData.size() >= 1);
const spirv::IdRef expressionValue =
accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
return true;
void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
// No preprocessor directives expected at this point.
spirv::Blob OutputSPIRVTraverser::getSpirv()
spirv::Blob result = mBuilder.getSpirv();
// Validate that correct SPIR-V was generated
// Disassemble and log the generated SPIR-V for debugging.
spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
std::string readableSpirv;
spirvTools.Disassemble(result, &readableSpirv, 0);
fprintf(stderr, "%s\n", readableSpirv.c_str());
return result;
} // anonymous namespace
bool OutputSPIRV(TCompiler *compiler,
TIntermBlock *root,
ShCompileOptions compileOptions,
bool forceHighp)
// Traverse the tree and generate SPIR-V instructions
OutputSPIRVTraverser traverser(compiler, compileOptions, forceHighp);
// Generate the final SPIR-V and store in the sink
spirv::Blob spirvBlob = traverser.getSpirv();
return true;
} // namespace sh