blob: 342628f32a46bbf61547f151914f75d7a6ff41a8 [file] [log] [blame]
Marat Dukhan1d75a542020-02-03 12:23:01 -08001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <math.h>
7#include <stddef.h>
8#include <stdint.h>
9#include <stdlib.h>
10
Marat Dukhan4b90bee2022-02-04 00:00:18 -080011#include <fp16.h>
12
Marat Dukhan1d75a542020-02-03 12:23:01 -080013#include <xnnpack.h>
14#include <xnnpack/allocator.h>
15#include <xnnpack/log.h>
16#include <xnnpack/math.h>
17#include <xnnpack/params.h>
18#include <xnnpack/subgraph.h>
19
20
21enum xnn_status xnn_create_subgraph(
22 uint32_t external_value_ids,
23 uint32_t flags,
24 xnn_subgraph_t* subgraph_out)
25{
26 struct xnn_subgraph* subgraph = NULL;
27 enum xnn_status status = xnn_status_uninitialized;
28
Marat Dukhan854fb6b2020-06-19 12:33:44 -070029 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
Marat Dukhan1d75a542020-02-03 12:23:01 -080030 xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
31 goto error;
32 }
33
34 status = xnn_status_out_of_memory;
35
36 subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
37 if (subgraph == NULL) {
38 xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
39 goto error;
40 }
41
42 subgraph->external_value_ids = external_value_ids;
43
44 subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
45 if (subgraph->values == NULL) {
46 xnn_log_error("failed to allocate %zu bytes for subgraph values", external_value_ids * sizeof(struct xnn_value));
47 goto error;
48 }
49 for (size_t i = 0; i < external_value_ids; i++) {
50 subgraph->values[i].id = i;
51 }
52 subgraph->num_values = external_value_ids;
53 subgraph->num_reserved_values = external_value_ids;
54
55 *subgraph_out = subgraph;
56 return xnn_status_success;
57
58error:
59 xnn_delete_subgraph(subgraph);
60 return status;
61}
62
63
64struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
65{
66 struct xnn_value* values = subgraph->values;
67 const size_t size = subgraph->num_values;
68 const size_t capacity = subgraph->num_reserved_values;
69 if (capacity < size + 1) {
70 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
71 assert(new_capacity >= size + 1);
72 values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
73 if (values == NULL) {
74 xnn_log_error("failed to allocate %zu bytes for subgraph values",
75 capacity * sizeof(struct xnn_value));
76 return values;
77 }
78
79 memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
80 subgraph->num_reserved_values = new_capacity;
81 subgraph->values = values;
82 }
83 subgraph->num_values = size + 1;
84 struct xnn_value* new_value = values + size;
85 new_value->id = size;
86 return new_value;
87}
88
Marat Dukhan1f198722020-05-24 14:07:03 -070089void xnn_node_clear(struct xnn_node* node) {
90 assert(node != NULL);
Marat Dukhan1f198722020-05-24 14:07:03 -070091 memset(node, 0, sizeof(struct xnn_node));
92}
93
94void xnn_value_clear(struct xnn_value* value) {
95 assert(value != NULL);
Marat Dukhan1f198722020-05-24 14:07:03 -070096 memset(value, 0, sizeof(struct xnn_value));
97}
98
Marat Dukhan4620ca62022-02-03 12:31:00 -080099void xnn_value_copy(
100 struct xnn_value* dst_value,
101 const struct xnn_value* src_value)
102{
103 // Note: Value ID stays unchanged
104
105 dst_value->type = src_value->type;
106 dst_value->datatype = src_value->datatype;
107 dst_value->quantization = src_value->quantization;
108 dst_value->shape = src_value->shape;
109 dst_value->flags = src_value->flags;
110 dst_value->data = src_value->data;
111 dst_value->producer = src_value->producer;
112 dst_value->first_consumer = src_value->first_consumer;
113}
114
Marat Dukhan1d75a542020-02-03 12:23:01 -0800115struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
116{
117 struct xnn_node* nodes = subgraph->nodes;
118 const size_t size = subgraph->num_nodes;
119 const size_t capacity = subgraph->num_reserved_nodes;
120
121 if (capacity < size + 1) {
122 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
123 assert(new_capacity >= size + 1);
124 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
125 if (nodes == NULL) {
126 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
127 capacity * sizeof(struct xnn_node));
128 return nodes;
129 }
130
131 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
132 subgraph->num_reserved_nodes = new_capacity;
133 subgraph->nodes = nodes;
134 }
135 subgraph->num_nodes = size + 1;
136 struct xnn_node* new_node = nodes + size;
137 new_node->id = size;
138 return new_node;
139}
140
Marat Dukhan4620ca62022-02-03 12:31:00 -0800141void xnn_subgraph_add_nodes(xnn_subgraph_t subgraph, size_t num_nodes)
142{
143 struct xnn_node* nodes = subgraph->nodes;
144 const size_t size = subgraph->num_nodes;
145 const size_t capacity = subgraph->num_reserved_nodes;
146
147 if (capacity < size + num_nodes) {
148 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + max(num_nodes, 64));
149 assert(new_capacity >= size + num_nodes);
150 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
151 if (nodes == NULL) {
152 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
153 capacity * sizeof(struct xnn_node));
154 return;
155 }
156
157 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
158 subgraph->num_reserved_nodes = new_capacity;
159 subgraph->nodes = nodes;
160 }
161 subgraph->num_nodes = size + num_nodes;
162 struct xnn_node* new_nodes = nodes + size;
163 for (size_t i = 0; i < num_nodes; i++) {
164 new_nodes[i].id = size + i;
165 }
166}
167
168void xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph)
169{
170 // Initialize producer/consumer fields to safe defaults.
171 for (uint32_t i = 0; i < subgraph->num_values; i++) {
172 struct xnn_value* value = &subgraph->values[i];
173 value->producer = XNN_INVALID_NODE_ID;
174 value->first_consumer = XNN_INVALID_NODE_ID;
175 value->num_consumers = 0;
176 }
177
178 // Analyse Nodes' inputs and output and update Values' producer/consumer fields
179 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
180 struct xnn_node* node = &subgraph->nodes[n];
181
182 for (uint32_t i = 0; i < node->num_inputs; i++) {
183 const uint32_t input_id = node->inputs[i];
184 assert(input_id < subgraph->num_values);
185
186 if (subgraph->values[input_id].num_consumers++ == 0) {
187 assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
188 subgraph->values[input_id].first_consumer = n;
189 }
190 }
191
192 for (uint32_t o = 0; o < node->num_outputs; o++) {
193 const uint32_t output_id = node->outputs[o];
194 assert(output_id < subgraph->num_values);
195
196 assert(subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
197 subgraph->values[output_id].producer = n;
198 }
199 }
200
201 // Count extra consumer for Values which are external outputs.
202 // Remove unreferenced values.
203 for (uint32_t i = 0; i < subgraph->num_values; i++) {
204 struct xnn_value* value = &subgraph->values[i];
205 if (value->flags & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
206 value->num_consumers += 1;
207 }
208 }
209}
210
Marat Dukhan9de90e02020-06-18 16:04:12 -0700211#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1
212#define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
213#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
214#define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
215
Marat Dukhan9de90e02020-06-18 16:04:12 -0700216uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
Marat Dukhand2ad6d02021-11-14 19:37:26 -0800217 if (node->compute_type != xnn_compute_type_fp32) {
218 return 0;
219 }
220
Marat Dukhan9de90e02020-06-18 16:04:12 -0700221 switch (node->type) {
222 case xnn_node_type_convolution_2d:
223 // Supported cases:
224 // - 1x1 convolution (no stride, no dilation, no padding, no groups)
225 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
226 if (node->params.convolution_2d.groups != 1) {
227 return 0;
228 }
229 if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
230 return 0;
231 }
232 if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
233 if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
234 node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0)
235 {
236 return 0;
237 }
238 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
239 return 0;
240 }
241 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
242 } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
243 if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
244 node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1)
245 {
246 return 0;
247 }
248 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
249 return 0;
250 }
251 if (node->params.convolution_2d.group_input_channels != 3) {
252 return 0;
253 }
254 return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
255 }
256 return 0;
257 case xnn_node_type_depthwise_convolution_2d:
258 // Supported cases:
259 // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
260 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
261 // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
262 // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
263 if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
264 return 0;
265 }
266 if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
267 return 0;
268 }
269 if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
270 return 0;
271 }
272 if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
273 return 0;
274 }
275 switch (node->params.depthwise_convolution_2d.subsampling_height) {
276 case 1:
277 case 2:
278 break;
279 default:
280 return 0;
281 }
282 if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
283 return 0;
284 }
285 switch (node->params.depthwise_convolution_2d.kernel_height) {
286 case 3:
287 return node->params.depthwise_convolution_2d.input_padding_top == 1 &&
288 node->params.depthwise_convolution_2d.input_padding_right == 1 &&
289 node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
290 node->params.depthwise_convolution_2d.input_padding_left == 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
291 case 5:
292 return node->params.depthwise_convolution_2d.input_padding_top == 2 &&
293 node->params.depthwise_convolution_2d.input_padding_right == 2 &&
294 node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
295 node->params.depthwise_convolution_2d.input_padding_left == 2 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
296 default:
297 return 0;
298 }
Artsiom Ablavatskibbe85062020-11-05 14:07:37 -0800299 case xnn_node_type_depth_to_space:
Marat Dukhanf56b4bb2020-12-06 19:06:04 -0800300 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
301 case xnn_node_type_global_average_pooling_2d:
Marat Dukhan9de90e02020-06-18 16:04:12 -0700302 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
303 case xnn_node_type_add2:
304 case xnn_node_type_multiply2:
305 assert(node->num_inputs == 2);
306 assert(node->num_outputs == 1);
307 if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
308 subgraph->values[node->inputs[1]].shape.num_dims != 4)
309 {
310 return 0;
311 }
312
313 if (subgraph->values[node->inputs[0]].data != NULL) {
314 // Check that the first input is representable as either a scalar, or a vector
315 size_t num_nonunit_dims = 0;
316 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
317 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
318 num_nonunit_dims += 1;
319 }
320 }
321 if (num_nonunit_dims > 1) {
322 return 0;
323 }
324 }
325
326 if (subgraph->values[node->inputs[1]].data != NULL) {
327 // Check that the second input is representable as either a scalar, or a vector
328 size_t num_nonunit_dims = 0;
329 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
330 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
331 num_nonunit_dims += 1;
332 }
333 }
334 if (num_nonunit_dims > 1) {
335 return 0;
336 }
337 }
338
339 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
Artsiom Ablavatskie6beeba2020-10-28 09:12:19 -0700340 case xnn_node_type_static_resize_bilinear_2d:
341 return subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
342 subgraph->values[node->inputs[0]].shape.dim[2] > 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
Marat Dukhan9de90e02020-06-18 16:04:12 -0700343 case xnn_node_type_abs:
344 case xnn_node_type_bankers_rounding:
345 case xnn_node_type_ceiling:
346 case xnn_node_type_clamp:
Marat Dukhan094e6922020-12-08 12:54:38 -0800347 case xnn_node_type_elu:
Marat Dukhan9de90e02020-06-18 16:04:12 -0700348 case xnn_node_type_floor:
349 case xnn_node_type_hardswish:
350 case xnn_node_type_leaky_relu:
351 case xnn_node_type_negate:
352 case xnn_node_type_sigmoid:
353 case xnn_node_type_square:
354 assert(node->num_inputs == 1);
355 assert(node->num_outputs == 1);
356 return subgraph->values[node->inputs[0]].shape.num_dims == 4 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
357 default:
358 return false;
359 }
360}
361
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700362void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
Marat Dukhan9de90e02020-06-18 16:04:12 -0700363{
364 // Convert parts of the subgraph to NCHW for sparse inference
365 // Step 1: detect NCHW-compatible Nodes
366 // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
367 // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
368 // Step 4: switch Values' layout to NCHW
369 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
370 struct xnn_node* node = &subgraph->nodes[n];
371 node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
372 xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
373 n, xnn_node_type_to_string(node->type),
374 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
375 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
376 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
377 }
378
XNNPACK Teama117ce72020-10-05 17:26:02 -0700379 // Run Shiloach-Vishkin connected components algorithm i.e. find all
380 // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
381 // to all the producer nodes
Marat Dukhan9de90e02020-06-18 16:04:12 -0700382 bool update = false;
383 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
384 struct xnn_node* node = &subgraph->nodes[n];
385 node->cluster_leader = n;
386 if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
387 for (uint32_t i = 0; i < node->num_inputs; i++) {
388 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
389 if (value->data != NULL) {
390 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
391 // during the initial NCHW compatibility check for the Node.
392 continue;
393 }
394 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
395 // External value, invalid cluster
396 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
397 continue;
398 }
399 const uint32_t producer_id = value->producer;
400 assert(producer_id != XNN_INVALID_NODE_ID);
401 assert(producer_id < n);
402 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
403 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
404 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
405 {
406 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
407 if (producer_node->cluster_leader != node->cluster_leader) {
408 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
409 update = true;
410 }
411 } else {
412 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
413 }
414 }
415 }
416 }
XNNPACK Teama117ce72020-10-05 17:26:02 -0700417 // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
slowy07ab1127f2021-07-27 08:23:22 +0700418 // practically cannot happen.
XNNPACK Teama117ce72020-10-05 17:26:02 -0700419 if (!update) {
420 return;
421 }
422 // Propagate the cluster leader to other nodes in the graph untill all the
423 // nodes in the cluster is not updated
Marat Dukhan9de90e02020-06-18 16:04:12 -0700424 while (update) {
425 update = false;
426 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
427 struct xnn_node* node = &subgraph->nodes[n];
428 if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
429 continue;
430 }
431
432 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
433 continue;
434 }
435
436 for (uint32_t i = 0; i < node->num_inputs; i++) {
437 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
438 if (value->data != NULL) {
439 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
440 // during the initial NCHW compatibility check for the Node.
441 continue;
442 }
443 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
444 // External value, invalid cluster
445 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
446 continue;
447 }
448 const uint32_t producer_id = value->producer;
449 assert(producer_id != XNN_INVALID_NODE_ID);
450 assert(producer_id < n);
451 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
452 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
453 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
454 {
455 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
456 if (producer_node->cluster_leader != node->cluster_leader) {
457 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
458 update = true;
459 }
460 } else {
461 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
462 }
463 }
464 }
465 }
466 // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
467 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
468 struct xnn_node* node = &subgraph->nodes[n];
469 subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
470 }
471 // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
472 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
473 struct xnn_node* node = &subgraph->nodes[n];
474 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
475 continue;
476 }
477
478 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
479 continue;
480 }
481
482 for (uint32_t i = 0; i < node->num_inputs; i++) {
483 struct xnn_value* value = &subgraph->values[node->inputs[i]];
484 if (value->data != NULL) {
485 // Static data, skip this input value because it doesn't have a producer Node.
486 continue;
487 }
488 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
489 value->num_nchw_compatible_consumers += 1;
490 }
491 }
492 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
493 struct xnn_node* node = &subgraph->nodes[n];
494 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
495 continue;
496 }
497
498 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
499 continue;
500 }
501
502 for (uint32_t i = 0; i < node->num_inputs; i++) {
503 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
504 if (value->data != NULL) {
505 // Static data, skip this input value because it doesn't have a producer Node.
506 continue;
507 }
508 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
509 assert(value->num_nchw_compatible_consumers > 0);
510 if (value->num_nchw_compatible_consumers != value->num_consumers) {
511 subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
512 }
513 }
514 }
Marat Dukhan54b2d542020-12-08 00:19:52 -0800515 // Evaluate if it is profitable to run the model as sparse:
516 // - Compute the number of parameters and zeroes in 1x1 Convolution weights
517 // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
518 // or with less than 2/3rd of zeroes in 1x1 Convolution filters
519 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
520 struct xnn_node* node = &subgraph->nodes[n];
521 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
522 continue;
523 }
524
525 if (node->type == xnn_node_type_convolution_2d &&
526 max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
527 {
528 assert(node->num_inputs >= 2);
529
530 const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
531 assert(filter->data != NULL);
532 assert(filter->shape.num_dims == 4);
533
534 const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
535 subgraph->nodes[node->cluster_leader].num_params += num_params;
536
537 const float* data = (const float*) filter->data;
538 size_t num_zeroes = 0;
539 for (size_t i = 0; i < num_params; i++) {
540 num_zeroes += (size_t) (data[i] == 0.0f);
541 }
542 xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
543 subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
544 }
545 }
Artsiom Ablavatskicd3e0682021-06-02 19:25:22 -0700546 bool use_nchw_layout = false;
Marat Dukhan9de90e02020-06-18 16:04:12 -0700547 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
548 struct xnn_node* node = &subgraph->nodes[n];
549 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
550 continue;
551 }
552
553 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
554 continue;
555 }
556
Marat Dukhan54b2d542020-12-08 00:19:52 -0800557 if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
558 xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
559 n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
560 continue;
561 }
562
Marat Dukhan9de90e02020-06-18 16:04:12 -0700563 for (uint32_t i = 0; i < node->num_inputs; i++) {
564 struct xnn_value* value = &subgraph->values[node->inputs[i]];
565 if (value->data != NULL) {
566 // Static data, skip this input value because it doesn't have a producer Node.
567 continue;
568 }
569 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
570 assert(value->num_nchw_compatible_consumers > 0);
571 assert(value->num_nchw_compatible_consumers == value->num_consumers);
572 if (value->layout != xnn_layout_type_nchw) {
573 value->layout = xnn_layout_type_nchw;
574 xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
Artsiom Ablavatskicd3e0682021-06-02 19:25:22 -0700575 use_nchw_layout = true;
Marat Dukhan9de90e02020-06-18 16:04:12 -0700576 }
577 }
578 }
Artsiom Ablavatskicd3e0682021-06-02 19:25:22 -0700579 if (use_nchw_layout) {
580 xnn_log_info("XNNPACK has switched to sparse inference mode!");
581 }
Marat Dukhan9de90e02020-06-18 16:04:12 -0700582}
Marat Dukhan9de90e02020-06-18 16:04:12 -0700583
Marat Dukhan4620ca62022-02-03 12:31:00 -0800584void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
585{
586 xnn_log_info("Analyzing subgraph for FP16 compatibility");
587
588 // Convert tensors and operators in the subgraph to FP16
589 // 1. Check that all operators in the subgraph are supported in FP16.
590 // 2. Indicate values that must be converted to FP16.
591 // 3. Replace FP32 Values with FP16 Values as Nodes' inputs/outputs.
592 // 4. Insert FP32->FP16 Convert Nodes for external FP32 inputs and FP16->FP32 Convert Nodes for external outputs.
593
594 // Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one.
595 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
596 struct xnn_node* node = &subgraph->nodes[n];
Marat Dukhan170f95a2022-02-04 02:18:23 -0800597 if (node->type == xnn_node_type_invalid) {
598 // Node was fused away, skip.
599 continue;
600 }
601
Marat Dukhan4620ca62022-02-03 12:31:00 -0800602 if (node->compute_type != xnn_compute_type_fp32) {
603 xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type));
604 return;
605 }
606 switch (node->type) {
607 case xnn_node_type_add2:
608 assert(node->num_inputs == 2);
609 for (uint32_t i = 0; i < node->num_inputs; i++) {
610 if (subgraph->values[node->inputs[i]].data != NULL) {
611 xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) has static input %i",
612 n, xnn_node_type_to_string(node->type), i);
613 return;
614 }
615 }
616 break;
617 case xnn_node_type_convolution_2d:
618 case xnn_node_type_depthwise_convolution_2d:
619 case xnn_node_type_global_average_pooling_2d:
620 case xnn_node_type_hardswish:
Marat Dukhan670826b2022-02-04 02:36:08 -0800621 case xnn_node_type_max_pooling_2d:
Marat Dukhan170f95a2022-02-04 02:18:23 -0800622 case xnn_node_type_prelu:
Marat Dukhan4b90bee2022-02-04 00:00:18 -0800623 case xnn_node_type_static_constant_pad:
Marat Dukhancb872b02022-02-04 04:05:35 -0800624 case xnn_node_type_static_reshape:
Marat Dukhan4620ca62022-02-03 12:31:00 -0800625 break;
626 default:
627 xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference",
628 n, xnn_node_type_to_string(node->type));
629 return;
630 }
631 }
632
633 // Annotate Values to be converted to FP16 as FP16-compatible.
Marat Dukhan170f95a2022-02-04 02:18:23 -0800634 // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32,
Marat Dukhan4620ca62022-02-03 12:31:00 -0800635 // they will be converted to FP16 during weight repacking when the operator is created.
636 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
637 struct xnn_node* node = &subgraph->nodes[n];
638 switch (node->type) {
639 case xnn_node_type_convolution_2d:
640 case xnn_node_type_depthwise_convolution_2d:
Marat Dukhan170f95a2022-02-04 02:18:23 -0800641 case xnn_node_type_prelu:
Marat Dukhan4620ca62022-02-03 12:31:00 -0800642 subgraph->values[node->inputs[0]].fp16_compatible = true;
643 subgraph->values[node->outputs[0]].fp16_compatible = true;
644 break;
645 default:
646 for (uint32_t i = 0; i < node->num_inputs; i++) {
647 subgraph->values[node->inputs[i]].fp16_compatible = true;
648 }
649 for (uint32_t o = 0; o < node->num_outputs; o++) {
650 subgraph->values[node->outputs[o]].fp16_compatible = true;
651 }
652 break;
653 }
654 }
655
656 // Replace FP32 Values in Nodes' inputs/outputs with FP16 Values.
657 // FP32 Values that are not external inputs or outputs are converted to FP16 in-place,
658 // for external inputs and outputs we create same-shaped FP16 Values and use those instead.
659 const uint32_t num_original_values = subgraph->num_values;
660 xnn_subgraph_analyze_consumers_and_producers(subgraph);
661 for (uint32_t n = 0; n < num_original_values; n++) {
662 struct xnn_value* value = &subgraph->values[n];
663 value->fp16_id = XNN_INVALID_VALUE_ID;
664 value->fp32_id = XNN_INVALID_VALUE_ID;
665 if (value->fp16_compatible) {
666 assert(value->data == NULL);
667 assert(value->datatype == xnn_datatype_fp32);
668 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
669 struct xnn_value* fp16_value = xnn_subgraph_new_internal_value(subgraph);
670
671 // Recompute value due to potential reallocation in xnn_subgraph_new_internal_value
672 value = &subgraph->values[n];
673 xnn_value_copy(fp16_value, value);
674 fp16_value->datatype = xnn_datatype_fp16;
675
676 fp16_value->producer = value->producer;
677 fp16_value->num_consumers = value->num_consumers;
678 fp16_value->first_consumer = value->first_consumer;
679 value->producer = XNN_INVALID_NODE_ID;
680 value->num_consumers = 0;
681 value->first_consumer = XNN_INVALID_NODE_ID;
682
683 // Clear external input/output flags
684 fp16_value->flags = 0;
685 xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n);
686
687 value->fp16_id = fp16_value->id;
688 fp16_value->fp32_id = n;
689 } else {
690 xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n);
691 value->datatype = xnn_datatype_fp16;
692 }
693 }
694 }
695 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
696 struct xnn_node* node = &subgraph->nodes[n];
697 assert(node->compute_type == xnn_compute_type_fp32);
698 node->compute_type = xnn_compute_type_fp16;
Marat Dukhan4b90bee2022-02-04 00:00:18 -0800699 if (node->type == xnn_node_type_static_constant_pad) {
700 node->params.static_pad.padding_value =
701 fp16_ieee_from_fp32_value(fp32_from_bits(node->params.static_pad.padding_value));
702 }
Marat Dukhan4620ca62022-02-03 12:31:00 -0800703 for (uint32_t i = 0; i < node->num_inputs; i++) {
704 const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id;
705 if (fp16_id != XNN_INVALID_VALUE_ID) {
706 assert(subgraph->values[fp16_id].fp32_id == node->inputs[i]);
707 node->inputs[i] = fp16_id;
708 }
709 }
710 for (uint32_t o = 0; o < node->num_outputs; o++) {
711 const uint32_t fp16_id = subgraph->values[node->outputs[o]].fp16_id;
712 if (fp16_id != XNN_INVALID_VALUE_ID) {
713 assert(subgraph->values[fp16_id].fp32_id == node->outputs[o]);
714 node->outputs[o] = fp16_id;
715 }
716 }
717 }
718
719 // Count the number of external inputs and outputs which require Convert nodes
720 uint32_t num_external_inputs = 0;
721 uint32_t num_external_outputs = 0;
722 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
723 const struct xnn_node* node = &subgraph->nodes[n];
724 for (uint32_t i = 0; i < node->num_inputs; i++) {
725 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
726 if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n) {
727 assert(value->data == NULL);
728 assert(value->datatype == xnn_datatype_fp16);
729 assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
730 assert(subgraph->values[value->fp32_id].flags & XNN_VALUE_FLAG_EXTERNAL_INPUT);
731 num_external_inputs += 1;
732 }
733 }
734 for (uint32_t o = 0; o < node->num_outputs; o++) {
735 const struct xnn_value* value = &subgraph->values[node->outputs[o]];
736 if (value->fp32_id != XNN_INVALID_VALUE_ID) {
737 assert(value->datatype == xnn_datatype_fp16);
738 assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
739 assert(subgraph->values[value->fp32_id].flags & XNN_VALUE_FLAG_EXTERNAL_OUTPUT);
740 num_external_outputs += 1;
741 }
742 }
743 }
744 xnn_log_debug("Discovered %"PRIu32" external inputs and %"PRIu32" external outputs",
745 num_external_inputs, num_external_outputs);
746
747 const uint32_t num_original_nodes = subgraph->num_nodes;
748 xnn_subgraph_add_nodes(subgraph, num_external_inputs + num_external_outputs);
749 struct xnn_node* output_node = subgraph->nodes + subgraph->num_nodes - 1;
750 for (uint32_t n = num_original_nodes; n != 0; n--) {
751 const struct xnn_node* node = &subgraph->nodes[n - 1];
752 // Insert Convert nodes for outputs
753 for (uint32_t o = 0; o < node->num_outputs; o++) {
754 const struct xnn_value* value = &subgraph->values[node->outputs[o]];
755 if (value->fp32_id != XNN_INVALID_VALUE_ID) {
756 xnn_log_debug("Inserted FP16->FP32 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
757 value->id, value->fp32_id);
758 const uint32_t output_node_id = output_node->id;
759 assert(output_node >= subgraph->nodes);
760 xnn_node_clear(output_node);
761 output_node->id = output_node_id;
762 xnn_init_convert_node(output_node, xnn_compute_type_fp16_to_fp32, value->id, value->fp32_id, 0 /* flags */);
763 output_node -= 1;
764 }
765 }
766 // Move the Node to the new location
767 if (output_node != node) {
768 const uint32_t output_node_id = output_node->id;
769 assert(output_node >= subgraph->nodes);
770 memcpy(output_node, node, sizeof(struct xnn_node));
771 output_node->id = output_node_id;
772 output_node -= 1;
773 }
774 // Insert Convert nodes for inputs
775 for (uint32_t i = 0; i < node->num_inputs; i++) {
776 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
777 if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n - 1) {
778 xnn_log_debug("Inserted FP32->FP16 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
779 value->fp32_id, value->id);
780 const uint32_t output_node_id = output_node->id;
781 assert(output_node >= subgraph->nodes);
782 xnn_node_clear(output_node);
783 output_node->id = output_node_id;
784 xnn_init_convert_node(output_node, xnn_compute_type_fp32_to_fp16, value->fp32_id, value->id, 0 /* flags */);
785 output_node -= 1;
786 }
787 }
788 }
789}
790
Marat Dukhan1f198722020-05-24 14:07:03 -0700791enum xnn_status xnn_subgraph_optimize(
792 xnn_subgraph_t subgraph,
793 uint32_t flags)
794{
Marat Dukhan4620ca62022-02-03 12:31:00 -0800795 xnn_subgraph_analyze_consumers_and_producers(subgraph);
Marat Dukhan1f198722020-05-24 14:07:03 -0700796
Marat Dukhan1f198722020-05-24 14:07:03 -0700797 // Remove unreferenced values.
798 for (uint32_t i = 0; i < subgraph->num_values; i++) {
799 struct xnn_value* value = &subgraph->values[i];
800 if (value->type == xnn_value_type_invalid) {
801 continue;
802 }
803
Marat Dukhan1f198722020-05-24 14:07:03 -0700804 if ((value->flags & XNN_VALUE_FLAG_EXTERNAL_INPUT) == 0 && value->num_consumers == 0) {
805 xnn_value_clear(value);
806 }
807 }
808
809 // Fuse Nodes where possible
810 for (uint32_t i = 0; i < subgraph->num_values; i++) {
811 struct xnn_value* value = &subgraph->values[i];
812 if (value->num_consumers == 1) {
813 const uint32_t producer_id = value->producer;
814 if (producer_id == XNN_INVALID_NODE_ID) {
815 continue;
816 }
817 assert(producer_id < subgraph->num_nodes);
818
819 const uint32_t consumer_id = value->first_consumer;
820 if (consumer_id == XNN_INVALID_NODE_ID) {
821 continue;
822 }
823 assert(consumer_id < subgraph->num_nodes);
824
825 struct xnn_node* producer = &subgraph->nodes[producer_id];
826 assert(producer->type != xnn_node_type_invalid);
827 struct xnn_node* consumer = &subgraph->nodes[consumer_id];
828 assert(consumer->type != xnn_node_type_invalid);
829
830 // Try to fuse Clamp Node upstream into producer Node
831 if (consumer->type == xnn_node_type_clamp) {
832 switch (producer->type) {
833 case xnn_node_type_add2:
834 case xnn_node_type_average_pooling_2d:
835 case xnn_node_type_clamp:
836 case xnn_node_type_convolution_2d:
Marat Dukhanb293e8d2020-07-23 20:10:45 -0700837 case xnn_node_type_divide:
838 case xnn_node_type_deconvolution_2d:
Marat Dukhan1f198722020-05-24 14:07:03 -0700839 case xnn_node_type_depthwise_convolution_2d:
840 case xnn_node_type_fully_connected:
841 case xnn_node_type_multiply2:
842 case xnn_node_type_max_pooling_2d:
Marat Dukhanb293e8d2020-07-23 20:10:45 -0700843 case xnn_node_type_subtract:
Marat Dukhan1f198722020-05-24 14:07:03 -0700844 xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
845 assert(producer->num_outputs == 1);
846 assert(consumer->num_inputs == 1);
847 assert(consumer->num_outputs == 1);
848
849 const uint32_t fused_output_id = consumer->outputs[0];
850 assert(fused_output_id < subgraph->num_values);
851 subgraph->values[fused_output_id].producer = producer_id;
852 producer->outputs[0] = fused_output_id;
853
854 producer->activation.output_min =
855 math_max_f32(producer->activation.output_min, consumer->activation.output_min);
856 producer->activation.output_max =
857 math_min_f32(producer->activation.output_max, consumer->activation.output_max);
858
859 xnn_node_clear(consumer);
860 xnn_value_clear(value);
861 break;
862 default:
863 break;
864 }
865 }
Marat Dukhanf3d12052020-05-25 15:41:37 -0700866 // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
Marat Dukhanaff24e22020-07-23 01:43:58 -0700867 if (producer->type == xnn_node_type_static_constant_pad) {
Marat Dukhanf3d12052020-05-25 15:41:37 -0700868 assert(producer->num_inputs == 1);
869 assert(producer->num_outputs == 1);
Marat Dukhan8c965212021-08-09 11:25:40 -0700870 const bool is_spatial_2d_padding = value->shape.num_dims == 4 &&
Marat Dukhanf3d12052020-05-25 15:41:37 -0700871 (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
Marat Dukhan8c965212021-08-09 11:25:40 -0700872 producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0;
873 const enum xnn_datatype padding_datatype = subgraph->values[producer->outputs[0]].datatype;
874 const uint32_t padding_value = producer->params.static_pad.padding_value;
875 const bool is_zero_padding =
876 (padding_datatype == xnn_datatype_fp32 && padding_value == 0) ||
877 ((padding_datatype == xnn_datatype_qint8 || padding_datatype == xnn_datatype_quint8) &&
878 padding_value == (uint32_t) (uint8_t) subgraph->values[producer->outputs[0]].quantization.zero_point);
Marat Dukhanf3d12052020-05-25 15:41:37 -0700879 switch (consumer->type) {
880 case xnn_node_type_convolution_2d:
Marat Dukhan8c965212021-08-09 11:25:40 -0700881 if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
Marat Dukhanf3d12052020-05-25 15:41:37 -0700882 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
883 consumer_id, producer_id);
884 assert(consumer->num_inputs >= 1);
885 assert(consumer->inputs[0] == producer->outputs[0]);
886
887 consumer->params.convolution_2d.input_padding_top += producer->params.static_pad.pre_paddings[1];
Marat Dukhanfacecc52020-08-10 08:00:08 -0700888 consumer->params.convolution_2d.input_padding_right += producer->params.static_pad.post_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700889 consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
Marat Dukhanfacecc52020-08-10 08:00:08 -0700890 consumer->params.convolution_2d.input_padding_left += producer->params.static_pad.pre_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700891
892 consumer->inputs[0] = producer->inputs[0];
893
894 const uint32_t fused_input_id = producer->inputs[0];
895 assert(fused_input_id < subgraph->num_values);
896 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
897 subgraph->values[fused_input_id].first_consumer = consumer_id;
898 }
899
900 xnn_node_clear(producer);
901 xnn_value_clear(value);
902 }
903 break;
904 case xnn_node_type_depthwise_convolution_2d:
Marat Dukhan8c965212021-08-09 11:25:40 -0700905 if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
Marat Dukhanf3d12052020-05-25 15:41:37 -0700906 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
907 consumer_id, producer_id);
908 assert(consumer->num_inputs >= 1);
909 assert(consumer->inputs[0] == producer->outputs[0]);
910
911 consumer->params.depthwise_convolution_2d.input_padding_top +=
912 producer->params.static_pad.pre_paddings[1];
913 consumer->params.depthwise_convolution_2d.input_padding_right +=
Marat Dukhanfacecc52020-08-10 08:00:08 -0700914 producer->params.static_pad.post_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700915 consumer->params.depthwise_convolution_2d.input_padding_bottom +=
916 producer->params.static_pad.post_paddings[1];
917 consumer->params.depthwise_convolution_2d.input_padding_left +=
Marat Dukhanfacecc52020-08-10 08:00:08 -0700918 producer->params.static_pad.pre_paddings[2];
Marat Dukhanf3d12052020-05-25 15:41:37 -0700919
920 consumer->inputs[0] = producer->inputs[0];
921
922 const uint32_t fused_input_id = producer->inputs[0];
923 assert(fused_input_id < subgraph->num_values);
924 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
925 subgraph->values[fused_input_id].first_consumer = consumer_id;
926 }
927
928 xnn_node_clear(producer);
929 xnn_value_clear(value);
930 }
931 break;
932 default:
933 break;
934 }
935 }
Marat Dukhan1f198722020-05-24 14:07:03 -0700936 }
937 }
Marat Dukhan9de90e02020-06-18 16:04:12 -0700938
939 #if XNN_ENABLE_SPARSE
Marat Dukhancfbed0a2020-12-08 10:01:51 -0800940 if ((flags & XNN_FLAG_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
Marat Dukhan7332e832020-12-06 23:26:11 -0800941 xnn_subgraph_rewrite_for_nchw(subgraph);
942 }
Marat Dukhan9de90e02020-06-18 16:04:12 -0700943 #endif
944
Marat Dukhan4620ca62022-02-03 12:31:00 -0800945 #ifndef XNN_NO_F16_OPERATORS
946 if ((flags & XNN_FLAG_FP16_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_F16)) {
947 xnn_subgraph_rewrite_for_fp16(subgraph);
948 }
949 #endif // XNN_NO_F16_OPERATORS
950
Marat Dukhan1f198722020-05-24 14:07:03 -0700951 return xnn_status_success;
952}
953
Marat Dukhan1d75a542020-02-03 12:23:01 -0800954enum xnn_status xnn_delete_subgraph(
955 xnn_subgraph_t subgraph)
956{
957 if (subgraph != NULL) {
958 memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
959 xnn_release_memory(subgraph->nodes);
960
961 memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
962 xnn_release_memory(subgraph->values);
963
964 memset(subgraph, 0, sizeof(struct xnn_subgraph));
965 xnn_release_memory(subgraph);
966 }
967 return xnn_status_success;
968}