Generalize Fully Connected operator creation across data types

PiperOrigin-RevId: 320083317
diff --git a/bench/q8-gemm.cc b/bench/q8-gemm.cc
index 08168a7..64650f0 100644
--- a/bench/q8-gemm.cc
+++ b/bench/q8-gemm.cc
@@ -37,7 +37,7 @@
 
 static void GEMMBenchmark(benchmark::State& state,
   xnn_q8_gemm_ukernel_function q8gemm,
-  size_t mr, size_t nr, size_t kr)
+  size_t mr, size_t nr, size_t kr, size_t sr)
 {
   if (!cpuinfo_initialize()) {
     state.SkipWithError("cpuinfo initialization failed");
@@ -72,7 +72,7 @@
   std::vector<uint8_t, AlignedAllocator<uint8_t, 32>> w(w_elements * num_buffers);
   std::fill(w.begin(), w.end(), 0);
   const xnn_q8_packing_params packing_params = { 127, 127 };
-  xnn_pack_q8_gemm_goi_w(1 /* groups */, nc, kc, nr, kr, k.data(), b.data(), w.data(), &packing_params);
+  xnn_pack_q8_gemm_goi_w(1 /* groups */, nc, kc, nr, kr, sr, k.data(), b.data(), w.data(), &packing_params);
   std::vector<uint8_t> c(c_elements * num_buffers);
   std::fill(c.begin(), c.end(), 0xA5);
 
@@ -298,11 +298,11 @@
 
 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
   static void q8gemm_4x8__neon(benchmark::State& state, const char* net) {
-    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_4x8__neon, 4, 8, 1);
+    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_4x8__neon, 4, 8, 1, 1);
   }
 
   static void q8gemm_8x8__neon(benchmark::State& state, const char* net) {
-    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_8x8__neon, 8, 8, 1);
+    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_8x8__neon, 8, 8, 1, 1);
   }
 
   BENCHMARK_GEMM(q8gemm_4x8__neon)
@@ -311,11 +311,11 @@
 
 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
   static void q8gemm_4x4c2__sse2(benchmark::State& state, const char* net) {
-    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_4x4c2__sse2, 4, 4, 2);
+    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_4x4c2__sse2, 4, 4, 2, 1);
   }
 
   static void q8gemm_2x4c8__sse2(benchmark::State& state, const char* net) {
-    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_2x4c8__sse2, 2, 4, 8);
+    GEMMBenchmark(state, xnn_q8_gemm_minmax_ukernel_2x4c8__sse2, 2, 4, 8, 1);
   }
 
   BENCHMARK_GEMM(q8gemm_4x4c2__sse2)
diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c
index 51fcefc..e5cdd09 100644
--- a/src/operators/convolution-nhwc.c
+++ b/src/operators/convolution-nhwc.c
@@ -311,7 +311,7 @@
         case xnn_ukernel_type_gemm:
           xnn_pack_q8_gemm_goi_w(
               groups, group_output_channels, group_input_channels,
-              nr, kr,
+              nr, kr, 1 /* sr */,
               kernel, bias, convolution_op->packed_weights,
               &packing_params);
           convolution_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c
index 68dec61..050e2c8 100644
--- a/src/operators/fully-connected-nc.c
+++ b/src/operators/fully-connected-nc.c
@@ -23,22 +23,25 @@
 #include <xnnpack/params.h>
 
 
