blob: 9ff2f12650fc3ead91348c80c26bfa008f3ea2ae [file] [log] [blame]
Jamie Madill1048e432016-07-23 18:51:28 -04001//
2// Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Implementation of the integer pow expressions HLSL bug workaround.
7// See header for more info.
8
9#include "compiler/translator/ExpandIntegerPowExpressions.h"
10
11#include <cmath>
12#include <cstdlib>
13
14#include "compiler/translator/IntermNode.h"
15
16namespace sh
17{
18
19namespace
20{
21
22class Traverser : public TIntermTraverser
23{
24 public:
25 static void Apply(TIntermNode *root, unsigned int *tempIndex);
26
27 private:
28 Traverser();
29 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
Jamie Madill5655b842016-08-02 11:00:07 -040030 void nextIteration();
31
32 bool mFound = false;
Jamie Madill1048e432016-07-23 18:51:28 -040033};
34
35// static
36void Traverser::Apply(TIntermNode *root, unsigned int *tempIndex)
37{
38 Traverser traverser;
39 traverser.useTemporaryIndex(tempIndex);
Jamie Madill5655b842016-08-02 11:00:07 -040040 do
41 {
42 traverser.nextIteration();
43 root->traverse(&traverser);
44 if (traverser.mFound)
45 {
46 traverser.updateTree();
47 }
48 } while (traverser.mFound);
Jamie Madill1048e432016-07-23 18:51:28 -040049}
50
51Traverser::Traverser() : TIntermTraverser(true, false, false)
52{
53}
54
Jamie Madill5655b842016-08-02 11:00:07 -040055void Traverser::nextIteration()
56{
57 mFound = false;
58 nextTemporaryIndex();
59}
60
Jamie Madill1048e432016-07-23 18:51:28 -040061bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
62{
Jamie Madill5655b842016-08-02 11:00:07 -040063 if (mFound)
64 {
65 return false;
66 }
67
Jamie Madill1048e432016-07-23 18:51:28 -040068 // Test 0: skip non-pow operators.
69 if (node->getOp() != EOpPow)
70 {
71 return true;
72 }
73
74 const TIntermSequence *sequence = node->getSequence();
75 ASSERT(sequence->size() == 2u);
76 const TIntermConstantUnion *constantNode = sequence->at(1)->getAsConstantUnion();
77
78 // Test 1: check for a single constant.
79 if (!constantNode || constantNode->getNominalSize() != 1)
80 {
81 return true;
82 }
83
84 const TConstantUnion *constant = constantNode->getUnionArrayPointer();
85
86 TConstantUnion asFloat;
87 asFloat.cast(EbtFloat, *constant);
88
89 float value = asFloat.getFConst();
90
91 // Test 2: value is in the problematic range.
92 if (value < -5.0f || value > 9.0f)
93 {
94 return true;
95 }
96
97 // Test 3: value is integer or pretty close to an integer.
Jamie Madill6c9503e2016-08-16 14:06:32 -040098 float absval = std::abs(value);
99 float frac = absval - std::round(absval);
Jamie Madill1048e432016-07-23 18:51:28 -0400100 if (frac > 0.0001f)
101 {
102 return true;
103 }
104
105 // Test 4: skip -1, 0, and 1
106 int exponent = static_cast<int>(value);
107 int n = std::abs(exponent);
108 if (n < 2)
109 {
110 return true;
111 }
112
113 // Potential problem case detected, apply workaround.
114 nextTemporaryIndex();
115
116 TIntermTyped *lhs = sequence->at(0)->getAsTyped();
117 ASSERT(lhs);
118
Olli Etuaho13389b62016-10-16 11:48:18 +0100119 TIntermDeclaration *init = createTempInitDeclaration(lhs);
Jamie Madilld7b1ab52016-12-12 14:42:19 -0500120 TIntermTyped *current = createTempSymbol(lhs->getType());
Jamie Madill1048e432016-07-23 18:51:28 -0400121
122 insertStatementInParentBlock(init);
123
124 // Create a chain of n-1 multiples.
125 for (int i = 1; i < n; ++i)
126 {
Olli Etuaho3272a6d2016-08-29 17:54:50 +0300127 TIntermBinary *mul = new TIntermBinary(EOpMul, current, createTempSymbol(lhs->getType()));
Jamie Madill1048e432016-07-23 18:51:28 -0400128 mul->setLine(node->getLine());
129 current = mul;
130 }
131
132 // For negative pow, compute the reciprocal of the positive pow.
133 if (exponent < 0)
134 {
135 TConstantUnion *oneVal = new TConstantUnion();
136 oneVal->setFConst(1.0f);
137 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
Olli Etuaho3272a6d2016-08-29 17:54:50 +0300138 TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
Jamie Madilld7b1ab52016-12-12 14:42:19 -0500139 current = div;
Jamie Madill1048e432016-07-23 18:51:28 -0400140 }
141
Jamie Madill03d863c2016-07-27 18:15:53 -0400142 queueReplacement(node, current, OriginalNode::IS_DROPPED);
Jamie Madill5655b842016-08-02 11:00:07 -0400143 mFound = true;
144 return false;
Jamie Madill1048e432016-07-23 18:51:28 -0400145}
146
147} // anonymous namespace
148
149void ExpandIntegerPowExpressions(TIntermNode *root, unsigned int *tempIndex)
150{
151 Traverser::Apply(root, tempIndex);
152}
153
154} // namespace sh