blob: d95293b64bad23e14cfee405be50d5f3e50c7384 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <math.h>
#include <immintrin.h>
#include <xnnpack/math-stubs.h>
void xnn_math_f32_exp__avx2_p5(
size_t n,
const float* input,
float* output)
{
assert(n % (8 * sizeof(float)) == 0);
const __m256 magic_bias = _mm256_set1_ps(0x1.800000p+23f);
// The smallest x for which expf(x) is non-zero.
const __m256 zero_cutoff = _mm256_set1_ps(-0x1.9FE368p+6f);
// The largest x for which expf(x) is finite.
const __m256 inf_cutoff = _mm256_set1_ps(0x1.62E42Ep+6f);
const __m256 log2e = _mm256_set1_ps(0x1.715476p+0f);
const __m256 minus_ln2_hi = _mm256_set1_ps(-0x1.62E43p-1f);
const __m256 minus_ln2_lo = _mm256_set1_ps(0x1.05C61p-29f);
const __m256 plus_inf = _mm256_set1_ps(INFINITY);
const __m256 c1 = _mm256_set1_ps(0x1.FFFFF6p-1f);
const __m256 c2 = _mm256_set1_ps(0x1.FFFDC6p-2f);
const __m256 c3 = _mm256_set1_ps(0x1.555A80p-3f);
const __m256 c4 = _mm256_set1_ps(0x1.573A1Ap-5f);
const __m256 c5 = _mm256_set1_ps(0x1.0F9F9Cp-7f);
const __m256i min_exponent = _mm256_set1_epi32(0xC1000000);
const __m256i max_exponent = _mm256_set1_epi32(0x3F800000);
const __m256i default_exponent = max_exponent;
for (; n != 0; n -= 8 * sizeof(float)) {
const __m256 x = _mm256_loadu_ps(input);
__m256 t = _mm256_fmadd_ps(x, log2e, magic_bias);
__m256i eo = _mm256_slli_epi32(_mm256_castps_si256(t), 23);
__m256i en = _mm256_max_epi32(eo, min_exponent);
en = _mm256_min_epi32(en, max_exponent);
eo = _mm256_sub_epi32(eo, en);
const __m256 sn = _mm256_castsi256_ps(_mm256_add_epi32(en, default_exponent));
const __m256 so = _mm256_castsi256_ps(_mm256_add_epi32(eo, default_exponent));
t = _mm256_sub_ps(t, magic_bias);
__m256 rx = _mm256_fmadd_ps(t, minus_ln2_hi, x);
rx = _mm256_fmadd_ps(t, minus_ln2_lo, rx);
// f = so * sn * (1 + x * (c1 + x * (c2 + x * (c3 + x * (c4 + x * c5)))))
// = sn * (so + (x * so) * (c1 + x * (c2 + x * (c3 + x * (c4 + x * c5))))))
__m256 rf = _mm256_fmadd_ps(c5, rx, c4);
rf = _mm256_fmadd_ps(rf, rx, c3);
rf = _mm256_fmadd_ps(rf, rx, c2);
rf = _mm256_fmadd_ps(rf, rx, c1);
rx = _mm256_mul_ps(rx, so);
__m256 f = _mm256_mul_ps(sn, _mm256_fmadd_ps(rx, rf, so));
// For inputs below zero cutoff, replace output with +0.0f.
// Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
f = _mm256_andnot_ps(_mm256_cmp_ps(x, zero_cutoff, _CMP_LT_OS), f);
// For inputs above inf cutoff, replace output with +inf.
// Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
f = _mm256_blendv_ps(f, plus_inf, _mm256_cmp_ps(x, inf_cutoff, _CMP_GT_OS));
_mm256_storeu_ps(output, f);
input += 8;
output += 8;
}
}