blob: 8198632d128d14266d56e427c7cb3d21b7a07e95 [file] [log] [blame]
Erich Elsen0a1970e2020-06-10 09:24:59 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7
8#include <psimd.h>
9
10#include <xnnpack/conv.h>
11#include <xnnpack/math.h>
12
13
14void xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x4__psimd_2x2(
15 size_t input_height,
16 size_t input_width,
17 size_t output_y_start,
18 size_t output_y_end,
19 const float* input,
20 const float* zero,
21 const float* weights,
22 float* output,
23 size_t input_padding_top,
24 size_t output_channels,
25 size_t output_height_stride,
26 size_t output_channel_stride,
27 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
28{
29 assert(input_width != 0);
30 assert(output_y_end > output_y_start);
31 assert(input_padding_top <= 1);
32 assert(output_channels != 0);
33
34 const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float);
35 const size_t input_width_increment = round_down_po2(input_width, 4) * 3 /* channels */ * sizeof(float);
36 const size_t output_width = (input_width + 1) / 2;
37 const size_t output_channel_increment = output_channel_stride * 4 - output_width * sizeof(float);
38
39 // Adjustment for padding processed below
40 const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top));
41 const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
42 const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
43 const float* i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
44 const float* i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
45 float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start);
46 float* output1 = (float*) ((uintptr_t) output0 + output_height_stride);
47
48 if XNN_UNPREDICTABLE(output_y_start < input_padding_top) {
49 i0 = zero;
50 }
51
52 const psimd_f32 vmin = psimd_load_splat_f32(&params->scalar.min);
53 const psimd_f32 vmax = psimd_load_splat_f32(&params->scalar.max);
54
55 for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 2) {
56 const size_t input_y2 = output_y * 2 + 2 - input_padding_top;
57 const size_t input_y4 = input_y2 + 2;
58 if XNN_UNPREDICTABLE(input_y2 >= input_height) {
59 i2 = zero;
60 }
61 if XNN_UNPREDICTABLE(input_y4 > input_height) {
62 i3 = zero;
63 }
64 if XNN_UNPREDICTABLE(input_y4 >= input_height) {
65 i4 = zero;
66 }
67 if XNN_UNPREDICTABLE(output_y + 2 > output_y_end) {
68 output1 = output0;
69 }
70
71 const float* w = weights;
72 size_t c = output_channels;
73 float* o0c0 = output0;
74 float* o1c0 = output1;
75 float* o0c1 = (float*) ((uintptr_t) o0c0 + output_channel_stride);
76 float* o1c1 = (float*) ((uintptr_t) o1c0 + output_channel_stride);
77 float* o0c2 = (float*) ((uintptr_t) o0c1 + output_channel_stride);
78 float* o1c2 = (float*) ((uintptr_t) o1c1 + output_channel_stride);
79 float* o0c3 = (float*) ((uintptr_t) o0c2 + output_channel_stride);
80 float* o1c3 = (float*) ((uintptr_t) o1c2 + output_channel_stride);
81 do {
82 if XNN_UNPREDICTABLE(c < 2) {
83 o0c1 = o0c0;
84 o1c1 = o1c0;
85 }
86 if XNN_UNPREDICTABLE(c <= 2) {
87 o0c2 = o0c1;
88 o1c2 = o1c1;
89 }
90 if XNN_UNPREDICTABLE(c < 4) {
91 o0c3 = o0c2;
92 o1c3 = o1c2;
93 }
94
95 // viMx0 = ( iM0c2, iM0c1, iM0c0, --- )
96 psimd_f32 vi0x0 = psimd_zero_f32();
97 psimd_f32 vi1x0 = psimd_zero_f32();
98 psimd_f32 vi2x0 = psimd_zero_f32();
99 psimd_f32 vi3x0 = psimd_zero_f32();
100 psimd_f32 vi4x0 = psimd_zero_f32();
101
102 size_t iw = input_width;
103 for (; iw >= 4; iw -= 4) {
104 psimd_f32 vo0x0 = psimd_load_f32(w);
105 psimd_f32 vo1x0 = vo0x0;
106 psimd_f32 vo0x1 = vo0x0;
107 psimd_f32 vo1x1 = vo0x0;
108
109 const psimd_f32 vk00c0 = psimd_load_f32(w + 4);
110
111 // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
112 const psimd_f32 vi0x1 = psimd_load_f32(i0); i0 += 4;
113 const psimd_f32 vi1x1 = psimd_load_f32(i1); i1 += 4;
114 const psimd_f32 vi2x1 = psimd_load_f32(i2); i2 += 4;
115 const psimd_f32 vi3x1 = psimd_load_f32(i3); i3 += 4;
116 const psimd_f32 vi4x1 = psimd_load_f32(i4); i4 += 4;
117
118 vo0x0 = psimd_qfma_f32(vo0x0, vk00c0, psimd_splat1_f32(vi0x0));
119 vo1x0 = psimd_qfma_f32(vo1x0, vk00c0, psimd_splat1_f32(vi2x0));
120 vo0x1 = psimd_qfma_f32(vo0x1, vk00c0, psimd_splat3_f32(vi0x1));
121 vo1x1 = psimd_qfma_f32(vo1x1, vk00c0, psimd_splat3_f32(vi2x1));
122
123 const psimd_f32 vk10c0 = psimd_load_f32(w + 8);
124
125 vo0x0 = psimd_qfma_f32(vo0x0, vk10c0, psimd_splat1_f32(vi1x0));
126 vo1x0 = psimd_qfma_f32(vo1x0, vk10c0, psimd_splat1_f32(vi3x0));
127 vo0x1 = psimd_qfma_f32(vo0x1, vk10c0, psimd_splat3_f32(vi1x1));
128 vo1x1 = psimd_qfma_f32(vo1x1, vk10c0, psimd_splat3_f32(vi3x1));
129
130 const psimd_f32 vk20c0 = psimd_load_f32(w + 12);
131
132 vo0x0 = psimd_qfma_f32(vo0x0, vk20c0, psimd_splat1_f32(vi2x0));
133 vo1x0 = psimd_qfma_f32(vo1x0, vk20c0, psimd_splat1_f32(vi4x0));
134 vo0x1 = psimd_qfma_f32(vo0x1, vk20c0, psimd_splat3_f32(vi2x1));
135 vo1x1 = psimd_qfma_f32(vo1x1, vk20c0, psimd_splat3_f32(vi4x1));
136
137 const psimd_f32 vk00c1 = psimd_load_f32(w + 16);
138
139 // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
140 const psimd_f32 vi0x2 = psimd_load_f32(i0); i0 += 4;
141 const psimd_f32 vi1x2 = psimd_load_f32(i1); i1 += 4;
142 const psimd_f32 vi2x2 = psimd_load_f32(i2); i2 += 4;
143 const psimd_f32 vi3x2 = psimd_load_f32(i3); i3 += 4;
144 const psimd_f32 vi4x2 = psimd_load_f32(i4); i4 += 4;
145
146 vo0x0 = psimd_qfma_f32(vo0x0, vk00c1, psimd_splat2_f32(vi0x0));
147 vo1x0 = psimd_qfma_f32(vo1x0, vk00c1, psimd_splat2_f32(vi2x0));
148 vo0x1 = psimd_qfma_f32(vo0x1, vk00c1, psimd_splat0_f32(vi0x2));
149 vo1x1 = psimd_qfma_f32(vo1x1, vk00c1, psimd_splat0_f32(vi2x2));
150
151 const psimd_f32 vk10c1 = psimd_load_f32(w + 20);
152
153 vo0x0 = psimd_qfma_f32(vo0x0, vk10c1, psimd_splat2_f32(vi1x0));
154 vo1x0 = psimd_qfma_f32(vo1x0, vk10c1, psimd_splat2_f32(vi3x0));
155 vo0x1 = psimd_qfma_f32(vo0x1, vk10c1, psimd_splat0_f32(vi1x2));
156 vo1x1 = psimd_qfma_f32(vo1x1, vk10c1, psimd_splat0_f32(vi3x2));
157
158 const psimd_f32 vk20c1 = psimd_load_f32(w + 24);
159
160 vo0x0 = psimd_qfma_f32(vo0x0, vk20c1, psimd_splat2_f32(vi2x0));
161 vo1x0 = psimd_qfma_f32(vo1x0, vk20c1, psimd_splat2_f32(vi4x0));
162 vo0x1 = psimd_qfma_f32(vo0x1, vk20c1, psimd_splat0_f32(vi2x2));
163 vo1x1 = psimd_qfma_f32(vo1x1, vk20c1, psimd_splat0_f32(vi4x2));
164
165 const psimd_f32 vk00c2 = psimd_load_f32(w + 28);
166
167 vo0x0 = psimd_qfma_f32(vo0x0, vk00c2, psimd_splat3_f32(vi0x0));
168 vo1x0 = psimd_qfma_f32(vo1x0, vk00c2, psimd_splat3_f32(vi2x0));
169 vo0x1 = psimd_qfma_f32(vo0x1, vk00c2, psimd_splat1_f32(vi0x2));
170 vo1x1 = psimd_qfma_f32(vo1x1, vk00c2, psimd_splat1_f32(vi2x2));
171
172 const psimd_f32 vk10c2 = psimd_load_f32(w + 32);
173
174 vo0x0 = psimd_qfma_f32(vo0x0, vk10c2, psimd_splat3_f32(vi1x0));
175 vo1x0 = psimd_qfma_f32(vo1x0, vk10c2, psimd_splat3_f32(vi3x0));
176 vo0x1 = psimd_qfma_f32(vo0x1, vk10c2, psimd_splat1_f32(vi1x2));
177 vo1x1 = psimd_qfma_f32(vo1x1, vk10c2, psimd_splat1_f32(vi3x2));
178
179 const psimd_f32 vk20c2 = psimd_load_f32(w + 36);
180
181 vo0x0 = psimd_qfma_f32(vo0x0, vk20c2, psimd_splat3_f32(vi2x0));
182 vo1x0 = psimd_qfma_f32(vo1x0, vk20c2, psimd_splat3_f32(vi4x0));
183 vo0x1 = psimd_qfma_f32(vo0x1, vk20c2, psimd_splat1_f32(vi2x2));
184 vo1x1 = psimd_qfma_f32(vo1x1, vk20c2, psimd_splat1_f32(vi4x2));
185
186 const psimd_f32 vk01c0 = psimd_load_f32(w + 40);
187
188 vo0x0 = psimd_qfma_f32(vo0x0, vk01c0, psimd_splat0_f32(vi0x1));
189 vo1x0 = psimd_qfma_f32(vo1x0, vk01c0, psimd_splat0_f32(vi2x1));
190 vo0x1 = psimd_qfma_f32(vo0x1, vk01c0, psimd_splat2_f32(vi0x2));
191 vo1x1 = psimd_qfma_f32(vo1x1, vk01c0, psimd_splat2_f32(vi2x2));
192
193 const psimd_f32 vk11c0 = psimd_load_f32(w + 44);
194
195 vo0x0 = psimd_qfma_f32(vo0x0, vk11c0, psimd_splat0_f32(vi1x1));
196 vo1x0 = psimd_qfma_f32(vo1x0, vk11c0, psimd_splat0_f32(vi3x1));
197 vo0x1 = psimd_qfma_f32(vo0x1, vk11c0, psimd_splat2_f32(vi1x2));
198 vo1x1 = psimd_qfma_f32(vo1x1, vk11c0, psimd_splat2_f32(vi3x2));
199
200 const psimd_f32 vk21c0 = psimd_load_f32(w + 48);
201
202 vo0x0 = psimd_qfma_f32(vo0x0, vk21c0, psimd_splat0_f32(vi2x1));
203 vo1x0 = psimd_qfma_f32(vo1x0, vk21c0, psimd_splat0_f32(vi4x1));
204 vo0x1 = psimd_qfma_f32(vo0x1, vk21c0, psimd_splat2_f32(vi2x2));
205 vo1x1 = psimd_qfma_f32(vo1x1, vk21c0, psimd_splat2_f32(vi4x2));
206
207 const psimd_f32 vk01c1 = psimd_load_f32(w + 52);
208
209 vo0x0 = psimd_qfma_f32(vo0x0, vk01c1, psimd_splat1_f32(vi0x1));
210 vo1x0 = psimd_qfma_f32(vo1x0, vk01c1, psimd_splat1_f32(vi2x1));
211 vo0x1 = psimd_qfma_f32(vo0x1, vk01c1, psimd_splat3_f32(vi0x2));
212 vo1x1 = psimd_qfma_f32(vo1x1, vk01c1, psimd_splat3_f32(vi2x2));
213
214 const psimd_f32 vk11c1 = psimd_load_f32(w + 56);
215
216 vo0x0 = psimd_qfma_f32(vo0x0, vk11c1, psimd_splat1_f32(vi1x1));
217 vo1x0 = psimd_qfma_f32(vo1x0, vk11c1, psimd_splat1_f32(vi3x1));
218 vo0x1 = psimd_qfma_f32(vo0x1, vk11c1, psimd_splat3_f32(vi1x2));
219 vo1x1 = psimd_qfma_f32(vo1x1, vk11c1, psimd_splat3_f32(vi3x2));
220
221 const psimd_f32 vk21c1 = psimd_load_f32(w + 60);
222
223 vo0x0 = psimd_qfma_f32(vo0x0, vk21c1, psimd_splat1_f32(vi2x1));
224 vo1x0 = psimd_qfma_f32(vo1x0, vk21c1, psimd_splat1_f32(vi4x1));
225 vo0x1 = psimd_qfma_f32(vo0x1, vk21c1, psimd_splat3_f32(vi2x2));
226 vo1x1 = psimd_qfma_f32(vo1x1, vk21c1, psimd_splat3_f32(vi4x2));
227
228 const psimd_f32 vk01c2 = psimd_load_f32(w + 64);
229
230 // viMx3 = ( iM4c2, iM4c1, iM4c0, iM3c2 )
231 const psimd_f32 vi0x3 = psimd_load_f32(i0); i0 += 4;
232 const psimd_f32 vi1x3 = psimd_load_f32(i1); i1 += 4;
233 const psimd_f32 vi2x3 = psimd_load_f32(i2); i2 += 4;
234 const psimd_f32 vi3x3 = psimd_load_f32(i3); i3 += 4;
235 const psimd_f32 vi4x3 = psimd_load_f32(i4); i4 += 4;
236
237 vo0x0 = psimd_qfma_f32(vo0x0, vk01c2, psimd_splat2_f32(vi0x1));
238 vo1x0 = psimd_qfma_f32(vo1x0, vk01c2, psimd_splat2_f32(vi2x1));
239 vo0x1 = psimd_qfma_f32(vo0x1, vk01c2, psimd_splat0_f32(vi0x3));
240 vo1x1 = psimd_qfma_f32(vo1x1, vk01c2, psimd_splat0_f32(vi2x3));
241
242 const psimd_f32 vk11c2 = psimd_load_f32(w + 68);
243
244 vo0x0 = psimd_qfma_f32(vo0x0, vk11c2, psimd_splat2_f32(vi1x1));
245 vo1x0 = psimd_qfma_f32(vo1x0, vk11c2, psimd_splat2_f32(vi3x1));
246 vo0x1 = psimd_qfma_f32(vo0x1, vk11c2, psimd_splat0_f32(vi1x3));
247 vo1x1 = psimd_qfma_f32(vo1x1, vk11c2, psimd_splat0_f32(vi3x3));
248
249 const psimd_f32 vk21c2 = psimd_load_f32(w + 72);
250
251 vo0x0 = psimd_qfma_f32(vo0x0, vk21c2, psimd_splat2_f32(vi2x1));
252 vo1x0 = psimd_qfma_f32(vo1x0, vk21c2, psimd_splat2_f32(vi4x1));
253 vo0x1 = psimd_qfma_f32(vo0x1, vk21c2, psimd_splat0_f32(vi2x3));
254 vo1x1 = psimd_qfma_f32(vo1x1, vk21c2, psimd_splat0_f32(vi4x3));
255
256 const psimd_f32 vk02c0 = psimd_load_f32(w + 76);
257
258 vo0x0 = psimd_qfma_f32(vo0x0, vk02c0, psimd_splat3_f32(vi0x1));
259 vo1x0 = psimd_qfma_f32(vo1x0, vk02c0, psimd_splat3_f32(vi2x1));
260 vo0x1 = psimd_qfma_f32(vo0x1, vk02c0, psimd_splat1_f32(vi0x3));
261 vo1x1 = psimd_qfma_f32(vo1x1, vk02c0, psimd_splat1_f32(vi2x3));
262
263 const psimd_f32 vk12c0 = psimd_load_f32(w + 80);
264
265 vo0x0 = psimd_qfma_f32(vo0x0, vk12c0, psimd_splat3_f32(vi1x1));
266 vo1x0 = psimd_qfma_f32(vo1x0, vk12c0, psimd_splat3_f32(vi3x1));
267 vo0x1 = psimd_qfma_f32(vo0x1, vk12c0, psimd_splat1_f32(vi1x3));
268 vo1x1 = psimd_qfma_f32(vo1x1, vk12c0, psimd_splat1_f32(vi3x3));
269
270 const psimd_f32 vk22c0 = psimd_load_f32(w + 84);
271
272 vo0x0 = psimd_qfma_f32(vo0x0, vk22c0, psimd_splat3_f32(vi2x1));
273 vo1x0 = psimd_qfma_f32(vo1x0, vk22c0, psimd_splat3_f32(vi4x1));
274 vo0x1 = psimd_qfma_f32(vo0x1, vk22c0, psimd_splat1_f32(vi2x3));
275 vo1x1 = psimd_qfma_f32(vo1x1, vk22c0, psimd_splat1_f32(vi4x3));
276
277 const psimd_f32 vk02c1 = psimd_load_f32(w + 88);
278
279 vo0x0 = psimd_qfma_f32(vo0x0, vk02c1, psimd_splat0_f32(vi0x2));
280 vo1x0 = psimd_qfma_f32(vo1x0, vk02c1, psimd_splat0_f32(vi2x2));
281 vo0x1 = psimd_qfma_f32(vo0x1, vk02c1, psimd_splat2_f32(vi0x3));
282 vo1x1 = psimd_qfma_f32(vo1x1, vk02c1, psimd_splat2_f32(vi2x3));
283
284 const psimd_f32 vk12c1 = psimd_load_f32(w + 92);
285
286 vo0x0 = psimd_qfma_f32(vo0x0, vk12c1, psimd_splat0_f32(vi1x2));
287 vo1x0 = psimd_qfma_f32(vo1x0, vk12c1, psimd_splat0_f32(vi3x2));
288 vo0x1 = psimd_qfma_f32(vo0x1, vk12c1, psimd_splat2_f32(vi1x3));
289 vo1x1 = psimd_qfma_f32(vo1x1, vk12c1, psimd_splat2_f32(vi3x3));
290
291 const psimd_f32 vk22c1 = psimd_load_f32(w + 96);
292
293 vo0x0 = psimd_qfma_f32(vo0x0, vk22c1, psimd_splat0_f32(vi2x2));
294 vo1x0 = psimd_qfma_f32(vo1x0, vk22c1, psimd_splat0_f32(vi4x2));
295 vo0x1 = psimd_qfma_f32(vo0x1, vk22c1, psimd_splat2_f32(vi2x3));
296 vo1x1 = psimd_qfma_f32(vo1x1, vk22c1, psimd_splat2_f32(vi4x3));
297
298 const psimd_f32 vk02c2 = psimd_load_f32(w + 100);
299
300 vo0x0 = psimd_qfma_f32(vo0x0, vk02c2, psimd_splat1_f32(vi0x2));
301 vo1x0 = psimd_qfma_f32(vo1x0, vk02c2, psimd_splat1_f32(vi2x2));
302 vo0x1 = psimd_qfma_f32(vo0x1, vk02c2, psimd_splat3_f32(vi0x3));
303 vo1x1 = psimd_qfma_f32(vo1x1, vk02c2, psimd_splat3_f32(vi2x3));
304
305 const psimd_f32 vk12c2 = psimd_load_f32(w + 104);
306
307 vo0x0 = psimd_qfma_f32(vo0x0, vk12c2, psimd_splat1_f32(vi1x2));
308 vo1x0 = psimd_qfma_f32(vo1x0, vk12c2, psimd_splat1_f32(vi3x2));
309 vo0x1 = psimd_qfma_f32(vo0x1, vk12c2, psimd_splat3_f32(vi1x3));
310 vo1x1 = psimd_qfma_f32(vo1x1, vk12c2, psimd_splat3_f32(vi3x3));
311
312 const psimd_f32 vk22c2 = psimd_load_f32(w + 108);
313
314 vo0x0 = psimd_qfma_f32(vo0x0, vk22c2, psimd_splat1_f32(vi2x2));
315 vo1x0 = psimd_qfma_f32(vo1x0, vk22c2, psimd_splat1_f32(vi4x2));
316 vo0x1 = psimd_qfma_f32(vo0x1, vk22c2, psimd_splat3_f32(vi2x3));
317 vo1x1 = psimd_qfma_f32(vo1x1, vk22c2, psimd_splat3_f32(vi4x3));
318
319 vi0x0 = vi0x3;
320 vi1x0 = vi1x3;
321 vi2x0 = vi2x3;
322 vi3x0 = vi3x3;
323 vi4x0 = vi4x3;
324
325 vo0x0 = psimd_max_f32(vo0x0, vmin);
326 vo1x0 = psimd_max_f32(vo1x0, vmin);
327 vo0x1 = psimd_max_f32(vo0x1, vmin);
328 vo1x1 = psimd_max_f32(vo1x1, vmin);
329
330 vo0x0 = psimd_min_f32(vo0x0, vmax);
331 vo1x0 = psimd_min_f32(vo1x0, vmax);
332 vo0x1 = psimd_min_f32(vo0x1, vmax);
333 vo1x1 = psimd_min_f32(vo1x1, vmax);
334
335 const psimd_f32 vo0c01 = psimd_interleave_lo_f32(vo0x0, vo0x1);
336 const psimd_f32 vo0c23 = psimd_interleave_hi_f32(vo0x0, vo0x1);
337 const psimd_f32 vo1c01 = psimd_interleave_lo_f32(vo1x0, vo1x1);
338 const psimd_f32 vo1c23 = psimd_interleave_hi_f32(vo1x0, vo1x1);
339
340 // Always 2+ output width elements remaining
341 psimd_store2_f32(o1c0, vo1c01); o1c0 += 2;
342 psimd_store2_f32(o1c1, psimd_concat_hi_f32(vo1c01, vo1c01)); o1c1 += 2;
343 psimd_store2_f32(o1c2, vo1c23); o1c2 += 2;
344 psimd_store2_f32(o1c3, psimd_concat_hi_f32(vo1c23, vo1c23)); o1c3 += 2;
345
346 psimd_store2_f32(o0c0, vo0c01); o0c0 += 2;
347 psimd_store2_f32(o0c1, psimd_concat_hi_f32(vo0c01, vo0c01)); o0c1 += 2;
348 psimd_store2_f32(o0c2, vo0c23); o0c2 += 2;
349 psimd_store2_f32(o0c3, psimd_concat_hi_f32(vo0c23, vo0c23)); o0c3 += 2;
350 }
351 assert(iw < 4);
352 if XNN_UNLIKELY(iw != 0) {
353 psimd_f32 vo0x0 = psimd_load_f32(w);
354 psimd_f32 vo1x0 = vo0x0;
355 psimd_f32 vo0x1 = vo0x0;
356 psimd_f32 vo1x1 = vo0x0;
357
358 const psimd_f32 vk00c0 = psimd_load_f32(w + 4);
359
360 // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
361 psimd_f32 vi0x1 = psimd_load_f32(i0);
362 psimd_f32 vi1x1 = psimd_load_f32(i1);
363 psimd_f32 vi2x1 = psimd_load_f32(i2);
364 psimd_f32 vi3x1 = psimd_load_f32(i3);
365 psimd_f32 vi4x1 = psimd_load_f32(i4);
366
367 vo0x0 = psimd_qfma_f32(vo0x0, vk00c0, psimd_splat1_f32(vi0x0));
368 vo1x0 = psimd_qfma_f32(vo1x0, vk00c0, psimd_splat1_f32(vi2x0));
369 if (iw > 2) {
370 vo0x1 = psimd_qfma_f32(vo0x1, vk00c0, psimd_splat3_f32(vi0x1));
371 vo1x1 = psimd_qfma_f32(vo1x1, vk00c0, psimd_splat3_f32(vi2x1));
372 }
373
374 const psimd_f32 vk10c0 = psimd_load_f32(w + 8);
375
376 vo0x0 = psimd_qfma_f32(vo0x0, vk10c0, psimd_splat1_f32(vi1x0));
377 vo1x0 = psimd_qfma_f32(vo1x0, vk10c0, psimd_splat1_f32(vi3x0));
378 if (iw > 2) {
379 vo0x1 = psimd_qfma_f32(vo0x1, vk10c0, psimd_splat3_f32(vi1x1));
380 vo1x1 = psimd_qfma_f32(vo1x1, vk10c0, psimd_splat3_f32(vi3x1));
381 }
382
383 const psimd_f32 vk20c0 = psimd_load_f32(w + 12);
384
385 vo0x0 = psimd_qfma_f32(vo0x0, vk20c0, psimd_splat1_f32(vi2x0));
386 vo1x0 = psimd_qfma_f32(vo1x0, vk20c0, psimd_splat1_f32(vi4x0));
387 if (iw > 2) {
388 vo0x1 = psimd_qfma_f32(vo0x1, vk20c0, psimd_splat3_f32(vi2x1));
389 vo1x1 = psimd_qfma_f32(vo1x1, vk20c0, psimd_splat3_f32(vi4x1));
390 }
391
392 const psimd_f32 vk00c1 = psimd_load_f32(w + 16);
393
394 psimd_f32 vi0x2 = psimd_zero_f32();
395 psimd_f32 vi1x2 = psimd_zero_f32();
396 psimd_f32 vi2x2 = psimd_zero_f32();
397 psimd_f32 vi3x2 = psimd_zero_f32();
398 psimd_f32 vi4x2 = psimd_zero_f32();
399 if (iw >= 2) {
400 // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
401 vi0x2 = psimd_load_f32(i0 + 4);
402 vi1x2 = psimd_load_f32(i1 + 4);
403 vi2x2 = psimd_load_f32(i2 + 4);
404 vi3x2 = psimd_load_f32(i3 + 4);
405 vi4x2 = psimd_load_f32(i4 + 4);
406 }
407
408 vo0x0 = psimd_qfma_f32(vo0x0, vk00c1, psimd_splat2_f32(vi0x0));
409 vo1x0 = psimd_qfma_f32(vo1x0, vk00c1, psimd_splat2_f32(vi2x0));
410 vo0x1 = psimd_qfma_f32(vo0x1, vk00c1, psimd_splat0_f32(vi0x2));
411 vo1x1 = psimd_qfma_f32(vo1x1, vk00c1, psimd_splat0_f32(vi2x2));
412
413 const psimd_f32 vk10c1 = psimd_load_f32(w + 20);
414
415 vo0x0 = psimd_qfma_f32(vo0x0, vk10c1, psimd_splat2_f32(vi1x0));
416 vo1x0 = psimd_qfma_f32(vo1x0, vk10c1, psimd_splat2_f32(vi3x0));
417 vo0x1 = psimd_qfma_f32(vo0x1, vk10c1, psimd_splat0_f32(vi1x2));
418 vo1x1 = psimd_qfma_f32(vo1x1, vk10c1, psimd_splat0_f32(vi3x2));
419
420 const psimd_f32 vk20c1 = psimd_load_f32(w + 24);
421
422 vo0x0 = psimd_qfma_f32(vo0x0, vk20c1, psimd_splat2_f32(vi2x0));
423 vo1x0 = psimd_qfma_f32(vo1x0, vk20c1, psimd_splat2_f32(vi4x0));
424 vo0x1 = psimd_qfma_f32(vo0x1, vk20c1, psimd_splat0_f32(vi2x2));
425 vo1x1 = psimd_qfma_f32(vo1x1, vk20c1, psimd_splat0_f32(vi4x2));
426
427 const psimd_f32 vk00c2 = psimd_load_f32(w + 28);
428
429 vo0x0 = psimd_qfma_f32(vo0x0, vk00c2, psimd_splat3_f32(vi0x0));
430 vo1x0 = psimd_qfma_f32(vo1x0, vk00c2, psimd_splat3_f32(vi2x0));
431 vo0x1 = psimd_qfma_f32(vo0x1, vk00c2, psimd_splat1_f32(vi0x2));
432 vo1x1 = psimd_qfma_f32(vo1x1, vk00c2, psimd_splat1_f32(vi2x2));
433
434 const psimd_f32 vk10c2 = psimd_load_f32(w + 32);
435
436 vo0x0 = psimd_qfma_f32(vo0x0, vk10c2, psimd_splat3_f32(vi1x0));
437 vo1x0 = psimd_qfma_f32(vo1x0, vk10c2, psimd_splat3_f32(vi3x0));
438 vo0x1 = psimd_qfma_f32(vo0x1, vk10c2, psimd_splat1_f32(vi1x2));
439 vo1x1 = psimd_qfma_f32(vo1x1, vk10c2, psimd_splat1_f32(vi3x2));
440
441 const psimd_f32 vk20c2 = psimd_load_f32(w + 36);
442
443 vo0x0 = psimd_qfma_f32(vo0x0, vk20c2, psimd_splat3_f32(vi2x0));
444 vo1x0 = psimd_qfma_f32(vo1x0, vk20c2, psimd_splat3_f32(vi4x0));
445 vo0x1 = psimd_qfma_f32(vo0x1, vk20c2, psimd_splat1_f32(vi2x2));
446 vo1x1 = psimd_qfma_f32(vo1x1, vk20c2, psimd_splat1_f32(vi4x2));
447
448 const psimd_f32 vk01c0 = psimd_load_f32(w + 40);
449
450 vo0x0 = psimd_qfma_f32(vo0x0, vk01c0, psimd_splat0_f32(vi0x1));
451 vo1x0 = psimd_qfma_f32(vo1x0, vk01c0, psimd_splat0_f32(vi2x1));
452 if (iw > 2) {
453 vo0x1 = psimd_qfma_f32(vo0x1, vk01c0, psimd_splat2_f32(vi0x2));
454 vo1x1 = psimd_qfma_f32(vo1x1, vk01c0, psimd_splat2_f32(vi2x2));
455 }
456
457 const psimd_f32 vk11c0 = psimd_load_f32(w + 44);
458
459 vo0x0 = psimd_qfma_f32(vo0x0, vk11c0, psimd_splat0_f32(vi1x1));
460 vo1x0 = psimd_qfma_f32(vo1x0, vk11c0, psimd_splat0_f32(vi3x1));
461 if (iw > 2) {
462 vo0x1 = psimd_qfma_f32(vo0x1, vk11c0, psimd_splat2_f32(vi1x2));
463 vo1x1 = psimd_qfma_f32(vo1x1, vk11c0, psimd_splat2_f32(vi3x2));
464 }
465
466 const psimd_f32 vk21c0 = psimd_load_f32(w + 48);
467
468 vo0x0 = psimd_qfma_f32(vo0x0, vk21c0, psimd_splat0_f32(vi2x1));
469 vo1x0 = psimd_qfma_f32(vo1x0, vk21c0, psimd_splat0_f32(vi4x1));
470 if (iw > 2) {
471 vo0x1 = psimd_qfma_f32(vo0x1, vk21c0, psimd_splat2_f32(vi2x2));
472 vo1x1 = psimd_qfma_f32(vo1x1, vk21c0, psimd_splat2_f32(vi4x2));
473 }
474
475 const psimd_f32 vk01c1 = psimd_load_f32(w + 52);
476
477 vo0x0 = psimd_qfma_f32(vo0x0, vk01c1, psimd_splat1_f32(vi0x1));
478 vo1x0 = psimd_qfma_f32(vo1x0, vk01c1, psimd_splat1_f32(vi2x1));
479 if (iw > 2) {
480 vo0x1 = psimd_qfma_f32(vo0x1, vk01c1, psimd_splat3_f32(vi0x2));
481 vo1x1 = psimd_qfma_f32(vo1x1, vk01c1, psimd_splat3_f32(vi2x2));
482 }
483
484 const psimd_f32 vk11c1 = psimd_load_f32(w + 56);
485
486 vo0x0 = psimd_qfma_f32(vo0x0, vk11c1, psimd_splat1_f32(vi1x1));
487 vo1x0 = psimd_qfma_f32(vo1x0, vk11c1, psimd_splat1_f32(vi3x1));
488 if (iw > 2) {
489 vo0x1 = psimd_qfma_f32(vo0x1, vk11c1, psimd_splat3_f32(vi1x2));
490 vo1x1 = psimd_qfma_f32(vo1x1, vk11c1, psimd_splat3_f32(vi3x2));
491 }
492
493 const psimd_f32 vk21c1 = psimd_load_f32(w + 60);
494
495 vo0x0 = psimd_qfma_f32(vo0x0, vk21c1, psimd_splat1_f32(vi2x1));
496 vo1x0 = psimd_qfma_f32(vo1x0, vk21c1, psimd_splat1_f32(vi4x1));
497 if (iw > 2) {
498 vo0x1 = psimd_qfma_f32(vo0x1, vk21c1, psimd_splat3_f32(vi2x2));
499 vo1x1 = psimd_qfma_f32(vo1x1, vk21c1, psimd_splat3_f32(vi4x2));
500 }
501
502 const psimd_f32 vk01c2 = psimd_load_f32(w + 64);
503
504 psimd_f32 vi0x3 = psimd_zero_f32();
505 psimd_f32 vi1x3 = psimd_zero_f32();
506 psimd_f32 vi2x3 = psimd_zero_f32();
507 psimd_f32 vi3x3 = psimd_zero_f32();
508 psimd_f32 vi4x3 = psimd_zero_f32();
509 if (iw > 2) {
510 // viMx3 = ( 0.0, 0.0, 0.0, iM3c2 )
511 vi0x3 = psimd_load1_f32(i0 + 8);
512 vi1x3 = psimd_load1_f32(i1 + 8);
513 vi2x3 = psimd_load1_f32(i2 + 8);
514 vi3x3 = psimd_load1_f32(i3 + 8);
515 vi4x3 = psimd_load1_f32(i4 + 8);
516 }
517
518 vo0x0 = psimd_qfma_f32(vo0x0, vk01c2, psimd_splat2_f32(vi0x1));
519 vo1x0 = psimd_qfma_f32(vo1x0, vk01c2, psimd_splat2_f32(vi2x1));
520 vo0x1 = psimd_qfma_f32(vo0x1, vk01c2, psimd_splat0_f32(vi0x3));
521 vo1x1 = psimd_qfma_f32(vo1x1, vk01c2, psimd_splat0_f32(vi2x3));
522
523 const psimd_f32 vk11c2 = psimd_load_f32(w + 68);
524
525 vo0x0 = psimd_qfma_f32(vo0x0, vk11c2, psimd_splat2_f32(vi1x1));
526 vo1x0 = psimd_qfma_f32(vo1x0, vk11c2, psimd_splat2_f32(vi3x1));
527 vo0x1 = psimd_qfma_f32(vo0x1, vk11c2, psimd_splat0_f32(vi1x3));
528 vo1x1 = psimd_qfma_f32(vo1x1, vk11c2, psimd_splat0_f32(vi3x3));
529
530 const psimd_f32 vk21c2 = psimd_load_f32(w + 72);
531
532 vo0x0 = psimd_qfma_f32(vo0x0, vk21c2, psimd_splat2_f32(vi2x1));
533 vo1x0 = psimd_qfma_f32(vo1x0, vk21c2, psimd_splat2_f32(vi4x1));
534 vo0x1 = psimd_qfma_f32(vo0x1, vk21c2, psimd_splat0_f32(vi2x3));
535 vo1x1 = psimd_qfma_f32(vo1x1, vk21c2, psimd_splat0_f32(vi4x3));
536
537 if (iw >= 2) {
538 const psimd_f32 vk02c0 = psimd_load_f32(w + 76);
539
540 vo0x0 = psimd_qfma_f32(vo0x0, vk02c0, psimd_splat3_f32(vi0x1));
541 vo1x0 = psimd_qfma_f32(vo1x0, vk02c0, psimd_splat3_f32(vi2x1));
542
543 const psimd_f32 vk12c0 = psimd_load_f32(w + 80);
544
545 vo0x0 = psimd_qfma_f32(vo0x0, vk12c0, psimd_splat3_f32(vi1x1));
546 vo1x0 = psimd_qfma_f32(vo1x0, vk12c0, psimd_splat3_f32(vi3x1));
547
548 const psimd_f32 vk22c0 = psimd_load_f32(w + 84);
549
550 vo0x0 = psimd_qfma_f32(vo0x0, vk22c0, psimd_splat3_f32(vi2x1));
551 vo1x0 = psimd_qfma_f32(vo1x0, vk22c0, psimd_splat3_f32(vi4x1));
552
553 const psimd_f32 vk02c1 = psimd_load_f32(w + 88);
554
555 vo0x0 = psimd_qfma_f32(vo0x0, vk02c1, psimd_splat0_f32(vi0x2));
556 vo1x0 = psimd_qfma_f32(vo1x0, vk02c1, psimd_splat0_f32(vi2x2));
557
558 const psimd_f32 vk12c1 = psimd_load_f32(w + 92);
559
560 vo0x0 = psimd_qfma_f32(vo0x0, vk12c1, psimd_splat0_f32(vi1x2));
561 vo1x0 = psimd_qfma_f32(vo1x0, vk12c1, psimd_splat0_f32(vi3x2));
562
563 const psimd_f32 vk22c1 = psimd_load_f32(w + 96);
564
565 vo0x0 = psimd_qfma_f32(vo0x0, vk22c1, psimd_splat0_f32(vi2x2));
566 vo1x0 = psimd_qfma_f32(vo1x0, vk22c1, psimd_splat0_f32(vi4x2));
567
568 const psimd_f32 vk02c2 = psimd_load_f32(w + 100);
569
570 vo0x0 = psimd_qfma_f32(vo0x0, vk02c2, psimd_splat1_f32(vi0x2));
571 vo1x0 = psimd_qfma_f32(vo1x0, vk02c2, psimd_splat1_f32(vi2x2));
572
573 const psimd_f32 vk12c2 = psimd_load_f32(w + 104);
574
575 vo0x0 = psimd_qfma_f32(vo0x0, vk12c2, psimd_splat1_f32(vi1x2));
576 vo1x0 = psimd_qfma_f32(vo1x0, vk12c2, psimd_splat1_f32(vi3x2));
577
578 const psimd_f32 vk22c2 = psimd_load_f32(w + 108);
579
580 vo0x0 = psimd_qfma_f32(vo0x0, vk22c2, psimd_splat1_f32(vi2x2));
581 vo1x0 = psimd_qfma_f32(vo1x0, vk22c2, psimd_splat1_f32(vi4x2));
582 }
583
584 vo0x0 = psimd_max_f32(vo0x0, vmin);
585 vo1x0 = psimd_max_f32(vo1x0, vmin);
586 vo0x1 = psimd_max_f32(vo0x1, vmin);
587 vo1x1 = psimd_max_f32(vo1x1, vmin);
588
589 vo0x0 = psimd_min_f32(vo0x0, vmax);
590 vo1x0 = psimd_min_f32(vo1x0, vmax);
591 vo0x1 = psimd_min_f32(vo0x1, vmax);
592 vo1x1 = psimd_min_f32(vo1x1, vmax);
593
594 if (iw == 3) {
595 // Exactly 2 output width elements remaining
596 const psimd_f32 vo0c01 = psimd_interleave_lo_f32(vo0x0, vo0x1);
597 const psimd_f32 vo0c23 = psimd_interleave_hi_f32(vo0x0, vo0x1);
598 const psimd_f32 vo1c01 = psimd_interleave_lo_f32(vo1x0, vo1x1);
599 const psimd_f32 vo1c23 = psimd_interleave_hi_f32(vo1x0, vo1x1);
600
601 psimd_store2_f32(o1c0, vo1c01); o1c0 += 2;
602 psimd_store2_f32(o1c1, psimd_concat_hi_f32(vo1c01, vo1c01)); o1c1 += 2;
603 psimd_store2_f32(o1c2, vo1c23); o1c2 += 2;
604 psimd_store2_f32(o1c3, psimd_concat_hi_f32(vo1c23, vo1c23)); o1c3 += 2;
605
606 psimd_store2_f32(o0c0, vo0c01); o0c0 += 2;
607 psimd_store2_f32(o0c1, psimd_concat_hi_f32(vo0c01, vo0c01)); o0c1 += 2;
608 psimd_store2_f32(o0c2, vo0c23); o0c2 += 2;
609 psimd_store2_f32(o0c3, psimd_concat_hi_f32(vo0c23, vo0c23)); o0c3 += 2;
610 } else {
611 // Exactly 1 output width element remaining
612
613 psimd_store1_f32(o1c0, psimd_splat0_f32(vo1x0)); o1c0 += 1;
614 psimd_store1_f32(o1c1, psimd_splat1_f32(vo1x0)); o1c1 += 1;
615 psimd_store1_f32(o1c2, psimd_splat2_f32(vo1x0)); o1c2 += 1;
616 psimd_store1_f32(o1c3, psimd_splat3_f32(vo1x0)); o1c3 += 1;
617
618 psimd_store1_f32(o0c0, psimd_splat0_f32(vo0x0)); o0c0 += 1;
619 psimd_store1_f32(o0c1, psimd_splat1_f32(vo0x0)); o0c1 += 1;
620 psimd_store1_f32(o0c2, psimd_splat2_f32(vo0x0)); o0c2 += 1;
621 psimd_store1_f32(o0c3, psimd_splat3_f32(vo0x0)); o0c3 += 1;
622 }
623 }
624 // Move output pointers back to the position of the first pixel in a row,
625 // and forward to the next block of output channels.
626 o0c0 = (float*) ((uintptr_t) o0c0 + output_channel_increment);
627 o0c1 = (float*) ((uintptr_t) o0c1 + output_channel_increment);
628 o0c2 = (float*) ((uintptr_t) o0c2 + output_channel_increment);
629 o0c3 = (float*) ((uintptr_t) o0c3 + output_channel_increment);
630 o1c0 = (float*) ((uintptr_t) o1c0 + output_channel_increment);
631 o1c1 = (float*) ((uintptr_t) o1c1 + output_channel_increment);
632 o1c2 = (float*) ((uintptr_t) o1c2 + output_channel_increment);
633 o1c3 = (float*) ((uintptr_t) o1c3 + output_channel_increment);
634 // Revert input pointers to the position of the first pixel in a row
635 i0 = (const float*) ((uintptr_t) i0 - input_width_increment);
636 i1 = (const float*) ((uintptr_t) i1 - input_width_increment);
637 i2 = (const float*) ((uintptr_t) i2 - input_width_increment);
638 i3 = (const float*) ((uintptr_t) i3 - input_width_increment);
639 i4 = (const float*) ((uintptr_t) i4 - input_width_increment);
640 // Move to the block of weights for the next 4 output channels
641 w += 112;
642 c = doz(c, 4);
643 } while (c != 0);
644 // Move output pointers forward to the next two rows
645 output0 = (float*) ((uintptr_t) output1 + output_height_stride);
646 output1 = (float*) ((uintptr_t) output0 + output_height_stride);
647 // Move input pointers forward to the next four rows
648 i0 = i4;
649 i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
650 i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
651 i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
652 i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
653 }
654}