blob: dd7b82445e7057a2fb93c69122d64990fd580972 [file] [log] [blame]
/*
* Copyright 2019 Google LLC
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
$assert CR % 4 == 0
$ABC = "0123456789ABCDEFGHIJKLMN"
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/math.h>
#include <xnnpack/vmulcaddc.h>
void xnn_f32_vmulcaddc_ukernel_c${CR}__${"neonfma" if FMA else "neon"}_x${MR}(
size_t m,
size_t channels,
const float*restrict x,
size_t x_stride,
const float*restrict weights,
float*restrict y,
size_t y_stride,
const union xnn_f32_output_params params[restrict static 1])
{
assert(m != 0);
assert(channels != 0);
assert(channels % sizeof(float) == 0);
const size_t x_increment = x_stride * ${MR} - channels;
const size_t y_increment = y_stride * ${MR} - channels;
const float* x0 = x;
float* y0 = y;
$for M in range(1, MR):
const float* x${M} = (const float*) ((uintptr_t) x${M-1} + x_stride);
float* y${M} = (float*) ((uintptr_t) y${M-1} + y_stride);
$if M % 2 == 0:
if XNN_UNPREDICTABLE(m <= ${M}) {
x${M} = x${M-1};
y${M} = y${M-1};
}
$else:
if XNN_UNPREDICTABLE(m < ${M+1}) {
x${M} = x${M-1};
y${M} = y${M-1};
}
const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
do {
const float* w = weights;
size_t c = channels;
for (; c >= ${CR} * sizeof(float); c -= ${CR} * sizeof(float)) {
$for C in range(0, CR, 4):
const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32(w); w += 4;
$for M in range(MR):
$for C in range(0, CR, 4):
const float32x4_t vx${M}x${ABC[C:C+4]} = vld1q_f32(x${M}); x${M} += 4;
$if not FMA:
$for M in range(MR):
$for C in range(0, CR, 4):
float32x4_t vacc${M}x${ABC[C:C+4]} = vmulq_f32(vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
$for C in range(0, CR, 4):
const float32x4_t vbias${ABC[C:C+4]} = vld1q_f32(w); w += 4;
$if not FMA:
$for M in range(MR):
$for C in range(0, CR, 4):
vacc${M}x${ABC[C:C+4]} = vaddq_f32(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]});
$else:
$for M in range(MR):
$for C in range(0, CR, 4):
float32x4_t vacc${M}x${ABC[C:C+4]} = vfmaq_f32(vbias${ABC[C:C+4]}, vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
$for M in range(MR):
$for C in range(0, CR, 4):
vacc${M}x${ABC[C:C+4]} = vmaxq_f32(vacc${M}x${ABC[C:C+4]}, vmin);
$for M in range(MR):
$for C in range(0, CR, 4):
vacc${M}x${ABC[C:C+4]} = vminq_f32(vacc${M}x${ABC[C:C+4]}, vmax);
$for M in range(MR):
$for C in range(0, CR, 4):
vst1q_f32(y${M}, vacc${M}x${ABC[C:C+4]}); y${M} += 4;
}
if XNN_UNLIKELY(c != 0) {
$for C in range(0, CR, 4):
const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32(w); w += 4;
$for M in range(MR):
$for C in range(0, CR, 4):
const float32x4_t vx${M}x${ABC[C:C+4]} = vld1q_f32(x${M}); x${M} = (const float*) ((uintptr_t) x${M} + c);
$if not FMA:
$for M in range(MR):
$for C in range(0, CR, 4):
float32x4_t vacc${M}x${ABC[C:C+4]} = vmulq_f32(vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
$for C in range(0, CR, 4):
const float32x4_t vbias${ABC[C:C+4]} = vld1q_f32(w); w += 4;
$if not FMA:
$for M in range(MR):
$for C in range(0, CR, 4):
vacc${M}x${ABC[C:C+4]} = vaddq_f32(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]});
$else:
$for M in range(MR):
$for C in range(0, CR, 4):
float32x4_t vacc${M}x${ABC[C:C+4]} = vfmaq_f32(vbias${ABC[C:C+4]}, vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
$for M in range(MR):
$for C in range(0, CR, 4):
vacc${M}x${ABC[C:C+4]} = vmaxq_f32(vacc${M}x${ABC[C:C+4]}, vmin);
$for M in range(MR):
$for C in range(0, CR, 4):
vacc${M}x${ABC[C:C+4]} = vminq_f32(vacc${M}x${ABC[C:C+4]}, vmax);
$for LOG2C in reversed(range(CR.bit_length())):
$if CR != 1 << LOG2C:
if (c & (${1 << LOG2C} * sizeof(float))) {
$if LOG2C >= 2:
$for M in range(MR):
$for C in range(0, 1 << LOG2C, 4):
vst1q_f32(y${M}, vacc${M}x${ABC[C:C+4]}); y${M} += 4;
$for M in range(MR):
$for C in range(0, 1 << (LOG2C - 1), 4):
vacc${M}x${ABC[C:C+4]} = vacc${M}x${ABC[C+(1<<LOG2C):C+(1<<LOG2C)+4]}
$elif LOG2C == 1:
$for M in range(MR):
vst1_f32(y${M}, vacc${M}x${ABC[0:2]}); y${M} += 2;
$for M in range(MR):
vacc${M}x${ABC[0:2]} = vget_high_f32(vacc${M}x${ABC[0:4]});
$elif LOG2C == 0:
$for M in range(MR):
vst1_lane_f32(y${M}, vacc${M}x${ABC[0:2]}, 0); y${M} += 1;
}
$if LOG2C == 2:
$for M in range(MR):
float32x2_t vacc${M}x${ABC[0:2]} = vget_low_f32(vacc${M}x${ABC[0:4]});
}
$for M in range(MR):
x${M} = (const float*) ((uintptr_t) x${M} + x_increment);
y${M} = (float*) ((uintptr_t) y${M} + y_increment);
$if M % 2 == 1:
if XNN_UNPREDICTABLE(m < ${MR+M+1}) {
x${M} = x${M-1};
y${M} = y${M-1};
}
$elif M != 0:
if XNN_UNPREDICTABLE(m <= ${MR+M}) {
x${M} = x${M-1};
y${M} = y${M-1};
}
m = doz(m, ${MR});
} while (m != 0);
}