blob: 76d477290305146bd8e11590155069f91bd9ac6c [file] [log] [blame]
John Kessenich140f3df2015-06-26 16:58:36 -06001//
2//Copyright (C) 2015 LunarG, Inc.
3//
4//All rights reserved.
5//
6//Redistribution and use in source and binary forms, with or without
7//modification, are permitted provided that the following conditions
8//are met:
9//
10// Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//
13// Redistributions in binary form must reproduce the above
14// copyright notice, this list of conditions and the following
15// disclaimer in the documentation and/or other materials provided
16// with the distribution.
17//
18// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19// contributors may be used to endorse or promote products derived
20// from this software without specific prior written permission.
21//
22//THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23//"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24//LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25//FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26//COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28//BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29//LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30//CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31//LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32//ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33//POSSIBILITY OF SUCH DAMAGE.
34//
35
36#include "SPVRemapper.h"
37#include "doc.h"
38
39#if !defined (use_cpp11)
40// ... not supported before C++11
41#else // defined (use_cpp11)
42
43#include <algorithm>
44#include <cassert>
45
46namespace spv {
47
48 // By default, just abort on error. Can be overridden via RegisterErrorHandler
49 spirvbin_t::errorfn_t spirvbin_t::errorHandler = [](const std::string&) { exit(5); };
50 // By default, eat log messages. Can be overridden via RegisterLogHandler
51 spirvbin_t::logfn_t spirvbin_t::logHandler = [](const std::string&) { };
52
53 // This can be overridden to provide other message behavior if needed
54 void spirvbin_t::msg(int minVerbosity, int indent, const std::string& txt) const
55 {
56 if (verbose >= minVerbosity)
57 logHandler(std::string(indent, ' ') + txt);
58 }
59
60 // hash opcode, with special handling for OpExtInst
61 std::uint32_t spirvbin_t::asOpCodeHash(unsigned word)
62 {
63 const spv::Op opCode = asOpCode(word);
64
65 std::uint32_t offset = 0;
66
67 switch (opCode) {
68 case spv::OpExtInst:
69 offset += asId(word + 4); break;
70 default:
71 break;
72 }
73
74 return opCode * 19 + offset; // 19 = small prime
75 }
76
77 spirvbin_t::range_t spirvbin_t::literalRange(spv::Op opCode) const
78 {
79 static const int maxCount = 1<<30;
80
81 switch (opCode) {
82 case spv::OpTypeFloat: // fall through...
83 case spv::OpTypePointer: return range_t(2, 3);
84 case spv::OpTypeInt: return range_t(2, 4);
85 case spv::OpTypeSampler: return range_t(3, 8);
86 case spv::OpTypeVector: // fall through
87 case spv::OpTypeMatrix: // ...
88 case spv::OpTypePipe: return range_t(3, 4);
89 case spv::OpConstant: return range_t(3, maxCount);
90 default: return range_t(0, 0);
91 }
92 }
93
94 spirvbin_t::range_t spirvbin_t::typeRange(spv::Op opCode) const
95 {
96 static const int maxCount = 1<<30;
97
98 if (isConstOp(opCode))
99 return range_t(1, 2);
100
101 switch (opCode) {
102 case spv::OpTypeVector: // fall through
103 case spv::OpTypeMatrix: // ...
104 case spv::OpTypeSampler: // ...
105 case spv::OpTypeArray: // ...
106 case spv::OpTypeRuntimeArray: // ...
107 case spv::OpTypePipe: return range_t(2, 3);
108 case spv::OpTypeStruct: // fall through
109 case spv::OpTypeFunction: return range_t(2, maxCount);
110 case spv::OpTypePointer: return range_t(3, 4);
111 default: return range_t(0, 0);
112 }
113 }
114
115 spirvbin_t::range_t spirvbin_t::constRange(spv::Op opCode) const
116 {
117 static const int maxCount = 1<<30;
118
119 switch (opCode) {
120 case spv::OpTypeArray: // fall through...
121 case spv::OpTypeRuntimeArray: return range_t(3, 4);
122 case spv::OpConstantComposite: return range_t(3, maxCount);
123 default: return range_t(0, 0);
124 }
125 }
126
127 // Is this an opcode we should remove when using --strip?
128 bool spirvbin_t::isStripOp(spv::Op opCode) const
129 {
130 switch (opCode) {
131 case spv::OpSource:
132 case spv::OpSourceExtension:
133 case spv::OpName:
134 case spv::OpMemberName:
135 case spv::OpLine: return true;
136 default: return false;
137 }
138 }
139
140 bool spirvbin_t::isFlowCtrlOpen(spv::Op opCode) const
141 {
142 switch (opCode) {
143 case spv::OpBranchConditional:
144 case spv::OpSwitch: return true;
145 default: return false;
146 }
147 }
148
149 bool spirvbin_t::isFlowCtrlClose(spv::Op opCode) const
150 {
151 switch (opCode) {
152 case spv::OpLoopMerge:
153 case spv::OpSelectionMerge: return true;
154 default: return false;
155 }
156 }
157
158 bool spirvbin_t::isTypeOp(spv::Op opCode) const
159 {
160 switch (opCode) {
161 case spv::OpTypeVoid:
162 case spv::OpTypeBool:
163 case spv::OpTypeInt:
164 case spv::OpTypeFloat:
165 case spv::OpTypeVector:
166 case spv::OpTypeMatrix:
167 case spv::OpTypeSampler:
168 case spv::OpTypeFilter:
169 case spv::OpTypeArray:
170 case spv::OpTypeRuntimeArray:
171 case spv::OpTypeStruct:
172 case spv::OpTypeOpaque:
173 case spv::OpTypePointer:
174 case spv::OpTypeFunction:
175 case spv::OpTypeEvent:
176 case spv::OpTypeDeviceEvent:
177 case spv::OpTypeReserveId:
178 case spv::OpTypeQueue:
179 case spv::OpTypePipe: return true;
180 default: return false;
181 }
182 }
183
184 bool spirvbin_t::isConstOp(spv::Op opCode) const
185 {
186 switch (opCode) {
187 case spv::OpConstantNullObject: error("unimplemented constant type");
188 case spv::OpConstantSampler: error("unimplemented constant type");
189
190 case spv::OpConstantTrue:
191 case spv::OpConstantFalse:
192 case spv::OpConstantNullPointer:
193 case spv::OpConstantComposite:
194 case spv::OpConstant: return true;
195 default: return false;
196 }
197 }
198
199 const auto inst_fn_nop = [](spv::Op, unsigned) { return false; };
200 const auto op_fn_nop = [](spv::Id&) { };
201
202 // g++ doesn't like these defined in the class proper in an anonymous namespace.
203 // Dunno why. Also MSVC doesn't like the constexpr keyword. Also dunno why.
204 // Defining them externally seems to please both compilers, so, here they are.
205 const spv::Id spirvbin_t::unmapped = spv::Id(-10000);
206 const spv::Id spirvbin_t::unused = spv::Id(-10001);
207 const int spirvbin_t::header_size = 5;
208
209 spv::Id spirvbin_t::nextUnusedId(spv::Id id)
210 {
211 while (isNewIdMapped(id)) // search for an unused ID
212 ++id;
213
214 return id;
215 }
216
217 spv::Id spirvbin_t::localId(spv::Id id, spv::Id newId)
218 {
219 assert(id != spv::NoResult && newId != spv::NoResult);
220
221 if (id >= idMapL.size())
222 idMapL.resize(id+1, unused);
223
224 if (newId != unmapped && newId != unused) {
225 if (isOldIdUnused(id))
226 error(std::string("ID unused in module: ") + std::to_string(id));
227
228 if (!isOldIdUnmapped(id))
229 error(std::string("ID already mapped: ") + std::to_string(id) + " -> "
230 + std::to_string(localId(id)));
231
232 if (isNewIdMapped(newId))
233 error(std::string("ID already used in module: ") + std::to_string(newId));
234
235 msg(4, 4, std::string("map: ") + std::to_string(id) + " -> " + std::to_string(newId));
236 setMapped(newId);
237 largestNewId = std::max(largestNewId, newId);
238 }
239
240 return idMapL[id] = newId;
241 }
242
243 // Parse a literal string from the SPIR binary and return it as an std::string
244 // Due to C++11 RValue references, this doesn't copy the result string.
245 std::string spirvbin_t::literalString(unsigned word) const
246 {
247 std::string literal;
248
249 literal.reserve(16);
250
251 const char* bytes = reinterpret_cast<const char*>(spv.data() + word);
252
253 while (bytes && *bytes)
254 literal += *bytes++;
255
256 return literal;
257 }
258
259
260 void spirvbin_t::applyMap()
261 {
262 msg(3, 2, std::string("Applying map: "));
263
264 // Map local IDs through the ID map
265 process(inst_fn_nop, // ignore instructions
266 [this](spv::Id& id) {
267 id = localId(id);
268 assert(id != unused && id != unmapped);
269 }
270 );
271 }
272
273
274 // Find free IDs for anything we haven't mapped
275 void spirvbin_t::mapRemainder()
276 {
277 msg(3, 2, std::string("Remapping remainder: "));
278
279 spv::Id unusedId = 1; // can't use 0: that's NoResult
280 spirword_t maxBound = 0;
281
282 for (spv::Id id = 0; id < idMapL.size(); ++id) {
283 if (isOldIdUnused(id))
284 continue;
285
286 // Find a new mapping for any used but unmapped IDs
287 if (isOldIdUnmapped(id))
288 localId(id, unusedId = nextUnusedId(unusedId));
289
290 if (isOldIdUnmapped(id))
291 error(std::string("old ID not mapped: ") + std::to_string(id));
292
293 // Track max bound
294 maxBound = std::max(maxBound, localId(id) + 1);
295 }
296
297 bound(maxBound); // reset header ID bound to as big as it now needs to be
298 }
299
300 void spirvbin_t::stripDebug()
301 {
302 if ((options & STRIP) == 0)
303 return;
304
305 // build local Id and name maps
306 process(
307 [&](spv::Op opCode, unsigned start) {
308 // remember opcodes we want to strip later
309 if (isStripOp(opCode))
310 stripInst(start);
311 return true;
312 },
313 op_fn_nop);
314 }
315
316 void spirvbin_t::buildLocalMaps()
317 {
318 msg(2, 2, std::string("build local maps: "));
319
320 mapped.clear();
321 idMapL.clear();
322// preserve nameMap, so we don't clear that.
323 fnPos.clear();
324 fnPosDCE.clear();
325 fnCalls.clear();
326 typeConstPos.clear();
327 typeConstPosR.clear();
328 entryPoint = spv::NoResult;
329 largestNewId = 0;
330
331 idMapL.resize(bound(), unused);
332
333 int fnStart = 0;
334 spv::Id fnRes = spv::NoResult;
335
336 // build local Id and name maps
337 process(
338 [&](spv::Op opCode, unsigned start) {
339 // remember opcodes we want to strip later
340 if ((options & STRIP) && isStripOp(opCode))
341 stripInst(start);
342
343 if (opCode == spv::Op::OpName) {
344 const spv::Id target = asId(start+1);
345 const std::string name = literalString(start+2);
346 nameMap[name] = target;
347
348 } else if (opCode == spv::Op::OpFunctionCall) {
349 ++fnCalls[asId(start + 3)];
350 } else if (opCode == spv::Op::OpEntryPoint) {
351 entryPoint = asId(start + 2);
352 } else if (opCode == spv::Op::OpFunction) {
353 if (fnStart != 0)
354 error("nested function found");
355 fnStart = start;
356 fnRes = asId(start + 2);
357 } else if (opCode == spv::Op::OpFunctionEnd) {
358 assert(fnRes != spv::NoResult);
359 if (fnStart == 0)
360 error("function end without function start");
361 fnPos[fnRes] = range_t(fnStart, start + asWordCount(start));
362 fnStart = 0;
363 } else if (isConstOp(opCode)) {
364 assert(asId(start + 2) != spv::NoResult);
365 typeConstPos.insert(start);
366 typeConstPosR[asId(start + 2)] = start;
367 } else if (isTypeOp(opCode)) {
368 assert(asId(start + 1) != spv::NoResult);
369 typeConstPos.insert(start);
370 typeConstPosR[asId(start + 1)] = start;
371 }
372
373 return false;
374 },
375
376 [this](spv::Id& id) { localId(id, unmapped); }
377 );
378 }
379
380 // Validate the SPIR header
381 void spirvbin_t::validate() const
382 {
383 msg(2, 2, std::string("validating: "));
384
385 if (spv.size() < header_size)
386 error("file too short: ");
387
388 if (magic() != spv::MagicNumber)
389 error("bad magic number");
390
391 // field 1 = version
392 // field 2 = generator magic
393 // field 3 = result <id> bound
394
395 if (schemaNum() != 0)
396 error("bad schema, must be 0");
397 }
398
399
400 int spirvbin_t::processInstruction(unsigned word, instfn_t instFn, idfn_t idFn)
401 {
402 const auto instructionStart = word;
403 const unsigned wordCount = asWordCount(instructionStart);
404 const spv::Op opCode = asOpCode(instructionStart);
405 const int nextInst = word++ + wordCount;
406
407 if (nextInst > int(spv.size()))
408 error("spir instruction terminated too early");
409
410 // Base for computing number of operands; will be updated as more is learned
411 unsigned numOperands = wordCount - 1;
412
413 if (instFn(opCode, instructionStart))
414 return nextInst;
415
416 // Read type and result ID from instruction desc table
417 if (spv::InstructionDesc[opCode].hasType()) {
418 idFn(asId(word++));
419 --numOperands;
420 }
421
422 if (spv::InstructionDesc[opCode].hasResult()) {
423 idFn(asId(word++));
424 --numOperands;
425 }
426
427 // Extended instructions: currently, assume everything is an ID.
428 // TODO: add whatever data we need for exceptions to that
429 if (opCode == spv::OpExtInst) {
430 word += 2; // instruction set, and instruction from set
431 numOperands -= 2;
432
433 for (unsigned op=0; op < numOperands; ++op)
434 idFn(asId(word++)); // ID
435
436 return nextInst;
437 }
438
439 // Store IDs from instruction in our map
440 for (int op = 0; op < spv::InstructionDesc[opCode].operands.getNum(); ++op, --numOperands) {
441 switch (spv::InstructionDesc[opCode].operands.getClass(op)) {
442 case spv::OperandId:
443 idFn(asId(word++));
444 break;
445
446 case spv::OperandOptionalId:
447 case spv::OperandVariableIds:
448 for (unsigned i = 0; i < numOperands; ++i)
449 idFn(asId(word++));
450 return nextInst;
451
452 case spv::OperandVariableLiterals:
453 if (opCode == spv::OpDecorate && asDecoration(word - 1) == spv::DecorationBuiltIn) {
454 ++word;
455 --numOperands;
456 }
457 word += numOperands;
458 return nextInst;
459
460 case spv::OperandVariableLiteralId:
461 while (numOperands > 0) {
462 ++word; // immediate
463 idFn(asId(word++)); // ID
464 numOperands -= 2;
465 }
466 return nextInst;
467
468 case spv::OperandLiteralString:
469 word += literalStringWords(literalString(word));
470 return nextInst;
471
472 // Single word operands we simply ignore, as they hold no IDs
473 case spv::OperandLiteralNumber:
474 case spv::OperandSource:
475 case spv::OperandExecutionModel:
476 case spv::OperandAddressing:
477 case spv::OperandMemory:
478 case spv::OperandExecutionMode:
479 case spv::OperandStorage:
480 case spv::OperandDimensionality:
481 case spv::OperandDecoration:
482 case spv::OperandBuiltIn:
483 case spv::OperandSelect:
484 case spv::OperandLoop:
485 case spv::OperandFunction:
486 case spv::OperandMemorySemantics:
487 case spv::OperandMemoryAccess:
488 case spv::OperandExecutionScope:
489 case spv::OperandGroupOperation:
490 case spv::OperandKernelEnqueueFlags:
491 case spv::OperandKernelProfilingInfo:
492 ++word;
493 break;
494
495 default:
496 break;
497 }
498 }
499
500 return nextInst;
501 }
502
503 // Make a pass over all the instructions and process them given appropriate functions
504 spirvbin_t& spirvbin_t::process(instfn_t instFn, idfn_t idFn, unsigned begin, unsigned end)
505 {
506 // For efficiency, reserve name map space. It can grow if needed.
507 nameMap.reserve(32);
508
509 // If begin or end == 0, use defaults
510 begin = (begin == 0 ? header_size : begin);
511 end = (end == 0 ? unsigned(spv.size()) : end);
512
513 // basic parsing and InstructionDesc table borrowed from SpvDisassemble.cpp...
514 unsigned nextInst = unsigned(spv.size());
515
516 for (unsigned word = begin; word < end; word = nextInst)
517 nextInst = processInstruction(word, instFn, idFn);
518
519 return *this;
520 }
521
522 // Apply global name mapping to a single module
523 void spirvbin_t::mapNames()
524 {
525 static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options
526 static const std::uint32_t firstMappedID = 3019; // offset into ID space
527
528 for (const auto& name : nameMap) {
529 std::uint32_t hashval = 1911;
530 for (const char c : name.first)
531 hashval = hashval * 1009 + c;
532
533 if (isOldIdUnmapped(name.second))
534 localId(name.second, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
535 }
536 }
537
538 // Map fn contents to IDs of similar functions in other modules
539 void spirvbin_t::mapFnBodies()
540 {
541 static const std::uint32_t softTypeIdLimit = 19071; // small prime. TODO: get from options
542 static const std::uint32_t firstMappedID = 6203; // offset into ID space
543
544 // Initial approach: go through some high priority opcodes first and assign them
545 // hash values.
546
547 spv::Id fnId = spv::NoResult;
548 std::vector<unsigned> instPos;
549 instPos.reserve(unsigned(spv.size()) / 16); // initial estimate; can grow if needed.
550
551 // Build local table of instruction start positions
552 process(
553 [&](spv::Op, unsigned start) { instPos.push_back(start); return true; },
554 op_fn_nop);
555
556 // Window size for context-sensitive canonicalization values
557 // Emperical best size from a single data set. TODO: Would be a good tunable.
558 // We essentially performa a little convolution around each instruction,
559 // to capture the flavor of nearby code, to hopefully match to similar
560 // code in other modules.
561 static const unsigned windowSize = 2;
562
563 for (unsigned entry = 0; entry < unsigned(instPos.size()); ++entry) {
564 const unsigned start = instPos[entry];
565 const spv::Op opCode = asOpCode(start);
566
567 if (opCode == spv::OpFunction)
568 fnId = asId(start + 2);
569
570 if (opCode == spv::OpFunctionEnd)
571 fnId = spv::NoResult;
572
573 if (fnId != spv::NoResult) { // if inside a function
574 if (spv::InstructionDesc[opCode].hasResult()) {
575 const unsigned word = start + (spv::InstructionDesc[opCode].hasType() ? 2 : 1);
576 const spv::Id resId = asId(word);
577 std::uint32_t hashval = fnId * 17; // small prime
578
579 for (unsigned i = entry-1; i >= entry-windowSize; --i) {
580 if (asOpCode(instPos[i]) == spv::OpFunction)
581 break;
582 hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
583 }
584
585 for (unsigned i = entry; i <= entry + windowSize; ++i) {
586 if (asOpCode(instPos[i]) == spv::OpFunctionEnd)
587 break;
588 hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
589 }
590
591 if (isOldIdUnmapped(resId))
592 localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
593 }
594 }
595 }
596
597 spv::Op thisOpCode(spv::OpNop);
598 std::unordered_map<int, int> opCounter;
599 int idCounter(0);
600 fnId = spv::NoResult;
601
602 process(
603 [&](spv::Op opCode, unsigned start) {
604 switch (opCode) {
605 case spv::OpFunction:
606 // Reset counters at each function
607 idCounter = 0;
608 opCounter.clear();
609 fnId = asId(start + 2);
610 break;
611
612 case spv::OpTextureSample:
613 case spv::OpTextureSampleDref:
614 case spv::OpTextureSampleLod:
615 case spv::OpTextureSampleProj:
616 case spv::OpTextureSampleGrad:
617 case spv::OpTextureSampleOffset:
618 case spv::OpTextureSampleProjLod:
619 case spv::OpTextureSampleProjGrad:
620 case spv::OpTextureSampleLodOffset:
621 case spv::OpTextureSampleProjOffset:
622 case spv::OpTextureSampleGradOffset:
623 case spv::OpTextureSampleProjLodOffset:
624 case spv::OpTextureSampleProjGradOffset:
625 case spv::OpDot:
626 case spv::OpCompositeExtract:
627 case spv::OpCompositeInsert:
628 case spv::OpVectorShuffle:
629 case spv::OpLabel:
630 case spv::OpVariable:
631
632 case spv::OpAccessChain:
633 case spv::OpLoad:
634 case spv::OpStore:
635 case spv::OpCompositeConstruct:
636 case spv::OpFunctionCall:
637 ++opCounter[opCode];
638 idCounter = 0;
639 thisOpCode = opCode;
640 break;
641 default:
642 thisOpCode = spv::OpNop;
643 }
644
645 return false;
646 },
647
648 [&](spv::Id& id) {
649 if (thisOpCode != spv::OpNop) {
650 ++idCounter;
651 const std::uint32_t hashval = opCounter[thisOpCode] * thisOpCode * 50047 + idCounter + fnId * 117;
652
653 if (isOldIdUnmapped(id))
654 localId(id, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
655 }
656 });
657 }
658
659 // EXPERIMENTAL: forward IO and uniform load/stores into operands
660 // This produces invalid Schema-0 SPIRV
661 void spirvbin_t::forwardLoadStores()
662 {
663 idset_t fnLocalVars; // set of function local vars
664 idmap_t idMap; // Map of load result IDs to what they load
665
666 // EXPERIMENTAL: Forward input and access chain loads into consumptions
667 process(
668 [&](spv::Op opCode, unsigned start) {
669 // Add inputs and uniforms to the map
670 if (((opCode == spv::OpVariable && asWordCount(start) == 4) || (opCode == spv::OpVariableArray)) &&
671 (spv[start+3] == spv::StorageClassUniform ||
672 spv[start+3] == spv::StorageClassUniformConstant ||
673 spv[start+3] == spv::StorageClassInput))
674 fnLocalVars.insert(asId(start+2));
675
676 if (opCode == spv::OpAccessChain && fnLocalVars.count(asId(start+3)) > 0)
677 fnLocalVars.insert(asId(start+2));
678
679 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
680 idMap[asId(start+2)] = asId(start+3);
681 stripInst(start);
682 }
683
684 return false;
685 },
686
687 [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
688 );
689
690 // EXPERIMENTAL: Implicit output stores
691 fnLocalVars.clear();
692 idMap.clear();
693
694 process(
695 [&](spv::Op opCode, unsigned start) {
696 // Add inputs and uniforms to the map
697 if (((opCode == spv::OpVariable && asWordCount(start) == 4) || (opCode == spv::OpVariableArray)) &&
698 (spv[start+3] == spv::StorageClassOutput))
699 fnLocalVars.insert(asId(start+2));
700
701 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
702 idMap[asId(start+2)] = asId(start+1);
703 stripInst(start);
704 }
705
706 return false;
707 },
708 op_fn_nop);
709
710 process(
711 inst_fn_nop,
712 [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
713 );
714
715 strip(); // strip out data we decided to eliminate
716 }
717
718 // remove bodies of uncalled functions
719 void spirvbin_t::optLoadStore()
720 {
721 idset_t fnLocalVars;
722 // Map of load result IDs to what they load
723 idmap_t idMap;
724
725 // Find all the function local pointers stored at most once, and not via access chains
726 process(
727 [&](spv::Op opCode, unsigned start) {
728 const int wordCount = asWordCount(start);
729
730 // Add local variables to the map
731 if ((opCode == spv::OpVariable && spv[start+3] == spv::StorageClassFunction && asWordCount(start) == 4) ||
732 (opCode == spv::OpVariableArray && spv[start+3] == spv::StorageClassFunction))
733 fnLocalVars.insert(asId(start+2));
734
735 // Ignore process vars referenced via access chain
736 if ((opCode == spv::OpAccessChain || opCode == spv::OpInBoundsAccessChain) && fnLocalVars.count(asId(start+3)) > 0) {
737 fnLocalVars.erase(asId(start+3));
738 idMap.erase(asId(start+3));
739 }
740
741 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
742 // Avoid loads before stores (TODO: why? Crashes driver, but seems like it shouldn't).
743 if (idMap.find(asId(start+3)) == idMap.end()) {
744 fnLocalVars.erase(asId(start+3));
745 idMap.erase(asId(start+3));
746 }
747
748 // don't do for volatile references
749 if (wordCount > 4 && (spv[start+4] & spv::MemoryAccessVolatileMask)) {
750 fnLocalVars.erase(asId(start+3));
751 idMap.erase(asId(start+3));
752 }
753 }
754
755 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
756 if (idMap.find(asId(start+1)) == idMap.end()) {
757 idMap[asId(start+1)] = asId(start+2);
758 } else {
759 // Remove if it has more than one store to the same pointer
760 fnLocalVars.erase(asId(start+1));
761 idMap.erase(asId(start+1));
762 }
763
764 // don't do for volatile references
765 if (wordCount > 3 && (spv[start+3] & spv::MemoryAccessVolatileMask)) {
766 fnLocalVars.erase(asId(start+3));
767 idMap.erase(asId(start+3));
768 }
769 }
770
771 return true;
772 },
773 op_fn_nop);
774
775 process(
776 [&](spv::Op opCode, unsigned start) {
777 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0)
778 idMap[asId(start+2)] = idMap[asId(start+3)];
779 return false;
780 },
781 op_fn_nop);
782
783 // Remove the load/store/variables for the ones we've discovered
784 process(
785 [&](spv::Op opCode, unsigned start) {
786 if ((opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) ||
787 (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) ||
788 (opCode == spv::OpVariable && fnLocalVars.count(asId(start+2)) > 0)) {
789 stripInst(start);
790 return true;
791 }
792
793 return false;
794 },
795
796 [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
797 );
798
799 strip(); // strip out data we decided to eliminate
800 }
801
802 // remove bodies of uncalled functions
803 void spirvbin_t::dceFuncs()
804 {
805 msg(3, 2, std::string("Removing Dead Functions: "));
806
807 // TODO: There are more efficient ways to do this.
808 bool changed = true;
809
810 while (changed) {
811 changed = false;
812
813 for (auto fn = fnPos.begin(); fn != fnPos.end(); ) {
814 if (fn->first == entryPoint) { // don't DCE away the entry point!
815 ++fn;
816 continue;
817 }
818
819 const auto call_it = fnCalls.find(fn->first);
820
821 if (call_it == fnCalls.end() || call_it->second == 0) {
822 changed = true;
823 stripRange.push_back(fn->second);
824 fnPosDCE.insert(*fn);
825
826 // decrease counts of called functions
827 process(
828 [&](spv::Op opCode, unsigned start) {
829 if (opCode == spv::Op::OpFunctionCall) {
830 const auto call_it = fnCalls.find(asId(start + 3));
831 if (call_it != fnCalls.end()) {
832 if (--call_it->second <= 0)
833 fnCalls.erase(call_it);
834 }
835 }
836
837 return true;
838 },
839 op_fn_nop,
840 fn->second.first,
841 fn->second.second);
842
843 fn = fnPos.erase(fn);
844 } else ++fn;
845 }
846 }
847 }
848
849 // remove unused function variables + decorations
850 void spirvbin_t::dceVars()
851 {
852 msg(3, 2, std::string("DCE Vars: "));
853
854 std::unordered_map<spv::Id, int> varUseCount;
855
856 // Count function variable use
857 process(
858 [&](spv::Op opCode, unsigned start) {
859 if (opCode == spv::OpVariable) { ++varUseCount[asId(start+2)]; return true; }
860 return false;
861 },
862
863 [&](spv::Id& id) { if (varUseCount[id]) ++varUseCount[id]; }
864 );
865
866 // Remove single-use function variables + associated decorations and names
867 process(
868 [&](spv::Op opCode, unsigned start) {
869 if ((opCode == spv::OpVariable && varUseCount[asId(start+2)] == 1) ||
870 (opCode == spv::OpDecorate && varUseCount[asId(start+1)] == 1) ||
871 (opCode == spv::OpName && varUseCount[asId(start+1)] == 1)) {
872 stripInst(start);
873 }
874 return true;
875 },
876 op_fn_nop);
877 }
878
879 // remove unused types
880 void spirvbin_t::dceTypes()
881 {
882 std::vector<bool> isType(bound(), false);
883
884 // for speed, make O(1) way to get to type query (map is log(n))
885 for (const auto typeStart : typeConstPos)
886 isType[asTypeConstId(typeStart)] = true;
887
888 std::unordered_map<spv::Id, int> typeUseCount;
889
890 // Count total type usage
891 process(inst_fn_nop,
892 [&](spv::Id& id) { if (isType[id]) ++typeUseCount[id]; }
893 );
894
895 // Remove types from deleted code
896 for (const auto& fn : fnPosDCE)
897 process(inst_fn_nop,
898 [&](spv::Id& id) { if (isType[id]) --typeUseCount[id]; },
899 fn.second.first, fn.second.second);
900
901 // Remove single reference types
902 for (const auto typeStart : typeConstPos) {
903 const spv::Id typeId = asTypeConstId(typeStart);
904 if (typeUseCount[typeId] == 1) {
905 --typeUseCount[typeId];
906 stripInst(typeStart);
907 }
908 }
909 }
910
911
912#ifdef NOTDEF
913 bool spirvbin_t::matchType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt, spv::Id gt) const
914 {
915 // Find the local type id "lt" and global type id "gt"
916 const auto lt_it = typeConstPosR.find(lt);
917 if (lt_it == typeConstPosR.end())
918 return false;
919
920 const auto typeStart = lt_it->second;
921
922 // Search for entry in global table
923 const auto gtype = globalTypes.find(gt);
924 if (gtype == globalTypes.end())
925 return false;
926
927 const auto& gdata = gtype->second;
928
929 // local wordcount and opcode
930 const int wordCount = asWordCount(typeStart);
931 const spv::Op opCode = asOpCode(typeStart);
932
933 // no type match if opcodes don't match, or operand count doesn't match
934 if (opCode != opOpCode(gdata[0]) || wordCount != opWordCount(gdata[0]))
935 return false;
936
937 const unsigned numOperands = wordCount - 2; // all types have a result
938
939 const auto cmpIdRange = [&](range_t range) {
940 for (int x=range.first; x<std::min(range.second, wordCount); ++x)
941 if (!matchType(globalTypes, asId(typeStart+x), gdata[x]))
942 return false;
943 return true;
944 };
945
946 const auto cmpConst = [&]() { return cmpIdRange(constRange(opCode)); };
947 const auto cmpSubType = [&]() { return cmpIdRange(typeRange(opCode)); };
948
949 // Compare literals in range [start,end)
950 const auto cmpLiteral = [&]() {
951 const auto range = literalRange(opCode);
952 return std::equal(spir.begin() + typeStart + range.first,
953 spir.begin() + typeStart + std::min(range.second, wordCount),
954 gdata.begin() + range.first);
955 };
956
957 assert(isTypeOp(opCode) || isConstOp(opCode));
958
959 switch (opCode) {
960 case spv::OpTypeOpaque: // TODO: disable until we compare the literal strings.
961 case spv::OpTypeQueue: return false;
962 case spv::OpTypeEvent: // fall through...
963 case spv::OpTypeDeviceEvent: // ...
964 case spv::OpTypeReserveId: return false;
965 // for samplers, we don't handle the optional parameters yet
966 case spv::OpTypeSampler: return cmpLiteral() && cmpConst() && cmpSubType() && wordCount == 8;
967 default: return cmpLiteral() && cmpConst() && cmpSubType();
968 }
969 }
970
971
972 // Look for an equivalent type in the globalTypes map
973 spv::Id spirvbin_t::findType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt) const
974 {
975 // Try a recursive type match on each in turn, and return a match if we find one
976 for (const auto& gt : globalTypes)
977 if (matchType(globalTypes, lt, gt.first))
978 return gt.first;
979
980 return spv::NoType;
981 }
982#endif // NOTDEF
983
984 // Return start position in SPV of given type. error if not found.
985 unsigned spirvbin_t::typePos(spv::Id id) const
986 {
987 const auto tid_it = typeConstPosR.find(id);
988 if (tid_it == typeConstPosR.end())
989 error("type ID not found");
990
991 return tid_it->second;
992 }
993
994 // Hash types to canonical values. This can return ID collisions (it's a bit
995 // inevitable): it's up to the caller to handle that gracefully.
996 std::uint32_t spirvbin_t::hashType(unsigned typeStart) const
997 {
998 const unsigned wordCount = asWordCount(typeStart);
999 const spv::Op opCode = asOpCode(typeStart);
1000
1001 switch (opCode) {
1002 case spv::OpTypeVoid: return 0;
1003 case spv::OpTypeBool: return 1;
1004 case spv::OpTypeInt: return 3 + (spv[typeStart+3]);
1005 case spv::OpTypeFloat: return 5;
1006 case spv::OpTypeVector:
1007 return 6 + hashType(typePos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1008 case spv::OpTypeMatrix:
1009 return 30 + hashType(typePos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1010 case spv::OpTypeSampler:
1011 return 120 + hashType(typePos(spv[typeStart+2])) +
1012 spv[typeStart+3] + // dimensionality
1013 spv[typeStart+4] * 8 * 16 + // content
1014 spv[typeStart+5] * 4 * 16 + // arrayed
1015 spv[typeStart+6] * 2 * 16 + // compare
1016 spv[typeStart+7] * 1 * 16; // multisampled
1017 case spv::OpTypeFilter:
1018 return 500;
1019 case spv::OpTypeArray:
1020 return 501 + hashType(typePos(spv[typeStart+2])) * spv[typeStart+3];
1021 case spv::OpTypeRuntimeArray:
1022 return 5000 + hashType(typePos(spv[typeStart+2]));
1023 case spv::OpTypeStruct:
1024 {
1025 std::uint32_t hash = 10000;
1026 for (unsigned w=2; w < wordCount; ++w)
1027 hash += w * hashType(typePos(spv[typeStart+w]));
1028 return hash;
1029 }
1030
1031 case spv::OpTypeOpaque: return 6000 + spv[typeStart+2];
1032 case spv::OpTypePointer: return 100000 + hashType(typePos(spv[typeStart+3]));
1033 case spv::OpTypeFunction:
1034 {
1035 std::uint32_t hash = 200000;
1036 for (unsigned w=2; w < wordCount; ++w)
1037 hash += w * hashType(typePos(spv[typeStart+w]));
1038 return hash;
1039 }
1040
1041 case spv::OpTypeEvent: return 300000;
1042 case spv::OpTypeDeviceEvent: return 300001;
1043 case spv::OpTypeReserveId: return 300002;
1044 case spv::OpTypeQueue: return 300003;
1045 case spv::OpTypePipe: return 300004;
1046
1047 case spv::OpConstantNullObject: return 300005;
1048 case spv::OpConstantSampler: return 300006;
1049
1050 case spv::OpConstantTrue: return 300007;
1051 case spv::OpConstantFalse: return 300008;
1052 case spv::OpConstantNullPointer: return 300009;
1053 case spv::OpConstantComposite:
1054 {
1055 std::uint32_t hash = 300011 + hashType(typePos(spv[typeStart+1]));
1056 for (unsigned w=3; w < wordCount; ++w)
1057 hash += w * hashType(typePos(spv[typeStart+w]));
1058 return hash;
1059 }
1060 case spv::OpConstant:
1061 {
1062 std::uint32_t hash = 400011 + hashType(typePos(spv[typeStart+1]));
1063 for (unsigned w=3; w < wordCount; ++w)
1064 hash += w * spv[typeStart+w];
1065 return hash;
1066 }
1067
1068 default:
1069 error("unknown type opcode");
1070 return 0;
1071 }
1072 }
1073
1074 void spirvbin_t::mapTypeConst()
1075 {
1076 globaltypes_t globalTypeMap;
1077
1078 msg(3, 2, std::string("Remapping Consts & Types: "));
1079
1080 static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options
1081 static const std::uint32_t firstMappedID = 8; // offset into ID space
1082
1083 for (auto& typeStart : typeConstPos) {
1084 const spv::Id resId = asTypeConstId(typeStart);
1085 const std::uint32_t hashval = hashType(typeStart);
1086
1087 if (isOldIdUnmapped(resId))
1088 localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
1089 }
1090 }
1091
1092
1093 // Strip a single binary by removing ranges given in stripRange
1094 void spirvbin_t::strip()
1095 {
1096 if (stripRange.empty()) // nothing to do
1097 return;
1098
1099 // Sort strip ranges in order of traversal
1100 std::sort(stripRange.begin(), stripRange.end());
1101
1102 // Allocate a new binary big enough to hold old binary
1103 // We'll step this iterator through the strip ranges as we go through the binary
1104 auto strip_it = stripRange.begin();
1105
1106 int strippedPos = 0;
1107 for (unsigned word = 0; word < unsigned(spv.size()); ++word) {
1108 if (strip_it != stripRange.end() && word >= strip_it->second)
1109 ++strip_it;
1110
1111 if (strip_it == stripRange.end() || word < strip_it->first || word >= strip_it->second)
1112 spv[strippedPos++] = spv[word];
1113 }
1114
1115 spv.resize(strippedPos);
1116 stripRange.clear();
1117
1118 buildLocalMaps();
1119 }
1120
1121 // Strip a single binary by removing ranges given in stripRange
1122 void spirvbin_t::remap(std::uint32_t opts)
1123 {
1124 options = opts;
1125
1126 // Set up opcode tables from SpvDoc
1127 spv::Parameterize();
1128
1129 validate(); // validate header
1130 buildLocalMaps();
1131
1132 msg(3, 4, std::string("ID bound: ") + std::to_string(bound()));
1133
1134 strip(); // strip out data we decided to eliminate
1135
1136 if (options & OPT_LOADSTORE) optLoadStore();
1137 if (options & OPT_FWD_LS) forwardLoadStores();
1138 if (options & DCE_FUNCS) dceFuncs();
1139 if (options & DCE_VARS) dceVars();
1140 if (options & DCE_TYPES) dceTypes();
1141 if (options & MAP_TYPES) mapTypeConst();
1142 if (options & MAP_NAMES) mapNames();
1143 if (options & MAP_FUNCS) mapFnBodies();
1144
1145 mapRemainder(); // map any unmapped IDs
1146 applyMap(); // Now remap each shader to the new IDs we've come up with
1147 strip(); // strip out data we decided to eliminate
1148 }
1149
1150 // remap from a memory image
1151 void spirvbin_t::remap(std::vector<std::uint32_t>& in_spv, std::uint32_t opts)
1152 {
1153 spv.swap(in_spv);
1154 remap(opts);
1155 spv.swap(in_spv);
1156 }
1157
1158} // namespace SPV
1159
1160#endif // defined (use_cpp11)
1161