Support built-in transposition of static weights in Fully Connected operator
PiperOrigin-RevId: 283628934
diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h
index 81370f9..d28a18d 100644
--- a/test/fully-connected-operator-tester.h
+++ b/test/fully-connected-operator-tester.h
@@ -101,6 +101,15 @@
return this->qmax_;
}
+ inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
+ this->transpose_weights_ = transpose_weights;
+ return *this;
+ }
+
+ inline bool transpose_weights() const {
+ return this->transpose_weights_;
+ }
+
inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
this->has_bias_ = has_bias;
return *this;
@@ -152,12 +161,24 @@
} else {
std::fill(accumulators.begin(), accumulators.end(), 0);
}
- for (size_t i = 0; i < batch_size(); i++) {
- for (size_t oc = 0; oc < output_channels(); oc++) {
- for (size_t ic = 0; ic < input_channels(); ic++) {
- accumulators[i * output_channels() + oc] +=
- (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
- (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+ if (transpose_weights()) {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ accumulators[i * output_channels() + oc] +=
+ (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+ (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
+ }
+ }
+ }
+ } else {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ accumulators[i * output_channels() + oc] +=
+ (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+ (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+ }
}
}
}
@@ -189,7 +210,8 @@
kernel_zero_point, 1.0f /* kernel scale */,
kernel.data(), has_bias() ? bias.data() : nullptr,
output_zero_point, output_scale, qmin(), qmax(),
- 0, &fully_connected_op));
+ transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+ &fully_connected_op));
// Smart pointer to automatically delete fully_connected_op.
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -249,11 +271,22 @@
} else {
std::fill(output_ref.begin(), output_ref.end(), 0.0f);
}
- for (size_t i = 0; i < batch_size(); i++) {
- for (size_t oc = 0; oc < output_channels(); oc++) {
- for (size_t ic = 0; ic < input_channels(); ic++) {
- output_ref[i * output_channels() + oc] +=
- input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+ if (transpose_weights()) {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ output_ref[i * output_channels() + oc] +=
+ input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
+ }
+ }
+ }
+ } else {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ output_ref[i * output_channels() + oc] +=
+ input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+ }
}
}
}
@@ -280,7 +313,8 @@
input_stride(), output_stride(),
kernel.data(), has_bias() ? bias.data() : nullptr,
output_min, output_max,
- 0, &fully_connected_op));
+ transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+ &fully_connected_op));
// Smart pointer to automatically delete fully_connected_op.
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -320,6 +354,7 @@
size_t batch_size_{1};
uint8_t qmin_{0};
uint8_t qmax_{255};
+ bool transpose_weights_{false};
bool has_bias_{true};
size_t iterations_{1};
};