blob: 0a2dab3adf7ac14d9be9f9f2ee5d9a3b9181d802 [file] [log] [blame]
ethannicholasb3058bd2016-07-01 08:22:01 -07001/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLSPIRVCodeGenerator.h"
9
10#include "string.h"
11
12#include "GLSL.std.450.h"
13
14#include "ir/SkSLExpressionStatement.h"
15#include "ir/SkSLExtension.h"
16#include "ir/SkSLIndexExpression.h"
17#include "ir/SkSLVariableReference.h"
18
19namespace SkSL {
20
21#define SPIRV_DEBUG 0
22
23static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number
24
25void SPIRVCodeGenerator::setupIntrinsics() {
26#define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
27 GLSLstd450 ## x, GLSLstd450 ## x)
28#define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
29 GLSLstd450 ## ifFloat, \
30 GLSLstd450 ## ifInt, \
31 GLSLstd450 ## ifUInt, \
32 SpvOpUndef)
33#define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
34 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
35 k ## x ## _SpecialIntrinsic)
36 fIntrinsicMap["round"] = ALL_GLSL(Round);
37 fIntrinsicMap["roundEven"] = ALL_GLSL(RoundEven);
38 fIntrinsicMap["trunc"] = ALL_GLSL(Trunc);
39 fIntrinsicMap["abs"] = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
40 fIntrinsicMap["sign"] = BY_TYPE_GLSL(FSign, SSign, SSign);
41 fIntrinsicMap["floor"] = ALL_GLSL(Floor);
42 fIntrinsicMap["ceil"] = ALL_GLSL(Ceil);
43 fIntrinsicMap["fract"] = ALL_GLSL(Fract);
44 fIntrinsicMap["radians"] = ALL_GLSL(Radians);
45 fIntrinsicMap["degrees"] = ALL_GLSL(Degrees);
46 fIntrinsicMap["sin"] = ALL_GLSL(Sin);
47 fIntrinsicMap["cos"] = ALL_GLSL(Cos);
48 fIntrinsicMap["tan"] = ALL_GLSL(Tan);
49 fIntrinsicMap["asin"] = ALL_GLSL(Asin);
50 fIntrinsicMap["acos"] = ALL_GLSL(Acos);
51 fIntrinsicMap["atan"] = SPECIAL(Atan);
52 fIntrinsicMap["sinh"] = ALL_GLSL(Sinh);
53 fIntrinsicMap["cosh"] = ALL_GLSL(Cosh);
54 fIntrinsicMap["tanh"] = ALL_GLSL(Tanh);
55 fIntrinsicMap["asinh"] = ALL_GLSL(Asinh);
56 fIntrinsicMap["acosh"] = ALL_GLSL(Acosh);
57 fIntrinsicMap["atanh"] = ALL_GLSL(Atanh);
58 fIntrinsicMap["pow"] = ALL_GLSL(Pow);
59 fIntrinsicMap["exp"] = ALL_GLSL(Exp);
60 fIntrinsicMap["log"] = ALL_GLSL(Log);
61 fIntrinsicMap["exp2"] = ALL_GLSL(Exp2);
62 fIntrinsicMap["log2"] = ALL_GLSL(Log2);
63 fIntrinsicMap["sqrt"] = ALL_GLSL(Sqrt);
64 fIntrinsicMap["inversesqrt"] = ALL_GLSL(InverseSqrt);
65 fIntrinsicMap["determinant"] = ALL_GLSL(Determinant);
66 fIntrinsicMap["matrixInverse"] = ALL_GLSL(MatrixInverse);
67 fIntrinsicMap["mod"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFMod, SpvOpSMod,
68 SpvOpUMod, SpvOpUndef);
69 fIntrinsicMap["min"] = BY_TYPE_GLSL(FMin, SMin, UMin);
70 fIntrinsicMap["max"] = BY_TYPE_GLSL(FMax, SMax, UMax);
71 fIntrinsicMap["clamp"] = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
72 fIntrinsicMap["dot"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot, SpvOpUndef,
73 SpvOpUndef, SpvOpUndef);
74 fIntrinsicMap["mix"] = ALL_GLSL(FMix);
75 fIntrinsicMap["step"] = ALL_GLSL(Step);
76 fIntrinsicMap["smoothstep"] = ALL_GLSL(SmoothStep);
77 fIntrinsicMap["fma"] = ALL_GLSL(Fma);
78 fIntrinsicMap["frexp"] = ALL_GLSL(Frexp);
79 fIntrinsicMap["ldexp"] = ALL_GLSL(Ldexp);
80
81#define PACK(type) fIntrinsicMap["pack" #type] = ALL_GLSL(Pack ## type); \
82 fIntrinsicMap["unpack" #type] = ALL_GLSL(Unpack ## type)
83 PACK(Snorm4x8);
84 PACK(Unorm4x8);
85 PACK(Snorm2x16);
86 PACK(Unorm2x16);
87 PACK(Half2x16);
88 PACK(Double2x32);
89 fIntrinsicMap["length"] = ALL_GLSL(Length);
90 fIntrinsicMap["distance"] = ALL_GLSL(Distance);
91 fIntrinsicMap["cross"] = ALL_GLSL(Cross);
92 fIntrinsicMap["normalize"] = ALL_GLSL(Normalize);
93 fIntrinsicMap["faceForward"] = ALL_GLSL(FaceForward);
94 fIntrinsicMap["reflect"] = ALL_GLSL(Reflect);
95 fIntrinsicMap["refract"] = ALL_GLSL(Refract);
96 fIntrinsicMap["findLSB"] = ALL_GLSL(FindILsb);
97 fIntrinsicMap["findMSB"] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
98 fIntrinsicMap["dFdx"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx, SpvOpUndef,
99 SpvOpUndef, SpvOpUndef);
100 fIntrinsicMap["dFdy"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, SpvOpUndef,
101 SpvOpUndef, SpvOpUndef);
102 fIntrinsicMap["dFdy"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, SpvOpUndef,
103 SpvOpUndef, SpvOpUndef);
104 fIntrinsicMap["texture"] = SPECIAL(Texture);
105 fIntrinsicMap["texture2D"] = SPECIAL(Texture2D);
106 fIntrinsicMap["textureProj"] = SPECIAL(TextureProj);
107
108 fIntrinsicMap["any"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
109 SpvOpUndef, SpvOpUndef, SpvOpAny);
110 fIntrinsicMap["all"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
111 SpvOpUndef, SpvOpUndef, SpvOpAll);
112 fIntrinsicMap["equal"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFOrdEqual,
113 SpvOpIEqual, SpvOpIEqual,
114 SpvOpLogicalEqual);
115 fIntrinsicMap["notEqual"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFOrdNotEqual,
116 SpvOpINotEqual, SpvOpINotEqual,
117 SpvOpLogicalNotEqual);
118 fIntrinsicMap["lessThan"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpSLessThan,
119 SpvOpULessThan, SpvOpFOrdLessThan,
120 SpvOpUndef);
121 fIntrinsicMap["lessThanEqual"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpSLessThanEqual,
122 SpvOpULessThanEqual, SpvOpFOrdLessThanEqual,
123 SpvOpUndef);
124 fIntrinsicMap["greaterThan"] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpSGreaterThan,
125 SpvOpUGreaterThan, SpvOpFOrdGreaterThan,
126 SpvOpUndef);
127 fIntrinsicMap["greaterThanEqual"] = std::make_tuple(kSPIRV_IntrinsicKind,
128 SpvOpSGreaterThanEqual,
129 SpvOpUGreaterThanEqual,
130 SpvOpFOrdGreaterThanEqual,
131 SpvOpUndef);
132
133// interpolateAt* not yet supported...
134}
135
136void SPIRVCodeGenerator::writeWord(int32_t word, std::ostream& out) {
137#if SPIRV_DEBUG
138 out << "(" << word << ") ";
139#else
140 out.write((const char*) &word, sizeof(word));
141#endif
142}
143
144static bool is_float(const Type& type) {
145 if (type.kind() == Type::kVector_Kind) {
146 return is_float(*type.componentType());
147 }
148 return type == *kFloat_Type || type == *kDouble_Type;
149}
150
151static bool is_signed(const Type& type) {
152 if (type.kind() == Type::kVector_Kind) {
153 return is_signed(*type.componentType());
154 }
155 return type == *kInt_Type;
156}
157
158static bool is_unsigned(const Type& type) {
159 if (type.kind() == Type::kVector_Kind) {
160 return is_unsigned(*type.componentType());
161 }
162 return type == *kUInt_Type;
163}
164
165static bool is_bool(const Type& type) {
166 if (type.kind() == Type::kVector_Kind) {
167 return is_bool(*type.componentType());
168 }
169 return type == *kBool_Type;
170}
171
172static bool is_out(std::shared_ptr<Variable> var) {
173 return (var->fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
174}
175
176#if SPIRV_DEBUG
177static std::string opcode_text(SpvOp_ opCode) {
178 switch (opCode) {
179 case SpvOpNop:
180 return "Nop";
181 case SpvOpUndef:
182 return "Undef";
183 case SpvOpSourceContinued:
184 return "SourceContinued";
185 case SpvOpSource:
186 return "Source";
187 case SpvOpSourceExtension:
188 return "SourceExtension";
189 case SpvOpName:
190 return "Name";
191 case SpvOpMemberName:
192 return "MemberName";
193 case SpvOpString:
194 return "String";
195 case SpvOpLine:
196 return "Line";
197 case SpvOpExtension:
198 return "Extension";
199 case SpvOpExtInstImport:
200 return "ExtInstImport";
201 case SpvOpExtInst:
202 return "ExtInst";
203 case SpvOpMemoryModel:
204 return "MemoryModel";
205 case SpvOpEntryPoint:
206 return "EntryPoint";
207 case SpvOpExecutionMode:
208 return "ExecutionMode";
209 case SpvOpCapability:
210 return "Capability";
211 case SpvOpTypeVoid:
212 return "TypeVoid";
213 case SpvOpTypeBool:
214 return "TypeBool";
215 case SpvOpTypeInt:
216 return "TypeInt";
217 case SpvOpTypeFloat:
218 return "TypeFloat";
219 case SpvOpTypeVector:
220 return "TypeVector";
221 case SpvOpTypeMatrix:
222 return "TypeMatrix";
223 case SpvOpTypeImage:
224 return "TypeImage";
225 case SpvOpTypeSampler:
226 return "TypeSampler";
227 case SpvOpTypeSampledImage:
228 return "TypeSampledImage";
229 case SpvOpTypeArray:
230 return "TypeArray";
231 case SpvOpTypeRuntimeArray:
232 return "TypeRuntimeArray";
233 case SpvOpTypeStruct:
234 return "TypeStruct";
235 case SpvOpTypeOpaque:
236 return "TypeOpaque";
237 case SpvOpTypePointer:
238 return "TypePointer";
239 case SpvOpTypeFunction:
240 return "TypeFunction";
241 case SpvOpTypeEvent:
242 return "TypeEvent";
243 case SpvOpTypeDeviceEvent:
244 return "TypeDeviceEvent";
245 case SpvOpTypeReserveId:
246 return "TypeReserveId";
247 case SpvOpTypeQueue:
248 return "TypeQueue";
249 case SpvOpTypePipe:
250 return "TypePipe";
251 case SpvOpTypeForwardPointer:
252 return "TypeForwardPointer";
253 case SpvOpConstantTrue:
254 return "ConstantTrue";
255 case SpvOpConstantFalse:
256 return "ConstantFalse";
257 case SpvOpConstant:
258 return "Constant";
259 case SpvOpConstantComposite:
260 return "ConstantComposite";
261 case SpvOpConstantSampler:
262 return "ConstantSampler";
263 case SpvOpConstantNull:
264 return "ConstantNull";
265 case SpvOpSpecConstantTrue:
266 return "SpecConstantTrue";
267 case SpvOpSpecConstantFalse:
268 return "SpecConstantFalse";
269 case SpvOpSpecConstant:
270 return "SpecConstant";
271 case SpvOpSpecConstantComposite:
272 return "SpecConstantComposite";
273 case SpvOpSpecConstantOp:
274 return "SpecConstantOp";
275 case SpvOpFunction:
276 return "Function";
277 case SpvOpFunctionParameter:
278 return "FunctionParameter";
279 case SpvOpFunctionEnd:
280 return "FunctionEnd";
281 case SpvOpFunctionCall:
282 return "FunctionCall";
283 case SpvOpVariable:
284 return "Variable";
285 case SpvOpImageTexelPointer:
286 return "ImageTexelPointer";
287 case SpvOpLoad:
288 return "Load";
289 case SpvOpStore:
290 return "Store";
291 case SpvOpCopyMemory:
292 return "CopyMemory";
293 case SpvOpCopyMemorySized:
294 return "CopyMemorySized";
295 case SpvOpAccessChain:
296 return "AccessChain";
297 case SpvOpInBoundsAccessChain:
298 return "InBoundsAccessChain";
299 case SpvOpPtrAccessChain:
300 return "PtrAccessChain";
301 case SpvOpArrayLength:
302 return "ArrayLength";
303 case SpvOpGenericPtrMemSemantics:
304 return "GenericPtrMemSemantics";
305 case SpvOpInBoundsPtrAccessChain:
306 return "InBoundsPtrAccessChain";
307 case SpvOpDecorate:
308 return "Decorate";
309 case SpvOpMemberDecorate:
310 return "MemberDecorate";
311 case SpvOpDecorationGroup:
312 return "DecorationGroup";
313 case SpvOpGroupDecorate:
314 return "GroupDecorate";
315 case SpvOpGroupMemberDecorate:
316 return "GroupMemberDecorate";
317 case SpvOpVectorExtractDynamic:
318 return "VectorExtractDynamic";
319 case SpvOpVectorInsertDynamic:
320 return "VectorInsertDynamic";
321 case SpvOpVectorShuffle:
322 return "VectorShuffle";
323 case SpvOpCompositeConstruct:
324 return "CompositeConstruct";
325 case SpvOpCompositeExtract:
326 return "CompositeExtract";
327 case SpvOpCompositeInsert:
328 return "CompositeInsert";
329 case SpvOpCopyObject:
330 return "CopyObject";
331 case SpvOpTranspose:
332 return "Transpose";
333 case SpvOpSampledImage:
334 return "SampledImage";
335 case SpvOpImageSampleImplicitLod:
336 return "ImageSampleImplicitLod";
337 case SpvOpImageSampleExplicitLod:
338 return "ImageSampleExplicitLod";
339 case SpvOpImageSampleDrefImplicitLod:
340 return "ImageSampleDrefImplicitLod";
341 case SpvOpImageSampleDrefExplicitLod:
342 return "ImageSampleDrefExplicitLod";
343 case SpvOpImageSampleProjImplicitLod:
344 return "ImageSampleProjImplicitLod";
345 case SpvOpImageSampleProjExplicitLod:
346 return "ImageSampleProjExplicitLod";
347 case SpvOpImageSampleProjDrefImplicitLod:
348 return "ImageSampleProjDrefImplicitLod";
349 case SpvOpImageSampleProjDrefExplicitLod:
350 return "ImageSampleProjDrefExplicitLod";
351 case SpvOpImageFetch:
352 return "ImageFetch";
353 case SpvOpImageGather:
354 return "ImageGather";
355 case SpvOpImageDrefGather:
356 return "ImageDrefGather";
357 case SpvOpImageRead:
358 return "ImageRead";
359 case SpvOpImageWrite:
360 return "ImageWrite";
361 case SpvOpImage:
362 return "Image";
363 case SpvOpImageQueryFormat:
364 return "ImageQueryFormat";
365 case SpvOpImageQueryOrder:
366 return "ImageQueryOrder";
367 case SpvOpImageQuerySizeLod:
368 return "ImageQuerySizeLod";
369 case SpvOpImageQuerySize:
370 return "ImageQuerySize";
371 case SpvOpImageQueryLod:
372 return "ImageQueryLod";
373 case SpvOpImageQueryLevels:
374 return "ImageQueryLevels";
375 case SpvOpImageQuerySamples:
376 return "ImageQuerySamples";
377 case SpvOpConvertFToU:
378 return "ConvertFToU";
379 case SpvOpConvertFToS:
380 return "ConvertFToS";
381 case SpvOpConvertSToF:
382 return "ConvertSToF";
383 case SpvOpConvertUToF:
384 return "ConvertUToF";
385 case SpvOpUConvert:
386 return "UConvert";
387 case SpvOpSConvert:
388 return "SConvert";
389 case SpvOpFConvert:
390 return "FConvert";
391 case SpvOpQuantizeToF16:
392 return "QuantizeToF16";
393 case SpvOpConvertPtrToU:
394 return "ConvertPtrToU";
395 case SpvOpSatConvertSToU:
396 return "SatConvertSToU";
397 case SpvOpSatConvertUToS:
398 return "SatConvertUToS";
399 case SpvOpConvertUToPtr:
400 return "ConvertUToPtr";
401 case SpvOpPtrCastToGeneric:
402 return "PtrCastToGeneric";
403 case SpvOpGenericCastToPtr:
404 return "GenericCastToPtr";
405 case SpvOpGenericCastToPtrExplicit:
406 return "GenericCastToPtrExplicit";
407 case SpvOpBitcast:
408 return "Bitcast";
409 case SpvOpSNegate:
410 return "SNegate";
411 case SpvOpFNegate:
412 return "FNegate";
413 case SpvOpIAdd:
414 return "IAdd";
415 case SpvOpFAdd:
416 return "FAdd";
417 case SpvOpISub:
418 return "ISub";
419 case SpvOpFSub:
420 return "FSub";
421 case SpvOpIMul:
422 return "IMul";
423 case SpvOpFMul:
424 return "FMul";
425 case SpvOpUDiv:
426 return "UDiv";
427 case SpvOpSDiv:
428 return "SDiv";
429 case SpvOpFDiv:
430 return "FDiv";
431 case SpvOpUMod:
432 return "UMod";
433 case SpvOpSRem:
434 return "SRem";
435 case SpvOpSMod:
436 return "SMod";
437 case SpvOpFRem:
438 return "FRem";
439 case SpvOpFMod:
440 return "FMod";
441 case SpvOpVectorTimesScalar:
442 return "VectorTimesScalar";
443 case SpvOpMatrixTimesScalar:
444 return "MatrixTimesScalar";
445 case SpvOpVectorTimesMatrix:
446 return "VectorTimesMatrix";
447 case SpvOpMatrixTimesVector:
448 return "MatrixTimesVector";
449 case SpvOpMatrixTimesMatrix:
450 return "MatrixTimesMatrix";
451 case SpvOpOuterProduct:
452 return "OuterProduct";
453 case SpvOpDot:
454 return "Dot";
455 case SpvOpIAddCarry:
456 return "IAddCarry";
457 case SpvOpISubBorrow:
458 return "ISubBorrow";
459 case SpvOpUMulExtended:
460 return "UMulExtended";
461 case SpvOpSMulExtended:
462 return "SMulExtended";
463 case SpvOpAny:
464 return "Any";
465 case SpvOpAll:
466 return "All";
467 case SpvOpIsNan:
468 return "IsNan";
469 case SpvOpIsInf:
470 return "IsInf";
471 case SpvOpIsFinite:
472 return "IsFinite";
473 case SpvOpIsNormal:
474 return "IsNormal";
475 case SpvOpSignBitSet:
476 return "SignBitSet";
477 case SpvOpLessOrGreater:
478 return "LessOrGreater";
479 case SpvOpOrdered:
480 return "Ordered";
481 case SpvOpUnordered:
482 return "Unordered";
483 case SpvOpLogicalEqual:
484 return "LogicalEqual";
485 case SpvOpLogicalNotEqual:
486 return "LogicalNotEqual";
487 case SpvOpLogicalOr:
488 return "LogicalOr";
489 case SpvOpLogicalAnd:
490 return "LogicalAnd";
491 case SpvOpLogicalNot:
492 return "LogicalNot";
493 case SpvOpSelect:
494 return "Select";
495 case SpvOpIEqual:
496 return "IEqual";
497 case SpvOpINotEqual:
498 return "INotEqual";
499 case SpvOpUGreaterThan:
500 return "UGreaterThan";
501 case SpvOpSGreaterThan:
502 return "SGreaterThan";
503 case SpvOpUGreaterThanEqual:
504 return "UGreaterThanEqual";
505 case SpvOpSGreaterThanEqual:
506 return "SGreaterThanEqual";
507 case SpvOpULessThan:
508 return "ULessThan";
509 case SpvOpSLessThan:
510 return "SLessThan";
511 case SpvOpULessThanEqual:
512 return "ULessThanEqual";
513 case SpvOpSLessThanEqual:
514 return "SLessThanEqual";
515 case SpvOpFOrdEqual:
516 return "FOrdEqual";
517 case SpvOpFUnordEqual:
518 return "FUnordEqual";
519 case SpvOpFOrdNotEqual:
520 return "FOrdNotEqual";
521 case SpvOpFUnordNotEqual:
522 return "FUnordNotEqual";
523 case SpvOpFOrdLessThan:
524 return "FOrdLessThan";
525 case SpvOpFUnordLessThan:
526 return "FUnordLessThan";
527 case SpvOpFOrdGreaterThan:
528 return "FOrdGreaterThan";
529 case SpvOpFUnordGreaterThan:
530 return "FUnordGreaterThan";
531 case SpvOpFOrdLessThanEqual:
532 return "FOrdLessThanEqual";
533 case SpvOpFUnordLessThanEqual:
534 return "FUnordLessThanEqual";
535 case SpvOpFOrdGreaterThanEqual:
536 return "FOrdGreaterThanEqual";
537 case SpvOpFUnordGreaterThanEqual:
538 return "FUnordGreaterThanEqual";
539 case SpvOpShiftRightLogical:
540 return "ShiftRightLogical";
541 case SpvOpShiftRightArithmetic:
542 return "ShiftRightArithmetic";
543 case SpvOpShiftLeftLogical:
544 return "ShiftLeftLogical";
545 case SpvOpBitwiseOr:
546 return "BitwiseOr";
547 case SpvOpBitwiseXor:
548 return "BitwiseXor";
549 case SpvOpBitwiseAnd:
550 return "BitwiseAnd";
551 case SpvOpNot:
552 return "Not";
553 case SpvOpBitFieldInsert:
554 return "BitFieldInsert";
555 case SpvOpBitFieldSExtract:
556 return "BitFieldSExtract";
557 case SpvOpBitFieldUExtract:
558 return "BitFieldUExtract";
559 case SpvOpBitReverse:
560 return "BitReverse";
561 case SpvOpBitCount:
562 return "BitCount";
563 case SpvOpDPdx:
564 return "DPdx";
565 case SpvOpDPdy:
566 return "DPdy";
567 case SpvOpFwidth:
568 return "Fwidth";
569 case SpvOpDPdxFine:
570 return "DPdxFine";
571 case SpvOpDPdyFine:
572 return "DPdyFine";
573 case SpvOpFwidthFine:
574 return "FwidthFine";
575 case SpvOpDPdxCoarse:
576 return "DPdxCoarse";
577 case SpvOpDPdyCoarse:
578 return "DPdyCoarse";
579 case SpvOpFwidthCoarse:
580 return "FwidthCoarse";
581 case SpvOpEmitVertex:
582 return "EmitVertex";
583 case SpvOpEndPrimitive:
584 return "EndPrimitive";
585 case SpvOpEmitStreamVertex:
586 return "EmitStreamVertex";
587 case SpvOpEndStreamPrimitive:
588 return "EndStreamPrimitive";
589 case SpvOpControlBarrier:
590 return "ControlBarrier";
591 case SpvOpMemoryBarrier:
592 return "MemoryBarrier";
593 case SpvOpAtomicLoad:
594 return "AtomicLoad";
595 case SpvOpAtomicStore:
596 return "AtomicStore";
597 case SpvOpAtomicExchange:
598 return "AtomicExchange";
599 case SpvOpAtomicCompareExchange:
600 return "AtomicCompareExchange";
601 case SpvOpAtomicCompareExchangeWeak:
602 return "AtomicCompareExchangeWeak";
603 case SpvOpAtomicIIncrement:
604 return "AtomicIIncrement";
605 case SpvOpAtomicIDecrement:
606 return "AtomicIDecrement";
607 case SpvOpAtomicIAdd:
608 return "AtomicIAdd";
609 case SpvOpAtomicISub:
610 return "AtomicISub";
611 case SpvOpAtomicSMin:
612 return "AtomicSMin";
613 case SpvOpAtomicUMin:
614 return "AtomicUMin";
615 case SpvOpAtomicSMax:
616 return "AtomicSMax";
617 case SpvOpAtomicUMax:
618 return "AtomicUMax";
619 case SpvOpAtomicAnd:
620 return "AtomicAnd";
621 case SpvOpAtomicOr:
622 return "AtomicOr";
623 case SpvOpAtomicXor:
624 return "AtomicXor";
625 case SpvOpPhi:
626 return "Phi";
627 case SpvOpLoopMerge:
628 return "LoopMerge";
629 case SpvOpSelectionMerge:
630 return "SelectionMerge";
631 case SpvOpLabel:
632 return "Label";
633 case SpvOpBranch:
634 return "Branch";
635 case SpvOpBranchConditional:
636 return "BranchConditional";
637 case SpvOpSwitch:
638 return "Switch";
639 case SpvOpKill:
640 return "Kill";
641 case SpvOpReturn:
642 return "Return";
643 case SpvOpReturnValue:
644 return "ReturnValue";
645 case SpvOpUnreachable:
646 return "Unreachable";
647 case SpvOpLifetimeStart:
648 return "LifetimeStart";
649 case SpvOpLifetimeStop:
650 return "LifetimeStop";
651 case SpvOpGroupAsyncCopy:
652 return "GroupAsyncCopy";
653 case SpvOpGroupWaitEvents:
654 return "GroupWaitEvents";
655 case SpvOpGroupAll:
656 return "GroupAll";
657 case SpvOpGroupAny:
658 return "GroupAny";
659 case SpvOpGroupBroadcast:
660 return "GroupBroadcast";
661 case SpvOpGroupIAdd:
662 return "GroupIAdd";
663 case SpvOpGroupFAdd:
664 return "GroupFAdd";
665 case SpvOpGroupFMin:
666 return "GroupFMin";
667 case SpvOpGroupUMin:
668 return "GroupUMin";
669 case SpvOpGroupSMin:
670 return "GroupSMin";
671 case SpvOpGroupFMax:
672 return "GroupFMax";
673 case SpvOpGroupUMax:
674 return "GroupUMax";
675 case SpvOpGroupSMax:
676 return "GroupSMax";
677 case SpvOpReadPipe:
678 return "ReadPipe";
679 case SpvOpWritePipe:
680 return "WritePipe";
681 case SpvOpReservedReadPipe:
682 return "ReservedReadPipe";
683 case SpvOpReservedWritePipe:
684 return "ReservedWritePipe";
685 case SpvOpReserveReadPipePackets:
686 return "ReserveReadPipePackets";
687 case SpvOpReserveWritePipePackets:
688 return "ReserveWritePipePackets";
689 case SpvOpCommitReadPipe:
690 return "CommitReadPipe";
691 case SpvOpCommitWritePipe:
692 return "CommitWritePipe";
693 case SpvOpIsValidReserveId:
694 return "IsValidReserveId";
695 case SpvOpGetNumPipePackets:
696 return "GetNumPipePackets";
697 case SpvOpGetMaxPipePackets:
698 return "GetMaxPipePackets";
699 case SpvOpGroupReserveReadPipePackets:
700 return "GroupReserveReadPipePackets";
701 case SpvOpGroupReserveWritePipePackets:
702 return "GroupReserveWritePipePackets";
703 case SpvOpGroupCommitReadPipe:
704 return "GroupCommitReadPipe";
705 case SpvOpGroupCommitWritePipe:
706 return "GroupCommitWritePipe";
707 case SpvOpEnqueueMarker:
708 return "EnqueueMarker";
709 case SpvOpEnqueueKernel:
710 return "EnqueueKernel";
711 case SpvOpGetKernelNDrangeSubGroupCount:
712 return "GetKernelNDrangeSubGroupCount";
713 case SpvOpGetKernelNDrangeMaxSubGroupSize:
714 return "GetKernelNDrangeMaxSubGroupSize";
715 case SpvOpGetKernelWorkGroupSize:
716 return "GetKernelWorkGroupSize";
717 case SpvOpGetKernelPreferredWorkGroupSizeMultiple:
718 return "GetKernelPreferredWorkGroupSizeMultiple";
719 case SpvOpRetainEvent:
720 return "RetainEvent";
721 case SpvOpReleaseEvent:
722 return "ReleaseEvent";
723 case SpvOpCreateUserEvent:
724 return "CreateUserEvent";
725 case SpvOpIsValidEvent:
726 return "IsValidEvent";
727 case SpvOpSetUserEventStatus:
728 return "SetUserEventStatus";
729 case SpvOpCaptureEventProfilingInfo:
730 return "CaptureEventProfilingInfo";
731 case SpvOpGetDefaultQueue:
732 return "GetDefaultQueue";
733 case SpvOpBuildNDRange:
734 return "BuildNDRange";
735 case SpvOpImageSparseSampleImplicitLod:
736 return "ImageSparseSampleImplicitLod";
737 case SpvOpImageSparseSampleExplicitLod:
738 return "ImageSparseSampleExplicitLod";
739 case SpvOpImageSparseSampleDrefImplicitLod:
740 return "ImageSparseSampleDrefImplicitLod";
741 case SpvOpImageSparseSampleDrefExplicitLod:
742 return "ImageSparseSampleDrefExplicitLod";
743 case SpvOpImageSparseSampleProjImplicitLod:
744 return "ImageSparseSampleProjImplicitLod";
745 case SpvOpImageSparseSampleProjExplicitLod:
746 return "ImageSparseSampleProjExplicitLod";
747 case SpvOpImageSparseSampleProjDrefImplicitLod:
748 return "ImageSparseSampleProjDrefImplicitLod";
749 case SpvOpImageSparseSampleProjDrefExplicitLod:
750 return "ImageSparseSampleProjDrefExplicitLod";
751 case SpvOpImageSparseFetch:
752 return "ImageSparseFetch";
753 case SpvOpImageSparseGather:
754 return "ImageSparseGather";
755 case SpvOpImageSparseDrefGather:
756 return "ImageSparseDrefGather";
757 case SpvOpImageSparseTexelsResident:
758 return "ImageSparseTexelsResident";
759 case SpvOpNoLine:
760 return "NoLine";
761 case SpvOpAtomicFlagTestAndSet:
762 return "AtomicFlagTestAndSet";
763 case SpvOpAtomicFlagClear:
764 return "AtomicFlagClear";
765 case SpvOpImageSparseRead:
766 return "ImageSparseRead";
767 default:
768 ABORT("unsupported SPIR-V op");
769 }
770}
771#endif
772
773void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, std::ostream& out) {
774 ASSERT(opCode != SpvOpUndef);
775 switch (opCode) {
776 case SpvOpReturn: // fall through
777 case SpvOpReturnValue: // fall through
ethannicholas552882f2016-07-07 06:30:48 -0700778 case SpvOpKill: // fall through
ethannicholasb3058bd2016-07-01 08:22:01 -0700779 case SpvOpBranch: // fall through
780 case SpvOpBranchConditional:
781 ASSERT(fCurrentBlock);
782 fCurrentBlock = 0;
783 break;
784 case SpvOpConstant: // fall through
785 case SpvOpConstantTrue: // fall through
786 case SpvOpConstantFalse: // fall through
787 case SpvOpConstantComposite: // fall through
788 case SpvOpTypeVoid: // fall through
789 case SpvOpTypeInt: // fall through
790 case SpvOpTypeFloat: // fall through
791 case SpvOpTypeBool: // fall through
792 case SpvOpTypeVector: // fall through
793 case SpvOpTypeMatrix: // fall through
794 case SpvOpTypeArray: // fall through
795 case SpvOpTypePointer: // fall through
796 case SpvOpTypeFunction: // fall through
797 case SpvOpTypeRuntimeArray: // fall through
798 case SpvOpTypeStruct: // fall through
799 case SpvOpTypeImage: // fall through
800 case SpvOpTypeSampledImage: // fall through
801 case SpvOpVariable: // fall through
802 case SpvOpFunction: // fall through
803 case SpvOpFunctionParameter: // fall through
804 case SpvOpFunctionEnd: // fall through
805 case SpvOpExecutionMode: // fall through
806 case SpvOpMemoryModel: // fall through
807 case SpvOpCapability: // fall through
808 case SpvOpExtInstImport: // fall through
809 case SpvOpEntryPoint: // fall through
810 case SpvOpSource: // fall through
811 case SpvOpSourceExtension: // fall through
812 case SpvOpName: // fall through
813 case SpvOpMemberName: // fall through
814 case SpvOpDecorate: // fall through
815 case SpvOpMemberDecorate:
816 break;
817 default:
818 ASSERT(fCurrentBlock);
819 }
820#if SPIRV_DEBUG
821 out << std::endl << opcode_text(opCode) << " ";
822#else
823 this->writeWord((length << 16) | opCode, out);
824#endif
825}
826
827void SPIRVCodeGenerator::writeLabel(SpvId label, std::ostream& out) {
828 fCurrentBlock = label;
829 this->writeInstruction(SpvOpLabel, label, out);
830}
831
832void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::ostream& out) {
833 this->writeOpCode(opCode, 1, out);
834}
835
836void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::ostream& out) {
837 this->writeOpCode(opCode, 2, out);
838 this->writeWord(word1, out);
839}
840
841void SPIRVCodeGenerator::writeString(const char* string, std::ostream& out) {
842 size_t length = strlen(string);
843 out << string;
844 switch (length % 4) {
845 case 1:
846 out << (char) 0;
847 // fall through
848 case 2:
849 out << (char) 0;
850 // fall through
851 case 3:
852 out << (char) 0;
853 break;
854 default:
855 this->writeWord(0, out);
856 }
857}
858
859void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, const char* string, std::ostream& out) {
860 int32_t length = (int32_t) strlen(string);
861 this->writeOpCode(opCode, 1 + (length + 4) / 4, out);
862 this->writeString(string, out);
863}
864
865
866void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, const char* string,
867 std::ostream& out) {
868 int32_t length = (int32_t) strlen(string);
869 this->writeOpCode(opCode, 2 + (length + 4) / 4, out);
870 this->writeWord(word1, out);
871 this->writeString(string, out);
872}
873
874void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
875 const char* string, std::ostream& out) {
876 int32_t length = (int32_t) strlen(string);
877 this->writeOpCode(opCode, 3 + (length + 4) / 4, out);
878 this->writeWord(word1, out);
879 this->writeWord(word2, out);
880 this->writeString(string, out);
881}
882
883void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
884 std::ostream& out) {
885 this->writeOpCode(opCode, 3, out);
886 this->writeWord(word1, out);
887 this->writeWord(word2, out);
888}
889
890void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
891 int32_t word3, std::ostream& out) {
892 this->writeOpCode(opCode, 4, out);
893 this->writeWord(word1, out);
894 this->writeWord(word2, out);
895 this->writeWord(word3, out);
896}
897
898void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
899 int32_t word3, int32_t word4, std::ostream& out) {
900 this->writeOpCode(opCode, 5, out);
901 this->writeWord(word1, out);
902 this->writeWord(word2, out);
903 this->writeWord(word3, out);
904 this->writeWord(word4, out);
905}
906
907void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
908 int32_t word3, int32_t word4, int32_t word5,
909 std::ostream& out) {
910 this->writeOpCode(opCode, 6, out);
911 this->writeWord(word1, out);
912 this->writeWord(word2, out);
913 this->writeWord(word3, out);
914 this->writeWord(word4, out);
915 this->writeWord(word5, out);
916}
917
918void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
919 int32_t word3, int32_t word4, int32_t word5,
920 int32_t word6, std::ostream& out) {
921 this->writeOpCode(opCode, 7, out);
922 this->writeWord(word1, out);
923 this->writeWord(word2, out);
924 this->writeWord(word3, out);
925 this->writeWord(word4, out);
926 this->writeWord(word5, out);
927 this->writeWord(word6, out);
928}
929
930void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
931 int32_t word3, int32_t word4, int32_t word5,
932 int32_t word6, int32_t word7, std::ostream& out) {
933 this->writeOpCode(opCode, 8, out);
934 this->writeWord(word1, out);
935 this->writeWord(word2, out);
936 this->writeWord(word3, out);
937 this->writeWord(word4, out);
938 this->writeWord(word5, out);
939 this->writeWord(word6, out);
940 this->writeWord(word7, out);
941}
942
943void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
944 int32_t word3, int32_t word4, int32_t word5,
945 int32_t word6, int32_t word7, int32_t word8,
946 std::ostream& out) {
947 this->writeOpCode(opCode, 9, out);
948 this->writeWord(word1, out);
949 this->writeWord(word2, out);
950 this->writeWord(word3, out);
951 this->writeWord(word4, out);
952 this->writeWord(word5, out);
953 this->writeWord(word6, out);
954 this->writeWord(word7, out);
955 this->writeWord(word8, out);
956}
957
958void SPIRVCodeGenerator::writeCapabilities(std::ostream& out) {
959 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
960 if (fCapabilities & bit) {
961 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
962 }
963 }
964}
965
966SpvId SPIRVCodeGenerator::nextId() {
967 return fIdCount++;
968}
969
970void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) {
971 this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
972 // go ahead and write all of the field types, so we don't inadvertently write them while we're
973 // in the middle of writing the struct instruction
974 std::vector<SpvId> types;
975 for (const auto& f : type.fields()) {
976 types.push_back(this->getType(*f.fType));
977 }
978 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
979 this->writeWord(resultId, fConstantBuffer);
980 for (SpvId id : types) {
981 this->writeWord(id, fConstantBuffer);
982 }
983 size_t offset = 0;
984 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
985 size_t size = type.fields()[i].fType->size();
986 size_t alignment = type.fields()[i].fType->alignment();
987 size_t mod = offset % alignment;
988 if (mod != 0) {
989 offset += alignment - mod;
990 }
991 this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName.c_str(),
992 fNameBuffer);
993 this->writeLayout(type.fields()[i].fModifiers.fLayout, resultId, i);
994 if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
995 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
996 (SpvId) offset, fDecorationBuffer);
997 }
998 if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
999 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1000 fDecorationBuffer);
1001 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1002 (SpvId) type.fields()[i].fType->stride(), fDecorationBuffer);
1003 }
1004 offset += size;
1005 Type::Kind kind = type.fields()[i].fType->kind();
1006 if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
1007 offset += alignment - offset % alignment;
1008 }
1009 ASSERT(offset % alignment == 0);
1010 }
1011}
1012
1013SpvId SPIRVCodeGenerator::getType(const Type& type) {
1014 auto entry = fTypeMap.find(type.name());
1015 if (entry == fTypeMap.end()) {
1016 SpvId result = this->nextId();
1017 switch (type.kind()) {
1018 case Type::kScalar_Kind:
1019 if (type == *kBool_Type) {
1020 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
1021 } else if (type == *kInt_Type) {
1022 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
1023 } else if (type == *kUInt_Type) {
1024 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
1025 } else if (type == *kFloat_Type) {
1026 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
1027 } else if (type == *kDouble_Type) {
1028 this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
1029 } else {
1030 ASSERT(false);
1031 }
1032 break;
1033 case Type::kVector_Kind:
1034 this->writeInstruction(SpvOpTypeVector, result,
1035 this->getType(*type.componentType()),
1036 type.columns(), fConstantBuffer);
1037 break;
1038 case Type::kMatrix_Kind:
1039 this->writeInstruction(SpvOpTypeMatrix, result, this->getType(*index_type(type)),
1040 type.columns(), fConstantBuffer);
1041 break;
1042 case Type::kStruct_Kind:
1043 this->writeStruct(type, result);
1044 break;
1045 case Type::kArray_Kind: {
1046 if (type.columns() > 0) {
1047 IntLiteral count(Position(), type.columns());
1048 this->writeInstruction(SpvOpTypeArray, result,
1049 this->getType(*type.componentType()),
1050 this->writeIntLiteral(count), fConstantBuffer);
1051 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
1052 (int32_t) type.stride(), fDecorationBuffer);
1053 } else {
1054 ABORT("runtime-sized arrays are not yet supported");
1055 this->writeInstruction(SpvOpTypeRuntimeArray, result,
1056 this->getType(*type.componentType()), fConstantBuffer);
1057 }
1058 break;
1059 }
1060 case Type::kSampler_Kind: {
1061 SpvId image = this->nextId();
1062 this->writeInstruction(SpvOpTypeImage, image, this->getType(*kFloat_Type),
1063 type.dimensions(), type.isDepth(), type.isArrayed(),
1064 type.isMultisampled(), type.isSampled(),
1065 SpvImageFormatUnknown, fConstantBuffer);
1066 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
1067 break;
1068 }
1069 default:
1070 if (type == *kVoid_Type) {
1071 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
1072 } else {
1073 ABORT("invalid type: %s", type.description().c_str());
1074 }
1075 }
1076 fTypeMap[type.name()] = result;
1077 return result;
1078 }
1079 return entry->second;
1080}
1081
1082SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> function) {
1083 std::string key = function->fReturnType->description() + "(";
1084 std::string separator = "";
1085 for (size_t i = 0; i < function->fParameters.size(); i++) {
1086 key += separator;
1087 separator = ", ";
1088 key += function->fParameters[i]->fType->description();
1089 }
1090 key += ")";
1091 auto entry = fTypeMap.find(key);
1092 if (entry == fTypeMap.end()) {
1093 SpvId result = this->nextId();
1094 int32_t length = 3 + (int32_t) function->fParameters.size();
1095 SpvId returnType = this->getType(*function->fReturnType);
1096 std::vector<SpvId> parameterTypes;
1097 for (size_t i = 0; i < function->fParameters.size(); i++) {
1098 // glslang seems to treat all function arguments as pointers whether they need to be or
1099 // not. I was initially puzzled by this until I ran bizarre failures with certain
1100 // patterns of function calls and control constructs, as exemplified by this minimal
1101 // failure case:
1102 //
1103 // void sphere(float x) {
1104 // }
1105 //
1106 // void map() {
1107 // sphere(1.0);
1108 // }
1109 //
1110 // void main() {
1111 // for (int i = 0; i < 1; i++) {
1112 // map();
1113 // }
1114 // }
1115 //
1116 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1117 // crashes. Making it take a float* and storing the argument in a temporary variable,
1118 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
1119 // the spec makes this make sense.
1120// if (is_out(function->fParameters[i])) {
1121 parameterTypes.push_back(this->getPointerType(function->fParameters[i]->fType,
1122 SpvStorageClassFunction));
1123// } else {
1124// parameterTypes.push_back(this->getType(*function->fParameters[i]->fType));
1125// }
1126 }
1127 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
1128 this->writeWord(result, fConstantBuffer);
1129 this->writeWord(returnType, fConstantBuffer);
1130 for (SpvId id : parameterTypes) {
1131 this->writeWord(id, fConstantBuffer);
1132 }
1133 fTypeMap[key] = result;
1134 return result;
1135 }
1136 return entry->second;
1137}
1138
1139SpvId SPIRVCodeGenerator::getPointerType(std::shared_ptr<Type> type,
1140 SpvStorageClass_ storageClass) {
1141 std::string key = type->description() + "*" + to_string(storageClass);
1142 auto entry = fTypeMap.find(key);
1143 if (entry == fTypeMap.end()) {
1144 SpvId result = this->nextId();
1145 this->writeInstruction(SpvOpTypePointer, result, storageClass,
1146 this->getType(*type), fConstantBuffer);
1147 fTypeMap[key] = result;
1148 return result;
1149 }
1150 return entry->second;
1151}
1152
1153SpvId SPIRVCodeGenerator::writeExpression(Expression& expr, std::ostream& out) {
1154 switch (expr.fKind) {
1155 case Expression::kBinary_Kind:
1156 return this->writeBinaryExpression((BinaryExpression&) expr, out);
1157 case Expression::kBoolLiteral_Kind:
1158 return this->writeBoolLiteral((BoolLiteral&) expr);
1159 case Expression::kConstructor_Kind:
1160 return this->writeConstructor((Constructor&) expr, out);
1161 case Expression::kIntLiteral_Kind:
1162 return this->writeIntLiteral((IntLiteral&) expr);
1163 case Expression::kFieldAccess_Kind:
1164 return this->writeFieldAccess(((FieldAccess&) expr), out);
1165 case Expression::kFloatLiteral_Kind:
1166 return this->writeFloatLiteral(((FloatLiteral&) expr));
1167 case Expression::kFunctionCall_Kind:
1168 return this->writeFunctionCall((FunctionCall&) expr, out);
1169 case Expression::kPrefix_Kind:
1170 return this->writePrefixExpression((PrefixExpression&) expr, out);
1171 case Expression::kPostfix_Kind:
1172 return this->writePostfixExpression((PostfixExpression&) expr, out);
1173 case Expression::kSwizzle_Kind:
1174 return this->writeSwizzle((Swizzle&) expr, out);
1175 case Expression::kVariableReference_Kind:
1176 return this->writeVariableReference((VariableReference&) expr, out);
1177 case Expression::kTernary_Kind:
1178 return this->writeTernaryExpression((TernaryExpression&) expr, out);
1179 case Expression::kIndex_Kind:
1180 return this->writeIndexExpression((IndexExpression&) expr, out);
1181 default:
1182 ABORT("unsupported expression: %s", expr.description().c_str());
1183 }
1184 return -1;
1185}
1186
1187SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) {
1188 auto intrinsic = fIntrinsicMap.find(c.fFunction->fName);
1189 ASSERT(intrinsic != fIntrinsicMap.end());
1190 std::shared_ptr<Type> type = c.fArguments[0]->fType;
1191 int32_t intrinsicId;
1192 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(*type)) {
1193 intrinsicId = std::get<1>(intrinsic->second);
1194 } else if (is_signed(*type)) {
1195 intrinsicId = std::get<2>(intrinsic->second);
1196 } else if (is_unsigned(*type)) {
1197 intrinsicId = std::get<3>(intrinsic->second);
1198 } else if (is_bool(*type)) {
1199 intrinsicId = std::get<4>(intrinsic->second);
1200 } else {
1201 ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
1202 type->description().c_str());
1203 }
1204 switch (std::get<0>(intrinsic->second)) {
1205 case kGLSL_STD_450_IntrinsicKind: {
1206 SpvId result = this->nextId();
1207 std::vector<SpvId> arguments;
1208 for (size_t i = 0; i < c.fArguments.size(); i++) {
1209 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1210 }
1211 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1212 this->writeWord(this->getType(*c.fType), out);
1213 this->writeWord(result, out);
1214 this->writeWord(fGLSLExtendedInstructions, out);
1215 this->writeWord(intrinsicId, out);
1216 for (SpvId id : arguments) {
1217 this->writeWord(id, out);
1218 }
1219 return result;
1220 }
1221 case kSPIRV_IntrinsicKind: {
1222 SpvId result = this->nextId();
1223 std::vector<SpvId> arguments;
1224 for (size_t i = 0; i < c.fArguments.size(); i++) {
1225 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1226 }
1227 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1228 this->writeWord(this->getType(*c.fType), out);
1229 this->writeWord(result, out);
1230 for (SpvId id : arguments) {
1231 this->writeWord(id, out);
1232 }
1233 return result;
1234 }
1235 case kSpecial_IntrinsicKind:
1236 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1237 default:
1238 ABORT("unsupported intrinsic kind");
1239 }
1240}
1241
1242SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsic kind,
1243 std::ostream& out) {
1244 SpvId result = this->nextId();
1245 switch (kind) {
1246 case kAtan_SpecialIntrinsic: {
1247 std::vector<SpvId> arguments;
1248 for (size_t i = 0; i < c.fArguments.size(); i++) {
1249 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1250 }
1251 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1252 this->writeWord(this->getType(*c.fType), out);
1253 this->writeWord(result, out);
1254 this->writeWord(fGLSLExtendedInstructions, out);
1255 this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
1256 for (SpvId id : arguments) {
1257 this->writeWord(id, out);
1258 }
1259 return result;
1260 }
1261 case kTexture_SpecialIntrinsic: {
1262 SpvId type = this->getType(*c.fType);
1263 SpvId sampler = this->writeExpression(*c.fArguments[0], out);
1264 SpvId uv = this->writeExpression(*c.fArguments[1], out);
1265 if (c.fArguments.size() == 3) {
1266 this->writeInstruction(SpvOpImageSampleImplicitLod, type, result, sampler, uv,
1267 SpvImageOperandsBiasMask,
1268 this->writeExpression(*c.fArguments[2], out),
1269 out);
1270 } else {
1271 ASSERT(c.fArguments.size() == 2);
1272 this->writeInstruction(SpvOpImageSampleImplicitLod, type, result, sampler, uv, out);
1273 }
1274 break;
1275 }
1276 case kTextureProj_SpecialIntrinsic: {
1277 SpvId type = this->getType(*c.fType);
1278 SpvId sampler = this->writeExpression(*c.fArguments[0], out);
1279 SpvId uv = this->writeExpression(*c.fArguments[1], out);
1280 if (c.fArguments.size() == 3) {
1281 this->writeInstruction(SpvOpImageSampleProjImplicitLod, type, result, sampler, uv,
1282 SpvImageOperandsBiasMask,
1283 this->writeExpression(*c.fArguments[2], out),
1284 out);
1285 } else {
1286 ASSERT(c.fArguments.size() == 2);
1287 this->writeInstruction(SpvOpImageSampleProjImplicitLod, type, result, sampler, uv,
1288 out);
1289 }
1290 break;
1291 }
1292 case kTexture2D_SpecialIntrinsic: {
1293 SpvId img = this->writeExpression(*c.fArguments[0], out);
1294 SpvId coords = this->writeExpression(*c.fArguments[1], out);
1295 this->writeInstruction(SpvOpImageSampleImplicitLod,
1296 this->getType(*c.fType),
1297 result,
1298 img,
1299 coords,
1300 out);
1301 break;
1302 }
1303 }
1304 return result;
1305}
1306
1307SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) {
1308 const auto& entry = fFunctionMap.find(c.fFunction);
1309 if (entry == fFunctionMap.end()) {
1310 return this->writeIntrinsicCall(c, out);
1311 }
1312 // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
1313 std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
1314 std::vector<SpvId> arguments;
1315 for (size_t i = 0; i < c.fArguments.size(); i++) {
1316 // id of temporary variable that we will use to hold this argument, or 0 if it is being
1317 // passed directly
1318 SpvId tmpVar;
1319 // if we need a temporary var to store this argument, this is the value to store in the var
1320 SpvId tmpValueId;
1321 if (is_out(c.fFunction->fParameters[i])) {
1322 std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
1323 SpvId ptr = lv->getPointer();
1324 if (ptr) {
1325 arguments.push_back(ptr);
1326 continue;
1327 } else {
1328 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1329 // copy it into a temp, call the function, read the value out of the temp, and then
1330 // update the lvalue.
1331 tmpValueId = lv->load(out);
1332 tmpVar = this->nextId();
1333 lvalues.push_back(std::make_tuple(tmpVar, this->getType(*c.fArguments[i]->fType),
1334 std::move(lv)));
1335 }
1336 } else {
1337 // see getFunctionType for an explanation of why we're always using pointer parameters
1338 tmpValueId = this->writeExpression(*c.fArguments[i], out);
1339 tmpVar = this->nextId();
1340 }
1341 this->writeInstruction(SpvOpVariable,
1342 this->getPointerType(c.fArguments[i]->fType,
1343 SpvStorageClassFunction),
1344 tmpVar,
1345 SpvStorageClassFunction,
1346 out);
1347 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1348 arguments.push_back(tmpVar);
1349 }
1350 SpvId result = this->nextId();
1351 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1352 this->writeWord(this->getType(*c.fType), out);
1353 this->writeWord(result, out);
1354 this->writeWord(entry->second, out);
1355 for (SpvId id : arguments) {
1356 this->writeWord(id, out);
1357 }
1358 // now that the call is complete, we may need to update some lvalues with the new values of out
1359 // arguments
1360 for (const auto& tuple : lvalues) {
1361 SpvId load = this->nextId();
1362 this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1363 std::get<2>(tuple)->store(load, out);
1364 }
1365 return result;
1366}
1367
1368SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) {
1369 ASSERT(c.fType->kind() == Type::kVector_Kind && c.isConstant());
1370 SpvId result = this->nextId();
1371 std::vector<SpvId> arguments;
1372 for (size_t i = 0; i < c.fArguments.size(); i++) {
1373 arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1374 }
1375 SpvId type = this->getType(*c.fType);
1376 if (c.fArguments.size() == 1) {
1377 // with a single argument, a vector will have all of its entries equal to the argument
1378 this->writeOpCode(SpvOpConstantComposite, 3 + c.fType->columns(), fConstantBuffer);
1379 this->writeWord(type, fConstantBuffer);
1380 this->writeWord(result, fConstantBuffer);
1381 for (int i = 0; i < c.fType->columns(); i++) {
1382 this->writeWord(arguments[0], fConstantBuffer);
1383 }
1384 } else {
1385 this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1386 fConstantBuffer);
1387 this->writeWord(type, fConstantBuffer);
1388 this->writeWord(result, fConstantBuffer);
1389 for (SpvId id : arguments) {
1390 this->writeWord(id, fConstantBuffer);
1391 }
1392 }
1393 return result;
1394}
1395
1396SpvId SPIRVCodeGenerator::writeFloatConstructor(Constructor& c, std::ostream& out) {
1397 ASSERT(c.fType == kFloat_Type);
1398 ASSERT(c.fArguments.size() == 1);
1399 ASSERT(c.fArguments[0]->fType->isNumber());
1400 SpvId result = this->nextId();
1401 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1402 if (c.fArguments[0]->fType == kInt_Type) {
1403 this->writeInstruction(SpvOpConvertSToF, this->getType(*c.fType), result, parameter,
1404 out);
1405 } else if (c.fArguments[0]->fType == kUInt_Type) {
1406 this->writeInstruction(SpvOpConvertUToF, this->getType(*c.fType), result, parameter,
1407 out);
1408 } else if (c.fArguments[0]->fType == kFloat_Type) {
1409 return parameter;
1410 }
1411 return result;
1412}
1413
1414SpvId SPIRVCodeGenerator::writeIntConstructor(Constructor& c, std::ostream& out) {
1415 ASSERT(c.fType == kInt_Type);
1416 ASSERT(c.fArguments.size() == 1);
1417 ASSERT(c.fArguments[0]->fType->isNumber());
1418 SpvId result = this->nextId();
1419 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1420 if (c.fArguments[0]->fType == kFloat_Type) {
1421 this->writeInstruction(SpvOpConvertFToS, this->getType(*c.fType), result, parameter,
1422 out);
1423 } else if (c.fArguments[0]->fType == kUInt_Type) {
1424 this->writeInstruction(SpvOpSatConvertUToS, this->getType(*c.fType), result, parameter,
1425 out);
1426 } else if (c.fArguments[0]->fType == kInt_Type) {
1427 return parameter;
1428 }
1429 return result;
1430}
1431
1432SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& out) {
1433 ASSERT(c.fType->kind() == Type::kMatrix_Kind);
1434 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1435 // an instruction
1436 std::vector<SpvId> arguments;
1437 for (size_t i = 0; i < c.fArguments.size(); i++) {
1438 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1439 }
1440 SpvId result = this->nextId();
1441 int rows = c.fType->rows();
1442 int columns = c.fType->columns();
1443 // FIXME this won't work to create a matrix from another matrix
1444 if (arguments.size() == 1) {
1445 // with a single argument, a matrix will have all of its diagonal entries equal to the
1446 // argument and its other values equal to zero
1447 // FIXME this won't work for int matrices
1448 FloatLiteral zero(Position(), 0);
1449 SpvId zeroId = this->writeFloatLiteral(zero);
1450 std::vector<SpvId> columnIds;
1451 for (int column = 0; column < columns; column++) {
1452 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(),
1453 out);
1454 this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), out);
1455 SpvId columnId = this->nextId();
1456 this->writeWord(columnId, out);
1457 columnIds.push_back(columnId);
1458 for (int row = 0; row < c.fType->columns(); row++) {
1459 this->writeWord(row == column ? arguments[0] : zeroId, out);
1460 }
1461 }
1462 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns,
1463 out);
1464 this->writeWord(this->getType(*c.fType), out);
1465 this->writeWord(result, out);
1466 for (SpvId id : columnIds) {
1467 this->writeWord(id, out);
1468 }
1469 } else {
1470 std::vector<SpvId> columnIds;
1471 int currentCount = 0;
1472 for (size_t i = 0; i < arguments.size(); i++) {
1473 if (c.fArguments[i]->fType->kind() == Type::kVector_Kind) {
1474 ASSERT(currentCount == 0);
1475 columnIds.push_back(arguments[i]);
1476 currentCount = 0;
1477 } else {
1478 ASSERT(c.fArguments[i]->fType->kind() == Type::kScalar_Kind);
1479 if (currentCount == 0) {
1480 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), out);
1481 this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)),
1482 out);
1483 SpvId id = this->nextId();
1484 this->writeWord(id, out);
1485 columnIds.push_back(id);
1486 }
1487 this->writeWord(arguments[i], out);
1488 currentCount = (currentCount + 1) % rows;
1489 }
1490 }
1491 ASSERT(columnIds.size() == (size_t) columns);
1492 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1493 this->writeWord(this->getType(*c.fType), out);
1494 this->writeWord(result, out);
1495 for (SpvId id : columnIds) {
1496 this->writeWord(id, out);
1497 }
1498 }
1499 return result;
1500}
1501
1502SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& out) {
1503 ASSERT(c.fType->kind() == Type::kVector_Kind);
1504 if (c.isConstant()) {
1505 return this->writeConstantVector(c);
1506 }
1507 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1508 // an instruction
1509 std::vector<SpvId> arguments;
1510 for (size_t i = 0; i < c.fArguments.size(); i++) {
1511 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1512 }
1513 SpvId result = this->nextId();
1514 if (arguments.size() == 1 && c.fArguments[0]->fType->kind() == Type::kScalar_Kind) {
1515 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->columns(), out);
1516 this->writeWord(this->getType(*c.fType), out);
1517 this->writeWord(result, out);
1518 for (int i = 0; i < c.fType->columns(); i++) {
1519 this->writeWord(arguments[0], out);
1520 }
1521 } else {
1522 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1523 this->writeWord(this->getType(*c.fType), out);
1524 this->writeWord(result, out);
1525 for (SpvId id : arguments) {
1526 this->writeWord(id, out);
1527 }
1528 }
1529 return result;
1530}
1531
1532SpvId SPIRVCodeGenerator::writeConstructor(Constructor& c, std::ostream& out) {
1533 if (c.fType == kFloat_Type) {
1534 return this->writeFloatConstructor(c, out);
1535 } else if (c.fType == kInt_Type) {
1536 return this->writeIntConstructor(c, out);
1537 }
1538 switch (c.fType->kind()) {
1539 case Type::kVector_Kind:
1540 return this->writeVectorConstructor(c, out);
1541 case Type::kMatrix_Kind:
1542 return this->writeMatrixConstructor(c, out);
1543 default:
1544 ABORT("unsupported constructor: %s", c.description().c_str());
1545 }
1546}
1547
1548SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1549 if (modifiers.fFlags & Modifiers::kIn_Flag) {
1550 return SpvStorageClassInput;
1551 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1552 return SpvStorageClassOutput;
1553 } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1554 return SpvStorageClassUniform;
1555 } else {
1556 return SpvStorageClassFunction;
1557 }
1558}
1559
1560SpvStorageClass_ get_storage_class(Expression& expr) {
1561 switch (expr.fKind) {
1562 case Expression::kVariableReference_Kind:
1563 return get_storage_class(((VariableReference&) expr).fVariable->fModifiers);
1564 case Expression::kFieldAccess_Kind:
1565 return get_storage_class(*((FieldAccess&) expr).fBase);
1566 case Expression::kIndex_Kind:
1567 return get_storage_class(*((IndexExpression&) expr).fBase);
1568 default:
1569 return SpvStorageClassFunction;
1570 }
1571}
1572
1573std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(Expression& expr, std::ostream& out) {
1574 std::vector<SpvId> chain;
1575 switch (expr.fKind) {
1576 case Expression::kIndex_Kind: {
1577 IndexExpression& indexExpr = (IndexExpression&) expr;
1578 chain = this->getAccessChain(*indexExpr.fBase, out);
1579 chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1580 break;
1581 }
1582 case Expression::kFieldAccess_Kind: {
1583 FieldAccess& fieldExpr = (FieldAccess&) expr;
1584 chain = this->getAccessChain(*fieldExpr.fBase, out);
1585 IntLiteral index(Position(), fieldExpr.fFieldIndex);
1586 chain.push_back(this->writeIntLiteral(index));
1587 break;
1588 }
1589 default:
1590 chain.push_back(this->getLValue(expr, out)->getPointer());
1591 }
1592 return chain;
1593}
1594
1595class PointerLValue : public SPIRVCodeGenerator::LValue {
1596public:
1597 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1598 : fGen(gen)
1599 , fPointer(pointer)
1600 , fType(type) {}
1601
1602 virtual SpvId getPointer() override {
1603 return fPointer;
1604 }
1605
1606 virtual SpvId load(std::ostream& out) override {
1607 SpvId result = fGen.nextId();
1608 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1609 return result;
1610 }
1611
1612 virtual void store(SpvId value, std::ostream& out) override {
1613 fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1614 }
1615
1616private:
1617 SPIRVCodeGenerator& fGen;
1618 const SpvId fPointer;
1619 const SpvId fType;
1620};
1621
1622class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1623public:
1624 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1625 const Type& baseType, const Type& swizzleType)
1626 : fGen(gen)
1627 , fVecPointer(vecPointer)
1628 , fComponents(components)
1629 , fBaseType(baseType)
1630 , fSwizzleType(swizzleType) {}
1631
1632 virtual SpvId getPointer() override {
1633 return 0;
1634 }
1635
1636 virtual SpvId load(std::ostream& out) override {
1637 SpvId base = fGen.nextId();
1638 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1639 SpvId result = fGen.nextId();
1640 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1641 fGen.writeWord(fGen.getType(fSwizzleType), out);
1642 fGen.writeWord(result, out);
1643 fGen.writeWord(base, out);
1644 fGen.writeWord(base, out);
1645 for (int component : fComponents) {
1646 fGen.writeWord(component, out);
1647 }
1648 return result;
1649 }
1650
1651 virtual void store(SpvId value, std::ostream& out) override {
1652 // use OpVectorShuffle to mix and match the vector components. We effectively create
1653 // a virtual vector out of the concatenation of the left and right vectors, and then
1654 // select components from this virtual vector to make the result vector. For
1655 // instance, given:
1656 // vec3 L = ...;
1657 // vec3 R = ...;
1658 // L.xz = R.xy;
1659 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1660 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1661 // (3, 1, 4).
1662 SpvId base = fGen.nextId();
1663 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1664 SpvId shuffle = fGen.nextId();
1665 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1666 fGen.writeWord(fGen.getType(fBaseType), out);
1667 fGen.writeWord(shuffle, out);
1668 fGen.writeWord(base, out);
1669 fGen.writeWord(value, out);
1670 for (int i = 0; i < fBaseType.columns(); i++) {
1671 // current offset into the virtual vector, defaults to pulling the unmodified
1672 // value from the left side
1673 int offset = i;
1674 // check to see if we are writing this component
1675 for (size_t j = 0; j < fComponents.size(); j++) {
1676 if (fComponents[j] == i) {
1677 // we're writing to this component, so adjust the offset to pull from
1678 // the correct component of the right side instead of preserving the
1679 // value from the left
1680 offset = (int) (j + fBaseType.columns());
1681 break;
1682 }
1683 }
1684 fGen.writeWord(offset, out);
1685 }
1686 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1687 }
1688
1689private:
1690 SPIRVCodeGenerator& fGen;
1691 const SpvId fVecPointer;
1692 const std::vector<int>& fComponents;
1693 const Type& fBaseType;
1694 const Type& fSwizzleType;
1695};
1696
1697std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expression& expr,
1698 std::ostream& out) {
1699 switch (expr.fKind) {
1700 case Expression::kVariableReference_Kind: {
1701 std::shared_ptr<Variable> var = ((VariableReference&) expr).fVariable;
1702 auto entry = fVariableMap.find(var);
1703 ASSERT(entry != fVariableMap.end());
1704 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1705 *this,
1706 entry->second,
1707 this->getType(*expr.fType)));
1708 }
1709 case Expression::kIndex_Kind: // fall through
1710 case Expression::kFieldAccess_Kind: {
1711 std::vector<SpvId> chain = this->getAccessChain(expr, out);
1712 SpvId member = this->nextId();
1713 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1714 this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1715 this->writeWord(member, out);
1716 for (SpvId idx : chain) {
1717 this->writeWord(idx, out);
1718 }
1719 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1720 *this,
1721 member,
1722 this->getType(*expr.fType)));
1723 }
1724
1725 case Expression::kSwizzle_Kind: {
1726 Swizzle& swizzle = (Swizzle&) expr;
1727 size_t count = swizzle.fComponents.size();
1728 SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1729 ASSERT(base);
1730 if (count == 1) {
1731 IntLiteral index(Position(), swizzle.fComponents[0]);
1732 SpvId member = this->nextId();
1733 this->writeInstruction(SpvOpAccessChain,
1734 this->getPointerType(swizzle.fType,
1735 get_storage_class(*swizzle.fBase)),
1736 member,
1737 base,
1738 this->writeIntLiteral(index),
1739 out);
1740 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1741 *this,
1742 member,
1743 this->getType(*expr.fType)));
1744 } else {
1745 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1746 *this,
1747 base,
1748 swizzle.fComponents,
1749 *swizzle.fBase->fType,
1750 *expr.fType));
1751 }
1752 }
1753
1754 default:
1755 // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1756 // to the need to store values in temporary variables during function calls (see
1757 // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1758 // caught by IRGenerator
1759 SpvId result = this->nextId();
1760 SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1761 this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, out);
1762 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1763 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1764 *this,
1765 result,
1766 this->getType(*expr.fType)));
1767 }
1768}
1769
1770SpvId SPIRVCodeGenerator::writeVariableReference(VariableReference& ref, std::ostream& out) {
1771 auto entry = fVariableMap.find(ref.fVariable);
1772 ASSERT(entry != fVariableMap.end());
1773 SpvId var = entry->second;
1774 SpvId result = this->nextId();
1775 this->writeInstruction(SpvOpLoad, this->getType(*ref.fVariable->fType), result, var, out);
1776 return result;
1777}
1778
1779SpvId SPIRVCodeGenerator::writeIndexExpression(IndexExpression& expr, std::ostream& out) {
1780 return getLValue(expr, out)->load(out);
1781}
1782
1783SpvId SPIRVCodeGenerator::writeFieldAccess(FieldAccess& f, std::ostream& out) {
1784 return getLValue(f, out)->load(out);
1785}
1786
1787SpvId SPIRVCodeGenerator::writeSwizzle(Swizzle& swizzle, std::ostream& out) {
1788 SpvId base = this->writeExpression(*swizzle.fBase, out);
1789 SpvId result = this->nextId();
1790 size_t count = swizzle.fComponents.size();
1791 if (count == 1) {
1792 this->writeInstruction(SpvOpCompositeExtract, this->getType(*swizzle.fType), result, base,
1793 swizzle.fComponents[0], out);
1794 } else {
1795 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1796 this->writeWord(this->getType(*swizzle.fType), out);
1797 this->writeWord(result, out);
1798 this->writeWord(base, out);
1799 this->writeWord(base, out);
1800 for (int component : swizzle.fComponents) {
1801 this->writeWord(component, out);
1802 }
1803 }
1804 return result;
1805}
1806
1807SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1808 const Type& operandType, SpvId lhs,
1809 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1810 SpvOp_ ifUInt, SpvOp_ ifBool, std::ostream& out) {
1811 SpvId result = this->nextId();
1812 if (is_float(operandType)) {
1813 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1814 } else if (is_signed(operandType)) {
1815 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1816 } else if (is_unsigned(operandType)) {
1817 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1818 } else if (operandType == *kBool_Type) {
1819 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1820 } else {
1821 ABORT("invalid operandType: %s", operandType.description().c_str());
1822 }
1823 return result;
1824}
1825
1826bool is_assignment(Token::Kind op) {
1827 switch (op) {
1828 case Token::EQ: // fall through
1829 case Token::PLUSEQ: // fall through
1830 case Token::MINUSEQ: // fall through
1831 case Token::STAREQ: // fall through
1832 case Token::SLASHEQ: // fall through
1833 case Token::PERCENTEQ: // fall through
1834 case Token::SHLEQ: // fall through
1835 case Token::SHREQ: // fall through
1836 case Token::BITWISEOREQ: // fall through
1837 case Token::BITWISEXOREQ: // fall through
1838 case Token::BITWISEANDEQ: // fall through
1839 case Token::LOGICALOREQ: // fall through
1840 case Token::LOGICALXOREQ: // fall through
1841 case Token::LOGICALANDEQ:
1842 return true;
1843 default:
1844 return false;
1845 }
1846}
1847
1848SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostream& out) {
1849 // handle cases where we don't necessarily evaluate both LHS and RHS
1850 switch (b.fOperator) {
1851 case Token::EQ: {
1852 SpvId rhs = this->writeExpression(*b.fRight, out);
1853 this->getLValue(*b.fLeft, out)->store(rhs, out);
1854 return rhs;
1855 }
1856 case Token::LOGICALAND:
1857 return this->writeLogicalAnd(b, out);
1858 case Token::LOGICALOR:
1859 return this->writeLogicalOr(b, out);
1860 default:
1861 break;
1862 }
1863
1864 // "normal" operators
1865 const Type& resultType = *b.fType;
1866 std::unique_ptr<LValue> lvalue;
1867 SpvId lhs;
1868 if (is_assignment(b.fOperator)) {
1869 lvalue = this->getLValue(*b.fLeft, out);
1870 lhs = lvalue->load(out);
1871 } else {
1872 lvalue = nullptr;
1873 lhs = this->writeExpression(*b.fLeft, out);
1874 }
1875 SpvId rhs = this->writeExpression(*b.fRight, out);
1876 // component type we are operating on: float, int, uint
1877 const Type* operandType;
1878 // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
1879 // in SPIR-V
1880 if (b.fLeft->fType != b.fRight->fType) {
1881 if (b.fLeft->fType->kind() == Type::kVector_Kind &&
1882 b.fRight->fType->isNumber()) {
1883 // promote number to vector
1884 SpvId vec = this->nextId();
1885 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out);
1886 this->writeWord(this->getType(resultType), out);
1887 this->writeWord(vec, out);
1888 for (int i = 0; i < resultType.columns(); i++) {
1889 this->writeWord(rhs, out);
1890 }
1891 rhs = vec;
1892 operandType = b.fRight->fType.get();
1893 } else if (b.fRight->fType->kind() == Type::kVector_Kind &&
1894 b.fLeft->fType->isNumber()) {
1895 // promote number to vector
1896 SpvId vec = this->nextId();
1897 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out);
1898 this->writeWord(this->getType(resultType), out);
1899 this->writeWord(vec, out);
1900 for (int i = 0; i < resultType.columns(); i++) {
1901 this->writeWord(lhs, out);
1902 }
1903 lhs = vec;
1904 ASSERT(!lvalue);
1905 operandType = b.fLeft->fType.get();
1906 } else if (b.fLeft->fType->kind() == Type::kMatrix_Kind) {
1907 SpvOp_ op;
1908 if (b.fRight->fType->kind() == Type::kMatrix_Kind) {
1909 op = SpvOpMatrixTimesMatrix;
1910 } else if (b.fRight->fType->kind() == Type::kVector_Kind) {
1911 op = SpvOpMatrixTimesVector;
1912 } else {
1913 ASSERT(b.fRight->fType->kind() == Type::kScalar_Kind);
1914 op = SpvOpMatrixTimesScalar;
1915 }
1916 SpvId result = this->nextId();
1917 this->writeInstruction(op, this->getType(*b.fType), result, lhs, rhs, out);
1918 if (b.fOperator == Token::STAREQ) {
1919 lvalue->store(result, out);
1920 } else {
1921 ASSERT(b.fOperator == Token::STAR);
1922 }
1923 return result;
1924 } else if (b.fRight->fType->kind() == Type::kMatrix_Kind) {
1925 SpvId result = this->nextId();
1926 if (b.fLeft->fType->kind() == Type::kVector_Kind) {
1927 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(*b.fType), result,
1928 lhs, rhs, out);
1929 } else {
1930 ASSERT(b.fLeft->fType->kind() == Type::kScalar_Kind);
1931 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(*b.fType), result, rhs,
1932 lhs, out);
1933 }
1934 if (b.fOperator == Token::STAREQ) {
1935 lvalue->store(result, out);
1936 } else {
1937 ASSERT(b.fOperator == Token::STAR);
1938 }
1939 return result;
1940 } else {
1941 ABORT("unsupported binary expression: %s", b.description().c_str());
1942 }
1943 } else {
1944 operandType = b.fLeft->fType.get();
1945 ASSERT(*operandType == *b.fRight->fType);
1946 }
1947 switch (b.fOperator) {
1948 case Token::EQEQ:
1949 ASSERT(resultType == *kBool_Type);
1950 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdEqual,
1951 SpvOpIEqual, SpvOpIEqual, SpvOpLogicalEqual, out);
1952 case Token::NEQ:
1953 ASSERT(resultType == *kBool_Type);
1954 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdNotEqual,
1955 SpvOpINotEqual, SpvOpINotEqual, SpvOpLogicalNotEqual,
1956 out);
1957 case Token::GT:
1958 ASSERT(resultType == *kBool_Type);
1959 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1960 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
1961 SpvOpUGreaterThan, SpvOpUndef, out);
1962 case Token::LT:
1963 ASSERT(resultType == *kBool_Type);
1964 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
1965 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
1966 case Token::GTEQ:
1967 ASSERT(resultType == *kBool_Type);
1968 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1969 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
1970 SpvOpUGreaterThanEqual, SpvOpUndef, out);
1971 case Token::LTEQ:
1972 ASSERT(resultType == *kBool_Type);
1973 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1974 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
1975 SpvOpULessThanEqual, SpvOpUndef, out);
1976 case Token::PLUS:
1977 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
1978 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
1979 case Token::MINUS:
1980 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
1981 SpvOpISub, SpvOpISub, SpvOpUndef, out);
1982 case Token::STAR:
1983 if (b.fLeft->fType->kind() == Type::kMatrix_Kind &&
1984 b.fRight->fType->kind() == Type::kMatrix_Kind) {
1985 // matrix multiply
1986 SpvId result = this->nextId();
1987 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
1988 lhs, rhs, out);
1989 return result;
1990 }
1991 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
1992 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
1993 case Token::SLASH:
1994 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
1995 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
1996 case Token::PLUSEQ: {
1997 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
1998 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
1999 ASSERT(lvalue);
2000 lvalue->store(result, out);
2001 return result;
2002 }
2003 case Token::MINUSEQ: {
2004 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2005 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2006 ASSERT(lvalue);
2007 lvalue->store(result, out);
2008 return result;
2009 }
2010 case Token::STAREQ: {
2011 if (b.fLeft->fType->kind() == Type::kMatrix_Kind &&
2012 b.fRight->fType->kind() == Type::kMatrix_Kind) {
2013 // matrix multiply
2014 SpvId result = this->nextId();
2015 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2016 lhs, rhs, out);
2017 ASSERT(lvalue);
2018 lvalue->store(result, out);
2019 return result;
2020 }
2021 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2022 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2023 ASSERT(lvalue);
2024 lvalue->store(result, out);
2025 return result;
2026 }
2027 case Token::SLASHEQ: {
2028 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2029 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2030 ASSERT(lvalue);
2031 lvalue->store(result, out);
2032 return result;
2033 }
2034 default:
2035 // FIXME: missing support for some operators (bitwise, &&=, ||=, shift...)
2036 ABORT("unsupported binary expression: %s", b.description().c_str());
2037 }
2038}
2039
2040SpvId SPIRVCodeGenerator::writeLogicalAnd(BinaryExpression& a, std::ostream& out) {
2041 ASSERT(a.fOperator == Token::LOGICALAND);
2042 BoolLiteral falseLiteral(Position(), false);
2043 SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2044 SpvId lhs = this->writeExpression(*a.fLeft, out);
2045 SpvId rhsLabel = this->nextId();
2046 SpvId end = this->nextId();
2047 SpvId lhsBlock = fCurrentBlock;
2048 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2049 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2050 this->writeLabel(rhsLabel, out);
2051 SpvId rhs = this->writeExpression(*a.fRight, out);
2052 SpvId rhsBlock = fCurrentBlock;
2053 this->writeInstruction(SpvOpBranch, end, out);
2054 this->writeLabel(end, out);
2055 SpvId result = this->nextId();
2056 this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, falseConstant, lhsBlock,
2057 rhs, rhsBlock, out);
2058 return result;
2059}
2060
2061SpvId SPIRVCodeGenerator::writeLogicalOr(BinaryExpression& o, std::ostream& out) {
2062 ASSERT(o.fOperator == Token::LOGICALOR);
2063 BoolLiteral trueLiteral(Position(), true);
2064 SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2065 SpvId lhs = this->writeExpression(*o.fLeft, out);
2066 SpvId rhsLabel = this->nextId();
2067 SpvId end = this->nextId();
2068 SpvId lhsBlock = fCurrentBlock;
2069 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2070 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2071 this->writeLabel(rhsLabel, out);
2072 SpvId rhs = this->writeExpression(*o.fRight, out);
2073 SpvId rhsBlock = fCurrentBlock;
2074 this->writeInstruction(SpvOpBranch, end, out);
2075 this->writeLabel(end, out);
2076 SpvId result = this->nextId();
2077 this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, trueConstant, lhsBlock,
2078 rhs, rhsBlock, out);
2079 return result;
2080}
2081
2082SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostream& out) {
2083 SpvId test = this->writeExpression(*t.fTest, out);
2084 if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2085 // both true and false are constants, can just use OpSelect
2086 SpvId result = this->nextId();
2087 SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2088 SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2089 this->writeInstruction(SpvOpSelect, this->getType(*t.fType), result, test, trueId, falseId,
2090 out);
2091 return result;
2092 }
2093 // was originally using OpPhi to choose the result, but for some reason that is crashing on
2094 // Adreno. Switched to storing the result in a temp variable as glslang does.
2095 SpvId var = this->nextId();
2096 this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2097 var, SpvStorageClassFunction, out);
2098 SpvId trueLabel = this->nextId();
2099 SpvId falseLabel = this->nextId();
2100 SpvId end = this->nextId();
2101 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2102 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2103 this->writeLabel(trueLabel, out);
2104 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2105 this->writeInstruction(SpvOpBranch, end, out);
2106 this->writeLabel(falseLabel, out);
2107 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2108 this->writeInstruction(SpvOpBranch, end, out);
2109 this->writeLabel(end, out);
2110 SpvId result = this->nextId();
2111 this->writeInstruction(SpvOpLoad, this->getType(*t.fType), result, var, out);
2112 return result;
2113}
2114
2115Expression* literal_1(const Type& type) {
2116 static IntLiteral int1(Position(), 1);
2117 static FloatLiteral float1(Position(), 1.0);
2118 if (type == *kInt_Type) {
2119 return &int1;
2120 }
2121 else if (type == *kFloat_Type) {
2122 return &float1;
2123 } else {
2124 ABORT("math is unsupported on type '%s'")
2125 }
2126}
2127
2128SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostream& out) {
2129 if (p.fOperator == Token::MINUS) {
2130 SpvId result = this->nextId();
2131 SpvId typeId = this->getType(*p.fType);
2132 SpvId expr = this->writeExpression(*p.fOperand, out);
2133 if (is_float(*p.fType)) {
2134 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2135 } else if (is_signed(*p.fType)) {
2136 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2137 } else {
2138 ABORT("unsupported prefix expression %s", p.description().c_str());
2139 };
2140 return result;
2141 }
2142 switch (p.fOperator) {
2143 case Token::PLUS:
2144 return this->writeExpression(*p.fOperand, out);
2145 case Token::PLUSPLUS: {
2146 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2147 SpvId one = this->writeExpression(*literal_1(*p.fType), out);
2148 SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one,
2149 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2150 out);
2151 lv->store(result, out);
2152 return result;
2153 }
2154 case Token::MINUSMINUS: {
2155 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2156 SpvId one = this->writeExpression(*literal_1(*p.fType), out);
2157 SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one,
2158 SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2159 out);
2160 lv->store(result, out);
2161 return result;
2162 }
2163 case Token::NOT: {
2164 ASSERT(p.fOperand->fType == kBool_Type);
2165 SpvId result = this->nextId();
2166 this->writeInstruction(SpvOpLogicalNot, this->getType(*p.fOperand->fType), result,
2167 this->writeExpression(*p.fOperand, out), out);
2168 return result;
2169 }
2170 default:
2171 ABORT("unsupported prefix expression: %s", p.description().c_str());
2172 }
2173}
2174
2175SpvId SPIRVCodeGenerator::writePostfixExpression(PostfixExpression& p, std::ostream& out) {
2176 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2177 SpvId result = lv->load(out);
2178 SpvId one = this->writeExpression(*literal_1(*p.fType), out);
2179 switch (p.fOperator) {
2180 case Token::PLUSPLUS: {
2181 SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFAdd,
2182 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2183 lv->store(temp, out);
2184 return result;
2185 }
2186 case Token::MINUSMINUS: {
2187 SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFSub,
2188 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2189 lv->store(temp, out);
2190 return result;
2191 }
2192 default:
2193 ABORT("unsupported postfix expression %s", p.description().c_str());
2194 }
2195}
2196
2197SpvId SPIRVCodeGenerator::writeBoolLiteral(BoolLiteral& b) {
2198 if (b.fValue) {
2199 if (fBoolTrue == 0) {
2200 fBoolTrue = this->nextId();
2201 this->writeInstruction(SpvOpConstantTrue, this->getType(*b.fType), fBoolTrue,
2202 fConstantBuffer);
2203 }
2204 return fBoolTrue;
2205 } else {
2206 if (fBoolFalse == 0) {
2207 fBoolFalse = this->nextId();
2208 this->writeInstruction(SpvOpConstantFalse, this->getType(*b.fType), fBoolFalse,
2209 fConstantBuffer);
2210 }
2211 return fBoolFalse;
2212 }
2213}
2214
2215SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) {
2216 if (i.fType == kInt_Type) {
2217 auto entry = fIntConstants.find(i.fValue);
2218 if (entry == fIntConstants.end()) {
2219 SpvId result = this->nextId();
2220 this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue,
2221 fConstantBuffer);
2222 fIntConstants[i.fValue] = result;
2223 return result;
2224 }
2225 return entry->second;
2226 } else {
2227 ASSERT(i.fType == kUInt_Type);
2228 auto entry = fUIntConstants.find(i.fValue);
2229 if (entry == fUIntConstants.end()) {
2230 SpvId result = this->nextId();
2231 this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue,
2232 fConstantBuffer);
2233 fUIntConstants[i.fValue] = result;
2234 return result;
2235 }
2236 return entry->second;
2237 }
2238}
2239
2240SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) {
2241 if (f.fType == kFloat_Type) {
2242 float value = (float) f.fValue;
2243 auto entry = fFloatConstants.find(value);
2244 if (entry == fFloatConstants.end()) {
2245 SpvId result = this->nextId();
2246 uint32_t bits;
2247 ASSERT(sizeof(bits) == sizeof(value));
2248 memcpy(&bits, &value, sizeof(bits));
2249 this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, bits,
2250 fConstantBuffer);
2251 fFloatConstants[value] = result;
2252 return result;
2253 }
2254 return entry->second;
2255 } else {
2256 ASSERT(f.fType == kDouble_Type);
2257 auto entry = fDoubleConstants.find(f.fValue);
2258 if (entry == fDoubleConstants.end()) {
2259 SpvId result = this->nextId();
2260 uint64_t bits;
2261 ASSERT(sizeof(bits) == sizeof(f.fValue));
2262 memcpy(&bits, &f.fValue, sizeof(bits));
2263 this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result,
2264 bits & 0xffffffff, bits >> 32, fConstantBuffer);
2265 fDoubleConstants[f.fValue] = result;
2266 return result;
2267 }
2268 return entry->second;
2269 }
2270}
2271
2272SpvId SPIRVCodeGenerator::writeFunctionStart(std::shared_ptr<FunctionDeclaration> f,
2273 std::ostream& out) {
2274 SpvId result = fFunctionMap[f];
2275 this->writeInstruction(SpvOpFunction, this->getType(*f->fReturnType), result,
2276 SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2277 this->writeInstruction(SpvOpName, result, f->fName.c_str(), fNameBuffer);
2278 for (size_t i = 0; i < f->fParameters.size(); i++) {
2279 SpvId id = this->nextId();
2280 fVariableMap[f->fParameters[i]] = id;
2281 SpvId type;
2282 type = this->getPointerType(f->fParameters[i]->fType, SpvStorageClassFunction);
2283 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2284 }
2285 return result;
2286}
2287
2288SpvId SPIRVCodeGenerator::writeFunction(FunctionDefinition& f, std::ostream& out) {
2289 SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2290 this->writeLabel(this->nextId(), out);
2291 if (f.fDeclaration->fName == "main") {
2292 out << fGlobalInitializersBuffer.str();
2293 }
2294 std::stringstream bodyBuffer;
2295 this->writeBlock(*f.fBody, bodyBuffer);
2296 out << fVariableBuffer.str();
2297 fVariableBuffer.str("");
2298 out << bodyBuffer.str();
2299 if (fCurrentBlock) {
2300 this->writeInstruction(SpvOpReturn, out);
2301 }
2302 this->writeInstruction(SpvOpFunctionEnd, out);
2303 return result;
2304}
2305
2306void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2307 if (layout.fLocation >= 0) {
2308 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2309 fDecorationBuffer);
2310 }
2311 if (layout.fBinding >= 0) {
2312 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2313 fDecorationBuffer);
2314 }
2315 if (layout.fIndex >= 0) {
2316 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2317 fDecorationBuffer);
2318 }
2319 if (layout.fSet >= 0) {
2320 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2321 fDecorationBuffer);
2322 }
2323 if (layout.fBuiltin >= 0) {
2324 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2325 fDecorationBuffer);
2326 }
2327}
2328
2329void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2330 if (layout.fLocation >= 0) {
2331 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2332 layout.fLocation, fDecorationBuffer);
2333 }
2334 if (layout.fBinding >= 0) {
2335 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2336 layout.fBinding, fDecorationBuffer);
2337 }
2338 if (layout.fIndex >= 0) {
2339 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2340 layout.fIndex, fDecorationBuffer);
2341 }
2342 if (layout.fSet >= 0) {
2343 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2344 layout.fSet, fDecorationBuffer);
2345 }
2346 if (layout.fBuiltin >= 0) {
2347 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2348 layout.fBuiltin, fDecorationBuffer);
2349 }
2350}
2351
2352SpvId SPIRVCodeGenerator::writeInterfaceBlock(InterfaceBlock& intf) {
2353 SpvId type = this->getType(*intf.fVariable->fType);
2354 SpvId result = this->nextId();
2355 this->writeInstruction(SpvOpDecorate, type, SpvDecorationBlock, fDecorationBuffer);
2356 SpvStorageClass_ storageClass = get_storage_class(intf.fVariable->fModifiers);
2357 SpvId ptrType = this->nextId();
2358 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, type, fConstantBuffer);
2359 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2360 this->writeLayout(intf.fVariable->fModifiers.fLayout, result);
2361 fVariableMap[intf.fVariable] = result;
2362 return result;
2363}
2364
2365void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out) {
2366 for (size_t i = 0; i < decl.fVars.size(); i++) {
2367 if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo) {
2368 continue;
2369 }
2370 SpvStorageClass_ storageClass;
2371 if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kIn_Flag) {
2372 storageClass = SpvStorageClassInput;
2373 } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
2374 storageClass = SpvStorageClassOutput;
2375 } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2376 if (decl.fVars[i]->fType->kind() == Type::kSampler_Kind) {
2377 storageClass = SpvStorageClassUniformConstant;
2378 } else {
2379 storageClass = SpvStorageClassUniform;
2380 }
2381 } else {
2382 storageClass = SpvStorageClassPrivate;
2383 }
2384 SpvId id = this->nextId();
2385 fVariableMap[decl.fVars[i]] = id;
2386 SpvId type = this->getPointerType(decl.fVars[i]->fType, storageClass);
2387 this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2388 this->writeInstruction(SpvOpName, id, decl.fVars[i]->fName.c_str(), fNameBuffer);
2389 if (decl.fVars[i]->fType->kind() == Type::kMatrix_Kind) {
2390 this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor,
2391 fDecorationBuffer);
2392 this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride,
2393 (SpvId) decl.fVars[i]->fType->stride(), fDecorationBuffer);
2394 }
2395 if (decl.fValues[i]) {
2396 ASSERT(!fCurrentBlock);
2397 fCurrentBlock = -1;
2398 SpvId value = this->writeExpression(*decl.fValues[i], fGlobalInitializersBuffer);
2399 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2400 fCurrentBlock = 0;
2401 }
2402 this->writeLayout(decl.fVars[i]->fModifiers.fLayout, id);
2403 }
2404}
2405
2406void SPIRVCodeGenerator::writeVarDeclaration(VarDeclaration& decl, std::ostream& out) {
2407 for (size_t i = 0; i < decl.fVars.size(); i++) {
2408 SpvId id = this->nextId();
2409 fVariableMap[decl.fVars[i]] = id;
2410 SpvId type = this->getPointerType(decl.fVars[i]->fType, SpvStorageClassFunction);
2411 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2412 this->writeInstruction(SpvOpName, id, decl.fVars[i]->fName.c_str(), fNameBuffer);
2413 if (decl.fValues[i]) {
2414 SpvId value = this->writeExpression(*decl.fValues[i], out);
2415 this->writeInstruction(SpvOpStore, id, value, out);
2416 }
2417 }
2418}
2419
2420void SPIRVCodeGenerator::writeStatement(Statement& s, std::ostream& out) {
2421 switch (s.fKind) {
2422 case Statement::kBlock_Kind:
2423 this->writeBlock((Block&) s, out);
2424 break;
2425 case Statement::kExpression_Kind:
2426 this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2427 break;
2428 case Statement::kReturn_Kind:
2429 this->writeReturnStatement((ReturnStatement&) s, out);
2430 break;
2431 case Statement::kVarDeclaration_Kind:
2432 this->writeVarDeclaration(*((VarDeclarationStatement&) s).fDeclaration, out);
2433 break;
2434 case Statement::kIf_Kind:
2435 this->writeIfStatement((IfStatement&) s, out);
2436 break;
2437 case Statement::kFor_Kind:
2438 this->writeForStatement((ForStatement&) s, out);
2439 break;
2440 case Statement::kBreak_Kind:
2441 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2442 break;
2443 case Statement::kContinue_Kind:
2444 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2445 break;
2446 case Statement::kDiscard_Kind:
2447 this->writeInstruction(SpvOpKill, out);
2448 break;
2449 default:
2450 ABORT("unsupported statement: %s", s.description().c_str());
2451 }
2452}
2453
2454void SPIRVCodeGenerator::writeBlock(Block& b, std::ostream& out) {
2455 for (size_t i = 0; i < b.fStatements.size(); i++) {
2456 this->writeStatement(*b.fStatements[i], out);
2457 }
2458}
2459
2460void SPIRVCodeGenerator::writeIfStatement(IfStatement& stmt, std::ostream& out) {
2461 SpvId test = this->writeExpression(*stmt.fTest, out);
2462 SpvId ifTrue = this->nextId();
2463 SpvId ifFalse = this->nextId();
2464 if (stmt.fIfFalse) {
2465 SpvId end = this->nextId();
2466 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2467 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2468 this->writeLabel(ifTrue, out);
2469 this->writeStatement(*stmt.fIfTrue, out);
2470 if (fCurrentBlock) {
2471 this->writeInstruction(SpvOpBranch, end, out);
2472 }
2473 this->writeLabel(ifFalse, out);
2474 this->writeStatement(*stmt.fIfFalse, out);
2475 if (fCurrentBlock) {
2476 this->writeInstruction(SpvOpBranch, end, out);
2477 }
2478 this->writeLabel(end, out);
2479 } else {
2480 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2481 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2482 this->writeLabel(ifTrue, out);
2483 this->writeStatement(*stmt.fIfTrue, out);
2484 if (fCurrentBlock) {
2485 this->writeInstruction(SpvOpBranch, ifFalse, out);
2486 }
2487 this->writeLabel(ifFalse, out);
2488 }
2489}
2490
2491void SPIRVCodeGenerator::writeForStatement(ForStatement& f, std::ostream& out) {
2492 if (f.fInitializer) {
2493 this->writeStatement(*f.fInitializer, out);
2494 }
2495 SpvId header = this->nextId();
2496 SpvId start = this->nextId();
2497 SpvId body = this->nextId();
2498 SpvId next = this->nextId();
2499 fContinueTarget.push(next);
2500 SpvId end = this->nextId();
2501 fBreakTarget.push(end);
2502 this->writeInstruction(SpvOpBranch, header, out);
2503 this->writeLabel(header, out);
2504 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2505 this->writeInstruction(SpvOpBranch, start, out);
2506 this->writeLabel(start, out);
2507 SpvId test = this->writeExpression(*f.fTest, out);
2508 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2509 this->writeLabel(body, out);
2510 this->writeStatement(*f.fStatement, out);
2511 if (fCurrentBlock) {
2512 this->writeInstruction(SpvOpBranch, next, out);
2513 }
2514 this->writeLabel(next, out);
2515 if (f.fNext) {
2516 this->writeExpression(*f.fNext, out);
2517 }
2518 this->writeInstruction(SpvOpBranch, header, out);
2519 this->writeLabel(end, out);
2520 fBreakTarget.pop();
2521 fContinueTarget.pop();
2522}
2523
2524void SPIRVCodeGenerator::writeReturnStatement(ReturnStatement& r, std::ostream& out) {
2525 if (r.fExpression) {
2526 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
2527 out);
2528 } else {
2529 this->writeInstruction(SpvOpReturn, out);
2530 }
2531}
2532
2533void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out) {
2534 fGLSLExtendedInstructions = this->nextId();
2535 std::stringstream body;
2536 std::vector<SpvId> interfaceVars;
2537 // assign IDs to functions
2538 for (size_t i = 0; i < program.fElements.size(); i++) {
2539 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2540 FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
2541 fFunctionMap[f.fDeclaration] = this->nextId();
2542 }
2543 }
2544 for (size_t i = 0; i < program.fElements.size(); i++) {
2545 if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
2546 InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
2547 SpvId id = this->writeInterfaceBlock(intf);
2548 if ((intf.fVariable->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2549 (intf.fVariable->fModifiers.fFlags & Modifiers::kOut_Flag)) {
2550 interfaceVars.push_back(id);
2551 }
2552 }
2553 }
2554 for (size_t i = 0; i < program.fElements.size(); i++) {
2555 if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
2556 this->writeGlobalVars(((VarDeclaration&) *program.fElements[i]), body);
2557 }
2558 }
2559 for (size_t i = 0; i < program.fElements.size(); i++) {
2560 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2561 this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
2562 }
2563 }
2564 std::shared_ptr<FunctionDeclaration> main = nullptr;
2565 for (auto entry : fFunctionMap) {
2566 if (entry.first->fName == "main") {
2567 main = entry.first;
2568 }
2569 }
2570 ASSERT(main);
2571 for (auto entry : fVariableMap) {
2572 std::shared_ptr<Variable> var = entry.first;
2573 if (var->fStorage == Variable::kGlobal_Storage &&
2574 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2575 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
2576 interfaceVars.push_back(entry.second);
2577 }
2578 }
2579 this->writeCapabilities(out);
2580 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
2581 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
2582 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (strlen(main->fName.c_str()) + 4) / 4) +
2583 (int32_t) interfaceVars.size(), out);
2584 switch (program.fKind) {
2585 case Program::kVertex_Kind:
2586 this->writeWord(SpvExecutionModelVertex, out);
2587 break;
2588 case Program::kFragment_Kind:
2589 this->writeWord(SpvExecutionModelFragment, out);
2590 break;
2591 }
2592 this->writeWord(fFunctionMap[main], out);
2593 this->writeString(main->fName.c_str(), out);
2594 for (int var : interfaceVars) {
2595 this->writeWord(var, out);
2596 }
2597 if (program.fKind == Program::kFragment_Kind) {
2598 this->writeInstruction(SpvOpExecutionMode,
2599 fFunctionMap[main],
2600 SpvExecutionModeOriginUpperLeft,
2601 out);
2602 }
2603 for (size_t i = 0; i < program.fElements.size(); i++) {
2604 if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
2605 this->writeInstruction(SpvOpSourceExtension,
2606 ((Extension&) *program.fElements[i]).fName.c_str(),
2607 out);
2608 }
2609 }
2610
2611 out << fNameBuffer.str();
2612 out << fDecorationBuffer.str();
2613 out << fConstantBuffer.str();
2614 out << fExternalFunctionsBuffer.str();
2615 out << body.str();
2616}
2617
2618void SPIRVCodeGenerator::generateCode(Program& program, std::ostream& out) {
2619 this->writeWord(SpvMagicNumber, out);
2620 this->writeWord(SpvVersion, out);
2621 this->writeWord(SKSL_MAGIC, out);
2622 std::stringstream buffer;
2623 this->writeInstructions(program, buffer);
2624 this->writeWord(fIdCount, out);
2625 this->writeWord(0, out); // reserved, always zero
2626 out << buffer.str();
2627}
2628
2629}