Auto-switch to LINEAR GEMM/IGEMM/DWCONV micro-kernels

- Initialize pointers to LINEAR-activation micro-kernels on WebAssembly
- Automatically detect and use LINEAR-activation micro-kernels
- Update unit tests to test both MINMAX- and LINEAR-activation micro-kernels

PiperOrigin-RevId: 305792844
diff --git a/src/convolution-nhwc.c b/src/convolution-nhwc.c
index f8c73ce..daf5137 100644
--- a/src/convolution-nhwc.c
+++ b/src/convolution-nhwc.c
@@ -552,6 +552,7 @@
   } else {
     ukernel_type = xnn_ukernel_type_igemm;
   }
+  const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
 
   size_t zero_size = 0;
   switch (ukernel_type) {
@@ -600,8 +601,12 @@
           kernel, bias, convolution_op->packed_weights);
       }
 
+      const union dwconv_fused_ukernels* ukernels = &dwconv_parameters->minmax;
+      if (linear_activation && dwconv_parameters->linear.unipass != NULL) {
+        ukernels = &dwconv_parameters->linear;
+      }
       convolution_op->ukernel.dwconv = (struct xnn_ukernel_dwconv) {
-        .unipass_function = dwconv_parameters->minmax.unipass,
+        .unipass_function = ukernels->unipass,
         .primary_tile = dwconv_parameters->primary_tile,
         .incremental_tile = dwconv_parameters->incremental_tile,
       };
@@ -626,6 +631,10 @@
       }
       memset(convolution_op->packed_weights, 0, packed_group_weights_size * groups);
 
+      const struct gemm_fused_ukernels* ukernels = &xnn_params.f32.gemm.minmax;
+      if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
+        ukernels = &xnn_params.f32.gemm.linear;
+      }
       switch (ukernel_type) {
         case xnn_ukernel_type_gemm:
           xnn_pack_f32_gemm_goi_w(
@@ -636,8 +645,8 @@
             .mr = xnn_params.f32.gemm.mr,
             .nr = nr,
             .kr = kr,
-            .general_case = xnn_params.f32.gemm.minmax.gemm,
-            .mr1_case = xnn_params.f32.gemm.minmax.gemm1,
+            .general_case = ukernels->gemm,
+            .mr1_case = ukernels->gemm1,
           };
           break;
         case xnn_ukernel_type_igemm:
@@ -656,8 +665,8 @@
             .mr = xnn_params.f32.gemm.mr,
             .nr = nr,
             .kr = kr,
-            .general_case = xnn_params.f32.gemm.minmax.igemm,
-            .mr1_case = xnn_params.f32.gemm.minmax.igemm1,
+            .general_case = ukernels->igemm,
+            .mr1_case = ukernels->igemm1,
           };
           break;
         default:
diff --git a/src/deconvolution-nhwc.c b/src/deconvolution-nhwc.c
index 323b269..4c1087b 100644
--- a/src/deconvolution-nhwc.c
+++ b/src/deconvolution-nhwc.c
@@ -427,23 +427,24 @@
     goto error;
   }
 
-  uint32_t mr = xnn_params.f32.gemm.mr;
-  uint32_t nr = xnn_params.f32.gemm.nr;
-  uint32_t kr = UINT32_C(1) << xnn_params.f32.gemm.log2_kr;
-  uint32_t sr = UINT32_C(1) << xnn_params.f32.gemm.log2_sr;
-  struct xnn_hmp_igemm_ukernel igemm_ukernel = xnn_params.f32.gemm.minmax.igemm;
-  struct xnn_hmp_gemm_ukernel gemm_ukernel = xnn_params.f32.gemm.minmax.gemm;
-  if (nr > group_output_channels) {
+  const struct gemm_parameters* gemm_params = &xnn_params.f32.gemm;
+  if (gemm_params->nr > group_output_channels) {
     // Default micro-kernel is suboptimal. Try to find a better micro-kernel.
     if (xnn_params.f32.gemm2.minmax.igemm.function[XNN_UARCH_DEFAULT] != NULL) {
-      mr = xnn_params.f32.gemm2.mr;
-      nr = xnn_params.f32.gemm2.nr;
-      kr = UINT32_C(1) << xnn_params.f32.gemm2.log2_kr;
-      sr = UINT32_C(1) << xnn_params.f32.gemm2.log2_sr;
-      igemm_ukernel = xnn_params.f32.gemm2.minmax.igemm;
-      gemm_ukernel = xnn_params.f32.gemm2.minmax.gemm;
+      gemm_params = &xnn_params.f32.gemm2;
     }
   }
+  const uint32_t mr = gemm_params->mr;
+  const uint32_t nr = gemm_params->nr;
+  const uint32_t kr = UINT32_C(1) << gemm_params->log2_kr;
+  const uint32_t sr = UINT32_C(1) << gemm_params->log2_sr;
+  const struct gemm_fused_ukernels* ukernels = &gemm_params->minmax;
+  const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
+  if (linear_activation && gemm_params->linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
+    ukernels = &gemm_params->linear;
+  }
+  struct xnn_hmp_igemm_ukernel igemm_ukernel = ukernels->igemm;
+  struct xnn_hmp_gemm_ukernel gemm_ukernel = ukernels->gemm;
 
   const uint32_t n_stride = round_up(group_output_channels, nr);
   const uint32_t k_stride = round_up_po2(group_input_channels, kr);
diff --git a/src/fully-connected-nc.c b/src/fully-connected-nc.c
index c867af9..7677984 100644
--- a/src/fully-connected-nc.c
+++ b/src/fully-connected-nc.c
@@ -308,10 +308,16 @@
 
   fully_connected_op->type = xnn_operator_type_fully_connected_nc_f32;
 
+  const struct gemm_fused_ukernels* ukernels = &xnn_params.f32.gemm.minmax;
+  const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
+  if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
+    ukernels = &xnn_params.f32.gemm.linear;
+  }
+
   fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
   fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
