blob: 308073023953581c8173cfd86a20704d5b018fcc [file] [log] [blame]
Jamie Madill1048e432016-07-23 18:51:28 -04001//
2// Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Implementation of the integer pow expressions HLSL bug workaround.
7// See header for more info.
8
9#include "compiler/translator/ExpandIntegerPowExpressions.h"
10
11#include <cmath>
12#include <cstdlib>
13
Olli Etuahocccf2b02017-07-05 14:50:54 +030014#include "compiler/translator/IntermTraverse.h"
Jamie Madill1048e432016-07-23 18:51:28 -040015
16namespace sh
17{
18
19namespace
20{
21
22class Traverser : public TIntermTraverser
23{
24 public:
Olli Etuahoa5e693a2017-07-13 16:07:26 +030025 static void Apply(TIntermNode *root, TSymbolTable *symbolTable);
Jamie Madill1048e432016-07-23 18:51:28 -040026
27 private:
Olli Etuahoa5e693a2017-07-13 16:07:26 +030028 Traverser(TSymbolTable *symbolTable);
Jamie Madill1048e432016-07-23 18:51:28 -040029 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
Jamie Madill5655b842016-08-02 11:00:07 -040030 void nextIteration();
31
32 bool mFound = false;
Jamie Madill1048e432016-07-23 18:51:28 -040033};
34
35// static
Olli Etuahoa5e693a2017-07-13 16:07:26 +030036void Traverser::Apply(TIntermNode *root, TSymbolTable *symbolTable)
Jamie Madill1048e432016-07-23 18:51:28 -040037{
Olli Etuahoa5e693a2017-07-13 16:07:26 +030038 Traverser traverser(symbolTable);
Jamie Madill5655b842016-08-02 11:00:07 -040039 do
40 {
41 traverser.nextIteration();
42 root->traverse(&traverser);
43 if (traverser.mFound)
44 {
45 traverser.updateTree();
46 }
47 } while (traverser.mFound);
Jamie Madill1048e432016-07-23 18:51:28 -040048}
49
Olli Etuahoa5e693a2017-07-13 16:07:26 +030050Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
Jamie Madill1048e432016-07-23 18:51:28 -040051{
52}
53
Jamie Madill5655b842016-08-02 11:00:07 -040054void Traverser::nextIteration()
55{
56 mFound = false;
Jamie Madill5655b842016-08-02 11:00:07 -040057}
58
Jamie Madill1048e432016-07-23 18:51:28 -040059bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
60{
Jamie Madill5655b842016-08-02 11:00:07 -040061 if (mFound)
62 {
63 return false;
64 }
65
Jamie Madill1048e432016-07-23 18:51:28 -040066 // Test 0: skip non-pow operators.
67 if (node->getOp() != EOpPow)
68 {
69 return true;
70 }
71
72 const TIntermSequence *sequence = node->getSequence();
73 ASSERT(sequence->size() == 2u);
Olli Etuaho629a6442017-12-11 10:55:43 +020074 const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion();
Jamie Madill1048e432016-07-23 18:51:28 -040075
76 // Test 1: check for a single constant.
Olli Etuaho629a6442017-12-11 10:55:43 +020077 if (!constantExponent || constantExponent->getNominalSize() != 1)
Jamie Madill1048e432016-07-23 18:51:28 -040078 {
79 return true;
80 }
81
Olli Etuaho629a6442017-12-11 10:55:43 +020082 ASSERT(constantExponent->getBasicType() == EbtFloat);
83 float exponentValue = constantExponent->getUnionArrayPointer()->getFConst();
Jamie Madill1048e432016-07-23 18:51:28 -040084
Olli Etuaho629a6442017-12-11 10:55:43 +020085 // Test 2: exponentValue is in the problematic range.
86 if (exponentValue < -5.0f || exponentValue > 9.0f)
Jamie Madill1048e432016-07-23 18:51:28 -040087 {
88 return true;
89 }
90
Olli Etuaho629a6442017-12-11 10:55:43 +020091 // Test 3: exponentValue is integer or pretty close to an integer.
92 if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f)
Jamie Madill1048e432016-07-23 18:51:28 -040093 {
94 return true;
95 }
96
97 // Test 4: skip -1, 0, and 1
Olli Etuaho629a6442017-12-11 10:55:43 +020098 int exponent = static_cast<int>(std::round(exponentValue));
Jamie Madill1048e432016-07-23 18:51:28 -040099 int n = std::abs(exponent);
100 if (n < 2)
101 {
102 return true;
103 }
104
105 // Potential problem case detected, apply workaround.
Olli Etuaho4dd06d52017-07-05 12:41:06 +0300106 nextTemporaryId();
Jamie Madill1048e432016-07-23 18:51:28 -0400107
108 TIntermTyped *lhs = sequence->at(0)->getAsTyped();
109 ASSERT(lhs);
110
Olli Etuaho13389b62016-10-16 11:48:18 +0100111 TIntermDeclaration *init = createTempInitDeclaration(lhs);
Jamie Madilld7b1ab52016-12-12 14:42:19 -0500112 TIntermTyped *current = createTempSymbol(lhs->getType());
Jamie Madill1048e432016-07-23 18:51:28 -0400113
114 insertStatementInParentBlock(init);
115
116 // Create a chain of n-1 multiples.
117 for (int i = 1; i < n; ++i)
118 {
Olli Etuaho3272a6d2016-08-29 17:54:50 +0300119 TIntermBinary *mul = new TIntermBinary(EOpMul, current, createTempSymbol(lhs->getType()));
Jamie Madill1048e432016-07-23 18:51:28 -0400120 mul->setLine(node->getLine());
121 current = mul;
122 }
123
124 // For negative pow, compute the reciprocal of the positive pow.
125 if (exponent < 0)
126 {
127 TConstantUnion *oneVal = new TConstantUnion();
128 oneVal->setFConst(1.0f);
129 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
Olli Etuaho3272a6d2016-08-29 17:54:50 +0300130 TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
Jamie Madilld7b1ab52016-12-12 14:42:19 -0500131 current = div;
Jamie Madill1048e432016-07-23 18:51:28 -0400132 }
133
Olli Etuahoea39a222017-07-06 12:47:59 +0300134 queueReplacement(current, OriginalNode::IS_DROPPED);
Jamie Madill5655b842016-08-02 11:00:07 -0400135 mFound = true;
136 return false;
Jamie Madill1048e432016-07-23 18:51:28 -0400137}
138
139} // anonymous namespace
140
Olli Etuahoa5e693a2017-07-13 16:07:26 +0300141void ExpandIntegerPowExpressions(TIntermNode *root, TSymbolTable *symbolTable)
Jamie Madill1048e432016-07-23 18:51:28 -0400142{
Olli Etuahoa5e693a2017-07-13 16:07:26 +0300143 Traverser::Apply(root, symbolTable);
Jamie Madill1048e432016-07-23 18:51:28 -0400144}
145
146} // namespace sh