blob: da74dd42071ce4ed100eb778555d7d4e2f0d4fbf [file] [log] [blame]
// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert DATATYPE in ["QS8", "QU8"]
$assert CHANNEL_TILE >= 1
$assert CHANNEL_TILE <= 16
$assert ROW_TILE >= 3
$assert ROW_SUBTILE >= 3
$assert ROW_SUBTILE <= ROW_TILE
$assert REQUANTIZATION == "FP32"
#include <assert.h>
$if VARIANT == "LRINTF":
#include <math.h>
$elif VARIANT in ["FMAGIC", "IMAGIC"]:
#include <fp16.h>
#include <xnnpack/gavgpool.h>
#include <xnnpack/math.h>
$PARAMS_STRUCT = "fp32_scalar_" + VARIANT.lower()
$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
void xnn_${DATATYPE.lower()}_gavgpool_minmax_fp32_ukernel_${ROW_TILE}p${ROW_SUBTILE}x__scalar_${VARIANT.lower()}_c${CHANNEL_TILE}(
size_t rows,
size_t channels,
const ${XINT8_T}* input,
size_t input_stride,
const ${XINT8_T}* zero,
int32_t* buffer,
${XINT8_T}* output,
const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(rows > ${ROW_TILE});
assert(channels != 0);
const ${XINT8_T}* i0 = input;
$for M in range(1, ROW_TILE):
const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride);
const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, ${CHANNEL_TILE}) * sizeof(${XINT8_T});
const int32_t vinit_bias = params->${PARAMS_STRUCT}.init_bias;
int32_t* b = buffer;
$if CHANNEL_TILE == 1:
size_t c = channels;
do {
int32_t vacc = vinit_bias;
$for M in range(2):
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(2, ROW_TILE):
vacc += vi${M-2};
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(ROW_TILE - 2, ROW_TILE):
vacc += vi${M};
*b++ = vacc;
} while (--c != 0);
$else:
for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) {
$for C in range(CHANNEL_TILE):
const int32_t vi0x${C} = (int32_t) i0[${C}];
i0 += ${CHANNEL_TILE};
$for C in range(CHANNEL_TILE):
int32_t vacc${C} = vi0x${C} + vinit_bias;
const int32_t vi1x${C} = (int32_t) i1[${C}];
i1 += ${CHANNEL_TILE};
$for M in range(2, ROW_TILE):
$for C in range(CHANNEL_TILE):
vacc${C} += vi${M-1}x${C};
const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
i${M} += ${CHANNEL_TILE};
$for C in range(CHANNEL_TILE):
vacc${C} += vi${ROW_TILE-1}x${C};
$for C in range(CHANNEL_TILE):
b[${C}] = vacc${C};
b += ${CHANNEL_TILE};
}
for (rows -= ${ROW_TILE}; rows > ${ROW_SUBTILE}; rows -= ${ROW_SUBTILE}) {
$for M in range(ROW_SUBTILE):
i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);
int32_t* b = buffer;
$if CHANNEL_TILE == 1:
size_t c = channels;
do {
int32_t vacc = *b;
$for M in range(2):
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(2, ROW_SUBTILE):
vacc += vi${M-2};
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
vacc += vi${M};
*b++ = vacc;
} while (--c != 0);
$else:
for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) {
$for C in range(CHANNEL_TILE):
int32_t vacc${C} = b[${C}];
const int32_t vi0x${C} = (int32_t) i0[${C}];
i0 += ${CHANNEL_TILE};
$for M in range(1, ROW_SUBTILE):
$for C in range(CHANNEL_TILE):
vacc${C} += vi${M-1}x${C};
const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
i${M} += ${CHANNEL_TILE};
$for C in range(CHANNEL_TILE):
vacc${C} += vi${ROW_SUBTILE-1}x${C};
$for C in range(CHANNEL_TILE):
b[${C}] = vacc${C};
b += ${CHANNEL_TILE};
}
}
i0 = (const ${XINT8_T}*) ((uintptr_t) i${ROW_TILE - ROW_SUBTILE} + input_increment);
$for M in range(1, ROW_SUBTILE):
i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);
$if M % 2 == 1:
if XNN_UNPREDICTABLE(rows < ${M+1}) {
i${M} = zero;
}
$else:
if XNN_UNPREDICTABLE(rows <= ${M}) {
i${M} = zero;
}
const float vscale = params->${PARAMS_STRUCT}.scale;
$if VARIANT == "FMAGIC":
const float voutput_min_less_zero_point = params->fp32_scalar_fmagic.output_min_less_zero_point;
const float voutput_max_less_zero_point = params->fp32_scalar_fmagic.output_max_less_zero_point;
const float vmagic_bias = params->fp32_scalar_fmagic.magic_bias;
const int32_t vmagic_bias_less_output_zero_point = params->fp32_scalar_fmagic.magic_bias_less_output_zero_point;
$elif VARIANT == "IMAGIC":
const float vmagic_bias = params->fp32_scalar_imagic.magic_bias;
const int32_t vmagic_min = params->fp32_scalar_imagic.magic_min;
const int32_t vmagic_max = params->fp32_scalar_imagic.magic_max;
const int32_t vmagic_bias_less_zero_point = params->fp32_scalar_imagic.magic_bias_less_zero_point;
$elif VARIANT == "LRINTF":
const float voutput_min_less_zero_point = params->fp32_scalar_lrintf.output_min_less_zero_point;
const float voutput_max_less_zero_point = params->fp32_scalar_lrintf.output_max_less_zero_point;
const int32_t voutput_zero_point = params->fp32_scalar_lrintf.output_zero_point;
$if CHANNEL_TILE == 1:
do {
int32_t vacc = *buffer++;
$for M in range(2):
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(2, ROW_SUBTILE):
vacc += vi${M-2};
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
vacc += vi${M};
float vfpacc = (float) vacc * vscale;
$if VARIANT == "FMAGIC":
vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
vfpacc += vmagic_bias;
int32_t vout = (int32_t) fp32_to_bits(vfpacc) - vmagic_bias_less_output_zero_point;
$elif VARIANT == "IMAGIC":
vfpacc += vmagic_bias;
int32_t vout = (int32_t) fp32_to_bits(vfpacc);
vout = math_max_s32(vout, vmagic_min);
vout = math_min_s32(vout, vmagic_max);
vout -= vmagic_bias_less_zero_point;
$elif VARIANT == "LRINTF":
vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
const int32_t vrndacc = (int32_t) lrintf(vfpacc);
int32_t vout = vrndacc + voutput_zero_point;
*output++ = (${XINT8_T}) vout;
} while (--channels != 0);
$else:
for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) {
$for C in range(CHANNEL_TILE):
int32_t vacc${C} = buffer[${C}];
const int32_t vi0x${C} = (int32_t) i0[${C}];
buffer += ${CHANNEL_TILE};
i0 += ${CHANNEL_TILE};
$for M in range(1, ROW_SUBTILE):
$for C in range(CHANNEL_TILE):
vacc${C} += vi${M-1}x${C};
const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
i${M} += ${CHANNEL_TILE};
$for C in range(CHANNEL_TILE):
vacc${C} += vi${ROW_SUBTILE-1}x${C};
$for C in range(CHANNEL_TILE):
float vfpacc${C} = (float) vacc${C} * vscale;
$if VARIANT == "FMAGIC":
$for C in range(CHANNEL_TILE):
vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);
$for C in range(CHANNEL_TILE):
vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);
$for C in range(CHANNEL_TILE):
vfpacc${C} += vmagic_bias;
$for C in range(CHANNEL_TILE):
int32_t vout${C} = (int32_t) fp32_to_bits(vfpacc${C}) - vmagic_bias_less_output_zero_point;
$elif VARIANT == "IMAGIC":
$for C in range(CHANNEL_TILE):
vfpacc${C} += vmagic_bias;
$for C in range(CHANNEL_TILE):
int32_t vout${C} = (int32_t) fp32_to_bits(vfpacc${C});
$for C in range(CHANNEL_TILE):
vout${C} = math_max_s32(vout${C}, vmagic_min);
$for C in range(CHANNEL_TILE):
vout${C} = math_min_s32(vout${C}, vmagic_max);
$for C in range(CHANNEL_TILE):
vout${C} -= vmagic_bias_less_zero_point;
$elif VARIANT == "LRINTF":
$for C in range(CHANNEL_TILE):
vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);
$for C in range(CHANNEL_TILE):
vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);
$for C in range(CHANNEL_TILE):
const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C});
$for C in range(CHANNEL_TILE):
int32_t vout${C} = vrndacc${C} + voutput_zero_point;
$for C in range(CHANNEL_TILE):
output[${C}] = (${XINT8_T}) vout${C};
output += ${CHANNEL_TILE};
}
if XNN_UNLIKELY(channels != 0) {
$if CHANNEL_TILE == 2:
int32_t vacc = *buffer;
$for M in range(2):
const int32_t vi${M} = (int32_t) *i${M};
$for M in range(2, ROW_SUBTILE):
vacc += vi${M-2};
const int32_t vi${M} = (int32_t) *i${M};
$for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
vacc += vi${M};
float vfpacc = (float) vacc * vscale;
$if VARIANT == "FMAGIC":
vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
vfpacc += vmagic_bias;
int32_t vout = (int32_t) fp32_to_bits(vfpacc) - vmagic_bias_less_output_zero_point;
$elif VARIANT == "IMAGIC":
vfpacc += vmagic_bias;
int32_t vout = (int32_t) fp32_to_bits(vfpacc);
vout = math_max_s32(vout, vmagic_min);
vout = math_min_s32(vout, vmagic_max);
vout -= vmagic_bias_less_zero_point;
$elif VARIANT == "LRINTF":
vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
const int32_t vrndacc = (int32_t) lrintf(vfpacc);
int32_t vout = vrndacc + voutput_zero_point;
*output = (${XINT8_T}) vout;
$else:
do {
int32_t vacc = *buffer++;
$for M in range(2):
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(2, ROW_SUBTILE):
vacc += vi${M-2};
const int32_t vi${M} = (int32_t) *i${M}++;
$for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
vacc += vi${M};
float vfpacc = (float) vacc * vscale;
$if VARIANT == "FMAGIC":
vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
vfpacc += vmagic_bias;
int32_t vout = (int32_t) fp32_to_bits(vfpacc) - vmagic_bias_less_output_zero_point;
$elif VARIANT == "IMAGIC":
vfpacc += vmagic_bias;
int32_t vout = (int32_t) fp32_to_bits(vfpacc);
vout = math_max_s32(vout, vmagic_min);
vout = math_min_s32(vout, vmagic_max);
vout -= vmagic_bias_less_zero_point;
$elif VARIANT == "LRINTF":
vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
const int32_t vrndacc = (int32_t) lrintf(vfpacc);
int32_t vout = vrndacc + voutput_zero_point;
*output++ = (${XINT8_T}) vout;
} while (--channels != 0);
}
}