Refactor and open-source Three-Pass Softmax micro-kernels
- RAddExpMinusMax micro-kernel (AVX2 and AVX512F)
- RAddStoreExpMinusMax micro-kernel (AVX2 and AVX512F)
- VScaleExpMinusMax micro-kernel (AVX2 and AVX512F)
- Unit tests for all micro-kernels
PiperOrigin-RevId: 275570264
diff --git a/BUILD b/BUILD
index 49d8dba..3e171ec 100644
--- a/BUILD
+++ b/BUILD
@@ -407,6 +407,9 @@
]
AVX2_UKERNELS = [
+ "src/f32-raddexpminusmax/avx2-p5-unroll64.c",
+ "src/f32-raddstoreexpminusmax/avx2-p5-unroll64.c",
+ "src/f32-vscaleexpminusmax/avx2-p5-unroll64.c",
"src/math/exp-avx2-p5.c",
"src/math/exp-avx2-perm-p3.c",
"src/math/exp-avx2-perm-p4.c",
@@ -414,7 +417,10 @@
]
AVX512F_UKERNELS = [
+ "src/f32-raddexpminusmax/avx512f-p5-scalef-unroll128.c",
+ "src/f32-raddstoreexpminusmax/avx512f-p5-scalef-unroll128.c",
"src/f32-rmax/avx512f.c",
+ "src/f32-vscaleexpminusmax/avx512f-p5-scalef-unroll128.c",
"src/math/exp-avx512f-p5-scalef.c",
"src/math/exp-avx512f-p5.c",
"src/math/exp-avx512f-perm-p3.c",
@@ -489,6 +495,8 @@
"src/xnnpack/pavgpool.h",
"src/xnnpack/ppmm.h",
"src/xnnpack/prelu.h",
+ "src/xnnpack/raddexpminusmax.h",
+ "src/xnnpack/raddstoreexpminusmax.h",
"src/xnnpack/rmax.h",
"src/xnnpack/scalar-utils.h",
"src/xnnpack/spmm.h",
@@ -496,6 +504,7 @@
"src/xnnpack/vadd.h",
"src/xnnpack/vmul.h",
"src/xnnpack/vmulcaddc.h",
+ "src/xnnpack/vscaleexpminusmax.h",
"src/xnnpack/vsub.h",
"src/xnnpack/zip.h",
]
@@ -1234,6 +1243,24 @@
)
xnnpack_unit_test(
+ name = "f32_raddexpminusmax_test",
+ srcs = [
+ "test/f32-raddexpminusmax.cc",
+ "test/raddexpminusmax-microkernel-tester.h",
+ ] + MICROKERNEL_TEST_HDRS,
+ deps = MICROKERNEL_TEST_DEPS,
+)
+
+xnnpack_unit_test(
+ name = "f32_raddstoreexpminusmax_test",
+ srcs = [
+ "test/f32-raddstoreexpminusmax.cc",
+ "test/raddstoreexpminusmax-microkernel-tester.h",
+ ] + MICROKERNEL_TEST_HDRS,
+ deps = MICROKERNEL_TEST_DEPS,
+)
+
+xnnpack_unit_test(
name = "f32_rmax_test",
srcs = [
"test/f32-rmax.cc",
@@ -1262,15 +1289,6 @@
)
xnnpack_unit_test(
- name = "f32_vsub_test",
- srcs = [
- "test/f32-vsub.cc",
- "test/vsub-microkernel-tester.h",
- ] + MICROKERNEL_TEST_HDRS,
- deps = MICROKERNEL_TEST_DEPS,
-)
-
-xnnpack_unit_test(
name = "f32_vmul_test",
srcs = [
"test/f32-vmul.cc",
@@ -1290,6 +1308,24 @@
)
xnnpack_unit_test(
+ name = "f32_vscaleexpminusmax_test",
+ srcs = [
+ "test/f32-vscaleexpminusmax.cc",
+ "test/vscaleexpminusmax-microkernel-tester.h",
+ ] + MICROKERNEL_TEST_HDRS,
+ deps = MICROKERNEL_TEST_DEPS,
+)
+
+xnnpack_unit_test(
+ name = "f32_vsub_test",
+ srcs = [
+ "test/f32-vsub.cc",
+ "test/vsub-microkernel-tester.h",
+ ] + MICROKERNEL_TEST_HDRS,
+ deps = MICROKERNEL_TEST_DEPS,
+)
+
+xnnpack_unit_test(
name = "q8_avgpool_test",
srcs = [
"test/q8-avgpool.cc",