blob: 7e66577507a367a179929e5129a328ee21e14166 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert BATCH_TILE % 4 == 0
$assert BATCH_TILE >= 4
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/common.h>
#include <xnnpack/vunary.h>
void xnn_f32_sigmoid_ukernel__neon_frac_p9_p10_nr1recps_x${BATCH_TILE}(
size_t n,
const float* x,
float* y,
const void* params)
{
assert(n % sizeof(float) == 0);
const float32x4_t vhalf = vmovq_n_f32(0.5f);
// The coefficients of the numerator polynomial (odd).
const float32x4_t valpha_1 = vmovq_n_f32(2.48287947061529e-01);
const float32x4_t valpha_3 = vmovq_n_f32(8.51377133304701e-03);
const float32x4_t valpha_5 = vmovq_n_f32(6.08574864600143e-05);
const float32x4_t valpha_7 = vmovq_n_f32(1.15627324459942e-07);
const float32x4_t valpha_9 = vmovq_n_f32(4.37031012579801e-11);
// The coefficients of the denominator polynomial (even).
const float32x4_t vbeta_0 = vmovq_n_f32(9.93151921023180e-01);
const float32x4_t vbeta_2 = vmovq_n_f32(1.16817656904453e-01);
const float32x4_t vbeta_4 = vmovq_n_f32(1.70198817374094e-03);
const float32x4_t vbeta_6 = vmovq_n_f32(6.29106785017040e-06);
const float32x4_t vbeta_8 = vmovq_n_f32(5.76102136993427e-09);
const float32x4_t vbeta_10 = vmovq_n_f32(6.10247389755681e-13);
// Sigmoid ~saturates outside of this range anyway.
const float32x4_t vsigmoid_maxinput = vdupq_n_f32(18.f);
const float32x4_t vsigmoid_mininput = vdupq_n_f32(-18.f);
for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
$for N in range(0, BATCH_TILE, 4):
float32x4_t vn${ABC[N:N+4]} = vld1q_f32(x); x += 4;
// restrict range to avoid overflow, output saturates outside this anyway
$for N in range(0, BATCH_TILE, 4):
vn${ABC[N:N+4]} = vminq_f32(vn${ABC[N:N+4]}, vsigmoid_maxinput);
vn${ABC[N:N+4]} = vmaxq_f32(vn${ABC[N:N+4]}, vsigmoid_mininput);
// square the input
$for N in range(0, BATCH_TILE, 4):
const float32x4_t vn${ABC[N:N+4]}_sq = vmulq_f32(vn${ABC[N:N+4]}, vn${ABC[N:N+4]});
// Evaluate numerator polynomial
$for N in range(0, BATCH_TILE, 4):
float32x4_t vnum${ABC[N:N+4]} = vmlaq_f32(valpha_7, vn${ABC[N:N+4]}_sq, valpha_9);
$for N in range(0, BATCH_TILE, 4):
vnum${ABC[N:N+4]} = vmlaq_f32(valpha_5, vn${ABC[N:N+4]}_sq, vnum${ABC[N:N+4]});
$for N in range(0, BATCH_TILE, 4):
vnum${ABC[N:N+4]} = vmlaq_f32(valpha_3, vn${ABC[N:N+4]}_sq, vnum${ABC[N:N+4]});
$for N in range(0, BATCH_TILE, 4):
vnum${ABC[N:N+4]} = vmlaq_f32(valpha_1, vn${ABC[N:N+4]}_sq, vnum${ABC[N:N+4]});
$for N in range(0, BATCH_TILE, 4):
vnum${ABC[N:N+4]} = vmulq_f32(vn${ABC[N:N+4]}, vnum${ABC[N:N+4]});
// Evaluate denominator polynomial
$for N in range(0, BATCH_TILE, 4):
float32x4_t vdenom${ABC[N:N+4]} = vmlaq_f32(vbeta_8, vn${ABC[N:N+4]}_sq, vbeta_10);
$for N in range(0, BATCH_TILE, 4):
vdenom${ABC[N:N+4]} = vmlaq_f32(vbeta_6, vn${ABC[N:N+4]}_sq, vdenom${ABC[N:N+4]});
$for N in range(0, BATCH_TILE, 4):
vdenom${ABC[N:N+4]} = vmlaq_f32(vbeta_4, vn${ABC[N:N+4]}_sq, vdenom${ABC[N:N+4]});
$for N in range(0, BATCH_TILE, 4):
vdenom${ABC[N:N+4]} = vmlaq_f32(vbeta_2, vn${ABC[N:N+4]}_sq, vdenom${ABC[N:N+4]});
$for N in range(0, BATCH_TILE, 4):
vdenom${ABC[N:N+4]} = vmlaq_f32(vbeta_0, vn${ABC[N:N+4]}_sq, vdenom${ABC[N:N+4]});
// Do division 1. / denom
$for N in range(0, BATCH_TILE, 4):
float32x4_t vrecp${ABC[N:N+4]} = vrecpeq_f32(vdenom${ABC[N:N+4]});
// One NR iteration
$for N in range(0, BATCH_TILE, 4):
vrecp${ABC[N:N+4]} = vmulq_f32(vrecp${ABC[N:N+4]}, vrecpsq_f32(vrecp${ABC[N:N+4]}, vdenom${ABC[N:N+4]}));
// .5 + num * (1. / denom)
$for N in range(0, BATCH_TILE, 4):
const float32x4_t vsigmoid${ABC[N:N+4]} = vmlaq_f32(vhalf, vnum${ABC[N:N+4]}, vrecp${ABC[N:N+4]});
$for N in range(0, BATCH_TILE, 4):
vst1q_f32(y, vsigmoid${ABC[N:N+4]}); y += 4;
}
for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
float32x4_t vn0123 = vld1q_f32(x); x += 4;
vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
// Evaluate numerator polynomial
float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
vnum0123 = vmulq_f32(vn0123, vnum0123);
// Evaluate denominator polynomial
float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
// Do division, one NR iteration
float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
vst1q_f32(y, vsigmoid0123); y += 4;
}
if XNN_UNLIKELY(n != 0) {
float32x4_t vn0123 = vld1q_f32(x);
vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
// Evaluate numerator polynomial
float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
vnum0123 = vmulq_f32(vn0123, vnum0123);
// Evaluate denominator polynomial
float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
// Do division, one NR iteration
float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
float32x2_t vf01 = vget_low_f32(vsigmoid0123);
if (n & (2 * sizeof(float))) {
vst1_f32(y, vf01); y += 2;
vf01 = vget_high_f32(vsigmoid0123);
}
if (n & (1 * sizeof(float))) {
vst1_lane_f32(y, vf01, 0);
}
}
}