FP16 versions of SpMM micro-kernels
PiperOrigin-RevId: 293518537
diff --git a/test/spmm-microkernel-tester.h b/test/spmm-microkernel-tester.h
index 20d6224..bcffad6 100644
--- a/test/spmm-microkernel-tester.h
+++ b/test/spmm-microkernel-tester.h
@@ -22,6 +22,12 @@
#include <xnnpack/params.h>
+static inline bool is_fp16_zero(uint16_t x) {
+ const uint32_t ext_x = x;
+ const uint32_t two_x = ext_x + ext_x;
+ return (uint16_t) two_x == 0;
+}
+
class SpMMMicrokernelTester {
public:
enum class Variant {
@@ -280,6 +286,173 @@
}
}
+ void Test(xnn_f16_spmm_ukernel_function spmm) const {
+ ASSERT_GE(m(), 1);
+ ASSERT_GE(n(), 1);
+ ASSERT_GE(k(), 1);
+
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
+ auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
+ auto prng = std::bind(std::uniform_real_distribution<float>(), rng);
+
+ std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> a(k() * m());
+ // Think of b as (n/nr + n % nr) x k, expansion happens later.
+ const size_t ncols = n() / nr() + n() % nr();
+ std::vector<uint16_t> b(ncols * k());
+ std::vector<uint16_t> bias(n());
+ // Number of non-zero weights per N (output channel).
+ std::vector<uint32_t> nmap(n());
+ // Mapping from index of non-zero weight to increment of K (input channel) following this index.
+ std::vector<int32_t> dmap(n() * k());
+ std::vector<uint16_t> w(n() * k() + n());
+ std::vector<uint16_t> c(n() * m());
+ std::vector<float> c_ref(n() * m());
+
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(a.begin(), a.end(), std::ref(f16rng));
+ std::generate(b.begin(), b.end(), std::ref(f16rng));
+ std::generate(bias.begin(), bias.end(), std::ref(f16rng));
+ std::fill(c.begin(), c.end(), 0xC000);
+ std::fill(c_ref.begin(), c_ref.end(), 0.0f);
+ std::fill(nmap.begin(), nmap.end(), 0);
+ std::fill(dmap.begin(), dmap.end(), 0);
+ std::fill(w.begin(), w.end(), 0);
+
+ for (uint16_t& b_value : b) {
+ if (prng() <= sparsity()) {
+ b_value = 0;
+ }
+ }
+
+ uint32_t nnz = 0;
+ uint32_t wcnt = 0;
+ size_t last_kk = 0;
+ bool first_nzz = true;
+ size_t first_kk = 0;
+ for (size_t nn = 0; nn < n() / nr(); nn++) {
+ for (size_t i = 0; i < nr(); ++i)
+ w[wcnt++] = bias[nr() * nn + i];
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (!is_fp16_zero(b[nn * k() + kk])) {
+ // Every non-zero actually corresponds to nr adjacent non-zeros.
+ for (size_t i = 0; i < nr(); ++i)
+ w[wcnt++] = fp16_ieee_from_fp32_value(fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
+ // Skip the very first non-zero weight as we record only the difference.
+ if (first_nzz) {
+ first_kk = kk;
+ } else {
+ const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
+ dmap[nnz++] = increment;
+ }
+ last_kk = kk;
+ first_nzz = false;
+ nmap[nn] += 1;
+ }
+ }
+ }
+
+ // now we've constructed the matrix for the blocked part and switch to the
+ // leftovers, which we do as nr=1 always.
+ for (size_t nn = n() / nr(); nn < ncols; nn++) {
+ w[wcnt++] = bias[(n() / nr()) * nr() + (nn - n() / nr())];
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (!is_fp16_zero(b[nn * k() + kk])) {
+ // Every non-zero actually corresponds to nr adjacent non-zeros.
+ w[wcnt++] = b[nn * k() + kk];
+ // Skip the very first non-zero weight as we record only the difference.
+ if (first_nzz) {
+ first_kk = kk;
+ } else {
+ const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
+ dmap[nnz++] = increment;
+ }
+ last_kk = kk;
+ first_nzz = false;
+ nmap[nn] += 1;
+ }
+ }
+ }
+ // In the end, we must return input pointer to the initial value.
+ const int64_t increment = int32_t(first_kk - last_kk) * int32_t(m() * sizeof(uint16_t));
+ dmap[nnz++] = increment;
+
+ // Generate expanded b which will be used in reference calculation.
+ // Everywhere there is a non-zero in the original we copy it and add an
+ // adjacent non-zero with incremented weight value.
+ std::vector<uint16_t> b_full(n() * k());
+ if (nr() == 1) {
+ b_full = b;
+ }
+ else {
+ for (size_t nn = 0; nn < n() / nr(); nn++) {
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (b[nn * k() + kk] != 0.0f) {
+ for (size_t i = 0; i < nr(); ++i)
+ b_full[nr() * nn * k() + i * k() + kk] = fp16_ieee_from_fp32_value(
+ fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
+ }
+ }
+ }
+ for (size_t nn = n() / nr(); nn < ncols; nn++) {
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (b[nn * k() + kk] != 0.0f) {
+ b_full[nr() * (n() / nr()) * k() + (nn - n() / nr()) * k() + kk] = b[nn * k() + kk];
+ }
+ }
+ }
+ }
+
+ for (size_t oc = 0; oc < n(); oc++) {
+ for (size_t pxb = 0; pxb < m(); pxb++) {
+ c_ref[oc * m() + pxb] = fp16_ieee_to_fp32_value(bias[oc]);
+ for (size_t ic = 0; ic < k(); ic++) {
+ c_ref[oc * m() + pxb] += fp16_ieee_to_fp32_value(a[ic * m() + pxb]) * fp16_ieee_to_fp32_value(b_full[oc * k() + ic]);
+ }
+ }
+ }
+
+ // Micro-kernel can access one element beyond w and dmap for software pipelining.
+ w.resize(wcnt + 1);
+ dmap.resize(nnz + 1);
+
+ // Compute clamping parameters.
+ const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
+ const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
+ const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
+ const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
+
+ // Clamp reference results.
+ for (float& c_value : c_ref) {
+ c_value = std::min(std::max(c_value, c_min), c_max);
+ }
+
+ // Prepare output parameters.
+ xnn_f16_output_params output_params;
+ output_params.scale = UINT16_C(0x3C00) /* 1.0 */;
+ output_params.max = fp16_ieee_from_fp32_value(c_max);
+ output_params.min = fp16_ieee_from_fp32_value(c_min);
+
+ spmm(m(), n(),
+ a.data() + first_kk * m(), w.data(), dmap.data(), nmap.data(), c.data(),
+ &output_params);
+
+ // Validate micro-kernel outputs.
+ for (size_t pxb = 0; pxb < n(); pxb++) {
+ for (size_t oc = 0; oc < m(); oc++) {
+ ASSERT_NEAR(
+ fp16_ieee_to_fp32_value(c[pxb * m() + oc]),
+ c_ref[pxb * m() + oc],
+ std::abs(c_ref[pxb * m() + oc]) * 1.0e-2f)
+ << "at " << pxb << ", " << oc
+ << ": Mr x Nr x Kr = " << mr() << " x " << nr()
+ << ", M x N x K = " << m() << " x " << n() << " x " << k();
+ }
+ }
+ }
+ }
+
private:
size_t mr_{1};
size_t nr_{1};