Support built-in transposition of static weights in Fully Connected operator

PiperOrigin-RevId: 283628934
diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h
index 865c5d7..a68dfbd 100644
--- a/src/xnnpack/pack.h
+++ b/src/xnnpack/pack.h
@@ -66,6 +66,52 @@
   } while (--g != 0);
 }
 
+static inline void xnn_pack_q8_gemm_io_w(
+  size_t nc,
+  size_t kc,
+  uint32_t nr,
+  uint32_t kr,
+  uint8_t izp,
+  uint8_t kzp,
+  const uint8_t* k,
+  const int32_t* b,
+  void* packed_w)
+{
+  const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
+  for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+    const size_t nr_block_size = min(nc - nr_block_start, nr);
+    int32_t* packed_b = (int32_t*) packed_w;
+    if XNN_LIKELY(b != NULL) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
+        packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+      }
+    } else {
+      size_t n = nr_block_size;
+      do {
+        *((int32_t*) packed_w) = boff;
+        packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+      } while (--n != 0);
+    }
+    packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+    for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+      const size_t kr_block_size = min(kc - kr_block_start, kr);
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        int32_t ksum = 0;
+        for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+          const uint8_t kv = k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+          ksum += (int32_t) kv;
+          *((uint8_t*) packed_w) = kv;
+          packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+        }
+        packed_b[nr_block_offset] -= ksum * (int32_t) izp;
+        packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
+      }
+      packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
+    }
+  }
+}
+
 static inline void xnn_pack_q8_conv_goki_w(
   size_t g,
   size_t nc,
@@ -363,6 +409,37 @@
   } while (--g != 0);
 }
 
+static inline void xnn_pack_f16_gemm_io_w(
+  size_t nc,
+  size_t kc,
+  size_t nr,
+  size_t kr,
+  const uint16_t* k,
+  const uint16_t* b,
+  uint16_t* packed_w)
+{
+  for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+    const size_t nr_block_size = min(nc - nr_block_start, nr);
+    if XNN_LIKELY(b != NULL) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
+      }
+    }
+    packed_w += nr;
+    for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+      const size_t kr_block_size = min(kc - kr_block_start, kr);
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+          *packed_w++ =
+            k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+        }
+        packed_w += kr - kr_block_size;
+      }
+      packed_w += (nr - nr_block_size) * kr;
+    }
+  }
+}
+
 static inline void xnn_pack_f32_gemm_goi_w(
   size_t g,
   size_t nc,
@@ -416,6 +493,52 @@
   } while (--g != 0);
 }
 
+static inline void xnn_pack_f32_gemm_io_w(
+  size_t nc,
+  size_t kc,
+  size_t nr,
+  size_t kr,
+  size_t sr,
+  const float* k,
+  const float* b,
+  float* packed_w)
+{
+  const size_t skr = sr * kr;
+  const size_t skc = round_down_po2(kc, skr);
+  const size_t sr_mask = (sr - 1) * kr;
+  for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+    const size_t nr_block_size = min(nc - nr_block_start, nr);
+    if XNN_LIKELY(b != NULL) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
+      }
+    }
+    packed_w += nr;
+
+    for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
+          *packed_w++ =
+            k[(round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+        }
+      }
+      packed_w += (nr - nr_block_size) * kr;
+    }
+
+    for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
+      const size_t kr_block_size = min(kc - kr_block_start, kr);
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+          *packed_w++ =
+            k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+        }
+        packed_w += kr - kr_block_size;
+      }
+      packed_w += (nr - nr_block_size) * kr;
+    }
+  }
+}
+
 static inline void xnn_pack_f32_gemminc_goi_w(
   size_t g,
   size_t nc,