Remove dynamic indexing of matrices and vectors in HLSL

Re-re-relanding after clang warning fix.

Re-re-landing with fix to setting qualifiers on generated nodes. The
previous version failed when a uniform was indexed, because it would
set the uniform qualifier on some of the generated nodes and that
interfered with the operation of UniformsHLSL.

Re-landing after fixing D3D9 specific issues.

HLSL doesn't support dynamic indexing of matrices and vectors, so replace
that with helper functions that unroll dynamic indexing into switch/case
and static indexing.

Both the indexed vector/matrix expression and the index may have side
effects, and these will be evaluated correctly. If necessary, index
expressions that have side effects will be written to a temporary
variable that will replace the index.

Besides dEQP tests, this change is tested by a WebGL 2 conformance test.

In the case that a dynamic index is out-of-range, the base ESSL 3.00 spec
allows undefined behavior. KHR_robust_buffer_access_behavior adds the
requirement that program termination should not occur and that
out-of-range reads must return either a value from the active program's
memory or zero, and out-of-range writes should only affect the active
program's memory or do nothing. This patch clamps out-of-range indices so
that either the first or last item of the matrix/vector is accessed.

The code is not transformed in case the it fits within the limited subset
of ESSL 1.00 given in Appendix A of the spec. If the code isn't within
the restricted subset, even ESSL 1.00 shaders may require this
workaround.

BUG=angleproject:1116
TEST=dEQP-GLES3.functional.shaders.indexing.* (all pass after change)
     WebGL 2 conformance tests (glsl3/vector-dynamic-indexing.html)

