FP16 versions of SpMM micro-kernels
PiperOrigin-RevId: 293518537
diff --git a/BUILD.bazel b/BUILD.bazel
index 91d2a03..095b259 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -843,6 +843,14 @@
"src/f16-gemm/gen/4x8-neonfp16arith-ld64.c",
"src/f16-gemm/gen/6x8-neonfp16arith-ld64.c",
"src/f16-gemm/gen/8x8-neonfp16arith-ld64.c",
+ "src/f16-spmm/gen/8x1-neonfp16arith.c",
+ "src/f16-spmm/gen/8x1-neonfp16arith-unroll2.c",
+ "src/f16-spmm/gen/16x1-neonfp16arith.c",
+ "src/f16-spmm/gen/16x1-neonfp16arith-unroll2.c",
+ "src/f16-spmm/gen/24x1-neonfp16arith.c",
+ "src/f16-spmm/gen/24x1-neonfp16arith-unroll2.c",
+ "src/f16-spmm/gen/32x1-neonfp16arith.c",
+ "src/f16-spmm/gen/32x1-neonfp16arith-unroll2.c",
]
SSE_UKERNELS = [
@@ -1942,6 +1950,18 @@
)
xnnpack_benchmark(
+ name = "f16_spmm_bench",
+ srcs = [
+ "bench/f16-spmm.cc",
+ "bench/gemm.h",
+ "bench/google/gemm.h",
+ "src/xnnpack/AlignedAllocator.h",
+ ] + MICROKERNEL_BENCHMARK_HDRS,
+ copts = ["-Wno-unused-function"],
+ deps = MICROKERNEL_BENCHMARK_DEPS,
+)
+
+xnnpack_benchmark(
name = "f32_igemm_bench",
srcs = [
"bench/f32-igemm.cc",
@@ -2312,6 +2332,16 @@
)
xnnpack_unit_test(
+ name = "f16_spmm_test",
+ srcs = [
+ "test/f16-spmm.cc",
+ "test/spmm-microkernel-tester.h",
+ "src/xnnpack/AlignedAllocator.h",
+ ] + MICROKERNEL_TEST_HDRS,
+ deps = MICROKERNEL_TEST_DEPS,
+)
+
+xnnpack_unit_test(
name = "f32_argmaxpool_test",
srcs = [
"test/f32-argmaxpool.cc",
diff --git a/bench/f16-spmm.cc b/bench/f16-spmm.cc
new file mode 100644
index 0000000..12529b8
--- /dev/null
+++ b/bench/f16-spmm.cc
@@ -0,0 +1,202 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <algorithm>
+#include <cfloat>
+#include <cmath>
+#include <functional>
+#include <random>
+#include <vector>
+
+#include <cpuinfo.h>
+
+#include <benchmark/benchmark.h>
+#include <fp16/fp16.h>
+#include "bench/gemm.h"
+#include "bench/utils.h"
+#include <xnnpack/AlignedAllocator.h>
+#include <xnnpack/common.h>
+#include <xnnpack/params-init.h>
+#include <xnnpack/params.h>
+#include <xnnpack/spmm.h>
+
+
+static void SpMMBenchmark(benchmark::State& state,
+ xnn_f16_spmm_ukernel_function spmm, uint32_t mr, uint32_t nr, float sparsity)
+{
+ if (!cpuinfo_initialize()) {
+ state.SkipWithError("cpuinfo initialization failed");
+ return;
+ }
+
+ const size_t mc = state.range(0);
+ const size_t nc = state.range(1);
+ const size_t kc = state.range(2);
+
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
+ auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
+
+ // if using blocks, generate the reduced matrix first and then extrude along
+ // the block dimension (n), to get the full matrix
+ size_t ncols = nc / nr + nc % nr;
+ std::vector<uint16_t> b(ncols * kc);
+ std::vector<uint16_t> bias(nc);
+ std::vector<uint16_t> w;
+ std::vector<uint32_t> nmap;
+ std::vector<int32_t> dmap;
+ const size_t sparse_end = std::min(size_t(float(b.size()) * sparsity), b.size());
+ const size_t num_nonzeroes = nr * (b.size() - sparse_end);
+
+ const size_t w_elements = num_nonzeroes + nc;
+ const size_t c_elements = mc * nc;
+ const size_t dmap_elements = num_nonzeroes / nr;
+ const size_t nmap_elements = nc;
+ const size_t num_buffers = 1 +
+ benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
+ sizeof(uint16_t) * (w_elements + c_elements) + sizeof(uint32_t) * (dmap_elements + nmap_elements));
+
+ // Micro-kernel can access one element beyond w and dmap for software pipelining.
+ w.reserve(num_buffers * w_elements + 1);
+ dmap.reserve(num_buffers * dmap_elements + 1);
+ nmap.resize(num_buffers * nmap_elements);
+
+ std::vector<size_t> a_offsets(num_buffers);
+
+ for (size_t buffer_index = 0; buffer_index < num_buffers; buffer_index++) {
+ // Re-generate weights. Note: each re-generation produces the number of non-zeroes.
+ std::fill(b.begin(), b.begin() + sparse_end, 0);
+ std::generate(b.begin() + sparse_end, b.end(), std::ref(f16rng));
+ std::shuffle(b.begin(), b.end(), rng);
+ std::generate(bias.begin(), bias.end(), std::ref(f16rng));
+
+ uint32_t first_j = 0, last_j = 0;
+ bool is_first_nonzero = true;
+ for (uint32_t i = 0; i < nc / nr; i++) {
+ for (uint32_t n = 0; n < nr; n++)
+ w.push_back(bias[nr * i + n]);
+ for (uint32_t j = 0; j < kc; j++) {
+ if ((b[i * kc + j] & 0x7FFF) != 0) {
+ for (size_t l = 0; l < nr; l++)
+ w.push_back(fp16_ieee_from_fp32_value(fp16_ieee_to_fp32_value(b[i * kc + j]) + static_cast<float>(i)));
+ if (is_first_nonzero) {
+ first_j = j;
+ } else {
+ const ptrdiff_t increment = int32_t(j - last_j) * int32_t(mc) * int32_t(sizeof(uint16_t));
+ dmap.push_back(increment);
+ }
+ last_j = j;
+ is_first_nonzero = false;
+ nmap[buffer_index * nmap_elements + i] += 1;
+ }
+ }
+ }
+ for (uint32_t i = nc / nr; i < ncols; i++) {
+ w.push_back(bias[i]);
+ for (uint32_t j = 0; j < kc; j++) {
+ if ((b[i * kc + j] & 0x7FFF) != 0) {
+ w.push_back(b[i * kc + j]);
+ if (is_first_nonzero) {
+ first_j = j;
+ } else {
+ const ptrdiff_t increment = int32_t(j - last_j) * int32_t(mc) * int32_t(sizeof(uint16_t));
+ dmap.push_back(increment);
+ }
+ last_j = j;
+ is_first_nonzero = false;
+ nmap[buffer_index * nmap_elements + i] += 1;
+ }
+ }
+ }
+ {
+ const ptrdiff_t increment = int32_t(first_j - last_j) * int32_t(mc) * int32_t(sizeof(uint16_t));
+ dmap.push_back(increment);
+ }
+
+ a_offsets[buffer_index] = first_j * mc;
+ }
+
+ // Micro-kernel can access one element beyond w and dmap for software pipelining.
+ w.resize(w.size() + 1);
+ dmap.resize(dmap.size() + 1);
+
+ std::vector<float, AlignedAllocator<float, 64>> a(kc * mc);
+ std::vector<float, AlignedAllocator<float, 64>> c(num_buffers * c_elements);
+
+ std::generate(a.begin(), a.end(), std::ref(f32rng));
+ std::fill(c.begin(), c.end(), nanf(""));
+
+ xnn_f16_output_params output_params{
+ 0x3C00 /* 1.0 */, 0x7C00 /* inf */, 0xFC00 /* -inf */};
+
+ size_t buffer_index = 0;
+ for (auto _ : state) {
+ // Use circular buffers (exceeding cache size) and prefetch to control cache state:
+ // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
+ // - W, Kmap, and Nmap is not in cache (for any cache level)
+ // - C is not in cache (for any cache level)
+ state.PauseTiming();
+ benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(uint16_t));
+ buffer_index = (buffer_index + 1) % num_buffers;
+ state.ResumeTiming();
+
+ spmm(mc, nc,
+ a.data() + a_offsets[buffer_index],
+ w.data() + buffer_index * w_elements,
+ dmap.data() + buffer_index * dmap_elements,
+ nmap.data() + buffer_index * nmap_elements,
+ c.data() + buffer_index * c_elements,
+ &output_params);
+ }
+
+ state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
+ state.counters["FLOPS"] = benchmark::Counter(
+ uint64_t(state.iterations()) * 2 * mc * num_nonzeroes, benchmark::Counter::kIsRate);
+
+ state.counters["EffFLOPS"] = benchmark::Counter(
+ uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
+}
+
+
+#if XNN_ARCH_ARM64
+ static void spmm80_8x1__neonfp16arith(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_8x1__neonfp16arith, 8, 1, 0.8f);
+ }
+ static void spmm80_8x1__neonfp16arith_unroll2(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2, 8, 1, 0.8f);
+ }
+ static void spmm80_16x1__neonfp16arith(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_16x1__neonfp16arith, 16, 1, 0.8f);
+ }
+ static void spmm80_16x1__neonfp16arith_unroll2(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2, 16, 1, 0.8f);
+ }
+ static void spmm80_24x1__neonfp16arith(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_24x1__neonfp16arith, 24, 1, 0.8f);
+ }
+ static void spmm80_24x1__neonfp16arith_unroll2(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2, 24, 1, 0.8f);
+ }
+ static void spmm80_32x1__neonfp16arith(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_32x1__neonfp16arith, 32, 1, 0.8f);
+ }
+ static void spmm80_32x1__neonfp16arith_unroll2(benchmark::State& state, const char* net) {
+ SpMMBenchmark(state, xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2, 32, 1, 0.8f);
+ }
+
+ BENCHMARK_GEMM(spmm80_8x1__neonfp16arith)
+ BENCHMARK_GEMM(spmm80_8x1__neonfp16arith_unroll2)
+ BENCHMARK_GEMM(spmm80_16x1__neonfp16arith)
+ BENCHMARK_GEMM(spmm80_16x1__neonfp16arith_unroll2)
+ BENCHMARK_GEMM(spmm80_24x1__neonfp16arith)
+ BENCHMARK_GEMM(spmm80_24x1__neonfp16arith_unroll2)
+ BENCHMARK_GEMM(spmm80_32x1__neonfp16arith)
+ BENCHMARK_GEMM(spmm80_32x1__neonfp16arith_unroll2)
+#endif // XNN_ARCH_ARM64
+
+#ifndef XNNPACK_BENCHMARK_NO_MAIN
+BENCHMARK_MAIN();
+#endif
diff --git a/scripts/generate-f16-spmm.sh b/scripts/generate-f16-spmm.sh
new file mode 100755
index 0000000..22eaa98
--- /dev/null
+++ b/scripts/generate-f16-spmm.sh
@@ -0,0 +1,20 @@
+#!/bin/sh
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+################################### ARM NEON ##################################
+### Microkernels without unrolling
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=8 -D NR=1 -D UNROLL=1 -o src/f16-spmm/gen/8x1-neonfp16arith.c
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=16 -D NR=1 -D UNROLL=1 -o src/f16-spmm/gen/16x1-neonfp16arith.c
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=24 -D NR=1 -D UNROLL=1 -o src/f16-spmm/gen/24x1-neonfp16arith.c
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=32 -D NR=1 -D UNROLL=1 -o src/f16-spmm/gen/32x1-neonfp16arith.c
+### Microkernels with 2X unrolling
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=8 -D NR=1 -D UNROLL=2 -o src/f16-spmm/gen/8x1-neonfp16arith-unroll2.c
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=16 -D NR=1 -D UNROLL=2 -o src/f16-spmm/gen/16x1-neonfp16arith-unroll2.c
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=24 -D NR=1 -D UNROLL=2 -o src/f16-spmm/gen/24x1-neonfp16arith-unroll2.c
+tools/xngen src/f16-spmm/neonfp16arith.c.in -D MR=32 -D NR=1 -D UNROLL=2 -o src/f16-spmm/gen/32x1-neonfp16arith-unroll2.c
+
+################################## Unit tests #################################
+tools/generate-spmm-test.py --spec test/f16-spmm.yaml --output test/f16-spmm.cc
diff --git a/src/f16-spmm/gen/16x1-neonfp16arith-unroll2.c b/src/f16-spmm/gen/16x1-neonfp16arith-unroll2.c
new file mode 100644
index 0000000..45a7f05
--- /dev/null
+++ b/src/f16-spmm/gen/16x1-neonfp16arith-unroll2.c
@@ -0,0 +1,201 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 16) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567x0 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc01234567x1 = vmovq_n_f16(0.0f);
+ float16x8_t vacc89ABCDEFx0 = vacc01234567x0;
+ float16x8_t vacc89ABCDEFx1 = vmovq_n_f16(0.0f);
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float16x8_t va01234567x0 = vld1q_f16(a);
+ const float16x8_t va89ABCDEFx0 = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float16x8_t vb0 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x0 = vfmaq_f16(vacc01234567x0, va01234567x0, vb0);
+ vacc89ABCDEFx0 = vfmaq_f16(vacc89ABCDEFx0, va89ABCDEFx0, vb0);
+ const float16x8_t va01234567x1 = vld1q_f16(a);
+ const float16x8_t va89ABCDEFx1 = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float16x8_t vb1 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x1 = vfmaq_f16(vacc01234567x1, va01234567x1, vb1);
+ vacc89ABCDEFx1 = vfmaq_f16(vacc89ABCDEFx1, va89ABCDEFx1, vb1);
+ }
+ float16x8_t vacc01234567 = vacc01234567x0;
+ float16x8_t vacc89ABCDEF = vacc89ABCDEFx0;
+ vacc01234567 = vaddq_f16(vacc01234567, vacc01234567x1);
+ vacc89ABCDEF = vaddq_f16(vacc89ABCDEF, vacc89ABCDEFx1);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ float16x8_t vout89ABCDEF = vmulq_f16(vacc89ABCDEF, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout89ABCDEF = vminq_f16(vout89ABCDEF, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ i -= 16;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/gen/16x1-neonfp16arith.c b/src/f16-spmm/gen/16x1-neonfp16arith.c
new file mode 100644
index 0000000..7f8cb71
--- /dev/null
+++ b/src/f16-spmm/gen/16x1-neonfp16arith.c
@@ -0,0 +1,178 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_16x1__neonfp16arith(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 16) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc89ABCDEF = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ float16x8_t vout89ABCDEF = vmulq_f16(vacc89ABCDEF, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout89ABCDEF = vminq_f16(vout89ABCDEF, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ i -= 16;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/gen/24x1-neonfp16arith-unroll2.c b/src/f16-spmm/gen/24x1-neonfp16arith-unroll2.c
new file mode 100644
index 0000000..dac9d87
--- /dev/null
+++ b/src/f16-spmm/gen/24x1-neonfp16arith-unroll2.c
@@ -0,0 +1,247 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 24) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567x0 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc01234567x1 = vmovq_n_f16(0.0f);
+ float16x8_t vacc89ABCDEFx0 = vacc01234567x0;
+ float16x8_t vacc89ABCDEFx1 = vmovq_n_f16(0.0f);
+ float16x8_t vaccGHIJKLMNx0 = vacc01234567x0;
+ float16x8_t vaccGHIJKLMNx1 = vmovq_n_f16(0.0f);
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float16x8_t va01234567x0 = vld1q_f16(a);
+ const float16x8_t va89ABCDEFx0 = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMNx0 = vld1q_f16(a + 16);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float16x8_t vb0 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x0 = vfmaq_f16(vacc01234567x0, va01234567x0, vb0);
+ vacc89ABCDEFx0 = vfmaq_f16(vacc89ABCDEFx0, va89ABCDEFx0, vb0);
+ vaccGHIJKLMNx0 = vfmaq_f16(vaccGHIJKLMNx0, vaGHIJKLMNx0, vb0);
+ const float16x8_t va01234567x1 = vld1q_f16(a);
+ const float16x8_t va89ABCDEFx1 = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMNx1 = vld1q_f16(a + 16);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float16x8_t vb1 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x1 = vfmaq_f16(vacc01234567x1, va01234567x1, vb1);
+ vacc89ABCDEFx1 = vfmaq_f16(vacc89ABCDEFx1, va89ABCDEFx1, vb1);
+ vaccGHIJKLMNx1 = vfmaq_f16(vaccGHIJKLMNx1, vaGHIJKLMNx1, vb1);
+ }
+ float16x8_t vacc01234567 = vacc01234567x0;
+ float16x8_t vacc89ABCDEF = vacc89ABCDEFx0;
+ float16x8_t vaccGHIJKLMN = vaccGHIJKLMNx0;
+ vacc01234567 = vaddq_f16(vacc01234567, vacc01234567x1);
+ vacc89ABCDEF = vaddq_f16(vacc89ABCDEF, vacc89ABCDEFx1);
+ vaccGHIJKLMN = vaddq_f16(vaccGHIJKLMN, vaccGHIJKLMNx1);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMN = vld1q_f16(a + 16);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ vaccGHIJKLMN = vfmaq_f16(vaccGHIJKLMN, vaGHIJKLMN, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ float16x8_t vout89ABCDEF = vmulq_f16(vacc89ABCDEF, vscale);
+ float16x8_t voutGHIJKLMN = vmulq_f16(vaccGHIJKLMN, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout89ABCDEF = vminq_f16(vout89ABCDEF, vmax);
+ voutGHIJKLMN = vminq_f16(voutGHIJKLMN, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ voutGHIJKLMN = vmaxq_f16(voutGHIJKLMN, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ vst1q_f16(c + 16, voutGHIJKLMN);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 24;
+ a += 24;
+ i -= 24;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 16) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc89ABCDEF = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ float16x8_t vout89ABCDEF = vminq_f16(vacc89ABCDEF, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ }
+ if (i & 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/gen/24x1-neonfp16arith.c b/src/f16-spmm/gen/24x1-neonfp16arith.c
new file mode 100644
index 0000000..75358cf
--- /dev/null
+++ b/src/f16-spmm/gen/24x1-neonfp16arith.c
@@ -0,0 +1,217 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_24x1__neonfp16arith(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 24) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc89ABCDEF = vacc01234567;
+ float16x8_t vaccGHIJKLMN = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMN = vld1q_f16(a + 16);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ vaccGHIJKLMN = vfmaq_f16(vaccGHIJKLMN, vaGHIJKLMN, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ float16x8_t vout89ABCDEF = vmulq_f16(vacc89ABCDEF, vscale);
+ float16x8_t voutGHIJKLMN = vmulq_f16(vaccGHIJKLMN, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout89ABCDEF = vminq_f16(vout89ABCDEF, vmax);
+ voutGHIJKLMN = vminq_f16(voutGHIJKLMN, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ voutGHIJKLMN = vmaxq_f16(voutGHIJKLMN, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ vst1q_f16(c + 16, voutGHIJKLMN);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 24;
+ a += 24;
+ i -= 24;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 16) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc89ABCDEF = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ float16x8_t vout89ABCDEF = vminq_f16(vacc89ABCDEF, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ }
+ if (i & 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/gen/32x1-neonfp16arith-unroll2.c b/src/f16-spmm/gen/32x1-neonfp16arith-unroll2.c
new file mode 100644
index 0000000..46956c3
--- /dev/null
+++ b/src/f16-spmm/gen/32x1-neonfp16arith-unroll2.c
@@ -0,0 +1,261 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 32) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567x0 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc01234567x1 = vmovq_n_f16(0.0f);
+ float16x8_t vacc89ABCDEFx0 = vacc01234567x0;
+ float16x8_t vacc89ABCDEFx1 = vmovq_n_f16(0.0f);
+ float16x8_t vaccGHIJKLMNx0 = vacc01234567x0;
+ float16x8_t vaccGHIJKLMNx1 = vmovq_n_f16(0.0f);
+ float16x8_t vaccOPQRSTUVx0 = vacc01234567x0;
+ float16x8_t vaccOPQRSTUVx1 = vmovq_n_f16(0.0f);
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float16x8_t va01234567x0 = vld1q_f16(a);
+ const float16x8_t va89ABCDEFx0 = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMNx0 = vld1q_f16(a + 16);
+ const float16x8_t vaOPQRSTUVx0 = vld1q_f16(a + 24);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float16x8_t vb0 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x0 = vfmaq_f16(vacc01234567x0, va01234567x0, vb0);
+ vacc89ABCDEFx0 = vfmaq_f16(vacc89ABCDEFx0, va89ABCDEFx0, vb0);
+ vaccGHIJKLMNx0 = vfmaq_f16(vaccGHIJKLMNx0, vaGHIJKLMNx0, vb0);
+ vaccOPQRSTUVx0 = vfmaq_f16(vaccOPQRSTUVx0, vaOPQRSTUVx0, vb0);
+ const float16x8_t va01234567x1 = vld1q_f16(a);
+ const float16x8_t va89ABCDEFx1 = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMNx1 = vld1q_f16(a + 16);
+ const float16x8_t vaOPQRSTUVx1 = vld1q_f16(a + 24);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float16x8_t vb1 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x1 = vfmaq_f16(vacc01234567x1, va01234567x1, vb1);
+ vacc89ABCDEFx1 = vfmaq_f16(vacc89ABCDEFx1, va89ABCDEFx1, vb1);
+ vaccGHIJKLMNx1 = vfmaq_f16(vaccGHIJKLMNx1, vaGHIJKLMNx1, vb1);
+ vaccOPQRSTUVx1 = vfmaq_f16(vaccOPQRSTUVx1, vaOPQRSTUVx1, vb1);
+ }
+ float16x8_t vacc01234567 = vacc01234567x0;
+ float16x8_t vacc89ABCDEF = vacc89ABCDEFx0;
+ float16x8_t vaccGHIJKLMN = vaccGHIJKLMNx0;
+ float16x8_t vaccOPQRSTUV = vaccOPQRSTUVx0;
+ vacc01234567 = vaddq_f16(vacc01234567, vacc01234567x1);
+ vacc89ABCDEF = vaddq_f16(vacc89ABCDEF, vacc89ABCDEFx1);
+ vaccGHIJKLMN = vaddq_f16(vaccGHIJKLMN, vaccGHIJKLMNx1);
+ vaccOPQRSTUV = vaddq_f16(vaccOPQRSTUV, vaccOPQRSTUVx1);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMN = vld1q_f16(a + 16);
+ const float16x8_t vaOPQRSTUV = vld1q_f16(a + 24);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ vaccGHIJKLMN = vfmaq_f16(vaccGHIJKLMN, vaGHIJKLMN, vb);
+ vaccOPQRSTUV = vfmaq_f16(vaccOPQRSTUV, vaOPQRSTUV, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ float16x8_t vout89ABCDEF = vmulq_f16(vacc89ABCDEF, vscale);
+ float16x8_t voutGHIJKLMN = vmulq_f16(vaccGHIJKLMN, vscale);
+ float16x8_t voutOPQRSTUV = vmulq_f16(vaccOPQRSTUV, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout89ABCDEF = vminq_f16(vout89ABCDEF, vmax);
+ voutGHIJKLMN = vminq_f16(voutGHIJKLMN, vmax);
+ voutOPQRSTUV = vminq_f16(voutOPQRSTUV, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ voutGHIJKLMN = vmaxq_f16(voutGHIJKLMN, vmin);
+ voutOPQRSTUV = vmaxq_f16(voutOPQRSTUV, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ vst1q_f16(c + 16, voutGHIJKLMN);
+ vst1q_f16(c + 24, voutOPQRSTUV);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 32;
+ a += 32;
+ i -= 32;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 16) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc89ABCDEF = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ float16x8_t vout89ABCDEF = vminq_f16(vacc89ABCDEF, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ }
+ if (i & 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/gen/32x1-neonfp16arith.c b/src/f16-spmm/gen/32x1-neonfp16arith.c
new file mode 100644
index 0000000..664765d
--- /dev/null
+++ b/src/f16-spmm/gen/32x1-neonfp16arith.c
@@ -0,0 +1,224 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_32x1__neonfp16arith(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 32) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc89ABCDEF = vacc01234567;
+ float16x8_t vaccGHIJKLMN = vacc01234567;
+ float16x8_t vaccOPQRSTUV = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ const float16x8_t vaGHIJKLMN = vld1q_f16(a + 16);
+ const float16x8_t vaOPQRSTUV = vld1q_f16(a + 24);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ vaccGHIJKLMN = vfmaq_f16(vaccGHIJKLMN, vaGHIJKLMN, vb);
+ vaccOPQRSTUV = vfmaq_f16(vaccOPQRSTUV, vaOPQRSTUV, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ float16x8_t vout89ABCDEF = vmulq_f16(vacc89ABCDEF, vscale);
+ float16x8_t voutGHIJKLMN = vmulq_f16(vaccGHIJKLMN, vscale);
+ float16x8_t voutOPQRSTUV = vmulq_f16(vaccOPQRSTUV, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout89ABCDEF = vminq_f16(vout89ABCDEF, vmax);
+ voutGHIJKLMN = vminq_f16(voutGHIJKLMN, vmax);
+ voutOPQRSTUV = vminq_f16(voutOPQRSTUV, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ voutGHIJKLMN = vmaxq_f16(voutGHIJKLMN, vmin);
+ voutOPQRSTUV = vmaxq_f16(voutOPQRSTUV, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ vst1q_f16(c + 16, voutGHIJKLMN);
+ vst1q_f16(c + 24, voutOPQRSTUV);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 32;
+ a += 32;
+ i -= 32;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 16) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc89ABCDEF = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ float16x8_t vout89ABCDEF = vminq_f16(vacc89ABCDEF, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
+ vst1q_f16(c, vout01234567);
+ vst1q_f16(c + 8, vout89ABCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ }
+ if (i & 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/gen/8x1-neonfp16arith-unroll2.c b/src/f16-spmm/gen/8x1-neonfp16arith-unroll2.c
new file mode 100644
index 0000000..7c78a0a
--- /dev/null
+++ b/src/f16-spmm/gen/8x1-neonfp16arith-unroll2.c
@@ -0,0 +1,161 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567x0 = vld1q_dup_f16(w); w += 1;
+ float16x8_t vacc01234567x1 = vmovq_n_f16(0.0f);
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float16x8_t va01234567x0 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float16x8_t vb0 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x0 = vfmaq_f16(vacc01234567x0, va01234567x0, vb0);
+ const float16x8_t va01234567x1 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float16x8_t vb1 = vld1q_dup_f16(w); w += 1;
+ vacc01234567x1 = vfmaq_f16(vacc01234567x1, va01234567x1, vb1);
+ }
+ float16x8_t vacc01234567 = vacc01234567x0;
+ vacc01234567 = vaddq_f16(vacc01234567, vacc01234567x1);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/gen/8x1-neonfp16arith.c b/src/f16-spmm/gen/8x1-neonfp16arith.c
new file mode 100644
index 0000000..9cf730b
--- /dev/null
+++ b/src/f16-spmm/gen/8x1-neonfp16arith.c
@@ -0,0 +1,145 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-spmm/neonfp16arith.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_8x1__neonfp16arith(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
+ } while (--nnz != 0);
+ }
+ float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
+ vout01234567 = vminq_f16(vout01234567, vmax);
+ vout01234567 = vmaxq_f16(vout01234567, vmin);
+ vst1q_f16(c, vout01234567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0123 = vld1_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0123 = vfma_f16(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
+ vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
+ vst1_f16(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc01 = vfma_f16(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
+ vout01 = vmax_f16(vout01, vget_low_f16(vmin));
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x4_t va0 = vld1_dup_f16(a);
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ vacc0 = vfma_f16(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
+ vout0 = vmax_f16(vout0, vget_low_f16(vmin));
+ vst1_lane_f16(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f16-spmm/neonfp16arith.c.in b/src/f16-spmm/neonfp16arith.c.in
new file mode 100644
index 0000000..966b647
--- /dev/null
+++ b/src/f16-spmm/neonfp16arith.c.in
@@ -0,0 +1,165 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 8 == 0
+$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f16_spmm_ukernel_${MR}x${NR}__neonfp16arith${"_unroll" + str(UNROLL) if UNROLL > 1 else ""}(
+ uint32_t m,
+ uint32_t n,
+ const void*restrict input,
+ const void*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ void*restrict output,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __fp16*restrict a = input;
+ __fp16*restrict c = output;
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+
+ size_t i = m;
+ while XNN_LIKELY(i >= ${MR}) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if UNROLL > 1:
+ float16x8_t vacc01234567x0 = vld1q_dup_f16(w); w += 1;
+ $for K in range(1, UNROLL):
+ float16x8_t vacc01234567x${K} = vmovq_n_f16(0.0f);
+ $for M in range(8, MR, 8):
+ float16x8_t vacc${ABC[M:M+8]}x0 = vacc01234567x0;
+ $for K in range(1, UNROLL):
+ float16x8_t vacc${ABC[M:M+8]}x${K} = vmovq_n_f16(0.0f);
+ for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
+ $for K in range(UNROLL):
+ const intptr_t diff${K} = dmap[${K}];
+ dmap += ${UNROLL};
+ $for K in range(UNROLL):
+ const float16x8_t va01234567x${K} = vld1q_f16(a);
+ $for M in range(8, MR, 8):
+ const float16x8_t va${ABC[M:M+8]}x${K} = vld1q_f16(a + ${M});
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff${K});
+ const float16x8_t vb${K} = vld1q_dup_f16(w); w += 1;
+ $for M in range(0, MR, 8):
+ vacc${ABC[M:M+8]}x${K} = vfmaq_f16(vacc${ABC[M:M+8]}x${K}, va${ABC[M:M+8]}x${K}, vb${K});
+ }
+ $for M in range(0, MR, 8):
+ float16x8_t vacc${ABC[M:M+8]} = vacc${ABC[M:M+8]}x0;
+ $for K in range(1, UNROLL):
+ $for M in range(0, MR, 8):
+ vacc${ABC[M:M+8]} = vaddq_f16(vacc${ABC[M:M+8]}, vacc${ABC[M:M+8]}x${K});
+ $else:
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ $for M in range(8, MR, 8):
+ float16x8_t vacc${ABC[M:M+8]} = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float16x8_t va01234567 = vld1q_f16(a);
+ $for M in range(8, MR, 8):
+ const float16x8_t va${ABC[M:M+8]} = vld1q_f16(a + ${M});
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ $for M in range(0, MR, 8):
+ vacc${ABC[M:M+8]} = vfmaq_f16(vacc${ABC[M:M+8]}, va${ABC[M:M+8]}, vb);
+ } while (--nnz != 0);
+ }
+ $for M in range(0, MR, 8):
+ float16x8_t vout${ABC[M:M+8]} = vmulq_f16(vacc${ABC[M:M+8]}, vscale);
+ $for M in range(0, MR, 8):
+ vout${ABC[M:M+8]} = vminq_f16(vout${ABC[M:M+8]}, vmax);
+ $for M in range(0, MR, 8):
+ vout${ABC[M:M+8]} = vmaxq_f16(vout${ABC[M:M+8]}, vmin);
+ vst1q_f16(c, vout01234567);
+ $for M in range(8, MR, 8):
+ vst1q_f16(c + ${M}, vout${ABC[M:M+8]});
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${MR};
+ a += ${MR};
+ i -= ${MR};
+ }
+ if XNN_UNLIKELY(i != 0) {
+ $for LOG2M in reversed(range((MR - 1).bit_length())):
+ $SUBMR = 1 << LOG2M
+ if (i & ${SUBMR}) {
+ const __fp16*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if SUBMR <= 4:
+ float16x4_t vacc${ABC[0:SUBMR]} = vld1_dup_f16(w); w += 1;
+ $else:
+ float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
+ $for M in range(8, SUBMR, 8):
+ float16x8_t vacc${ABC[M:M+8]} = vacc01234567;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $if SUBMR == 1:
+ const float16x4_t va0 = vld1_dup_f16(a);
+ $elif SUBMR == 2:
+ const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
+ $elif SUBMR == 4:
+ const float16x4_t va0123 = vld1_f16(a);
+ $else:
+ const float16x8_t va01234567 = vld1q_f16(a);
+ $for M in range(8, SUBMR, 8):
+ const float16x8_t va${ABC[M:M+8]} = vld1q_f16(a + ${M});
+ a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ $if SUBMR <= 4:
+ const float16x4_t vb = vld1_dup_f16(w); w += 1;
+ $else:
+ const float16x8_t vb = vld1q_dup_f16(w); w += 1;
+ $if SUBMR <= 4:
+ vacc${ABC[0:SUBMR]} = vfma_f16(vacc${ABC[0:SUBMR]}, va${ABC[0:SUBMR]}, vb);
+ $else:
+ $for M in range(0, SUBMR, 8):
+ vacc${ABC[M:M+8]} = vfmaq_f16(vacc${ABC[M:M+8]}, va${ABC[M:M+8]}, vb);
+ } while (--nnz != 0);
+ }
+ $if SUBMR <= 4:
+ float16x4_t vout${ABC[0:SUBMR]} = vmin_f16(vacc${ABC[0:SUBMR]}, vget_low_f16(vmax));
+ vout${ABC[0:SUBMR]} = vmax_f16(vout${ABC[0:SUBMR]}, vget_low_f16(vmin));
+ $if SUBMR == 1:
+ vst1_lane_f16(c, vout${ABC[0]}, 0);
+ $elif SUBMR == 2:
+ vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout${ABC[0:SUBMR]}), 0);
+ $else:
+ vst1_f16(c, vout${ABC[0:SUBMR]});
+ $else:
+ $for M in range(0, SUBMR, 8):
+ float16x8_t vout${ABC[M:M+8]} = vminq_f16(vacc${ABC[M:M+8]}, vmax);
+ $for M in range(0, SUBMR, 8):
+ vout${ABC[M:M+8]} = vmaxq_f16(vout${ABC[M:M+8]}, vmin);
+ vst1q_f16(c, vout01234567);
+ $for M in range(8, SUBMR, 8):
+ vst1q_f16(c + ${M}, vout${ABC[M:M+8]});
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${SUBMR};
+ a += ${SUBMR};
+ }
+ }
+}
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 8259291..f3ec735 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -562,6 +562,16 @@
void* c,
const void* params);
+typedef void (*xnn_f16_spmm_ukernel_function)(
+ uint32_t m,
+ uint32_t n,
+ const void* a,
+ const void* w,
+ const int32_t* dmap,
+ const uint32_t* nmap,
+ void* c,
+ const struct xnn_f16_output_params* params);
+
typedef void (*xnn_f32_spmm_ukernel_function)(
uint32_t m,
uint32_t n,
diff --git a/src/xnnpack/spmm.h b/src/xnnpack/spmm.h
index c525a70..f49bc4a 100644
--- a/src/xnnpack/spmm.h
+++ b/src/xnnpack/spmm.h
@@ -58,6 +58,25 @@
DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__scalar_pipelined)
DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__sse)
+#define DECLARE_F16_SPMM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ uint32_t m, \
+ uint32_t n, \
+ const void* a, \
+ const void* w, \
+ const int32_t* dmap, \
+ const uint32_t* nmap, \
+ void* c, \
+ const struct xnn_f16_output_params* params);
+
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_8x1__neonfp16arith)
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2)
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_16x1__neonfp16arith)
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2)
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_24x1__neonfp16arith)
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2)
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_32x1__neonfp16arith)
+DECLARE_F16_SPMM_UKERNEL_FUNCTION(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2)
#ifdef __cplusplus
} // extern "C"
diff --git a/test/f16-spmm.cc b/test/f16-spmm.cc
new file mode 100644
index 0000000..cce89d0
--- /dev/null
+++ b/test/f16-spmm.cc
@@ -0,0 +1,1449 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+//
+// Auto-generated file. Do not edit!
+// Specification: test/f16-spmm.yaml
+// Generator: tools/generate-spmm-test.py
+
+
+#include <gtest/gtest.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/isa-checks.h>
+
+#include <xnnpack/spmm.h>
+#include "spmm-microkernel-tester.h"
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, k_eq_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(1)
+ .k(1)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, k_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 2; k < 10; k++) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, m_lt_8) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 8; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, m_div_8) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 16; m <= 24; m += 8) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, m_gt_8) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 9; m < 16; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, k_eq_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(1)
+ .k(2)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, k_lt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 1; k < 2; k++) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, k_gt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 3; k < 4; k++) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, k_div_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 4; k <= 20; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(8)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, m_lt_8) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 8; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, m_div_8) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 16; m <= 24; m += 8) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, m_gt_8) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 9; m < 16; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_8X1__NEONFP16ARITH_UNROLL2, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(8)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, k_eq_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(1)
+ .k(1)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, k_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 2; k < 10; k++) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, m_lt_16) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 16; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, m_div_16) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 32; m <= 48; m += 16) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, m_gt_16) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 17; m < 32; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, k_eq_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(1)
+ .k(2)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, k_lt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 1; k < 2; k++) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, k_gt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 3; k < 4; k++) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, k_div_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 4; k <= 20; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(16)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, m_lt_16) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 16; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, m_div_16) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 32; m <= 48; m += 16) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, m_gt_16) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 17; m < 32; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_16X1__NEONFP16ARITH_UNROLL2, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(16)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, k_eq_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(1)
+ .k(1)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, k_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 2; k < 10; k++) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, m_lt_24) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 24; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, m_div_24) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 48; m <= 72; m += 24) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, m_gt_24) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 25; m < 48; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, k_eq_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(1)
+ .k(2)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, k_lt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 1; k < 2; k++) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, k_gt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 3; k < 4; k++) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, k_div_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 4; k <= 20; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(24)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, m_lt_24) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 24; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, m_div_24) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 48; m <= 72; m += 24) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, m_gt_24) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 25; m < 48; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_24X1__NEONFP16ARITH_UNROLL2, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(24)
+ .nr(1)
+ .m(48)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, k_eq_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(1)
+ .k(1)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, k_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 2; k < 10; k++) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, m_lt_32) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 32; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, m_div_32) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 64; m <= 96; m += 32) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, m_gt_32) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 33; m < 64; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 5; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM64
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, k_eq_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(1)
+ .k(2)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, k_lt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 1; k < 2; k++) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, k_gt_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 3; k < 4; k++) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, k_div_2) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (size_t k = 4; k <= 20; k += 2) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(1)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, n_gt_1) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 2; n < 10; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(32)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, m_lt_32) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 1; m < 32; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, m_div_32) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 64; m <= 96; m += 32) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, m_gt_32) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t m = 33; m < 64; m++) {
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, qmin) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmin(128)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, qmax) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(0.0f)
+ .qmax(128)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, half_sparse) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(0.5f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+
+ TEST(F16_SPMM_32X1__NEONFP16ARITH_UNROLL2, zero_weights) {
+ TEST_REQUIRES_ARM_NEON_FP16_ARITH;
+ for (uint32_t n = 1; n < 10; n += 2) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ SpMMMicrokernelTester()
+ .mr(32)
+ .nr(1)
+ .m(64)
+ .n(n)
+ .k(k)
+ .sparsity(1.0f)
+ .Test(xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2);
+ }
+ }
+ }
+#endif // XNN_ARCH_ARM64
diff --git a/test/f16-spmm.yaml b/test/f16-spmm.yaml
new file mode 100644
index 0000000..4956c02
--- /dev/null
+++ b/test/f16-spmm.yaml
@@ -0,0 +1,36 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+- name: xnn_f16_spmm_ukernel_8x1__neonfp16arith
+ k-block: 1
+ arch:
+ - aarch64
+- name: xnn_f16_spmm_ukernel_8x1__neonfp16arith_unroll2
+ k-block: 2
+ arch:
+ - aarch64
+- name: xnn_f16_spmm_ukernel_16x1__neonfp16arith
+ k-block: 1
+ arch:
+ - aarch64
+- name: xnn_f16_spmm_ukernel_16x1__neonfp16arith_unroll2
+ k-block: 2
+ arch:
+ - aarch64
+- name: xnn_f16_spmm_ukernel_24x1__neonfp16arith
+ k-block: 1
+ arch:
+ - aarch64
+- name: xnn_f16_spmm_ukernel_24x1__neonfp16arith_unroll2
+ k-block: 2
+ arch:
+ - aarch64
+- name: xnn_f16_spmm_ukernel_32x1__neonfp16arith
+ k-block: 1
+ arch:
+ - aarch64
+- name: xnn_f16_spmm_ukernel_32x1__neonfp16arith_unroll2
+ k-block: 2
+ arch:
+ - aarch64
diff --git a/test/spmm-microkernel-tester.h b/test/spmm-microkernel-tester.h
index 20d6224..bcffad6 100644
--- a/test/spmm-microkernel-tester.h
+++ b/test/spmm-microkernel-tester.h
@@ -22,6 +22,12 @@
#include <xnnpack/params.h>
+static inline bool is_fp16_zero(uint16_t x) {
+ const uint32_t ext_x = x;
+ const uint32_t two_x = ext_x + ext_x;
+ return (uint16_t) two_x == 0;
+}
+
class SpMMMicrokernelTester {
public:
enum class Variant {
@@ -280,6 +286,173 @@
}
}
+ void Test(xnn_f16_spmm_ukernel_function spmm) const {
+ ASSERT_GE(m(), 1);
+ ASSERT_GE(n(), 1);
+ ASSERT_GE(k(), 1);
+
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
+ auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
+ auto prng = std::bind(std::uniform_real_distribution<float>(), rng);
+
+ std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> a(k() * m());
+ // Think of b as (n/nr + n % nr) x k, expansion happens later.
+ const size_t ncols = n() / nr() + n() % nr();
+ std::vector<uint16_t> b(ncols * k());
+ std::vector<uint16_t> bias(n());
+ // Number of non-zero weights per N (output channel).
+ std::vector<uint32_t> nmap(n());
+ // Mapping from index of non-zero weight to increment of K (input channel) following this index.
+ std::vector<int32_t> dmap(n() * k());
+ std::vector<uint16_t> w(n() * k() + n());
+ std::vector<uint16_t> c(n() * m());
+ std::vector<float> c_ref(n() * m());
+
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(a.begin(), a.end(), std::ref(f16rng));
+ std::generate(b.begin(), b.end(), std::ref(f16rng));
+ std::generate(bias.begin(), bias.end(), std::ref(f16rng));
+ std::fill(c.begin(), c.end(), 0xC000);
+ std::fill(c_ref.begin(), c_ref.end(), 0.0f);
+ std::fill(nmap.begin(), nmap.end(), 0);
+ std::fill(dmap.begin(), dmap.end(), 0);
+ std::fill(w.begin(), w.end(), 0);
+
+ for (uint16_t& b_value : b) {
+ if (prng() <= sparsity()) {
+ b_value = 0;
+ }
+ }
+
+ uint32_t nnz = 0;
+ uint32_t wcnt = 0;
+ size_t last_kk = 0;
+ bool first_nzz = true;
+ size_t first_kk = 0;
+ for (size_t nn = 0; nn < n() / nr(); nn++) {
+ for (size_t i = 0; i < nr(); ++i)
+ w[wcnt++] = bias[nr() * nn + i];
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (!is_fp16_zero(b[nn * k() + kk])) {
+ // Every non-zero actually corresponds to nr adjacent non-zeros.
+ for (size_t i = 0; i < nr(); ++i)
+ w[wcnt++] = fp16_ieee_from_fp32_value(fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
+ // Skip the very first non-zero weight as we record only the difference.
+ if (first_nzz) {
+ first_kk = kk;
+ } else {
+ const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
+ dmap[nnz++] = increment;
+ }
+ last_kk = kk;
+ first_nzz = false;
+ nmap[nn] += 1;
+ }
+ }
+ }
+
+ // now we've constructed the matrix for the blocked part and switch to the
+ // leftovers, which we do as nr=1 always.
+ for (size_t nn = n() / nr(); nn < ncols; nn++) {
+ w[wcnt++] = bias[(n() / nr()) * nr() + (nn - n() / nr())];
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (!is_fp16_zero(b[nn * k() + kk])) {
+ // Every non-zero actually corresponds to nr adjacent non-zeros.
+ w[wcnt++] = b[nn * k() + kk];
+ // Skip the very first non-zero weight as we record only the difference.
+ if (first_nzz) {
+ first_kk = kk;
+ } else {
+ const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
+ dmap[nnz++] = increment;
+ }
+ last_kk = kk;
+ first_nzz = false;
+ nmap[nn] += 1;
+ }
+ }
+ }
+ // In the end, we must return input pointer to the initial value.
+ const int64_t increment = int32_t(first_kk - last_kk) * int32_t(m() * sizeof(uint16_t));
+ dmap[nnz++] = increment;
+
+ // Generate expanded b which will be used in reference calculation.
+ // Everywhere there is a non-zero in the original we copy it and add an
+ // adjacent non-zero with incremented weight value.
+ std::vector<uint16_t> b_full(n() * k());
+ if (nr() == 1) {
+ b_full = b;
+ }
+ else {
+ for (size_t nn = 0; nn < n() / nr(); nn++) {
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (b[nn * k() + kk] != 0.0f) {
+ for (size_t i = 0; i < nr(); ++i)
+ b_full[nr() * nn * k() + i * k() + kk] = fp16_ieee_from_fp32_value(
+ fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
+ }
+ }
+ }
+ for (size_t nn = n() / nr(); nn < ncols; nn++) {
+ for (size_t kk = 0; kk < k(); kk++) {
+ if (b[nn * k() + kk] != 0.0f) {
+ b_full[nr() * (n() / nr()) * k() + (nn - n() / nr()) * k() + kk] = b[nn * k() + kk];
+ }
+ }
+ }
+ }
+
+ for (size_t oc = 0; oc < n(); oc++) {
+ for (size_t pxb = 0; pxb < m(); pxb++) {
+ c_ref[oc * m() + pxb] = fp16_ieee_to_fp32_value(bias[oc]);
+ for (size_t ic = 0; ic < k(); ic++) {
+ c_ref[oc * m() + pxb] += fp16_ieee_to_fp32_value(a[ic * m() + pxb]) * fp16_ieee_to_fp32_value(b_full[oc * k() + ic]);
+ }
+ }
+ }
+
+ // Micro-kernel can access one element beyond w and dmap for software pipelining.
+ w.resize(wcnt + 1);
+ dmap.resize(nnz + 1);
+
+ // Compute clamping parameters.
+ const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
+ const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
+ const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
+ const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
+
+ // Clamp reference results.
+ for (float& c_value : c_ref) {
+ c_value = std::min(std::max(c_value, c_min), c_max);
+ }
+
+ // Prepare output parameters.
+ xnn_f16_output_params output_params;
+ output_params.scale = UINT16_C(0x3C00) /* 1.0 */;
+ output_params.max = fp16_ieee_from_fp32_value(c_max);
+ output_params.min = fp16_ieee_from_fp32_value(c_min);
+
+ spmm(m(), n(),
+ a.data() + first_kk * m(), w.data(), dmap.data(), nmap.data(), c.data(),
+ &output_params);
+
+ // Validate micro-kernel outputs.
+ for (size_t pxb = 0; pxb < n(); pxb++) {
+ for (size_t oc = 0; oc < m(); oc++) {
+ ASSERT_NEAR(
+ fp16_ieee_to_fp32_value(c[pxb * m() + oc]),
+ c_ref[pxb * m() + oc],
+ std::abs(c_ref[pxb * m() + oc]) * 1.0e-2f)
+ << "at " << pxb << ", " << oc
+ << ": Mr x Nr x Kr = " << mr() << " x " << nr()
+ << ", M x N x K = " << m() << " x " << n() << " x " << k();
+ }
+ }
+ }
+ }
+
private:
size_t mr_{1};
size_t nr_{1};