-enum xnn_status xnn_create_fully_connected_nc_q8(
+static enum xnn_status create_fully_connected_nc(
     size_t input_channels,
     size_t output_channels,
     size_t input_stride,
     size_t output_stride,
-    uint8_t input_zero_point,
-    float input_scale,
-    uint8_t kernel_zero_point,
-    float kernel_scale,
-    const uint8_t* kernel,
-    const int32_t* bias,
-    uint8_t output_zero_point,
-    float output_scale,
-    uint8_t output_min,
-    uint8_t output_max,
+    const void* kernel,
+    const void* bias,
     uint32_t flags,
+    uint32_t log2_filter_element_size,
+    uint32_t bias_element_size,
+    xnn_pack_gemm_io_w_function pack_gemm_io_w,
+    xnn_pack_gemm_goi_w_function pack_gemm_goi_w,
+    const void* packing_params,
+    int packed_weights_padding_byte,
+    const void* params,
+    size_t params_size,
+    const struct gemm_parameters* gemm_parameters,
+    const struct gemm_fused_ukernels* gemm_ukernels,
+    enum xnn_operator_type operator_type,
     xnn_operator_t* fully_connected_op_out)
 {
   xnn_operator_t fully_connected_op = NULL;
@@ -46,7 +49,7 @@
 
   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8));
+      xnn_operator_type_to_string(operator_type));
     goto error;
   }
 
@@ -55,14 +58,14 @@
   if (input_channels == 0) {
     xnn_log_error(
       "failed to create %s operator with %zu input channels: number of channels must be non-zero",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), input_channels);
+      xnn_operator_type_to_string(operator_type), input_channels);
     goto error;
   }
 
   if (output_channels == 0) {
     xnn_log_error(
       "failed to create %s operator with %zu output channels: number of channels must be non-zero",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), output_channels);
+      xnn_operator_type_to_string(operator_type), output_channels);
     goto error;
   }
 
@@ -70,7 +73,7 @@
     xnn_log_error(
       "failed to create %s operator with input element stride of %zu: "
       "stride must be at least as large as the number of input channels (%zu)",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), input_stride, input_channels);
+      xnn_operator_type_to_string(operator_type), input_stride, input_channels);
     goto error;
   }
 
@@ -78,197 +81,7 @@
     xnn_log_error(
       "failed to create %s operator with output element stride of %zu: "
       "stride must be at least as large as the number of output channels (%zu)",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), output_stride, output_channels);
