blob: d95293b64bad23e14cfee405be50d5f3e50c7384 [file] [log] [blame]
Marat Dukhan6adff4e2019-10-14 18:32:07 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7#include <math.h>
8
9#include <immintrin.h>
10
11#include <xnnpack/math-stubs.h>
12
13
14void xnn_math_f32_exp__avx2_p5(
15 size_t n,
16 const float* input,
17 float* output)
18{
19 assert(n % (8 * sizeof(float)) == 0);
20
21 const __m256 magic_bias = _mm256_set1_ps(0x1.800000p+23f);
22 // The smallest x for which expf(x) is non-zero.
23 const __m256 zero_cutoff = _mm256_set1_ps(-0x1.9FE368p+6f);
24 // The largest x for which expf(x) is finite.
25 const __m256 inf_cutoff = _mm256_set1_ps(0x1.62E42Ep+6f);
26 const __m256 log2e = _mm256_set1_ps(0x1.715476p+0f);
27 const __m256 minus_ln2_hi = _mm256_set1_ps(-0x1.62E43p-1f);
28 const __m256 minus_ln2_lo = _mm256_set1_ps(0x1.05C61p-29f);
29 const __m256 plus_inf = _mm256_set1_ps(INFINITY);
30
31 const __m256 c1 = _mm256_set1_ps(0x1.FFFFF6p-1f);
32 const __m256 c2 = _mm256_set1_ps(0x1.FFFDC6p-2f);
33 const __m256 c3 = _mm256_set1_ps(0x1.555A80p-3f);
34 const __m256 c4 = _mm256_set1_ps(0x1.573A1Ap-5f);
35 const __m256 c5 = _mm256_set1_ps(0x1.0F9F9Cp-7f);
36
37 const __m256i min_exponent = _mm256_set1_epi32(0xC1000000);
38 const __m256i max_exponent = _mm256_set1_epi32(0x3F800000);
39 const __m256i default_exponent = max_exponent;
40
41 for (; n != 0; n -= 8 * sizeof(float)) {
42 const __m256 x = _mm256_loadu_ps(input);
43 __m256 t = _mm256_fmadd_ps(x, log2e, magic_bias);
44 __m256i eo = _mm256_slli_epi32(_mm256_castps_si256(t), 23);
45 __m256i en = _mm256_max_epi32(eo, min_exponent);
46 en = _mm256_min_epi32(en, max_exponent);
47 eo = _mm256_sub_epi32(eo, en);
48 const __m256 sn = _mm256_castsi256_ps(_mm256_add_epi32(en, default_exponent));
49 const __m256 so = _mm256_castsi256_ps(_mm256_add_epi32(eo, default_exponent));
50 t = _mm256_sub_ps(t, magic_bias);
51 __m256 rx = _mm256_fmadd_ps(t, minus_ln2_hi, x);
52 rx = _mm256_fmadd_ps(t, minus_ln2_lo, rx);
53 // f = so * sn * (1 + x * (c1 + x * (c2 + x * (c3 + x * (c4 + x * c5)))))
54 // = sn * (so + (x * so) * (c1 + x * (c2 + x * (c3 + x * (c4 + x * c5))))))
55 __m256 rf = _mm256_fmadd_ps(c5, rx, c4);
56 rf = _mm256_fmadd_ps(rf, rx, c3);
57 rf = _mm256_fmadd_ps(rf, rx, c2);
58 rf = _mm256_fmadd_ps(rf, rx, c1);
59 rx = _mm256_mul_ps(rx, so);
60 __m256 f = _mm256_mul_ps(sn, _mm256_fmadd_ps(rx, rf, so));
61 // For inputs below zero cutoff, replace output with +0.0f.
62 // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
63 f = _mm256_andnot_ps(_mm256_cmp_ps(x, zero_cutoff, _CMP_LT_OS), f);
64 // For inputs above inf cutoff, replace output with +inf.
65 // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
66 f = _mm256_blendv_ps(f, plus_inf, _mm256_cmp_ps(x, inf_cutoff, _CMP_GT_OS));
67 _mm256_storeu_ps(output, f);
68
69 input += 8;
70 output += 8;
71 }
72}