Minor optimization in NEON S8/U8 IBILINEAR microkernels on ARM64
PiperOrigin-RevId: 412101661
diff --git a/src/s8-ibilinear/gen/neon-c16.c b/src/s8-ibilinear/gen/neon-c16.c
index 4a86c93..5281688 100644
--- a/src/s8-ibilinear/gen/neon-c16.c
+++ b/src/s8-ibilinear/gen/neon-c16.c
@@ -91,8 +91,13 @@
const int32x4_t vacc89AB = vmlaq_s32(vshlq_n_s32(vt89AB, 11), vd89AB, valphav);
const int32x4_t vaccCDEF = vmlaq_s32(vshlq_n_s32(vtCDEF, 11), vdCDEF, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
- const int16x8_t vacc89ABCDEF = vcombine_s16(vshrn_n_s32(vacc89AB, 16), vshrn_n_s32(vaccCDEF, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ const int16x8_t vacc89ABCDEF = vuzp2q_s16(vreinterpretq_s16_s32(vacc89AB), vreinterpretq_s16_s32(vaccCDEF));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ const int16x8_t vacc89ABCDEF = vcombine_s16(vshrn_n_s32(vacc89AB, 16), vshrn_n_s32(vaccCDEF, 16));
+ #endif // !XNN_ARCH_ARM64
const int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
const int8x8_t vo89ABCDEF = vrshrn_n_s16(vacc89ABCDEF, 6);
@@ -130,7 +135,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
const int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
@@ -166,7 +175,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
diff --git a/src/s8-ibilinear/gen/neon-c8.c b/src/s8-ibilinear/gen/neon-c8.c
index f6e9854..b255793 100644
--- a/src/s8-ibilinear/gen/neon-c8.c
+++ b/src/s8-ibilinear/gen/neon-c8.c
@@ -72,7 +72,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
const int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
@@ -108,7 +112,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
diff --git a/src/s8-ibilinear/neon.c.in b/src/s8-ibilinear/neon.c.in
index ac6dc6c..6fc3e6e 100644
--- a/src/s8-ibilinear/neon.c.in
+++ b/src/s8-ibilinear/neon.c.in
@@ -96,8 +96,13 @@
const int32x4_t vacc${ABC[C:C+4]} = vmlaq_s32(vshlq_n_s32(vt${ABC[C:C+4]}, 11), vd${ABC[C:C+4]}, valphav);
const int32x4_t vacc${ABC[C+4:C+8]} = vmlaq_s32(vshlq_n_s32(vt${ABC[C+4:C+8]}, 11), vd${ABC[C+4:C+8]}, valphav);
- $for C in range(0, CHANNEL_TILE, 8):
- const int16x8_t vacc${ABC[C:C+8]} = vcombine_s16(vshrn_n_s32(vacc${ABC[C:C+4]}, 16), vshrn_n_s32(vacc${ABC[C+4:C+8]}, 16));
+ #if XNN_ARCH_ARM64
+ $for C in range(0, CHANNEL_TILE, 8):
+ const int16x8_t vacc${ABC[C:C+8]} = vuzp2q_s16(vreinterpretq_s16_s32(vacc${ABC[C:C+4]}), vreinterpretq_s16_s32(vacc${ABC[C+4:C+8]}));
+ #else // !XNN_ARCH_ARM64
+ $for C in range(0, CHANNEL_TILE, 8):
+ const int16x8_t vacc${ABC[C:C+8]} = vcombine_s16(vshrn_n_s32(vacc${ABC[C:C+4]}, 16), vshrn_n_s32(vacc${ABC[C+4:C+8]}, 16));
+ #endif // !XNN_ARCH_ARM64
$if DATATYPE == "S8":
$for C in range(0, CHANNEL_TILE, 8):
@@ -145,7 +150,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
$if DATATYPE == "S8":
const int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
@@ -190,7 +199,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
$if DATATYPE == "S8":
int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
diff --git a/src/u8-ibilinear/gen/neon-c16.c b/src/u8-ibilinear/gen/neon-c16.c
index 496aef8..968a054 100644
--- a/src/u8-ibilinear/gen/neon-c16.c
+++ b/src/u8-ibilinear/gen/neon-c16.c
@@ -91,8 +91,13 @@
const int32x4_t vacc89AB = vmlaq_s32(vshlq_n_s32(vt89AB, 11), vd89AB, valphav);
const int32x4_t vaccCDEF = vmlaq_s32(vshlq_n_s32(vtCDEF, 11), vdCDEF, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
- const int16x8_t vacc89ABCDEF = vcombine_s16(vshrn_n_s32(vacc89AB, 16), vshrn_n_s32(vaccCDEF, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ const int16x8_t vacc89ABCDEF = vuzp2q_s16(vreinterpretq_s16_s32(vacc89AB), vreinterpretq_s16_s32(vaccCDEF));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ const int16x8_t vacc89ABCDEF = vcombine_s16(vshrn_n_s32(vacc89AB, 16), vshrn_n_s32(vaccCDEF, 16));
+ #endif // !XNN_ARCH_ARM64
const uint8x8_t vo01234567 = vrshrn_n_u16(vreinterpretq_u16_s16(vacc01234567), 6);
const uint8x8_t vo89ABCDEF = vrshrn_n_u16(vreinterpretq_u16_s16(vacc89ABCDEF), 6);
@@ -130,7 +135,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
const uint8x8_t vo01234567 = vrshrn_n_u16(vreinterpretq_u16_s16(vacc01234567), 6);
@@ -166,7 +175,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
uint8x8_t vo01234567 = vrshrn_n_u16(vreinterpretq_u16_s16(vacc01234567), 6);
diff --git a/src/u8-ibilinear/gen/neon-c8.c b/src/u8-ibilinear/gen/neon-c8.c
index f2cd288..527ff6b 100644
--- a/src/u8-ibilinear/gen/neon-c8.c
+++ b/src/u8-ibilinear/gen/neon-c8.c
@@ -72,7 +72,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
const uint8x8_t vo01234567 = vrshrn_n_u16(vreinterpretq_u16_s16(vacc01234567), 6);
@@ -108,7 +112,11 @@
const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
- const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #if XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
+ #else // !XNN_ARCH_ARM64
+ const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
+ #endif // !XNN_ARCH_ARM64
uint8x8_t vo01234567 = vrshrn_n_u16(vreinterpretq_u16_s16(vacc01234567), 6);