Jamie Madill | 1048e43 | 2016-07-23 18:51:28 -0400 | [diff] [blame] | 1 | // |
| 2 | // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved. |
| 3 | // Use of this source code is governed by a BSD-style license that can be |
| 4 | // found in the LICENSE file. |
| 5 | // |
| 6 | // Implementation of the integer pow expressions HLSL bug workaround. |
| 7 | // See header for more info. |
| 8 | |
| 9 | #include "compiler/translator/ExpandIntegerPowExpressions.h" |
| 10 | |
| 11 | #include <cmath> |
| 12 | #include <cstdlib> |
| 13 | |
| 14 | #include "compiler/translator/IntermNode.h" |
| 15 | |
| 16 | namespace sh |
| 17 | { |
| 18 | |
| 19 | namespace |
| 20 | { |
| 21 | |
| 22 | class Traverser : public TIntermTraverser |
| 23 | { |
| 24 | public: |
| 25 | static void Apply(TIntermNode *root, unsigned int *tempIndex); |
| 26 | |
| 27 | private: |
| 28 | Traverser(); |
| 29 | bool visitAggregate(Visit visit, TIntermAggregate *node) override; |
Jamie Madill | 5655b84 | 2016-08-02 11:00:07 -0400 | [diff] [blame] | 30 | void nextIteration(); |
| 31 | |
| 32 | bool mFound = false; |
Jamie Madill | 1048e43 | 2016-07-23 18:51:28 -0400 | [diff] [blame] | 33 | }; |
| 34 | |
| 35 | // static |
| 36 | void Traverser::Apply(TIntermNode *root, unsigned int *tempIndex) |
| 37 | { |
| 38 | Traverser traverser; |
| 39 | traverser.useTemporaryIndex(tempIndex); |
Jamie Madill | 5655b84 | 2016-08-02 11:00:07 -0400 | [diff] [blame] | 40 | do |
| 41 | { |
| 42 | traverser.nextIteration(); |
| 43 | root->traverse(&traverser); |
| 44 | if (traverser.mFound) |
| 45 | { |
| 46 | traverser.updateTree(); |
| 47 | } |
| 48 | } while (traverser.mFound); |
Jamie Madill | 1048e43 | 2016-07-23 18:51:28 -0400 | [diff] [blame] | 49 | } |
| 50 | |
| 51 | Traverser::Traverser() : TIntermTraverser(true, false, false) |
| 52 | { |
| 53 | } |
| 54 | |
Jamie Madill | 5655b84 | 2016-08-02 11:00:07 -0400 | [diff] [blame] | 55 | void Traverser::nextIteration() |
| 56 | { |
| 57 | mFound = false; |
| 58 | nextTemporaryIndex(); |
| 59 | } |
| 60 | |
Jamie Madill | 1048e43 | 2016-07-23 18:51:28 -0400 | [diff] [blame] | 61 | bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) |
| 62 | { |
Jamie Madill | 5655b84 | 2016-08-02 11:00:07 -0400 | [diff] [blame] | 63 | if (mFound) |
| 64 | { |
| 65 | return false; |
| 66 | } |
| 67 | |
Jamie Madill | 1048e43 | 2016-07-23 18:51:28 -0400 | [diff] [blame] | 68 | // Test 0: skip non-pow operators. |
| 69 | if (node->getOp() != EOpPow) |
| 70 | { |
| 71 | return true; |
| 72 | } |
| 73 | |
| 74 | const TIntermSequence *sequence = node->getSequence(); |
| 75 | ASSERT(sequence->size() == 2u); |
| 76 | const TIntermConstantUnion *constantNode = sequence->at(1)->getAsConstantUnion(); |
| 77 | |
| 78 | // Test 1: check for a single constant. |
| 79 | if (!constantNode || constantNode->getNominalSize() != 1) |
| 80 | { |
| 81 | return true; |
| 82 | } |
| 83 | |
| 84 | const TConstantUnion *constant = constantNode->getUnionArrayPointer(); |
| 85 | |
| 86 | TConstantUnion asFloat; |
| 87 | asFloat.cast(EbtFloat, *constant); |
| 88 | |
| 89 | float value = asFloat.getFConst(); |
| 90 | |
| 91 | // Test 2: value is in the problematic range. |
| 92 | if (value < -5.0f || value > 9.0f) |
| 93 | { |
| 94 | return true; |
| 95 | } |
| 96 | |
| 97 | // Test 3: value is integer or pretty close to an integer. |
Jamie Madill | 6c9503e | 2016-08-16 14:06:32 -0400 | [diff] [blame] | 98 | float absval = std::abs(value); |
| 99 | float frac = absval - std::round(absval); |
Jamie Madill | 1048e43 | 2016-07-23 18:51:28 -0400 | [diff] [blame] | 100 | if (frac > 0.0001f) |
| 101 | { |
| 102 | return true; |
| 103 | } |
| 104 | |
| 105 | // Test 4: skip -1, 0, and 1 |
| 106 | int exponent = static_cast<int>(value); |
| 107 | int n = std::abs(exponent); |
| 108 | if (n < 2) |
| 109 | { |
| 110 | return true; |
| 111 | } |
| 112 | |
| 113 | // Potential problem case detected, apply workaround. |
| 114 | nextTemporaryIndex(); |
| 115 | |
| 116 | TIntermTyped *lhs = sequence->at(0)->getAsTyped(); |
| 117 | ASSERT(lhs); |
| 118 | |
| 119 | TIntermAggregate *init = createTempInitDeclaration(lhs); |
| 120 | TIntermTyped *current = createTempSymbol(lhs->getType()); |
| 121 | |
| 122 | insertStatementInParentBlock(init); |
| 123 | |
| 124 | // Create a chain of n-1 multiples. |
| 125 | for (int i = 1; i < n; ++i) |
| 126 | { |
| 127 | TIntermBinary *mul = new TIntermBinary(EOpMul); |
| 128 | mul->setLeft(current); |
| 129 | mul->setRight(createTempSymbol(lhs->getType())); |
| 130 | mul->setType(node->getType()); |
| 131 | mul->setLine(node->getLine()); |
| 132 | current = mul; |
| 133 | } |
| 134 | |
| 135 | // For negative pow, compute the reciprocal of the positive pow. |
| 136 | if (exponent < 0) |
| 137 | { |
| 138 | TConstantUnion *oneVal = new TConstantUnion(); |
| 139 | oneVal->setFConst(1.0f); |
| 140 | TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType()); |
| 141 | TIntermBinary *div = new TIntermBinary(EOpDiv); |
| 142 | div->setLeft(oneNode); |
| 143 | div->setRight(current); |
| 144 | current = div; |
| 145 | } |
| 146 | |
Jamie Madill | 03d863c | 2016-07-27 18:15:53 -0400 | [diff] [blame] | 147 | queueReplacement(node, current, OriginalNode::IS_DROPPED); |
Jamie Madill | 5655b84 | 2016-08-02 11:00:07 -0400 | [diff] [blame] | 148 | mFound = true; |
| 149 | return false; |
Jamie Madill | 1048e43 | 2016-07-23 18:51:28 -0400 | [diff] [blame] | 150 | } |
| 151 | |
| 152 | } // anonymous namespace |
| 153 | |
| 154 | void ExpandIntegerPowExpressions(TIntermNode *root, unsigned int *tempIndex) |
| 155 | { |
| 156 | Traverser::Apply(root, tempIndex); |
| 157 | } |
| 158 | |
| 159 | } // namespace sh |