blob: 31914dcf360735b17d7a3aeaccee8dce3a9f1cfd [file] [log] [blame]
Olli Etuaho5d91dda2015-06-18 15:47:46 +03001//
2// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7// replacing them with calls to functions that choose which component to return or write.
8//
9
10#include "compiler/translator/RemoveDynamicIndexing.h"
11
12#include "compiler/translator/InfoSink.h"
13#include "compiler/translator/IntermNode.h"
Jamie Madill666f65a2016-08-26 01:34:37 +000014#include "compiler/translator/IntermNodePatternMatcher.h"
Olli Etuaho5d91dda2015-06-18 15:47:46 +030015#include "compiler/translator/SymbolTable.h"
16
Jamie Madill45bcc782016-11-07 13:58:48 -050017namespace sh
18{
19
Olli Etuaho5d91dda2015-06-18 15:47:46 +030020namespace
21{
22
23TName GetIndexFunctionName(const TType &type, bool write)
24{
25 TInfoSinkBase nameSink;
26 nameSink << "dyn_index_";
27 if (write)
28 {
29 nameSink << "write_";
30 }
31 if (type.isMatrix())
32 {
33 nameSink << "mat" << type.getCols() << "x" << type.getRows();
34 }
35 else
36 {
37 switch (type.getBasicType())
38 {
39 case EbtInt:
40 nameSink << "ivec";
41 break;
42 case EbtBool:
43 nameSink << "bvec";
44 break;
45 case EbtUInt:
46 nameSink << "uvec";
47 break;
48 case EbtFloat:
49 nameSink << "vec";
50 break;
51 default:
52 UNREACHABLE();
53 }
54 nameSink << type.getNominalSize();
55 }
56 TString nameString = TFunction::mangleName(nameSink.c_str());
57 TName name(nameString);
58 name.setInternal(true);
59 return name;
60}
61
62TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
63{
64 TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
65 symbol->setInternal(true);
66 symbol->getTypePointer()->setQualifier(qualifier);
67 return symbol;
68}
69
70TIntermSymbol *CreateIndexSymbol()
71{
72 TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
73 symbol->setInternal(true);
74 symbol->getTypePointer()->setQualifier(EvqIn);
75 return symbol;
76}
77
78TIntermSymbol *CreateValueSymbol(const TType &type)
79{
80 TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
81 symbol->setInternal(true);
82 symbol->getTypePointer()->setQualifier(EvqIn);
83 return symbol;
84}
85
86TIntermConstantUnion *CreateIntConstantNode(int i)
87{
88 TConstantUnion *constant = new TConstantUnion();
89 constant->setIConst(i);
90 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
91}
92
93TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
94 const TType &fieldType,
95 const int index,
96 TQualifier baseQualifier)
97{
Olli Etuaho5d91dda2015-06-18 15:47:46 +030098 TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier);
Olli Etuaho3272a6d2016-08-29 17:54:50 +030099 TIntermBinary *indexNode =
100 new TIntermBinary(EOpIndexDirect, baseSymbol, TIntermTyped::CreateIndexNode(index));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300101 return indexNode;
102}
103
104TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
105{
Olli Etuaho3272a6d2016-08-29 17:54:50 +0300106 return new TIntermBinary(EOpAssign, targetNode, CreateValueSymbol(assignedValueType));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300107}
108
109TIntermTyped *EnsureSignedInt(TIntermTyped *node)
110{
111 if (node->getBasicType() == EbtInt)
112 return node;
113
114 TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt);
115 convertedNode->setType(TType(EbtInt));
116 convertedNode->getSequence()->push_back(node);
117 convertedNode->setPrecisionFromChildren();
118 return convertedNode;
119}
120
121TType GetFieldType(const TType &indexedType)
122{
123 if (indexedType.isMatrix())
124 {
125 TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
126 fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
127 return fieldType;
128 }
129 else
130 {
131 return TType(indexedType.getBasicType(), indexedType.getPrecision());
132 }
133}
134
135// Generate a read or write function for one field in a vector/matrix.
136// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
137// indices in other places.
138// Note that indices can be either int or uint. We create only int versions of the functions,
139// and convert uint indices to int at the call site.
140// read function example:
141// float dyn_index_vec2(in vec2 base, in int index)
142// {
143// switch(index)
144// {
145// case (0):
146// return base[0];
147// case (1):
148// return base[1];
149// default:
150// break;
151// }
152// if (index < 0)
153// return base[0];
154// return base[1];
155// }
156// write function example:
157// void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
158// {
159// switch(index)
160// {
161// case (0):
162// base[0] = value;
163// return;
164// case (1):
165// base[1] = value;
166// return;
167// default:
168// break;
169// }
170// if (index < 0)
171// {
172// base[0] = value;
173// return;
174// }
175// base[1] = value;
176// }
177// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
Olli Etuaho336b1472016-10-05 16:37:55 +0100178TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write)
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300179{
180 ASSERT(!type.isArray());
181 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
182 // end up using mediump version of an indexing function for a highp value, if both mediump and
183 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
184 // principle this code could be used with multiple backends.
185 type.setPrecision(EbpHigh);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300186
187 TType fieldType = GetFieldType(type);
188 int numCases = 0;
189 if (type.isMatrix())
190 {
191 numCases = type.getCols();
192 }
193 else
194 {
195 numCases = type.getNominalSize();
196 }
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300197
198 TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
199 TQualifier baseQualifier = EvqInOut;
200 if (!write)
201 baseQualifier = EvqIn;
202 TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier);
203 paramsNode->getSequence()->push_back(baseParam);
204 TIntermSymbol *indexParam = CreateIndexSymbol();
205 paramsNode->getSequence()->push_back(indexParam);
206 if (write)
207 {
208 TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
209 paramsNode->getSequence()->push_back(valueParam);
210 }
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300211
Olli Etuaho6d40bbd2016-09-30 13:49:38 +0100212 TIntermBlock *statementList = new TIntermBlock();
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300213 for (int i = 0; i < numCases; ++i)
214 {
215 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
216 statementList->getSequence()->push_back(caseNode);
217
218 TIntermBinary *indexNode =
219 CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier);
220 if (write)
221 {
222 TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
223 statementList->getSequence()->push_back(assignNode);
224 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
225 statementList->getSequence()->push_back(returnNode);
226 }
227 else
228 {
229 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
230 statementList->getSequence()->push_back(returnNode);
231 }
232 }
233
234 // Default case
235 TIntermCase *defaultNode = new TIntermCase(nullptr);
236 statementList->getSequence()->push_back(defaultNode);
237 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
238 statementList->getSequence()->push_back(breakNode);
239
240 TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
241
Olli Etuaho6d40bbd2016-09-30 13:49:38 +0100242 TIntermBlock *bodyNode = new TIntermBlock();
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300243 bodyNode->getSequence()->push_back(switchNode);
244
Olli Etuaho3272a6d2016-08-29 17:54:50 +0300245 TIntermBinary *cond =
246 new TIntermBinary(EOpLessThan, CreateIndexSymbol(), CreateIntConstantNode(0));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300247 cond->setType(TType(EbtBool, EbpUndefined));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300248
249 // Two blocks: one accesses (either reads or writes) the first element and returns,
250 // the other accesses the last element.
Olli Etuaho6d40bbd2016-09-30 13:49:38 +0100251 TIntermBlock *useFirstBlock = new TIntermBlock();
252 TIntermBlock *useLastBlock = new TIntermBlock();
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300253 TIntermBinary *indexFirstNode =
254 CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier);
255 TIntermBinary *indexLastNode =
256 CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier);
257 if (write)
258 {
259 TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
260 useFirstBlock->getSequence()->push_back(assignFirstNode);
261 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
262 useFirstBlock->getSequence()->push_back(returnNode);
263
264 TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
265 useLastBlock->getSequence()->push_back(assignLastNode);
266 }
267 else
268 {
269 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
270 useFirstBlock->getSequence()->push_back(returnFirstNode);
271
272 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
273 useLastBlock->getSequence()->push_back(returnLastNode);
274 }
Olli Etuaho57961272016-09-14 13:57:46 +0300275 TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300276 bodyNode->getSequence()->push_back(ifNode);
277 bodyNode->getSequence()->push_back(useLastBlock);
278
Olli Etuaho336b1472016-10-05 16:37:55 +0100279 TIntermFunctionDefinition *indexingFunction = nullptr;
280 if (write)
281 {
282 indexingFunction = new TIntermFunctionDefinition(TType(EbtVoid), paramsNode, bodyNode);
283 }
284 else
285 {
286 indexingFunction = new TIntermFunctionDefinition(fieldType, paramsNode, bodyNode);
287 }
288 indexingFunction->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300289 return indexingFunction;
290}
291
292class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
293{
294 public:
295 RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
296
297 bool visitBinary(Visit visit, TIntermBinary *node) override;
298
299 void insertHelperDefinitions(TIntermNode *root);
300
301 void nextIteration();
302
303 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
304
305 protected:
306 // Sets of types that are indexed. Note that these can not store multiple variants
307 // of the same type with different precisions - only one precision gets stored.
308 std::set<TType> mIndexedVecAndMatrixTypes;
309 std::set<TType> mWrittenVecAndMatrixTypes;
310
311 bool mUsedTreeInsertion;
312
313 // When true, the traverser will remove side effects from any indexing expression.
314 // This is done so that in code like
315 // V[j++][i]++.
316 // where V is an array of vectors, j++ will only be evaluated once.
317 bool mRemoveIndexSideEffectsInSubtree;
318};
319
320RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
321 int shaderVersion)
322 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
323 mUsedTreeInsertion(false),
324 mRemoveIndexSideEffectsInSubtree(false)
325{
326}
327
328void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
329{
Olli Etuaho6d40bbd2016-09-30 13:49:38 +0100330 TIntermBlock *rootBlock = root->getAsBlock();
331 ASSERT(rootBlock != nullptr);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300332 TIntermSequence insertions;
333 for (TType type : mIndexedVecAndMatrixTypes)
334 {
335 insertions.push_back(GetIndexFunctionDefinition(type, false));
336 }
337 for (TType type : mWrittenVecAndMatrixTypes)
338 {
339 insertions.push_back(GetIndexFunctionDefinition(type, true));
340 }
Olli Etuaho6d40bbd2016-09-30 13:49:38 +0100341 mInsertions.push_back(NodeInsertMultipleEntry(rootBlock, 0, insertions, TIntermSequence()));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300342}
343
344// Create a call to dyn_index_*() based on an indirect indexing op node
345TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
346 TIntermTyped *indexedNode,
347 TIntermTyped *index)
348{
349 ASSERT(node->getOp() == EOpIndexIndirect);
350 TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall);
351 indexingCall->setLine(node->getLine());
352 indexingCall->setUserDefined();
Olli Etuahobd674552016-10-06 13:28:42 +0100353 indexingCall->getFunctionSymbolInfo()->setNameObj(
354 GetIndexFunctionName(indexedNode->getType(), false));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300355 indexingCall->getSequence()->push_back(indexedNode);
356 indexingCall->getSequence()->push_back(index);
357
358 TType fieldType = GetFieldType(indexedNode->getType());
359 indexingCall->setType(fieldType);
360 return indexingCall;
361}
362
363TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
364 TIntermTyped *index,
365 TIntermTyped *writtenValue)
366{
367 // Deep copy the left node so that two pointers to the same node don't end up in the tree.
368 TIntermNode *leftCopy = node->getLeft()->deepCopy();
369 ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr);
370 TIntermAggregate *indexedWriteCall =
371 CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index);
Olli Etuahobd674552016-10-06 13:28:42 +0100372 indexedWriteCall->getFunctionSymbolInfo()->setNameObj(
373 GetIndexFunctionName(node->getLeft()->getType(), true));
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300374 indexedWriteCall->setType(TType(EbtVoid));
375 indexedWriteCall->getSequence()->push_back(writtenValue);
376 return indexedWriteCall;
377}
378
379bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
380{
381 if (mUsedTreeInsertion)
382 return false;
383
384 if (node->getOp() == EOpIndexIndirect)
385 {
386 if (mRemoveIndexSideEffectsInSubtree)
387 {
388 ASSERT(node->getRight()->hasSideEffects());
389 // In case we're just removing index side effects, convert
390 // v_expr[index_expr]
391 // to this:
392 // int s0 = index_expr; v_expr[s0];
393 // Now v_expr[s0] can be safely executed several times without unintended side effects.
394
395 // Init the temp variable holding the index
Olli Etuaho13389b62016-10-16 11:48:18 +0100396 TIntermDeclaration *initIndex = createTempInitDeclaration(node->getRight());
Jamie Madill1048e432016-07-23 18:51:28 -0400397 insertStatementInParentBlock(initIndex);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300398 mUsedTreeInsertion = true;
399
400 // Replace the index with the temp variable
401 TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
Jamie Madill03d863c2016-07-27 18:15:53 -0400402 queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300403 }
Jamie Madill666f65a2016-08-26 01:34:37 +0000404 else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300405 {
406 bool write = isLValueRequiredHere();
407
Jamie Madill666f65a2016-08-26 01:34:37 +0000408#if defined(ANGLE_ENABLE_ASSERTS)
409 // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
410 // implemented checks in this traverser.
411 IntermNodePatternMatcher matcher(
412 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
413 ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
414#endif
415
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300416 TType type = node->getLeft()->getType();
417 mIndexedVecAndMatrixTypes.insert(type);
418
419 if (write)
420 {
421 // Convert:
422 // v_expr[index_expr]++;
423 // to this:
424 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
425 // dyn_index_write(v_expr, s0, s1);
426 // This works even if index_expr has some side effects.
427 if (node->getLeft()->hasSideEffects())
428 {
429 // If v_expr has side effects, those need to be removed before proceeding.
430 // Otherwise the side effects of v_expr would be evaluated twice.
431 // The only case where an l-value can have side effects is when it is
432 // indexing. For example, it can be V[j++] where V is an array of vectors.
433 mRemoveIndexSideEffectsInSubtree = true;
434 return true;
435 }
436 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
437 // only writes it and doesn't need the previous value. http://anglebug.com/1116
438
439 mWrittenVecAndMatrixTypes.insert(type);
440 TType fieldType = GetFieldType(type);
441
442 TIntermSequence insertionsBefore;
443 TIntermSequence insertionsAfter;
444
445 // Store the index in a temporary signed int variable.
446 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
Olli Etuaho13389b62016-10-16 11:48:18 +0100447 TIntermDeclaration *initIndex = createTempInitDeclaration(indexInitializer);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300448 initIndex->setLine(node->getLine());
449 insertionsBefore.push_back(initIndex);
450
451 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
452 node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
453
454 // Create a node for referring to the index after the nextTemporaryIndex() call
455 // below.
456 TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
457
458 nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
459 // field value.
460 insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
461
462 TIntermAggregate *indexedWriteCall =
463 CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType));
464 insertionsAfter.push_back(indexedWriteCall);
465 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
Jamie Madill03d863c2016-07-27 18:15:53 -0400466 queueReplacement(node, createTempSymbol(fieldType), OriginalNode::IS_DROPPED);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300467 mUsedTreeInsertion = true;
468 }
469 else
470 {
471 // The indexed value is not being written, so we can simply convert
472 // v_expr[index_expr]
473 // into
474 // dyn_index(v_expr, index_expr)
475 // If the index_expr is unsigned, we'll convert it to signed.
476 ASSERT(!mRemoveIndexSideEffectsInSubtree);
477 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
478 node, node->getLeft(), EnsureSignedInt(node->getRight()));
Jamie Madill03d863c2016-07-27 18:15:53 -0400479 queueReplacement(node, indexingCall, OriginalNode::IS_DROPPED);
Olli Etuaho5d91dda2015-06-18 15:47:46 +0300480 }
481 }
482 }
483 return !mUsedTreeInsertion;
484}
485
486void RemoveDynamicIndexingTraverser::nextIteration()
487{
488 mUsedTreeInsertion = false;
489 mRemoveIndexSideEffectsInSubtree = false;
490 nextTemporaryIndex();
491}
492
493} // namespace
494
495void RemoveDynamicIndexing(TIntermNode *root,
496 unsigned int *temporaryIndex,
497 const TSymbolTable &symbolTable,
498 int shaderVersion)
499{
500 RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
501 ASSERT(temporaryIndex != nullptr);
502 traverser.useTemporaryIndex(temporaryIndex);
503 do
504 {
505 traverser.nextIteration();
506 root->traverse(&traverser);
507 traverser.updateTree();
508 } while (traverser.usedTreeInsertion());
509 traverser.insertHelperDefinitions(root);
510 traverser.updateTree();
511}
Jamie Madill45bcc782016-11-07 13:58:48 -0500512
513} // namespace sh