blob: 1635186baa6ad7139a2ebfb08fda3e1774a07c77 [file] [log] [blame]
Marat Dukhan281262d2020-08-10 13:23:21 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 8 == 0
7$assert CHANNEL_TILE >= 8
Marat Dukhan9e258d62022-01-12 10:50:51 -08008$assert ROW_TILE >= 3
Marat Dukhan53f41062022-01-11 19:44:57 -08009$assert REQUANTIZATION == "FP32"
Marat Dukhan281262d2020-08-10 13:23:21 -070010$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/gavgpool.h>
Marat Dukhan139337c2022-01-12 14:41:11 -080016$if ARMV8:
17 #include <xnnpack/intrinsics-polyfill.h>
Marat Dukhan281262d2020-08-10 13:23:21 -070018
19
Marat Dukhan53f41062022-01-11 19:44:57 -080020$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if ARMV8 else "neon")
21$ISA = "neonv8" if ARMV8 else "neon"
Marat Dukhan9e258d62022-01-12 10:50:51 -080022void xnn_qs8_gavgpool_minmax_fp32_ukernel_${ROW_TILE}x__${ISA}_c${CHANNEL_TILE}(
Marat Dukhan281262d2020-08-10 13:23:21 -070023 size_t rows,
24 size_t channels,
25 const int8_t* input,
26 size_t input_stride,
27 const int8_t* zero,
28 int8_t* output,
Marat Dukhan5d456ce2022-01-07 16:07:54 -080029 const union xnn_qs8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
Marat Dukhan281262d2020-08-10 13:23:21 -070030{
31 assert(rows != 0);
32 assert(rows <= ${ROW_TILE});
33 assert(channels != 0);
34
35 const int8_t* i0 = input;
36 $for M in range(1, ROW_TILE):
37 const int8_t* i${M} = (const int8_t*) ((uintptr_t) i${M-1} + input_stride);
38 $if M % 2 == 1:
39 if XNN_UNPREDICTABLE(rows < ${M+1}) {
40 i${M} = zero;
41 }
42 $else:
43 if XNN_UNPREDICTABLE(rows <= ${M}) {
44 i${M} = zero;
45 }
46
Marat Dukhan53f41062022-01-11 19:44:57 -080047 const int32x4_t vinit_bias = vld1q_dup_s32(&params->${PARAMS_STRUCT}.init_bias);
48 const float32x4_t vscale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.scale);
49 $if ARMV8:
50 const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->${PARAMS_STRUCT}.output_zero_point);
Marat Dukhan281262d2020-08-10 13:23:21 -070051 $else:
Marat Dukhan53f41062022-01-11 19:44:57 -080052 const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
53 const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(&params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
54 $if CHANNEL_TILE > 8:
55 const int8x16_t voutput_min = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_min);
56 const int8x16_t voutput_max = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_max);
57 $else:
58 const int8x8_t voutput_min = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_min);
59 const int8x8_t voutput_max = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_max);
60 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) {
Marat Dukhan9e258d62022-01-12 10:50:51 -080061 $for M in range(2):
Marat Dukhan281262d2020-08-10 13:23:21 -070062 $for C in range(0, CHANNEL_TILE, 8):
63 const int8x8_t vi${M}x${ABC[C:C+8]} = vld1_s8(i${M}); i${M} += 8;
64
Marat Dukhan9e258d62022-01-12 10:50:51 -080065 $for C in range(0, CHANNEL_TILE, 8):
66 const int8x8_t vi2x${ABC[C:C+8]} = vld1_s8(i2); i2 += 8;
67 int16x8_t vacc${ABC[C:C+8]} = vaddl_s8(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]});
Marat Dukhan281262d2020-08-10 13:23:21 -070068
Marat Dukhan9e258d62022-01-12 10:50:51 -080069 $for M in range(2, ROW_TILE):
Marat Dukhan281262d2020-08-10 13:23:21 -070070 $for C in range(0, CHANNEL_TILE, 8):
Marat Dukhan9e258d62022-01-12 10:50:51 -080071 $if M + 1 != ROW_TILE:
72 const int8x8_t vi${M+1}x${ABC[C:C+8]} = vld1_s8(i${M+1}); i${M+1} += 8;
73 vacc${ABC[C:C+8]} = vaddw_s8(vacc${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]});
Marat Dukhan281262d2020-08-10 13:23:21 -070074
75 $for C in range(0, CHANNEL_TILE, 8):
Marat Dukhan9e258d62022-01-12 10:50:51 -080076 int32x4_t vacc${ABC[C:C+4]} = vaddw_s16(vinit_bias, vget_low_s16(vacc${ABC[C:C+8]}));
77 int32x4_t vacc${ABC[C+4:C+8]} = vaddw_s16(vinit_bias, vget_high_s16(vacc${ABC[C:C+8]}));
Marat Dukhan281262d2020-08-10 13:23:21 -070078
79 $for C in range(0, CHANNEL_TILE, 4):
Marat Dukhan53f41062022-01-11 19:44:57 -080080 float32x4_t vfpacc${ABC[C:C+4]} = vcvtq_f32_s32(vacc${ABC[C:C+4]});
Marat Dukhan281262d2020-08-10 13:23:21 -070081
82 $for C in range(0, CHANNEL_TILE, 4):
Marat Dukhan53f41062022-01-11 19:44:57 -080083 vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale);
Marat Dukhan281262d2020-08-10 13:23:21 -070084
Marat Dukhan53f41062022-01-11 19:44:57 -080085 $if ARMV8:
86 $for C in range(0, CHANNEL_TILE, 4):
87 vacc${ABC[C:C+4]} = vcvtnq_s32_f32(vfpacc${ABC[C:C+4]});
88 $else:
89 $for C in range(0, CHANNEL_TILE, 4):
90 vacc${ABC[C:C+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[C:C+4]}, vmagic_bias));
Marat Dukhan281262d2020-08-10 13:23:21 -070091
Marat Dukhan53f41062022-01-11 19:44:57 -080092 $for C in range(0, CHANNEL_TILE, 4):
93 vacc${ABC[C:C+4]} = vqsubq_s32(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point);
Marat Dukhan281262d2020-08-10 13:23:21 -070094
Marat Dukhan53f41062022-01-11 19:44:57 -080095 #if XNN_ARCH_ARM64
96 $for C in range(0, CHANNEL_TILE, 8):
Marat Dukhan9e258d62022-01-12 10:50:51 -080097 vacc${ABC[C:C+8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[C:C+4]}), vacc${ABC[C+4:C+8]});
98 #else // !XNN_ARCH_ARM64
99 $for C in range(0, CHANNEL_TILE, 8):
100 vacc${ABC[C:C+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[C:C+4]}), vqmovn_s32(vacc${ABC[C+4:C+8]}));
101 #endif // !XNN_ARCH_ARM64
Marat Dukhan281262d2020-08-10 13:23:21 -0700102
Marat Dukhan9e258d62022-01-12 10:50:51 -0800103 $if ARMV8:
104 $for C in range(0, CHANNEL_TILE, 8):
105 vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point);
Marat Dukhan281262d2020-08-10 13:23:21 -0700106
Marat Dukhan9e258d62022-01-12 10:50:51 -0800107 #if XNN_ARCH_ARM64
Marat Dukhan53f41062022-01-11 19:44:57 -0800108 $for C in range(0, CHANNEL_TILE, 16):
109 $if C + 8 < CHANNEL_TILE:
110 int8x16_t vout${ABC[C:C+16]} = vqmovn_high_s16(vqmovn_s16(vacc${ABC[C:C+8]}), vacc${ABC[C+8:C+16]});
111 $else:
112 int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]});
113 #else // !XNN_ARCH_ARM64
Marat Dukhan53f41062022-01-11 19:44:57 -0800114 $for C in range(0, CHANNEL_TILE, 16):
115 $if C + 8 < CHANNEL_TILE:
116 int8x16_t vout${ABC[C:C+16]} = vcombine_s8(vqmovn_s16(vacc${ABC[C:C+8]}), vqmovn_s16(vacc${ABC[C+8:C+16]}));
117 $else:
118 int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]});
119 #endif // !XNN_ARCH_ARM64
Marat Dukhan281262d2020-08-10 13:23:21 -0700120
121 $for C in range(0, CHANNEL_TILE, 16):
122 $if C + 8 < CHANNEL_TILE:
123 vout${ABC[C:C+16]} = vmaxq_s8(vout${ABC[C:C+16]}, voutput_min);
124 $elif CHANNEL_TILE > 8:
125 vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_min));
126 $else:
127 vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, voutput_min);
128
129 $for C in range(0, CHANNEL_TILE, 16):
130 $if C + 8 < CHANNEL_TILE:
131 vout${ABC[C:C+16]} = vminq_s8(vout${ABC[C:C+16]}, voutput_max);
132 $elif CHANNEL_TILE > 8:
133 vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_max));
134 $else:
135 vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, voutput_max);
136
137 $for C in range(0, CHANNEL_TILE, 16):
138 $if C + 8 < CHANNEL_TILE:
139 vst1q_s8(output, vout${ABC[C:C+16]}); output += 16;
140 $else:
141 vst1_s8(output, vout${ABC[C:C+8]}); output += 8;
Marat Dukhan281262d2020-08-10 13:23:21 -0700142 }
143 if XNN_UNLIKELY(channels != 0) {
144 ${"do " if CHANNEL_TILE > 8 else ""}{
Marat Dukhan9e258d62022-01-12 10:50:51 -0800145 $for M in range(3):
Marat Dukhan281262d2020-08-10 13:23:21 -0700146 const int8x8_t vi${M}x${ABC[0:8]} = vld1_s8(i${M}); i${M} += 8;
Marat Dukhan9e258d62022-01-12 10:50:51 -0800147 int16x8_t vacc${ABC[0:8]} = vaddl_s8(vi0x${ABC[0:8]}, vi1x${ABC[0:8]});
Marat Dukhan281262d2020-08-10 13:23:21 -0700148
Marat Dukhan9e258d62022-01-12 10:50:51 -0800149 $for M in range(2, ROW_TILE):
150 $if M + 1 != ROW_TILE:
151 const int8x8_t vi${M+1}x${ABC[0:8]} = vld1_s8(i${M+1}); i${M+1} += 8;
152 vacc${ABC[0:8]} = vaddw_s8(vacc${ABC[0:8]}, vi${M}x${ABC[0:8]});
Marat Dukhan281262d2020-08-10 13:23:21 -0700153
Marat Dukhan9e258d62022-01-12 10:50:51 -0800154 int32x4_t vacc${ABC[0:4]} = vaddw_s16(vinit_bias, vget_low_s16(vacc${ABC[0:8]}));
155 int32x4_t vacc${ABC[4:8]} = vaddw_s16(vinit_bias, vget_high_s16(vacc${ABC[0:8]}));
Marat Dukhan281262d2020-08-10 13:23:21 -0700156
Marat Dukhan53f41062022-01-11 19:44:57 -0800157 float32x4_t vfpacc${ABC[0:4]} = vcvtq_f32_s32(vacc${ABC[0:4]});
158 float32x4_t vfpacc${ABC[4:8]} = vcvtq_f32_s32(vacc${ABC[4:8]});
Marat Dukhan281262d2020-08-10 13:23:21 -0700159
Marat Dukhan53f41062022-01-11 19:44:57 -0800160 vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale);
161 vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale);
Marat Dukhan281262d2020-08-10 13:23:21 -0700162
Marat Dukhan53f41062022-01-11 19:44:57 -0800163 $if ARMV8:
164 vacc${ABC[0:4]} = vcvtnq_s32_f32(vfpacc${ABC[0:4]});
165 vacc${ABC[4:8]} = vcvtnq_s32_f32(vfpacc${ABC[4:8]});
166 $else:
167 vacc${ABC[0:4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[0:4]}, vmagic_bias));
168 vacc${ABC[4:8]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[4:8]}, vmagic_bias));
Marat Dukhan281262d2020-08-10 13:23:21 -0700169
Marat Dukhan53f41062022-01-11 19:44:57 -0800170 vacc${ABC[0:4]} = vqsubq_s32(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point);
171 vacc${ABC[4:8]} = vqsubq_s32(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point);
Marat Dukhan281262d2020-08-10 13:23:21 -0700172
Marat Dukhan53f41062022-01-11 19:44:57 -0800173 #if XNN_ARCH_ARM64
Marat Dukhan9e258d62022-01-12 10:50:51 -0800174 vacc${ABC[0:8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[0:4]}), vacc${ABC[4:8]});
Marat Dukhan53f41062022-01-11 19:44:57 -0800175 #else
Marat Dukhan9e258d62022-01-12 10:50:51 -0800176 vacc${ABC[0:8]} = vcombine_s16(vqmovn_s32(vacc${ABC[0:4]}), vqmovn_s32(vacc${ABC[4:8]}));
Marat Dukhan53f41062022-01-11 19:44:57 -0800177 #endif
178 $if ARMV8:
179 vacc${ABC[0:8]} = vqaddq_s16(vacc${ABC[0:8]}, voutput_zero_point);
Marat Dukhan281262d2020-08-10 13:23:21 -0700180
181 int8x8_t vout${ABC[0:8]} = vqmovn_s16(vacc${ABC[0:8]});
Marat Dukhan281262d2020-08-10 13:23:21 -0700182 $if CHANNEL_TILE > 8:
183 vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, vget_low_s8(voutput_min));
184 vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, vget_low_s8(voutput_max));
Marat Dukhan281262d2020-08-10 13:23:21 -0700185
Marat Dukhan281262d2020-08-10 13:23:21 -0700186 if XNN_LIKELY(channels >= 8) {
187 vst1_s8(output, vout${ABC[0:8]}); output += 8;
188 channels -= 8;
189 } else {
190 if (channels & 4) {
Marat Dukhan5f7cf552021-11-25 17:37:03 -0800191 vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4;
Marat Dukhan281262d2020-08-10 13:23:21 -0700192 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4);
193 }
194 if (channels & 2) {
Marat Dukhan5f7cf552021-11-25 17:37:03 -0800195 vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2;
Marat Dukhan281262d2020-08-10 13:23:21 -0700196 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2);
197 }
198 if (channels & 1) {
199 vst1_lane_s8(output, vout${ABC[0:8]}, 0); output += 1;
200 }
201 channels = 0;
202 }
203 $else:
Marat Dukhan53f41062022-01-11 19:44:57 -0800204 vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, voutput_min);
205 vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, voutput_max);
206
Marat Dukhan281262d2020-08-10 13:23:21 -0700207 if (channels & 4) {
Marat Dukhan5f7cf552021-11-25 17:37:03 -0800208 vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4;
Marat Dukhan281262d2020-08-10 13:23:21 -0700209 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4);
210 }
211 if (channels & 2) {
Marat Dukhan5f7cf552021-11-25 17:37:03 -0800212 vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2;
Marat Dukhan281262d2020-08-10 13:23:21 -0700213 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2);
214 }
215 if (channels & 1) {
216 vst1_lane_s8(output, vout${ABC[0:8]}, 0);
217 }
218 }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""}
219 }
220}