blob: b219971a5662bf2de6ecedf3b7be1d29d138e167 [file] [log] [blame]
Jamie Madill1048e432016-07-23 18:51:28 -04001//
2// Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Implementation of the integer pow expressions HLSL bug workaround.
7// See header for more info.
8
9#include "compiler/translator/ExpandIntegerPowExpressions.h"
10
11#include <cmath>
12#include <cstdlib>
13
Olli Etuahoc26214d2018-03-16 10:43:11 +020014#include "compiler/translator/tree_util/IntermNode_util.h"
15#include "compiler/translator/tree_util/IntermTraverse.h"
Jamie Madill1048e432016-07-23 18:51:28 -040016
17namespace sh
18{
19
20namespace
21{
22
23class Traverser : public TIntermTraverser
24{
25 public:
Olli Etuahoa5e693a2017-07-13 16:07:26 +030026 static void Apply(TIntermNode *root, TSymbolTable *symbolTable);
Jamie Madill1048e432016-07-23 18:51:28 -040027
28 private:
Olli Etuahoa5e693a2017-07-13 16:07:26 +030029 Traverser(TSymbolTable *symbolTable);
Jamie Madill1048e432016-07-23 18:51:28 -040030 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
Jamie Madill5655b842016-08-02 11:00:07 -040031 void nextIteration();
32
33 bool mFound = false;
Jamie Madill1048e432016-07-23 18:51:28 -040034};
35
36// static
Olli Etuahoa5e693a2017-07-13 16:07:26 +030037void Traverser::Apply(TIntermNode *root, TSymbolTable *symbolTable)
Jamie Madill1048e432016-07-23 18:51:28 -040038{
Olli Etuahoa5e693a2017-07-13 16:07:26 +030039 Traverser traverser(symbolTable);
Jamie Madill5655b842016-08-02 11:00:07 -040040 do
41 {
42 traverser.nextIteration();
43 root->traverse(&traverser);
44 if (traverser.mFound)
45 {
46 traverser.updateTree();
47 }
48 } while (traverser.mFound);
Jamie Madill1048e432016-07-23 18:51:28 -040049}
50
Olli Etuahoa5e693a2017-07-13 16:07:26 +030051Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
Jamie Madill1048e432016-07-23 18:51:28 -040052{
53}
54
Jamie Madill5655b842016-08-02 11:00:07 -040055void Traverser::nextIteration()
56{
57 mFound = false;
Jamie Madill5655b842016-08-02 11:00:07 -040058}
59
Jamie Madill1048e432016-07-23 18:51:28 -040060bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
61{
Jamie Madill5655b842016-08-02 11:00:07 -040062 if (mFound)
63 {
64 return false;
65 }
66
Jamie Madill1048e432016-07-23 18:51:28 -040067 // Test 0: skip non-pow operators.
68 if (node->getOp() != EOpPow)
69 {
70 return true;
71 }
72
73 const TIntermSequence *sequence = node->getSequence();
74 ASSERT(sequence->size() == 2u);
Olli Etuaho629a6442017-12-11 10:55:43 +020075 const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion();
Jamie Madill1048e432016-07-23 18:51:28 -040076
77 // Test 1: check for a single constant.
Olli Etuaho629a6442017-12-11 10:55:43 +020078 if (!constantExponent || constantExponent->getNominalSize() != 1)
Jamie Madill1048e432016-07-23 18:51:28 -040079 {
80 return true;
81 }
82
Olli Etuaho629a6442017-12-11 10:55:43 +020083 ASSERT(constantExponent->getBasicType() == EbtFloat);
Olli Etuahoea22b7a2018-01-04 17:09:11 +020084 float exponentValue = constantExponent->getConstantValue()->getFConst();
Jamie Madill1048e432016-07-23 18:51:28 -040085
Olli Etuaho629a6442017-12-11 10:55:43 +020086 // Test 2: exponentValue is in the problematic range.
87 if (exponentValue < -5.0f || exponentValue > 9.0f)
Jamie Madill1048e432016-07-23 18:51:28 -040088 {
89 return true;
90 }
91
Olli Etuaho629a6442017-12-11 10:55:43 +020092 // Test 3: exponentValue is integer or pretty close to an integer.
93 if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f)
Jamie Madill1048e432016-07-23 18:51:28 -040094 {
95 return true;
96 }
97
98 // Test 4: skip -1, 0, and 1
Olli Etuaho629a6442017-12-11 10:55:43 +020099 int exponent = static_cast<int>(std::round(exponentValue));
Jamie Madill1048e432016-07-23 18:51:28 -0400100 int n = std::abs(exponent);
101 if (n < 2)
102 {
103 return true;
104 }
105
106 // Potential problem case detected, apply workaround.
Jamie Madill1048e432016-07-23 18:51:28 -0400107
108 TIntermTyped *lhs = sequence->at(0)->getAsTyped();
109 ASSERT(lhs);
110
Olli Etuaho195be942017-12-04 23:40:14 +0200111 TIntermDeclaration *lhsVariableDeclaration = nullptr;
112 TVariable *lhsVariable =
113 DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration);
114 insertStatementInParentBlock(lhsVariableDeclaration);
Jamie Madill1048e432016-07-23 18:51:28 -0400115
116 // Create a chain of n-1 multiples.
Olli Etuaho195be942017-12-04 23:40:14 +0200117 TIntermTyped *current = CreateTempSymbolNode(lhsVariable);
Jamie Madill1048e432016-07-23 18:51:28 -0400118 for (int i = 1; i < n; ++i)
119 {
Olli Etuaho195be942017-12-04 23:40:14 +0200120 TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable));
Jamie Madill1048e432016-07-23 18:51:28 -0400121 mul->setLine(node->getLine());
122 current = mul;
123 }
124
125 // For negative pow, compute the reciprocal of the positive pow.
126 if (exponent < 0)
127 {
128 TConstantUnion *oneVal = new TConstantUnion();
129 oneVal->setFConst(1.0f);
130 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
Olli Etuaho3272a6d2016-08-29 17:54:50 +0300131 TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
Jamie Madilld7b1ab52016-12-12 14:42:19 -0500132 current = div;
Jamie Madill1048e432016-07-23 18:51:28 -0400133 }
134
Olli Etuahoea39a222017-07-06 12:47:59 +0300135 queueReplacement(current, OriginalNode::IS_DROPPED);
Jamie Madill5655b842016-08-02 11:00:07 -0400136 mFound = true;
137 return false;
Jamie Madill1048e432016-07-23 18:51:28 -0400138}
139
140} // anonymous namespace
141
Olli Etuahoa5e693a2017-07-13 16:07:26 +0300142void ExpandIntegerPowExpressions(TIntermNode *root, TSymbolTable *symbolTable)
Jamie Madill1048e432016-07-23 18:51:28 -0400143{
Olli Etuahoa5e693a2017-07-13 16:07:26 +0300144 Traverser::Apply(root, symbolTable);
Jamie Madill1048e432016-07-23 18:51:28 -0400145}
146
147} // namespace sh