Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 1 | // Copyright 2020 Google LLC |
| 2 | // |
| 3 | // This source code is licensed under the BSD-style license found in the |
| 4 | // LICENSE file in the root directory of this source tree. |
| 5 | |
| 6 | $assert CHANNEL_TILE % 8 == 0 |
| 7 | $assert CHANNEL_TILE >= 8 |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 8 | $assert ROW_TILE >= 3 |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 9 | $assert REQUANTIZATION == "FP32" |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 10 | $ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
| 11 | #include <assert.h> |
| 12 | |
| 13 | #include <arm_neon.h> |
| 14 | |
| 15 | #include <xnnpack/gavgpool.h> |
Marat Dukhan | 139337c | 2022-01-12 14:41:11 -0800 | [diff] [blame^] | 16 | $if ARMV8: |
| 17 | #include <xnnpack/intrinsics-polyfill.h> |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 18 | |
| 19 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 20 | $PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if ARMV8 else "neon") |
| 21 | $ISA = "neonv8" if ARMV8 else "neon" |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 22 | void xnn_qs8_gavgpool_minmax_fp32_ukernel_${ROW_TILE}x__${ISA}_c${CHANNEL_TILE}( |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 23 | size_t rows, |
| 24 | size_t channels, |
| 25 | const int8_t* input, |
| 26 | size_t input_stride, |
| 27 | const int8_t* zero, |
| 28 | int8_t* output, |
Marat Dukhan | 5d456ce | 2022-01-07 16:07:54 -0800 | [diff] [blame] | 29 | const union xnn_qs8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 30 | { |
| 31 | assert(rows != 0); |
| 32 | assert(rows <= ${ROW_TILE}); |
| 33 | assert(channels != 0); |
| 34 | |
| 35 | const int8_t* i0 = input; |
| 36 | $for M in range(1, ROW_TILE): |
| 37 | const int8_t* i${M} = (const int8_t*) ((uintptr_t) i${M-1} + input_stride); |
| 38 | $if M % 2 == 1: |
| 39 | if XNN_UNPREDICTABLE(rows < ${M+1}) { |
| 40 | i${M} = zero; |
| 41 | } |
| 42 | $else: |
| 43 | if XNN_UNPREDICTABLE(rows <= ${M}) { |
| 44 | i${M} = zero; |
| 45 | } |
| 46 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 47 | const int32x4_t vinit_bias = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.init_bias); |
| 48 | const float32x4_t vscale = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.scale); |
| 49 | $if ARMV8: |
| 50 | const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->${PARAMS_STRUCT}.output_zero_point); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 51 | $else: |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 52 | const float32x4_t vmagic_bias = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.magic_bias); |
| 53 | const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.magic_bias_less_output_zero_point); |
| 54 | $if CHANNEL_TILE > 8: |
| 55 | const int8x16_t voutput_min = vld1q_dup_s8(¶ms->${PARAMS_STRUCT}.output_min); |
| 56 | const int8x16_t voutput_max = vld1q_dup_s8(¶ms->${PARAMS_STRUCT}.output_max); |
| 57 | $else: |
| 58 | const int8x8_t voutput_min = vld1_dup_s8(¶ms->${PARAMS_STRUCT}.output_min); |
| 59 | const int8x8_t voutput_max = vld1_dup_s8(¶ms->${PARAMS_STRUCT}.output_max); |
| 60 | for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 61 | $for M in range(2): |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 62 | $for C in range(0, CHANNEL_TILE, 8): |
| 63 | const int8x8_t vi${M}x${ABC[C:C+8]} = vld1_s8(i${M}); i${M} += 8; |
| 64 | |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 65 | $for C in range(0, CHANNEL_TILE, 8): |
| 66 | const int8x8_t vi2x${ABC[C:C+8]} = vld1_s8(i2); i2 += 8; |
| 67 | int16x8_t vacc${ABC[C:C+8]} = vaddl_s8(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 68 | |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 69 | $for M in range(2, ROW_TILE): |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 70 | $for C in range(0, CHANNEL_TILE, 8): |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 71 | $if M + 1 != ROW_TILE: |
| 72 | const int8x8_t vi${M+1}x${ABC[C:C+8]} = vld1_s8(i${M+1}); i${M+1} += 8; |
| 73 | vacc${ABC[C:C+8]} = vaddw_s8(vacc${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]}); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 74 | |
| 75 | $for C in range(0, CHANNEL_TILE, 8): |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 76 | int32x4_t vacc${ABC[C:C+4]} = vaddw_s16(vinit_bias, vget_low_s16(vacc${ABC[C:C+8]})); |
| 77 | int32x4_t vacc${ABC[C+4:C+8]} = vaddw_s16(vinit_bias, vget_high_s16(vacc${ABC[C:C+8]})); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 78 | |
| 79 | $for C in range(0, CHANNEL_TILE, 4): |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 80 | float32x4_t vfpacc${ABC[C:C+4]} = vcvtq_f32_s32(vacc${ABC[C:C+4]}); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 81 | |
| 82 | $for C in range(0, CHANNEL_TILE, 4): |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 83 | vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 84 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 85 | $if ARMV8: |
| 86 | $for C in range(0, CHANNEL_TILE, 4): |
| 87 | vacc${ABC[C:C+4]} = vcvtnq_s32_f32(vfpacc${ABC[C:C+4]}); |
| 88 | $else: |
| 89 | $for C in range(0, CHANNEL_TILE, 4): |
| 90 | vacc${ABC[C:C+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[C:C+4]}, vmagic_bias)); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 91 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 92 | $for C in range(0, CHANNEL_TILE, 4): |
| 93 | vacc${ABC[C:C+4]} = vqsubq_s32(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 94 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 95 | #if XNN_ARCH_ARM64 |
| 96 | $for C in range(0, CHANNEL_TILE, 8): |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 97 | vacc${ABC[C:C+8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[C:C+4]}), vacc${ABC[C+4:C+8]}); |
| 98 | #else // !XNN_ARCH_ARM64 |
| 99 | $for C in range(0, CHANNEL_TILE, 8): |
| 100 | vacc${ABC[C:C+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[C:C+4]}), vqmovn_s32(vacc${ABC[C+4:C+8]})); |
| 101 | #endif // !XNN_ARCH_ARM64 |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 102 | |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 103 | $if ARMV8: |
| 104 | $for C in range(0, CHANNEL_TILE, 8): |
| 105 | vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 106 | |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 107 | #if XNN_ARCH_ARM64 |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 108 | $for C in range(0, CHANNEL_TILE, 16): |
| 109 | $if C + 8 < CHANNEL_TILE: |
| 110 | int8x16_t vout${ABC[C:C+16]} = vqmovn_high_s16(vqmovn_s16(vacc${ABC[C:C+8]}), vacc${ABC[C+8:C+16]}); |
| 111 | $else: |
| 112 | int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]}); |
| 113 | #else // !XNN_ARCH_ARM64 |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 114 | $for C in range(0, CHANNEL_TILE, 16): |
| 115 | $if C + 8 < CHANNEL_TILE: |
| 116 | int8x16_t vout${ABC[C:C+16]} = vcombine_s8(vqmovn_s16(vacc${ABC[C:C+8]}), vqmovn_s16(vacc${ABC[C+8:C+16]})); |
| 117 | $else: |
| 118 | int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]}); |
| 119 | #endif // !XNN_ARCH_ARM64 |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 120 | |
| 121 | $for C in range(0, CHANNEL_TILE, 16): |
| 122 | $if C + 8 < CHANNEL_TILE: |
| 123 | vout${ABC[C:C+16]} = vmaxq_s8(vout${ABC[C:C+16]}, voutput_min); |
| 124 | $elif CHANNEL_TILE > 8: |
| 125 | vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_min)); |
| 126 | $else: |
| 127 | vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, voutput_min); |
| 128 | |
| 129 | $for C in range(0, CHANNEL_TILE, 16): |
| 130 | $if C + 8 < CHANNEL_TILE: |
| 131 | vout${ABC[C:C+16]} = vminq_s8(vout${ABC[C:C+16]}, voutput_max); |
| 132 | $elif CHANNEL_TILE > 8: |
| 133 | vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_max)); |
| 134 | $else: |
| 135 | vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, voutput_max); |
| 136 | |
| 137 | $for C in range(0, CHANNEL_TILE, 16): |
| 138 | $if C + 8 < CHANNEL_TILE: |
| 139 | vst1q_s8(output, vout${ABC[C:C+16]}); output += 16; |
| 140 | $else: |
| 141 | vst1_s8(output, vout${ABC[C:C+8]}); output += 8; |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 142 | } |
| 143 | if XNN_UNLIKELY(channels != 0) { |
| 144 | ${"do " if CHANNEL_TILE > 8 else ""}{ |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 145 | $for M in range(3): |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 146 | const int8x8_t vi${M}x${ABC[0:8]} = vld1_s8(i${M}); i${M} += 8; |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 147 | int16x8_t vacc${ABC[0:8]} = vaddl_s8(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 148 | |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 149 | $for M in range(2, ROW_TILE): |
| 150 | $if M + 1 != ROW_TILE: |
| 151 | const int8x8_t vi${M+1}x${ABC[0:8]} = vld1_s8(i${M+1}); i${M+1} += 8; |
| 152 | vacc${ABC[0:8]} = vaddw_s8(vacc${ABC[0:8]}, vi${M}x${ABC[0:8]}); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 153 | |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 154 | int32x4_t vacc${ABC[0:4]} = vaddw_s16(vinit_bias, vget_low_s16(vacc${ABC[0:8]})); |
| 155 | int32x4_t vacc${ABC[4:8]} = vaddw_s16(vinit_bias, vget_high_s16(vacc${ABC[0:8]})); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 156 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 157 | float32x4_t vfpacc${ABC[0:4]} = vcvtq_f32_s32(vacc${ABC[0:4]}); |
| 158 | float32x4_t vfpacc${ABC[4:8]} = vcvtq_f32_s32(vacc${ABC[4:8]}); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 159 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 160 | vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale); |
| 161 | vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 162 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 163 | $if ARMV8: |
| 164 | vacc${ABC[0:4]} = vcvtnq_s32_f32(vfpacc${ABC[0:4]}); |
| 165 | vacc${ABC[4:8]} = vcvtnq_s32_f32(vfpacc${ABC[4:8]}); |
| 166 | $else: |
| 167 | vacc${ABC[0:4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[0:4]}, vmagic_bias)); |
| 168 | vacc${ABC[4:8]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[4:8]}, vmagic_bias)); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 169 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 170 | vacc${ABC[0:4]} = vqsubq_s32(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point); |
| 171 | vacc${ABC[4:8]} = vqsubq_s32(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 172 | |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 173 | #if XNN_ARCH_ARM64 |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 174 | vacc${ABC[0:8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[0:4]}), vacc${ABC[4:8]}); |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 175 | #else |
Marat Dukhan | 9e258d6 | 2022-01-12 10:50:51 -0800 | [diff] [blame] | 176 | vacc${ABC[0:8]} = vcombine_s16(vqmovn_s32(vacc${ABC[0:4]}), vqmovn_s32(vacc${ABC[4:8]})); |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 177 | #endif |
| 178 | $if ARMV8: |
| 179 | vacc${ABC[0:8]} = vqaddq_s16(vacc${ABC[0:8]}, voutput_zero_point); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 180 | |
| 181 | int8x8_t vout${ABC[0:8]} = vqmovn_s16(vacc${ABC[0:8]}); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 182 | $if CHANNEL_TILE > 8: |
| 183 | vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, vget_low_s8(voutput_min)); |
| 184 | vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, vget_low_s8(voutput_max)); |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 185 | |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 186 | if XNN_LIKELY(channels >= 8) { |
| 187 | vst1_s8(output, vout${ABC[0:8]}); output += 8; |
| 188 | channels -= 8; |
| 189 | } else { |
| 190 | if (channels & 4) { |
Marat Dukhan | 5f7cf55 | 2021-11-25 17:37:03 -0800 | [diff] [blame] | 191 | vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4; |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 192 | vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); |
| 193 | } |
| 194 | if (channels & 2) { |
Marat Dukhan | 5f7cf55 | 2021-11-25 17:37:03 -0800 | [diff] [blame] | 195 | vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2; |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 196 | vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); |
| 197 | } |
| 198 | if (channels & 1) { |
| 199 | vst1_lane_s8(output, vout${ABC[0:8]}, 0); output += 1; |
| 200 | } |
| 201 | channels = 0; |
| 202 | } |
| 203 | $else: |
Marat Dukhan | 53f4106 | 2022-01-11 19:44:57 -0800 | [diff] [blame] | 204 | vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, voutput_min); |
| 205 | vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, voutput_max); |
| 206 | |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 207 | if (channels & 4) { |
Marat Dukhan | 5f7cf55 | 2021-11-25 17:37:03 -0800 | [diff] [blame] | 208 | vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4; |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 209 | vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); |
| 210 | } |
| 211 | if (channels & 2) { |
Marat Dukhan | 5f7cf55 | 2021-11-25 17:37:03 -0800 | [diff] [blame] | 212 | vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2; |
Marat Dukhan | 281262d | 2020-08-10 13:23:21 -0700 | [diff] [blame] | 213 | vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); |
| 214 | } |
| 215 | if (channels & 1) { |
| 216 | vst1_lane_s8(output, vout${ABC[0:8]}, 0); |
| 217 | } |
| 218 | }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""} |
| 219 | } |
| 220 | } |