Refactor MaxPool and ArgMaxPool micro-kernels
- Support input_offset argument in MaxPool and ArgMaxPool micro-kernels
- Use input_offset to make indirection buffer independent on batch size
- Simplify and auto-generate unit tests
- Use more descriptive names for micro-kernel parameters
PiperOrigin-RevId: 281447682
diff --git a/src/init.c b/src/init.c
index dbcae34..083b978 100644
--- a/src/init.c
+++ b/src/init.c
@@ -109,7 +109,7 @@
/**************************** U8 micro-kernels ****************************/
#ifndef XNN_NO_U8_OPERATORS
xnn_params.u8.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__neon,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8x__neon_c16,
.mr = 9,
.qr = 8,
};
@@ -178,20 +178,20 @@
.mr = 7,
};
xnn_params.f32.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__psimd,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8x__psimd_c4,
.mr = 9,
.qr = 8,
};
xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__psimd,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_4x__psimd_c4,
.mr = 4,
};
xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__psimd,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_9x__psimd_c4,
.mr = 9,
};
xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
- .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__psimd,
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_9p8x__psimd_c4,
.mr = 9,
.qr = 8,
};
@@ -268,7 +268,7 @@
/**************************** U8 micro-kernels ****************************/
#ifndef XNN_NO_U8_OPERATORS
xnn_params.u8.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__neon,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8x__neon_c16,
.mr = 9,
.qr = 8,
};
@@ -440,20 +440,20 @@
.mr = 7,
};
xnn_params.f32.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__psimd,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8x__psimd_c4,
.mr = 9,
.qr = 8,
};
xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__psimd,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_4x__psimd_c4,
.mr = 4,
};
xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__psimd,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_9x__psimd_c4,
.mr = 9,
};
xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
- .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__psimd,
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_9p8x__psimd_c4,
.mr = 9,
.qr = 8,
};
@@ -588,7 +588,7 @@
/**************************** U8 micro-kernels ****************************/
#ifndef XNN_NO_U8_OPERATORS
xnn_params.u8.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__sse2,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8x__sse2_c16,
.mr = 9,
.qr = 8,
};
@@ -658,20 +658,20 @@
.mr = 7,
};
xnn_params.f32.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__sse,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8x__sse_c4,
.mr = 9,
.qr = 8,
};
xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__sse2,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_4x__sse2_c4,
.mr = 4,
};
xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__sse2,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_9x__sse2_c4,
.mr = 9,
};
xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
- .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__sse2,
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_9p8x__sse2_c4,
.mr = 9,
.qr = 8,
};
@@ -778,7 +778,7 @@
/**************************** U8 micro-kernels ****************************/
#ifndef XNN_NO_U8_OPERATORS
xnn_params.u8.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__scalar,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8x__scalar_c1,
.mr = 9,
.qr = 8,
};
@@ -860,20 +860,20 @@
.mr = 7,
};
xnn_params.f32.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__psimd,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8x__psimd_c4,
.mr = 9,
.qr = 8,
};
xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__psimd,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_4x__psimd_c4,
.mr = 4,
};
xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__psimd,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_9x__psimd_c4,
.mr = 9,
};
xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
- .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__psimd,
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_9p8x__psimd_c4,
.mr = 9,
.qr = 8,
};
@@ -956,7 +956,7 @@
/**************************** U8 micro-kernels ****************************/
#ifndef XNN_NO_U8_OPERATORS
xnn_params.u8.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__scalar,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8x__scalar_c1,
.mr = 9,
.qr = 8,
};
@@ -1036,20 +1036,20 @@
.mr = 7,
};
xnn_params.f32.maxpool = (struct maxpool_parameters) {
- .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__scalar,
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8x__scalar_c1,
.mr = 9,
.qr = 8,
};
xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__scalar,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_4x__scalar_c1,
.mr = 4,
};
xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
- .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__scalar,
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_9x__scalar_c1,
.mr = 9,
};
xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
- .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__scalar,
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1,
.mr = 9,
.qr = 8,
};