blob: 49b747e0381585f2383067278efe05dc157e97b0 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * Copyright 2019 Google LLC
6 *
7 * This source code is licensed under the BSD-style license found in the
8 * LICENSE file in the root directory of this source tree.
9 */
10
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/maxpool.h>
16
17
18void xnn_u8_maxpool_ukernel_9p8q__neon(
19 size_t n,
20 size_t ks,
21 size_t kc,
22 const uint8_t** input,
23 uint8_t* output,
24 size_t input_increment,
25 size_t output_increment,
26 const union xnn_u8_output_params params[restrict static 1])
27{
28 assert(n != 0);
29 assert(ks != 0);
30 assert(kc != 0);
31
32 const uint8x16_t voutput_max = vld1q_dup_u8(&params->neon.max);
33 const uint8x16_t voutput_min = vld1q_dup_u8(&params->neon.min);
34 do {
35 uint8_t* o = output;
36 {
37 const uint8_t* i0 = *input++;
38 const uint8_t* i1 = *input++;
39 const uint8_t* i2 = *input++;
40 const uint8_t* i3 = *input++;
41 const uint8_t* i4 = *input++;
42 const uint8_t* i5 = *input++;
43 const uint8_t* i6 = *input++;
44 const uint8_t* i7 = *input++;
45 const uint8_t* i8 = *input++;
46 if (ks < 2) {
47 i1 = i0;
48 }
49 if (ks <= 2) {
50 i2 = i0;
51 }
52 if (ks < 4) {
53 i3 = i0;
54 }
55 if (ks <= 4) {
56 i4 = i0;
57 }
58 if (ks < 6) {
59 i5 = i0;
60 }
61 if (ks <= 6) {
62 i6 = i0;
63 }
64 if (ks < 8) {
65 i7 = i0;
66 }
67 if (ks <= 8) {
68 i8 = i0;
69 }
70
71 size_t k = kc;
72 for (; k >= 16; k -= 16) {
73 const uint8x16_t vi0 = vld1q_u8(i0); i0 += 16;
74 const uint8x16_t vi1 = vld1q_u8(i1); i1 += 16;
75 const uint8x16_t vi2 = vld1q_u8(i2); i2 += 16;
76 const uint8x16_t vi3 = vld1q_u8(i3); i3 += 16;
77 const uint8x16_t vi4 = vld1q_u8(i4); i4 += 16;
78 const uint8x16_t vi5 = vld1q_u8(i5); i5 += 16;
79 const uint8x16_t vi6 = vld1q_u8(i6); i6 += 16;
80 const uint8x16_t vi7 = vld1q_u8(i7); i7 += 16;
81 const uint8x16_t vi8 = vld1q_u8(i8); i8 += 16;
82
83 const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
84 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
85 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
86 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
87
88 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
89 const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
90 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
91 const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
92
93 vst1q_u8(o, vout); o += 16;
94 }
95 if (k != 0) {
96 const uint8x16_t vi0 = vld1q_u8(i0);
97 const uint8x16_t vi1 = vld1q_u8(i1);
98 const uint8x16_t vi2 = vld1q_u8(i2);
99 const uint8x16_t vi3 = vld1q_u8(i3);
100 const uint8x16_t vi4 = vld1q_u8(i4);
101 const uint8x16_t vi5 = vld1q_u8(i5);
102 const uint8x16_t vi6 = vld1q_u8(i6);
103 const uint8x16_t vi7 = vld1q_u8(i7);
104 const uint8x16_t vi8 = vld1q_u8(i8);
105
106 const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
107 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
108 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
109 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
110
111 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
112 const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
113 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
114 const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
115
116 uint8x8_t vout_lo = vget_low_u8(vout);
117 if (k & 8) {
118 vst1_u8(o, vout_lo); o += 8;
119 vout_lo = vget_high_u8(vout);
120 }
121 if (k & 4) {
122 vst1_lane_u32(__builtin_assume_aligned(o, 1), vreinterpret_u32_u8(vout_lo), 0); o += 4;
123 vout_lo = vext_u8(vout_lo, vout_lo, 4);
124 }
125 if (k & 2) {
126 vst1_lane_u16(__builtin_assume_aligned(o, 1), vreinterpret_u16_u8(vout_lo), 0); o += 2;
127 vout_lo = vext_u8(vout_lo, vout_lo, 2);
128 }
129 if (k & 1) {
130 vst1_lane_u8(o, vout_lo, 0); o += 1;
131 }
132 }
133 }
134
135 for (ptrdiff_t m = (ptrdiff_t) ks - 9; m > 0; m -= 8) {
136 const uint8_t* i0 = *input++;
137 const uint8_t* i1 = *input++;
138 const uint8_t* i2 = *input++;
139 const uint8_t* i3 = *input++;
140 const uint8_t* i4 = *input++;
141 const uint8_t* i5 = *input++;
142 const uint8_t* i6 = *input++;
143 const uint8_t* i7 = *input++;
144 if (m < 2) {
145 i1 = i0;
146 }
147 if (m <= 2) {
148 i2 = i0;
149 }
150 if (m < 4) {
151 i3 = i0;
152 }
153 if (m <= 4) {
154 i4 = i0;
155 }
156 if (m < 6) {
157 i5 = i0;
158 }
159 if (m <= 6) {
160 i6 = i0;
161 }
162 if (m < 8) {
163 i7 = i0;
164 }
165
166 o = output;
167 size_t k = kc;
168 for (; k >= 16; k -= 16) {
169 const uint8x16_t vi0 = vld1q_u8(i0); i0 += 16;
170 const uint8x16_t vi1 = vld1q_u8(i1); i1 += 16;
171 const uint8x16_t vi2 = vld1q_u8(i2); i2 += 16;
172 const uint8x16_t vi3 = vld1q_u8(i3); i3 += 16;
173 const uint8x16_t vi4 = vld1q_u8(i4); i4 += 16;
174 const uint8x16_t vi5 = vld1q_u8(i5); i5 += 16;
175 const uint8x16_t vi6 = vld1q_u8(i6); i6 += 16;
176 const uint8x16_t vi7 = vld1q_u8(i7); i7 += 16;
177 const uint8x16_t vo = vld1q_u8(o);
178
179 const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
180 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
181 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
182 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
183
184 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
185 const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
186 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
187 const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
188
189 vst1q_u8(o, vout); o += 16;
190 }
191 if (k != 0) {
192 const uint8x16_t vi0 = vld1q_u8(i0);
193 const uint8x16_t vi1 = vld1q_u8(i1);
194 const uint8x16_t vi2 = vld1q_u8(i2);
195 const uint8x16_t vi3 = vld1q_u8(i3);
196 const uint8x16_t vi4 = vld1q_u8(i4);
197 const uint8x16_t vi5 = vld1q_u8(i5);
198 const uint8x16_t vi6 = vld1q_u8(i6);
199 const uint8x16_t vi7 = vld1q_u8(i7);
200 const uint8x16_t vo = vld1q_u8(o);
201
202 const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
203 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
204 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
205 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
206
207 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
208 const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
209 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
210 const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
211
212 uint8x8_t vout_lo = vget_low_u8(vout);
213 if (k & 8) {
214 vst1_u8(o, vout_lo); o += 8;
215 vout_lo = vget_high_u8(vout);
216 }
217 if (k & 4) {
218 vst1_lane_u32(__builtin_assume_aligned(o, 1), vreinterpret_u32_u8(vout_lo), 0); o += 4;
219 vout_lo = vext_u8(vout_lo, vout_lo, 4);
220 }
221 if (k & 2) {
222 vst1_lane_u16(__builtin_assume_aligned(o, 1), vreinterpret_u16_u8(vout_lo), 0); o += 2;
223 vout_lo = vext_u8(vout_lo, vout_lo, 2);
224 }
225 if (k & 1) {
226 vst1_lane_u8(o, vout_lo, 0); o += 1;
227 }
228 }
229 }
230 input = (const uint8_t**) ((uintptr_t) input + input_increment);
231 output = (uint8_t*) ((uintptr_t) o + output_increment);
232 } while (--n != 0);
233}