Refactor DWCONV micro-kernels
- Fix bugs in generation of micro-kernels with large channel tiles
- Add missing unit tests
- Generate, test, and benchmark a microkernels with 2 accumulators, with 2X
channel tile, and their combinations
PiperOrigin-RevId: 279137161
diff --git a/src/f32-dwconv/up4x9-sse.c b/src/f32-dwconv/up4x9-sse.c
index 6779891..1011f23 100644
--- a/src/f32-dwconv/up4x9-sse.c
+++ b/src/f32-dwconv/up4x9-sse.c
@@ -44,114 +44,122 @@
size_t c = channels;
const float* w = weights;
for (; c >= 4; c -= 4) {
- __m128 vacc0 = _mm_load_ps(w);
+ __m128 vacc0123p0 = _mm_load_ps(w);
- const __m128 vi0 = _mm_loadu_ps(i0);
- const __m128 vk0 = _mm_load_ps(w + 4);
- vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi0, vk0));
+
+ const __m128 vi0x0123 = _mm_loadu_ps(i0);
i0 += 4;
- const __m128 vi1 = _mm_loadu_ps(i1);
- const __m128 vk1 = _mm_load_ps(w + 8);
- __m128 vacc1 = _mm_mul_ps(vi1, vk1);
+ const __m128 vk0x0123 = _mm_load_ps(w + 4);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi0x0123, vk0x0123));
+
+ const __m128 vi1x0123 = _mm_loadu_ps(i1);
i1 += 4;
- const __m128 vi2 = _mm_loadu_ps(i2);
- const __m128 vk2 = _mm_load_ps(w + 12);
- vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi2, vk2));
+ const __m128 vk1x0123 = _mm_load_ps(w + 8);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi1x0123, vk1x0123));
+
+ const __m128 vi2x0123 = _mm_loadu_ps(i2);
i2 += 4;
- const __m128 vi3 = _mm_loadu_ps(i3);
- const __m128 vk3 = _mm_load_ps(w + 16);
- vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi3, vk3));
+ const __m128 vk2x0123 = _mm_load_ps(w + 12);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi2x0123, vk2x0123));
+
+ const __m128 vi3x0123 = _mm_loadu_ps(i3);
i3 += 4;
- const __m128 vi4 = _mm_loadu_ps(i4);
- const __m128 vk4 = _mm_load_ps(w + 20);
- vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi4, vk4));
+ const __m128 vk3x0123 = _mm_load_ps(w + 16);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi3x0123, vk3x0123));
+
+ const __m128 vi4x0123 = _mm_loadu_ps(i4);
i4 += 4;
- const __m128 vi5 = _mm_loadu_ps(i5);
- const __m128 vk5 = _mm_load_ps(w + 24);
- vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi5, vk5));
+ const __m128 vk4x0123 = _mm_load_ps(w + 20);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi4x0123, vk4x0123));
+
+ const __m128 vi5x0123 = _mm_loadu_ps(i5);
i5 += 4;
- const __m128 vi6 = _mm_loadu_ps(i6);
- const __m128 vk6 = _mm_load_ps(w + 28);
- vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi6, vk6));
+ const __m128 vk5x0123 = _mm_load_ps(w + 24);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi5x0123, vk5x0123));
+
+ const __m128 vi6x0123 = _mm_loadu_ps(i6);
i6 += 4;
- const __m128 vi7 = _mm_loadu_ps(i7);
- const __m128 vk7 = _mm_load_ps(w + 32);
- vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi7, vk7));
+ const __m128 vk6x0123 = _mm_load_ps(w + 28);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi6x0123, vk6x0123));
+
+ const __m128 vi7x0123 = _mm_loadu_ps(i7);
i7 += 4;
- const __m128 vi8 = _mm_loadu_ps(i8);
- const __m128 vk8 = _mm_load_ps(w + 36);
- vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi8, vk8));
+ const __m128 vk7x0123 = _mm_load_ps(w + 32);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi7x0123, vk7x0123));
+
+ const __m128 vi8x0123 = _mm_loadu_ps(i8);
i8 += 4;
+ const __m128 vk8x0123 = _mm_load_ps(w + 36);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi8x0123, vk8x0123));
+
w += 40;
- vacc0 = _mm_add_ps(vacc0, vacc1);
- vacc0 = _mm_max_ps(vacc0, vmin);
- vacc0 = _mm_min_ps(vacc0, vmax);
+ __m128 vacc0123 = _mm_max_ps(vacc0123p0, vmin);
+ vacc0123 = _mm_min_ps(vacc0123, vmax);
- _mm_storeu_ps(output, vacc0);
+ _mm_storeu_ps(output, vacc0123);
output += 4;
}
if XNN_UNLIKELY(c != 0) {
- __m128 vacc = _mm_load_ps(w);
+ __m128 vacc0123p0 = _mm_load_ps(w);
- const __m128 vi0 = _mm_loadu_ps(i0);
- const __m128 vk0 = _mm_load_ps(w + 4);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi0, vk0));
+ const __m128 vi0x0123 = _mm_loadu_ps(i0);
+ const __m128 vk0x0123 = _mm_load_ps(w + 4);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi0x0123, vk0x0123));
- const __m128 vi1 = _mm_loadu_ps(i1);
- const __m128 vk1 = _mm_load_ps(w + 8);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi1, vk1));
+ const __m128 vi1x0123 = _mm_loadu_ps(i1);
+ const __m128 vk1x0123 = _mm_load_ps(w + 8);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi1x0123, vk1x0123));
- const __m128 vi2 = _mm_loadu_ps(i2);
- const __m128 vk2 = _mm_load_ps(w + 12);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi2, vk2));
+ const __m128 vi2x0123 = _mm_loadu_ps(i2);
+ const __m128 vk2x0123 = _mm_load_ps(w + 12);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi2x0123, vk2x0123));
- const __m128 vi3 = _mm_loadu_ps(i3);
- const __m128 vk3 = _mm_load_ps(w + 16);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi3, vk3));
+ const __m128 vi3x0123 = _mm_loadu_ps(i3);
+ const __m128 vk3x0123 = _mm_load_ps(w + 16);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi3x0123, vk3x0123));
- const __m128 vi4 = _mm_loadu_ps(i4);
- const __m128 vk4 = _mm_load_ps(w + 20);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi4, vk4));
+ const __m128 vi4x0123 = _mm_loadu_ps(i4);
+ const __m128 vk4x0123 = _mm_load_ps(w + 20);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi4x0123, vk4x0123));
- const __m128 vi5 = _mm_loadu_ps(i5);
- const __m128 vk5 = _mm_load_ps(w + 24);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi5, vk5));
+ const __m128 vi5x0123 = _mm_loadu_ps(i5);
+ const __m128 vk5x0123 = _mm_load_ps(w + 24);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi5x0123, vk5x0123));
- const __m128 vi6 = _mm_loadu_ps(i6);
- const __m128 vk6 = _mm_load_ps(w + 28);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi6, vk6));
+ const __m128 vi6x0123 = _mm_loadu_ps(i6);
+ const __m128 vk6x0123 = _mm_load_ps(w + 28);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi6x0123, vk6x0123));
- const __m128 vi7 = _mm_loadu_ps(i7);
- const __m128 vk7 = _mm_load_ps(w + 32);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi7, vk7));
+ const __m128 vi7x0123 = _mm_loadu_ps(i7);
+ const __m128 vk7x0123 = _mm_load_ps(w + 32);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi7x0123, vk7x0123));
- const __m128 vi8 = _mm_loadu_ps(i8);
- const __m128 vk8 = _mm_load_ps(w + 36);
- vacc = _mm_add_ps(vacc, _mm_mul_ps(vi8, vk8));
+ const __m128 vi8x0123 = _mm_loadu_ps(i8);
+ const __m128 vk8x0123 = _mm_load_ps(w + 36);
+ vacc0123p0 = _mm_add_ps(vacc0123p0, _mm_mul_ps(vi8x0123, vk8x0123));
- w += 40;
- vacc = _mm_max_ps(vacc, vmin);
- vacc = _mm_min_ps(vacc, vmax);
+ __m128 vacc0123 = _mm_max_ps(vacc0123p0, vmin);
+ vacc0123 = _mm_min_ps(vacc0123, vmax);
if (c & 2) {
- _mm_storel_pi((__m64*) output, vacc);
- vacc = _mm_movehl_ps(vacc, vacc);
+ _mm_storel_pi((__m64*) output, vacc0123);
+ vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
output += 2;
}
if (c & 1) {
- _mm_store_ss(output, vacc);
+ _mm_store_ss(output, vacc0123);
output += 1;
}
}