blob: 5586f853d1940f73b32395476fc84afb265fac84 [file] [log] [blame]
Olli Etuaho83f34112015-06-18 15:47:46 +03001//
2// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7// replacing them with calls to functions that choose which component to return or write.
8//
9
10#include "compiler/translator/RemoveDynamicIndexing.h"
11
12#include "compiler/translator/InfoSink.h"
13#include "compiler/translator/IntermNode.h"
14#include "compiler/translator/SymbolTable.h"
15
16namespace
17{
18
19TName GetIndexFunctionName(const TType &type, bool write)
20{
21 TInfoSinkBase nameSink;
22 nameSink << "dyn_index_";
23 if (write)
24 {
25 nameSink << "write_";
26 }
27 if (type.isMatrix())
28 {
29 nameSink << "mat" << type.getCols() << "x" << type.getRows();
30 }
31 else
32 {
33 switch (type.getBasicType())
34 {
35 case EbtInt:
36 nameSink << "ivec";
37 break;
38 case EbtBool:
39 nameSink << "bvec";
40 break;
41 case EbtUInt:
42 nameSink << "uvec";
43 break;
44 case EbtFloat:
45 nameSink << "vec";
46 break;
47 default:
48 UNREACHABLE();
49 }
50 nameSink << type.getNominalSize();
51 }
52 TString nameString = TFunction::mangleName(nameSink.c_str());
53 TName name(nameString);
54 name.setInternal(true);
55 return name;
56}
57
58TIntermSymbol *CreateBaseSymbol(const TType &type)
59{
60 TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
61 symbol->setInternal(true);
62 return symbol;
63}
64
65TIntermSymbol *CreateIndexSymbol()
66{
67 TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
68 symbol->setInternal(true);
69 return symbol;
70}
71
72TIntermSymbol *CreateValueSymbol(const TType &type)
73{
74 TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
75 symbol->setInternal(true);
76 return symbol;
77}
78
79TIntermConstantUnion *CreateIntConstantNode(int i)
80{
81 TConstantUnion *constant = new TConstantUnion();
82 constant->setIConst(i);
83 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
84}
85
86TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
87 const TType &fieldType,
88 const int index)
89{
90 TIntermBinary *indexNode = new TIntermBinary(EOpIndexDirect);
91 indexNode->setType(fieldType);
92 indexNode->setLeft(CreateBaseSymbol(indexedType));
93 indexNode->setRight(CreateIntConstantNode(index));
94 return indexNode;
95}
96
97TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
98{
99 TIntermBinary *assignNode = new TIntermBinary(EOpAssign);
100 assignNode->setType(assignedValueType);
101 assignNode->setLeft(targetNode);
102 assignNode->setRight(CreateValueSymbol(assignedValueType));
103 return assignNode;
104}
105
106TIntermTyped *EnsureSignedInt(TIntermTyped *node)
107{
108 if (node->getBasicType() == EbtInt)
109 return node;
110
111 TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt);
112 convertedNode->setType(TType(EbtInt));
113 convertedNode->getSequence()->push_back(node);
114 convertedNode->setPrecisionFromChildren();
115 return convertedNode;
116}
117
118TType GetFieldType(const TType &indexedType)
119{
120 if (indexedType.isMatrix())
121 {
122 TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
123 fieldType.setPrimarySize(unsigned char(indexedType.getRows()));
124 return fieldType;
125 }
126 else
127 {
128 return TType(indexedType.getBasicType(), indexedType.getPrecision());
129 }
130}
131
132// Generate a read or write function for one field in a vector/matrix.
133// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
134// indices in other places.
135// Note that indices can be either int or uint. We create only int versions of the functions,
136// and convert uint indices to int at the call site.
137// read function example:
138// float dyn_index_vec2(in vec2 base, in int index) {
139// switch(index) {
140// case (0):
141// return base[0];
142// case (1):
143// return base[1];
144// default:
145// if (index < 0)
146// return base[0];
147// else
148// return base[1];
149// }
150// }
151// write function example:
152// void dyn_index_write_vec2(inout vec2 base, in int index, in float value) {
153// switch(index) {
154// case (0):
155// base[0] = value;
156// break;
157// case (1):
158// base[1] = value;
159// break;
160// default:
161// if (index < 0)
162// base[0] = value;
163// else
164// base[1] = value;
165// break;
166// }
167// }
168TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
169{
170 ASSERT(!type.isArray());
171 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
172 // end up using mediump version of an indexing function for a highp value, if both mediump and
173 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
174 // principle this code could be used with multiple backends.
175 type.setPrecision(EbpHigh);
176 TIntermAggregate *indexingFunction = new TIntermAggregate(EOpFunction);
177 indexingFunction->setNameObj(GetIndexFunctionName(type, write));
178
179 TType fieldType = GetFieldType(type);
180 int numCases = 0;
181 if (type.isMatrix())
182 {
183 numCases = type.getCols();
184 }
185 else
186 {
187 numCases = type.getNominalSize();
188 }
189 if (write)
190 {
191 indexingFunction->setType(TType(EbtVoid));
192 }
193 else
194 {
195 indexingFunction->setType(fieldType);
196 }
197
198 TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
199 TIntermSymbol *baseParam = CreateBaseSymbol(type);
200 if (write)
201 baseParam->getTypePointer()->setQualifier(EvqInOut);
202 else
203 baseParam->getTypePointer()->setQualifier(EvqIn);
204 paramsNode->getSequence()->push_back(baseParam);
205 TIntermSymbol *indexParam = CreateIndexSymbol();
206 indexParam->getTypePointer()->setQualifier(EvqIn);
207 paramsNode->getSequence()->push_back(indexParam);
208 if (write)
209 {
210 TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
211 valueParam->getTypePointer()->setQualifier(EvqIn);
212 paramsNode->getSequence()->push_back(valueParam);
213 }
214 indexingFunction->getSequence()->push_back(paramsNode);
215
216 TIntermAggregate *statementList = new TIntermAggregate(EOpSequence);
217 for (int i = 0; i < numCases; ++i)
218 {
219 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
220 statementList->getSequence()->push_back(caseNode);
221
222 TIntermBinary *indexNode = CreateIndexDirectBaseSymbolNode(type, fieldType, i);
223 if (write)
224 {
225 TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
226 statementList->getSequence()->push_back(assignNode);
227 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
228 statementList->getSequence()->push_back(breakNode);
229 }
230 else
231 {
232 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
233 statementList->getSequence()->push_back(returnNode);
234 }
235 }
236
237 // Default case
238 TIntermCase *defaultNode = new TIntermCase(nullptr);
239 statementList->getSequence()->push_back(defaultNode);
240 TIntermBinary *cond = new TIntermBinary(EOpLessThan);
241 cond->setType(TType(EbtBool, EbpUndefined));
242 cond->setLeft(CreateIndexSymbol());
243 cond->setRight(CreateIntConstantNode(0));
244 TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence);
245 TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence);
246 TIntermBinary *indexFirstNode = CreateIndexDirectBaseSymbolNode(type, fieldType, 0);
247 TIntermBinary *indexLastNode = CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1);
248 if (write)
249 {
250 TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
251 trueBlock->getSequence()->push_back(assignFirstNode);
252 TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
253 falseBlock->getSequence()->push_back(assignLastNode);
254 }
255 else
256 {
257 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
258 trueBlock->getSequence()->push_back(returnFirstNode);
259
260 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
261 falseBlock->getSequence()->push_back(returnLastNode);
262 }
263 TIntermSelection *ifNode = new TIntermSelection(cond, trueBlock, falseBlock);
264 statementList->getSequence()->push_back(ifNode);
265 TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
266
267 TIntermAggregate *bodyNode = new TIntermAggregate(EOpSequence);
268 bodyNode->getSequence()->push_back(switchNode);
269 indexingFunction->getSequence()->push_back(bodyNode);
270
271 return indexingFunction;
272}
273
274class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
275{
276 public:
277 RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
278
279 bool visitBinary(Visit visit, TIntermBinary *node) override;
280
281 void insertHelperDefinitions(TIntermNode *root);
282
283 void nextIteration();
284
285 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
286
287 protected:
288 // Sets of types that are indexed. Note that these can not store multiple variants
289 // of the same type with different precisions - only one precision gets stored.
290 std::set<TType> mIndexedVecAndMatrixTypes;
291 std::set<TType> mWrittenVecAndMatrixTypes;
292
293 bool mUsedTreeInsertion;
294
295 // When true, the traverser will remove side effects from any indexing expression.
296 // This is done so that in code like
297 // V[j++][i]++.
298 // where V is an array of vectors, j++ will only be evaluated once.
299 bool mRemoveIndexSideEffectsInSubtree;
300};
301
302RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
303 int shaderVersion)
304 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
305 mUsedTreeInsertion(false),
306 mRemoveIndexSideEffectsInSubtree(false)
307{
308}
309
310void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
311{
312 TIntermAggregate *rootAgg = root->getAsAggregate();
313 ASSERT(rootAgg != nullptr && rootAgg->getOp() == EOpSequence);
314 TIntermSequence insertions;
315 for (TType type : mIndexedVecAndMatrixTypes)
316 {
317 insertions.push_back(GetIndexFunctionDefinition(type, false));
318 }
319 for (TType type : mWrittenVecAndMatrixTypes)
320 {
321 insertions.push_back(GetIndexFunctionDefinition(type, true));
322 }
323 mInsertions.push_back(NodeInsertMultipleEntry(rootAgg, 0, insertions, TIntermSequence()));
324}
325
326// Create a call to dyn_index_*() based on an indirect indexing op node
327TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
328 TIntermTyped *indexedNode,
329 TIntermTyped *index)
330{
331 ASSERT(node->getOp() == EOpIndexIndirect);
332 TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall);
333 indexingCall->setLine(node->getLine());
334 indexingCall->setUserDefined();
335 indexingCall->setNameObj(GetIndexFunctionName(indexedNode->getType(), false));
336 indexingCall->getSequence()->push_back(indexedNode);
337 indexingCall->getSequence()->push_back(index);
338
339 TType fieldType = GetFieldType(indexedNode->getType());
340 indexingCall->setType(fieldType);
341 return indexingCall;
342}
343
344TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
345 TIntermTyped *index,
346 TIntermTyped *writtenValue)
347{
348 // Deep copy the left node so that two pointers to the same node don't end up in the tree.
349 TIntermNode *leftCopy = node->getLeft()->deepCopy();
350 ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr);
351 TIntermAggregate *indexedWriteCall =
352 CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index);
353 indexedWriteCall->setNameObj(GetIndexFunctionName(node->getLeft()->getType(), true));
354 indexedWriteCall->setType(TType(EbtVoid));
355 indexedWriteCall->getSequence()->push_back(writtenValue);
356 return indexedWriteCall;
357}
358
359bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
360{
361 if (mUsedTreeInsertion)
362 return false;
363
364 if (node->getOp() == EOpIndexIndirect)
365 {
366 if (mRemoveIndexSideEffectsInSubtree)
367 {
368 ASSERT(node->getRight()->hasSideEffects());
369 // In case we're just removing index side effects, convert
370 // v_expr[index_expr]
371 // to this:
372 // int s0 = index_expr; v_expr[s0];
373 // Now v_expr[s0] can be safely executed several times without unintended side effects.
374
375 // Init the temp variable holding the index
376 TIntermAggregate *initIndex = createTempInitDeclaration(node->getRight());
377 TIntermSequence insertions;
378 insertions.push_back(initIndex);
379 insertStatementsInParentBlock(insertions);
380 mUsedTreeInsertion = true;
381
382 // Replace the index with the temp variable
383 TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
384 NodeUpdateEntry replaceIndex(node, node->getRight(), tempIndex, false);
385 mReplacements.push_back(replaceIndex);
386 }
387 else if (!node->getLeft()->isArray() && node->getLeft()->getBasicType() != EbtStruct)
388 {
389 bool write = isLValueRequiredHere();
390
391 TType type = node->getLeft()->getType();
392 mIndexedVecAndMatrixTypes.insert(type);
393
394 if (write)
395 {
396 // Convert:
397 // v_expr[index_expr]++;
398 // to this:
399 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
400 // dyn_index_write(v_expr, s0, s1);
401 // This works even if index_expr has some side effects.
402 if (node->getLeft()->hasSideEffects())
403 {
404 // If v_expr has side effects, those need to be removed before proceeding.
405 // Otherwise the side effects of v_expr would be evaluated twice.
406 // The only case where an l-value can have side effects is when it is
407 // indexing. For example, it can be V[j++] where V is an array of vectors.
408 mRemoveIndexSideEffectsInSubtree = true;
409 return true;
410 }
411 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
412 // only writes it and doesn't need the previous value. http://anglebug.com/1116
413
414 mWrittenVecAndMatrixTypes.insert(type);
415 TType fieldType = GetFieldType(type);
416
417 TIntermSequence insertionsBefore;
418 TIntermSequence insertionsAfter;
419
420 // Store the index in a temporary signed int variable.
421 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
422 TIntermAggregate *initIndex = createTempInitDeclaration(indexInitializer);
423 initIndex->setLine(node->getLine());
424 insertionsBefore.push_back(initIndex);
425
426 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
427 node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
428
429 // Create a node for referring to the index after the nextTemporaryIndex() call
430 // below.
431 TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
432
433 nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
434 // field value.
435 insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
436
437 TIntermAggregate *indexedWriteCall =
438 CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType));
439 insertionsAfter.push_back(indexedWriteCall);
440 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
441 NodeUpdateEntry replaceIndex(getParentNode(), node, createTempSymbol(fieldType),
442 false);
443 mReplacements.push_back(replaceIndex);
444 mUsedTreeInsertion = true;
445 }
446 else
447 {
448 // The indexed value is not being written, so we can simply convert
449 // v_expr[index_expr]
450 // into
451 // dyn_index(v_expr, index_expr)
452 // If the index_expr is unsigned, we'll convert it to signed.
453 ASSERT(!mRemoveIndexSideEffectsInSubtree);
454 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
455 node, node->getLeft(), EnsureSignedInt(node->getRight()));
456 NodeUpdateEntry replaceIndex(getParentNode(), node, indexingCall, false);
457 mReplacements.push_back(replaceIndex);
458 }
459 }
460 }
461 return !mUsedTreeInsertion;
462}
463
464void RemoveDynamicIndexingTraverser::nextIteration()
465{
466 mUsedTreeInsertion = false;
467 mRemoveIndexSideEffectsInSubtree = false;
468 nextTemporaryIndex();
469}
470
471} // namespace
472
473void RemoveDynamicIndexing(TIntermNode *root,
474 unsigned int *temporaryIndex,
475 const TSymbolTable &symbolTable,
476 int shaderVersion)
477{
478 RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
479 ASSERT(temporaryIndex != nullptr);
480 traverser.useTemporaryIndex(temporaryIndex);
481 do
482 {
483 traverser.nextIteration();
484 root->traverse(&traverser);
485 traverser.updateTree();
486 } while (traverser.usedTreeInsertion());
487 traverser.insertHelperDefinitions(root);
488 traverser.updateTree();
489}