Change-Id: I9f6d7c7ecda8ac4dc3c30b39e15a9a0b5381c5a8
Reviewed-on: https://chromium-review.googlesource.com/310010
Reviewed-by: Geoff Lang <geofflang@chromium.org>
Tested-by: Geoff Lang <geofflang@chromium.org>
diff --git a/src/compiler/translator/RemoveDynamicIndexing.cpp b/src/compiler/translator/RemoveDynamicIndexing.cpp
new file mode 100644
index 0000000..74814f2
--- /dev/null
+++ b/src/compiler/translator/RemoveDynamicIndexing.cpp
@@ -0,0 +1,513 @@
+//
+// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
+// replacing them with calls to functions that choose which component to return or write.
+//
+
+#include "compiler/translator/RemoveDynamicIndexing.h"
+
+#include "compiler/translator/InfoSink.h"
+#include "compiler/translator/IntermNode.h"
+#include "compiler/translator/SymbolTable.h"
+
+namespace
+{
+
+TName GetIndexFunctionName(const TType &type, bool write)
+{
+    TInfoSinkBase nameSink;
+    nameSink << "dyn_index_";
+    if (write)
+    {
+        nameSink << "write_";
+    }
+    if (type.isMatrix())
+    {
+        nameSink << "mat" << type.getCols() << "x" << type.getRows();
+    }
+    else
+    {
+        switch (type.getBasicType())
+        {
+            case EbtInt:
+                nameSink << "ivec";
+                break;
+            case EbtBool:
+                nameSink << "bvec";
+                break;
+            case EbtUInt:
+                nameSink << "uvec";
+                break;
+            case EbtFloat:
+                nameSink << "vec";
+                break;
+            default:
+                UNREACHABLE();
+        }
+        nameSink << type.getNominalSize();
+    }
+    TString nameString = TFunction::mangleName(nameSink.c_str());
+    TName name(nameString);
+    name.setInternal(true);
+    return name;
+}
+
+TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
+{
+    TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
+    symbol->setInternal(true);
+    symbol->getTypePointer()->setQualifier(qualifier);
+    return symbol;
+}
+
+TIntermSymbol *CreateIndexSymbol()
+{
+    TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
+    symbol->setInternal(true);
+    symbol->getTypePointer()->setQualifier(EvqIn);
+    return symbol;
+}
+
+TIntermSymbol *CreateValueSymbol(const TType &type)
+{
+    TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
+    symbol->setInternal(true);
+    symbol->getTypePointer()->setQualifier(EvqIn);
+    return symbol;
+}
+
+TIntermConstantUnion *CreateIntConstantNode(int i)
+{
+    TConstantUnion *constant = new TConstantUnion();
+    constant->setIConst(i);
+    return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
+}
+
+TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
+                                               const TType &fieldType,
+                                               const int index,
+                                               TQualifier baseQualifier)
+{
+    TIntermBinary *indexNode = new TIntermBinary(EOpIndexDirect);
+    indexNode->setType(fieldType);
+    TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier);
+    indexNode->setLeft(baseSymbol);
+    indexNode->setRight(CreateIntConstantNode(index));
+    return indexNode;
+}
+
+TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
+{
+    TIntermBinary *assignNode = new TIntermBinary(EOpAssign);
+    assignNode->setType(assignedValueType);
+    assignNode->setLeft(targetNode);
+    assignNode->setRight(CreateValueSymbol(assignedValueType));
+    return assignNode;
+}
+
+TIntermTyped *EnsureSignedInt(TIntermTyped *node)
+{
+    if (node->getBasicType() == EbtInt)
+        return node;
+
+    TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt);
+    convertedNode->setType(TType(EbtInt));
+    convertedNode->getSequence()->push_back(node);
+    convertedNode->setPrecisionFromChildren();
+    return convertedNode;
+}
+
+TType GetFieldType(const TType &indexedType)
+{
+    if (indexedType.isMatrix())
+    {
+        TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
+        fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
+        return fieldType;
+    }
+    else
+    {
+        return TType(indexedType.getBasicType(), indexedType.getPrecision());
+    }
+}
+
+// Generate a read or write function for one field in a vector/matrix.
+// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
+// indices in other places.
+// Note that indices can be either int or uint. We create only int versions of the functions,
+// and convert uint indices to int at the call site.
+// read function example:
+// float dyn_index_vec2(in vec2 base, in int index)
+// {
+//    switch(index)
+//    {
+//      case (0):
+//        return base[0];
+//      case (1):
+//        return base[1];
+//      default:
+//        break;
+//    }
+//    if (index < 0)
+//      return base[0];
+//    return base[1];
+// }
+// write function example:
+// void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
+// {
+//    switch(index)
+//    {
+//      case (0):
+//        base[0] = value;
+//        return;
+//      case (1):
+//        base[1] = value;
+//        return;
+//      default:
+//        break;
+//    }
+//    if (index < 0)
+//    {
+//      base[0] = value;
+//      return;
+//    }
+//    base[1] = value;
+// }
+// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
+TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
+{
+    ASSERT(!type.isArray());
+    // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
+    // end up using mediump version of an indexing function for a highp value, if both mediump and
+    // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
+    // principle this code could be used with multiple backends.
+    type.setPrecision(EbpHigh);
+    TIntermAggregate *indexingFunction = new TIntermAggregate(EOpFunction);
+    indexingFunction->setNameObj(GetIndexFunctionName(type, write));
+
+    TType fieldType = GetFieldType(type);
+    int numCases = 0;
+    if (type.isMatrix())
+    {
+        numCases = type.getCols();
+    }
+    else
+    {
+        numCases = type.getNominalSize();
+    }
+    if (write)
+    {
+        indexingFunction->setType(TType(EbtVoid));
+    }
+    else
+    {
+        indexingFunction->setType(fieldType);
+    }
+
+    TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
+    TQualifier baseQualifier = EvqInOut;
+    if (!write)
+        baseQualifier        = EvqIn;
+    TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier);
+    paramsNode->getSequence()->push_back(baseParam);
+    TIntermSymbol *indexParam = CreateIndexSymbol();
+    paramsNode->getSequence()->push_back(indexParam);
+    if (write)
+    {
+        TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
+        paramsNode->getSequence()->push_back(valueParam);
+    }
+    indexingFunction->getSequence()->push_back(paramsNode);
+
+    TIntermAggregate *statementList = new TIntermAggregate(EOpSequence);
+    for (int i = 0; i < numCases; ++i)
+    {
+        TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
+        statementList->getSequence()->push_back(caseNode);
+
+        TIntermBinary *indexNode =
+            CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier);
+        if (write)
+        {
+            TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
+            statementList->getSequence()->push_back(assignNode);
+            TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
+            statementList->getSequence()->push_back(returnNode);
+        }
+        else
+        {
+            TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
+            statementList->getSequence()->push_back(returnNode);
+        }
+    }
+
+    // Default case
+    TIntermCase *defaultNode = new TIntermCase(nullptr);
+    statementList->getSequence()->push_back(defaultNode);
+    TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
+    statementList->getSequence()->push_back(breakNode);
+
+    TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
+
+    TIntermAggregate *bodyNode = new TIntermAggregate(EOpSequence);
+    bodyNode->getSequence()->push_back(switchNode);
+
+    TIntermBinary *cond = new TIntermBinary(EOpLessThan);
+    cond->setType(TType(EbtBool, EbpUndefined));
+    cond->setLeft(CreateIndexSymbol());
+    cond->setRight(CreateIntConstantNode(0));
+
+    // Two blocks: one accesses (either reads or writes) the first element and returns,
+    // the other accesses the last element.
+    TIntermAggregate *useFirstBlock = new TIntermAggregate(EOpSequence);
+    TIntermAggregate *useLastBlock = new TIntermAggregate(EOpSequence);
+    TIntermBinary *indexFirstNode =
+        CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier);
+    TIntermBinary *indexLastNode =
+        CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier);
+    if (write)
+    {
+        TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
+        useFirstBlock->getSequence()->push_back(assignFirstNode);
+        TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
+        useFirstBlock->getSequence()->push_back(returnNode);
+
+        TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
+        useLastBlock->getSequence()->push_back(assignLastNode);
+    }
+    else
+    {
+        TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
+        useFirstBlock->getSequence()->push_back(returnFirstNode);
+
+        TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
+        useLastBlock->getSequence()->push_back(returnLastNode);
+    }
+    TIntermSelection *ifNode = new TIntermSelection(cond, useFirstBlock, nullptr);
+    bodyNode->getSequence()->push_back(ifNode);
+    bodyNode->getSequence()->push_back(useLastBlock);
+
+    indexingFunction->getSequence()->push_back(bodyNode);
+
+    return indexingFunction;
+}
+
+class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
+{
+  public:
+    RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
+
+    bool visitBinary(Visit visit, TIntermBinary *node) override;
+
+    void insertHelperDefinitions(TIntermNode *root);
+
+    void nextIteration();
+
+    bool usedTreeInsertion() const { return mUsedTreeInsertion; }
+
+  protected:
+    // Sets of types that are indexed. Note that these can not store multiple variants
+    // of the same type with different precisions - only one precision gets stored.
+    std::set<TType> mIndexedVecAndMatrixTypes;
+    std::set<TType> mWrittenVecAndMatrixTypes;
+
+    bool mUsedTreeInsertion;
+
+    // When true, the traverser will remove side effects from any indexing expression.
+    // This is done so that in code like
+    //   V[j++][i]++.
+    // where V is an array of vectors, j++ will only be evaluated once.
+    bool mRemoveIndexSideEffectsInSubtree;
+};
+
+RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
+                                                               int shaderVersion)
+    : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
+      mUsedTreeInsertion(false),
+      mRemoveIndexSideEffectsInSubtree(false)
+{
+}
+
+void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
+{
+    TIntermAggregate *rootAgg = root->getAsAggregate();
+    ASSERT(rootAgg != nullptr && rootAgg->getOp() == EOpSequence);
+    TIntermSequence insertions;
+    for (TType type : mIndexedVecAndMatrixTypes)
+    {
+        insertions.push_back(GetIndexFunctionDefinition(type, false));
+    }
+    for (TType type : mWrittenVecAndMatrixTypes)
+    {
+        insertions.push_back(GetIndexFunctionDefinition(type, true));
+    }
+    mInsertions.push_back(NodeInsertMultipleEntry(rootAgg, 0, insertions, TIntermSequence()));
+}
+
+// Create a call to dyn_index_*() based on an indirect indexing op node
+TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
+                                          TIntermTyped *indexedNode,
+                                          TIntermTyped *index)
+{
+    ASSERT(node->getOp() == EOpIndexIndirect);
+    TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall);
+    indexingCall->setLine(node->getLine());
+    indexingCall->setUserDefined();
+    indexingCall->setNameObj(GetIndexFunctionName(indexedNode->getType(), false));
+    indexingCall->getSequence()->push_back(indexedNode);
+    indexingCall->getSequence()->push_back(index);
+
+    TType fieldType = GetFieldType(indexedNode->getType());
+    indexingCall->setType(fieldType);
+    return indexingCall;
+}
+
+TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
+                                                 TIntermTyped *index,
+                                                 TIntermTyped *writtenValue)
+{
+    // Deep copy the left node so that two pointers to the same node don't end up in the tree.
+    TIntermNode *leftCopy = node->getLeft()->deepCopy();
+    ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr);
+    TIntermAggregate *indexedWriteCall =
+        CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index);
+    indexedWriteCall->setNameObj(GetIndexFunctionName(node->getLeft()->getType(), true));
+    indexedWriteCall->setType(TType(EbtVoid));
+    indexedWriteCall->getSequence()->push_back(writtenValue);
+    return indexedWriteCall;
+}
+
+bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
+{
+    if (mUsedTreeInsertion)
+        return false;
+
+    if (node->getOp() == EOpIndexIndirect)
+    {
+        if (mRemoveIndexSideEffectsInSubtree)
+        {
+            ASSERT(node->getRight()->hasSideEffects());
+            // In case we're just removing index side effects, convert
+            //   v_expr[index_expr]
+            // to this:
+            //   int s0 = index_expr; v_expr[s0];
+            // Now v_expr[s0] can be safely executed several times without unintended side effects.
+
+            // Init the temp variable holding the index
+            TIntermAggregate *initIndex = createTempInitDeclaration(node->getRight());
+            TIntermSequence insertions;
+            insertions.push_back(initIndex);
+            insertStatementsInParentBlock(insertions);
+            mUsedTreeInsertion = true;
+
+            // Replace the index with the temp variable
+            TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
+            NodeUpdateEntry replaceIndex(node, node->getRight(), tempIndex, false);
+            mReplacements.push_back(replaceIndex);
+        }
+        else if (!node->getLeft()->isArray() && node->getLeft()->getBasicType() != EbtStruct)
+        {
+            bool write = isLValueRequiredHere();
+
+            TType type = node->getLeft()->getType();
+            mIndexedVecAndMatrixTypes.insert(type);
+
+            if (write)
+            {
+                // Convert:
+                //   v_expr[index_expr]++;
+                // to this:
+                //   int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
+                //   dyn_index_write(v_expr, s0, s1);
+                // This works even if index_expr has some side effects.
+                if (node->getLeft()->hasSideEffects())
+                {
+                    // If v_expr has side effects, those need to be removed before proceeding.
+                    // Otherwise the side effects of v_expr would be evaluated twice.
+                    // The only case where an l-value can have side effects is when it is
+                    // indexing. For example, it can be V[j++] where V is an array of vectors.
+                    mRemoveIndexSideEffectsInSubtree = true;
+                    return true;
+                }
+                // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
+                // only writes it and doesn't need the previous value. http://anglebug.com/1116
+
+                mWrittenVecAndMatrixTypes.insert(type);
+                TType fieldType = GetFieldType(type);
+
+                TIntermSequence insertionsBefore;
+                TIntermSequence insertionsAfter;
+
+                // Store the index in a temporary signed int variable.
+                TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
+                TIntermAggregate *initIndex = createTempInitDeclaration(indexInitializer);
+                initIndex->setLine(node->getLine());
+                insertionsBefore.push_back(initIndex);
+
+                TIntermAggregate *indexingCall = CreateIndexFunctionCall(
+                    node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
+
+                // Create a node for referring to the index after the nextTemporaryIndex() call
+                // below.
+                TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
+
+                nextTemporaryIndex();  // From now on, creating temporary symbols that refer to the
+                                       // field value.
+                insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
+
+                TIntermAggregate *indexedWriteCall =
+                    CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType));
+                insertionsAfter.push_back(indexedWriteCall);
+                insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
+                NodeUpdateEntry replaceIndex(getParentNode(), node, createTempSymbol(fieldType),
+                                             false);
+                mReplacements.push_back(replaceIndex);
+                mUsedTreeInsertion = true;
+            }
+            else
+            {
+                // The indexed value is not being written, so we can simply convert
+                //   v_expr[index_expr]
+                // into
+                //   dyn_index(v_expr, index_expr)
+                // If the index_expr is unsigned, we'll convert it to signed.
+                ASSERT(!mRemoveIndexSideEffectsInSubtree);
+                TIntermAggregate *indexingCall = CreateIndexFunctionCall(
+                    node, node->getLeft(), EnsureSignedInt(node->getRight()));
+                NodeUpdateEntry replaceIndex(getParentNode(), node, indexingCall, false);
+                mReplacements.push_back(replaceIndex);
+            }
+        }
+    }
+    return !mUsedTreeInsertion;
+}
+
+void RemoveDynamicIndexingTraverser::nextIteration()
+{
+    mUsedTreeInsertion               = false;
+    mRemoveIndexSideEffectsInSubtree = false;
+    nextTemporaryIndex();
+}
+
+}  // namespace
+
+void RemoveDynamicIndexing(TIntermNode *root,
+                           unsigned int *temporaryIndex,
+                           const TSymbolTable &symbolTable,
+                           int shaderVersion)
+{
+    RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
+    ASSERT(temporaryIndex != nullptr);
+    traverser.useTemporaryIndex(temporaryIndex);
+    do
+    {
+        traverser.nextIteration();
+        root->traverse(&traverser);
+        traverser.updateTree();
+    } while (traverser.usedTreeInsertion());
+    traverser.insertHelperDefinitions(root);
+    traverser.updateTree();
+}