-    goto error;
-  }
-
-  if (input_scale <= 0.0f || !isnormal(input_scale)) {
-    xnn_log_error(
-      "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), input_scale);
-    goto error;
-  }
-
-  if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
-    xnn_log_error(
-      "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), kernel_scale);
-    goto error;
-  }
-
-  if (output_scale <= 0.0f || !isnormal(output_scale)) {
-    xnn_log_error(
-      "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), output_scale);
-    goto error;
-  }
-
-  if (output_min >= output_max) {
-    xnn_log_error(
-      "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), output_min, output_max);
-    goto error;
-  }
-
-  status = xnn_status_unsupported_parameter;
-
-  const float requantization_scale = input_scale * kernel_scale / output_scale;
-  if (requantization_scale >= 1.0f) {
-    xnn_log_error(
-      "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
-      "requantization scale %.7g is greater or equal to 1.0",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8),
-      input_scale, kernel_scale, output_scale, requantization_scale);
-    goto error;
-  }
-
-  status = xnn_status_out_of_memory;
-
-  fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
-  if (fully_connected_op == NULL) {
-    xnn_log_error("failed to allocate %zu bytes for %s operator descriptor",
-      sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8));
-    goto error;
-  }
-
-  const uint32_t nr = xnn_params.q8.gemm.nr;
-  const uint32_t kr = UINT32_C(1) << xnn_params.q8.gemm.log2_kr;
-
-  const size_t n_stride = round_up(output_channels, nr);
-  const size_t k_stride = round_up_po2(input_channels, kr);
-
-  const size_t packed_weights_size = n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t));
-  fully_connected_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
-  if (fully_connected_op->packed_weights == NULL) {
-    xnn_log_error("failed to allocate %zu bytes for %s operator packed weights",
-      packed_weights_size, xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8));
-    goto error;
-  }
-  memset(fully_connected_op->packed_weights, kernel_zero_point, n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
-
-  const struct xnn_q8_packing_params packing_params = {
-    .input_zero_point = input_zero_point,
-    .kernel_zero_point = kernel_zero_point,
-  };
-  if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
-    xnn_pack_q8_gemm_io_w(
-      output_channels, input_channels,
-      nr, kr,
-      kernel, bias,
-      fully_connected_op->packed_weights,
-      &packing_params);
-  } else {
-    xnn_pack_q8_gemm_goi_w(
-      1, output_channels, input_channels,
-      nr, kr,
-      kernel, bias,
-      fully_connected_op->packed_weights,
-      &packing_params);
-  }
-
-  fully_connected_op->group_input_channels = input_channels;
-  fully_connected_op->group_output_channels = output_channels;
-  fully_connected_op->input_pixel_stride = input_stride;
-  fully_connected_op->output_pixel_stride = output_stride;
-
-  fully_connected_op->kernel_zero_point = kernel_zero_point;
-
-  fully_connected_op->params.q8_gemm =
-    xnn_init_q8_gemm_params(
-      input_zero_point, kernel_zero_point,
-      requantization_scale, output_zero_point, output_min, output_max);
-
-  fully_connected_op->type = xnn_operator_type_fully_connected_nc_q8;
-
-  fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
-  fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
-    .general_case = xnn_params.q8.gemm.minmax.gemm,
-    .mr = xnn_params.q8.gemm.mr,
-    .nr = nr,
-    .kr = kr,
-  };
-
-  fully_connected_op->state = xnn_run_state_invalid;
-
-  *fully_connected_op_out = fully_connected_op;
-  return xnn_status_success;
-
-error:
-  xnn_delete_operator(fully_connected_op);
-  return status;
-}
-
-enum xnn_status xnn_create_fully_connected_nc_f32(
-    size_t input_channels,
-    size_t output_channels,
-    size_t input_stride,
-    size_t output_stride,
-    const float* kernel,
-    const float* bias,
-    float output_min,
-    float output_max,
-    uint32_t flags,
-    xnn_operator_t* fully_connected_op_out)
-{
-  xnn_operator_t fully_connected_op = NULL;
-  enum xnn_status status = xnn_status_uninitialized;
-
-  if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
-    xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
-    goto error;
-  }
-
-  status = xnn_status_invalid_parameter;
-
-  if (input_channels == 0) {
-    xnn_log_error(
-      "failed to create %s operator with %zu input channels: number of channels must be non-zero",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), input_channels);
-    goto error;
-  }
-
-  if (output_channels == 0) {
-    xnn_log_error(
-      "failed to create %s operator with %zu output channels: number of channels must be non-zero",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_channels);
-    goto error;
-  }
-
-  if (input_stride < input_channels) {
-    xnn_log_error(
-      "failed to create %s operator with input element stride of %zu: "
-      "stride must be at least as large as the number of input channels (%zu)",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), input_stride, input_channels);
-    goto error;
-  }
-
-  if (output_stride < output_channels) {
-    xnn_log_error(
-      "failed to create %s operator with output element stride of %zu: "
-      "stride must be at least as large as the number of output channels (%zu)",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_stride, output_channels);
-    goto error;
-  }
-
-  if (isnan(output_min)) {
-    xnn_log_error(
-      "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
-    goto error;
-  }
-
-  if (isnan(output_max)) {
-    xnn_log_error(
-      "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
-    goto error;
-  }
-
-  if (output_min >= output_max) {
-    xnn_log_error(
-      "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
-      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_min, output_max);
+      xnn_operator_type_to_string(operator_type), output_stride, output_channels);
     goto error;
   }
 
@@ -278,41 +91,41 @@
   if (fully_connected_op == NULL) {
     xnn_log_error(
       "failed to allocate %zu bytes for %s operator descriptor",
-      sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
+      sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
     goto error;
   }
 
-  const uint32_t nr = xnn_params.f32.gemm.nr;
-  const uint32_t kr = UINT32_C(1) << xnn_params.f32.gemm.log2_kr;
-  const uint32_t sr = UINT32_C(1) << xnn_params.f32.gemm.log2_sr;
+  const uint32_t nr = gemm_parameters->nr;
+  const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
+  const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
 
   const size_t n_stride = round_up(output_channels, nr);
   const size_t k_stride = round_up_po2(input_channels, kr);
 
-  const size_t packed_weights_size = n_stride * (k_stride * sizeof(float) + sizeof(float));
+  const size_t packed_weights_size = n_stride * (bias_element_size + (k_stride << log2_filter_element_size));
   fully_connected_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
   if (fully_connected_op->packed_weights == NULL) {
     xnn_log_error(
       "failed to allocate %zu bytes for %s operator packed weights",
-      packed_weights_size, xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
+      packed_weights_size, xnn_operator_type_to_string(operator_type));
     goto error;
   }
-  memset(fully_connected_op->packed_weights, 0, n_stride * (k_stride * sizeof(float) + sizeof(float)));
+  memset(fully_connected_op->packed_weights, packed_weights_padding_byte, packed_weights_size);
 
   if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
-    xnn_pack_f32_gemm_io_w(
+    pack_gemm_io_w(
       output_channels, input_channels,
       nr, kr, sr,
       kernel, bias,
       fully_connected_op->packed_weights,
-      NULL);
+      packing_params);
   } else {
-    xnn_pack_f32_gemm_goi_w(
+    pack_gemm_goi_w(
       1, output_channels, input_channels,
       nr, kr, sr,
       kernel, bias,
       fully_connected_op->packed_weights,
-      NULL);
+      packing_params);
   }
 
   fully_connected_op->group_input_channels = input_channels;
@@ -320,21 +133,14 @@
   fully_connected_op->input_pixel_stride = input_stride;
   fully_connected_op->output_pixel_stride = output_stride;
 
-  fully_connected_op->params.f32_minmax = xnn_init_f32_minmax_params(output_min, output_max);
-
-  fully_connected_op->type = xnn_operator_type_fully_connected_nc_f32;
-
-  const struct gemm_fused_ukernels* ukernels = &xnn_params.f32.gemm.minmax;
-  const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
-  if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
-    ukernels = &xnn_params.f32.gemm.linear;
-  }
+  memcpy(&fully_connected_op->params, params, params_size);
+  fully_connected_op->type = operator_type;
 
   fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
   fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
