AVX512F implementation of GEMM & IGEMM micro-kernels
PiperOrigin-RevId: 282810752
diff --git a/bench/f32-igemm.cc b/bench/f32-igemm.cc
index db5dd6f..ba837ff 100644
--- a/bench/f32-igemm.cc
+++ b/bench/f32-igemm.cc
@@ -360,7 +360,6 @@
static void f32_igemm_1x8__sse_load1(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x8__sse_load1, 1, 8, 1, 1);
}
-
static void f32_igemm_4x8__sse_load1(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__sse_load1, 4, 8, 1, 1);
}
@@ -368,7 +367,6 @@
static void f32_igemm_1x8__sse_dup(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x8__sse_dup, 1, 8, 1, 1);
}
-
static void f32_igemm_4x8__sse_dup(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__sse_dup, 4, 8, 1, 1);
}
@@ -376,7 +374,6 @@
static void f32_igemm_1x8s4__sse(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x8s4__sse, 1, 8, 1, 4);
}
-
static void f32_igemm_4x8s4__sse(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8s4__sse, 4, 8, 1, 4);
}
@@ -384,19 +381,15 @@
static void f32_igemm_1x8__avx_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x8__avx_broadcast, 1, 8, 1, 1, benchmark::utils::CheckAVX);
}
-
static void f32_igemm_4x8__avx_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__avx_broadcast, 4, 8, 1, 1, benchmark::utils::CheckAVX);
}
-
static void f32_igemm_5x8__avx_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_5x8__avx_broadcast, 5, 8, 1, 1, benchmark::utils::CheckAVX);
}
-
static void f32_igemm_6x8__avx_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_6x8__avx_broadcast, 6, 8, 1, 1, benchmark::utils::CheckAVX);
}
-
static void f32_igemm_7x8__avx_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_7x8__avx_broadcast, 7, 8, 1, 1, benchmark::utils::CheckAVX);
}
@@ -404,44 +397,69 @@
static void f32_igemm_1x8__fma3_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x8__fma3_broadcast, 1, 8, 1, 1, benchmark::utils::CheckFMA3);
}
-
static void f32_igemm_4x8__fma3_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__fma3_broadcast, 4, 8, 1, 1, benchmark::utils::CheckFMA3);
}
-
static void f32_igemm_5x8__fma3_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_5x8__fma3_broadcast, 5, 8, 1, 1, benchmark::utils::CheckFMA3);
}
-
static void f32_igemm_6x8__fma3_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_6x8__fma3_broadcast, 6, 8, 1, 1, benchmark::utils::CheckFMA3);
}
-
static void f32_igemm_7x8__fma3_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_7x8__fma3_broadcast, 7, 8, 1, 1, benchmark::utils::CheckFMA3);
}
-
static void f32_igemm_8x8__fma3_broadcast(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_8x8__fma3_broadcast, 8, 8, 1, 1, benchmark::utils::CheckFMA3);
}
+ static void f32_igemm_1x16__avx512f_broadcast(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x16__avx512f_broadcast, 1, 16, 1, 1, benchmark::utils::CheckFMA3);
+ }
+ static void f32_igemm_4x16__avx512f_broadcast(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x16__avx512f_broadcast, 4, 16, 1, 1, benchmark::utils::CheckFMA3);
+ }
+ static void f32_igemm_5x16__avx512f_broadcast(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_5x16__avx512f_broadcast, 5, 16, 1, 1, benchmark::utils::CheckFMA3);
+ }
+ static void f32_igemm_6x16__avx512f_broadcast(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_6x16__avx512f_broadcast, 6, 16, 1, 1, benchmark::utils::CheckFMA3);
+ }
+ static void f32_igemm_7x16__avx512f_broadcast(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_7x16__avx512f_broadcast, 7, 16, 1, 1, benchmark::utils::CheckFMA3);
+ }
+ static void f32_igemm_8x16__avx512f_broadcast(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_8x16__avx512f_broadcast, 8, 16, 1, 1, benchmark::utils::CheckFMA3);
+ }
+
BENCHMARK_CONV(f32_igemm_1x8__sse_load1)
BENCHMARK_CONV(f32_igemm_4x8__sse_load1)
+
BENCHMARK_CONV(f32_igemm_1x8__sse_dup)
BENCHMARK_CONV(f32_igemm_4x8__sse_dup)
+
BENCHMARK_CONV(f32_igemm_1x8s4__sse)
BENCHMARK_CONV(f32_igemm_4x8s4__sse)
+
BENCHMARK_CONV(f32_igemm_1x8__avx_broadcast)
BENCHMARK_CONV(f32_igemm_4x8__avx_broadcast)
BENCHMARK_CONV(f32_igemm_5x8__avx_broadcast)
BENCHMARK_CONV(f32_igemm_6x8__avx_broadcast)
BENCHMARK_CONV(f32_igemm_7x8__avx_broadcast)
+
BENCHMARK_CONV(f32_igemm_1x8__fma3_broadcast)
BENCHMARK_CONV(f32_igemm_4x8__fma3_broadcast)
BENCHMARK_CONV(f32_igemm_5x8__fma3_broadcast)
BENCHMARK_CONV(f32_igemm_6x8__fma3_broadcast)
BENCHMARK_CONV(f32_igemm_7x8__fma3_broadcast)
BENCHMARK_CONV(f32_igemm_8x8__fma3_broadcast)
+
+ BENCHMARK_CONV(f32_igemm_1x16__avx512f_broadcast)
+ BENCHMARK_CONV(f32_igemm_4x16__avx512f_broadcast)
+ BENCHMARK_CONV(f32_igemm_5x16__avx512f_broadcast)
+ BENCHMARK_CONV(f32_igemm_6x16__avx512f_broadcast)
+ BENCHMARK_CONV(f32_igemm_7x16__avx512f_broadcast)
+ BENCHMARK_CONV(f32_igemm_8x16__avx512f_broadcast)
#endif /* XNN_ARCH_X86 || XNN_ARCH_X86_64 */
#if !XNN_ARCH_WASM && !XNN_ARCH_ASMJS