Support built-in transposition of static weights in Fully Connected operator

PiperOrigin-RevId: 283628934
diff --git a/include/xnnpack.h b/include/xnnpack.h
index a909585..7ce9b1c 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -30,6 +30,9 @@
 /// The convolution operator represents a depthwise convolution, and use HWGo layout for filters.
 #define XNN_FLAG_DEPTHWISE_CONVOLUTION 0x00000001
 
+/// Assume transposed weights in a fully connected operator.
+#define XNN_FLAG_TRANSPOSE_WEIGHTS 0x00000001
+
 /// The operator assumes NHWC layout for the input, regardless of the output layout.
 #define XNN_FLAG_INPUT_NHWC 0x00000002
 
diff --git a/src/fully-connected-nc.c b/src/fully-connected-nc.c
index 9399a80..447ad17 100644
--- a/src/fully-connected-nc.c
+++ b/src/fully-connected-nc.c
@@ -143,12 +143,21 @@
   }
   memset(fully_connected_op->packed_weights, kernel_zero_point, n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
 
-  xnn_pack_q8_gemm_goi_w(
-    1, output_channels, input_channels,
-    nr, kr,
-    input_zero_point, kernel_zero_point,
-    kernel, bias,
-    fully_connected_op->packed_weights);
+  if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
+    xnn_pack_q8_gemm_io_w(
+      output_channels, input_channels,
+      nr, kr,
+      input_zero_point, kernel_zero_point,
+      kernel, bias,
+      fully_connected_op->packed_weights);
+  } else {
+    xnn_pack_q8_gemm_goi_w(
+      1, output_channels, input_channels,
+      nr, kr,
+      input_zero_point, kernel_zero_point,
+      kernel, bias,
+      fully_connected_op->packed_weights);
+  }
 
   fully_connected_op->group_input_channels = input_channels;
   fully_connected_op->group_output_channels = output_channels;
@@ -276,11 +285,19 @@
   }
   memset(fully_connected_op->packed_weights, 0, n_stride * (k_stride * sizeof(float) + sizeof(float)));
 
-  xnn_pack_f32_gemm_goi_w(
-    1, output_channels, input_channels,
-    nr, kr, sr,
-    kernel, bias,
-    fully_connected_op->packed_weights);
+  if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
+    xnn_pack_f32_gemm_io_w(
+      output_channels, input_channels,
+      nr, kr, sr,
+      kernel, bias,
+      fully_connected_op->packed_weights);
+  } else {
+    xnn_pack_f32_gemm_goi_w(
+      1, output_channels, input_channels,
+      nr, kr, sr,
+      kernel, bias,
+      fully_connected_op->packed_weights);
+  }
 
   fully_connected_op->group_input_channels = input_channels;
   fully_connected_op->group_output_channels = output_channels;
diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h
index 865c5d7..a68dfbd 100644
--- a/src/xnnpack/pack.h
+++ b/src/xnnpack/pack.h
@@ -66,6 +66,52 @@
   } while (--g != 0);
 }
 
+static inline void xnn_pack_q8_gemm_io_w(
+  size_t nc,
+  size_t kc,
+  uint32_t nr,
+  uint32_t kr,
+  uint8_t izp,
+  uint8_t kzp,
+  const uint8_t* k,
+  const int32_t* b,
+  void* packed_w)
+{
+  const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
+  for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+    const size_t nr_block_size = min(nc - nr_block_start, nr);
+    int32_t* packed_b = (int32_t*) packed_w;
+    if XNN_LIKELY(b != NULL) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
+        packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+      }
+    } else {
+      size_t n = nr_block_size;
+      do {
+        *((int32_t*) packed_w) = boff;
+        packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+      } while (--n != 0);
+    }
+    packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+    for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+      const size_t kr_block_size = min(kc - kr_block_start, kr);
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        int32_t ksum = 0;
+        for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+          const uint8_t kv = k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+          ksum += (int32_t) kv;
+          *((uint8_t*) packed_w) = kv;
+          packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+        }
+        packed_b[nr_block_offset] -= ksum * (int32_t) izp;
+        packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
+      }
+      packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
+    }
+  }
+}
+
 static inline void xnn_pack_q8_conv_goki_w(
   size_t g,
   size_t nc,
@@ -363,6 +409,37 @@
   } while (--g != 0);
 }
 
+static inline void xnn_pack_f16_gemm_io_w(
+  size_t nc,
+  size_t kc,
+  size_t nr,
+  size_t kr,
+  const uint16_t* k,
+  const uint16_t* b,
+  uint16_t* packed_w)
+{
+  for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+    const size_t nr_block_size = min(nc - nr_block_start, nr);
+    if XNN_LIKELY(b != NULL) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
+      }
+    }
+    packed_w += nr;
+    for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+      const size_t kr_block_size = min(kc - kr_block_start, kr);
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+          *packed_w++ =
+            k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+        }
+        packed_w += kr - kr_block_size;
+      }
+      packed_w += (nr - nr_block_size) * kr;
+    }
+  }
+}
+
 static inline void xnn_pack_f32_gemm_goi_w(
   size_t g,
   size_t nc,
@@ -416,6 +493,52 @@
   } while (--g != 0);
 }
 
