Marat Dukhan | 1d75a54 | 2020-02-03 12:23:01 -0800 | [diff] [blame] | 1 | // Copyright 2020 Google LLC |
| 2 | // |
| 3 | // This source code is licensed under the BSD-style license found in the |
| 4 | // LICENSE file in the root directory of this source tree. |
| 5 | |
| 6 | #include <math.h> |
| 7 | #include <stddef.h> |
| 8 | #include <stdint.h> |
| 9 | #include <stdlib.h> |
| 10 | |
Marat Dukhan | 4b90bee | 2022-02-04 00:00:18 -0800 | [diff] [blame] | 11 | #include <fp16.h> |
| 12 | |
Marat Dukhan | 1d75a54 | 2020-02-03 12:23:01 -0800 | [diff] [blame] | 13 | #include <xnnpack.h> |
| 14 | #include <xnnpack/allocator.h> |
| 15 | #include <xnnpack/log.h> |
| 16 | #include <xnnpack/math.h> |
| 17 | #include <xnnpack/params.h> |
| 18 | #include <xnnpack/subgraph.h> |
| 19 | |
| 20 | |
| 21 | enum xnn_status xnn_create_subgraph( |
| 22 | uint32_t external_value_ids, |
| 23 | uint32_t flags, |
| 24 | xnn_subgraph_t* subgraph_out) |
| 25 | { |
| 26 | struct xnn_subgraph* subgraph = NULL; |
| 27 | enum xnn_status status = xnn_status_uninitialized; |
| 28 | |
Marat Dukhan | 854fb6b | 2020-06-19 12:33:44 -0700 | [diff] [blame] | 29 | if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { |
Marat Dukhan | 1d75a54 | 2020-02-03 12:23:01 -0800 | [diff] [blame] | 30 | xnn_log_error("failed to create subgraph: XNNPACK is not initialized"); |
| 31 | goto error; |
| 32 | } |
| 33 | |
| 34 | status = xnn_status_out_of_memory; |
| 35 | |
| 36 | subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph)); |
| 37 | if (subgraph == NULL) { |
| 38 | xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph)); |
| 39 | goto error; |
| 40 | } |
| 41 | |
| 42 | subgraph->external_value_ids = external_value_ids; |
| 43 | |
| 44 | subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value)); |
| 45 | if (subgraph->values == NULL) { |
| 46 | xnn_log_error("failed to allocate %zu bytes for subgraph values", external_value_ids * sizeof(struct xnn_value)); |
| 47 | goto error; |
| 48 | } |
| 49 | for (size_t i = 0; i < external_value_ids; i++) { |
| 50 | subgraph->values[i].id = i; |
| 51 | } |
| 52 | subgraph->num_values = external_value_ids; |
| 53 | subgraph->num_reserved_values = external_value_ids; |
| 54 | |
| 55 | *subgraph_out = subgraph; |
| 56 | return xnn_status_success; |
| 57 | |
| 58 | error: |
| 59 | xnn_delete_subgraph(subgraph); |
| 60 | return status; |
| 61 | } |
| 62 | |
| 63 | |
| 64 | struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph) |
| 65 | { |
| 66 | struct xnn_value* values = subgraph->values; |
| 67 | const size_t size = subgraph->num_values; |
| 68 | const size_t capacity = subgraph->num_reserved_values; |
| 69 | if (capacity < size + 1) { |
| 70 | const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64); |
| 71 | assert(new_capacity >= size + 1); |
| 72 | values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value)); |
| 73 | if (values == NULL) { |
| 74 | xnn_log_error("failed to allocate %zu bytes for subgraph values", |
| 75 | capacity * sizeof(struct xnn_value)); |
| 76 | return values; |
| 77 | } |
| 78 | |
| 79 | memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value)); |
| 80 | subgraph->num_reserved_values = new_capacity; |
| 81 | subgraph->values = values; |
| 82 | } |
| 83 | subgraph->num_values = size + 1; |
| 84 | struct xnn_value* new_value = values + size; |
| 85 | new_value->id = size; |
| 86 | return new_value; |
| 87 | } |
| 88 | |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 89 | void xnn_node_clear(struct xnn_node* node) { |
| 90 | assert(node != NULL); |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 91 | memset(node, 0, sizeof(struct xnn_node)); |
| 92 | } |
| 93 | |
| 94 | void xnn_value_clear(struct xnn_value* value) { |
| 95 | assert(value != NULL); |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 96 | memset(value, 0, sizeof(struct xnn_value)); |
| 97 | } |
| 98 | |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 99 | void xnn_value_copy( |
| 100 | struct xnn_value* dst_value, |
| 101 | const struct xnn_value* src_value) |
| 102 | { |
| 103 | // Note: Value ID stays unchanged |
| 104 | |
| 105 | dst_value->type = src_value->type; |
| 106 | dst_value->datatype = src_value->datatype; |
| 107 | dst_value->quantization = src_value->quantization; |
| 108 | dst_value->shape = src_value->shape; |
| 109 | dst_value->flags = src_value->flags; |
| 110 | dst_value->data = src_value->data; |
| 111 | dst_value->producer = src_value->producer; |
| 112 | dst_value->first_consumer = src_value->first_consumer; |
| 113 | } |
| 114 | |
Marat Dukhan | 1d75a54 | 2020-02-03 12:23:01 -0800 | [diff] [blame] | 115 | struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph) |
| 116 | { |
| 117 | struct xnn_node* nodes = subgraph->nodes; |
| 118 | const size_t size = subgraph->num_nodes; |
| 119 | const size_t capacity = subgraph->num_reserved_nodes; |
| 120 | |
| 121 | if (capacity < size + 1) { |
| 122 | const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64); |
| 123 | assert(new_capacity >= size + 1); |
| 124 | nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node)); |
| 125 | if (nodes == NULL) { |
| 126 | xnn_log_error("failed to allocate %zu bytes for subgraph nodes", |
| 127 | capacity * sizeof(struct xnn_node)); |
| 128 | return nodes; |
| 129 | } |
| 130 | |
| 131 | memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node)); |
| 132 | subgraph->num_reserved_nodes = new_capacity; |
| 133 | subgraph->nodes = nodes; |
| 134 | } |
| 135 | subgraph->num_nodes = size + 1; |
| 136 | struct xnn_node* new_node = nodes + size; |
| 137 | new_node->id = size; |
| 138 | return new_node; |
| 139 | } |
| 140 | |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 141 | void xnn_subgraph_add_nodes(xnn_subgraph_t subgraph, size_t num_nodes) |
| 142 | { |
| 143 | struct xnn_node* nodes = subgraph->nodes; |
| 144 | const size_t size = subgraph->num_nodes; |
| 145 | const size_t capacity = subgraph->num_reserved_nodes; |
| 146 | |
| 147 | if (capacity < size + num_nodes) { |
| 148 | const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + max(num_nodes, 64)); |
| 149 | assert(new_capacity >= size + num_nodes); |
| 150 | nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node)); |
| 151 | if (nodes == NULL) { |
| 152 | xnn_log_error("failed to allocate %zu bytes for subgraph nodes", |
| 153 | capacity * sizeof(struct xnn_node)); |
| 154 | return; |
| 155 | } |
| 156 | |
| 157 | memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node)); |
| 158 | subgraph->num_reserved_nodes = new_capacity; |
| 159 | subgraph->nodes = nodes; |
| 160 | } |
| 161 | subgraph->num_nodes = size + num_nodes; |
| 162 | struct xnn_node* new_nodes = nodes + size; |
| 163 | for (size_t i = 0; i < num_nodes; i++) { |
| 164 | new_nodes[i].id = size + i; |
| 165 | } |
| 166 | } |
| 167 | |
| 168 | void xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph) |
| 169 | { |
| 170 | // Initialize producer/consumer fields to safe defaults. |
| 171 | for (uint32_t i = 0; i < subgraph->num_values; i++) { |
| 172 | struct xnn_value* value = &subgraph->values[i]; |
| 173 | value->producer = XNN_INVALID_NODE_ID; |
| 174 | value->first_consumer = XNN_INVALID_NODE_ID; |
| 175 | value->num_consumers = 0; |
| 176 | } |
| 177 | |
| 178 | // Analyse Nodes' inputs and output and update Values' producer/consumer fields |
| 179 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 180 | struct xnn_node* node = &subgraph->nodes[n]; |
| 181 | |
| 182 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 183 | const uint32_t input_id = node->inputs[i]; |
| 184 | assert(input_id < subgraph->num_values); |
| 185 | |
| 186 | if (subgraph->values[input_id].num_consumers++ == 0) { |
| 187 | assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID); |
| 188 | subgraph->values[input_id].first_consumer = n; |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | for (uint32_t o = 0; o < node->num_outputs; o++) { |
| 193 | const uint32_t output_id = node->outputs[o]; |
| 194 | assert(output_id < subgraph->num_values); |
| 195 | |
| 196 | assert(subgraph->values[output_id].producer == XNN_INVALID_NODE_ID); |
| 197 | subgraph->values[output_id].producer = n; |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | // Count extra consumer for Values which are external outputs. |
| 202 | // Remove unreferenced values. |
| 203 | for (uint32_t i = 0; i < subgraph->num_values; i++) { |
| 204 | struct xnn_value* value = &subgraph->values[i]; |
| 205 | if (value->flags & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) { |
| 206 | value->num_consumers += 1; |
| 207 | } |
| 208 | } |
| 209 | } |
| 210 | |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 211 | #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1 |
| 212 | #define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2 |
| 213 | #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4 |
| 214 | #define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8 |
| 215 | |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 216 | uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) { |
Marat Dukhan | d2ad6d0 | 2021-11-14 19:37:26 -0800 | [diff] [blame] | 217 | if (node->compute_type != xnn_compute_type_fp32) { |
| 218 | return 0; |
| 219 | } |
| 220 | |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 221 | switch (node->type) { |
| 222 | case xnn_node_type_convolution_2d: |
| 223 | // Supported cases: |
| 224 | // - 1x1 convolution (no stride, no dilation, no padding, no groups) |
| 225 | // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels) |
| 226 | if (node->params.convolution_2d.groups != 1) { |
| 227 | return 0; |
| 228 | } |
| 229 | if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) { |
| 230 | return 0; |
| 231 | } |
| 232 | if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) { |
| 233 | if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right | |
| 234 | node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0) |
| 235 | { |
| 236 | return 0; |
| 237 | } |
| 238 | if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) { |
| 239 | return 0; |
| 240 | } |
| 241 | return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW; |
| 242 | } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) { |
| 243 | if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 || |
| 244 | node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1) |
| 245 | { |
| 246 | return 0; |
| 247 | } |
| 248 | if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) { |
| 249 | return 0; |
| 250 | } |
| 251 | if (node->params.convolution_2d.group_input_channels != 3) { |
| 252 | return 0; |
| 253 | } |
| 254 | return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW; |
| 255 | } |
| 256 | return 0; |
| 257 | case xnn_node_type_depthwise_convolution_2d: |
| 258 | // Supported cases: |
| 259 | // - 3x3 stride-1 convolution (no dilation, padding 1 on each side) |
| 260 | // - 3x3 stride-2 convolution (no dilation, padding 1 on each side) |
| 261 | // - 5x5 stride-1 convolution (no dilation, padding 2 on each side) |
| 262 | // - 5x5 stride-2 convolution (no dilation, padding 2 on each side) |
| 263 | if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) { |
| 264 | return 0; |
| 265 | } |
| 266 | if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) { |
| 267 | return 0; |
| 268 | } |
| 269 | if (node->params.depthwise_convolution_2d.depth_multiplier != 1) { |
| 270 | return 0; |
| 271 | } |
| 272 | if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) { |
| 273 | return 0; |
| 274 | } |
| 275 | switch (node->params.depthwise_convolution_2d.subsampling_height) { |
| 276 | case 1: |
| 277 | case 2: |
| 278 | break; |
| 279 | default: |
| 280 | return 0; |
| 281 | } |
| 282 | if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) { |
| 283 | return 0; |
| 284 | } |
| 285 | switch (node->params.depthwise_convolution_2d.kernel_height) { |
| 286 | case 3: |
| 287 | return node->params.depthwise_convolution_2d.input_padding_top == 1 && |
| 288 | node->params.depthwise_convolution_2d.input_padding_right == 1 && |
| 289 | node->params.depthwise_convolution_2d.input_padding_bottom == 1 && |
| 290 | node->params.depthwise_convolution_2d.input_padding_left == 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0; |
| 291 | case 5: |
| 292 | return node->params.depthwise_convolution_2d.input_padding_top == 2 && |
| 293 | node->params.depthwise_convolution_2d.input_padding_right == 2 && |
| 294 | node->params.depthwise_convolution_2d.input_padding_bottom == 2 && |
| 295 | node->params.depthwise_convolution_2d.input_padding_left == 2 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0; |
| 296 | default: |
| 297 | return 0; |
| 298 | } |
Artsiom Ablavatski | bbe8506 | 2020-11-05 14:07:37 -0800 | [diff] [blame] | 299 | case xnn_node_type_depth_to_space: |
Marat Dukhan | f56b4bb | 2020-12-06 19:06:04 -0800 | [diff] [blame] | 300 | return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC; |
| 301 | case xnn_node_type_global_average_pooling_2d: |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 302 | return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC; |
| 303 | case xnn_node_type_add2: |
| 304 | case xnn_node_type_multiply2: |
| 305 | assert(node->num_inputs == 2); |
| 306 | assert(node->num_outputs == 1); |
| 307 | if (subgraph->values[node->inputs[0]].shape.num_dims != 4 || |
| 308 | subgraph->values[node->inputs[1]].shape.num_dims != 4) |
| 309 | { |
| 310 | return 0; |
| 311 | } |
| 312 | |
| 313 | if (subgraph->values[node->inputs[0]].data != NULL) { |
| 314 | // Check that the first input is representable as either a scalar, or a vector |
| 315 | size_t num_nonunit_dims = 0; |
| 316 | for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) { |
| 317 | if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) { |
| 318 | num_nonunit_dims += 1; |
| 319 | } |
| 320 | } |
| 321 | if (num_nonunit_dims > 1) { |
| 322 | return 0; |
| 323 | } |
| 324 | } |
| 325 | |
| 326 | if (subgraph->values[node->inputs[1]].data != NULL) { |
| 327 | // Check that the second input is representable as either a scalar, or a vector |
| 328 | size_t num_nonunit_dims = 0; |
| 329 | for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) { |
| 330 | if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) { |
| 331 | num_nonunit_dims += 1; |
| 332 | } |
| 333 | } |
| 334 | if (num_nonunit_dims > 1) { |
| 335 | return 0; |
| 336 | } |
| 337 | } |
| 338 | |
| 339 | return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW; |
Artsiom Ablavatski | e6beeba | 2020-10-28 09:12:19 -0700 | [diff] [blame] | 340 | case xnn_node_type_static_resize_bilinear_2d: |
| 341 | return subgraph->values[node->inputs[0]].shape.dim[1] > 1 && |
| 342 | subgraph->values[node->inputs[0]].shape.dim[2] > 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0; |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 343 | case xnn_node_type_abs: |
| 344 | case xnn_node_type_bankers_rounding: |
| 345 | case xnn_node_type_ceiling: |
| 346 | case xnn_node_type_clamp: |
Marat Dukhan | 094e692 | 2020-12-08 12:54:38 -0800 | [diff] [blame] | 347 | case xnn_node_type_elu: |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 348 | case xnn_node_type_floor: |
| 349 | case xnn_node_type_hardswish: |
| 350 | case xnn_node_type_leaky_relu: |
| 351 | case xnn_node_type_negate: |
| 352 | case xnn_node_type_sigmoid: |
| 353 | case xnn_node_type_square: |
| 354 | assert(node->num_inputs == 1); |
| 355 | assert(node->num_outputs == 1); |
| 356 | return subgraph->values[node->inputs[0]].shape.num_dims == 4 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0; |
| 357 | default: |
| 358 | return false; |
| 359 | } |
| 360 | } |
| 361 | |
XNNPACK Team | ab8c4c8 | 2020-10-09 08:05:51 -0700 | [diff] [blame] | 362 | void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph) |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 363 | { |
| 364 | // Convert parts of the subgraph to NCHW for sparse inference |
| 365 | // Step 1: detect NCHW-compatible Nodes |
| 366 | // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm) |
| 367 | // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes |
| 368 | // Step 4: switch Values' layout to NCHW |
| 369 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 370 | struct xnn_node* node = &subgraph->nodes[n]; |
| 371 | node->layout_flags = xnn_check_nchw_compatibility(subgraph, node); |
| 372 | xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)", |
| 373 | n, xnn_node_type_to_string(node->type), |
| 374 | node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no", |
| 375 | node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no", |
| 376 | node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no"); |
| 377 | } |
| 378 | |
XNNPACK Team | a117ce7 | 2020-10-05 17:26:02 -0700 | [diff] [blame] | 379 | // Run Shiloach-Vishkin connected components algorithm i.e. find all |
| 380 | // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders |
| 381 | // to all the producer nodes |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 382 | bool update = false; |
| 383 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 384 | struct xnn_node* node = &subgraph->nodes[n]; |
| 385 | node->cluster_leader = n; |
| 386 | if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) { |
| 387 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 388 | const struct xnn_value* value = &subgraph->values[node->inputs[i]]; |
| 389 | if (value->data != NULL) { |
| 390 | // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated |
| 391 | // during the initial NCHW compatibility check for the Node. |
| 392 | continue; |
| 393 | } |
| 394 | if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) { |
| 395 | // External value, invalid cluster |
| 396 | node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER; |
| 397 | continue; |
| 398 | } |
| 399 | const uint32_t producer_id = value->producer; |
| 400 | assert(producer_id != XNN_INVALID_NODE_ID); |
| 401 | assert(producer_id < n); |
| 402 | struct xnn_node* producer_node = &subgraph->nodes[producer_id]; |
| 403 | if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 && |
| 404 | (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0) |
| 405 | { |
| 406 | producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC; |
| 407 | if (producer_node->cluster_leader != node->cluster_leader) { |
| 408 | producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader); |
| 409 | update = true; |
| 410 | } |
| 411 | } else { |
| 412 | node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER; |
| 413 | } |
| 414 | } |
| 415 | } |
| 416 | } |
XNNPACK Team | a117ce7 | 2020-10-05 17:26:02 -0700 | [diff] [blame] | 417 | // No NCHW2NHWC compatible nodes have been found thus the graph rewriting |
slowy07 | ab1127f | 2021-07-27 08:23:22 +0700 | [diff] [blame] | 418 | // practically cannot happen. |
XNNPACK Team | a117ce7 | 2020-10-05 17:26:02 -0700 | [diff] [blame] | 419 | if (!update) { |
| 420 | return; |
| 421 | } |
| 422 | // Propagate the cluster leader to other nodes in the graph untill all the |
| 423 | // nodes in the cluster is not updated |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 424 | while (update) { |
| 425 | update = false; |
| 426 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 427 | struct xnn_node* node = &subgraph->nodes[n]; |
| 428 | if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) { |
| 429 | continue; |
| 430 | } |
| 431 | |
| 432 | if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) { |
| 433 | continue; |
| 434 | } |
| 435 | |
| 436 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 437 | const struct xnn_value* value = &subgraph->values[node->inputs[i]]; |
| 438 | if (value->data != NULL) { |
| 439 | // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated |
| 440 | // during the initial NCHW compatibility check for the Node. |
| 441 | continue; |
| 442 | } |
| 443 | if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) { |
| 444 | // External value, invalid cluster |
| 445 | node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER; |
| 446 | continue; |
| 447 | } |
| 448 | const uint32_t producer_id = value->producer; |
| 449 | assert(producer_id != XNN_INVALID_NODE_ID); |
| 450 | assert(producer_id < n); |
| 451 | struct xnn_node* producer_node = &subgraph->nodes[producer_id]; |
| 452 | if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 && |
| 453 | (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0) |
| 454 | { |
| 455 | producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC; |
| 456 | if (producer_node->cluster_leader != node->cluster_leader) { |
| 457 | producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader); |
| 458 | update = true; |
| 459 | } |
| 460 | } else { |
| 461 | node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER; |
| 462 | } |
| 463 | } |
| 464 | } |
| 465 | } |
| 466 | // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders |
| 467 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 468 | struct xnn_node* node = &subgraph->nodes[n]; |
| 469 | subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER; |
| 470 | } |
| 471 | // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers |
| 472 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 473 | struct xnn_node* node = &subgraph->nodes[n]; |
| 474 | if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) { |
| 475 | continue; |
| 476 | } |
| 477 | |
| 478 | if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) { |
| 479 | continue; |
| 480 | } |
| 481 | |
| 482 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 483 | struct xnn_value* value = &subgraph->values[node->inputs[i]]; |
| 484 | if (value->data != NULL) { |
| 485 | // Static data, skip this input value because it doesn't have a producer Node. |
| 486 | continue; |
| 487 | } |
| 488 | assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0); |
| 489 | value->num_nchw_compatible_consumers += 1; |
| 490 | } |
| 491 | } |
| 492 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 493 | struct xnn_node* node = &subgraph->nodes[n]; |
| 494 | if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) { |
| 495 | continue; |
| 496 | } |
| 497 | |
| 498 | if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) { |
| 499 | continue; |
| 500 | } |
| 501 | |
| 502 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 503 | const struct xnn_value* value = &subgraph->values[node->inputs[i]]; |
| 504 | if (value->data != NULL) { |
| 505 | // Static data, skip this input value because it doesn't have a producer Node. |
| 506 | continue; |
| 507 | } |
| 508 | assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0); |
| 509 | assert(value->num_nchw_compatible_consumers > 0); |
| 510 | if (value->num_nchw_compatible_consumers != value->num_consumers) { |
| 511 | subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER; |
| 512 | } |
| 513 | } |
| 514 | } |
Marat Dukhan | 54b2d54 | 2020-12-08 00:19:52 -0800 | [diff] [blame] | 515 | // Evaluate if it is profitable to run the model as sparse: |
| 516 | // - Compute the number of parameters and zeroes in 1x1 Convolution weights |
| 517 | // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0) |
| 518 | // or with less than 2/3rd of zeroes in 1x1 Convolution filters |
| 519 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 520 | struct xnn_node* node = &subgraph->nodes[n]; |
| 521 | if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) { |
| 522 | continue; |
| 523 | } |
| 524 | |
| 525 | if (node->type == xnn_node_type_convolution_2d && |
| 526 | max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1) |
| 527 | { |
| 528 | assert(node->num_inputs >= 2); |
| 529 | |
| 530 | const struct xnn_value* filter = &subgraph->values[node->inputs[1]]; |
| 531 | assert(filter->data != NULL); |
| 532 | assert(filter->shape.num_dims == 4); |
| 533 | |
| 534 | const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3]; |
| 535 | subgraph->nodes[node->cluster_leader].num_params += num_params; |
| 536 | |
| 537 | const float* data = (const float*) filter->data; |
| 538 | size_t num_zeroes = 0; |
| 539 | for (size_t i = 0; i < num_params; i++) { |
| 540 | num_zeroes += (size_t) (data[i] == 0.0f); |
| 541 | } |
| 542 | xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params); |
| 543 | subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes; |
| 544 | } |
| 545 | } |
Artsiom Ablavatski | cd3e068 | 2021-06-02 19:25:22 -0700 | [diff] [blame] | 546 | bool use_nchw_layout = false; |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 547 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 548 | struct xnn_node* node = &subgraph->nodes[n]; |
| 549 | if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) { |
| 550 | continue; |
| 551 | } |
| 552 | |
| 553 | if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) { |
| 554 | continue; |
| 555 | } |
| 556 | |
Marat Dukhan | 54b2d54 | 2020-12-08 00:19:52 -0800 | [diff] [blame] | 557 | if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) { |
| 558 | xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights", |
| 559 | n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params); |
| 560 | continue; |
| 561 | } |
| 562 | |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 563 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 564 | struct xnn_value* value = &subgraph->values[node->inputs[i]]; |
| 565 | if (value->data != NULL) { |
| 566 | // Static data, skip this input value because it doesn't have a producer Node. |
| 567 | continue; |
| 568 | } |
| 569 | assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0); |
| 570 | assert(value->num_nchw_compatible_consumers > 0); |
| 571 | assert(value->num_nchw_compatible_consumers == value->num_consumers); |
| 572 | if (value->layout != xnn_layout_type_nchw) { |
| 573 | value->layout = xnn_layout_type_nchw; |
| 574 | xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]); |
Artsiom Ablavatski | cd3e068 | 2021-06-02 19:25:22 -0700 | [diff] [blame] | 575 | use_nchw_layout = true; |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 576 | } |
| 577 | } |
| 578 | } |
Artsiom Ablavatski | cd3e068 | 2021-06-02 19:25:22 -0700 | [diff] [blame] | 579 | if (use_nchw_layout) { |
| 580 | xnn_log_info("XNNPACK has switched to sparse inference mode!"); |
| 581 | } |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 582 | } |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 583 | |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 584 | void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) |
| 585 | { |
| 586 | xnn_log_info("Analyzing subgraph for FP16 compatibility"); |
| 587 | |
| 588 | // Convert tensors and operators in the subgraph to FP16 |
| 589 | // 1. Check that all operators in the subgraph are supported in FP16. |
| 590 | // 2. Indicate values that must be converted to FP16. |
| 591 | // 3. Replace FP32 Values with FP16 Values as Nodes' inputs/outputs. |
| 592 | // 4. Insert FP32->FP16 Convert Nodes for external FP32 inputs and FP16->FP32 Convert Nodes for external outputs. |
| 593 | |
| 594 | // Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one. |
| 595 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 596 | struct xnn_node* node = &subgraph->nodes[n]; |
Marat Dukhan | 170f95a | 2022-02-04 02:18:23 -0800 | [diff] [blame] | 597 | if (node->type == xnn_node_type_invalid) { |
| 598 | // Node was fused away, skip. |
| 599 | continue; |
| 600 | } |
| 601 | |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 602 | if (node->compute_type != xnn_compute_type_fp32) { |
| 603 | xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type)); |
| 604 | return; |
| 605 | } |
| 606 | switch (node->type) { |
| 607 | case xnn_node_type_add2: |
| 608 | assert(node->num_inputs == 2); |
| 609 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 610 | if (subgraph->values[node->inputs[i]].data != NULL) { |
| 611 | xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) has static input %i", |
| 612 | n, xnn_node_type_to_string(node->type), i); |
| 613 | return; |
| 614 | } |
| 615 | } |
| 616 | break; |
| 617 | case xnn_node_type_convolution_2d: |
| 618 | case xnn_node_type_depthwise_convolution_2d: |
| 619 | case xnn_node_type_global_average_pooling_2d: |
| 620 | case xnn_node_type_hardswish: |
Marat Dukhan | 670826b | 2022-02-04 02:36:08 -0800 | [diff] [blame] | 621 | case xnn_node_type_max_pooling_2d: |
Marat Dukhan | 170f95a | 2022-02-04 02:18:23 -0800 | [diff] [blame] | 622 | case xnn_node_type_prelu: |
Marat Dukhan | 4b90bee | 2022-02-04 00:00:18 -0800 | [diff] [blame] | 623 | case xnn_node_type_static_constant_pad: |
Marat Dukhan | cb872b0 | 2022-02-04 04:05:35 -0800 | [diff] [blame] | 624 | case xnn_node_type_static_reshape: |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 625 | break; |
| 626 | default: |
| 627 | xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference", |
| 628 | n, xnn_node_type_to_string(node->type)); |
| 629 | return; |
| 630 | } |
| 631 | } |
| 632 | |
| 633 | // Annotate Values to be converted to FP16 as FP16-compatible. |
Marat Dukhan | 170f95a | 2022-02-04 02:18:23 -0800 | [diff] [blame] | 634 | // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32, |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 635 | // they will be converted to FP16 during weight repacking when the operator is created. |
| 636 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 637 | struct xnn_node* node = &subgraph->nodes[n]; |
| 638 | switch (node->type) { |
| 639 | case xnn_node_type_convolution_2d: |
| 640 | case xnn_node_type_depthwise_convolution_2d: |
Marat Dukhan | 170f95a | 2022-02-04 02:18:23 -0800 | [diff] [blame] | 641 | case xnn_node_type_prelu: |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 642 | subgraph->values[node->inputs[0]].fp16_compatible = true; |
| 643 | subgraph->values[node->outputs[0]].fp16_compatible = true; |
| 644 | break; |
| 645 | default: |
| 646 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 647 | subgraph->values[node->inputs[i]].fp16_compatible = true; |
| 648 | } |
| 649 | for (uint32_t o = 0; o < node->num_outputs; o++) { |
| 650 | subgraph->values[node->outputs[o]].fp16_compatible = true; |
| 651 | } |
| 652 | break; |
| 653 | } |
| 654 | } |
| 655 | |
| 656 | // Replace FP32 Values in Nodes' inputs/outputs with FP16 Values. |
| 657 | // FP32 Values that are not external inputs or outputs are converted to FP16 in-place, |
| 658 | // for external inputs and outputs we create same-shaped FP16 Values and use those instead. |
| 659 | const uint32_t num_original_values = subgraph->num_values; |
| 660 | xnn_subgraph_analyze_consumers_and_producers(subgraph); |
| 661 | for (uint32_t n = 0; n < num_original_values; n++) { |
| 662 | struct xnn_value* value = &subgraph->values[n]; |
| 663 | value->fp16_id = XNN_INVALID_VALUE_ID; |
| 664 | value->fp32_id = XNN_INVALID_VALUE_ID; |
| 665 | if (value->fp16_compatible) { |
| 666 | assert(value->data == NULL); |
| 667 | assert(value->datatype == xnn_datatype_fp32); |
| 668 | if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) { |
| 669 | struct xnn_value* fp16_value = xnn_subgraph_new_internal_value(subgraph); |
| 670 | |
| 671 | // Recompute value due to potential reallocation in xnn_subgraph_new_internal_value |
| 672 | value = &subgraph->values[n]; |
| 673 | xnn_value_copy(fp16_value, value); |
| 674 | fp16_value->datatype = xnn_datatype_fp16; |
| 675 | |
| 676 | fp16_value->producer = value->producer; |
| 677 | fp16_value->num_consumers = value->num_consumers; |
| 678 | fp16_value->first_consumer = value->first_consumer; |
| 679 | value->producer = XNN_INVALID_NODE_ID; |
| 680 | value->num_consumers = 0; |
| 681 | value->first_consumer = XNN_INVALID_NODE_ID; |
| 682 | |
| 683 | // Clear external input/output flags |
| 684 | fp16_value->flags = 0; |
| 685 | xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n); |
| 686 | |
| 687 | value->fp16_id = fp16_value->id; |
| 688 | fp16_value->fp32_id = n; |
| 689 | } else { |
| 690 | xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n); |
| 691 | value->datatype = xnn_datatype_fp16; |
| 692 | } |
| 693 | } |
| 694 | } |
| 695 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 696 | struct xnn_node* node = &subgraph->nodes[n]; |
| 697 | assert(node->compute_type == xnn_compute_type_fp32); |
| 698 | node->compute_type = xnn_compute_type_fp16; |
Marat Dukhan | 4b90bee | 2022-02-04 00:00:18 -0800 | [diff] [blame] | 699 | if (node->type == xnn_node_type_static_constant_pad) { |
| 700 | node->params.static_pad.padding_value = |
| 701 | fp16_ieee_from_fp32_value(fp32_from_bits(node->params.static_pad.padding_value)); |
| 702 | } |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 703 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 704 | const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id; |
| 705 | if (fp16_id != XNN_INVALID_VALUE_ID) { |
| 706 | assert(subgraph->values[fp16_id].fp32_id == node->inputs[i]); |
| 707 | node->inputs[i] = fp16_id; |
| 708 | } |
| 709 | } |
| 710 | for (uint32_t o = 0; o < node->num_outputs; o++) { |
| 711 | const uint32_t fp16_id = subgraph->values[node->outputs[o]].fp16_id; |
| 712 | if (fp16_id != XNN_INVALID_VALUE_ID) { |
| 713 | assert(subgraph->values[fp16_id].fp32_id == node->outputs[o]); |
| 714 | node->outputs[o] = fp16_id; |
| 715 | } |
| 716 | } |
| 717 | } |
| 718 | |
| 719 | // Count the number of external inputs and outputs which require Convert nodes |
| 720 | uint32_t num_external_inputs = 0; |
| 721 | uint32_t num_external_outputs = 0; |
| 722 | for (uint32_t n = 0; n < subgraph->num_nodes; n++) { |
| 723 | const struct xnn_node* node = &subgraph->nodes[n]; |
| 724 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 725 | const struct xnn_value* value = &subgraph->values[node->inputs[i]]; |
| 726 | if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n) { |
| 727 | assert(value->data == NULL); |
| 728 | assert(value->datatype == xnn_datatype_fp16); |
| 729 | assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32); |
| 730 | assert(subgraph->values[value->fp32_id].flags & XNN_VALUE_FLAG_EXTERNAL_INPUT); |
| 731 | num_external_inputs += 1; |
| 732 | } |
| 733 | } |
| 734 | for (uint32_t o = 0; o < node->num_outputs; o++) { |
| 735 | const struct xnn_value* value = &subgraph->values[node->outputs[o]]; |
| 736 | if (value->fp32_id != XNN_INVALID_VALUE_ID) { |
| 737 | assert(value->datatype == xnn_datatype_fp16); |
| 738 | assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32); |
| 739 | assert(subgraph->values[value->fp32_id].flags & XNN_VALUE_FLAG_EXTERNAL_OUTPUT); |
| 740 | num_external_outputs += 1; |
| 741 | } |
| 742 | } |
| 743 | } |
| 744 | xnn_log_debug("Discovered %"PRIu32" external inputs and %"PRIu32" external outputs", |
| 745 | num_external_inputs, num_external_outputs); |
| 746 | |
| 747 | const uint32_t num_original_nodes = subgraph->num_nodes; |
| 748 | xnn_subgraph_add_nodes(subgraph, num_external_inputs + num_external_outputs); |
| 749 | struct xnn_node* output_node = subgraph->nodes + subgraph->num_nodes - 1; |
| 750 | for (uint32_t n = num_original_nodes; n != 0; n--) { |
| 751 | const struct xnn_node* node = &subgraph->nodes[n - 1]; |
| 752 | // Insert Convert nodes for outputs |
| 753 | for (uint32_t o = 0; o < node->num_outputs; o++) { |
| 754 | const struct xnn_value* value = &subgraph->values[node->outputs[o]]; |
| 755 | if (value->fp32_id != XNN_INVALID_VALUE_ID) { |
| 756 | xnn_log_debug("Inserted FP16->FP32 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32, |
| 757 | value->id, value->fp32_id); |
| 758 | const uint32_t output_node_id = output_node->id; |
| 759 | assert(output_node >= subgraph->nodes); |
| 760 | xnn_node_clear(output_node); |
| 761 | output_node->id = output_node_id; |
| 762 | xnn_init_convert_node(output_node, xnn_compute_type_fp16_to_fp32, value->id, value->fp32_id, 0 /* flags */); |
| 763 | output_node -= 1; |
| 764 | } |
| 765 | } |
| 766 | // Move the Node to the new location |
| 767 | if (output_node != node) { |
| 768 | const uint32_t output_node_id = output_node->id; |
| 769 | assert(output_node >= subgraph->nodes); |
| 770 | memcpy(output_node, node, sizeof(struct xnn_node)); |
| 771 | output_node->id = output_node_id; |
| 772 | output_node -= 1; |
| 773 | } |
| 774 | // Insert Convert nodes for inputs |
| 775 | for (uint32_t i = 0; i < node->num_inputs; i++) { |
| 776 | const struct xnn_value* value = &subgraph->values[node->inputs[i]]; |
| 777 | if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n - 1) { |
| 778 | xnn_log_debug("Inserted FP32->FP16 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32, |
| 779 | value->fp32_id, value->id); |
| 780 | const uint32_t output_node_id = output_node->id; |
| 781 | assert(output_node >= subgraph->nodes); |
| 782 | xnn_node_clear(output_node); |
| 783 | output_node->id = output_node_id; |
| 784 | xnn_init_convert_node(output_node, xnn_compute_type_fp32_to_fp16, value->fp32_id, value->id, 0 /* flags */); |
| 785 | output_node -= 1; |
| 786 | } |
| 787 | } |
| 788 | } |
| 789 | } |
| 790 | |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 791 | enum xnn_status xnn_subgraph_optimize( |
| 792 | xnn_subgraph_t subgraph, |
| 793 | uint32_t flags) |
| 794 | { |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 795 | xnn_subgraph_analyze_consumers_and_producers(subgraph); |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 796 | |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 797 | // Remove unreferenced values. |
| 798 | for (uint32_t i = 0; i < subgraph->num_values; i++) { |
| 799 | struct xnn_value* value = &subgraph->values[i]; |
| 800 | if (value->type == xnn_value_type_invalid) { |
| 801 | continue; |
| 802 | } |
| 803 | |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 804 | if ((value->flags & XNN_VALUE_FLAG_EXTERNAL_INPUT) == 0 && value->num_consumers == 0) { |
| 805 | xnn_value_clear(value); |
| 806 | } |
| 807 | } |
| 808 | |
| 809 | // Fuse Nodes where possible |
| 810 | for (uint32_t i = 0; i < subgraph->num_values; i++) { |
| 811 | struct xnn_value* value = &subgraph->values[i]; |
| 812 | if (value->num_consumers == 1) { |
| 813 | const uint32_t producer_id = value->producer; |
| 814 | if (producer_id == XNN_INVALID_NODE_ID) { |
| 815 | continue; |
| 816 | } |
| 817 | assert(producer_id < subgraph->num_nodes); |
| 818 | |
| 819 | const uint32_t consumer_id = value->first_consumer; |
| 820 | if (consumer_id == XNN_INVALID_NODE_ID) { |
| 821 | continue; |
| 822 | } |
| 823 | assert(consumer_id < subgraph->num_nodes); |
| 824 | |
| 825 | struct xnn_node* producer = &subgraph->nodes[producer_id]; |
| 826 | assert(producer->type != xnn_node_type_invalid); |
| 827 | struct xnn_node* consumer = &subgraph->nodes[consumer_id]; |
| 828 | assert(consumer->type != xnn_node_type_invalid); |
| 829 | |
| 830 | // Try to fuse Clamp Node upstream into producer Node |
| 831 | if (consumer->type == xnn_node_type_clamp) { |
| 832 | switch (producer->type) { |
| 833 | case xnn_node_type_add2: |
| 834 | case xnn_node_type_average_pooling_2d: |
| 835 | case xnn_node_type_clamp: |
| 836 | case xnn_node_type_convolution_2d: |
Marat Dukhan | b293e8d | 2020-07-23 20:10:45 -0700 | [diff] [blame] | 837 | case xnn_node_type_divide: |
| 838 | case xnn_node_type_deconvolution_2d: |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 839 | case xnn_node_type_depthwise_convolution_2d: |
| 840 | case xnn_node_type_fully_connected: |
| 841 | case xnn_node_type_multiply2: |
| 842 | case xnn_node_type_max_pooling_2d: |
Marat Dukhan | b293e8d | 2020-07-23 20:10:45 -0700 | [diff] [blame] | 843 | case xnn_node_type_subtract: |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 844 | xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id); |
| 845 | assert(producer->num_outputs == 1); |
| 846 | assert(consumer->num_inputs == 1); |
| 847 | assert(consumer->num_outputs == 1); |
| 848 | |
| 849 | const uint32_t fused_output_id = consumer->outputs[0]; |
| 850 | assert(fused_output_id < subgraph->num_values); |
| 851 | subgraph->values[fused_output_id].producer = producer_id; |
| 852 | producer->outputs[0] = fused_output_id; |
| 853 | |
| 854 | producer->activation.output_min = |
| 855 | math_max_f32(producer->activation.output_min, consumer->activation.output_min); |
| 856 | producer->activation.output_max = |
| 857 | math_min_f32(producer->activation.output_max, consumer->activation.output_max); |
| 858 | |
| 859 | xnn_node_clear(consumer); |
| 860 | xnn_value_clear(value); |
| 861 | break; |
| 862 | default: |
| 863 | break; |
| 864 | } |
| 865 | } |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 866 | // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node |
Marat Dukhan | aff24e2 | 2020-07-23 01:43:58 -0700 | [diff] [blame] | 867 | if (producer->type == xnn_node_type_static_constant_pad) { |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 868 | assert(producer->num_inputs == 1); |
| 869 | assert(producer->num_outputs == 1); |
Marat Dukhan | 8c96521 | 2021-08-09 11:25:40 -0700 | [diff] [blame] | 870 | const bool is_spatial_2d_padding = value->shape.num_dims == 4 && |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 871 | (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] | |
Marat Dukhan | 8c96521 | 2021-08-09 11:25:40 -0700 | [diff] [blame] | 872 | producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0; |
| 873 | const enum xnn_datatype padding_datatype = subgraph->values[producer->outputs[0]].datatype; |
| 874 | const uint32_t padding_value = producer->params.static_pad.padding_value; |
| 875 | const bool is_zero_padding = |
| 876 | (padding_datatype == xnn_datatype_fp32 && padding_value == 0) || |
| 877 | ((padding_datatype == xnn_datatype_qint8 || padding_datatype == xnn_datatype_quint8) && |
| 878 | padding_value == (uint32_t) (uint8_t) subgraph->values[producer->outputs[0]].quantization.zero_point); |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 879 | switch (consumer->type) { |
| 880 | case xnn_node_type_convolution_2d: |
Marat Dukhan | 8c96521 | 2021-08-09 11:25:40 -0700 | [diff] [blame] | 881 | if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) { |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 882 | xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32, |
| 883 | consumer_id, producer_id); |
| 884 | assert(consumer->num_inputs >= 1); |
| 885 | assert(consumer->inputs[0] == producer->outputs[0]); |
| 886 | |
| 887 | consumer->params.convolution_2d.input_padding_top += producer->params.static_pad.pre_paddings[1]; |
Marat Dukhan | facecc5 | 2020-08-10 08:00:08 -0700 | [diff] [blame] | 888 | consumer->params.convolution_2d.input_padding_right += producer->params.static_pad.post_paddings[2]; |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 889 | consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1]; |
Marat Dukhan | facecc5 | 2020-08-10 08:00:08 -0700 | [diff] [blame] | 890 | consumer->params.convolution_2d.input_padding_left += producer->params.static_pad.pre_paddings[2]; |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 891 | |
| 892 | consumer->inputs[0] = producer->inputs[0]; |
| 893 | |
| 894 | const uint32_t fused_input_id = producer->inputs[0]; |
| 895 | assert(fused_input_id < subgraph->num_values); |
| 896 | if (subgraph->values[fused_input_id].first_consumer == producer_id) { |
| 897 | subgraph->values[fused_input_id].first_consumer = consumer_id; |
| 898 | } |
| 899 | |
| 900 | xnn_node_clear(producer); |
| 901 | xnn_value_clear(value); |
| 902 | } |
| 903 | break; |
| 904 | case xnn_node_type_depthwise_convolution_2d: |
Marat Dukhan | 8c96521 | 2021-08-09 11:25:40 -0700 | [diff] [blame] | 905 | if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) { |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 906 | xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32, |
| 907 | consumer_id, producer_id); |
| 908 | assert(consumer->num_inputs >= 1); |
| 909 | assert(consumer->inputs[0] == producer->outputs[0]); |
| 910 | |
| 911 | consumer->params.depthwise_convolution_2d.input_padding_top += |
| 912 | producer->params.static_pad.pre_paddings[1]; |
| 913 | consumer->params.depthwise_convolution_2d.input_padding_right += |
Marat Dukhan | facecc5 | 2020-08-10 08:00:08 -0700 | [diff] [blame] | 914 | producer->params.static_pad.post_paddings[2]; |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 915 | consumer->params.depthwise_convolution_2d.input_padding_bottom += |
| 916 | producer->params.static_pad.post_paddings[1]; |
| 917 | consumer->params.depthwise_convolution_2d.input_padding_left += |
Marat Dukhan | facecc5 | 2020-08-10 08:00:08 -0700 | [diff] [blame] | 918 | producer->params.static_pad.pre_paddings[2]; |
Marat Dukhan | f3d1205 | 2020-05-25 15:41:37 -0700 | [diff] [blame] | 919 | |
| 920 | consumer->inputs[0] = producer->inputs[0]; |
| 921 | |
| 922 | const uint32_t fused_input_id = producer->inputs[0]; |
| 923 | assert(fused_input_id < subgraph->num_values); |
| 924 | if (subgraph->values[fused_input_id].first_consumer == producer_id) { |
| 925 | subgraph->values[fused_input_id].first_consumer = consumer_id; |
| 926 | } |
| 927 | |
| 928 | xnn_node_clear(producer); |
| 929 | xnn_value_clear(value); |
| 930 | } |
| 931 | break; |
| 932 | default: |
| 933 | break; |
| 934 | } |
| 935 | } |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 936 | } |
| 937 | } |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 938 | |
| 939 | #if XNN_ENABLE_SPARSE |
Marat Dukhan | cfbed0a | 2020-12-08 10:01:51 -0800 | [diff] [blame] | 940 | if ((flags & XNN_FLAG_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) { |
Marat Dukhan | 7332e83 | 2020-12-06 23:26:11 -0800 | [diff] [blame] | 941 | xnn_subgraph_rewrite_for_nchw(subgraph); |
| 942 | } |
Marat Dukhan | 9de90e0 | 2020-06-18 16:04:12 -0700 | [diff] [blame] | 943 | #endif |
| 944 | |
Marat Dukhan | 4620ca6 | 2022-02-03 12:31:00 -0800 | [diff] [blame] | 945 | #ifndef XNN_NO_F16_OPERATORS |
| 946 | if ((flags & XNN_FLAG_FP16_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_F16)) { |
| 947 | xnn_subgraph_rewrite_for_fp16(subgraph); |
| 948 | } |
| 949 | #endif // XNN_NO_F16_OPERATORS |
| 950 | |
Marat Dukhan | 1f19872 | 2020-05-24 14:07:03 -0700 | [diff] [blame] | 951 | return xnn_status_success; |
| 952 | } |
| 953 | |
Marat Dukhan | 1d75a54 | 2020-02-03 12:23:01 -0800 | [diff] [blame] | 954 | enum xnn_status xnn_delete_subgraph( |
| 955 | xnn_subgraph_t subgraph) |
| 956 | { |
| 957 | if (subgraph != NULL) { |
| 958 | memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes); |
| 959 | xnn_release_memory(subgraph->nodes); |
| 960 | |
| 961 | memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values); |
| 962 | xnn_release_memory(subgraph->values); |
| 963 | |
| 964 | memset(subgraph, 0, sizeof(struct xnn_subgraph)); |
| 965 | xnn_release_memory(subgraph); |
| 966 | } |
| 967 | return xnn_status_success; |
| 968 | } |