blob: 8aadc4a954a8e99c9aa37c1e24d5bf1190bae4ba [file] [log] [blame]
Marat Dukhan93d1ba12020-06-26 12:33:35 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF"]
Frank Barchard674778d2020-08-08 10:17:25 -070010$assert ACTIVATION in ["LINEAR", "MINMAX", "RELU"]
Marat Dukhan93d1ba12020-06-26 12:33:35 -070011#include <assert.h>
12
13#include <wasm_simd128.h>
14
15#include <xnnpack/common.h>
16#include <xnnpack/vbinary.h>
17
18
19$WASM_F32X4_OP = {
20$ "ADD": "wasm_f32x4_add",
21$ "DIV": "wasm_f32x4_div",
22$ "MAX": "wasm_f32x4_max",
23$ "MIN": "wasm_f32x4_min",
24$ "MUL": "wasm_f32x4_mul",
25$ "SUB": "wasm_f32x4_sub",
26$ "SQRDIFF": "wasm_f32x4_sub",
27$}[OP]
Frank Barchard674778d2020-08-08 10:17:25 -070028$ARCH_SUFFIX = "" if ACTIVATION in ["LINEAR", "RELU"] and OP not in ["MIN", "MAX"] else "_x86" if X86 else "_arm"
29$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower())
30$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
Marat Dukhan93d1ba12020-06-26 12:33:35 -070031void xnn_f32_v${OP.lower()}${ACTIVATION_SUFFIX}_ukernel__wasmsimd${ARCH_SUFFIX}_x${BATCH_TILE}(
32 size_t n,
33 const float* a,
34 const float* b,
35 float* y,
36 const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
37{
38 assert(n != 0);
39 assert(n % sizeof(float) == 0);
Frank Barchard0822dde2020-07-04 12:47:24 -070040 assert(a != NULL);
41 assert(b != NULL);
42 assert(y != NULL);
Marat Dukhan93d1ba12020-06-26 12:33:35 -070043
44 $if ACTIVATION == "MINMAX":
45 const v128_t vy_min = wasm_v32x4_load_splat(&params->scalar.min);
46 const v128_t vy_max = wasm_v32x4_load_splat(&params->scalar.max);
Frank Barchard674778d2020-08-08 10:17:25 -070047 $elif ACTIVATION == "RELU":
48 const v128_t vzero = wasm_f32x4_splat(0.0f);
Marat Dukhan93d1ba12020-06-26 12:33:35 -070049
50 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
51 const v128_t va${ABC[0:4]} = wasm_v128_load(a);
52 $for N in range(4, BATCH_TILE, 4):
53 const v128_t va${ABC[N:N+4]} = wasm_v128_load(a + ${N});
54 a += ${BATCH_TILE};
55
56 const v128_t vb${ABC[0:4]} = wasm_v128_load(b);
57 $for N in range(4, BATCH_TILE, 4):
58 const v128_t vb${ABC[N:N+4]} = wasm_v128_load(b + ${N});
59 b += ${BATCH_TILE};
60
61 $if OP == "MIN" and X86:
62 $for N in range(0, BATCH_TILE, 4):
63 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_lt(va${ABC[N:N+4]}, vb${ABC[N:N+4]});
64
65 $for N in range(0, BATCH_TILE, 4):
66 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(va${ABC[N:N+4]}, vb${ABC[N:N+4]}, vm${ABC[N:N+4]});
67 $elif OP == "MAX" and X86:
68 $for N in range(0, BATCH_TILE, 4):
69 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_le(va${ABC[N:N+4]}, vb${ABC[N:N+4]});
70
71 $for N in range(0, BATCH_TILE, 4):
72 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(vb${ABC[N:N+4]}, va${ABC[N:N+4]}, vm${ABC[N:N+4]});
73 $else:
74 $for N in range(0, BATCH_TILE, 4):
75 v128_t vy${ABC[N:N+4]} = ${WASM_F32X4_OP}(va${ABC[N:N+4]}, vb${ABC[N:N+4]});
76
77 $if OP == "SQRDIFF":
78 $for N in range(0, BATCH_TILE, 4):
79 vy${ABC[N:N+4]} = wasm_f32x4_mul(vy${ABC[N:N+4]}, vy${ABC[N:N+4]});
80
81 $if ACTIVATION == "MINMAX":
82 $if X86:
83 $for N in range(0, BATCH_TILE, 4):
84 const v128_t vltmask${ABC[N:N+4]} = wasm_f32x4_lt(vy${ABC[N:N+4]}, vy_min);
85
86 $for N in range(0, BATCH_TILE, 4):
87 const v128_t vngtmask${ABC[N:N+4]} = wasm_f32x4_le(vy${ABC[N:N+4]}, vy_max);
88 vy${ABC[N:N+4]} = wasm_v128_bitselect(vy_min, vy${ABC[N:N+4]}, vltmask${ABC[N:N+4]});
89
90 $for N in range(0, BATCH_TILE, 4):
91 vy${ABC[N:N+4]} = wasm_v128_bitselect(vy${ABC[N:N+4]}, vy_max, vngtmask${ABC[N:N+4]});
92 $else:
93 $for N in range(0, BATCH_TILE, 4):
94 vy${ABC[N:N+4]} = wasm_f32x4_max(vy${ABC[N:N+4]}, vy_min);
95
96 $for N in range(0, BATCH_TILE, 4):
97 vy${ABC[N:N+4]} = wasm_f32x4_min(vy${ABC[N:N+4]}, vy_max);
Frank Barchard674778d2020-08-08 10:17:25 -070098 $elif ACTIVATION == "RELU":
99 $for N in range(0, BATCH_TILE, 4):
100 vy${ABC[N:N+4]} = wasm_i32x4_max(vy${ABC[N:N+4]}, vzero);
Marat Dukhan93d1ba12020-06-26 12:33:35 -0700101
102 wasm_v128_store(y, vy${ABC[0:4]});
103 $for N in range(4, BATCH_TILE, 4):
104 wasm_v128_store(y + ${N}, vy${ABC[N:N+4]});
105 y += ${BATCH_TILE};
106 }
107 $if BATCH_TILE > 4:
108 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
109 const v128_t va = wasm_v128_load(a);
110 a += 4;
111
112 const v128_t vb = wasm_v128_load(b);
113 b += 4;
114
115 $if OP == "MIN" and X86:
116 const v128_t vm = wasm_f32x4_lt(va, vb);
117 v128_t vy = wasm_v128_bitselect(va, vb, vm);
118 $elif OP == "MAX" and X86:
119 const v128_t vm = wasm_f32x4_le(va, vb);
120 v128_t vy = wasm_v128_bitselect(vb, va, vm);
121 $else:
122 v128_t vy = ${WASM_F32X4_OP}(va, vb);
123 $if OP == "SQRDIFF":
124 vy = wasm_f32x4_mul(vy, vy);
125
126 $if ACTIVATION == "MINMAX":
127 $if X86:
128 const v128_t vltmask = wasm_f32x4_lt(vy, vy_min);
129 const v128_t vngtmask = wasm_f32x4_le(vy, vy_max);
130 vy = wasm_v128_bitselect(vy_min, vy, vltmask);
131 vy = wasm_v128_bitselect(vy, vy_max, vngtmask);
132 $else:
133 vy = wasm_f32x4_max(vy, vy_min);
134 vy = wasm_f32x4_min(vy, vy_max);
Frank Barchard674778d2020-08-08 10:17:25 -0700135 $elif ACTIVATION == "RELU":
136 vy = wasm_i32x4_max(vy, vzero);
Marat Dukhan93d1ba12020-06-26 12:33:35 -0700137
138 wasm_v128_store(y, vy);
139 y += 4;
140 }
141 if XNN_UNLIKELY(n != 0) {
142 const v128_t va = wasm_v128_load(a);
143 const v128_t vb = wasm_v128_load(b);
144
Marat Dukhan22855472020-06-26 22:38:47 -0700145 $if OP == "MIN" and X86:
146 const v128_t vm = wasm_f32x4_lt(va, vb);
147 v128_t vy = wasm_v128_bitselect(va, vb, vm);
148 $elif OP == "MAX" and X86:
149 const v128_t vm = wasm_f32x4_le(va, vb);
150 v128_t vy = wasm_v128_bitselect(vb, va, vm);
151 $else:
152 v128_t vy = ${WASM_F32X4_OP}(va, vb);
153 $if OP == "SQRDIFF":
154 vy = wasm_f32x4_mul(vy, vy);
Marat Dukhan93d1ba12020-06-26 12:33:35 -0700155
156 $if ACTIVATION == "MINMAX":
157 $if X86:
158 const v128_t vltmask = wasm_f32x4_lt(vy, vy_min);
159 const v128_t vngtmask = wasm_f32x4_le(vy, vy_max);
160 vy = wasm_v128_bitselect(vy_min, vy, vltmask);
161 vy = wasm_v128_bitselect(vy, vy_max, vngtmask);
162 $else:
163 vy = wasm_f32x4_max(vy, vy_min);
164 vy = wasm_f32x4_min(vy, vy_max);
Frank Barchard674778d2020-08-08 10:17:25 -0700165 $elif ACTIVATION == "RELU":
166 vy = wasm_i32x4_max(vy, vzero);
Marat Dukhan93d1ba12020-06-26 12:33:35 -0700167
168 if (n & (2 * sizeof(float))) {
169 *((double*) y) = wasm_f64x2_extract_lane(vy, 0);
170 vy = wasm_v32x4_shuffle(vy, vy, 2, 3, 2, 3);
171 y += 2;
172 }
173 if (n & (1 * sizeof(float))) {
174 *y = wasm_f32x4_extract_lane(vy, 0);
175 }
176 }
177}