| // Copyright 2019 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| $assert BATCH_TILE >= 1 |
| $ABC = "0123456789ABCDEFGHIJKLMN" |
| $assert OP in ["ADD", "MUL", "SUB"] |
| #include <assert.h> |
| |
| #include <xnnpack/common.h> |
| #include <xnnpack/math.h> |
| #include <xnnpack/vbinary.h> |
| |
| |
| $MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" |
| $MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" |
| $OP_FUNC = { |
| $ "ADD": lambda x, y: "%s + %s" % (x, y), |
| $ "MUL": lambda x, y: "%s * %s" % (x, y), |
| $ "SUB": lambda x, y: "%s - %s" % (x, y), |
| $}[OP] |
| void xnn_f32_v${OP.lower()}_ukernel__${"wasm" if WASM else "scalar"}_x${BATCH_TILE}( |
| size_t n, |
| const float* a, |
| const float* b, |
| float* y, |
| const union xnn_f32_output_params params[restrict static 1]) |
| { |
| assert(n != 0); |
| assert(n % sizeof(float) == 0); |
| |
| const float vy_min = params->scalar.min; |
| const float vy_max = params->scalar.max; |
| |
| $if BATCH_TILE > 1: |
| for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { |
| $for N in range(BATCH_TILE): |
| const float va${ABC[N]} = a[${N}]; |
| a += ${BATCH_TILE}; |
| |
| $for N in range(BATCH_TILE): |
| const float vb${ABC[N]} = b[${N}]; |
| b += ${BATCH_TILE}; |
| |
| $for N in range(BATCH_TILE): |
| float vy${ABC[N]} = ${OP_FUNC("va" + ABC[N], "vb" + ABC[N])}; |
| |
| $for N in range(BATCH_TILE): |
| vy${ABC[N]} = ${MAX_F32}(vy${ABC[N]}, vy_min); |
| |
| $for N in range(BATCH_TILE): |
| vy${ABC[N]} = ${MIN_F32}(vy${ABC[N]}, vy_max); |
| |
| $for N in range(BATCH_TILE): |
| y[${N}] = vy${ABC[N]}; |
| y += ${BATCH_TILE}; |
| } |
| if XNN_UNLIKELY(n != 0) { |
| $if BATCH_TILE > 2: |
| do { |
| const float va = *a++; |
| const float vb = *b++; |
| float vy = ${OP_FUNC("va", "vb")}; |
| vy = ${MAX_F32}(vy, vy_min); |
| vy = ${MIN_F32}(vy, vy_max); |
| *y++ = vy; |
| n -= sizeof(float); |
| } while (n != 0); |
| $else: |
| const float va = *a; |
| const float vb = *b; |
| float vy = ${OP_FUNC("va", "vb")}; |
| vy = ${MAX_F32}(vy, vy_min); |
| vy = ${MIN_F32}(vy, vy_max); |
| *y = vy; |
| } |
| $else: |
| for (; n >= sizeof(float); n -= sizeof(float)) { |
| const float va = *a++; |
| const float vb = *b++; |
| float vy = ${OP_FUNC("va", "vb")}; |
| vy = ${MAX_F32}(vy, vy_min); |
| vy = ${MIN_F32}(vy, vy_max); |
| *y++ = vy; |
| } |
| } |