blob: 162925a616f35d02b45040474aabbd2a2e747aac [file] [log] [blame]
Jamie Madill1048e432016-07-23 18:51:28 -04001//
2// Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Implementation of the integer pow expressions HLSL bug workaround.
7// See header for more info.
8
9#include "compiler/translator/ExpandIntegerPowExpressions.h"
10
11#include <cmath>
12#include <cstdlib>
13
14#include "compiler/translator/IntermNode.h"
15
16namespace sh
17{
18
19namespace
20{
21
22class Traverser : public TIntermTraverser
23{
24 public:
25 static void Apply(TIntermNode *root, unsigned int *tempIndex);
26
27 private:
28 Traverser();
29 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
30};
31
32// static
33void Traverser::Apply(TIntermNode *root, unsigned int *tempIndex)
34{
35 Traverser traverser;
36 traverser.useTemporaryIndex(tempIndex);
37 root->traverse(&traverser);
38 traverser.updateTree();
39}
40
41Traverser::Traverser() : TIntermTraverser(true, false, false)
42{
43}
44
45bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
46{
47 // Test 0: skip non-pow operators.
48 if (node->getOp() != EOpPow)
49 {
50 return true;
51 }
52
53 const TIntermSequence *sequence = node->getSequence();
54 ASSERT(sequence->size() == 2u);
55 const TIntermConstantUnion *constantNode = sequence->at(1)->getAsConstantUnion();
56
57 // Test 1: check for a single constant.
58 if (!constantNode || constantNode->getNominalSize() != 1)
59 {
60 return true;
61 }
62
63 const TConstantUnion *constant = constantNode->getUnionArrayPointer();
64
65 TConstantUnion asFloat;
66 asFloat.cast(EbtFloat, *constant);
67
68 float value = asFloat.getFConst();
69
70 // Test 2: value is in the problematic range.
71 if (value < -5.0f || value > 9.0f)
72 {
73 return true;
74 }
75
76 // Test 3: value is integer or pretty close to an integer.
77 float frac = std::abs(value) - std::floor(std::abs(value));
78 if (frac > 0.0001f)
79 {
80 return true;
81 }
82
83 // Test 4: skip -1, 0, and 1
84 int exponent = static_cast<int>(value);
85 int n = std::abs(exponent);
86 if (n < 2)
87 {
88 return true;
89 }
90
91 // Potential problem case detected, apply workaround.
92 nextTemporaryIndex();
93
94 TIntermTyped *lhs = sequence->at(0)->getAsTyped();
95 ASSERT(lhs);
96
97 TIntermAggregate *init = createTempInitDeclaration(lhs);
98 TIntermTyped *current = createTempSymbol(lhs->getType());
99
100 insertStatementInParentBlock(init);
101
102 // Create a chain of n-1 multiples.
103 for (int i = 1; i < n; ++i)
104 {
105 TIntermBinary *mul = new TIntermBinary(EOpMul);
106 mul->setLeft(current);
107 mul->setRight(createTempSymbol(lhs->getType()));
108 mul->setType(node->getType());
109 mul->setLine(node->getLine());
110 current = mul;
111 }
112
113 // For negative pow, compute the reciprocal of the positive pow.
114 if (exponent < 0)
115 {
116 TConstantUnion *oneVal = new TConstantUnion();
117 oneVal->setFConst(1.0f);
118 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
119 TIntermBinary *div = new TIntermBinary(EOpDiv);
120 div->setLeft(oneNode);
121 div->setRight(current);
122 current = div;
123 }
124
125 replace(node, current);
126 return true;
127}
128
129} // anonymous namespace
130
131void ExpandIntegerPowExpressions(TIntermNode *root, unsigned int *tempIndex)
132{
133 Traverser::Apply(root, tempIndex);
134}
135
136} // namespace sh