blob: 74814f22a793fc7a75c83c267f9bfcbb4ac04ab6 [file] [log] [blame]
Olli Etuaho5d91dda2015-06-18 15:47:46 +03001//
2// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7// replacing them with calls to functions that choose which component to return or write.
8//
9
10#include "compiler/translator/RemoveDynamicIndexing.h"
11
12#include "compiler/translator/InfoSink.h"
13#include "compiler/translator/IntermNode.h"
14#include "compiler/translator/SymbolTable.h"
15
16namespace
17{
18
19TName GetIndexFunctionName(const TType &type, bool write)
20{
21 TInfoSinkBase nameSink;
22 nameSink << "dyn_index_";
23 if (write)
24 {
25 nameSink << "write_";
26 }
27 if (type.isMatrix())
28 {
29 nameSink << "mat" << type.getCols() << "x" << type.getRows();
30 }
31 else
32 {
33 switch (type.getBasicType())
34 {
35 case EbtInt:
36 nameSink << "ivec";
37 break;
38 case EbtBool:
39 nameSink << "bvec";
40 break;
41 case EbtUInt:
42 nameSink << "uvec";
43 break;
44 case EbtFloat:
45 nameSink << "vec";
46 break;
47 default:
48 UNREACHABLE();
49 }
50 nameSink << type.getNominalSize();
51 }
52 TString nameString = TFunction::mangleName(nameSink.c_str());
53 TName name(nameString);
54 name.setInternal(true);
55 return name;
56}
57
58TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
59{
60 TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
61 symbol->setInternal(true);
62 symbol->getTypePointer()->setQualifier(qualifier);
63 return symbol;
64}
65
66TIntermSymbol *CreateIndexSymbol()
67{
68 TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
69 symbol->setInternal(true);
70 symbol->getTypePointer()->setQualifier(EvqIn);
71 return symbol;
72}
73
74TIntermSymbol *CreateValueSymbol(const TType &type)
75{
76 TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
77 symbol->setInternal(true);
78 symbol->getTypePointer()->setQualifier(EvqIn);
79 return symbol;
80}
81
82TIntermConstantUnion *CreateIntConstantNode(int i)
83{
84 TConstantUnion *constant = new TConstantUnion();
85 constant->setIConst(i);
86 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
87}
88
89TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
90 const TType &fieldType,
91 const int index,
92 TQualifier baseQualifier)
93{
94 TIntermBinary *indexNode = new TIntermBinary(EOpIndexDirect);
95 indexNode->setType(fieldType);
96 TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier);
97 indexNode->setLeft(baseSymbol);
98 indexNode->setRight(CreateIntConstantNode(index));
99 return indexNode;
100}
101
102TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
103{
104 TIntermBinary *assignNode = new TIntermBinary(EOpAssign);
105 assignNode->setType(assignedValueType);
106 assignNode->setLeft(targetNode);
107 assignNode->setRight(CreateValueSymbol(assignedValueType));
108 return assignNode;
109}
110
111TIntermTyped *EnsureSignedInt(TIntermTyped *node)
112{
113 if (node->getBasicType() == EbtInt)
114 return node;
115
116 TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt);
117 convertedNode->setType(TType(EbtInt));
118 convertedNode->getSequence()->push_back(node);
119 convertedNode->setPrecisionFromChildren();
120 return convertedNode;
121}
122
123TType GetFieldType(const TType &indexedType)
124{
125 if (indexedType.isMatrix())
126 {
127 TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
128 fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
129 return fieldType;
130 }
131 else
132 {
133 return TType(indexedType.getBasicType(), indexedType.getPrecision());
134 }
135}
136
137// Generate a read or write function for one field in a vector/matrix.
138// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
139// indices in other places.
140// Note that indices can be either int or uint. We create only int versions of the functions,
141// and convert uint indices to int at the call site.
142// read function example:
143// float dyn_index_vec2(in vec2 base, in int index)
144// {
145// switch(index)
146// {
147// case (0):
148// return base[0];
149// case (1):
150// return base[1];
151// default:
152// break;
153// }
154// if (index < 0)
155// return base[0];
156// return base[1];
157// }
158// write function example:
159// void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
160// {
161// switch(index)
162// {
163// case (0):
164// base[0] = value;
165// return;
166// case (1):
167// base[1] = value;
168// return;
169// default:
170// break;
171// }
172// if (index < 0)
173// {
174// base[0] = value;
175// return;
176// }
177// base[1] = value;
178// }
179// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
180TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
181{
182 ASSERT(!type.isArray());
183 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
184 // end up using mediump version of an indexing function for a highp value, if both mediump and
185 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
186 // principle this code could be used with multiple backends.
187 type.setPrecision(EbpHigh);
188 TIntermAggregate *indexingFunction = new TIntermAggregate(EOpFunction);
189 indexingFunction->setNameObj(GetIndexFunctionName(type, write));
190
191 TType fieldType = GetFieldType(type);
192 int numCases = 0;
193 if (type.isMatrix())
194 {
195 numCases = type.getCols();
196 }
197 else
198 {
199 numCases = type.getNominalSize();
200 }
201 if (write)
202 {
203 indexingFunction->setType(TType(EbtVoid));
204 }
205 else
206 {
207 indexingFunction->setType(fieldType);
208 }
209
210 TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
211 TQualifier baseQualifier = EvqInOut;
212 if (!write)
213 baseQualifier = EvqIn;
214 TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier);
215 paramsNode->getSequence()->push_back(baseParam);
216 TIntermSymbol *indexParam = CreateIndexSymbol();
217 paramsNode->getSequence()->push_back(indexParam);
218 if (write)
219 {
220 TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
221 paramsNode->getSequence()->push_back(valueParam);
222 }
223 indexingFunction->getSequence()->push_back(paramsNode);
224
225 TIntermAggregate *statementList = new TIntermAggregate(EOpSequence);
226 for (int i = 0; i < numCases; ++i)
227 {
228 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
229 statementList->getSequence()->push_back(caseNode);
230
231 TIntermBinary *indexNode =
232 CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier);
233 if (write)
234 {
235 TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
236 statementList->getSequence()->push_back(assignNode);
237 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
238 statementList->getSequence()->push_back(returnNode);
239 }
240 else
241 {
242 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
243 statementList->getSequence()->push_back(returnNode);
244 }
245 }
246
247 // Default case
248 TIntermCase *defaultNode = new TIntermCase(nullptr);
249 statementList->getSequence()->push_back(defaultNode);
250 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
251 statementList->getSequence()->push_back(breakNode);
252
253 TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
254
255 TIntermAggregate *bodyNode = new TIntermAggregate(EOpSequence);
256 bodyNode->getSequence()->push_back(switchNode);
257
258 TIntermBinary *cond = new TIntermBinary(EOpLessThan);
259 cond->setType(TType(EbtBool, EbpUndefined));
260 cond->setLeft(CreateIndexSymbol());
261 cond->setRight(CreateIntConstantNode(0));
262
263 // Two blocks: one accesses (either reads or writes) the first element and returns,
264 // the other accesses the last element.
265 TIntermAggregate *useFirstBlock = new TIntermAggregate(EOpSequence);
266 TIntermAggregate *useLastBlock = new TIntermAggregate(EOpSequence);
267 TIntermBinary *indexFirstNode =
268 CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier);
269 TIntermBinary *indexLastNode =
270 CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier);
271 if (write)
272 {
273 TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
274 useFirstBlock->getSequence()->push_back(assignFirstNode);
275 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
276 useFirstBlock->getSequence()->push_back(returnNode);
277
278 TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
279 useLastBlock->getSequence()->push_back(assignLastNode);
280 }
281 else
282 {
283 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
284 useFirstBlock->getSequence()->push_back(returnFirstNode);
285
286 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
287 useLastBlock->getSequence()->push_back(returnLastNode);
288 }
289 TIntermSelection *ifNode = new TIntermSelection(cond, useFirstBlock, nullptr);
290 bodyNode->getSequence()->push_back(ifNode);
291 bodyNode->getSequence()->push_back(useLastBlock);
292
293 indexingFunction->getSequence()->push_back(bodyNode);
294
295 return indexingFunction;
296}
297
298class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
299{
300 public:
301 RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
302
303 bool visitBinary(Visit visit, TIntermBinary *node) override;
304
305 void insertHelperDefinitions(TIntermNode *root);
306
307 void nextIteration();
308
309 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
310
311 protected:
312 // Sets of types that are indexed. Note that these can not store multiple variants
313 // of the same type with different precisions - only one precision gets stored.
314 std::set<TType> mIndexedVecAndMatrixTypes;
315 std::set<TType> mWrittenVecAndMatrixTypes;
316
317 bool mUsedTreeInsertion;
318
319 // When true, the traverser will remove side effects from any indexing expression.
320 // This is done so that in code like
321 // V[j++][i]++.
322 // where V is an array of vectors, j++ will only be evaluated once.
323 bool mRemoveIndexSideEffectsInSubtree;
324};
325
326RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
327 int shaderVersion)
328 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
329 mUsedTreeInsertion(false),
330 mRemoveIndexSideEffectsInSubtree(false)
331{
332}
333
334void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
335{
336 TIntermAggregate *rootAgg = root->getAsAggregate();
337 ASSERT(rootAgg != nullptr && rootAgg->getOp() == EOpSequence);
338 TIntermSequence insertions;
339 for (TType type : mIndexedVecAndMatrixTypes)
340 {
341 insertions.push_back(GetIndexFunctionDefinition(type, false));
342 }
343 for (TType type : mWrittenVecAndMatrixTypes)
344 {
345 insertions.push_back(GetIndexFunctionDefinition(type, true));
346 }
347 mInsertions.push_back(NodeInsertMultipleEntry(rootAgg, 0, insertions, TIntermSequence()));
348}
349
350// Create a call to dyn_index_*() based on an indirect indexing op node
351TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
352 TIntermTyped *indexedNode,
353 TIntermTyped *index)
354{
355 ASSERT(node->getOp() == EOpIndexIndirect);
356 TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall);
357 indexingCall->setLine(node->getLine());
358 indexingCall->setUserDefined();
359 indexingCall->setNameObj(GetIndexFunctionName(indexedNode->getType(), false));
360 indexingCall->getSequence()->push_back(indexedNode);
361 indexingCall->getSequence()->push_back(index);
362
363 TType fieldType = GetFieldType(indexedNode->getType());
364 indexingCall->setType(fieldType);
365 return indexingCall;
366}
367
368TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
369 TIntermTyped *index,
370 TIntermTyped *writtenValue)
371{
372 // Deep copy the left node so that two pointers to the same node don't end up in the tree.
373 TIntermNode *leftCopy = node->getLeft()->deepCopy();
374 ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr);
375 TIntermAggregate *indexedWriteCall =
376 CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index);
377 indexedWriteCall->setNameObj(GetIndexFunctionName(node->getLeft()->getType(), true));
378 indexedWriteCall->setType(TType(EbtVoid));
379 indexedWriteCall->getSequence()->push_back(writtenValue);
380 return indexedWriteCall;
381}
382
383bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
384{
385 if (mUsedTreeInsertion)
386 return false;
387
388 if (node->getOp() == EOpIndexIndirect)
389 {
390 if (mRemoveIndexSideEffectsInSubtree)
391 {
392 ASSERT(node->getRight()->hasSideEffects());
393 // In case we're just removing index side effects, convert
394 // v_expr[index_expr]
395 // to this:
396 // int s0 = index_expr; v_expr[s0];
397 // Now v_expr[s0] can be safely executed several times without unintended side effects.
398
399 // Init the temp variable holding the index
400 TIntermAggregate *initIndex = createTempInitDeclaration(node->getRight());
401 TIntermSequence insertions;
402 insertions.push_back(initIndex);
403 insertStatementsInParentBlock(insertions);
404 mUsedTreeInsertion = true;
405
406 // Replace the index with the temp variable
407 TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
408 NodeUpdateEntry replaceIndex(node, node->getRight(), tempIndex, false);
409 mReplacements.push_back(replaceIndex);
410 }
411 else if (!node->getLeft()->isArray() && node->getLeft()->getBasicType() != EbtStruct)
412 {
413 bool write = isLValueRequiredHere();
414
415 TType type = node->getLeft()->getType();
416 mIndexedVecAndMatrixTypes.insert(type);
417
418 if (write)
419 {
420 // Convert:
421 // v_expr[index_expr]++;
422 // to this:
423 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
424 // dyn_index_write(v_expr, s0, s1);
425 // This works even if index_expr has some side effects.
426 if (node->getLeft()->hasSideEffects())
427 {
428 // If v_expr has side effects, those need to be removed before proceeding.
429 // Otherwise the side effects of v_expr would be evaluated twice.
430 // The only case where an l-value can have side effects is when it is
431 // indexing. For example, it can be V[j++] where V is an array of vectors.
432 mRemoveIndexSideEffectsInSubtree = true;
433 return true;
434 }
435 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
436 // only writes it and doesn't need the previous value. http://anglebug.com/1116
437
438 mWrittenVecAndMatrixTypes.insert(type);
439 TType fieldType = GetFieldType(type);
440
441 TIntermSequence insertionsBefore;
442 TIntermSequence insertionsAfter;
443
444 // Store the index in a temporary signed int variable.
445 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
446 TIntermAggregate *initIndex = createTempInitDeclaration(indexInitializer);
447 initIndex->setLine(node->getLine());
448 insertionsBefore.push_back(initIndex);
449
450 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
451 node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
452
453 // Create a node for referring to the index after the nextTemporaryIndex() call
454 // below.
455 TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
456
457 nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
458 // field value.
459 insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
460
461 TIntermAggregate *indexedWriteCall =
462 CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType));
463 insertionsAfter.push_back(indexedWriteCall);
464 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
465 NodeUpdateEntry replaceIndex(getParentNode(), node, createTempSymbol(fieldType),
466 false);
467 mReplacements.push_back(replaceIndex);
468 mUsedTreeInsertion = true;
469 }
470 else
471 {
472 // The indexed value is not being written, so we can simply convert
473 // v_expr[index_expr]
474 // into
475 // dyn_index(v_expr, index_expr)
476 // If the index_expr is unsigned, we'll convert it to signed.
477 ASSERT(!mRemoveIndexSideEffectsInSubtree);
478 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
479 node, node->getLeft(), EnsureSignedInt(node->getRight()));
480 NodeUpdateEntry replaceIndex(getParentNode(), node, indexingCall, false);
481 mReplacements.push_back(replaceIndex);
482 }
483 }
484 }
485 return !mUsedTreeInsertion;
486}
487
488void RemoveDynamicIndexingTraverser::nextIteration()
489{
490 mUsedTreeInsertion = false;
491 mRemoveIndexSideEffectsInSubtree = false;
492 nextTemporaryIndex();
493}
494
495} // namespace
496
497void RemoveDynamicIndexing(TIntermNode *root,
498 unsigned int *temporaryIndex,
499 const TSymbolTable &symbolTable,
500 int shaderVersion)
501{
502 RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
503 ASSERT(temporaryIndex != nullptr);
504 traverser.useTemporaryIndex(temporaryIndex);
505 do
506 {
507 traverser.nextIteration();
508 root->traverse(&traverser);
509 traverser.updateTree();
510 } while (traverser.usedTreeInsertion());
511 traverser.insertHelperDefinitions(root);
512 traverser.updateTree();
513}