Additional variants of Softmax microkernels
PiperOrigin-RevId: 284483874
diff --git a/BUILD.bazel b/BUILD.bazel
index 1421c6a..c3f0d06 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -838,11 +838,66 @@
]
AVX2_UKERNELS = [
- "src/f32-raddexpminusmax/avx2-p5-unroll64.c",
- "src/f32-raddextexp/avx2-p5-unroll64.c",
- "src/f32-raddstoreexpminusmax/avx2-p5-unroll64.c",
- "src/f32-vscaleexpminusmax/avx2-p5-unroll64.c",
- "src/f32-vscaleextexp/avx2-p5-unroll64.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x64.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x64-acc2.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x64-acc4.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x72.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x72-acc3.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x80.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x80-acc2.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x80-acc5.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x96.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x96-acc2.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x96-acc3.c",
+ "src/f32-raddexpminusmax/gen/avx2-p5-x96-acc6.c",
+ "src/f32-raddextexp/gen/avx2-p5-x64.c",
+ "src/f32-raddextexp/gen/avx2-p5-x64-acc2.c",
+ "src/f32-raddextexp/gen/avx2-p5-x64-acc4.c",
+ "src/f32-raddextexp/gen/avx2-p5-x72.c",
+ "src/f32-raddextexp/gen/avx2-p5-x72-acc3.c",
+ "src/f32-raddextexp/gen/avx2-p5-x80.c",
+ "src/f32-raddextexp/gen/avx2-p5-x80-acc2.c",
+ "src/f32-raddextexp/gen/avx2-p5-x80-acc5.c",
+ "src/f32-raddextexp/gen/avx2-p5-x96.c",
+ "src/f32-raddextexp/gen/avx2-p5-x96-acc2.c",
+ "src/f32-raddextexp/gen/avx2-p5-x96-acc3.c",
+ "src/f32-raddextexp/gen/avx2-p5-x96-acc6.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x64.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x64-acc2.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x64-acc4.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x72.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x72-acc3.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x80.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x80-acc2.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x80-acc5.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x96.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x96-acc2.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x96-acc3.c",
+ "src/f32-raddstoreexpminusmax/gen/avx2-p5-x96-acc6.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x8.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x16.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x24.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x32.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x40.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x48.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x56.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x64.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x72.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x80.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x88.c",
+ "src/f32-vscaleexpminusmax/gen/avx2-p5-x96.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x8.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x16.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x24.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x32.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x40.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x48.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x56.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x64.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x72.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x80.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x88.c",
+ "src/f32-vscaleextexp/gen/avx2-p5-x96.c",
"src/math/exp-avx2-p5.c",
"src/math/exp-avx2-perm-p3.c",
"src/math/exp-avx2-perm-p4.c",
@@ -882,13 +937,68 @@
"src/f32-igemm/gen/6x16-avx512f-broadcast.c",
"src/f32-igemm/gen/7x16-avx512f-broadcast.c",
"src/f32-igemm/gen/8x16-avx512f-broadcast.c",
- "src/f32-raddexpminusmax/avx512f-p5-scalef-unroll128.c",
- "src/f32-raddextexp/avx512f-p5-scalef-unroll128.c",
- "src/f32-raddstoreexpminusmax/avx512f-p5-scalef-unroll128.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x128.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x128-acc2.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x128-acc4.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x144.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x144-acc3.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x160.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x160-acc2.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x160-acc5.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x192.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x192-acc2.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x192-acc3.c",
+ "src/f32-raddexpminusmax/gen/avx512f-p5-scalef-x192-acc6.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x128.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x128-acc2.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x128-acc4.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x144.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x144-acc3.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x160.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x160-acc2.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x160-acc5.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x192.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x192-acc2.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x192-acc3.c",
+ "src/f32-raddextexp/gen/avx512f-p5-scalef-x192-acc6.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x128.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x128-acc2.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x128-acc4.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x144.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x144-acc3.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x160.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x160-acc2.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x160-acc5.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x192.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x192-acc2.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x192-acc3.c",
+ "src/f32-raddstoreexpminusmax/gen/avx512f-p5-scalef-x192-acc6.c",
"src/f32-rmax/avx512f.c",
"src/f32-vscale/avx512f-unroll64.c",
- "src/f32-vscaleexpminusmax/avx512f-p5-scalef-unroll128.c",
- "src/f32-vscaleextexp/avx512f-p5-scalef-unroll128.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x16.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x32.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x48.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x64.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x80.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x96.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x112.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x128.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x144.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x160.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x176.c",
+ "src/f32-vscaleexpminusmax/gen/avx512f-p5-scalef-x192.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x16.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x32.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x48.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x64.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x80.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x96.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x112.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x128.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x144.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x160.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x176.c",
+ "src/f32-vscaleextexp/gen/avx512f-p5-scalef-x192.c",
"src/math/exp-avx512f-p5-scalef.c",
"src/math/exp-avx512f-p5.c",
"src/math/exp-avx512f-perm-p3.c",
@@ -1489,6 +1599,33 @@
)
xnnpack_benchmark(
+ name = "f32_raddexpminusmax_bench",
+ srcs = [
+ "bench/f32-raddexpminusmax.cc",
+ "src/xnnpack/AlignedAllocator.h",
+ ] + MICROKERNEL_BENCHMARK_HDRS,
+ deps = MICROKERNEL_BENCHMARK_DEPS,
+)
+
+xnnpack_benchmark(
+ name = "f32_raddextexp_bench",
+ srcs = [
+ "bench/f32-raddextexp.cc",
+ "src/xnnpack/AlignedAllocator.h",
+ ] + MICROKERNEL_BENCHMARK_HDRS,
+ deps = MICROKERNEL_BENCHMARK_DEPS,
+)
+
+xnnpack_benchmark(
+ name = "f32_raddstoreexpminusmax_bench",
+ srcs = [
+ "bench/f32-raddstoreexpminusmax.cc",
+ "src/xnnpack/AlignedAllocator.h",
+ ] + MICROKERNEL_BENCHMARK_HDRS,
+ deps = MICROKERNEL_BENCHMARK_DEPS,
+)
+
+xnnpack_benchmark(
name = "f32_rmax_bench",
srcs = [
"bench/f32-rmax.cc",
@@ -1528,6 +1665,24 @@
)
xnnpack_benchmark(
+ name = "f32_vscaleexpminusmax_bench",
+ srcs = [
+ "bench/f32-vscaleexpminusmax.cc",
+ "src/xnnpack/AlignedAllocator.h",
+ ] + MICROKERNEL_BENCHMARK_HDRS,
+ deps = MICROKERNEL_BENCHMARK_DEPS,
+)
+
+xnnpack_benchmark(
+ name = "f32_vscaleextexp_bench",
+ srcs = [
+ "bench/f32-vscaleextexp.cc",
+ "src/xnnpack/AlignedAllocator.h",
+ ] + MICROKERNEL_BENCHMARK_HDRS,
+ deps = MICROKERNEL_BENCHMARK_DEPS,
+)
+
+xnnpack_benchmark(
name = "f32_im2col_gemm_bench",
srcs = [
"bench/f32-im2col-gemm.cc",