blob: c3f4b6ce75b6848171c0294d52694c0f7beac055 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <math.h>
7#include <stddef.h>
8#include <stdint.h>
9#include <stdlib.h>
10
11#include <xnnpack.h>
Marat Dukhan04f03be2019-11-19 12:36:47 -080012#include <xnnpack/allocator.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070013#include <xnnpack/log.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070014#include <xnnpack/operator.h>
15#include <xnnpack/params-init.h>
16#include <xnnpack/params.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070017
18
19enum xnn_status xnn_create_hardswish_nc_f32(
20 size_t channels,
21 size_t input_stride,
22 size_t output_stride,
23 uint32_t flags,
24 xnn_operator_t* hardswish_op_out)
25{
26 xnn_operator_t hardswish_op = NULL;
27 enum xnn_status status = xnn_status_uninitialized;
28
29 if (!xnn_params.initialized) {
30 xnn_log_error("failed to create HardSwish operator: XNNPACK is not initialized");
31 goto error;
32 }
33
34 status = xnn_status_invalid_parameter;
35
36 if (channels == 0) {
37 xnn_log_error(
38 "failed to create HardSwish operator with %zu channels: number of channels must be non-zero", channels);
39 goto error;
40 }
41
42 if (input_stride < channels) {
43 xnn_log_error(
44 "failed to create HardSwish operator with input element stride of %zu: "
45 "stride must be at least as large as the number of channels (%zu)",
46 input_stride, channels);
47 goto error;
48 }
49
50 if (output_stride < channels) {
51 xnn_log_error(
52 "failed to create HardSwish operator with output element stride of %zu: "
53 "stride must be at least as large as the number of channels (%zu)",
54 output_stride, channels);
55 goto error;
56 }
57
58 status = xnn_status_out_of_memory;
59
Marat Dukhan04f03be2019-11-19 12:36:47 -080060 hardswish_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
XNNPACK Teamb455b122019-09-27 18:10:33 -070061 if (hardswish_op == NULL) {
62 xnn_log_error("failed to allocate %zu bytes for xnn_operator structure", sizeof(struct xnn_operator));
63 goto error;
64 }
65
66 hardswish_op->channels = channels;
67 hardswish_op->input_pixel_stride = input_stride;
68 hardswish_op->output_pixel_stride = output_stride;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070069 hardswish_op->f32_hswish_params = xnn_init_f32_hswish_params();
XNNPACK Teamb455b122019-09-27 18:10:33 -070070
Marat Dukhanefc47b82019-11-18 09:25:38 -080071 hardswish_op->type = xnn_operator_type_hardswish_nc_f32;
XNNPACK Teamb455b122019-09-27 18:10:33 -070072 hardswish_op->ukernel.type = xnn_ukernel_type_hswish;
73
74 hardswish_op->state = xnn_run_state_invalid;
75
76 *hardswish_op_out = hardswish_op;
77 return xnn_status_success;
78
79error:
80 xnn_delete_operator(hardswish_op);
81 return status;
82}
83
84enum xnn_status xnn_setup_hardswish_nc_f32(
85 xnn_operator_t hardswish_op,
86 size_t batch_size,
87 const float* input,
88 float* output,
89 pthreadpool_t threadpool)
90{
Marat Dukhanefc47b82019-11-18 09:25:38 -080091 if (hardswish_op->type != xnn_operator_type_hardswish_nc_f32) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070092 xnn_log_error("failed to setup HardSwish (F32) operator: operator type mismatch");
93 return xnn_status_invalid_parameter;
94 }
95 hardswish_op->state = xnn_run_state_invalid;
96
97 if (!xnn_params.initialized) {
98 xnn_log_error("failed to setup HardSwish operator: XNNPACK is not initialized");
99 return xnn_status_uninitialized;
100 }
101
102 if (batch_size == 0) {
103 hardswish_op->state = xnn_run_state_skip;
104 return xnn_status_success;
105 }
106
107 const size_t channels = hardswish_op->channels;
108 const size_t input_stride = hardswish_op->input_pixel_stride;
109 const size_t output_stride = hardswish_op->output_pixel_stride;
110 if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
111 const size_t block_size = 4096;
112 hardswish_op->context.univector_contiguous = (struct univector_contiguous_context) {
113 .x = input,
114 .x_stride = input_stride * sizeof(float),
115 .y = output,
116 .y_stride = output_stride * sizeof(float),
117 .ukernel = xnn_params.f32.hswish,
118 .params.f32_hswish = hardswish_op->f32_hswish_params,
119 };
120 hardswish_op->compute.type = xnn_parallelization_type_1d_tile_1d;
121 hardswish_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
122 hardswish_op->compute.range[0] = batch_size * channels * sizeof(float);
123 hardswish_op->compute.tile[0] = block_size;
124 } else {
125 hardswish_op->context.univector_strided = (struct univector_strided_context) {
126 .n = channels * sizeof(float),
127 .x = input,
128 .x_stride = input_stride * sizeof(float),
129 .y = output,
130 .y_stride = output_stride * sizeof(float),
131 .ukernel = xnn_params.f32.hswish,
132 .params.f32_hswish = hardswish_op->f32_hswish_params,
133 };
134 hardswish_op->compute.type = xnn_parallelization_type_1d_tile_1d;
135 hardswish_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
136 hardswish_op->compute.range[0] = batch_size;
137 hardswish_op->compute.tile[0] = 1;
138 }
139 hardswish_op->state = xnn_run_state_ready;
140
141 return xnn_status_success;
142}