blob: 17dd3a76f09a8ea6ae4472d562c3a706c01117f7 [file] [log] [blame]
Marat Dukhan1d75a542020-02-03 12:23:01 -08001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <math.h>
7#include <stddef.h>
8#include <stdint.h>
9#include <stdlib.h>
10
11#include <xnnpack.h>
12#include <xnnpack/allocator.h>
13#include <xnnpack/log.h>
14#include <xnnpack/math.h>
15#include <xnnpack/params.h>
16#include <xnnpack/subgraph.h>
17
18
19enum xnn_status xnn_create_subgraph(
20 uint32_t external_value_ids,
21 uint32_t flags,
22 xnn_subgraph_t* subgraph_out)
23{
24 struct xnn_subgraph* subgraph = NULL;
25 enum xnn_status status = xnn_status_uninitialized;
26
Marat Dukhan854fb6b2020-06-19 12:33:44 -070027 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
Marat Dukhan1d75a542020-02-03 12:23:01 -080028 xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
29 goto error;
30 }
31
32 status = xnn_status_out_of_memory;
33
34 subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
35 if (subgraph == NULL) {
36 xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
37 goto error;
38 }
39
40 subgraph->external_value_ids = external_value_ids;
41
42 subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
43 if (subgraph->values == NULL) {
44 xnn_log_error("failed to allocate %zu bytes for subgraph values", external_value_ids * sizeof(struct xnn_value));
45 goto error;
46 }
47 for (size_t i = 0; i < external_value_ids; i++) {
48 subgraph->values[i].id = i;
49 }
50 subgraph->num_values = external_value_ids;
51 subgraph->num_reserved_values = external_value_ids;
52
53 *subgraph_out = subgraph;
54 return xnn_status_success;
55
56error:
57 xnn_delete_subgraph(subgraph);
58 return status;
59}
60
61
62struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
63{
64 struct xnn_value* values = subgraph->values;
65 const size_t size = subgraph->num_values;
66 const size_t capacity = subgraph->num_reserved_values;
67 if (capacity < size + 1) {
68 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
69 assert(new_capacity >= size + 1);
70 values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
71 if (values == NULL) {
72 xnn_log_error("failed to allocate %zu bytes for subgraph values",
73 capacity * sizeof(struct xnn_value));
74 return values;
75 }
76
77 memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
78 subgraph->num_reserved_values = new_capacity;
79 subgraph->values = values;
80 }
81 subgraph->num_values = size + 1;
82 struct xnn_value* new_value = values + size;
83 new_value->id = size;
84 return new_value;
85}
86
Marat Dukhan1f198722020-05-24 14:07:03 -070087void xnn_node_clear(struct xnn_node* node) {
88 assert(node != NULL);
89 assert(node->type != xnn_node_type_invalid);
90 memset(node, 0, sizeof(struct xnn_node));
91}
92
93void xnn_value_clear(struct xnn_value* value) {
94 assert(value != NULL);
95 assert(value->type != xnn_value_type_invalid);
96 memset(value, 0, sizeof(struct xnn_value));
97}
98
Marat Dukhan1d75a542020-02-03 12:23:01 -080099struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
100{
101 struct xnn_node* nodes = subgraph->nodes;
102 const size_t size = subgraph->num_nodes;
103 const size_t capacity = subgraph->num_reserved_nodes;
104
105 if (capacity < size + 1) {
106 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
107 assert(new_capacity >= size + 1);
108 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
109 if (nodes == NULL) {
110 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
111 capacity * sizeof(struct xnn_node));
112 return nodes;
113 }
114
115 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
116 subgraph->num_reserved_nodes = new_capacity;
117 subgraph->nodes = nodes;
118 }
119 subgraph->num_nodes = size + 1;
120 struct xnn_node* new_node = nodes + size;
121 new_node->id = size;
122 return new_node;
123}
124
Marat Dukhan9de90e02020-06-18 16:04:12 -0700125#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1
126#define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
127#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
128#define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
129
Marat Dukhan9de90e02020-06-18 16:04:12 -0700130uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
131 switch (node->type) {
132 case xnn_node_type_convolution_2d:
133 // Supported cases:
134 // - 1x1 convolution (no stride, no dilation, no padding, no groups)
135 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
136 if (node->params.convolution_2d.groups != 1) {
137 return 0;
138 }
139 if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
140 return 0;
141 }
142 if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
143 if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
144 node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0)
145 {
146 return 0;
147 }
148 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
149 return 0;
150 }
151 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
152 } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
153 if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
154 node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1)
155 {
156 return 0;
157 }
158 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
159 return 0;
160 }
161 if (node->params.convolution_2d.group_input_channels != 3) {
162 return 0;
163 }
164 return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
165 }
166 return 0;
167 case xnn_node_type_depthwise_convolution_2d:
168 // Supported cases:
169 // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
170 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
171 // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
172 // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
173 if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
174 return 0;
175 }
176 if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
177 return 0;
178 }
179 if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
180 return 0;
181 }
182 if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
183 return 0;
184 }
185 switch (node->params.depthwise_convolution_2d.subsampling_height) {
186 case 1:
187 case 2:
188 break;
189 default:
190 return 0;
191 }
192 if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
193 return 0;
194 }
195 switch (node->params.depthwise_convolution_2d.kernel_height) {
196 case 3:
197 return node->params.depthwise_convolution_2d.input_padding_top == 1 &&
198 node->params.depthwise_convolution_2d.input_padding_right == 1 &&
199 node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
200 node->params.depthwise_convolution_2d.input_padding_left == 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
201 case 5:
202 return node->params.depthwise_convolution_2d.input_padding_top == 2 &&
203 node->params.depthwise_convolution_2d.input_padding_right == 2 &&
204 node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
205 node->params.depthwise_convolution_2d.input_padding_left == 2 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
206 default:
207 return 0;
208 }
Artsiom Ablavatskibbe85062020-11-05 14:07:37 -0800209 case xnn_node_type_depth_to_space:
Marat Dukhanf56b4bb2020-12-06 19:06:04 -0800210 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
211 case xnn_node_type_global_average_pooling_2d:
Marat Dukhan9de90e02020-06-18 16:04:12 -0700212 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
213 case xnn_node_type_add2:
214 case xnn_node_type_multiply2:
215 assert(node->num_inputs == 2);
216 assert(node->num_outputs == 1);
217 if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
218 subgraph->values[node->inputs[1]].shape.num_dims != 4)
219 {
220 return 0;
221 }
222
223 if (subgraph->values[node->inputs[0]].data != NULL) {
224 // Check that the first input is representable as either a scalar, or a vector
225 size_t num_nonunit_dims = 0;
226 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
227 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
228 num_nonunit_dims += 1;
229 }
230 }
231 if (num_nonunit_dims > 1) {
232 return 0;
233 }
234 }
235
236 if (subgraph->values[node->inputs[1]].data != NULL) {
237 // Check that the second input is representable as either a scalar, or a vector
238 size_t num_nonunit_dims = 0;
239 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
240 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
241 num_nonunit_dims += 1;
242 }
243 }
244 if (num_nonunit_dims > 1) {
245 return 0;
246 }
247 }
248
249 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
Artsiom Ablavatskie6beeba2020-10-28 09:12:19 -0700250 case xnn_node_type_static_resize_bilinear_2d:
251 return subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
252 subgraph->values[node->inputs[0]].shape.dim[2] > 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
Marat Dukhan9de90e02020-06-18 16:04:12 -0700253 case xnn_node_type_abs:
254 case xnn_node_type_bankers_rounding:
255 case xnn_node_type_ceiling:
256 case xnn_node_type_clamp:
Marat Dukhan094e6922020-12-08 12:54:38 -0800257 case xnn_node_type_elu:
Marat Dukhan9de90e02020-06-18 16:04:12 -0700258 case xnn_node_type_floor:
259 case xnn_node_type_hardswish:
260 case xnn_node_type_leaky_relu:
261 case xnn_node_type_negate:
262 case xnn_node_type_sigmoid:
263 case xnn_node_type_square:
264 assert(node->num_inputs == 1);
265 assert(node->num_outputs == 1);
266 return subgraph->values[node->inputs[0]].shape.num_dims == 4 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
267 default:
268 return false;
269 }
270}
271
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700272void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
Marat Dukhan9de90e02020-06-18 16:04:12 -0700273{
274 // Convert parts of the subgraph to NCHW for sparse inference
275 // Step 1: detect NCHW-compatible Nodes
276 // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
277 // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
278 // Step 4: switch Values' layout to NCHW
279 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
280 struct xnn_node* node = &subgraph->nodes[n];
281 node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
282 xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
283 n, xnn_node_type_to_string(node->type),
284 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
285 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
286 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
287 }
288
XNNPACK Teama117ce72020-10-05 17:26:02 -0700289 // Run Shiloach-Vishkin connected components algorithm i.e. find all
290 // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
291 // to all the producer nodes
Marat Dukhan9de90e02020-06-18 16:04:12 -0700292 bool update = false;
293 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
294 struct xnn_node* node = &subgraph->nodes[n];
295 node->cluster_leader = n;
296 if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
297 for (uint32_t i = 0; i < node->num_inputs; i++) {
298 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
299 if (value->data != NULL) {
300 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
301 // during the initial NCHW compatibility check for the Node.
302 continue;
303 }
304 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
305 // External value, invalid cluster
306 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
307 continue;
308 }
309 const uint32_t producer_id = value->producer;
310 assert(producer_id != XNN_INVALID_NODE_ID);
311 assert(producer_id < n);
312 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
313 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
314 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
315 {
316 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
317 if (producer_node->cluster_leader != node->cluster_leader) {
318 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
319 update = true;
320 }
321 } else {
322 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
323 }
324 }
325 }
326 }
XNNPACK Teama117ce72020-10-05 17:26:02 -0700327 // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
328 // pratically cannot happen.
329 if (!update) {
330 return;
331 }
332 // Propagate the cluster leader to other nodes in the graph untill all the
333 // nodes in the cluster is not updated
Marat Dukhan9de90e02020-06-18 16:04:12 -0700334 while (update) {
335 update = false;
336 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
337 struct xnn_node* node = &subgraph->nodes[n];
338 if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
339 continue;
340 }
341
342 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
343 continue;
344 }
345
346 for (uint32_t i = 0; i < node->num_inputs; i++) {
347 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
348 if (value->data != NULL) {
349 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
350 // during the initial NCHW compatibility check for the Node.
351 continue;
352 }
353 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
354 // External value, invalid cluster
355 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
356 continue;
357 }
358 const uint32_t producer_id = value->producer;
359 assert(producer_id != XNN_INVALID_NODE_ID);
360 assert(producer_id < n);
361 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
362 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
363 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
364 {
365 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
366 if (producer_node->cluster_leader != node->cluster_leader) {
367 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
368 update = true;
369 }
370 } else {
371 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
372 }
373 }
374 }
375 }
376 // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
377 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
378 struct xnn_node* node = &subgraph->nodes[n];
379 subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
380 }
381 // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
382 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
383 struct xnn_node* node = &subgraph->nodes[n];
384 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
385 continue;
386 }
387
388 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
389 continue;
390 }
391
392 for (uint32_t i = 0; i < node->num_inputs; i++) {
393 struct xnn_value* value = &subgraph->values[node->inputs[i]];
394 if (value->data != NULL) {
395 // Static data, skip this input value because it doesn't have a producer Node.
396 continue;
397 }
398 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
399 value->num_nchw_compatible_consumers += 1;
400 }
401 }
402 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
403 struct xnn_node* node = &subgraph->nodes[n];
404 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
405 continue;
406 }
407
408 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
409 continue;
410 }
411
412 for (uint32_t i = 0; i < node->num_inputs; i++) {
413 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
414 if (value->data != NULL) {
415 // Static data, skip this input value because it doesn't have a producer Node.
416 continue;
417 }
418 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
419 assert(value->num_nchw_compatible_consumers > 0);
420 if (value->num_nchw_compatible_consumers != value->num_consumers) {
421 subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
422 }
423 }
424 }
Marat Dukhan54b2d542020-12-08 00:19:52 -0800425 // Evaluate if it is profitable to run the model as sparse:
426 // - Compute the number of parameters and zeroes in 1x1 Convolution weights
427 // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
428 // or with less than 2/3rd of zeroes in 1x1 Convolution filters
429 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
430 struct xnn_node* node = &subgraph->nodes[n];
431 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
432 continue;
433 }
434
435 if (node->type == xnn_node_type_convolution_2d &&
436 max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
437 {
438 assert(node->num_inputs >= 2);
439
440 const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
441 assert(filter->data != NULL);
442 assert(filter->shape.num_dims == 4);
443
444 const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
445 subgraph->nodes[node->cluster_leader].num_params += num_params;
446
447 const float* data = (const float*) filter->data;
448 size_t num_zeroes = 0;
449 for (size_t i = 0; i < num_params; i++) {
450 num_zeroes += (size_t) (data[i] == 0.0f);
451 }
452 xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
453 subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
454 }
455 }
Artsiom Ablavatskicd3e0682021-06-02 19:25:22 -0700456 bool use_nchw_layout = false;
Marat Dukhan9de90e02020-06-18 16:04:12 -0700457 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
458 struct xnn_node* node = &subgraph->nodes[n];
459 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
460 continue;
461 }
462
463 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
464 continue;
465 }
466
Marat Dukhan54b2d542020-12-08 00:19:52 -0800467 if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
468 xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
469 n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
470 continue;
471 }
472
Marat Dukhan9de90e02020-06-18 16:04:12 -0700473 for (uint32_t i = 0; i < node->num_inputs; i++) {
474 struct xnn_value* value = &subgraph->values[node->inputs[i]];
475 if (value->data != NULL) {
476 // Static data, skip this input value because it doesn't have a producer Node.
477 continue;
478 }
479 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
480 assert(value->num_nchw_compatible_consumers > 0);
481 assert(value->num_nchw_compatible_consumers == value->num_consumers);
482 if (value->layout != xnn_layout_type_nchw) {
483 value->layout = xnn_layout_type_nchw;
484 xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
Artsiom Ablavatskicd3e0682021-06-02 19:25:22 -0700485 use_nchw_layout = true;
Marat Dukhan9de90e02020-06-18 16:04:12 -0700486 }
487 }
488 }
Artsiom Ablavatskicd3e0682021-06-02 19:25:22 -0700489 if (use_nchw_layout) {
490 xnn_log_info("XNNPACK has switched to sparse inference mode!");
491 }
Marat Dukhan9de90e02020-06-18 16:04:12 -0700492}
Marat Dukhan9de90e02020-06-18 16:04:12 -0700493
Marat Dukhan1f198722020-05-24 14:07:03 -0700494enum xnn_status xnn_subgraph_optimize(
495 xnn_subgraph_t subgraph,
496 uint32_t flags)
497{
498 // Initialize producer/consumer fields to safe defaults.
499 for (uint32_t i = 0; i < subgraph->num_values; i++) {
500 struct xnn_value* value = &subgraph->values[i];
501 value->producer = XNN_INVALID_NODE_ID;
502 value->first_consumer = XNN_INVALID_NODE_ID;
503 value->num_consumers = 0;
504 }
505
506 // Analyse Nodes' inputs and output and update Values' producer/consumer fields
507 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
508 struct xnn_node* node = &subgraph->nodes[n];
509
510 for (uint32_t i = 0; i < node->num_inputs; i++) {
511 const uint32_t input_id = node->inputs[i];
512 assert(input_id < subgraph->num_values);
513
514 if (subgraph->values[input_id].num_consumers++ == 0) {
515 assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
516 subgraph->values[input_id].first_consumer = n;
517 }
518 }
519
520 for (uint32_t o = 0; o < node->num_outputs; o++) {
521 const uint32_t output_id = node->outputs[o];
522 assert(output_id < subgraph->num_values);
523
524 assert(subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
525 subgraph->values[output_id].producer = n;
526 }
527 }
528
529 // Count extra consumer for Values which are external outputs.
530 // Remove unreferenced values.
531 for (uint32_t i = 0; i < subgraph->num_values; i++) {
532 struct xnn_value* value = &subgraph->values[i];
533 if (value->type == xnn_value_type_invalid) {
534 continue;
535 }
536
537 if (value->flags & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
538 value->num_consumers += 1;
539 }
540 if ((value->flags & XNN_VALUE_FLAG_EXTERNAL_INPUT) == 0 && value->num_consumers == 0) {
541 xnn_value_clear(value);
542 }
543 }
544
545 // Fuse Nodes where possible
546 for (uint32_t i = 0; i < subgraph->num_values; i++) {
547 struct xnn_value* value = &subgraph->values[i];
548 if (value->num_consumers == 1) {
549 const uint32_t producer_id = value->producer;
550 if (producer_id == XNN_INVALID_NODE_ID) {
551 continue;
552 }
553 assert(producer_id < subgraph->num_nodes);
554
555 const uint32_t consumer_id = value->first_consumer;
556 if (consumer_id == XNN_INVALID_NODE_ID) {
557 continue;
558 }
559 assert(consumer_id < subgraph->num_nodes);
560
561 struct xnn_node* producer = &subgraph->nodes[producer_id];
562 assert(producer->type != xnn_node_type_invalid);
563 struct xnn_node* consumer = &subgraph->nodes[consumer_id];
564 assert(consumer->type != xnn_node_type_invalid);
565
566 // Try to fuse Clamp Node upstream into producer Node
567 if (consumer->type == xnn_node_type_clamp) {
568 switch (producer->type) {
569 case xnn_node_type_add2:
570 case xnn_node_type_average_pooling_2d:
571 case xnn_node_type_clamp:
572 case xnn_node_type_convolution_2d:
Marat Dukhanb293e8d2020-07-23 20:10:45 -0700573 case xnn_node_type_divide:
574 case xnn_node_type_deconvolution_2d:
Marat Dukhan1f198722020-05-24 14:07:03 -0700575 case xnn_node_type_depthwise_convolution_2d:
576 case xnn_node_type_fully_connected:
577 case xnn_node_type_multiply2:
578 case xnn_node_type_max_pooling_2d:
Marat Dukhanb293e8d2020-07-23 20:10:45 -0700579 case xnn_node_type_subtract:
Marat Dukhan1f198722020-05-24 14:07:03 -0700580 xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
581 assert(producer->num_outputs == 1);
582 assert(consumer->num_inputs == 1);
583 assert(consumer->num_outputs == 1);
584
585 const uint32_t fused_output_id = consumer->outputs[0];
586 assert(fused_output_id < subgraph->num_values);
587 subgraph->values[fused_output_id].producer = producer_id;
588 producer->outputs[0] = fused_output_id;
589
590 producer->activation.output_min =
591 math_max_f32(producer->activation.output_min, consumer->activation.output_min);
592 producer->activation.output_max =
593 math_min_f32(producer->activation.output_max, consumer->activation.output_max);
594
595 xnn_node_clear(consumer);
596 xnn_value_clear(value);
597 break;
598 default:
599 break;
600 }
601 }
Marat Dukhanf3d12052020-05-25 15:41:37 -0700602 // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
Marat Dukhanaff24e22020-07-23 01:43:58 -0700603 if (producer->type == xnn_node_type_static_constant_pad) {
Marat Dukhanf3d12052020-05-25 15:41:37 -0700604 assert(producer->num_inputs == 1);
605 assert(producer->num_outputs == 1);
Marat Dukhan62a69492020-06-16 23:36:40 -0700606 const bool is_spatial_2d_zero_padding = value->shape.num_dims == 4 &&
Marat Dukhanf3d12052020-05-25 15:41:37 -0700607 (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
Marat Dukhan62a69492020-06-16 23:36:40 -0700608 producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0 &&
609 producer->params.static_pad.padding_value == 0;
Marat Dukhanf3d12052020-05-25 15:41:37 -0700610 switch (consumer->type) {
611 case xnn_node_type_convolution_2d:
Marat Dukhan62a69492020-06-16 23:36:40 -0700612 if (is_spatial_2d_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
Marat Dukhanf3d12052020-05-25 15:41:37 -0700613 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
614 consumer_id, producer_id);
615 assert(consumer->num_inputs >= 1);
616 assert(consumer->inputs[0] == producer->outputs[0]);
617
618 consumer->params.convolution_2d.input_padding_top += producer->params.static_pad.pre_paddings[1];
Marat Dukhanfacecc52020-08-10 08:00:08 -0700619 consumer->params.convolution_2d.input_padding_right += producer->params.static_pad.post_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700620 consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
Marat Dukhanfacecc52020-08-10 08:00:08 -0700621 consumer->params.convolution_2d.input_padding_left += producer->params.static_pad.pre_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700622
623 consumer->inputs[0] = producer->inputs[0];
624
625 const uint32_t fused_input_id = producer->inputs[0];
626 assert(fused_input_id < subgraph->num_values);
627 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
628 subgraph->values[fused_input_id].first_consumer = consumer_id;
629 }
630
631 xnn_node_clear(producer);
632 xnn_value_clear(value);
633 }
634 break;
635 case xnn_node_type_depthwise_convolution_2d:
Marat Dukhan62a69492020-06-16 23:36:40 -0700636 if (is_spatial_2d_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
Marat Dukhanf3d12052020-05-25 15:41:37 -0700637 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
638 consumer_id, producer_id);
639 assert(consumer->num_inputs >= 1);
640 assert(consumer->inputs[0] == producer->outputs[0]);
641
642 consumer->params.depthwise_convolution_2d.input_padding_top +=
643 producer->params.static_pad.pre_paddings[1];
644 consumer->params.depthwise_convolution_2d.input_padding_right +=
Marat Dukhanfacecc52020-08-10 08:00:08 -0700645 producer->params.static_pad.post_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700646 consumer->params.depthwise_convolution_2d.input_padding_bottom +=
647 producer->params.static_pad.post_paddings[1];
648 consumer->params.depthwise_convolution_2d.input_padding_left +=
Marat Dukhanfacecc52020-08-10 08:00:08 -0700649 producer->params.static_pad.pre_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700650
651 consumer->inputs[0] = producer->inputs[0];
652
653 const uint32_t fused_input_id = producer->inputs[0];
654 assert(fused_input_id < subgraph->num_values);
655 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
656 subgraph->values[fused_input_id].first_consumer = consumer_id;
657 }
658
659 xnn_node_clear(producer);
660 xnn_value_clear(value);
661 }
662 break;
663 default:
664 break;
665 }
666 }
Marat Dukhan1f198722020-05-24 14:07:03 -0700667 }
668 }
Marat Dukhan9de90e02020-06-18 16:04:12 -0700669
670 #if XNN_ENABLE_SPARSE
Marat Dukhancfbed0a2020-12-08 10:01:51 -0800671 if ((flags & XNN_FLAG_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
Marat Dukhan7332e832020-12-06 23:26:11 -0800672 xnn_subgraph_rewrite_for_nchw(subgraph);
673 }
Marat Dukhan9de90e02020-06-18 16:04:12 -0700674 #endif
675
Marat Dukhan1f198722020-05-24 14:07:03 -0700676 return xnn_status_success;
677}
678
Marat Dukhan1d75a542020-02-03 12:23:01 -0800679enum xnn_status xnn_delete_subgraph(
680 xnn_subgraph_t subgraph)
681{
682 if (subgraph != NULL) {
683 memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
684 xnn_release_memory(subgraph->nodes);
685
686 memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
687 xnn_release_memory(subgraph->values);
688
689 memset(subgraph, 0, sizeof(struct xnn_subgraph));
690 xnn_release_memory(subgraph);
691 }
692 return xnn_status_success;
693}