+static inline void xnn_pack_f32_gemm_io_w(
+  size_t nc,
+  size_t kc,
+  size_t nr,
+  size_t kr,
+  size_t sr,
+  const float* k,
+  const float* b,
+  float* packed_w)
+{
+  const size_t skr = sr * kr;
+  const size_t skc = round_down_po2(kc, skr);
+  const size_t sr_mask = (sr - 1) * kr;
+  for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+    const size_t nr_block_size = min(nc - nr_block_start, nr);
+    if XNN_LIKELY(b != NULL) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
+      }
+    }
+    packed_w += nr;
+
+    for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
+          *packed_w++ =
+            k[(round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+        }
+      }
+      packed_w += (nr - nr_block_size) * kr;
+    }
+
+    for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
+      const size_t kr_block_size = min(kc - kr_block_start, kr);
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+          *packed_w++ =
+            k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+        }
+        packed_w += kr - kr_block_size;
+      }
+      packed_w += (nr - nr_block_size) * kr;
+    }
+  }
+}
+
 static inline void xnn_pack_f32_gemminc_goi_w(
   size_t g,
   size_t nc,
diff --git a/test/fully-connected-nc.cc b/test/fully-connected-nc.cc
index 91c9508..81477ed 100644
--- a/test/fully-connected-nc.cc
+++ b/test/fully-connected-nc.cc
@@ -60,6 +60,16 @@
     .TestQ8();
 }
 
+TEST(FULLY_CONNECTED_NC_Q8, unit_batch_transpose_weights) {
+  FullyConnectedOperatorTester()
+    .transpose_weights(true)
+    .batch_size(1)
+    .input_channels(23)
+    .output_channels(19)
+    .iterations(3)
+    .TestQ8();
+}
+
 TEST(FULLY_CONNECTED_NC_Q8, unit_batch_without_bias) {
   FullyConnectedOperatorTester()
     .has_bias(false)
@@ -119,6 +129,16 @@
     .TestQ8();
 }
 
+TEST(FULLY_CONNECTED_NC_Q8, small_batch_transpose_weights) {
+  FullyConnectedOperatorTester()
+    .transpose_weights(true)
+    .batch_size(12)
+    .input_channels(23)
+    .output_channels(19)
+    .iterations(3)
+    .TestQ8();
+}
+
 TEST(FULLY_CONNECTED_NC_Q8, small_batch_without_bias) {
   FullyConnectedOperatorTester()
     .has_bias(false)
@@ -178,6 +198,16 @@
     .TestF32();
 }
 
+TEST(FULLY_CONNECTED_NC_F32, unit_batch_transpose_weights) {
+  FullyConnectedOperatorTester()
+    .transpose_weights(true)
+    .batch_size(1)
+    .input_channels(23)
+    .output_channels(19)
+    .iterations(3)
+    .TestF32();
+}
+
 TEST(FULLY_CONNECTED_NC_F32, unit_batch_without_bias) {
   FullyConnectedOperatorTester()
     .has_bias(false)
@@ -237,6 +267,16 @@
     .TestF32();
 }
 
+TEST(FULLY_CONNECTED_NC_F32, small_batch_transpose_weights) {
+  FullyConnectedOperatorTester()
+    .transpose_weights(true)
+    .batch_size(12)
+    .input_channels(23)
+    .output_channels(19)
+    .iterations(3)
+    .TestF32();
+}
+
 TEST(FULLY_CONNECTED_NC_F32, small_batch_without_bias) {
   FullyConnectedOperatorTester()
     .has_bias(false)
diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h
index 81370f9..d28a18d 100644
--- a/test/fully-connected-operator-tester.h
+++ b/test/fully-connected-operator-tester.h
@@ -101,6 +101,15 @@
     return this->qmax_;
   }
 
+  inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
+    this->transpose_weights_ = transpose_weights;
+    return *this;
+  }
+
+  inline bool transpose_weights() const {
+    return this->transpose_weights_;
+  }
+
   inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
     this->has_bias_ = has_bias;
     return *this;
@@ -152,12 +161,24 @@
       } else {
         std::fill(accumulators.begin(), accumulators.end(), 0);
       }
-      for (size_t i = 0; i < batch_size(); i++) {
-        for (size_t oc = 0; oc < output_channels(); oc++) {
-          for (size_t ic = 0; ic < input_channels(); ic++) {
-            accumulators[i * output_channels() + oc] +=
-              (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
-              (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+      if (transpose_weights()) {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              accumulators[i * output_channels() + oc] +=
+                (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+                (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
+            }
+          }
+        }
+      } else {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              accumulators[i * output_channels() + oc] +=
+                (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+                (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+            }
           }
         }
       }
@@ -189,7 +210,8 @@
           kernel_zero_point, 1.0f /* kernel scale */,
           kernel.data(), has_bias() ? bias.data() : nullptr,
           output_zero_point, output_scale, qmin(), qmax(),
-          0, &fully_connected_op));
+          transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+          &fully_connected_op));
 
       // Smart pointer to automatically delete fully_connected_op.
       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -249,11 +271,22 @@
       } else {
         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
       }
-      for (size_t i = 0; i < batch_size(); i++) {
-        for (size_t oc = 0; oc < output_channels(); oc++) {
-          for (size_t ic = 0; ic < input_channels(); ic++) {
-            output_ref[i * output_channels() + oc] +=
-              input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+      if (transpose_weights()) {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              output_ref[i * output_channels() + oc] +=
+                input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
+            }
+          }
+        }
+      } else {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              output_ref[i * output_channels() + oc] +=
+                input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+            }
           }
         }
       }
@@ -280,7 +313,8 @@
           input_stride(), output_stride(),
           kernel.data(), has_bias() ? bias.data() : nullptr,
           output_min, output_max,
-          0, &fully_connected_op));
+          transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+          &fully_connected_op));
 
       // Smart pointer to automatically delete fully_connected_op.
       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -320,6 +354,7 @@
   size_t batch_size_{1};
   uint8_t qmin_{0};
   uint8_t qmax_{255};
+  bool transpose_weights_{false};
   bool has_bias_{true};
   size_t iterations_{1};
 };