-    .general_case = xnn_params.f32.gemm.minmax.gemm,
-    .mr1_case = xnn_params.f32.gemm.minmax.gemm1,
+    .general_case = ukernels->gemm,
+    .mr1_case = ukernels->gemm1,
     .mr = xnn_params.f32.gemm.mr,
     .nr = nr,
     .kr = kr,
diff --git a/src/init.c b/src/init.c
index 7a504da..b195d1e 100644
--- a/src/init.c
+++ b/src/init.c
@@ -1407,6 +1407,10 @@
       xnn_params.f32.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_minmax_ukernel_2x4__scalar);
       xnn_params.f32.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_minmax_ukernel_1x4__wasm);
       xnn_params.f32.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_minmax_ukernel_1x4__wasm);
+      xnn_params.f32.gemm.linear.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_2x4__scalar);
+      xnn_params.f32.gemm.linear.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_2x4__scalar);
+      xnn_params.f32.gemm.linear.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__wasm);
+      xnn_params.f32.gemm.linear.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__wasm);
       xnn_params.f32.gemm.mr = 2;
       xnn_params.f32.gemm.nr = 4;
     } else {
@@ -1414,23 +1418,32 @@
       xnn_params.f32.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_minmax_ukernel_4x4__wasm);
       xnn_params.f32.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_minmax_ukernel_1x4__wasm);
       xnn_params.f32.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_minmax_ukernel_1x4__wasm);
+      xnn_params.f32.gemm.linear.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x4__wasm);
+      xnn_params.f32.gemm.linear.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x4__wasm);
+      xnn_params.f32.gemm.linear.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__wasm);
+      xnn_params.f32.gemm.linear.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__wasm);
       xnn_params.f32.gemm.mr = 4;
       xnn_params.f32.gemm.nr = 4;
     }
     xnn_params.f32.gemm2.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_minmax_ukernel_4x2__wasm);
     xnn_params.f32.gemm2.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_minmax_ukernel_4x2__wasm),
+    xnn_params.f32.gemm2.linear.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2__wasm);
+    xnn_params.f32.gemm2.linear.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__wasm),
     xnn_params.f32.gemm2.mr = 4;
     xnn_params.f32.gemm2.nr = 2;
 
     xnn_params.f32.dwconv[0].minmax.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_f32_dwconv_minmax_ukernel_up1x4__wasm_acc2;
+    xnn_params.f32.dwconv[0].linear.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_f32_dwconv_ukernel_up1x4__wasm_acc2;
     xnn_params.f32.dwconv[0].channel_tile = 1;
     xnn_params.f32.dwconv[0].primary_tile = 4;
 
     xnn_params.f32.dwconv[1].minmax.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_f32_dwconv_minmax_ukernel_up1x9__wasm_acc2;
+    xnn_params.f32.dwconv[1].linear.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_f32_dwconv_ukernel_up1x9__wasm_acc2;
     xnn_params.f32.dwconv[1].channel_tile = 1;
     xnn_params.f32.dwconv[1].primary_tile = 9;
 
     xnn_params.f32.dwconv[2].minmax.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_f32_dwconv_minmax_ukernel_up1x25__wasm_acc2;
+    xnn_params.f32.dwconv[2].linear.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_f32_dwconv_ukernel_up1x25__wasm_acc2;
     xnn_params.f32.dwconv[2].channel_tile = 1;
     xnn_params.f32.dwconv[2].primary_tile = 25;
 
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index d9ceaf6..caff185 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -1293,6 +1293,7 @@
 
 struct gemm_parameters {
   struct gemm_fused_ukernels minmax;
+  struct gemm_fused_ukernels linear;
   uint8_t mr;
   uint8_t nr;
   uint8_t log2_kr;
@@ -1355,6 +1356,7 @@
 
 struct dwconv_parameters {
   union dwconv_fused_ukernels minmax;
+  union dwconv_fused_ukernels linear;
   uint8_t channel_tile;
   uint8_t primary_tile;
   uint8_t incremental_tile;
diff --git a/test/convolution-operator-tester.h b/test/convolution-operator-tester.h
index f6a092d..70de774 100644
--- a/test/convolution-operator-tester.h
+++ b/test/convolution-operator-tester.h
@@ -957,8 +957,10 @@
       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
 
-      const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
-      const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
+      const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
+        accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
+      const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
+        accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
 
       // Clamp reference results.
       for (float& value : output_ref) {
diff --git a/test/deconvolution-operator-tester.h b/test/deconvolution-operator-tester.h
index 347baa9..08f48e8 100644
--- a/test/deconvolution-operator-tester.h
+++ b/test/deconvolution-operator-tester.h
@@ -633,8 +633,10 @@
       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
 
-      const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
-      const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
+      const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
+        accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
+      const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
+        accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
 
       // Clamp reference results.
       for (float& value : output_ref) {
diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h
index d28a18d..cfc3738 100644
--- a/test/fully-connected-operator-tester.h
+++ b/test/fully-connected-operator-tester.h
@@ -295,8 +295,10 @@
       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
 
-      const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
-      const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
+      const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
+        accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
+      const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
+        accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
 
       // Clamp reference results.
       for (float& value : output_ref) {