Add binary op microkernels with RELU activation
PiperOrigin-RevId: 325607697
diff --git a/src/f32-vbinary/vop-scalar.c.in b/src/f32-vbinary/vop-scalar.c.in
index 0df0def..3802093 100644
--- a/src/f32-vbinary/vop-scalar.c.in
+++ b/src/f32-vbinary/vop-scalar.c.in
@@ -6,7 +6,7 @@
$assert BATCH_TILE >= 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF"]
-$assert ACTIVATION in ["LINEAR", "MINMAX"]
+$assert ACTIVATION in ["LINEAR", "MINMAX", "RELU"]
#include <assert.h>
#include <xnnpack/common.h>
@@ -25,8 +25,8 @@
$ "SUB": lambda x, y: "%s - %s" % (x, y),
$ "SQRDIFF": lambda x, y: "%s - %s" % (x, y),
$}[OP]
-$SUFFIX = {"LINEAR": "", "MINMAX": "_minmax"}[ACTIVATION]
-$PARAMS = {"LINEAR": "xnn_f32_default_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
+$SUFFIX = {"LINEAR": "", "RELU": "_relu", "MINMAX": "_minmax"}[ACTIVATION]
+$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
void xnn_f32_v${OP.lower()}${SUFFIX}_ukernel__${"wasm" if WASM else "scalar"}_x${BATCH_TILE}(
size_t n,
const float* a,
@@ -67,6 +67,9 @@
$for N in range(BATCH_TILE):
vy${ABC[N]} = ${MIN_F32}(vy${ABC[N]}, vy_max);
+ $elif ACTIVATION == "RELU":
+ $for N in range(BATCH_TILE):
+ vy${ABC[N]} = ${MAX_F32}(vy${ABC[N]}, 0.0f);
$for N in range(BATCH_TILE):
y[${N}] = vy${ABC[N]};
@@ -83,6 +86,8 @@
$if ACTIVATION == "MINMAX":
vy = ${MAX_F32}(vy, vy_min);
vy = ${MIN_F32}(vy, vy_max);
+ $elif ACTIVATION == "RELU":
+ vy = ${MAX_F32}(vy, 0.0f);
*y++ = vy;
n -= sizeof(float);
} while (n != 0);
@@ -95,6 +100,8 @@
$if ACTIVATION == "MINMAX":
vy = ${MAX_F32}(vy, vy_min);
vy = ${MIN_F32}(vy, vy_max);
+ $elif ACTIVATION == "RELU":
+ vy = ${MAX_F32}(vy, 0.0f);
*y = vy;
}
$else:
@@ -107,6 +114,8 @@
$if ACTIVATION == "MINMAX":
vy = ${MAX_F32}(vy, vy_min);
vy = ${MIN_F32}(vy, vy_max);
+ $elif ACTIVATION == "RELU":
+ vy = ${MAX_F32}(vy, 0.0f);
*y++ = vy;
}
}