Re-generate SpMM micro-kernels
PiperOrigin-RevId: 276802114
diff --git a/src/f32-spmm/1x1-scalar.c b/src/f32-spmm/1x1-scalar.c
index 5d5752f..48029eb 100644
--- a/src/f32-spmm/1x1-scalar.c
+++ b/src/f32-spmm/1x1-scalar.c
@@ -33,23 +33,44 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
+ float vacc0x0 = *w++;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
const float va0 = a[0];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- vout0 = math_max_f32(vout0, vmin);
- c[0] = vout0;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 1;
a += 1;
diff --git a/src/f32-spmm/2x1-scalar.c b/src/f32-spmm/2x1-scalar.c
index 474c2ea..4e33933 100644
--- a/src/f32-spmm/2x1-scalar.c
+++ b/src/f32-spmm/2x1-scalar.c
@@ -33,29 +33,56 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
- float vacc1 = vacc0;
+ float vacc0x0 = *w++;
+ float vacc1x0 = vacc0x0;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
const float va0 = a[0];
const float va1 = a[1];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
- vacc1 += va1 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- float vout1 = math_min_f32(vacc1, vmax);
- vout0 = math_max_f32(vout0, vmin);
- vout1 = math_max_f32(vout1, vmin);
- c[0] = vout0;
- c[1] = vout1;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ float vout1x0 = math_min_f32(vacc1x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ vout1x0 = math_max_f32(vout1x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c[0 * m + 1] = vout1x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 2;
a += 2;
@@ -67,23 +94,44 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
+ float vacc0x0 = *w++;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
const float va0 = a[0];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- vout0 = math_max_f32(vout0, vmin);
- c[0] = vout0;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 1;
a += 1;
diff --git a/src/f32-spmm/4x1-scalar.c b/src/f32-spmm/4x1-scalar.c
index 353f633..d45c518 100644
--- a/src/f32-spmm/4x1-scalar.c
+++ b/src/f32-spmm/4x1-scalar.c
@@ -33,12 +33,12 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
- float vacc1 = vacc0;
- float vacc2 = vacc0;
- float vacc3 = vacc0;
+ float vacc0x0 = *w++;
+ float vacc1x0 = vacc0x0;
+ float vacc2x0 = vacc0x0;
+ float vacc3x0 = vacc0x0;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
@@ -47,27 +47,66 @@
const float va2 = a[2];
const float va3 = a[3];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
- vacc1 += va1 * vb;
- vacc2 += va2 * vb;
- vacc3 += va3 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
+ vacc2x0 += va2 * vb0;
+ vacc3x0 += va3 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- float vout1 = math_min_f32(vacc1, vmax);
- float vout2 = math_min_f32(vacc2, vmax);
- float vout3 = math_min_f32(vacc3, vmax);
- vout0 = math_max_f32(vout0, vmin);
- vout1 = math_max_f32(vout1, vmin);
- vout2 = math_max_f32(vout2, vmin);
- vout3 = math_max_f32(vout3, vmin);
- c[0] = vout0;
- c[1] = vout1;
- c[2] = vout2;
- c[3] = vout3;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ float vout1x0 = math_min_f32(vacc1x0, vmax);
+ float vout2x0 = math_min_f32(vacc2x0, vmax);
+ float vout3x0 = math_min_f32(vacc3x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ vout1x0 = math_max_f32(vout1x0, vmin);
+ vout2x0 = math_max_f32(vout2x0, vmin);
+ vout3x0 = math_max_f32(vout3x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c[0 * m + 1] = vout1x0;
+ c[0 * m + 2] = vout2x0;
+ c[0 * m + 3] = vout3x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ float vacc2 = vacc0;
+ float vacc3 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 4;
a += 4;
@@ -79,29 +118,56 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
- float vacc1 = vacc0;
+ float vacc0x0 = *w++;
+ float vacc1x0 = vacc0x0;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
const float va0 = a[0];
const float va1 = a[1];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
- vacc1 += va1 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- float vout1 = math_min_f32(vacc1, vmax);
- vout0 = math_max_f32(vout0, vmin);
- vout1 = math_max_f32(vout1, vmin);
- c[0] = vout0;
- c[1] = vout1;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ float vout1x0 = math_min_f32(vacc1x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ vout1x0 = math_max_f32(vout1x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c[0 * m + 1] = vout1x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 2;
a += 2;
@@ -111,23 +177,44 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
+ float vacc0x0 = *w++;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
const float va0 = a[0];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- vout0 = math_max_f32(vout0, vmin);
- c[0] = vout0;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 1;
a += 1;
diff --git a/src/f32-spmm/8x1-scalar.c b/src/f32-spmm/8x1-scalar.c
index 2a1ac08..5574a7d 100644
--- a/src/f32-spmm/8x1-scalar.c
+++ b/src/f32-spmm/8x1-scalar.c
@@ -33,16 +33,16 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
- float vacc1 = vacc0;
- float vacc2 = vacc0;
- float vacc3 = vacc0;
- float vacc4 = vacc0;
- float vacc5 = vacc0;
- float vacc6 = vacc0;
- float vacc7 = vacc0;
+ float vacc0x0 = *w++;
+ float vacc1x0 = vacc0x0;
+ float vacc2x0 = vacc0x0;
+ float vacc3x0 = vacc0x0;
+ float vacc4x0 = vacc0x0;
+ float vacc5x0 = vacc0x0;
+ float vacc6x0 = vacc0x0;
+ float vacc7x0 = vacc0x0;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
@@ -55,43 +55,106 @@
const float va6 = a[6];
const float va7 = a[7];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
- vacc1 += va1 * vb;
- vacc2 += va2 * vb;
- vacc3 += va3 * vb;
- vacc4 += va4 * vb;
- vacc5 += va5 * vb;
- vacc6 += va6 * vb;
- vacc7 += va7 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
+ vacc2x0 += va2 * vb0;
+ vacc3x0 += va3 * vb0;
+ vacc4x0 += va4 * vb0;
+ vacc5x0 += va5 * vb0;
+ vacc6x0 += va6 * vb0;
+ vacc7x0 += va7 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- float vout1 = math_min_f32(vacc1, vmax);
- float vout2 = math_min_f32(vacc2, vmax);
- float vout3 = math_min_f32(vacc3, vmax);
- float vout4 = math_min_f32(vacc4, vmax);
- float vout5 = math_min_f32(vacc5, vmax);
- float vout6 = math_min_f32(vacc6, vmax);
- float vout7 = math_min_f32(vacc7, vmax);
- vout0 = math_max_f32(vout0, vmin);
- vout1 = math_max_f32(vout1, vmin);
- vout2 = math_max_f32(vout2, vmin);
- vout3 = math_max_f32(vout3, vmin);
- vout4 = math_max_f32(vout4, vmin);
- vout5 = math_max_f32(vout5, vmin);
- vout6 = math_max_f32(vout6, vmin);
- vout7 = math_max_f32(vout7, vmin);
- c[0] = vout0;
- c[1] = vout1;
- c[2] = vout2;
- c[3] = vout3;
- c[4] = vout4;
- c[5] = vout5;
- c[6] = vout6;
- c[7] = vout7;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ float vout1x0 = math_min_f32(vacc1x0, vmax);
+ float vout2x0 = math_min_f32(vacc2x0, vmax);
+ float vout3x0 = math_min_f32(vacc3x0, vmax);
+ float vout4x0 = math_min_f32(vacc4x0, vmax);
+ float vout5x0 = math_min_f32(vacc5x0, vmax);
+ float vout6x0 = math_min_f32(vacc6x0, vmax);
+ float vout7x0 = math_min_f32(vacc7x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ vout1x0 = math_max_f32(vout1x0, vmin);
+ vout2x0 = math_max_f32(vout2x0, vmin);
+ vout3x0 = math_max_f32(vout3x0, vmin);
+ vout4x0 = math_max_f32(vout4x0, vmin);
+ vout5x0 = math_max_f32(vout5x0, vmin);
+ vout6x0 = math_max_f32(vout6x0, vmin);
+ vout7x0 = math_max_f32(vout7x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c[0 * m + 1] = vout1x0;
+ c[0 * m + 2] = vout2x0;
+ c[0 * m + 3] = vout3x0;
+ c[0 * m + 4] = vout4x0;
+ c[0 * m + 5] = vout5x0;
+ c[0 * m + 6] = vout6x0;
+ c[0 * m + 7] = vout7x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ float vacc2 = vacc0;
+ float vacc3 = vacc0;
+ float vacc4 = vacc0;
+ float vacc5 = vacc0;
+ float vacc6 = vacc0;
+ float vacc7 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ const float va4 = a[4];
+ const float va5 = a[5];
+ const float va6 = a[6];
+ const float va7 = a[7];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ vacc4 += va4 * vb;
+ vacc5 += va5 * vb;
+ vacc6 += va6 * vb;
+ vacc7 += va7 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ float vout4 = math_min_f32(vacc4, vmax);
+ float vout5 = math_min_f32(vacc5, vmax);
+ float vout6 = math_min_f32(vacc6, vmax);
+ float vout7 = math_min_f32(vacc7, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ vout4 = math_max_f32(vout4, vmin);
+ vout5 = math_max_f32(vout5, vmin);
+ vout6 = math_max_f32(vout6, vmin);
+ vout7 = math_max_f32(vout7, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c[4] = vout4;
+ c[5] = vout5;
+ c[6] = vout6;
+ c[7] = vout7;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 8;
a += 8;
@@ -103,12 +166,12 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
- float vacc1 = vacc0;
- float vacc2 = vacc0;
- float vacc3 = vacc0;
+ float vacc0x0 = *w++;
+ float vacc1x0 = vacc0x0;
+ float vacc2x0 = vacc0x0;
+ float vacc3x0 = vacc0x0;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
@@ -117,27 +180,66 @@
const float va2 = a[2];
const float va3 = a[3];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
- vacc1 += va1 * vb;
- vacc2 += va2 * vb;
- vacc3 += va3 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
+ vacc2x0 += va2 * vb0;
+ vacc3x0 += va3 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- float vout1 = math_min_f32(vacc1, vmax);
- float vout2 = math_min_f32(vacc2, vmax);
- float vout3 = math_min_f32(vacc3, vmax);
- vout0 = math_max_f32(vout0, vmin);
- vout1 = math_max_f32(vout1, vmin);
- vout2 = math_max_f32(vout2, vmin);
- vout3 = math_max_f32(vout3, vmin);
- c[0] = vout0;
- c[1] = vout1;
- c[2] = vout2;
- c[3] = vout3;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ float vout1x0 = math_min_f32(vacc1x0, vmax);
+ float vout2x0 = math_min_f32(vacc2x0, vmax);
+ float vout3x0 = math_min_f32(vacc3x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ vout1x0 = math_max_f32(vout1x0, vmin);
+ vout2x0 = math_max_f32(vout2x0, vmin);
+ vout3x0 = math_max_f32(vout3x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c[0 * m + 1] = vout1x0;
+ c[0 * m + 2] = vout2x0;
+ c[0 * m + 3] = vout3x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ float vacc2 = vacc0;
+ float vacc3 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 4;
a += 4;
@@ -147,29 +249,56 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
- float vacc1 = vacc0;
+ float vacc0x0 = *w++;
+ float vacc1x0 = vacc0x0;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
const float va0 = a[0];
const float va1 = a[1];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
- vacc1 += va1 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- float vout1 = math_min_f32(vacc1, vmax);
- vout0 = math_max_f32(vout0, vmin);
- vout1 = math_max_f32(vout1, vmin);
- c[0] = vout0;
- c[1] = vout1;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ float vout1x0 = math_min_f32(vacc1x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ vout1x0 = math_max_f32(vout1x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c[0 * m + 1] = vout1x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 2;
a += 2;
@@ -179,23 +308,44 @@
const int32_t* dmap = widx_dmap;
const uint32_t* nnzmap = nidx_nnzmap;
size_t j = n;
- do {
+ while (j >= 1) {
uint32_t nnz = *nnzmap++;
- float vacc0 = *w++;
+ float vacc0x0 = *w++;
if XNN_LIKELY(nnz != 0) {
do {
const intptr_t diff = *dmap++;
const float va0 = a[0];
a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
- const float vb = *w++;
- vacc0 += va0 * vb;
+ const float vb0 = *w++;
+ vacc0x0 += va0 * vb0;
} while (--nnz != 0);
}
- float vout0 = math_min_f32(vacc0, vmax);
- vout0 = math_max_f32(vout0, vmin);
- c[0] = vout0;
- c += m;
- } while (--j != 0);
+ float vout0x0 = math_min_f32(vacc0x0, vmax);
+ vout0x0 = math_max_f32(vout0x0, vmin);
+ c[0 * m + 0] = vout0x0;
+ c += 1 * m;
+ j -= 1;
+ }
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
c -= m * n;
c += 1;
a += 1;
diff --git a/test/f32-spmm.cc b/test/f32-spmm.cc
index 5c352f6..49d397f 100644
--- a/test/f32-spmm.cc
+++ b/test/f32-spmm.cc
@@ -5029,8 +5029,8 @@
}
}
-TEST(F32_SPMM_8X2__SCALAR, n_gt_1) {
- for (uint32_t n = 2; n < 10; n++) {
+TEST(F32_SPMM_8X2__SCALAR, n_gt_2) {
+ for (uint32_t n = 3; n < 10; n++) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5044,9 +5044,23 @@
}
}
+TEST(F32_SPMM_8X2__SCALAR, n_div_2) {
+ for (uint32_t n = 4; n <= 6; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(2)
+ .m(8)
+ .n(n)
+ .k(k)
+ .Test(xnn_f32_spmm_ukernel_8x2__scalar, SpMMMicrokernelTester::Variant::Scalar);
+ }
+ }
+}
+
TEST(F32_SPMM_8X2__SCALAR, m_lt_8) {
for (uint32_t m = 1; m < 8; m++) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 10; n += 3) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5063,7 +5077,7 @@
TEST(F32_SPMM_8X2__SCALAR, m_div_8) {
for (uint32_t m = 16; m <= 24; m += 8) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 10; n += 3) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5080,7 +5094,7 @@
TEST(F32_SPMM_8X2__SCALAR, m_gt_8) {
for (uint32_t m = 9; m < 16; m++) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 10; n += 3) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5096,7 +5110,7 @@
}
TEST(F32_SPMM_8X2__SCALAR, qmin) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 10; n += 3) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5112,7 +5126,7 @@
}
TEST(F32_SPMM_8X2__SCALAR, qmax) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 10; n += 3) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5128,7 +5142,7 @@
}
TEST(F32_SPMM_8X2__SCALAR, half_sparse) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 10; n += 3) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5143,7 +5157,7 @@
}
TEST(F32_SPMM_8X2__SCALAR, zero_weights) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 10; n += 3) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5162,7 +5176,7 @@
.mr(8)
.nr(4)
.m(8)
- .n(2)
+ .n(4)
.k(1)
.sparsity(0.0f)
.Test(xnn_f32_spmm_ukernel_8x4__scalar, SpMMMicrokernelTester::Variant::Scalar);
@@ -5209,8 +5223,8 @@
}
}
-TEST(F32_SPMM_8X4__SCALAR, n_gt_1) {
- for (uint32_t n = 2; n < 10; n++) {
+TEST(F32_SPMM_8X4__SCALAR, n_gt_4) {
+ for (uint32_t n = 5; n < 10; n++) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5224,9 +5238,23 @@
}
}
+TEST(F32_SPMM_8X4__SCALAR, n_div_4) {
+ for (uint32_t n = 8; n <= 12; n += 4) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(4)
+ .m(8)
+ .n(n)
+ .k(k)
+ .Test(xnn_f32_spmm_ukernel_8x4__scalar, SpMMMicrokernelTester::Variant::Scalar);
+ }
+ }
+}
+
TEST(F32_SPMM_8X4__SCALAR, m_lt_8) {
for (uint32_t m = 1; m < 8; m++) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 20; n += 5) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5243,7 +5271,7 @@
TEST(F32_SPMM_8X4__SCALAR, m_div_8) {
for (uint32_t m = 16; m <= 24; m += 8) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 20; n += 5) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5260,7 +5288,7 @@
TEST(F32_SPMM_8X4__SCALAR, m_gt_8) {
for (uint32_t m = 9; m < 16; m++) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 20; n += 5) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5276,7 +5304,7 @@
}
TEST(F32_SPMM_8X4__SCALAR, qmin) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 20; n += 5) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5292,7 +5320,7 @@
}
TEST(F32_SPMM_8X4__SCALAR, qmax) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 20; n += 5) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5308,7 +5336,7 @@
}
TEST(F32_SPMM_8X4__SCALAR, half_sparse) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 20; n += 5) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5323,7 +5351,7 @@
}
TEST(F32_SPMM_8X4__SCALAR, zero_weights) {
- for (uint32_t n = 1; n < 10; n += 2) {
+ for (uint32_t n = 1; n < 20; n += 5) {
for (size_t k = 1; k <= 5; k += 2) {
SpMMMicrokernelTester()
.mr(8)
@@ -5335,4 +5363,4 @@
.Test(xnn_f32_spmm_ukernel_8x4__scalar, SpMMMicrokernelTester::Variant::Scalar);
}
}
-}
+}
\ No newline at end of file