blob: 35f6c10708788a284346906370af5a293322a30c [file] [log] [blame]
Olli Etuaho5d91dda2015-06-18 15:47:46 +03001//
2// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7// replacing them with calls to functions that choose which component to return or write.
8//
9
10#include "compiler/translator/RemoveDynamicIndexing.h"
11
12#include "compiler/translator/InfoSink.h"
13#include "compiler/translator/IntermNode.h"
Olli Etuaho7da98502016-07-20 18:45:09 +030014#include "compiler/translator/IntermNodePatternMatcher.h"
Olli Etuaho5d91dda2015-06-18 15:47:46 +030015#include "compiler/translator/SymbolTable.h"
16
17namespace
18{
19
20TName GetIndexFunctionName(const TType &type, bool write)
21{
22 TInfoSinkBase nameSink;
23 nameSink << "dyn_index_";
24 if (write)
25 {
26 nameSink << "write_";
27 }
28 if (type.isMatrix())
29 {
30 nameSink << "mat" << type.getCols() << "x" << type.getRows();
31 }
32 else
33 {
34 switch (type.getBasicType())
35 {
36 case EbtInt:
37 nameSink << "ivec";
38 break;
39 case EbtBool:
40 nameSink << "bvec";
41 break;
42 case EbtUInt:
43 nameSink << "uvec";
44 break;
45 case EbtFloat:
46 nameSink << "vec";
47 break;
48 default:
49 UNREACHABLE();
50 }
51 nameSink << type.getNominalSize();
52 }
53 TString nameString = TFunction::mangleName(nameSink.c_str());
54 TName name(nameString);
55 name.setInternal(true);
56 return name;
57}
58
59TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
60{
61 TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
62 symbol->setInternal(true);
63 symbol->getTypePointer()->setQualifier(qualifier);
64 return symbol;
65}
66
67TIntermSymbol *CreateIndexSymbol()
68{
69 TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
70 symbol->setInternal(true);
71 symbol->getTypePointer()->setQualifier(EvqIn);
72 return symbol;
73}
74
75TIntermSymbol *CreateValueSymbol(const TType &type)
76{
77 TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
78 symbol->setInternal(true);
79 symbol->getTypePointer()->setQualifier(EvqIn);
80 return symbol;
81}
82
83TIntermConstantUnion *CreateIntConstantNode(int i)
84{
85 TConstantUnion *constant = new TConstantUnion();
86 constant->setIConst(i);
87 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
88}
89
90TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
91 const TType &fieldType,
92 const int index,
93 TQualifier baseQualifier)
94{
95 TIntermBinary *indexNode = new TIntermBinary(EOpIndexDirect);
96 indexNode->setType(fieldType);
97 TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier);
98 indexNode->setLeft(baseSymbol);
99 indexNode->setRight(CreateIntConstantNode(index));
100 return indexNode;
101}
102
103TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
104{
105 TIntermBinary *assignNode = new TIntermBinary(EOpAssign);
106 assignNode->setType(assignedValueType);
107 assignNode->setLeft(targetNode);
108 assignNode->setRight(CreateValueSymbol(assignedValueType));
109 return assignNode;
110}
111
112TIntermTyped *EnsureSignedInt(TIntermTyped *node)
113{
114 if (node->getBasicType() == EbtInt)
115 return node;
116
117 TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt);
118 convertedNode->setType(TType(EbtInt));
119 convertedNode->getSequence()->push_back(node);
120 convertedNode->setPrecisionFromChildren();
121 return convertedNode;
122}
123
124TType GetFieldType(const TType &indexedType)
125{
126 if (indexedType.isMatrix())
127 {
128 TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
129 fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
130 return fieldType;
131 }
132 else
133 {
134 return TType(indexedType.getBasicType(), indexedType.getPrecision());
135 }
136}
137
138// Generate a read or write function for one field in a vector/matrix.
139// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
140// indices in other places.
141// Note that indices can be either int or uint. We create only int versions of the functions,
142// and convert uint indices to int at the call site.
143// read function example:
144// float dyn_index_vec2(in vec2 base, in int index)
145// {
146// switch(index)
147// {
148// case (0):
149// return base[0];
150// case (1):
151// return base[1];
152// default:
153// break;
154// }
155// if (index < 0)
156// return base[0];
157// return base[1];
158// }
159// write function example:
160// void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
161// {
162// switch(index)
163// {
164// case (0):
165// base[0] = value;
166// return;
167// case (1):
168// base[1] = value;
169// return;
170// default:
171// break;
172// }
173// if (index < 0)
174// {
175// base[0] = value;
176// return;
177// }
178// base[1] = value;
179// }
180// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
181TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
182{
183 ASSERT(!type.isArray());
184 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
185 // end up using mediump version of an indexing function for a highp value, if both mediump and
186 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
187 // principle this code could be used with multiple backends.
188 type.setPrecision(EbpHigh);
189 TIntermAggregate *indexingFunction = new TIntermAggregate(EOpFunction);
190 indexingFunction->setNameObj(GetIndexFunctionName(type, write));
191
192 TType fieldType = GetFieldType(type);
193 int numCases = 0;
194 if (type.isMatrix())
195 {
196 numCases = type.getCols();
197 }
198 else
199 {
200 numCases = type.getNominalSize();
201 }
202 if (write)
203 {
204 indexingFunction->setType(TType(EbtVoid));
205 }
206 else
207 {
208 indexingFunction->setType(fieldType);
209 }
210
211 TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
212 TQualifier baseQualifier = EvqInOut;
213 if (!write)
214 baseQualifier = EvqIn;
215 TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier);
216 paramsNode->getSequence()->push_back(baseParam);
217 TIntermSymbol *indexParam = CreateIndexSymbol();
218 paramsNode->getSequence()->push_back(indexParam);
219 if (write)
220 {
221 TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
222 paramsNode->getSequence()->push_back(valueParam);
223 }
224 indexingFunction->getSequence()->push_back(paramsNode);
225
226 TIntermAggregate *statementList = new TIntermAggregate(EOpSequence);
227 for (int i = 0; i < numCases; ++i)
228 {
229 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
230 statementList->getSequence()->push_back(caseNode);
231
232 TIntermBinary *indexNode =
233 CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier);
234 if (write)
235 {
236 TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
237 statementList->getSequence()->push_back(assignNode);
238 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
239 statementList->getSequence()->push_back(returnNode);
240 }
241 else
242 {
243 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
244 statementList->getSequence()->push_back(returnNode);
245 }
246 }
247
248 // Default case
249 TIntermCase *defaultNode = new TIntermCase(nullptr);
250 statementList->getSequence()->push_back(defaultNode);
251 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
252 statementList->getSequence()->push_back(breakNode);
253
254 TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
255
256 TIntermAggregate *bodyNode = new TIntermAggregate(EOpSequence);
257 bodyNode->getSequence()->push_back(switchNode);
258
259 TIntermBinary *cond = new TIntermBinary(EOpLessThan);
260 cond->setType(TType(EbtBool, EbpUndefined));
261 cond->setLeft(CreateIndexSymbol());
262 cond->setRight(CreateIntConstantNode(0));
263
264 // Two blocks: one accesses (either reads or writes) the first element and returns,
265 // the other accesses the last element.
266 TIntermAggregate *useFirstBlock = new TIntermAggregate(EOpSequence);
267 TIntermAggregate *useLastBlock = new TIntermAggregate(EOpSequence);
268 TIntermBinary *indexFirstNode =
269 CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier);
270 TIntermBinary *indexLastNode =
271 CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier);
272 if (write)
273 {
274 TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
275 useFirstBlock->getSequence()->push_back(assignFirstNode);
276 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
277 useFirstBlock->getSequence()->push_back(returnNode);
278
279 TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
280 useLastBlock->getSequence()->push_back(assignLastNode);
281 }
282 else
283 {
284 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
285 useFirstBlock->getSequence()->push_back(returnFirstNode);
286
287 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
288 useLastBlock->getSequence()->push_back(returnLastNode);
289 }
290 TIntermSelection *ifNode = new TIntermSelection(cond, useFirstBlock, nullptr);
291 bodyNode->getSequence()->push_back(ifNode);
292 bodyNode->getSequence()->push_back(useLastBlock);
293
294 indexingFunction->getSequence()->push_back(bodyNode);
295
296 return indexingFunction;
297}
298
299class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
300{
301 public:
302 RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
303
304 bool visitBinary(Visit visit, TIntermBinary *node) override;
305
306 void insertHelperDefinitions(TIntermNode *root);
307
308 void nextIteration();
309
310 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
311
312 protected:
313 // Sets of types that are indexed. Note that these can not store multiple variants
314 // of the same type with different precisions - only one precision gets stored.
315 std::set<TType> mIndexedVecAndMatrixTypes;
316 std::set<TType> mWrittenVecAndMatrixTypes;
317
318 bool mUsedTreeInsertion;
319
320 // When true, the traverser will remove side effects from any indexing expression.
321 // This is done so that in code like
322 // V[j++][i]++.
323 // where V is an array of vectors, j++ will only be evaluated once.
324 bool mRemoveIndexSideEffectsInSubtree;
325};
326
327RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
328 int shaderVersion)
329 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
330 mUsedTreeInsertion(false),
331 mRemoveIndexSideEffectsInSubtree(false)
332{
333}
334
335void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
336{
337 TIntermAggregate *rootAgg = root->getAsAggregate();
338 ASSERT(rootAgg != nullptr && rootAgg->getOp() == EOpSequence);
339 TIntermSequence insertions;
340 for (TType type : mIndexedVecAndMatrixTypes)
341 {
342 insertions.push_back(GetIndexFunctionDefinition(type, false));
343 }
344 for (TType type : mWrittenVecAndMatrixTypes)
345 {
346 insertions.push_back(GetIndexFunctionDefinition(type, true));
347 }
348 mInsertions.push_back(NodeInsertMultipleEntry(rootAgg, 0, insertions, TIntermSequence()));
349}
350
351// Create a call to dyn_index_*() based on an indirect indexing op node
352TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
353 TIntermTyped *indexedNode,
354 TIntermTyped *index)
355{
356 ASSERT(node->getOp() == EOpIndexIndirect);
357 TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall);
358 indexingCall->setLine(node->getLine());
359 indexingCall->setUserDefined();
360 indexingCall->setNameObj(GetIndexFunctionName(indexedNode->getType(), false));
361 indexingCall->getSequence()->push_back(indexedNode);
362 indexingCall->getSequence()->push_back(index);
363
364 TType fieldType = GetFieldType(indexedNode->getType());
365 indexingCall->setType(fieldType);
366 return indexingCall;
367}
368
369TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
370 TIntermTyped *index,
371 TIntermTyped *writtenValue)
372{
373 // Deep copy the left node so that two pointers to the same node don't end up in the tree.
374 TIntermNode *leftCopy = node->getLeft()->deepCopy();
375 ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr);
376 TIntermAggregate *indexedWriteCall =
377 CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index);
378 indexedWriteCall->setNameObj(GetIndexFunctionName(node->getLeft()->getType(), true));
379 indexedWriteCall->setType(TType(EbtVoid));
380 indexedWriteCall->getSequence()->push_back(writtenValue);
381 return indexedWriteCall;
382}
383
384bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
385{
386 if (mUsedTreeInsertion)
387 return false;
388
389 if (node->getOp() == EOpIndexIndirect)
390 {
391 if (mRemoveIndexSideEffectsInSubtree)
392 {
393 ASSERT(node->getRight()->hasSideEffects());
394 // In case we're just removing index side effects, convert
395 // v_expr[index_expr]
396 // to this:
397 // int s0 = index_expr; v_expr[s0];
398 // Now v_expr[s0] can be safely executed several times without unintended side effects.
399
400 // Init the temp variable holding the index
401 TIntermAggregate *initIndex = createTempInitDeclaration(node->getRight());
402 TIntermSequence insertions;
403 insertions.push_back(initIndex);
404 insertStatementsInParentBlock(insertions);
405 mUsedTreeInsertion = true;
406
407 // Replace the index with the temp variable
408 TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
409 NodeUpdateEntry replaceIndex(node, node->getRight(), tempIndex, false);
410 mReplacements.push_back(replaceIndex);
411 }
Olli Etuaho7da98502016-07-20 18:45:09 +0300412 else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300413 {
414 bool write = isLValueRequiredHere();
415
Olli Etuaho7da98502016-07-20 18:45:09 +0300416#if defined(ANGLE_ENABLE_ASSERTS)
417 // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
418 // implemented checks in this traverser.
419 IntermNodePatternMatcher matcher(
420 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
421 ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
422#endif
423
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300424 TType type = node->getLeft()->getType();
425 mIndexedVecAndMatrixTypes.insert(type);
426
427 if (write)
428 {
429 // Convert:
430 // v_expr[index_expr]++;
431 // to this:
432 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
433 // dyn_index_write(v_expr, s0, s1);
434 // This works even if index_expr has some side effects.
435 if (node->getLeft()->hasSideEffects())
436 {
437 // If v_expr has side effects, those need to be removed before proceeding.
438 // Otherwise the side effects of v_expr would be evaluated twice.
439 // The only case where an l-value can have side effects is when it is
440 // indexing. For example, it can be V[j++] where V is an array of vectors.
441 mRemoveIndexSideEffectsInSubtree = true;
442 return true;
443 }
444 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
445 // only writes it and doesn't need the previous value. http://anglebug.com/1116
446
447 mWrittenVecAndMatrixTypes.insert(type);
448 TType fieldType = GetFieldType(type);
449
450 TIntermSequence insertionsBefore;
451 TIntermSequence insertionsAfter;
452
453 // Store the index in a temporary signed int variable.
454 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
455 TIntermAggregate *initIndex = createTempInitDeclaration(indexInitializer);
456 initIndex->setLine(node->getLine());
457 insertionsBefore.push_back(initIndex);
458
459 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
460 node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
461
462 // Create a node for referring to the index after the nextTemporaryIndex() call
463 // below.
464 TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
465
466 nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
467 // field value.
468 insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
469
470 TIntermAggregate *indexedWriteCall =
471 CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType));
472 insertionsAfter.push_back(indexedWriteCall);
473 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
474 NodeUpdateEntry replaceIndex(getParentNode(), node, createTempSymbol(fieldType),
475 false);
476 mReplacements.push_back(replaceIndex);
477 mUsedTreeInsertion = true;
478 }
479 else
480 {
481 // The indexed value is not being written, so we can simply convert
482 // v_expr[index_expr]
483 // into
484 // dyn_index(v_expr, index_expr)
485 // If the index_expr is unsigned, we'll convert it to signed.
486 ASSERT(!mRemoveIndexSideEffectsInSubtree);
487 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
488 node, node->getLeft(), EnsureSignedInt(node->getRight()));
489 NodeUpdateEntry replaceIndex(getParentNode(), node, indexingCall, false);
490 mReplacements.push_back(replaceIndex);
491 }
492 }
493 }
494 return !mUsedTreeInsertion;
495}
496
497void RemoveDynamicIndexingTraverser::nextIteration()
498{
499 mUsedTreeInsertion = false;
500 mRemoveIndexSideEffectsInSubtree = false;
501 nextTemporaryIndex();
502}
503
504} // namespace
505
506void RemoveDynamicIndexing(TIntermNode *root,
507 unsigned int *temporaryIndex,
508 const TSymbolTable &symbolTable,
509 int shaderVersion)
510{
511 RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
512 ASSERT(temporaryIndex != nullptr);
513 traverser.useTemporaryIndex(temporaryIndex);
514 do
515 {
516 traverser.nextIteration();
517 root->traverse(&traverser);
518 traverser.updateTree();
519 } while (traverser.usedTreeInsertion());
520 traverser.insertHelperDefinitions(root);
521 traverser.updateTree();
522}