-    .general_case = ukernels->gemm,
-    .mr1_case = ukernels->gemm1,
-    .mr = xnn_params.f32.gemm.mr,
+    .general_case = gemm_ukernels->gemm,
+    .mr1_case = gemm_ukernels->gemm1,
+    .mr = gemm_parameters->mr,
     .nr = nr,
     .kr = kr,
   };
@@ -359,6 +165,7 @@
   uint32_t bias_element_size,
   uint32_t log2_output_element_size,
   const void* params,
+  size_t params_size,
   size_t num_threads)
 {
   fully_connected_op->state = xnn_run_state_invalid;
@@ -407,7 +214,7 @@
     .log2_csize = log2_output_element_size,
     .ukernel = gemm_ukernel,
   };
-  memcpy(&fully_connected_op->context.gemm.params, params, sizeof(fully_connected_op->context.gemm.params));
+  memcpy(&fully_connected_op->context.gemm.params, params, params_size);
 
   size_t nc = output_channels;
   if (num_threads > 1) {
@@ -429,6 +236,138 @@
   return xnn_status_success;
 }
 
+enum xnn_status xnn_create_fully_connected_nc_q8(
+    size_t input_channels,
+    size_t output_channels,
+    size_t input_stride,
+    size_t output_stride,
+    uint8_t input_zero_point,
+    float input_scale,
+    uint8_t kernel_zero_point,
+    float kernel_scale,
+    const uint8_t* kernel,
+    const int32_t* bias,
+    uint8_t output_zero_point,
+    float output_scale,
+    uint8_t output_min,
+    uint8_t output_max,
+    uint32_t flags,
+    xnn_operator_t* fully_connected_op_out)
+{
+  if (input_scale <= 0.0f || !isnormal(input_scale)) {
+    xnn_log_error(
+      "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), input_scale);
+    return xnn_status_invalid_parameter;
+  }
+
+  if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
+    xnn_log_error(
+      "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), kernel_scale);
+    return xnn_status_invalid_parameter;
+  }
+
+  if (output_scale <= 0.0f || !isnormal(output_scale)) {
+    xnn_log_error(
+      "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), output_scale);
+    return xnn_status_invalid_parameter;
+  }
+
+  if (output_min >= output_max) {
+    xnn_log_error(
+      "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8), output_min, output_max);
+    return xnn_status_invalid_parameter;
+  }
+
+  const float requantization_scale = input_scale * kernel_scale / output_scale;
+  if (requantization_scale >= 1.0f) {
+    xnn_log_error(
+      "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
+      "requantization scale %.7g is greater or equal to 1.0",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_q8),
+      input_scale, kernel_scale, output_scale, requantization_scale);
+    return xnn_status_unsupported_parameter;
+  }
+
+  const union xnn_q8_gemm_params params = xnn_init_q8_gemm_params(
+    input_zero_point, kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
+  const struct xnn_q8_packing_params packing_params = {
+    .input_zero_point = input_zero_point,
+    .kernel_zero_point = kernel_zero_point,
+  };
+  return create_fully_connected_nc(
+    input_channels, output_channels,
+    input_stride, output_stride,
+    kernel, bias, flags,
+    sizeof(int32_t) /* sizeof(bias element) */,
+    0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
+    (xnn_pack_gemm_io_w_function) xnn_pack_q8_gemm_io_w,
+    (xnn_pack_gemm_goi_w_function) xnn_pack_q8_gemm_goi_w,
+    &packing_params /* packing params */, kernel_zero_point /* packed weights padding byte */,
+    &params, sizeof(params),
+    &xnn_params.q8.gemm, &xnn_params.q8.gemm.minmax,
+    xnn_operator_type_fully_connected_nc_q8,
+    fully_connected_op_out);
+}
+
+enum xnn_status xnn_create_fully_connected_nc_f32(
+    size_t input_channels,
+    size_t output_channels,
+    size_t input_stride,
+    size_t output_stride,
+    const float* kernel,
+    const float* bias,
+    float output_min,
+    float output_max,
+    uint32_t flags,
+    xnn_operator_t* fully_connected_op_out)
+{
+  if (isnan(output_min)) {
+    xnn_log_error(
+      "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
+    return xnn_status_invalid_parameter;
+  }
+
+  if (isnan(output_max)) {
+    xnn_log_error(
+      "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
+    return xnn_status_invalid_parameter;
+  }
+
+  if (output_min >= output_max) {
+    xnn_log_error(
+      "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
+      xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_min, output_max);
+    return xnn_status_invalid_parameter;
+  }
+
+  const struct gemm_fused_ukernels* gemm_ukernels = &xnn_params.f32.gemm.minmax;
+  const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
+  if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
+    gemm_ukernels = &xnn_params.f32.gemm.linear;
+  }
+
+  const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
+  return create_fully_connected_nc(
+    input_channels, output_channels,
+    input_stride, output_stride,
+    kernel, bias, flags,
+    sizeof(float) /* sizeof(bias element) */,
+    2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
+    (xnn_pack_gemm_io_w_function) xnn_pack_f32_gemm_io_w,
+    (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
+    NULL /* packing params */, 0 /* packed weights padding byte */,
+    &params, sizeof(params),
+    &xnn_params.f32.gemm, gemm_ukernels,
+    xnn_operator_type_fully_connected_nc_f32,
+    fully_connected_op_out);
+}
+
 enum xnn_status xnn_setup_fully_connected_nc_q8(
     xnn_operator_t fully_connected_op,
     size_t batch_size,
@@ -452,6 +391,7 @@
     sizeof(int32_t) /* sizeof(bias element) */,
     0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
     &fully_connected_op->params.q8_gemm,
+    sizeof(fully_connected_op->params.q8_gemm),
     pthreadpool_get_threads_count(threadpool));
 }
 
@@ -478,5 +418,6 @@
     sizeof(float) /* sizeof(bias element) */,
     2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
     &fully_connected_op->params.f32_minmax,
+    sizeof(fully_connected_op->params.f32_minmax),
     pthreadpool_get_threads_count(threadpool));
 }
diff --git a/src/packing.c b/src/packing.c
index 2dd7f1f..0868104 100644
--- a/src/packing.c
+++ b/src/packing.c
@@ -127,11 +127,13 @@
   size_t kc,
   size_t nr,
   size_t kr,
+  size_t sr,
   const uint8_t* k,
   const int32_t* b,
   void* packed_w,
   const struct xnn_q8_packing_params* params)
 {
+  assert(sr == 1);
   const int32_t izp = (int32_t) params->input_zero_point;
   const int32_t boff = (int32_t) kc * izp * (int32_t) params->kernel_zero_point;
   do {
@@ -273,11 +275,13 @@
   size_t kc,
   size_t nr,
   size_t kr,
+  size_t sr,
   const uint8_t* k,
   const int32_t* b,
   void* packed_w,
   const struct xnn_q8_packing_params* params)
 {
+  assert(sr == 1);
   const int32_t izp = (int32_t) params->input_zero_point;
   const int32_t boff = (int32_t) kc * izp * (int32_t) params->kernel_zero_point;
   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h
index cb20b58..a45eb63 100644
--- a/src/xnnpack/pack.h
+++ b/src/xnnpack/pack.h
@@ -26,6 +26,18 @@
 };
 
 
+typedef void (*xnn_pack_gemm_goi_w_function)(
+  size_t g,
+  size_t nc,
+  size_t kc,
+  size_t nr,
+  size_t kr,
+  size_t sr,
+  const void* k,
+  const void* b,
+  void* packed_w,
+  const void* params);
+
 XNN_INTERNAL void xnn_pack_f32_gemm_goi_w(
   size_t g,
   size_t nc,
@@ -56,12 +68,24 @@
   size_t kc,
   size_t nr,
   size_t kr,
+  size_t sr,
   const uint8_t* k,
   const int32_t* b,
   void* packed_w,
   const struct xnn_q8_packing_params* params);
 
 
+typedef void (*xnn_pack_gemm_io_w_function)(
+  size_t nc,
+  size_t kc,
+  size_t nr,
+  size_t kr,
+  size_t sr,
+  const void* k,
+  const void* b,
+  void* packed_w,
+  const void* params);
+
 XNN_INTERNAL void xnn_pack_f32_gemm_io_w(
   size_t nc,
   size_t kc,
@@ -89,6 +113,7 @@
   size_t kc,
   size_t nr,
   size_t kr,
+  size_t sr,
   const uint8_t* k,
   const int32_t* b,
   void* packed_w,
diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h
index b86a1f6..4e35875 100644
--- a/test/gemm-microkernel-tester.h
+++ b/test/gemm-microkernel-tester.h
@@ -243,7 +243,7 @@
         .input_zero_point = a_zero_point(),
         .kernel_zero_point = b_zero_point(),
       };
-      xnn_pack_q8_gemm_goi_w(1, n(), k(), nr(), kr(),
+      xnn_pack_q8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
         b.data(), bias.data(), packed_w.data(), &packing_params);
 
       // Compute 32-bit results and output quantization arguments.