Support built-in transposition of static weights in Fully Connected operator
PiperOrigin-RevId: 283628934
diff --git a/include/xnnpack.h b/include/xnnpack.h
index a909585..7ce9b1c 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -30,6 +30,9 @@
/// The convolution operator represents a depthwise convolution, and use HWGo layout for filters.
#define XNN_FLAG_DEPTHWISE_CONVOLUTION 0x00000001
+/// Assume transposed weights in a fully connected operator.
+#define XNN_FLAG_TRANSPOSE_WEIGHTS 0x00000001
+
/// The operator assumes NHWC layout for the input, regardless of the output layout.
#define XNN_FLAG_INPUT_NHWC 0x00000002
diff --git a/src/fully-connected-nc.c b/src/fully-connected-nc.c
index 9399a80..447ad17 100644
--- a/src/fully-connected-nc.c
+++ b/src/fully-connected-nc.c
@@ -143,12 +143,21 @@
}
memset(fully_connected_op->packed_weights, kernel_zero_point, n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
- xnn_pack_q8_gemm_goi_w(
- 1, output_channels, input_channels,
- nr, kr,
- input_zero_point, kernel_zero_point,
- kernel, bias,
- fully_connected_op->packed_weights);
+ if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
+ xnn_pack_q8_gemm_io_w(
+ output_channels, input_channels,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias,
+ fully_connected_op->packed_weights);
+ } else {
+ xnn_pack_q8_gemm_goi_w(
+ 1, output_channels, input_channels,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias,
+ fully_connected_op->packed_weights);
+ }
fully_connected_op->group_input_channels = input_channels;
fully_connected_op->group_output_channels = output_channels;
@@ -276,11 +285,19 @@
}
memset(fully_connected_op->packed_weights, 0, n_stride * (k_stride * sizeof(float) + sizeof(float)));
- xnn_pack_f32_gemm_goi_w(
- 1, output_channels, input_channels,
- nr, kr, sr,
- kernel, bias,
- fully_connected_op->packed_weights);
+ if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
+ xnn_pack_f32_gemm_io_w(
+ output_channels, input_channels,
+ nr, kr, sr,
+ kernel, bias,
+ fully_connected_op->packed_weights);
+ } else {
+ xnn_pack_f32_gemm_goi_w(
+ 1, output_channels, input_channels,
+ nr, kr, sr,
+ kernel, bias,
+ fully_connected_op->packed_weights);
+ }
fully_connected_op->group_input_channels = input_channels;
fully_connected_op->group_output_channels = output_channels;
diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h
index 865c5d7..a68dfbd 100644
--- a/src/xnnpack/pack.h
+++ b/src/xnnpack/pack.h
@@ -66,6 +66,52 @@
} while (--g != 0);
}
+static inline void xnn_pack_q8_gemm_io_w(
+ size_t nc,
+ size_t kc,
+ uint32_t nr,
+ uint32_t kr,
+ uint8_t izp,
+ uint8_t kzp,
+ const uint8_t* k,
+ const int32_t* b,
+ void* packed_w)
+{
+ const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ if XNN_LIKELY(b != NULL) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ } else {
+ size_t n = nr_block_size;
+ do {
+ *((int32_t*) packed_w) = boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ } while (--n != 0);
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const uint8_t kv = k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+ ksum += (int32_t) kv;
+ *((uint8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+ }
+ packed_b[nr_block_offset] -= ksum * (int32_t) izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
+ }
+ }
+}
+
static inline void xnn_pack_q8_conv_goki_w(
size_t g,
size_t nc,
@@ -363,6 +409,37 @@
} while (--g != 0);
}
+static inline void xnn_pack_f16_gemm_io_w(
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ const uint16_t* k,
+ const uint16_t* b,
+ uint16_t* packed_w)
+{
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ if XNN_LIKELY(b != NULL) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
+ }
+ }
+ packed_w += nr;
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ *packed_w++ =
+ k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+ }
+ packed_w += kr - kr_block_size;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+}
+
static inline void xnn_pack_f32_gemm_goi_w(
size_t g,
size_t nc,
@@ -416,6 +493,52 @@
} while (--g != 0);
}
+static inline void xnn_pack_f32_gemm_io_w(
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ size_t sr,
+ const float* k,
+ const float* b,
+ float* packed_w)
+{
+ const size_t skr = sr * kr;
+ const size_t skc = round_down_po2(kc, skr);
+ const size_t sr_mask = (sr - 1) * kr;
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ if XNN_LIKELY(b != NULL) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
+ }
+ }
+ packed_w += nr;
+
+ for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
+ *packed_w++ =
+ k[(round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+ }
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+
+ for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ *packed_w++ =
+ k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+ }
+ packed_w += kr - kr_block_size;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+}
+
static inline void xnn_pack_f32_gemminc_goi_w(
size_t g,
size_t nc,
diff --git a/test/fully-connected-nc.cc b/test/fully-connected-nc.cc
index 91c9508..81477ed 100644
--- a/test/fully-connected-nc.cc
+++ b/test/fully-connected-nc.cc
@@ -60,6 +60,16 @@
.TestQ8();
}
+TEST(FULLY_CONNECTED_NC_Q8, unit_batch_transpose_weights) {
+ FullyConnectedOperatorTester()
+ .transpose_weights(true)
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQ8();
+}
+
TEST(FULLY_CONNECTED_NC_Q8, unit_batch_without_bias) {
FullyConnectedOperatorTester()
.has_bias(false)
@@ -119,6 +129,16 @@
.TestQ8();
}
+TEST(FULLY_CONNECTED_NC_Q8, small_batch_transpose_weights) {
+ FullyConnectedOperatorTester()
+ .transpose_weights(true)
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQ8();
+}
+
TEST(FULLY_CONNECTED_NC_Q8, small_batch_without_bias) {
FullyConnectedOperatorTester()
.has_bias(false)
@@ -178,6 +198,16 @@
.TestF32();
}
+TEST(FULLY_CONNECTED_NC_F32, unit_batch_transpose_weights) {
+ FullyConnectedOperatorTester()
+ .transpose_weights(true)
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestF32();
+}
+
TEST(FULLY_CONNECTED_NC_F32, unit_batch_without_bias) {
FullyConnectedOperatorTester()
.has_bias(false)
@@ -237,6 +267,16 @@
.TestF32();
}
+TEST(FULLY_CONNECTED_NC_F32, small_batch_transpose_weights) {
+ FullyConnectedOperatorTester()
+ .transpose_weights(true)
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestF32();
+}
+
TEST(FULLY_CONNECTED_NC_F32, small_batch_without_bias) {
FullyConnectedOperatorTester()
.has_bias(false)
diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h
index 81370f9..d28a18d 100644
--- a/test/fully-connected-operator-tester.h
+++ b/test/fully-connected-operator-tester.h
@@ -101,6 +101,15 @@
return this->qmax_;
}
+ inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
+ this->transpose_weights_ = transpose_weights;
+ return *this;
+ }
+
+ inline bool transpose_weights() const {
+ return this->transpose_weights_;
+ }
+
inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
this->has_bias_ = has_bias;
return *this;
@@ -152,12 +161,24 @@
} else {
std::fill(accumulators.begin(), accumulators.end(), 0);
}
- for (size_t i = 0; i < batch_size(); i++) {
- for (size_t oc = 0; oc < output_channels(); oc++) {
- for (size_t ic = 0; ic < input_channels(); ic++) {
- accumulators[i * output_channels() + oc] +=
- (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
- (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+ if (transpose_weights()) {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ accumulators[i * output_channels() + oc] +=
+ (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+ (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
+ }
+ }
+ }
+ } else {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ accumulators[i * output_channels() + oc] +=
+ (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+ (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
+ }
}
}
}
@@ -189,7 +210,8 @@
kernel_zero_point, 1.0f /* kernel scale */,
kernel.data(), has_bias() ? bias.data() : nullptr,
output_zero_point, output_scale, qmin(), qmax(),
- 0, &fully_connected_op));
+ transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+ &fully_connected_op));
// Smart pointer to automatically delete fully_connected_op.
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -249,11 +271,22 @@
} else {
std::fill(output_ref.begin(), output_ref.end(), 0.0f);
}
- for (size_t i = 0; i < batch_size(); i++) {
- for (size_t oc = 0; oc < output_channels(); oc++) {
- for (size_t ic = 0; ic < input_channels(); ic++) {
- output_ref[i * output_channels() + oc] +=
- input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+ if (transpose_weights()) {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ output_ref[i * output_channels() + oc] +=
+ input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
+ }
+ }
+ }
+ } else {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ output_ref[i * output_channels() + oc] +=
+ input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
+ }
}
}
}
@@ -280,7 +313,8 @@
input_stride(), output_stride(),
kernel.data(), has_bias() ? bias.data() : nullptr,
output_min, output_max,
- 0, &fully_connected_op));
+ transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+ &fully_connected_op));
// Smart pointer to automatically delete fully_connected_op.
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -320,6 +354,7 @@
size_t batch_size_{1};
uint8_t qmin_{0};
uint8_t qmax_{255};
+ bool transpose_weights_{false};
bool has_bias_{true};
size_t iterations_{1};
};