Support built-in transposition of static weights in Fully Connected operator

PiperOrigin-RevId: 283628934
diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h
index 81370f9..d28a18d 100644
--- a/test/fully-connected-operator-tester.h
+++ b/test/fully-connected-operator-tester.h
@@ -101,6 +101,15 @@
     return this->qmax_;
   }
 
+  inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
+    this->transpose_weights_ = transpose_weights;
+    return *this;
+  }
+
+  inline bool transpose_weights() const {
+    return this->transpose_weights_;
+  }
+
   inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
     this->has_bias_ = has_bias;
     return *this;
@@ -152,12 +161,24 @@
       } else {
         std::fill(accumulators.begin(), accumulators.end(), 0);
       }
-      for (size_t i = 0; i < batch_size(); i++) {
-        for (size_t oc = 0; oc < output_channels(); oc++) {
-          for (size_t ic = 0; ic < input_channels(); ic++) {
-            accumulators[i * output_channels() + oc] +=
-              (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
-              (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+      if (transpose_weights()) {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              accumulators[i * output_channels() + oc] +=
+                (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+                (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
+            }
+          }
+        }
+      } else {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              accumulators[i * output_channels() + oc] +=
+                (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+                (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+            }
           }
         }
       }
@@ -189,7 +210,8 @@
           kernel_zero_point, 1.0f /* kernel scale */,
           kernel.data(), has_bias() ? bias.data() : nullptr,
           output_zero_point, output_scale, qmin(), qmax(),
-          0, &fully_connected_op));
+          transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+          &fully_connected_op));
 
       // Smart pointer to automatically delete fully_connected_op.
       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -249,11 +271,22 @@
       } else {
         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
       }
-      for (size_t i = 0; i < batch_size(); i++) {
-        for (size_t oc = 0; oc < output_channels(); oc++) {
-          for (size_t ic = 0; ic < input_channels(); ic++) {
-            output_ref[i * output_channels() + oc] +=
-              input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+      if (transpose_weights()) {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              output_ref[i * output_channels() + oc] +=
+                input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
+            }
+          }
+        }
+      } else {
+        for (size_t i = 0; i < batch_size(); i++) {
+          for (size_t oc = 0; oc < output_channels(); oc++) {
+            for (size_t ic = 0; ic < input_channels(); ic++) {
+              output_ref[i * output_channels() + oc] +=
+                input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+            }
           }
         }
       }
@@ -280,7 +313,8 @@
           input_stride(), output_stride(),
           kernel.data(), has_bias() ? bias.data() : nullptr,
           output_min, output_max,
-          0, &fully_connected_op));
+          transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+          &fully_connected_op));
 
       // Smart pointer to automatically delete fully_connected_op.
       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -320,6 +354,7 @@
   size_t batch_size_{1};
   uint8_t qmin_{0};
   uint8_t qmax_{255};
+  bool transpose_weights_{false};
   bool has_bias_{true};
   size_t iterations_{1};
 };