blob: e846fccbde99fa018331da73590994fc97e0d049 [file] [log] [blame]
Olli Etuaho5d91dda2015-06-18 15:47:46 +03001//
2// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7// replacing them with calls to functions that choose which component to return or write.
8//
9
10#include "compiler/translator/RemoveDynamicIndexing.h"
11
12#include "compiler/translator/InfoSink.h"
13#include "compiler/translator/IntermNode.h"
Olli Etuaho7da98502016-07-20 18:45:09 +030014#include "compiler/translator/IntermNodePatternMatcher.h"
Olli Etuaho5d91dda2015-06-18 15:47:46 +030015#include "compiler/translator/SymbolTable.h"
16
17namespace
18{
19
20TName GetIndexFunctionName(const TType &type, bool write)
21{
22 TInfoSinkBase nameSink;
23 nameSink << "dyn_index_";
24 if (write)
25 {
26 nameSink << "write_";
27 }
28 if (type.isMatrix())
29 {
30 nameSink << "mat" << type.getCols() << "x" << type.getRows();
31 }
32 else
33 {
34 switch (type.getBasicType())
35 {
36 case EbtInt:
37 nameSink << "ivec";
38 break;
39 case EbtBool:
40 nameSink << "bvec";
41 break;
42 case EbtUInt:
43 nameSink << "uvec";
44 break;
45 case EbtFloat:
46 nameSink << "vec";
47 break;
48 default:
49 UNREACHABLE();
50 }
51 nameSink << type.getNominalSize();
52 }
53 TString nameString = TFunction::mangleName(nameSink.c_str());
54 TName name(nameString);
55 name.setInternal(true);
56 return name;
57}
58
59TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
60{
61 TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
62 symbol->setInternal(true);
63 symbol->getTypePointer()->setQualifier(qualifier);
64 return symbol;
65}
66
67TIntermSymbol *CreateIndexSymbol()
68{
69 TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
70 symbol->setInternal(true);
71 symbol->getTypePointer()->setQualifier(EvqIn);
72 return symbol;
73}
74
75TIntermSymbol *CreateValueSymbol(const TType &type)
76{
77 TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
78 symbol->setInternal(true);
79 symbol->getTypePointer()->setQualifier(EvqIn);
80 return symbol;
81}
82
83TIntermConstantUnion *CreateIntConstantNode(int i)
84{
85 TConstantUnion *constant = new TConstantUnion();
86 constant->setIConst(i);
87 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
88}
89
90TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
91 const TType &fieldType,
92 const int index,
93 TQualifier baseQualifier)
94{
95 TIntermBinary *indexNode = new TIntermBinary(EOpIndexDirect);
96 indexNode->setType(fieldType);
97 TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier);
98 indexNode->setLeft(baseSymbol);
99 indexNode->setRight(CreateIntConstantNode(index));
100 return indexNode;
101}
102
103TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
104{
105 TIntermBinary *assignNode = new TIntermBinary(EOpAssign);
106 assignNode->setType(assignedValueType);
107 assignNode->setLeft(targetNode);
108 assignNode->setRight(CreateValueSymbol(assignedValueType));
109 return assignNode;
110}
111
112TIntermTyped *EnsureSignedInt(TIntermTyped *node)
113{
114 if (node->getBasicType() == EbtInt)
115 return node;
116
117 TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt);
118 convertedNode->setType(TType(EbtInt));
119 convertedNode->getSequence()->push_back(node);
120 convertedNode->setPrecisionFromChildren();
121 return convertedNode;
122}
123
124TType GetFieldType(const TType &indexedType)
125{
126 if (indexedType.isMatrix())
127 {
128 TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
129 fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
130 return fieldType;
131 }
132 else
133 {
134 return TType(indexedType.getBasicType(), indexedType.getPrecision());
135 }
136}
137
138// Generate a read or write function for one field in a vector/matrix.
139// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
140// indices in other places.
141// Note that indices can be either int or uint. We create only int versions of the functions,
142// and convert uint indices to int at the call site.
143// read function example:
144// float dyn_index_vec2(in vec2 base, in int index)
145// {
146// switch(index)
147// {
148// case (0):
149// return base[0];
150// case (1):
151// return base[1];
152// default:
153// break;
154// }
155// if (index < 0)
156// return base[0];
157// return base[1];
158// }
159// write function example:
160// void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
161// {
162// switch(index)
163// {
164// case (0):
165// base[0] = value;
166// return;
167// case (1):
168// base[1] = value;
169// return;
170// default:
171// break;
172// }
173// if (index < 0)
174// {
175// base[0] = value;
176// return;
177// }
178// base[1] = value;
179// }
180// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
181TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
182{
183 ASSERT(!type.isArray());
184 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
185 // end up using mediump version of an indexing function for a highp value, if both mediump and
186 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
187 // principle this code could be used with multiple backends.
188 type.setPrecision(EbpHigh);
189 TIntermAggregate *indexingFunction = new TIntermAggregate(EOpFunction);
190 indexingFunction->setNameObj(GetIndexFunctionName(type, write));
191
192 TType fieldType = GetFieldType(type);
193 int numCases = 0;
194 if (type.isMatrix())
195 {
196 numCases = type.getCols();
197 }
198 else
199 {
200 numCases = type.getNominalSize();
201 }
202 if (write)
203 {
204 indexingFunction->setType(TType(EbtVoid));
205 }
206 else
207 {
208 indexingFunction->setType(fieldType);
209 }
210
211 TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
212 TQualifier baseQualifier = EvqInOut;
213 if (!write)
214 baseQualifier = EvqIn;
215 TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier);
216 paramsNode->getSequence()->push_back(baseParam);
217 TIntermSymbol *indexParam = CreateIndexSymbol();
218 paramsNode->getSequence()->push_back(indexParam);
219 if (write)
220 {
221 TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
222 paramsNode->getSequence()->push_back(valueParam);
223 }
224 indexingFunction->getSequence()->push_back(paramsNode);
225
226 TIntermAggregate *statementList = new TIntermAggregate(EOpSequence);
227 for (int i = 0; i < numCases; ++i)
228 {
229 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
230 statementList->getSequence()->push_back(caseNode);
231
232 TIntermBinary *indexNode =
233 CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier);
234 if (write)
235 {
236 TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
237 statementList->getSequence()->push_back(assignNode);
238 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
239 statementList->getSequence()->push_back(returnNode);
240 }
241 else
242 {
243 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
244 statementList->getSequence()->push_back(returnNode);
245 }
246 }
247
248 // Default case
249 TIntermCase *defaultNode = new TIntermCase(nullptr);
250 statementList->getSequence()->push_back(defaultNode);
251 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
252 statementList->getSequence()->push_back(breakNode);
253
254 TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
255
256 TIntermAggregate *bodyNode = new TIntermAggregate(EOpSequence);
257 bodyNode->getSequence()->push_back(switchNode);
258
259 TIntermBinary *cond = new TIntermBinary(EOpLessThan);
260 cond->setType(TType(EbtBool, EbpUndefined));
261 cond->setLeft(CreateIndexSymbol());
262 cond->setRight(CreateIntConstantNode(0));
263
264 // Two blocks: one accesses (either reads or writes) the first element and returns,
265 // the other accesses the last element.
266 TIntermAggregate *useFirstBlock = new TIntermAggregate(EOpSequence);
267 TIntermAggregate *useLastBlock = new TIntermAggregate(EOpSequence);
268 TIntermBinary *indexFirstNode =
269 CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier);
270 TIntermBinary *indexLastNode =
271 CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier);
272 if (write)
273 {
274 TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
275 useFirstBlock->getSequence()->push_back(assignFirstNode);
276 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
277 useFirstBlock->getSequence()->push_back(returnNode);
278
279 TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
280 useLastBlock->getSequence()->push_back(assignLastNode);
281 }
282 else
283 {
284 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
285 useFirstBlock->getSequence()->push_back(returnFirstNode);
286
287 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
288 useLastBlock->getSequence()->push_back(returnLastNode);
289 }
290 TIntermSelection *ifNode = new TIntermSelection(cond, useFirstBlock, nullptr);
291 bodyNode->getSequence()->push_back(ifNode);
292 bodyNode->getSequence()->push_back(useLastBlock);
293
294 indexingFunction->getSequence()->push_back(bodyNode);
295
296 return indexingFunction;
297}
298
299class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
300{
301 public:
302 RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
303
304 bool visitBinary(Visit visit, TIntermBinary *node) override;
305
306 void insertHelperDefinitions(TIntermNode *root);
307
308 void nextIteration();
309
310 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
311
312 protected:
313 // Sets of types that are indexed. Note that these can not store multiple variants
314 // of the same type with different precisions - only one precision gets stored.
315 std::set<TType> mIndexedVecAndMatrixTypes;
316 std::set<TType> mWrittenVecAndMatrixTypes;
317
318 bool mUsedTreeInsertion;
319
320 // When true, the traverser will remove side effects from any indexing expression.
321 // This is done so that in code like
322 // V[j++][i]++.
323 // where V is an array of vectors, j++ will only be evaluated once.
324 bool mRemoveIndexSideEffectsInSubtree;
325};
326
327RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
328 int shaderVersion)
329 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
330 mUsedTreeInsertion(false),
331 mRemoveIndexSideEffectsInSubtree(false)
332{
333}
334
335void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
336{
337 TIntermAggregate *rootAgg = root->getAsAggregate();
338 ASSERT(rootAgg != nullptr && rootAgg->getOp() == EOpSequence);
339 TIntermSequence insertions;
340 for (TType type : mIndexedVecAndMatrixTypes)
341 {
342 insertions.push_back(GetIndexFunctionDefinition(type, false));
343 }
344 for (TType type : mWrittenVecAndMatrixTypes)
345 {
346 insertions.push_back(GetIndexFunctionDefinition(type, true));
347 }
348 mInsertions.push_back(NodeInsertMultipleEntry(rootAgg, 0, insertions, TIntermSequence()));
349}
350
351// Create a call to dyn_index_*() based on an indirect indexing op node
352TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
353 TIntermTyped *indexedNode,
354 TIntermTyped *index)
355{
356 ASSERT(node->getOp() == EOpIndexIndirect);
357 TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall);
358 indexingCall->setLine(node->getLine());
359 indexingCall->setUserDefined();
360 indexingCall->setNameObj(GetIndexFunctionName(indexedNode->getType(), false));
361 indexingCall->getSequence()->push_back(indexedNode);
362 indexingCall->getSequence()->push_back(index);
363
364 TType fieldType = GetFieldType(indexedNode->getType());
365 indexingCall->setType(fieldType);
366 return indexingCall;
367}
368
369TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
370 TIntermTyped *index,
371 TIntermTyped *writtenValue)
372{
373 // Deep copy the left node so that two pointers to the same node don't end up in the tree.
374 TIntermNode *leftCopy = node->getLeft()->deepCopy();
375 ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr);
376 TIntermAggregate *indexedWriteCall =
377 CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index);
378 indexedWriteCall->setNameObj(GetIndexFunctionName(node->getLeft()->getType(), true));
379 indexedWriteCall->setType(TType(EbtVoid));
380 indexedWriteCall->getSequence()->push_back(writtenValue);
381 return indexedWriteCall;
382}
383
384bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
385{
386 if (mUsedTreeInsertion)
387 return false;
388
389 if (node->getOp() == EOpIndexIndirect)
390 {
391 if (mRemoveIndexSideEffectsInSubtree)
392 {
393 ASSERT(node->getRight()->hasSideEffects());
394 // In case we're just removing index side effects, convert
395 // v_expr[index_expr]
396 // to this:
397 // int s0 = index_expr; v_expr[s0];
398 // Now v_expr[s0] can be safely executed several times without unintended side effects.
399
400 // Init the temp variable holding the index
401 TIntermAggregate *initIndex = createTempInitDeclaration(node->getRight());
Jamie Madill1048e432016-07-23 18:51:28 -0400402 insertStatementInParentBlock(initIndex);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300403 mUsedTreeInsertion = true;
404
405 // Replace the index with the temp variable
406 TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
407 NodeUpdateEntry replaceIndex(node, node->getRight(), tempIndex, false);
408 mReplacements.push_back(replaceIndex);
409 }
Olli Etuaho7da98502016-07-20 18:45:09 +0300410 else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300411 {
412 bool write = isLValueRequiredHere();
413
Olli Etuaho7da98502016-07-20 18:45:09 +0300414#if defined(ANGLE_ENABLE_ASSERTS)
415 // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
416 // implemented checks in this traverser.
417 IntermNodePatternMatcher matcher(
418 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
419 ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
420#endif
421
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300422 TType type = node->getLeft()->getType();
423 mIndexedVecAndMatrixTypes.insert(type);
424
425 if (write)
426 {
427 // Convert:
428 // v_expr[index_expr]++;
429 // to this:
430 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
431 // dyn_index_write(v_expr, s0, s1);
432 // This works even if index_expr has some side effects.
433 if (node->getLeft()->hasSideEffects())
434 {
435 // If v_expr has side effects, those need to be removed before proceeding.
436 // Otherwise the side effects of v_expr would be evaluated twice.
437 // The only case where an l-value can have side effects is when it is
438 // indexing. For example, it can be V[j++] where V is an array of vectors.
439 mRemoveIndexSideEffectsInSubtree = true;
440 return true;
441 }
442 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
443 // only writes it and doesn't need the previous value. http://anglebug.com/1116
444
445 mWrittenVecAndMatrixTypes.insert(type);
446 TType fieldType = GetFieldType(type);
447
448 TIntermSequence insertionsBefore;
449 TIntermSequence insertionsAfter;
450
451 // Store the index in a temporary signed int variable.
452 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
453 TIntermAggregate *initIndex = createTempInitDeclaration(indexInitializer);
454 initIndex->setLine(node->getLine());
455 insertionsBefore.push_back(initIndex);
456
457 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
458 node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
459
460 // Create a node for referring to the index after the nextTemporaryIndex() call
461 // below.
462 TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
463
464 nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
465 // field value.
466 insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
467
468 TIntermAggregate *indexedWriteCall =
469 CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType));
470 insertionsAfter.push_back(indexedWriteCall);
471 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
472 NodeUpdateEntry replaceIndex(getParentNode(), node, createTempSymbol(fieldType),
473 false);
474 mReplacements.push_back(replaceIndex);
475 mUsedTreeInsertion = true;
476 }
477 else
478 {
479 // The indexed value is not being written, so we can simply convert
480 // v_expr[index_expr]
481 // into
482 // dyn_index(v_expr, index_expr)
483 // If the index_expr is unsigned, we'll convert it to signed.
484 ASSERT(!mRemoveIndexSideEffectsInSubtree);
485 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
486 node, node->getLeft(), EnsureSignedInt(node->getRight()));
487 NodeUpdateEntry replaceIndex(getParentNode(), node, indexingCall, false);
488 mReplacements.push_back(replaceIndex);
489 }
490 }
491 }
492 return !mUsedTreeInsertion;
493}
494
495void RemoveDynamicIndexingTraverser::nextIteration()
496{
497 mUsedTreeInsertion = false;
498 mRemoveIndexSideEffectsInSubtree = false;
499 nextTemporaryIndex();
500}
501
502} // namespace
503
504void RemoveDynamicIndexing(TIntermNode *root,
505 unsigned int *temporaryIndex,
506 const TSymbolTable &symbolTable,
507 int shaderVersion)
508{
509 RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
510 ASSERT(temporaryIndex != nullptr);
511 traverser.useTemporaryIndex(temporaryIndex);
512 do
513 {
514 traverser.nextIteration();
515 root->traverse(&traverser);
516 traverser.updateTree();
517 } while (traverser.usedTreeInsertion());
518 traverser.insertHelperDefinitions(root);
519 traverser.updateTree();
520}