blob: ab13696dbb582a7ae986d794c0c9d62a4543da31 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <stdint.h>
12#include <xnnpack/math.h>
13#include <xnnpack/operator.h>
14
15
16static inline void xnn_pack_q8_gemm_goi_w(
17 size_t g,
18 size_t nc,
19 size_t kc,
20 uint32_t nr,
21 uint32_t kr,
22 uint8_t izp,
23 uint8_t kzp,
24 const uint8_t* k,
25 const int32_t* b,
26 void* packed_w)
27{
28 const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
29 do {
30 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
31 const size_t nr_block_size = min(nc - nr_block_start, nr);
32 int32_t* packed_b = (int32_t*) packed_w;
33 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
34 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
35 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
36 }
37 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
38 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
39 const size_t kr_block_size = min(kc - kr_block_start, kr);
40 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
41 int32_t ksum = 0;
42 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
43 const uint8_t kv = k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
44 ksum += (int32_t) kv;
45 *((uint8_t*) packed_w) = kv;
46 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
47 }
48 packed_b[nr_block_offset] -= ksum * (int32_t) izp;
49 packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
50 }
51 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
52 }
53 }
54 k += nc * kc;
55 b += nc;
56 } while (--g != 0);
57}
58
59static inline void xnn_pack_q8_conv_goki_w(
60 size_t g,
61 size_t nc,
62 size_t ks,
63 size_t kc,
64 uint32_t nr,
65 uint32_t kr,
66 uint8_t izp,
67 uint8_t kzp,
68 const uint8_t* k,
69 const int32_t* b,
70 void* packed_w)
71{
72 const int32_t boff = (int32_t) ks * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
73 do {
74 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
75 const size_t nr_block_size = min(nc - nr_block_start, nr);
76 int32_t* packed_b = (int32_t*) packed_w;
77 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
78 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
79 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
80 }
81 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
82 for (size_t ki = 0; ki < ks; ki++) {
83 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
84 const size_t kr_block_size = min(kc - kr_block_start, kr);
85 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
86 int32_t ksum = 0;
87 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
88 const uint8_t kv =
89 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
90 ksum += (int32_t) kv;
91 *((uint8_t*) packed_w) = kv;
92 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
93 }
94 packed_b[nr_block_offset] -= ksum * (int32_t) izp;
95 packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
96 }
97 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
98 }
99 }
100 }
101 k += ks * kc * nc;
102 b += nc;
103 } while (--g != 0);
104}
105
106static inline void xnn_pack_q8_conv_kgo_w(
107 size_t g,
108 size_t nc,
109 size_t ks,
110 uint32_t nr,
111 uint32_t kr,
112 uint8_t izp,
113 uint8_t kzp,
114 const uint8_t* k,
115 const int32_t* b,
116 void* packed_w)
117{
118 const int32_t boff = (int32_t) ks * (int32_t) izp * (int32_t) kzp;
119 for (size_t i = 0; i < g; i++) {
120 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
121 const size_t nr_block_size = min(nc - nr_block_start, nr);
122 int32_t* packed_b = (int32_t*) packed_w;
123 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
124 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
125 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
126 }
127 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
128 for (size_t ki = 0; ki < ks; ki++) {
129 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
130 const uint8_t kv =
131 k[ki * g * nc + (nr_block_start + nr_block_offset)];
132 *((uint8_t*) packed_w) = kv;
133 packed_b[nr_block_offset] -= (int32_t) kv * (int32_t) izp;
134 packed_w = (void*) ((uintptr_t) packed_w + kr * sizeof(uint8_t));
135 }
136 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
137 }
138 }
139 k += nc;
140 b += nc;
141 }
142}
143
144static inline void xnn_pack_q8_deconv_goki_w(
145 size_t g,
146 size_t nc,
147 size_t kh,
148 size_t kw,
149 size_t kc,
150 size_t sh,
151 size_t sw,
152 size_t nr,
153 size_t kr,
154 uint8_t izp,
155 uint8_t kzp,
156 const uint8_t* k,
157 const int32_t* b,
158 void* packed_w,
159 struct subconvolution_params* params)
160{
161 for (size_t i = 0; i < g; i++) {
162 for (size_t oy = 0; oy < sh; oy++) {
163 for (size_t ox = 0; ox < sw; ox++) {
164 if (i == 0) {
165 (*params++).weights = packed_w;
166 }
167 const int32_t boff = (int32_t) divide_round_up(kh - oy, sh) * (int32_t) divide_round_up(kw - ox, sw) * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
168 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
169 const size_t nr_block_size = min(nc - nr_block_start, nr);
170 int32_t* packed_b = (int32_t*) packed_w;
171 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
172 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
173 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
174 }
175 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
176 for (size_t ky = oy; ky < kh; ky += sh) {
177 for (size_t kx = ox; kx < kw; kx += sw) {
178 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
179 const size_t kr_block_size = min(kc - kr_block_start, kr);
180 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
181 int32_t ksum = 0;
182 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
183 const uint8_t kv =
184 k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
185 ksum += (int32_t) kv;
186 *((uint8_t*) packed_w) = kv;
187 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
188 }
189 packed_b[nr_block_offset] -= ksum * (int32_t) izp;
190 packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
191 }
192 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
193 }
194 }
195 }
196 }
197 }
198 }
199 k += kh * kw * kc * nc;
200 b += nc;
201 }
202}
203
204static inline void xnn_pack_q8_dwconv_ghw_w(
205 size_t h,
206 size_t w,
207 size_t c,
208 size_t cr,
209 uint8_t izp,
210 uint8_t kzp,
211 const uint8_t* k,
212 const int32_t* b,
213 void* packed_w)
214{
215 const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
216 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
217 const size_t cr_block_size = min(c - cr_block_start, cr);
218 int32_t* packed_b = (int32_t*) packed_w;
219 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
220 *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
221 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
222 }
223 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
224 for (size_t x = 0; x < w; x++) {
225 for (size_t y = 0; y < h; y++) {
226 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
227 const uint8_t kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
228 packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
229 *((uint8_t*) packed_w) = kv;
230 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
231 }
232 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
233 }
234 }
235 }
236}
237
238static inline void xnn_pack_q8_dwconv_hwg_w(
239 size_t h,
240 size_t w,
241 size_t c,
242 size_t cr,
243 uint8_t izp,
244 uint8_t kzp,
245 const uint8_t* k,
246 const int32_t* b,
247 void* packed_w)
248{
249 const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
250 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
251 const size_t cr_block_size = min(c - cr_block_start, cr);
252 int32_t* packed_b = (int32_t*) packed_w;
253 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
254 *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
255 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
256 }
257 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
258 for (size_t x = 0; x < w; x++) {
259 for (size_t y = 0; y < h; y++) {
260 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
261 const uint8_t kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
262 packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
263 *((uint8_t*) packed_w) = kv;
264 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
265 }
266 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
267 }
268 }
269 }
270}
271
272static inline void xnn_pack_f16_gemm_goi_w(
273 size_t g,
274 size_t nc,
275 size_t kc,
276 size_t nr,
277 size_t kr,
278 const uint16_t* k,
279 const uint16_t* b,
280 uint16_t* packed_w)
281{
282 do {
283 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
284 const size_t nr_block_size = min(nc - nr_block_start, nr);
285 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
286 *packed_w++ = b[nr_block_start + nr_block_offset];
287 }
288 packed_w += nr - nr_block_size;
289 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
290 const size_t kr_block_size = min(kc - kr_block_start, kr);
291 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
292 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
293 *packed_w++ =
294 k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
295 }
296 packed_w += kr - kr_block_size;
297 }
298 packed_w += (nr - nr_block_size) * kr;
299 }
300 }
301 k += nc * kc;
302 b += nc;
303 } while (--g != 0);
304}
305
306static inline void xnn_pack_f32_gemm_goi_w(
307 size_t g,
308 size_t nc,
309 size_t kc,
310 size_t nr,
311 size_t kr,
312 size_t sr,
313 const float* k,
314 const float* b,
315 float* packed_w)
316{
317 const size_t skr = sr * kr;
318 const size_t skc = round_down_po2(kc, skr);
319 const size_t sr_mask = (sr - 1) * kr;
320 do {
321 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
322 const size_t nr_block_size = min(nc - nr_block_start, nr);
323 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
324 *packed_w++ = b[nr_block_start + nr_block_offset];
325 }
326 packed_w += nr - nr_block_size;
327
328 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
329 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
330 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
331 *packed_w++ =
332 k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
333 }
334 }
335 packed_w += (nr - nr_block_size) * kr;
336 }
337
338 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
339 const size_t kr_block_size = min(kc - kr_block_start, kr);
340 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
341 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
342 *packed_w++ =
343 k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
344 }
345 packed_w += kr - kr_block_size;
346 }
347 packed_w += (nr - nr_block_size) * kr;
348 }
349 }
350 k += nc * kc;
351 b += nc;
352 } while (--g != 0);
353}
354
355static inline void xnn_pack_f32_gemminc_goi_w(
356 size_t g,
357 size_t nc,
358 size_t kc,
359 size_t nr,
360 size_t kr,
361 size_t sr,
362 const float* k,
363 float* packed_w)
364{
365 const size_t skr = sr * kr;
366 const size_t skc = round_down_po2(kc, skr);
367 const size_t sr_mask = (sr - 1) * kr;
368 do {
369 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
370 const size_t nr_block_size = min(nc - nr_block_start, nr);
371
372 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
373 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
374 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
375 *packed_w++ =
376 k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
377 }
378 }
379 packed_w += (nr - nr_block_size) * kr;
380 }
381
382 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
383 const size_t kr_block_size = min(kc - kr_block_start, kr);
384 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
385 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
386 *packed_w++ =
387 k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
388 }
389 packed_w += kr - kr_block_size;
390 }
391 packed_w += (nr - nr_block_size) * kr;
392 }
393 }
394 k += nc * kc;
395 } while (--g != 0);
396}
397
398static inline void xnn_pack_f32_conv_goki_w(
399 size_t g,
400 size_t nc,
401 size_t ks,
402 size_t kc,
403 size_t nr,
404 size_t kr,
405 size_t sr,
406 const float* k,
407 const float* b,
408 float* packed_w)
409{
410 const size_t skr = sr * kr;
411 const size_t skc = round_down_po2(kc, skr);
412 const size_t sr_mask = (sr - 1) * kr;
413 do {
414 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
415 const size_t nr_block_size = min(nc - nr_block_start, nr);
416 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
417 *packed_w++ = b[nr_block_start + nr_block_offset];
418 }
419 packed_w += nr - nr_block_size;
420
421 for (size_t ki = 0; ki < ks; ki++) {
422 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
423 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
424 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
425 *packed_w++ =
426 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
427 }
428 }
429 packed_w += (nr - nr_block_size) * kr;
430 }
431
432 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
433 const size_t kr_block_size = min(kc - kr_block_start, kr);
434 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
435 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
436 *packed_w++ =
437 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
438 }
439 packed_w += kr - kr_block_size;
440 }
441 packed_w += (nr - nr_block_size) * kr;
442 }
443 }
444 }
445 k += ks * kc * nc;
446 b += nc;
447 } while (--g != 0);
448}
449
450static inline void xnn_pack_f32_conv_kgo_w(
451 size_t g,
452 size_t nc,
453 size_t ks,
454 size_t nr,
455 size_t kr,
456 const float* k,
457 const float* b,
458 float* packed_w)
459{
460 for (size_t i = 0; i < g; i++) {
461 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
462 const size_t nr_block_size = min(nc - nr_block_start, nr);
463 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
464 *packed_w++ = b[nr_block_start + nr_block_offset];
465 }
466 packed_w += nr - nr_block_size;
467 for (size_t ki = 0; ki < ks; ki++) {
468 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
469 *packed_w =
470 k[ki * g * nc + (nr_block_start + nr_block_offset)];
471 packed_w += kr;
472 }
473 packed_w += (nr - nr_block_size) * kr;
474 }
475 }
476 k += nc;
477 b += nc;
478 }
479}
480
481static inline void xnn_pack_f32_dconv_oki_w(
482 size_t nc,
483 size_t kc,
484 size_t nr,
485 size_t kh,
486 size_t kw,
487 const float* k,
488 const float* b,
489 float* packed_w)
490{
491 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
492 const size_t nr_block_size = min(nc - nr_block_start, nr);
493 for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
494 *packed_w++ = b[nr_block_start + min(nr_block_offset, nr_block_size - 1)];
495 }
496
497 for (size_t kx = 0; kx < kw; kx++) {
498 for (size_t c = 0; c < kc; c++) {
499 for (size_t ky = 0; ky < kh; ky++) {
500 for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
501 *packed_w++ = k[(((nr_block_start + min(nr_block_offset, nr_block_size - 1)) * kh + ky) * kw + kx) * kc + c];
502 }
503 }
504 }
505 }
506 }
507}
508
509static inline void xnn_pack_f32_deconv_goki_w(
510 size_t g,
511 size_t nc,
512 size_t kh,
513 size_t kw,
514 size_t kc,
515 size_t sh,
516 size_t sw,
517 size_t nr,
518 size_t kr,
Marat Dukhanc4ae7de2019-10-25 02:06:26 -0700519 size_t sr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700520 const float* k,
521 const float* b,
522 float* packed_w,
523 struct subconvolution_params* params)
524{
Marat Dukhanc4ae7de2019-10-25 02:06:26 -0700525 const size_t skr = sr * kr;
526 const size_t skc = round_down_po2(kc, skr);
527 const size_t sr_mask = (sr - 1) * kr;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700528 for (size_t i = 0; i < g; i++) {
529 for (size_t oy = 0; oy < sh; oy++) {
530 for (size_t ox = 0; ox < sw; ox++) {
531 if (i == 0) {
532 (*params++).weights = packed_w;
533 }
534 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
535 const size_t nr_block_size = min(nc - nr_block_start, nr);
536 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
537 *packed_w++ = b[nr_block_start + nr_block_offset];
538 }
539 packed_w += nr - nr_block_size;
540 for (size_t ky = oy; ky < kh; ky += sh) {
541 for (size_t kx = ox; kx < kw; kx += sw) {
Marat Dukhanc4ae7de2019-10-25 02:06:26 -0700542 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
543 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
544 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
545 *packed_w++ =
546 k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
547 }
548 }
549 packed_w += (nr - nr_block_size) * kr;
550 }
551
552 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700553 const size_t kr_block_size = min(kc - kr_block_start, kr);
554 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
555 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
556 *packed_w++ =
557 k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
558 }
559 packed_w += kr - kr_block_size;
560 }
561 packed_w += (nr - nr_block_size) * kr;
562 }
563 }
564 }
565 }
566 }
567 }
568 k += kh * kw * kc * nc;
569 b += nc;
570 }
571}
572
573static inline void xnn_pack_f32_dwconv_ghw_w(
574 size_t h,
575 size_t w,
576 size_t c,
577 size_t cr,
578 const float* k,
579 const float* b,
580 float* packed_w)
581{
582 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
583 const size_t cr_block_size = min(c - cr_block_start, cr);
584 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
585 *packed_w++ = b[cr_block_start + cr_block_offset];
586 }
587 packed_w += cr - cr_block_size;
588 for (size_t x = 0; x < w; x++) {
589 for (size_t y = 0; y < h; y++) {
590 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
591 const float kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
592 *packed_w++ = kv;
593 }
594 packed_w += cr - cr_block_size;
595 }
596 }
597 }
598}
599
600static inline void xnn_pack_f32_dwconv_hwg_w(
601 size_t h,
602 size_t w,
603 size_t c,
604 size_t cr,
605 const float* k,
606 const float* b,
607 float* packed_w)
608{
609 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
610 const size_t cr_block_size = min(c - cr_block_start, cr);
611 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
612 *packed_w++ = b[cr_block_start + cr_block_offset];
613 }
614 packed_w += cr - cr_block_size;
615 for (size_t x = 0; x < w; x++) {
616 for (size_t y = 0; y < h; y++) {
617 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
618 const float kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
619 *packed_w++ = kv;
620 }
621 packed_w += cr - cr_block_size;
622 }
623 }
624 }
625}
626
627static inline void xnn_pack_f32_spchw_dwconv_ghw_w(
628 size_t kernel_size,
629 size_t groups,
630 const float* kernel,
631 const float* bias,
632 float* packed_weights)
633{
634 for (size_t g = 0; g < groups; g++) {
635 *packed_weights++ = *bias++;
636 for (size_t i = 0; i < kernel_size; i++) {
637 *packed_weights++ = kernel[g * kernel_size + i];
638 }
639 }
640}
641
642static inline void xnn_pack_f32_vmulcaddc_w(
643 size_t c,
644 size_t cr,
645 const float* s,
646 const float* b,
647 float* packed_w)
648{
649 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
650 const size_t cr_block_size = min(c - cr_block_start, cr);
651 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
652 packed_w[cr_block_offset] = s[cr_block_start + cr_block_offset];
653 }
654 packed_w += cr;
655 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
656 packed_w[cr_block_offset] = b[cr_block_start + cr_block_offset];
657 }
658 packed_w += cr;
659 }
660}