Snap for 4620899 from 0ed4f31d5ced2432473aa7063bc1e28d990ff3f2 to pi-release
Change-Id: I656a8f97b61e98c1ed76ba1e52d597b6d2f7353d
diff --git a/doc/quantization.md b/doc/quantization.md
index 3e0df16..3a8f72b 100644
--- a/doc/quantization.md
+++ b/doc/quantization.md
@@ -301,7 +301,7 @@
The specific output pipeline stage implementing the present quantization
paradigm, i.e. implementing the precise computation detailed in the previous
section (equation (5)), is
-`OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint`.
+`OutputStageQuantizeDownInt32ByFixedPoint`.
Please refer to the comment explaining it in
[public/output_stages.h](../public/output_stages.h).
@@ -313,7 +313,7 @@
document boils down to the difference between the legacy output stage
implementing it, `OutputStageQuantizeDownInt32ToUint8Scale`, and the new output
stage implementing the new paradigm,
-`OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint`.
+`OutputStageQuantizeDownInt32ByFixedPoint`.
Please refer to the comments in
[public/output_stages.h](../public/output_stages.h) for details about these two
@@ -323,7 +323,7 @@
1. The int32 accumulators (inputs to the output stage) undergo a plain int32
multiplication with a int32 multiplier, which may overflow. By contrast, in
- the newer `OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint`, this
+ the newer `OutputStageQuantizeDownInt32ByFixedPoint`, this
integer multiplication becomes a fixed-point multiplication and cannot
overflow.
diff --git a/doc/quantization_example.cc b/doc/quantization_example.cc
index 4368de2..d7b147d 100644
--- a/doc/quantization_example.cc
+++ b/doc/quantization_example.cc
@@ -201,7 +201,7 @@
//
// This is how to obtain the fixed-point multiplier and right shift
// parameters to pass to
-// OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint.
+// OutputStageQuantizeDownInt32ByFixedPoint.
//
// Note: all this code only needs to run offline to generate the quantized
// neural network workload, not at runtime on the
@@ -347,7 +347,7 @@
<< "use quantized arithmetic.\n"
<< std::endl;
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint
quantize_down_stage;
quantize_down_stage.result_offset_after_shift = result_offset;
quantize_down_stage.result_fixedpoint_multiplier = quantized_multiplier;
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index e21337f..d39341b 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -50,6 +50,12 @@
static const int kLanes = 1;
};
+template <>
+struct FixedPointRawTypeTraits<std::int16_t> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 1;
+};
+
// Returns a SIMD value duplicating a scalar value across all lanes.
template <typename tRawType>
tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
@@ -217,6 +223,50 @@
return static_cast<std::int32_t>((sum + sign) / 2);
}
+template <>
+inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t sum = a32 + b32;
+ std::int32_t sign = sum >= 0 ? 1 : -1;
+ return static_cast<std::int16_t>((sum + sign) / 2);
+}
+
+template <typename IntegerType>
+IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
+ static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ return a;
+}
+
+// So far this is only needed for int16.
+template <>
+inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t sum = a32 + b32;
+ return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
+}
+
+// Returns a+b, saturating if the integers are 16bit or narrower,
+// otherwise just a plain addition.
+template <typename IntegerType, bool Is16Bit>
+struct AddSaturatingIf16BitImpl {
+ static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
+};
+template <typename IntegerType>
+struct AddSaturatingIf16BitImpl<IntegerType, true> {
+ static IntegerType Run(IntegerType a, IntegerType b) {
+ return SaturatingAdd(a, b);
+ }
+};
+template <typename IntegerType>
+IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
+ using ScalarType =
+ typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+ return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
+ b);
+}
+
// Returns the integer that represents the product of two fixed-point
// numbers, interpreting all integers as fixed-point values in the
// interval [-1, 1), rounding to the nearest value, and saturating
@@ -266,14 +316,23 @@
return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
}
+template <>
+inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
+ std::int16_t b) {
+ bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
+ std::int32_t a_32(a);
+ std::int32_t b_32(b);
+ std::int32_t ab_32 = a_32 * b_32;
+ std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
+ std::int16_t ab_x2_high16 =
+ static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
+ return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
+}
+
// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
template <typename IntegerType>
inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
- using ScalarIntegerType =
- typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
- "Currently only supporting int32 scalar and SIMD types");
assert(exponent >= 0);
assert(exponent <= 31);
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
@@ -304,14 +363,14 @@
static IntegerType eval(IntegerType x) {
using ScalarIntegerType =
typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
- "Currently only supporting int32 scalar and SIMD types");
const IntegerType min =
- Dup<IntegerType>(std::numeric_limits<std::int32_t>::min());
+ Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
const IntegerType max =
- Dup<IntegerType>(std::numeric_limits<std::int32_t>::max());
+ Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+ const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
- const std::int32_t threshold = ((1 << (31 - Exponent)) - 1);
+ const std::int32_t threshold =
+ ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
const IntegerType positive_mask =
MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
const IntegerType negative_mask =
@@ -425,15 +484,16 @@
static FixedPoint Zero() { return FromScalarRaw(0); }
static FixedPoint One() {
- return FromScalarRaw(kIntegerBits == 0
- ? ScalarRawMax()
- : (ScalarRawType(1) << kFractionalBits));
+ return FromScalarRaw(
+ kIntegerBits == 0
+ ? ScalarRawMax()
+ : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
}
static FixedPoint FromDouble(double x) {
const double min_bound = static_cast<double>(ScalarRawMin());
const double max_bound = static_cast<double>(ScalarRawMax());
- return FromScalarRaw(static_cast<std::int32_t>(std::min(
+ return FromScalarRaw(static_cast<ScalarRawType>(std::min(
std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
min_bound),
max_bound)));
@@ -555,6 +615,22 @@
return !(a == b);
}
+template <typename tRawType, int tIntegerBits>
+FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
+ FixedPoint<tRawType, tIntegerBits> a,
+ FixedPoint<tRawType, tIntegerBits> b) {
+ return FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingAdd(a.raw(), b.raw()));
+}
+
+template <typename tRawType, int tIntegerBits>
+FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
+ FixedPoint<tRawType, tIntegerBits> a,
+ FixedPoint<tRawType, tIntegerBits> b) {
+ return FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ AddSaturatingIf16Bit(a.raw(), b.raw()));
+}
+
// Conversion to floating-point.
template <typename tRawType, int tIntegerBits>
double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
@@ -579,23 +655,41 @@
// initialized as real numbers, in a way that does not compile floating-point
// arithmetic in production code, yet still checks agreement with the
// floating-point expressions when asserts are enabled.
+//
+// The raw integer value provided is always a int32, encoding a 32-bit
+// fixed-point value, regardless of the actual Scalar type. This allows
+// writing generic code that applies just as well to the 32-bit and 16-bit
+// cases. In the 16-bit case, the raw integer value is internally
+// rounding-shifted by 16 bits to the right.
+template <typename FixedPointType>
+inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
+ std::int32_t int32_value) {
+ typedef typename FixedPointType::ScalarRawType ScalarRawType;
+ static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
+ return static_cast<ScalarRawType>(
+ RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
+}
#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
template <typename FixedPointType>
-FixedPointType CheckedFixedPointConstant(
- typename FixedPointType::ScalarRawType raw_value, double double_value) {
- typedef typename FixedPointType::RawType RawType;
+FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
+ double double_value) {
const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
assert(result == FixedPointType::FromDouble(double_value));
return result;
}
-#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
- DoubleValue) \
- (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue))
+#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
+ ScalarRawInt32Value, DoubleValue) \
+ (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \
+ gemmlowp::RescaleConstantInitializer<FixedPointType>( \
+ ScalarRawInt32Value), \
+ DoubleValue))
#else
-#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
- DoubleValue) \
- (FixedPointType::FromScalarRaw(ScalarRawValue))
+#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
+ ScalarRawInt32Value, DoubleValue) \
+ (FixedPointType::FromScalarRaw( \
+ gemmlowp::RescaleConstantInitializer<FixedPointType>( \
+ ScalarRawInt32Value)))
#endif
// Implementation of exponential function.
@@ -620,8 +714,9 @@
F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
SaturatingRoundingMultiplyByPOT<-1>(
((x4_over_4 + x3) * constant_1_over_3) + x2);
- return constant_term +
- constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
+ return AddSaturatingIf16Bit(
+ constant_term,
+ constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
}
// Returns exp(x) for x < 0.
@@ -661,7 +756,7 @@
#undef GEMMLOWP_EXP_BARREL_SHIFTER
if (kIntegerBits > 5) {
- static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0;
+ static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
const InputF clamp =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
@@ -774,6 +869,8 @@
#include "./fixedpoint_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "./fixedpoint_sse.h"
+#elif defined(GEMMLOWP_MSA)
+#include "./fixedpoint_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h
new file mode 100644
index 0000000..c7a110c
--- /dev/null
+++ b/fixedpoint/fixedpoint_msa.h
@@ -0,0 +1,354 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// fixedpoint_msa.h: optimized MSA specializations of the templates
+// in fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
+#define GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+template <>
+struct FixedPointRawTypeTraits<v4i32> {
+ typedef std::int32_t ScalarRawType;
+ static const int kLanes = 4;
+};
+
+template <>
+struct FixedPointRawTypeTraits<v8i16> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 8;
+};
+
+template <>
+inline v4i32 BitAnd(v4i32 a, v4i32 b) {
+ return reinterpret_cast<v4i32>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v8i16 BitAnd(v8i16 a, v8i16 b) {
+ return reinterpret_cast<v8i16>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v4i32 BitOr(v4i32 a, v4i32 b) {
+ return reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v8i16 BitOr(v8i16 a, v8i16 b) {
+ return reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v4i32 BitXor(v4i32 a, v4i32 b) {
+ return reinterpret_cast<v4i32>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v8i16 BitXor(v8i16 a, v8i16 b) {
+ return reinterpret_cast<v8i16>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v4i32 BitNot(v4i32 a) {
+ return reinterpret_cast<v4i32>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(a)));
+}
+
+template <>
+inline v8i16 BitNot(v8i16 a) {
+ return reinterpret_cast<v8i16>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(a)));
+}
+
+template <>
+inline v4i32 Add(v4i32 a, v4i32 b) {
+ return __builtin_msa_addv_w(a, b);
+}
+
+template <>
+inline v8i16 Add(v8i16 a, v8i16 b) {
+ return __builtin_msa_addv_h(a, b);
+}
+
+template <>
+inline v4i32 Sub(v4i32 a, v4i32 b) {
+ return __builtin_msa_subv_w(a, b);
+}
+
+template <>
+inline v8i16 Sub(v8i16 a, v8i16 b) {
+ return __builtin_msa_subv_h(a, b);
+}
+
+template <>
+inline v4i32 Neg(v4i32 a) {
+ v4i32 zeroes = __builtin_msa_ldi_w(0);
+ return __builtin_msa_subv_w(zeroes, a);
+}
+
+template <>
+inline v8i16 Neg(v8i16 a) {
+ v8i16 zeroes = __builtin_msa_ldi_h(0);
+ return __builtin_msa_subv_h(zeroes, a);
+}
+
+template <>
+inline v4i32 ShiftLeft(v4i32 a, int offset) {
+ return __builtin_msa_sll_w(a, __builtin_msa_fill_w(offset));
+}
+
+template <>
+inline v8i16 ShiftLeft(v8i16 a, int offset) {
+ return __builtin_msa_sll_h(a, __builtin_msa_fill_h(offset));
+}
+
+template <>
+inline v4i32 ShiftRight(v4i32 a, int offset) {
+ return __builtin_msa_sra_w(a, __builtin_msa_fill_w(offset));
+}
+
+template <>
+inline v8i16 ShiftRight(v8i16 a, int offset) {
+ return __builtin_msa_sra_h(a, __builtin_msa_fill_h(offset));
+}
+
+template <>
+inline v4i32 SelectUsingMask(v4i32 if_mask, v4i32 then_val, v4i32 else_val) {
+ if_mask = reinterpret_cast<v4i32>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
+ reinterpret_cast<v16u8>(else_val),
+ reinterpret_cast<v16u8>(then_val)));
+ return if_mask;
+}
+
+template <>
+inline v8i16 SelectUsingMask(v8i16 if_mask, v8i16 then_val, v8i16 else_val) {
+ if_mask = reinterpret_cast<v8i16>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
+ reinterpret_cast<v16u8>(else_val),
+ reinterpret_cast<v16u8>(then_val)));
+ return if_mask;
+}
+
+template <>
+inline v4i32 MaskIfEqual(v4i32 a, v4i32 b) {
+ return __builtin_msa_ceq_w(a, b);
+}
+
+template <>
+inline v8i16 MaskIfEqual(v8i16 a, v8i16 b) {
+ return __builtin_msa_ceq_h(a, b);
+}
+
+template <>
+inline v4i32 MaskIfNotEqual(v4i32 a, v4i32 b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
+inline v8i16 MaskIfNotEqual(v8i16 a, v8i16 b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
+inline v4i32 MaskIfZero(v4i32 a) {
+ return __builtin_msa_ceqi_w(a, 0);
+}
+
+template <>
+inline v8i16 MaskIfZero(v8i16 a) {
+ return __builtin_msa_ceqi_h(a, 0);
+}
+
+template <>
+inline v4i32 MaskIfNonZero(v4i32 a) {
+ return BitNot(MaskIfZero(a));
+}
+
+template <>
+inline v8i16 MaskIfNonZero(v8i16 a) {
+ return BitNot(MaskIfZero(a));
+}
+
+template <>
+inline v4i32 MaskIfGreaterThan(v4i32 a, v4i32 b) {
+ return __builtin_msa_clt_s_w(b, a);
+}
+
+template <>
+inline v8i16 MaskIfGreaterThan(v8i16 a, v8i16 b) {
+ return __builtin_msa_clt_s_h(b, a);
+}
+
+template <>
+inline v4i32 MaskIfGreaterThanOrEqual(v4i32 a, v4i32 b) {
+ return __builtin_msa_cle_s_w(b, a);
+}
+
+template <>
+inline v8i16 MaskIfGreaterThanOrEqual(v8i16 a, v8i16 b) {
+ return __builtin_msa_cle_s_h(b, a);
+}
+
+template <>
+inline v4i32 MaskIfLessThan(v4i32 a, v4i32 b) {
+ return __builtin_msa_clt_s_w(a, b);
+}
+
+template <>
+inline v8i16 MaskIfLessThan(v8i16 a, v8i16 b) {
+ return __builtin_msa_clt_s_h(a, b);
+}
+
+template <>
+inline v4i32 MaskIfLessThanOrEqual(v4i32 a, v4i32 b) {
+ return __builtin_msa_cle_s_w(a, b);
+}
+
+template <>
+inline v8i16 MaskIfLessThanOrEqual(v8i16 a, v8i16 b) {
+ return __builtin_msa_cle_s_h(a, b);
+}
+
+template <>
+inline bool All(v4i32 a) {
+ return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
+}
+
+template <>
+inline bool All(v8i16 a) {
+ return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
+}
+
+template <>
+inline bool Any(v4i32 a) {
+ return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
+}
+
+template <>
+inline bool Any(v8i16 a) {
+ return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
+}
+
+template <>
+inline v4i32 RoundingHalfSum(v4i32 a, v4i32 b) {
+ return __builtin_msa_aver_s_w(a, b);
+}
+
+template <>
+inline v8i16 RoundingHalfSum(v8i16 a, v8i16 b) {
+ return __builtin_msa_aver_s_h(a, b);
+}
+
+template <>
+inline v4i32 SaturatingRoundingDoublingHighMul(v4i32 a, v4i32 b) {
+ return __builtin_msa_mulr_q_w(a, b);
+}
+
+template <>
+inline v8i16 SaturatingRoundingDoublingHighMul(v8i16 a, v8i16 b) {
+ return __builtin_msa_mulr_q_h(a, b);
+}
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, 1> {
+ static v4i32 eval(v4i32 x) {
+ static_assert(Exponent >= 0 && Exponent < 32, "");
+ if (Exponent < 5) {
+ for (int i = 0; i < Exponent; i++) {
+ x = __builtin_msa_adds_s_w(x, x);
+ }
+ return x;
+ } else {
+ // Saturate each signed 32-bit element to (32 - Exponent)
+ // bits (this takes full care of negative elements).
+ v4i32 res = __builtin_msa_sat_s_w(x, 31 - Exponent);
+ // Set tmp to 0x7FFFFFFF for those elements which staturated
+ // to smaller (positive) values and 0 for all others.
+ v4i32 tmp = __builtin_msa_srli_w(__builtin_msa_clt_s_w(res, x), 1);
+ // Shift the saturated elements. The positive saturated elements
+ // will have Exponent trailing zero bits after the shift. Those
+ // need to be ones, not zeroes.
+ res = __builtin_msa_slli_w(res, Exponent);
+ // Finally, set those trailing zero bits to ones.
+ res = reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
+ reinterpret_cast<v16u8>(tmp)));
+ return res;
+ }
+ }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> {
+ static v8i16 eval(v8i16 x) {
+ static_assert(Exponent >= 0 && Exponent < 16, "");
+ if (Exponent < 5) {
+ for (int i = 0; i < Exponent; i++) {
+ x = __builtin_msa_adds_s_h(x, x);
+ }
+ return x;
+ } else {
+ // Saturate each signed 16-bit element to (16 - Exponent)
+ // bits (this takes full care of negative elements).
+ v8i16 res = __builtin_msa_sat_s_h(x, 15 - Exponent);
+ // Set tmp to 0x7FFF for those elements which staturated
+ // to smaller (positive) values and 0 for all others.
+ v8i16 tmp = __builtin_msa_srli_h(__builtin_msa_clt_s_h(res, x), 1);
+ // Shift the saturated elements. The positive saturated elements
+ // will have Exponent trailing zero bits after the shift. Those
+ // need to be ones, not zeroes.
+ res = __builtin_msa_slli_h(res, Exponent);
+ // Finally, set those trailing zero bits to ones.
+ res = reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
+ reinterpret_cast<v16u8>(tmp)));
+ return res;
+ }
+ }
+};
+
+// TODO: possibly implement:
+// template <> v4i32 RoundingDivideByPOT(v4i32, int)
+// template <> v8i16 RoundingDivideByPOT(v8i16, int)
+// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
+// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
+
+template <>
+inline v4i32 Dup<v4i32>(std::int32_t x) {
+ return __builtin_msa_fill_w(x);
+}
+
+template <>
+inline v8i16 Dup<v8i16>(std::int16_t x) {
+ return __builtin_msa_fill_h(x);
+}
+
+// So far this is only needed for int16.
+template <>
+inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
+ return __builtin_msa_adds_s_h(a, b);
+ return a;
+}
+
+} // end namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h
index 8b23de2..92b349b 100644
--- a/fixedpoint/fixedpoint_neon.h
+++ b/fixedpoint/fixedpoint_neon.h
@@ -29,97 +29,194 @@
};
template <>
+struct FixedPointRawTypeTraits<int16x8_t> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 8;
+};
+
+template <>
inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) {
return vandq_s32(a, b);
}
template <>
+inline int16x8_t BitAnd(int16x8_t a, int16x8_t b) {
+ return vandq_s16(a, b);
+}
+
+template <>
inline int32x4_t BitOr(int32x4_t a, int32x4_t b) {
return vorrq_s32(a, b);
}
template <>
+inline int16x8_t BitOr(int16x8_t a, int16x8_t b) {
+ return vorrq_s16(a, b);
+}
+
+template <>
inline int32x4_t BitXor(int32x4_t a, int32x4_t b) {
return veorq_s32(a, b);
}
template <>
+inline int16x8_t BitXor(int16x8_t a, int16x8_t b) {
+ return veorq_s16(a, b);
+}
+
+template <>
inline int32x4_t BitNot(int32x4_t a) {
return veorq_s32(a, vdupq_n_s32(-1));
}
template <>
+inline int16x8_t BitNot(int16x8_t a) {
+ return veorq_s16(a, vdupq_n_s16(-1));
+}
+
+template <>
inline int32x4_t Add(int32x4_t a, int32x4_t b) {
return vaddq_s32(a, b);
}
template <>
+inline int16x8_t Add(int16x8_t a, int16x8_t b) {
+ return vaddq_s16(a, b);
+}
+
+template <>
inline int32x4_t Sub(int32x4_t a, int32x4_t b) {
return vsubq_s32(a, b);
}
template <>
+inline int16x8_t Sub(int16x8_t a, int16x8_t b) {
+ return vsubq_s16(a, b);
+}
+
+template <>
inline int32x4_t Neg(int32x4_t a) {
return vnegq_s32(a);
}
template <>
+inline int16x8_t Neg(int16x8_t a) {
+ return vnegq_s16(a);
+}
+
+template <>
inline int32x4_t ShiftLeft(int32x4_t a, int offset) {
return vshlq_s32(a, vdupq_n_s32(offset));
}
template <>
+inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
+ return vshlq_s16(a, vdupq_n_s16(offset));
+}
+
+template <>
inline int32x4_t ShiftRight(int32x4_t a, int offset) {
return vshlq_s32(a, vdupq_n_s32(-offset));
}
template <>
+inline int16x8_t ShiftRight(int16x8_t a, int offset) {
+ return vshlq_s16(a, vdupq_n_s16(-offset));
+}
+
+template <>
inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val,
int32x4_t else_val) {
return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val);
}
template <>
+inline int16x8_t SelectUsingMask(int16x8_t if_mask, int16x8_t then_val,
+ int16x8_t else_val) {
+ return vbslq_s16(vreinterpretq_u16_s16(if_mask), then_val, else_val);
+}
+
+template <>
inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vceqq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfEqual(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vceqq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) {
return BitNot(MaskIfEqual(a, b));
}
template <>
+inline int16x8_t MaskIfNotEqual(int16x8_t a, int16x8_t b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
inline int32x4_t MaskIfZero(int32x4_t a) {
return MaskIfEqual(a, vdupq_n_s32(0));
}
template <>
+inline int16x8_t MaskIfZero(int16x8_t a) {
+ return MaskIfEqual(a, vdupq_n_s16(0));
+}
+
+template <>
inline int32x4_t MaskIfNonZero(int32x4_t a) {
return vreinterpretq_s32_u32(vtstq_s32(a, a));
}
template <>
+inline int16x8_t MaskIfNonZero(int16x8_t a) {
+ return vreinterpretq_s16_u16(vtstq_s16(a, a));
+}
+
+template <>
inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcgtq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfGreaterThan(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcgtq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcgeq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfGreaterThanOrEqual(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcgeq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcltq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfLessThan(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcltq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcleq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfLessThanOrEqual(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcleq_s16(a, b));
+}
+
+template <>
inline bool All(int32x4_t a) {
a = vandq_s32(a, vextq_s32(a, a, 1));
a = vandq_s32(a, vextq_s32(a, a, 2));
@@ -127,6 +224,14 @@
}
template <>
+inline bool All(int16x8_t a) {
+ a = vandq_s16(a, vextq_s16(a, a, 1));
+ a = vandq_s16(a, vextq_s16(a, a, 2));
+ a = vandq_s16(a, vextq_s16(a, a, 4));
+ return vgetq_lane_s16(a, 0);
+}
+
+template <>
inline bool Any(int32x4_t a) {
a = vorrq_s32(a, vextq_s32(a, a, 1));
a = vorrq_s32(a, vextq_s32(a, a, 2));
@@ -134,16 +239,34 @@
}
template <>
+inline bool Any(int16x8_t a) {
+ a = vorrq_s16(a, vextq_s16(a, a, 1));
+ a = vorrq_s16(a, vextq_s16(a, a, 2));
+ a = vorrq_s16(a, vextq_s16(a, a, 4));
+ return vgetq_lane_s16(a, 0);
+}
+
+template <>
inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) {
return vrhaddq_s32(a, b);
}
template <>
+inline int16x8_t RoundingHalfSum(int16x8_t a, int16x8_t b) {
+ return vrhaddq_s16(a, b);
+}
+
+template <>
inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) {
return vqrdmulhq_s32(a, b);
}
template <>
+inline int16x8_t SaturatingRoundingDoublingHighMul(int16x8_t a, int16x8_t b) {
+ return vqrdmulhq_s16(a, b);
+}
+
+template <>
inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) {
const int32x4_t shift_vec = vdupq_n_s32(-exponent);
const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
@@ -151,6 +274,14 @@
return vrshlq_s32(fixed_up_x, shift_vec);
}
+template <>
+inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
+ const int16x8_t shift_vec = vdupq_n_s16(-exponent);
+ const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
+ const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
+ return vrshlq_s16(fixed_up_x, shift_vec);
+}
+
template <int Exponent>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
@@ -165,11 +296,36 @@
}
};
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, 1> {
+ static int16x8_t eval(int16x8_t x) { return vqshlq_n_s16(x, Exponent); }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, -1> {
+ static int16x8_t eval(int16x8_t x) {
+ const int16x8_t fixup = vshrq_n_s16(x, 15);
+ const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
+ return vrshrq_n_s16(fixed_up_x, -Exponent);
+ }
+};
+
template <>
inline int32x4_t Dup<int32x4_t>(std::int32_t x) {
return vdupq_n_s32(x);
}
+template <>
+inline int16x8_t Dup<int16x8_t>(std::int16_t x) {
+ return vdupq_n_s16(x);
+}
+
+// So far this is only needed for int16.
+template <>
+inline int16x8_t SaturatingAdd(int16x8_t a, int16x8_t b) {
+ return vqaddq_s16(a, b);
+}
+
} // end namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h
index 3f2654d..ba990f0 100644
--- a/fixedpoint/fixedpoint_sse.h
+++ b/fixedpoint/fixedpoint_sse.h
@@ -23,6 +23,22 @@
namespace gemmlowp {
+// SSE intrinsics are not finely typed: there is a single __m128i vector
+// type that does not distinguish between "int32x4" and "int16x8" use
+// cases, unlike the NEON equivalents. Because we had initially focused
+// on int32x4, we did not pay attention and specialized these fixedpoint
+// templates directly for __m128i hardcoding the int32x4 semantics,
+// not leaving room for int16x8 semantics. Amending that by adding a separate
+// data type, int16x8_m128i, that wraps __m128i while being a separate
+// type.
+struct int16x8_m128i {
+ int16x8_m128i() {}
+ explicit int16x8_m128i(__m128i w) : v(w) {}
+ ~int16x8_m128i() {}
+
+ __m128i v;
+};
+
template <>
struct FixedPointRawTypeTraits<__m128i> {
typedef std::int32_t ScalarRawType;
@@ -30,61 +46,125 @@
};
template <>
+struct FixedPointRawTypeTraits<int16x8_m128i> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 8;
+};
+
+template <>
inline __m128i BitAnd(__m128i a, __m128i b) {
return _mm_and_si128(a, b);
}
template <>
+inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_and_si128(a.v, b.v));
+}
+
+template <>
inline __m128i BitOr(__m128i a, __m128i b) {
return _mm_or_si128(a, b);
}
template <>
+inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_or_si128(a.v, b.v));
+}
+
+template <>
inline __m128i BitXor(__m128i a, __m128i b) {
return _mm_xor_si128(a, b);
}
template <>
+inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_xor_si128(a.v, b.v));
+}
+
+template <>
inline __m128i BitNot(__m128i a) {
return _mm_andnot_si128(a, _mm_set1_epi32(-1));
}
template <>
+inline int16x8_m128i BitNot(int16x8_m128i a) {
+ return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1)));
+}
+
+template <>
inline __m128i Add(__m128i a, __m128i b) {
return _mm_add_epi32(a, b);
}
template <>
+inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_add_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i Mul(__m128i a, __m128i b) {
return _mm_mullo_epi32(a, b);
}
template <>
+inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_mullo_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i Sub(__m128i a, __m128i b) {
return _mm_sub_epi32(a, b);
}
template <>
+inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_sub_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i Neg(__m128i a) {
return _mm_sign_epi32(a, _mm_set1_epi32(-1));
}
template <>
+inline int16x8_m128i Neg(int16x8_m128i a) {
+ return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1)));
+}
+
+template <>
inline __m128i ShiftLeft(__m128i a, int offset) {
return _mm_slli_epi32(a, offset);
}
template <>
+inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) {
+ return int16x8_m128i(_mm_slli_epi16(a.v, offset));
+}
+
+template <>
inline __m128i ShiftRight(__m128i a, int offset) {
return _mm_srai_epi32(a, offset);
}
template <>
+inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) {
+ return int16x8_m128i(_mm_srai_epi16(a.v, offset));
+}
+
+template <>
inline __m128i SelectUsingMask(__m128i if_mask, __m128i then_val,
__m128i else_val) {
- return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(else_val),
- _mm_castsi128_ps(then_val),
- _mm_castsi128_ps(if_mask)));
+ // borrowed from Intel's arm_neon_sse.h header.
+ return _mm_or_si128(_mm_and_si128(if_mask, then_val),
+ _mm_andnot_si128(if_mask, else_val));
+}
+
+template <>
+inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask,
+ int16x8_m128i then_val,
+ int16x8_m128i else_val) {
+ // borrowed from Intel's arm_neon_sse.h header.
+ return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v));
}
template <>
@@ -93,40 +173,81 @@
}
template <>
+inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i MaskIfNotEqual(__m128i a, __m128i b) {
return BitNot(MaskIfEqual(a, b));
}
template <>
+inline int16x8_m128i MaskIfNotEqual(int16x8_m128i a, int16x8_m128i b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
inline __m128i MaskIfZero(__m128i a) {
return MaskIfEqual(a, _mm_set1_epi32(0));
}
template <>
+inline int16x8_m128i MaskIfZero(int16x8_m128i a) {
+ return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
+}
+
+template <>
inline __m128i MaskIfNonZero(__m128i a) {
return MaskIfNotEqual(a, _mm_set1_epi32(0));
}
template <>
+inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) {
+ return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
+}
+
+template <>
inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) {
return _mm_cmpgt_epi32(a, b);
}
template <>
+inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i MaskIfLessThan(__m128i a, __m128i b) {
return _mm_cmplt_epi32(a, b);
}
template <>
+inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i MaskIfGreaterThanOrEqual(__m128i a, __m128i b) {
return BitNot(MaskIfLessThan(a, b));
}
template <>
+inline int16x8_m128i MaskIfGreaterThanOrEqual(int16x8_m128i a,
+ int16x8_m128i b) {
+ return BitNot(MaskIfLessThan(a, b));
+}
+
+template <>
inline __m128i MaskIfLessThanOrEqual(__m128i a, __m128i b) {
return BitNot(MaskIfGreaterThan(a, b));
}
+template <>
+inline int16x8_m128i MaskIfLessThanOrEqual(int16x8_m128i a, int16x8_m128i b) {
+ return BitNot(MaskIfGreaterThan(a, b));
+}
+
/* Assumptions:
- All and Any are used on masks.
- masks are all_ones for true lanes, all_zeroes otherwise.
@@ -139,8 +260,18 @@
}
template <>
+inline bool All(int16x8_m128i a) {
+ return _mm_testc_si128(a.v, a.v);
+}
+
+template <>
inline bool Any(__m128i a) {
- return BitNot(_mm_testz_si128(a, a));
+ return !_mm_testz_si128(a, a);
+}
+
+template <>
+inline bool Any(int16x8_m128i a) {
+ return !_mm_testz_si128(a.v, a.v);
}
template <>
@@ -171,6 +302,18 @@
}
template <>
+inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) {
+ // Idea: go to unsigned to use _mm_avg_epu16,
+ // borrowed from Intel's arm_neon_sse.h header.
+ __m128i constant_neg_32768 = _mm_set1_epi16(-32768);
+ __m128i a_unsigned = _mm_sub_epi16(a.v, constant_neg_32768);
+ __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768);
+ __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned);
+ __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768);
+ return int16x8_m128i(avg);
+}
+
+template <>
inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) {
__m128i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
__m128i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
@@ -209,10 +352,33 @@
}
template <>
+inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a,
+ int16x8_m128i b) {
+ // Idea: use _mm_mulhrs_epi16 then saturate with a bit-operation,
+ // borrowed from Intel's arm_neon_sse.h header.
+ __m128i result_unsaturated = _mm_mulhrs_epi16(a.v, b.v);
+ __m128i saturation_mask =
+ _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000));
+ __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask);
+ return int16x8_m128i(result);
+}
+
+template <>
inline __m128i Dup<__m128i>(std::int32_t x) {
return _mm_set1_epi32(x);
}
+template <>
+inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) {
+ return int16x8_m128i(_mm_set1_epi16(x));
+}
+
+// So far this is only needed for int16.
+template <>
+inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_adds_epi16(a.v, b.v));
+}
+
} // end namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_
diff --git a/internal/common.h b/internal/common.h
index 9de151b..26b6713 100644
--- a/internal/common.h
+++ b/internal/common.h
@@ -55,6 +55,19 @@
#define GEMMLOWP_ARM
#endif
+// Detect MIPS, 32-bit or 64-bit
+#if defined(__mips) && !defined(__LP64__)
+#define GEMMLOWP_MIPS_32
+#endif
+
+#if defined(__mips) && defined(__LP64__)
+#define GEMMLOWP_MIPS_64
+#endif
+
+#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MIPS
+#endif
+
// Detect x86, 32-bit or 64-bit
#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
#define GEMMLOWP_X86_32
@@ -87,6 +100,23 @@
#define GEMMLOWP_NEON_64
#endif
+// Detect MIPS MSA.
+// Limit MSA optimizations to little-endian CPUs for now.
+// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
+#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
+ defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define GEMMLOWP_MSA
+#endif
+
+// Convenience MIPS MSA tokens for 32-bit or 64-bit.
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
+#define GEMMLOWP_MSA_32
+#endif
+
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MSA_64
+#endif
+
// Detect SSE.
#ifdef __SSE4_1__
#define GEMMLOWP_SSE4
@@ -97,7 +127,8 @@
#endif
// Convenience SSE4 tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32)
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
#define GEMMLOWP_SSE4_32
#endif
@@ -105,7 +136,8 @@
#define GEMMLOWP_SSE3_32
#endif
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64)
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
#define GEMMLOWP_SSE4_64
#endif
@@ -178,6 +210,10 @@
// x86-32 and not Android. Same as x86-64 but less bullish.
const int kDefaultL1CacheSize = 32 * 1024;
const int kDefaultL2CacheSize = 2 * 1024 * 1024;
+#elif defined(GEMMLOWP_MIPS)
+// MIPS and not Android. TODO: MIPS and Android?
+const int kDefaultL1CacheSize = 32 * 1024;
+const int kDefaultL2CacheSize = 1024 * 1024;
#else
// Less common hardware. Maybe some unusual or older or embedded thing.
// Assume smaller caches, but don't depart too far from what we do
diff --git a/internal/kernel_default.h b/internal/kernel_default.h
index 7037bda..a919ffe 100644
--- a/internal/kernel_default.h
+++ b/internal/kernel_default.h
@@ -18,18 +18,13 @@
#ifndef GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
#define GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
-#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#endif
-
#include "../public/bit_depth.h"
#include "common.h"
#include "kernel_reference.h"
namespace gemmlowp {
-template <bool MaxProductIsLessThan4096,
- bool LhsAlwaysNonzero>
+template <bool MaxProductIsLessThan4096, bool LhsAlwaysNonzero>
struct DefaultKernelImpl {};
// Partial specialization implementing the logic that if we want to use
@@ -56,12 +51,12 @@
} // end namespace gemmlowp
-#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
- LhsAlwaysNonzero, Kernel) \
- namespace gemmlowp { \
- template <> \
- struct DefaultKernelImpl<MaxProductIsLessThan4096, \
- LhsAlwaysNonzero> : Kernel {}; \
+#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
+ LhsAlwaysNonzero, Kernel) \
+ namespace gemmlowp { \
+ template <> \
+ struct DefaultKernelImpl<MaxProductIsLessThan4096, LhsAlwaysNonzero> \
+ : Kernel {}; \
}
#if defined GEMMLOWP_NEON_32
@@ -76,6 +71,9 @@
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2)
GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
NEON_64bit_GEMM_Int8Operands_LhsNonzero)
+#elif defined(GEMMLOWP_MSA)
+#include "kernel_msa.h"
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2)
#elif defined GEMMLOWP_SSE4_32
#include "kernel_sse.h"
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
@@ -83,23 +81,6 @@
#include "kernel_sse.h"
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2)
#else
-#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#if defined __ARM_ARCH_5TE__
-// SIMD is not available on this platform. The slow fallback will be used.
-// Don't require GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK because there's nothing
-// the user can do about it.
-#elif defined __powerpc__
-// There is currently no fast kernel using SIMD instructions on POWER. Don't
-// require GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK because there's nothing the user
-// can do about it.
-#else
-#error \
- "SIMD not enabled, you'd be getting a slow software fallback. Consider \
-enabling SIMD extensions (for example using -msse4 if you're on modern x86). \
-If that's not an option, and you would like to continue with the \
-slow fallback, define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK."
-#endif
-#endif
#include "kernel_reference.h"
namespace gemmlowp {
typedef ReferenceKernel<KernelFormat<
diff --git a/internal/kernel_msa.h b/internal/kernel_msa.h
new file mode 100644
index 0000000..4985b73
--- /dev/null
+++ b/internal/kernel_msa.h
@@ -0,0 +1,339 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// kernel_msa.h: a collection of MSA optimized kernels.
+// Check in kernel_default.h which one(s) are actually used by default.
+// Others are mere experiments; they are still covered by tests
+// in case they might be useful some day.
+
+#ifndef GEMMLOWP_INTERNAL_KERNEL_MSA_H_
+#define GEMMLOWP_INTERNAL_KERNEL_MSA_H_
+
+#include "kernel.h"
+
+#include <msa.h>
+#include <cassert>
+
+namespace gemmlowp {
+
+#ifdef GEMMLOWP_MSA
+
+// Some convenience macros to hide differences between MIPS32 and MIPS64.
+#ifdef GEMMLOWP_MIPS_64
+#define GEMMLOWP_MIPS_XADDU "daddu"
+#define GEMMLOWP_MIPS_XADDIU "daddiu"
+#define GEMMLOWP_MIPS_XSLL "dsll"
+#else
+#define GEMMLOWP_MIPS_XADDU "addu"
+#define GEMMLOWP_MIPS_XADDIU "addiu"
+#define GEMMLOWP_MIPS_XSLL "sll"
+#endif
+
+// Our main GEMM kernel.
+struct MSA_Kernel12x8Depth2 : KernelBase {
+ typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
+ KernelSideFormat<CellFormat<4, 2>, 2> >
+ Format;
+
+ const char* Name() const override { return "MSA, 12x8, depth 2"; }
+
+ // TODO(benoitjacob): reorder function arguments so dst comes last
+ void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+ std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+ const std::uint8_t* rhs_ptr, std::size_t start_depth,
+ std::size_t run_depth) const override {
+ ScopedProfilingLabel label("optimized kernel (MSA 12x8)");
+// See comments above for why we need local numerical labels in our asm.
+#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
+#define GEMMLOWP_LABEL_BEFORE_LOOP "2"
+#define GEMMLOWP_LABEL_LOOP "3"
+#define GEMMLOWP_LABEL_AFTER_LOOP "4"
+
+ assert(dst_row_stride == 1);
+ asm volatile(
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ // Multiply dst_col_stride by 4 == sizeof(int32) to use
+ // it as a byte offset below.
+ GEMMLOWP_MIPS_XSLL
+ " %[dst_col_stride], %[dst_col_stride], 2\n"
+
+ // Check if start_depth==0 to decide whether we will clear
+ // accumulators or load existing accumulators.
+ "beqz %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
+
+ // Load accumulators (start_depth != 0).
+ GEMMLOWP_MIPS_XADDU
+ " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ "ld.w $w0, (0*16)(%[dst_ptr])\n"
+ "ld.w $w4, (1*16)(%[dst_ptr])\n"
+ "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w1, (0*16)($a0)\n"
+ "ld.w $w5, (1*16)($a0)\n"
+ "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w2, (0*16)($a1)\n"
+ "ld.w $w6, (1*16)($a1)\n"
+ "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w3, (0*16)($a0)\n"
+ "ld.w $w7, (1*16)($a0)\n"
+ "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w12, (0*16)($a1)\n"
+ "ld.w $w16, (1*16)($a1)\n"
+ "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w13, (0*16)($a0)\n"
+ "ld.w $w17, (1*16)($a0)\n"
+ "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w14, (0*16)($a1)\n"
+ "ld.w $w18, (1*16)($a1)\n"
+ "ld.w $w22, (2*16)($a1)\n"
+ "ld.w $w15, (0*16)($a0)\n"
+ "ld.w $w19, (1*16)($a0)\n"
+ "ld.w $w23, (2*16)($a0)\n"
+ "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
+
+ GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+ ":\n"
+ // Clear accumulators (start_depth == 0).
+ "ldi.w $w0, 0\n"
+ "ldi.w $w4, 0\n"
+ "ldi.w $w8, 0\n"
+ "ldi.w $w1, 0\n"
+ "ldi.w $w5, 0\n"
+ "ldi.w $w9, 0\n"
+ "ldi.w $w2, 0\n"
+ "ldi.w $w6, 0\n"
+ "ldi.w $w10, 0\n"
+ "ldi.w $w3, 0\n"
+ "ldi.w $w7, 0\n"
+ "ldi.w $w11, 0\n"
+ "ldi.w $w12, 0\n"
+ "ldi.w $w16, 0\n"
+ "ldi.w $w20, 0\n"
+ "ldi.w $w13, 0\n"
+ "ldi.w $w17, 0\n"
+ "ldi.w $w21, 0\n"
+ "ldi.w $w14, 0\n"
+ "ldi.w $w18, 0\n"
+ "ldi.w $w22, 0\n"
+ "ldi.w $w15, 0\n"
+ "ldi.w $w19, 0\n"
+ "ldi.w $w23, 0\n"
+
+ GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // (each register contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w27 |w28 |w29 |w30 |
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +------+------+------+------+
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+ - - - - +------+------+------+------+
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+ - - - - +------+------+------+------+
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+ - - - - +------+------+------+------+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w27, $w31, $w24\n"
+ "ilvl.d $w28, $w31, $w25\n"
+ "ilvl.d $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w27, $w24\n"
+ "ilvr.h $w25, $w28, $w25\n"
+ "ilvr.h $w26, $w29, $w26\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for
+ // dpadd_u.w (for the first half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Load 4 bytes of rhs[] for the second half of depth 0.
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the second half of depth 1.
+ "lbu $v0, 12(%[rhs_ptr])\n"
+ "lbu $v1, 13(%[rhs_ptr])\n"
+ "lbu $t8, 14(%[rhs_ptr])\n"
+ "lbu $t9, 15(%[rhs_ptr])\n"
+
+ // First half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w27\n"
+ "dpadd_u.w $w4, $w25, $w27\n"
+ "dpadd_u.w $w8, $w26, $w27\n"
+ "dpadd_u.w $w1, $w24, $w28\n"
+ "dpadd_u.w $w5, $w25, $w28\n"
+ "dpadd_u.w $w9, $w26, $w28\n"
+ "dpadd_u.w $w2, $w24, $w29\n"
+ "dpadd_u.w $w6, $w25, $w29\n"
+ "dpadd_u.w $w10, $w26, $w29\n"
+ "dpadd_u.w $w3, $w24, $w30\n"
+ "dpadd_u.w $w7, $w25, $w30\n"
+ "dpadd_u.w $w11, $w26, $w30\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for
+ // dpadd_u.w (for the second half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Second half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w12, $w24, $w27\n"
+ "dpadd_u.w $w16, $w25, $w27\n"
+ "dpadd_u.w $w20, $w26, $w27\n"
+ "dpadd_u.w $w13, $w24, $w28\n"
+ "dpadd_u.w $w17, $w25, $w28\n"
+ "dpadd_u.w $w21, $w26, $w28\n"
+ "dpadd_u.w $w14, $w24, $w29\n"
+ "dpadd_u.w $w18, $w25, $w29\n"
+ "dpadd_u.w $w22, $w26, $w29\n"
+ "dpadd_u.w $w15, $w24, $w30\n"
+ "dpadd_u.w $w19, $w25, $w30\n"
+ "dpadd_u.w $w23, $w26, $w30\n"
+
+ GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU
+ " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
+ " %[rhs_ptr], 16\n"
+ "bnez %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ GEMMLOWP_LABEL_AFTER_LOOP ":\n"
+
+ // Store accumulators.
+ GEMMLOWP_MIPS_XADDU
+ " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ "st.w $w0, (0*16)(%[dst_ptr])\n"
+ "st.w $w4, (1*16)(%[dst_ptr])\n"
+ "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w1, (0*16)($a0)\n"
+ "st.w $w5, (1*16)($a0)\n"
+ "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w2, (0*16)($a1)\n"
+ "st.w $w6, (1*16)($a1)\n"
+ "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w3, (0*16)($a0)\n"
+ "st.w $w7, (1*16)($a0)\n"
+ "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w12, (0*16)($a1)\n"
+ "st.w $w16, (1*16)($a1)\n"
+ "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w13, (0*16)($a0)\n"
+ "st.w $w17, (1*16)($a0)\n"
+ "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w14, (0*16)($a1)\n"
+ "st.w $w18, (1*16)($a1)\n"
+ "st.w $w22, (2*16)($a1)\n"
+ "st.w $w15, (0*16)($a0)\n"
+ "st.w $w19, (1*16)($a0)\n"
+ "st.w $w23, (2*16)($a0)\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [run_depth] "+r"(run_depth),
+ [dst_col_stride] "+r"(dst_col_stride)
+ : // inputs
+ [dst_ptr] "r"(dst_ptr),
+ [start_depth] "r"(start_depth)
+ : // clobbers
+ "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1",
+ "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11",
+ "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
+ "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29",
+ "$f30", "$f31");
+
+#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+#undef GEMMLOWP_LABEL_BEFORE_LOOP
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP
+ }
+};
+
+#undef GEMMLOWP_MIPS_XADDU
+#undef GEMMLOWP_MIPS_XADDIU
+#undef GEMMLOWP_MIPS_XSLL
+
+#endif // GEMMLOWP_MSA
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_KERNEL_MSA_H_
diff --git a/internal/kernel_neon.h b/internal/kernel_neon.h
index 5c253ba..3cd48f4 100644
--- a/internal/kernel_neon.h
+++ b/internal/kernel_neon.h
@@ -421,52 +421,52 @@
GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS
":\n"
-// Overview of register layout:
-//
-// Registers q4--q16 are the local 16-bit accumulators.
-// However, each entry in the result matrix is represented
-// by *two* local 16-bit accumulators: one for even levels
-// of depth and one for odd levels of depth. These correspond
-// to the scalars at even and odd indices within each q-register.
-// Thus we effectively use 32 bits of register space for each
-// entry in the result matrix. The accumulators register layout
-// is the same as was described above for the global 32-bit
-// accumulators (3 cells of size 4x4 in diagonal-major order)
-// with the only difference that instead of 32bit values we have
-// pairs of 16bit values.
-//
-// A 2x4 cell of Rhs is stored in 8bit in d0.
-// A 12x2 block of 3 4x2 cells Lhs is stored in 8bit in d1--d3.
-//
-// +--------+--------+--------+--------+
-// |d0[0] |d0[2] |d0[4] |d0[6] |
-// Rhs +--------+--------+--------+--------+
-// |d0[1] |d0[3] |d0[5] |d0[7] |
-// +--------+--------+--------+--------+
-//
-// | | | | |
-//
-// Lhs | | | | |
-//
-// +-----+-----+ - - - +--------+--------+--------+--------+
-// |d1[0]|d1[1]| |q4[0,1] |q5[0,1] |q6[0,1] |q7[0,1] |
-// |d1[2]|d1[3]| |q7[2,3] |q4[2,3] |q5[2,3] |q6[2,3] |
-// |d1[4]|d1[5]| |q6[4,5] |q7[4,5] |q4[4,5] |q5[4,5] |
-// |d1[6]|d1[7]| |q5[6,7] |q6[6,7] |q7[6,7] |q4[6,7] |
-// +-----+-----+ - - - +--------+--------+--------+--------+
-// |d2[0]|d2[1]| |q8[0,1] |q8[0,1] |q8[0,1] |q8[0,1] |
-// |d2[2]|d2[3]| |q9[2,3] |q9[2,3] |q9[2,3] |q9[2,3] |
-// |d2[4]|d2[5]| |q10[4,5]|q10[4,5]|q10[4,5]|q10[4,5]|
-// |d2[6]|d2[7]| |q11[6,7]|q11[6,7]|q11[6,7]|q11[6,7]|
-// +-----+-----+ - - - +--------+--------+--------+--------+
-// |d3[0]|d3[1]| |q12[0,1]|q12[0,1]|q12[0,1]|q12[0,1]|
-// |d3[2]|d3[3]| |q13[2,3]|q13[2,3]|q13[2,3]|q13[2,3]|
-// |d3[4]|d3[5]| |q14[4,5]|q14[4,5]|q14[4,5]|q14[4,5]|
-// |d3[6]|d3[7]| |q15[6,7]|q15[6,7]|q15[6,7]|q15[6,7]|
-// +-----+-----+ - - - +--------+--------+--------+--------+
-//
-// Local 16-bit accumulators
-// Note: 2 scalars per matrix entry
+ // Overview of register layout:
+ //
+ // Registers q4--q16 are the local 16-bit accumulators.
+ // However, each entry in the result matrix is represented
+ // by *two* local 16-bit accumulators: one for even levels
+ // of depth and one for odd levels of depth. These correspond
+ // to the scalars at even and odd indices within each q-register.
+ // Thus we effectively use 32 bits of register space for each
+ // entry in the result matrix. The accumulators register layout
+ // is the same as was described above for the global 32-bit
+ // accumulators (3 cells of size 4x4 in diagonal-major order)
+ // with the only difference that instead of 32bit values we have
+ // pairs of 16bit values.
+ //
+ // A 2x4 cell of Rhs is stored in 8bit in d0.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 8bit in d1--d3.
+ //
+ // +--------+--------+--------+--------+
+ // |d0[0] |d0[2] |d0[4] |d0[6] |
+ // Rhs +--------+--------+--------+--------+
+ // |d0[1] |d0[3] |d0[5] |d0[7] |
+ // +--------+--------+--------+--------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ // |d1[0]|d1[1]| |q4[0,1] |q5[0,1] |q6[0,1] |q7[0,1] |
+ // |d1[2]|d1[3]| |q7[2,3] |q4[2,3] |q5[2,3] |q6[2,3] |
+ // |d1[4]|d1[5]| |q6[4,5] |q7[4,5] |q4[4,5] |q5[4,5] |
+ // |d1[6]|d1[7]| |q5[6,7] |q6[6,7] |q7[6,7] |q4[6,7] |
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ // |d2[0]|d2[1]| |q8[0,1] |q8[0,1] |q8[0,1] |q8[0,1] |
+ // |d2[2]|d2[3]| |q9[2,3] |q9[2,3] |q9[2,3] |q9[2,3] |
+ // |d2[4]|d2[5]| |q10[4,5]|q10[4,5]|q10[4,5]|q10[4,5]|
+ // |d2[6]|d2[7]| |q11[6,7]|q11[6,7]|q11[6,7]|q11[6,7]|
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ // |d3[0]|d3[1]| |q12[0,1]|q12[0,1]|q12[0,1]|q12[0,1]|
+ // |d3[2]|d3[3]| |q13[2,3]|q13[2,3]|q13[2,3]|q13[2,3]|
+ // |d3[4]|d3[5]| |q14[4,5]|q14[4,5]|q14[4,5]|q14[4,5]|
+ // |d3[6]|d3[7]| |q15[6,7]|q15[6,7]|q15[6,7]|q15[6,7]|
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ //
+ // Local 16-bit accumulators
+ // Note: 2 scalars per matrix entry
#define GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH \
/* Load 3 Lhs cells of size 4x2 */ \
@@ -1261,7 +1261,6 @@
}
};
-
// Our main GEMM kernel.
struct NEON_64_Kernel12x8Depth2 : KernelBase {
typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
diff --git a/internal/multi_thread_gemm.h b/internal/multi_thread_gemm.h
index df7387a..791402f 100644
--- a/internal/multi_thread_gemm.h
+++ b/internal/multi_thread_gemm.h
@@ -149,9 +149,7 @@
// to have finished working.
class BlockingCounter {
public:
- BlockingCounter()
- : count_(0),
- initial_count_(0) {
+ BlockingCounter() : count_(0), initial_count_(0) {
pthread_cond_init(&cond_, nullptr);
pthread_mutex_init(&mutex_, nullptr);
}
@@ -548,11 +546,6 @@
WorkersPool workers_pool_;
};
-// Needed by chrome native builds
-#ifndef _SC_NPROCESSORS_CONF
-#define _SC_NPROCESSORS_CONF _SC_NPROCESSORS_ONLN
-#endif
-
// Determines how many threads should be used for a given Gemm
// operation.
template <int KernelRows>
diff --git a/internal/output.h b/internal/output.h
index 8ccb8ee..dcfe2b5 100644
--- a/internal/output.h
+++ b/internal/output.h
@@ -119,12 +119,12 @@
template <int Size>
struct OutputStageEvalBufferImpl<
- OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
+ OutputStageQuantizeDownInt32ByFixedPoint,
RegisterBuffer<std::int32_t, Size>> {
typedef RegisterBuffer<std::int32_t, Size> InputType;
typedef RegisterBuffer<std::int32_t, Size> OutputType;
- typedef OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint OutputStage;
+ typedef OutputStageQuantizeDownInt32ByFixedPoint OutputStage;
OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
@@ -146,6 +146,39 @@
const OutputStage& output_stage;
};
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent,
+ RegisterBuffer<std::int32_t, Size>> {
+ typedef RegisterBuffer<std::int32_t, Size> InputType;
+ typedef RegisterBuffer<std::int32_t, Size> OutputType;
+
+ typedef OutputStageScaleInt32ByFixedPointAndExponent OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {
+ left_shift = std::max(0, output_stage.result_exponent);
+ right_shift = std::max(0, -output_stage.result_exponent);
+ }
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ using RegisterType = typename InputType::RegisterType;
+ const RegisterType result_offset_after_shift =
+ Dup<RegisterType>(output_stage.result_offset_after_shift);
+ for (int i = 0; i < InputType::kRegisterCount; i++) {
+ const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul(
+ ShiftLeft(input.reg[i], left_shift),
+ output_stage.result_fixedpoint_multiplier);
+ output.reg[i] = Add(RoundingDivideByPOT(mulhigh_val, right_shift),
+ result_offset_after_shift);
+ }
+ return output;
+ }
+
+ const OutputStage& output_stage;
+ int left_shift;
+ int right_shift;
+};
+
// Implementation of OutputStageSaturatingCastToUint8 for scalar data
template <int Size>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
@@ -169,6 +202,29 @@
}
};
+// Implementation of OutputStageSaturatingCastToInt16 for scalar data
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegisterBuffer<std::int32_t, Size>> {
+ typedef RegisterBuffer<std::int32_t, Size> InputType;
+ typedef RegisterBuffer<std::int16_t, Size> OutputType;
+ static_assert(InputType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ for (int i = 0; i < InputType::kRegisterCount; i++) {
+ std::int32_t data = input.reg[i];
+ output.reg[i] = data > 32767 ? 32767 : data < -32768 ? -32768 : data;
+ }
+ return output;
+ }
+};
+
template <int Rows, int Cols, typename VectorType>
struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
RegisterBlock<std::int32_t, Rows, Cols>> {
@@ -430,6 +486,8 @@
#include "output_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "output_sse.h"
+#elif defined(GEMMLOWP_MSA)
+#include "output_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_OUTPUT_H_
diff --git a/internal/output_msa.h b/internal/output_msa.h
new file mode 100644
index 0000000..4c8eb5d
--- /dev/null
+++ b/internal/output_msa.h
@@ -0,0 +1,622 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// output_msa.h: optimized MSA specializations of the templates in output.h.
+
+#ifndef GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
+#define GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
+
+#include "output.h"
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferUint8<4> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Signed saturate each 32-bit element to 9 bits
+ // (this takes full care of non-negative elements).
+ v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
+ // Pack every 32-bit element into 16 bits.
+ tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
+ // Detect negative elements with arithmetic shift right (we
+ // get a 16-bit mask of all zeroes or all ones for every element).
+ v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
+ // Zero out negative elements.
+ signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
+ reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
+ // Pack every element into 8 bits.
+ tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+ // Return 4 uint8_t elements as uint32_t.
+ output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferUint8<8> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Signed saturate each 32-bit element to 9 bits
+ // (this takes full care of non-negative elements).
+ v4i32 tmp_lo = __builtin_msa_sat_s_w(input.reg[0], 8);
+ v4i32 tmp_hi = __builtin_msa_sat_s_w(input.reg[1], 8);
+ // Pack every 32-bit element into 16 bits,
+ // combining all 8 elements into one vector.
+ tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
+ // Detect negative elements with arithmetic shift right (we
+ // get a 16-bit mask of all zeroes or all ones for every element).
+ v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
+ // Zero out negative elements.
+ signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
+ reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
+ // Pack every element into 8 bits.
+ tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+ // Return 8 uint8_t elements as 2 uint32_t's.
+ output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
+ output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
+ return output;
+ }
+};
+
+#define GEMMLOWP_MIPS_SAT_U8_16(out, in0, in1, in2, in3) \
+ { \
+ v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 8); \
+ v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 8); \
+ v4i32 tmp2 = __builtin_msa_sat_s_w(in2, 8); \
+ v4i32 tmp3 = __builtin_msa_sat_s_w(in3, 8); \
+ tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \
+ reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0))); \
+ tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \
+ reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2))); \
+ v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15); \
+ v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15); \
+ signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
+ reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
+ signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
+ reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
+ signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b( \
+ reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0))); \
+ out = reinterpret_cast<v16i8>(signs0); \
+ }
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferUint8<16> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
+ input.reg[2], input.reg[3]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferUint8<32> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
+ input.reg[2], input.reg[3]);
+ GEMMLOWP_MIPS_SAT_U8_16(output.reg[1], input.reg[4], input.reg[5],
+ input.reg[6], input.reg[7]);
+ return output;
+ }
+};
+
+#undef GEMMLOWP_MIPS_SAT_U8_16
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferInt16<4> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Signed saturate each 32-bit element to 16 bits.
+ v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
+ input.reg[0], 15));
+ output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
+ output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
+ output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
+ output.reg[3] = __builtin_msa_copy_s_h(tmp, 6);
+ return output;
+ }
+};
+
+#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \
+ { \
+ v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \
+ v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \
+ out = __builtin_msa_pckev_h( \
+ reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
+ }
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferInt16<8> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferInt16<16> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferInt16<32> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[2], input.reg[4], input.reg[5]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[3], input.reg[6], input.reg[7]);
+ return output;
+ }
+};
+
+#undef GEMMLOWP_MIPS_SAT_I16_8
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
+ static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
+ *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
+ *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
+ *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
+ static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+ StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
+ } else {
+ *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
+ *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
+ *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
+ *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
+ *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
+ *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
+ *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
+ *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
+ static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
+ int col) {
+ *dst->data(row + 0, col) = src.buf.reg[0];
+ *dst->data(row + 1, col) = src.buf.reg[1];
+ *dst->data(row + 2, col) = src.buf.reg[2];
+ *dst->data(row + 3, col) = src.buf.reg[3];
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
+ static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ *dst->data(row + 0, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 0);
+ *dst->data(row + 1, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 1);
+ *dst->data(row + 2, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 2);
+ *dst->data(row + 3, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 3);
+ *dst->data(row + 4, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 4);
+ *dst->data(row + 5, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 5);
+ *dst->data(row + 6, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 6);
+ *dst->data(row + 7, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 7);
+ }
+ }
+};
+
+inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
+ RegBlockInt32<4, 4> result;
+ v4i32 tmp0, tmp1;
+ tmp0 = __builtin_msa_ilvr_w(src.buf.reg[1], src.buf.reg[0]);
+ tmp1 = __builtin_msa_ilvr_w(src.buf.reg[3], src.buf.reg[2]);
+ result.buf.reg[0] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ result.buf.reg[1] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ tmp0 = __builtin_msa_ilvl_w(src.buf.reg[1], src.buf.reg[0]);
+ tmp1 = __builtin_msa_ilvl_w(src.buf.reg[3], src.buf.reg[2]);
+ result.buf.reg[2] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ result.buf.reg[3] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ return result;
+}
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
+ static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ const auto transpose = Transpose(src);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
+ static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int16_t buf[16];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 4 * j];
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
+ static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
+ StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
+ }
+ } else {
+ RegBlockInt32<4, 4> top;
+ top.buf.reg[0] = src.buf.reg[0];
+ top.buf.reg[1] = src.buf.reg[2];
+ top.buf.reg[2] = src.buf.reg[4];
+ top.buf.reg[3] = src.buf.reg[6];
+ const auto transpose_top = Transpose(top);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> bottom;
+ bottom.buf.reg[0] = src.buf.reg[1];
+ bottom.buf.reg[1] = src.buf.reg[3];
+ bottom.buf.reg[2] = src.buf.reg[5];
+ bottom.buf.reg[3] = src.buf.reg[7];
+ const auto transpose_bottom = Transpose(bottom);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
+ static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ std::int16_t buf[32];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ StoreInt16x8(buf + 16, src.buf.reg[2]);
+ StoreInt16x8(buf + 24, src.buf.reg[3]);
+ for (int i = 0; i < 8; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 8 * j];
+ }
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
+ static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 8; i++) {
+ StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
+ StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
+ }
+ } else {
+ RegBlockInt32<4, 4> top_left;
+ top_left.buf.reg[0] = src.buf.reg[0];
+ top_left.buf.reg[1] = src.buf.reg[2];
+ top_left.buf.reg[2] = src.buf.reg[4];
+ top_left.buf.reg[3] = src.buf.reg[6];
+ const auto transpose_top_left = Transpose(top_left);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> bottom_left;
+ bottom_left.buf.reg[0] = src.buf.reg[1];
+ bottom_left.buf.reg[1] = src.buf.reg[3];
+ bottom_left.buf.reg[2] = src.buf.reg[5];
+ bottom_left.buf.reg[3] = src.buf.reg[7];
+ const auto transpose_bottom_left = Transpose(bottom_left);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + 4 + i, col),
+ transpose_bottom_left.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> top_right;
+ top_right.buf.reg[0] = src.buf.reg[8];
+ top_right.buf.reg[1] = src.buf.reg[10];
+ top_right.buf.reg[2] = src.buf.reg[12];
+ top_right.buf.reg[3] = src.buf.reg[14];
+ const auto transpose_top_right = Transpose(top_right);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col + 4),
+ transpose_top_right.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> bottom_right;
+ bottom_right.buf.reg[0] = src.buf.reg[9];
+ bottom_right.buf.reg[1] = src.buf.reg[11];
+ bottom_right.buf.reg[2] = src.buf.reg[13];
+ bottom_right.buf.reg[3] = src.buf.reg[15];
+ const auto transpose_bottom_right = Transpose(bottom_right);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + 4 + i, col + 4),
+ transpose_bottom_right.buf.reg[i]);
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
+ static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 8; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ // top-left 4x4
+ v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
+ src.buf.reg[0]));
+ v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
+ src.buf.reg[2]));
+ v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
+ v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
+ // top-right 4x4
+ v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
+ src.buf.reg[4]));
+ v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
+ src.buf.reg[6]));
+ v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
+ v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
+ // bottom-left 4x4
+ v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
+ src.buf.reg[0]));
+ v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
+ src.buf.reg[2]));
+ v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
+ v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
+ // bottom-right 4x4
+ v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
+ src.buf.reg[4]));
+ v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
+ src.buf.reg[6]));
+ v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
+ v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
+
+ StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u7, u5)));
+ StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u7, u5)));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
+ static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
+ *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
+ *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
+ *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
+ } else {
+ StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
+ static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
+ int col) {
+ const std::uint32_t src_reg = src.buf.reg[0];
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row + i, col) = (src_reg >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
+ static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
+ int col) {
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
+ }
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
+ static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
+ int col) {
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
+ static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
+ int col) {
+ std::uint8_t buf[16];
+ StoreUint8x16(buf, src.buf.reg[0]);
+ for (int c = 0; c < 4; c++) {
+ for (int r = 0; r < 4; r++) {
+ *dst->data(row + r, col + c) = buf[r + 4 * c];
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
+ static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
+ int col) {
+ std::uint8_t buf[32];
+ StoreUint8x16(buf, src.buf.reg[0]);
+ StoreUint8x16(buf + 16, src.buf.reg[1]);
+ for (int c = 0; c < 4; c++) {
+ for (int r = 0; r < 8; r++) {
+ *dst->data(row + r, col + c) = buf[r + 8 * c];
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
+ static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
+ int col) {
+ std::uint8_t buf[64];
+ StoreUint8x16(buf, src.buf.reg[0]);
+ StoreUint8x16(buf + 16, src.buf.reg[1]);
+ StoreUint8x16(buf + 32, src.buf.reg[2]);
+ StoreUint8x16(buf + 48, src.buf.reg[3]);
+ for (int c = 0; c < 8; c++) {
+ for (int r = 0; r < 8; r++) {
+ *dst->data(row + r, col + c) = buf[r + 8 * c];
+ }
+ }
+ }
+};
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
diff --git a/internal/output_neon.h b/internal/output_neon.h
index 7e111e5..911fed0 100644
--- a/internal/output_neon.h
+++ b/internal/output_neon.h
@@ -107,6 +107,85 @@
}
};
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferInt16<4> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = vqmovn_s32(input.reg[0]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferInt16<8> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferInt16<16> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ output.reg[1] =
+ vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferInt16<32> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ output.reg[1] =
+ vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
+ output.reg[2] =
+ vcombine_s16(vqmovn_s32(input.reg[4]), vqmovn_s32(input.reg[5]));
+ output.reg[3] =
+ vcombine_s16(vqmovn_s32(input.reg[6]), vqmovn_s32(input.reg[7]));
+ return output;
+ }
+};
+
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
@@ -115,14 +194,48 @@
StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
} else {
- *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
- *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
- *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
- *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
- *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
- *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
- *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
- *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
+ vst1q_lane_s32(dst->data(row + 0, col), src.buf.reg[0], 0);
+ vst1q_lane_s32(dst->data(row + 1, col), src.buf.reg[0], 1);
+ vst1q_lane_s32(dst->data(row + 2, col), src.buf.reg[0], 2);
+ vst1q_lane_s32(dst->data(row + 3, col), src.buf.reg[0], 3);
+ vst1q_lane_s32(dst->data(row + 4, col), src.buf.reg[1], 0);
+ vst1q_lane_s32(dst->data(row + 5, col), src.buf.reg[1], 1);
+ vst1q_lane_s32(dst->data(row + 6, col), src.buf.reg[1], 2);
+ vst1q_lane_s32(dst->data(row + 7, col), src.buf.reg[1], 3);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
+ static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x4(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ vst1_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
+ vst1_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
+ vst1_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
+ vst1_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
+ static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ vst1q_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
+ vst1q_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
+ vst1q_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
+ vst1q_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
+ vst1q_lane_s16(dst->data(row + 4, col), src.buf.reg[0], 4);
+ vst1q_lane_s16(dst->data(row + 5, col), src.buf.reg[0], 5);
+ vst1q_lane_s16(dst->data(row + 6, col), src.buf.reg[0], 6);
+ vst1q_lane_s16(dst->data(row + 7, col), src.buf.reg[0], 7);
}
}
};
@@ -157,6 +270,35 @@
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
+ static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ vst1_s16(dst->data(row, col + 0), vget_low_s16(src.buf.reg[0]));
+ vst1_s16(dst->data(row, col + 1), vget_high_s16(src.buf.reg[0]));
+ vst1_s16(dst->data(row, col + 2), vget_low_s16(src.buf.reg[1]));
+ vst1_s16(dst->data(row, col + 3), vget_high_s16(src.buf.reg[1]));
+ } else {
+ const int16x4x2_t t0 =
+ vtrn_s16(vget_low_s16(src.buf.reg[0]), vget_high_s16(src.buf.reg[0]));
+ const int16x4x2_t t1 =
+ vtrn_s16(vget_low_s16(src.buf.reg[1]), vget_high_s16(src.buf.reg[1]));
+ const int32x4x2_t t =
+ vtrnq_s32(vreinterpretq_s32_s16(vcombine_s16(t0.val[0], t0.val[1])),
+ vreinterpretq_s32_s16(vcombine_s16(t1.val[0], t1.val[1])));
+ vst1_s16(dst->data(row + 0, col),
+ vget_low_s16(vreinterpretq_s16_s32(t.val[0])));
+ vst1_s16(dst->data(row + 1, col),
+ vget_high_s16(vreinterpretq_s16_s32(t.val[0])));
+ vst1_s16(dst->data(row + 2, col),
+ vget_low_s16(vreinterpretq_s16_s32(t.val[1])));
+ vst1_s16(dst->data(row + 3, col),
+ vget_high_s16(vreinterpretq_s16_s32(t.val[1])));
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
int col) {
@@ -192,6 +334,42 @@
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
+ static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
+ vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
+ vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
+ vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
+ } else {
+ const int16x8x2_t t0 = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
+ const int16x8x2_t t1 = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
+ const int32x4x2_t u0 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[0]),
+ vreinterpretq_s32_s16(t1.val[0]));
+ const int32x4x2_t u1 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[1]),
+ vreinterpretq_s32_s16(t1.val[1]));
+ vst1_s16(dst->data(row + 0, col),
+ vget_low_s16(vreinterpretq_s16_s32(u0.val[0])));
+ vst1_s16(dst->data(row + 1, col),
+ vget_low_s16(vreinterpretq_s16_s32(u1.val[0])));
+ vst1_s16(dst->data(row + 2, col),
+ vget_low_s16(vreinterpretq_s16_s32(u0.val[1])));
+ vst1_s16(dst->data(row + 3, col),
+ vget_low_s16(vreinterpretq_s16_s32(u1.val[1])));
+ vst1_s16(dst->data(row + 4, col),
+ vget_high_s16(vreinterpretq_s16_s32(u0.val[0])));
+ vst1_s16(dst->data(row + 5, col),
+ vget_high_s16(vreinterpretq_s16_s32(u1.val[0])));
+ vst1_s16(dst->data(row + 6, col),
+ vget_high_s16(vreinterpretq_s16_s32(u0.val[1])));
+ vst1_s16(dst->data(row + 7, col),
+ vget_high_s16(vreinterpretq_s16_s32(u1.val[1])));
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
int col) {
@@ -281,6 +459,23 @@
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<1, 4>, DstType> {
+ static void Run(const RegBlockInt16<1, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int16_t* dst_ptr = dst->data(row, col);
+ if (DstType::kOrder == MapOrder::RowMajor) {
+ vst1_s16(dst_ptr, src.buf.reg[0]);
+ } else {
+ int col_stride = dst->cols_stride();
+ vst1_lane_s16(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
+ vst1_lane_s16(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
+ vst1_lane_s16(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
+ vst1_lane_s16(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
int col) {
@@ -427,6 +622,70 @@
}
};
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
+ static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
+ vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
+ vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
+ vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
+ vst1q_s16(dst->data(row, col + 4), src.buf.reg[4]);
+ vst1q_s16(dst->data(row, col + 5), src.buf.reg[5]);
+ vst1q_s16(dst->data(row, col + 6), src.buf.reg[6]);
+ vst1q_s16(dst->data(row, col + 7), src.buf.reg[7]);
+ } else {
+ int16x8x2_t a[4];
+ a[0] = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
+ a[1] = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
+ a[2] = vtrnq_s16(src.buf.reg[4], src.buf.reg[5]);
+ a[3] = vtrnq_s16(src.buf.reg[6], src.buf.reg[7]);
+ int32x4x2_t b[4];
+ b[0] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[0]),
+ vreinterpretq_s32_s16(a[1].val[0]));
+ b[1] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[1]),
+ vreinterpretq_s32_s16(a[1].val[1]));
+ b[2] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[0]),
+ vreinterpretq_s32_s16(a[3].val[0]));
+ b[3] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[1]),
+ vreinterpretq_s32_s16(a[3].val[1]));
+ vst1_s16(dst->data(row + 0, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[0].val[0])));
+ vst1_s16(dst->data(row + 0, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[2].val[0])));
+ vst1_s16(dst->data(row + 1, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[1].val[0])));
+ vst1_s16(dst->data(row + 1, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[3].val[0])));
+ vst1_s16(dst->data(row + 2, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[0].val[1])));
+ vst1_s16(dst->data(row + 2, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[2].val[1])));
+ vst1_s16(dst->data(row + 3, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[1].val[1])));
+ vst1_s16(dst->data(row + 3, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[3].val[1])));
+ vst1_s16(dst->data(row + 4, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[0].val[0])));
+ vst1_s16(dst->data(row + 4, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[2].val[0])));
+ vst1_s16(dst->data(row + 5, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[1].val[0])));
+ vst1_s16(dst->data(row + 5, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[3].val[0])));
+ vst1_s16(dst->data(row + 6, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[0].val[1])));
+ vst1_s16(dst->data(row + 6, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[2].val[1])));
+ vst1_s16(dst->data(row + 7, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[1].val[1])));
+ vst1_s16(dst->data(row + 7, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[3].val[1])));
+ }
+ }
+};
+
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
diff --git a/internal/output_sse.h b/internal/output_sse.h
index 5c06253..75aebfd 100644
--- a/internal/output_sse.h
+++ b/internal/output_sse.h
@@ -103,6 +103,82 @@
}
};
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferInt16<4> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
+ output.reg[0] = _mm_extract_epi16(res_16, 0);
+ output.reg[1] = _mm_extract_epi16(res_16, 1);
+ output.reg[2] = _mm_extract_epi16(res_16, 2);
+ output.reg[3] = _mm_extract_epi16(res_16, 3);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferInt16<8> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferInt16<16> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
+ output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferInt16<32> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
+ output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
+ output.reg[2] = _mm_packs_epi32(input.reg[4], input.reg[5]);
+ output.reg[3] = _mm_packs_epi32(input.reg[6], input.reg[7]);
+ return output;
+ }
+};
+
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
@@ -138,6 +214,36 @@
}
};
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
+ static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
+ int col) {
+ *dst->data(row + 0, col) = src.buf.reg[0];
+ *dst->data(row + 1, col) = src.buf.reg[1];
+ *dst->data(row + 2, col) = src.buf.reg[2];
+ *dst->data(row + 3, col) = src.buf.reg[3];
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
+ static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ *dst->data(row + 0, col) = _mm_extract_epi16(src.buf.reg[0], 0);
+ *dst->data(row + 1, col) = _mm_extract_epi16(src.buf.reg[0], 1);
+ *dst->data(row + 2, col) = _mm_extract_epi16(src.buf.reg[0], 2);
+ *dst->data(row + 3, col) = _mm_extract_epi16(src.buf.reg[0], 3);
+ *dst->data(row + 4, col) = _mm_extract_epi16(src.buf.reg[0], 4);
+ *dst->data(row + 5, col) = _mm_extract_epi16(src.buf.reg[0], 5);
+ *dst->data(row + 6, col) = _mm_extract_epi16(src.buf.reg[0], 6);
+ *dst->data(row + 7, col) = _mm_extract_epi16(src.buf.reg[0], 7);
+ }
+ }
+};
+
inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
__m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]);
__m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]);
@@ -170,6 +276,21 @@
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
+ static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int16_t buf[16];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 4 * j];
+ }
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
int col) {
@@ -202,6 +323,29 @@
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
+ static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ std::int16_t buf[32];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ StoreInt16x8(buf + 16, src.buf.reg[2]);
+ StoreInt16x8(buf + 24, src.buf.reg[3]);
+ for (int i = 0; i < 8; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 8 * j];
+ }
+ }
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
int col) {
@@ -255,6 +399,48 @@
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
+ static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 8; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ // top-left 4x4
+ __m128i t0 = _mm_unpacklo_epi16(src.buf.reg[0], src.buf.reg[1]);
+ __m128i t1 = _mm_unpacklo_epi16(src.buf.reg[2], src.buf.reg[3]);
+ __m128i u0 = _mm_unpacklo_epi32(t0, t1);
+ __m128i u1 = _mm_unpackhi_epi32(t0, t1);
+ // top-right 4x4
+ __m128i t2 = _mm_unpacklo_epi16(src.buf.reg[4], src.buf.reg[5]);
+ __m128i t3 = _mm_unpacklo_epi16(src.buf.reg[6], src.buf.reg[7]);
+ __m128i u2 = _mm_unpacklo_epi32(t2, t3);
+ __m128i u3 = _mm_unpackhi_epi32(t2, t3);
+ // bottom-left 4x4
+ __m128i t4 = _mm_unpackhi_epi16(src.buf.reg[0], src.buf.reg[1]);
+ __m128i t5 = _mm_unpackhi_epi16(src.buf.reg[2], src.buf.reg[3]);
+ __m128i u4 = _mm_unpacklo_epi32(t4, t5);
+ __m128i u5 = _mm_unpackhi_epi32(t4, t5);
+ // bottom-right 4x4
+ __m128i t6 = _mm_unpackhi_epi16(src.buf.reg[4], src.buf.reg[5]);
+ __m128i t7 = _mm_unpackhi_epi16(src.buf.reg[6], src.buf.reg[7]);
+ __m128i u6 = _mm_unpacklo_epi32(t6, t7);
+ __m128i u7 = _mm_unpackhi_epi32(t6, t7);
+
+ StoreInt16x8(dst->data(row + 0, col), _mm_unpacklo_epi64(u0, u2));
+ StoreInt16x8(dst->data(row + 1, col), _mm_unpackhi_epi64(u0, u2));
+ StoreInt16x8(dst->data(row + 2, col), _mm_unpacklo_epi64(u1, u3));
+ StoreInt16x8(dst->data(row + 3, col), _mm_unpackhi_epi64(u1, u3));
+ StoreInt16x8(dst->data(row + 4, col), _mm_unpacklo_epi64(u4, u6));
+ StoreInt16x8(dst->data(row + 5, col), _mm_unpackhi_epi64(u4, u6));
+ StoreInt16x8(dst->data(row + 6, col), _mm_unpacklo_epi64(u5, u7));
+ StoreInt16x8(dst->data(row + 7, col), _mm_unpackhi_epi64(u5, u7));
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
int col) {
diff --git a/internal/pack.h b/internal/pack.h
index 3395396..cb4b93a 100644
--- a/internal/pack.h
+++ b/internal/pack.h
@@ -430,6 +430,8 @@
#include "pack_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "pack_sse.h"
+#elif defined(GEMMLOWP_MSA)
+#include "pack_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_PACK_H_
diff --git a/internal/pack_msa.h b/internal/pack_msa.h
new file mode 100644
index 0000000..fba8a0f
--- /dev/null
+++ b/internal/pack_msa.h
@@ -0,0 +1,353 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// pack_msa.h: optimized MSA specializations of the templates in pack.h.
+
+#ifndef GEMMLOWP_INTERNAL_PACK_MSA_H_
+#define GEMMLOWP_INTERNAL_PACK_MSA_H_
+
+#include "pack.h"
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
+ WidthMajorUint8SideMap;
+
+template <int Cells>
+using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+ typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static constexpr int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+ std::uint8_t* dst_ptr = dst->current_data();
+ const std::uint8_t* const src_ptr = this->complete_src_.data();
+ const int stride = this->complete_src_.stride();
+ // Load source WidthMajor data
+ v16i8 src_lines[4 * kCells];
+ for (int i = 0; i < 4 * kCells; i++) {
+ src_lines[i] = __builtin_msa_ld_b(
+ const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
+ }
+ // Reorder the data within registers to make DepthMajor 4x2 cells
+ v16i8 src_lines_intertwined_2x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_2x[2 * i][0] =
+ __builtin_msa_ilvr_b(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i][1] =
+ __builtin_msa_ilvl_b(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i + 1][0] =
+ __builtin_msa_ilvr_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ src_lines_intertwined_2x[2 * i + 1][1] =
+ __builtin_msa_ilvl_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ }
+ v16i8 src_lines_intertwined_4x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_4x[2 * i][0] =
+ __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i][1] =
+ __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i + 1][0] =
+ __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ src_lines_intertwined_4x[2 * i + 1][1] =
+ __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ }
+ // Store the resulting DepthMajor 4x2 cells in the destination packed block
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ if (kCells % 2 == 0) {
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ } else {
+ // Store even number of low vector halves.
+ for (int cell = 0; cell < kCells - 1; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ // Store last low half and first high half.
+ v2i64 tmp = reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * 0 + outer][inner]);
+ tmp = __builtin_msa_insve_d(
+ tmp, 0,
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ // Store even number of high vector halves.
+ for (int cell = 1; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ }
+ }
+ }
+ // Compute sums across the depth dimension
+ v8i16 sums_of_2_cells[kCells][4];
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ int i = 2 * outer + inner;
+ for (int cell = 0; cell < kCells; cell++) {
+ v8i16 tmp0 = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(
+ zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
+ v8i16 tmp1 = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(
+ zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
+ sums_of_2_cells[cell][i] = __builtin_msa_addv_h(tmp0, tmp1);
+ }
+ }
+ }
+ v4i32 sums_of_4_cells[kCells][4];
+ for (int i = 0; i < 4; i++) {
+ for (int cell = 0; cell < kCells; cell++) {
+ v4i32 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(
+ reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
+ v4i32 tmp1 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(
+ reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
+ sums_of_4_cells[cell][i] = __builtin_msa_addv_w(tmp0, tmp1);
+ }
+ }
+ // Update the sums_of_each_slice vector
+ for (int cell = 0; cell < kCells; cell++) {
+ v4i32 s01 = __builtin_msa_addv_w(sums_of_4_cells[cell][0],
+ sums_of_4_cells[cell][1]);
+ v4i32 s23 = __builtin_msa_addv_w(sums_of_4_cells[cell][2],
+ sums_of_4_cells[cell][3]);
+ v4i32 s = __builtin_msa_addv_w(s01, s23);
+ std::int32_t* sums_of_each_slice_ptr =
+ dst->sums_of_each_slice() + start_width + 4 * cell;
+ v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
+ tmp = __builtin_msa_addv_w(tmp, s);
+ __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
+ }
+ dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+ }
+};
+
+template <int Cells>
+using WidthMajorSideFormatNCells4x2 =
+ KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+ typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static constexpr int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+ std::uint8_t* dst_ptr = dst->current_data();
+ const std::uint8_t* src_ptr = this->complete_src_.data();
+ const int stride = this->complete_src_.stride();
+ // Load source WidthMajor data
+ v8i16 src_lines[kCells * 4];
+ for (int i = 0; i < kCells; i++) {
+#define GEMMLOWP_UNROLLED_LOOP_ITER(k) \
+ src_lines[4 * i + k] = \
+ __builtin_msa_ld_h(const_cast<std::uint8_t*>(src_ptr), 0); \
+ src_ptr += stride;
+
+ GEMMLOWP_UNROLLED_LOOP_ITER(0)
+ GEMMLOWP_UNROLLED_LOOP_ITER(1)
+ GEMMLOWP_UNROLLED_LOOP_ITER(2)
+ GEMMLOWP_UNROLLED_LOOP_ITER(3)
+
+#undef GEMMLOWP_UNROLLED_LOOP_ITER
+ }
+ // Reorder the data within registers to make WidthMajor 4x2 cells
+ v8i16 src_lines_intertwined_2x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_2x[2 * i][0] =
+ __builtin_msa_ilvr_h(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i][1] =
+ __builtin_msa_ilvl_h(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i + 1][0] =
+ __builtin_msa_ilvr_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ src_lines_intertwined_2x[2 * i + 1][1] =
+ __builtin_msa_ilvl_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ }
+ v8i16 src_lines_intertwined_4x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_4x[2 * i][0] =
+ __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i][1] =
+ __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i + 1][0] =
+ __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ src_lines_intertwined_4x[2 * i + 1][1] =
+ __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ }
+ // Store the resulting WidthMajor 4x2 cells in the destination packed block
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ if (kCells % 2 == 0) {
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ } else {
+ // Store even number of low vector halves.
+ for (int cell = 0; cell < kCells - 1; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ // Store last low half and first high half.
+ v2i64 tmp = reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * 0 + outer][inner]);
+ tmp = __builtin_msa_insve_d(
+ tmp, 0,
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ // Store even number of high vector halves.
+ for (int cell = 1; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ }
+ }
+ }
+ // Compute sums across the depth dimension
+ v8i16 sums_of_2[kCells][4];
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ int i = 2 * outer + inner;
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_2[cell][i] = reinterpret_cast<v8i16>(__builtin_msa_hadd_u_h(
+ reinterpret_cast<v16u8>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]),
+ reinterpret_cast<v16u8>(
+ src_lines_intertwined_4x[2 * cell + outer][inner])));
+ }
+ }
+ }
+ v8i16 sums_of_4[kCells][2];
+ for (int i = 0; i < 2; i++) {
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_4[cell][i] = __builtin_msa_addv_h(sums_of_2[cell][2 * i],
+ sums_of_2[cell][2 * i + 1]);
+ }
+ }
+ v8i16 sums_of_8[kCells];
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_8[cell] =
+ __builtin_msa_addv_h(sums_of_4[cell][0], sums_of_4[cell][1]);
+ }
+
+ v4i32 sums_of_16[kCells];
+ const v8i16 zeroes = __builtin_msa_ldi_h(0);
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_16[cell] = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvr_h(zeroes, sums_of_8[cell]));
+ v8i16 tmp = __builtin_msa_ilvl_h(zeroes, sums_of_8[cell]);
+ sums_of_16[cell] =
+ __builtin_msa_addv_w(sums_of_16[cell], reinterpret_cast<v4i32>(tmp));
+ }
+ // Update the sums_of_each_slice vector
+ for (int cell = 0; cell < kCells; cell++) {
+ std::int32_t* sums_of_each_slice_ptr =
+ dst->sums_of_each_slice() + start_width + 4 * cell;
+ v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
+ tmp = __builtin_msa_addv_w(tmp, sums_of_16[cell]);
+ __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
+ }
+ dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+ }
+};
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_PACK_MSA_H_
diff --git a/internal/pack_neon.h b/internal/pack_neon.h
index e212d07..2b08464 100644
--- a/internal/pack_neon.h
+++ b/internal/pack_neon.h
@@ -153,10 +153,10 @@
// Load source WidthMajor data
uint16x8_t src_lines[kCells * 4];
for (int i = 0; i < kCells; i++) {
-// This packing path is used with our current
-// less-than-8-bit kernel, and the partial unrolling of this loop
-// results in substantially faster code (thanks to better
-// register allocation) on Nexus 5.
+ // This packing path is used with our current
+ // less-than-8-bit kernel, and the partial unrolling of this loop
+ // results in substantially faster code (thanks to better
+ // register allocation) on Nexus 5.
#define GEMMLOWP_UNROLLED_LOOP_ITER(k) \
src_lines[4 * i + k] = vreinterpretq_u16_u8(vld1q_u8(src_ptr)); \
diff --git a/internal/platform.h b/internal/platform.h
old mode 100755
new mode 100644
index 49e41a9..1114767
--- a/internal/platform.h
+++ b/internal/platform.h
@@ -17,17 +17,20 @@
#ifndef GEMMLOWP_INTERNAL_PLATFORM_H_
#define GEMMLOWP_INTERNAL_PLATFORM_H_
-
#ifdef _WIN32
#include <windows.h>
#else
-#include <unistd.h>
-#include <time.h>
#include <stdlib.h>
+#include <time.h>
+#include <unistd.h>
#endif
-#include <malloc.h>
+
+#ifdef __APPLE__
+#include <sys/time.h>
+#endif
#if defined ANDROID || defined __ANDROID__
+#include <malloc.h>
#include <android/api-level.h>
// The 18 here should be 16, but has to be 18 for now due
// to a Google-internal issue.
@@ -42,6 +45,10 @@
#endif
#endif
+// Needed by chrome native builds
+#ifndef _SC_NPROCESSORS_CONF
+#define _SC_NPROCESSORS_CONF _SC_NPROCESSORS_ONLN
+#endif
namespace gemmlowp {
@@ -50,9 +57,7 @@
return _aligned_malloc(size, alignment);
}
-inline void aligned_free(void *memptr) {
- _aligned_free(memptr);
-}
+inline void aligned_free(void *memptr) { _aligned_free(memptr); }
inline int GetHardwareConcurrency(int max_threads) {
if (max_threads == 0) {
@@ -64,8 +69,9 @@
}
inline double real_time_in_seconds() {
- __int64 wintime; GetSystemTimeAsFileTime((FILETIME*)&wintime);
- wintime -= 116444736000000000i64; //1jan1601 to 1jan1970
+ __int64 wintime;
+ GetSystemTimeAsFileTime((FILETIME *)&wintime);
+ wintime -= 116444736000000000i64; // 1jan1601 to 1jan1970
return wintime / 10000000i64 + wintime % 10000000i64 * 100 * 1e-9;
}
@@ -91,9 +97,7 @@
return max_threads;
}
-inline void aligned_free(void *memptr) {
- free(memptr);
-}
+inline void aligned_free(void *memptr) { free(memptr); }
inline double real_time_in_seconds() {
#ifdef __APPLE__
@@ -108,5 +112,5 @@
}
#endif
-} // namespace gemmlowp
+} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_PLATFORM_H_
diff --git a/internal/simd_wrappers.h b/internal/simd_wrappers.h
index e39eaf8..d9721c9 100644
--- a/internal/simd_wrappers.h
+++ b/internal/simd_wrappers.h
@@ -491,10 +491,14 @@
template <int N>
using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
template <int N>
+using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
+template <int N>
using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
template <int R, int C>
using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
template <int R, int C>
+using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
+template <int R, int C>
using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
} // end namespace gemmlowp
@@ -503,6 +507,8 @@
#include "simd_wrappers_neon.h"
#elif defined GEMMLOWP_SSE4
#include "simd_wrappers_sse.h"
+#elif defined GEMMLOWP_MSA
+#include "simd_wrappers_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
diff --git a/internal/simd_wrappers_msa.h b/internal/simd_wrappers_msa.h
new file mode 100644
index 0000000..cf5e8e9
--- /dev/null
+++ b/internal/simd_wrappers_msa.h
@@ -0,0 +1,196 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// simd_wrappers_msa.h: MSA specialization of simd_wrappers.h
+
+#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
+#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+using Int32x4 = v4i32;
+using Int16x8 = v8i16;
+using Uint8x16 = v16i8;
+
+template <int ScalarCount>
+struct RegisterType<std::int32_t, ScalarCount> {
+ using Type =
+ typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
+};
+
+template <int ScalarCount>
+struct RegisterType<std::int16_t, ScalarCount> {
+ using Type =
+ typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
+};
+
+template <int ScalarCount>
+struct RegisterType<std::uint8_t, ScalarCount> {
+ using Type = typename std::conditional<
+ ScalarCount >= 16, Uint8x16,
+ typename std::conditional<ScalarCount >= 4, std::uint32_t,
+ std::uint8_t>::type>::type;
+};
+
+inline Int32x4 LoadInt32x4(const std::int32_t* src) {
+ return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0);
+}
+
+inline Int32x4 LoadInt32x4(const Int32x4* src) {
+ return __builtin_msa_ld_w(const_cast<Int32x4*>(src), 0);
+}
+
+inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
+ __builtin_msa_st_w(value, dst, 0);
+}
+
+inline void StoreInt32x4(Int32x4* dst, Int32x4 value) {
+ __builtin_msa_st_w(value, dst, 0);
+}
+
+inline Int16x8 LoadInt16x8(const std::int16_t* src) {
+ return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0);
+}
+
+inline Int16x8 LoadInt16x8(const Int16x8* src) {
+ return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0);
+}
+
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ __builtin_msa_st_h(value, dst, 0);
+}
+
+inline void StoreInt16x8(Int16x8* dst, Int16x8 value) {
+ __builtin_msa_st_h(value, dst, 0);
+}
+
+inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
+ return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0);
+}
+
+inline Uint8x16 LoadUint8x16(const Uint8x16* src) {
+ return __builtin_msa_ld_b(const_cast<Uint8x16*>(src), 0);
+}
+
+inline void StoreUint8x16(std::uint8_t* dst, Uint8x16 value) {
+ __builtin_msa_st_b(value, dst, 0);
+}
+
+inline void StoreUint8x16(Uint8x16* dst, Uint8x16 value) {
+ __builtin_msa_st_b(value, dst, 0);
+}
+
+template <int Lane>
+std::int32_t GetLane(Int32x4 value) {
+ return __builtin_msa_copy_s_w(value, Lane);
+}
+
+template <int Lane>
+Int32x4 DupLane(Int32x4 value) {
+ static_assert(Lane >= 0 && Lane <= 3, "");
+ return __builtin_msa_splati_w(value, Lane);
+}
+
+inline Int32x4 Mul(Int32x4 a, std::int32_t b) {
+ return __builtin_msa_mulv_w(a, __builtin_msa_fill_w(b));
+}
+
+inline Int32x4 Min(Int32x4 a, Int32x4 b) { return __builtin_msa_min_s_w(a, b); }
+
+inline Int32x4 Max(Int32x4 a, Int32x4 b) { return __builtin_msa_max_s_w(a, b); }
+
+inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
+ return __builtin_msa_mulr_q_w(a, __builtin_msa_fill_w(b));
+}
+
+template <int Lane>
+Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
+ static_assert(Lane >= 0 && Lane <= 3, "");
+ return __builtin_msa_mulv_w(a, __builtin_msa_splati_w(b, Lane));
+}
+
+static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
+ // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
+#if 0
+ return __builtin_msa_maddv_w(a, b, c);
+#else
+ asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
+ // Outputs
+ : [a] "+f"(a)
+ // Inputs
+ : [b] "f"(b), [c] "f"(c));
+ return a;
+#endif
+}
+
+inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+ Int32x4 tmp = LoadInt32x4(acc);
+ tmp = workaround_msa_maddv_w(tmp, lhs, rhs);
+ StoreInt32x4(acc, tmp);
+}
+
+inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
+ Int32x4 tmp = LoadInt32x4(acc);
+ tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_fill_w(rhs));
+ StoreInt32x4(acc, tmp);
+}
+
+template <int Lane>
+inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+ static_assert(Lane >= 0 && Lane <= 3, "");
+ Int32x4 tmp = LoadInt32x4(acc);
+ tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_splati_w(rhs, Lane));
+ StoreInt32x4(acc, tmp);
+}
+
+template <>
+struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
+ static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
+ RegBlockUint8<8, 8> result;
+ for (int i = 0; i < 4; i++) {
+ result.buf.reg[i] = LoadUint8x16(src + 16 * i);
+ }
+ return result;
+ }
+};
+
+template <>
+struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
+ static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
+ RegBlockInt32<8, 8> result;
+ for (int i = 0; i < 16; i++) {
+ result.buf.reg[i] = LoadInt32x4(src + 4 * i);
+ }
+ return result;
+ }
+};
+
+template <>
+struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
+ static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
+ RegBlockInt16<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = LoadInt16x8(src + 8 * i);
+ }
+ return result;
+ }
+};
+
+} // end namespace gemmlowp
+
+#include "simd_wrappers_common_neon_sse.h"
+
+#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h
index c992b15..2949173 100644
--- a/internal/simd_wrappers_neon.h
+++ b/internal/simd_wrappers_neon.h
@@ -22,6 +22,8 @@
namespace gemmlowp {
using Int32x4 = int32x4_t;
+using Int16x4 = int16x4_t;
+using Int16x8 = int16x8_t;
using Uint8x8 = uint8x8_t;
template <int ScalarCount>
@@ -31,6 +33,14 @@
};
template <int ScalarCount>
+struct RegisterType<std::int16_t, ScalarCount> {
+ using Type = typename std::conditional<
+ ScalarCount >= 8, Int16x8,
+ typename std::conditional<ScalarCount >= 4, Int16x4,
+ std::int16_t>::type>::type;
+};
+
+template <int ScalarCount>
struct RegisterType<std::uint8_t, ScalarCount> {
using Type = typename std::conditional<
ScalarCount >= 8, Uint8x8,
@@ -39,11 +49,21 @@
};
inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
+inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
+inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
vst1q_s32(dst, value);
}
+inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) {
+ vst1_s16(dst, value);
+}
+
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ vst1q_s16(dst, value);
+}
+
template <int Lane>
std::int32_t GetLane(Int32x4 value) {
return vgetq_lane_s32(value, Lane);
@@ -122,6 +142,17 @@
}
template <>
+struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
+ static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
+ RegBlockInt16<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = vld1q_s16(src + 8 * i);
+ }
+ return result;
+ }
+};
+
+template <>
struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
RegBlockUint8<8, 8> result;
diff --git a/internal/simd_wrappers_sse.h b/internal/simd_wrappers_sse.h
index 6480b66..3b78cb4 100644
--- a/internal/simd_wrappers_sse.h
+++ b/internal/simd_wrappers_sse.h
@@ -22,6 +22,7 @@
namespace gemmlowp {
using Int32x4 = __m128i;
+using Int16x8 = __m128i;
using Uint8x16 = __m128i;
template <int ScalarCount>
@@ -31,6 +32,12 @@
};
template <int ScalarCount>
+struct RegisterType<std::int16_t, ScalarCount> {
+ using Type =
+ typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
+};
+
+template <int ScalarCount>
struct RegisterType<std::uint8_t, ScalarCount> {
using Type = typename std::conditional<
ScalarCount >= 16, Uint8x16,
@@ -42,10 +49,18 @@
return _mm_loadu_si128(reinterpret_cast<const Int32x4*>(src));
}
+inline Int32x4 LoadInt16x8(const std::int16_t* src) {
+ return _mm_loadu_si128(reinterpret_cast<const Int16x8*>(src));
+}
+
inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value);
}
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value);
+}
+
inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
return _mm_loadu_si128(reinterpret_cast<const Uint8x16*>(src));
}
@@ -116,6 +131,17 @@
}
};
+template <>
+struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
+ static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
+ RegBlockInt16<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = LoadInt16x8(src + 8 * i);
+ }
+ return result;
+ }
+};
+
} // end namespace gemmlowp
#include "simd_wrappers_common_neon_sse.h"
diff --git a/internal/single_thread_gemm.h b/internal/single_thread_gemm.h
index 3d430c5..35a7835 100644
--- a/internal/single_thread_gemm.h
+++ b/internal/single_thread_gemm.h
@@ -89,10 +89,9 @@
Allocator* allocator = context->allocator();
BlockParams block_params;
- block_params.Init<KernelFormat>(rows, cols, depth, 1,
- context->l1_bytes_to_use(),
- context->l2_bytes_to_use(),
- context->l2_rhs_factor());
+ block_params.Init<KernelFormat>(
+ rows, cols, depth, 1, context->l1_bytes_to_use(),
+ context->l2_bytes_to_use(), context->l2_rhs_factor());
#ifdef GEMMLOWP_PROFILING_SIZES
// Using a static map of label strings. Not reentrant at all!
diff --git a/meta/multi_thread_common.h b/meta/multi_thread_common.h
index dc1b799..0b35759 100644
--- a/meta/multi_thread_common.h
+++ b/meta/multi_thread_common.h
@@ -20,6 +20,15 @@
namespace gemmlowp {
namespace meta {
+inline int ResolveMaxThreads(int max_threads) {
+ if (max_threads == 0) {
+ static const int hardware_threads_count =
+ static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
+ return hardware_threads_count;
+ }
+ return max_threads;
+}
+
template <typename WorkersPool>
class SimpleContext {
public:
diff --git a/profiling/instrumentation.h b/profiling/instrumentation.h
index 539076a..437fe54 100644
--- a/profiling/instrumentation.h
+++ b/profiling/instrumentation.h
@@ -24,7 +24,6 @@
#ifndef GEMMLOWP_PROFILING_INSTRUMENTATION_H_
#define GEMMLOWP_PROFILING_INSTRUMENTATION_H_
-#include <pthread.h>
#include <cstdio>
#ifndef GEMMLOWP_USE_STLPORT
@@ -32,15 +31,15 @@
#else
#include <stdint.h>
namespace std {
-using ::uint8_t;
-using ::uint16_t;
-using ::uint32_t;
-using ::int8_t;
using ::int16_t;
using ::int32_t;
+using ::int8_t;
using ::size_t;
+using ::uint16_t;
+using ::uint32_t;
+using ::uint8_t;
using ::uintptr_t;
-}
+} // namespace std
#endif
#include <algorithm>
@@ -52,6 +51,8 @@
#include <set>
#endif
+#include "./pthread_everywhere.h"
+
namespace gemmlowp {
inline void ReleaseBuildAssertion(bool condition, const char* msg) {
diff --git a/profiling/pthread_everywhere.h b/profiling/pthread_everywhere.h
index 7e12d66..df17c6f 100644
--- a/profiling/pthread_everywhere.h
+++ b/profiling/pthread_everywhere.h
@@ -18,8 +18,6 @@
#ifndef GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_
#define GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_
-#include "pthread_everywhere.h"
-
#ifndef _WIN32
#define GEMMLOWP_USE_PTHREAD
#endif
@@ -39,39 +37,29 @@
// structs; ours take nullptr_t. That is because gemmlowp always passes
// nullptr at the moment, so any support we would code for non-null
// attribs would be unused.
-#include <thread>
-#include <mutex>
#include <condition_variable>
#include <cstddef>
+#include <mutex>
+#include <thread>
namespace gemmlowp {
-using pthread_t = std::thread*;
-using pthread_mutex_t = std::mutex*;
-using pthread_cond_t = std::condition_variable*;
-inline void pthread_create(pthread_t* thread, std::nullptr_t,
- void *(*start_routine) (void *), void *arg) {
+using pthread_t = std::thread *;
+using pthread_mutex_t = std::mutex *;
+using pthread_cond_t = std::condition_variable *;
+inline void pthread_create(pthread_t *thread, std::nullptr_t,
+ void *(*start_routine)(void *), void *arg) {
*thread = new std::thread(start_routine, arg);
}
-inline void pthread_join(pthread_t thread, std::nullptr_t) {
- thread->join();
-}
+inline void pthread_join(pthread_t thread, std::nullptr_t) { thread->join(); }
inline void pthread_mutex_init(pthread_mutex_t *mutex, std::nullptr_t) {
*mutex = new std::mutex;
}
-inline void pthread_mutex_lock(pthread_mutex_t* mutex) {
- (*mutex)->lock();
-}
-inline void pthread_mutex_unlock(pthread_mutex_t* mutex) {
- (*mutex)->unlock();
-}
-inline void pthread_mutex_destroy(pthread_mutex_t *mutex) {
- delete *mutex;
-}
+inline void pthread_mutex_lock(pthread_mutex_t *mutex) { (*mutex)->lock(); }
+inline void pthread_mutex_unlock(pthread_mutex_t *mutex) { (*mutex)->unlock(); }
+inline void pthread_mutex_destroy(pthread_mutex_t *mutex) { delete *mutex; }
inline void pthread_cond_init(pthread_cond_t *cond, std::nullptr_t) {
*cond = new std::condition_variable;
}
-inline void pthread_cond_signal(pthread_cond_t* cond) {
- (*cond)->notify_one();
-}
+inline void pthread_cond_signal(pthread_cond_t *cond) { (*cond)->notify_one(); }
inline void pthread_cond_wait(pthread_cond_t *cond, pthread_mutex_t *mutex) {
std::unique_lock<std::mutex> lock(**mutex, std::adopt_lock);
(*cond)->wait(lock);
@@ -79,10 +67,8 @@
// the lock is not released
lock.release();
}
-inline void pthread_cond_destroy(pthread_cond_t *cond) {
- delete *cond;
-}
+inline void pthread_cond_destroy(pthread_cond_t *cond) { delete *cond; }
} // end namespace gemmlowp
#endif
-#endif // GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_
\ No newline at end of file
+#endif // GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_
diff --git a/public/output_stages.h b/public/output_stages.h
index 23bcdc0..1d5fca4 100644
--- a/public/output_stages.h
+++ b/public/output_stages.h
@@ -66,8 +66,9 @@
};
// This output stage takes int32 values and returns still int32 values,
-// but "quantized down" to the uint8 scale; in other words, its output
-// is typically what one would then clamp to [0..255] and cast to uint8
+// but "quantized down" to a difference scale; for example, in a pipeline
+// that outputs uint8 values in [0..255], the output of this stage would be
+// int32 values ready to be clamped to [0..255] and casted to uint8
// (see OutputStageSaturatingCastToUint8).
//
// This "quantization down" process depends on 3 parameters,
@@ -111,17 +112,42 @@
// expansions that implicitly rely on 0-padding. If 0 were not
// a representable value, such operations would have to pad
// using a nonzero value, introducing bias in the computation.
-struct OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint {
+struct OutputStageQuantizeDownInt32ByFixedPoint {
std::int32_t result_fixedpoint_multiplier;
std::int32_t result_shift;
std::int32_t result_offset_after_shift;
};
+// OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint is the old deprecated
+// name of OutputStageQuantizeDownInt32ByFixedPoint, before we noticed that
+// there really wasn't anything Uint8-specific about it.
+using OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint = OutputStageQuantizeDownInt32ByFixedPoint;
+
+// Variant of OutputStageQuantizeDownInt32ByFixedPoint where the 'shift'
+// is not necessarily just a right shift, so we can represent multipliers
+// greater than 1. This takes an result_exponent parameter; when it's
+// <= 0, this is equivalent to OutputStageQuantizeDownInt32ByFixedPoint
+// with result_shift = -result_exponent.
+// In the general case, this consists in first left-shifting by
+// std::max(result_exponent, 0), before doing the same as
+// OutputStageQuantizeDownInt32ByFixedPoint with
+// result_shift = std::max(-result_exponent, 0).
+struct OutputStageScaleInt32ByFixedPointAndExponent {
+ std::int32_t result_fixedpoint_multiplier;
+ std::int32_t result_exponent;
+ std::int32_t result_offset_after_shift;
+};
+
// This output stage takes int32 values that are expected to be already
// on the final uint8 scale, but not necessarily in the [0..255] range.
// It clamps them to the [0..255] range and returns them casted to uint8.
struct OutputStageSaturatingCastToUint8 {};
+// This output stage takes int32 values that are expected to be already
+// on the final int16 scale, but not necessarily in the [-32768..32767] range.
+// It clamps them to the [-32768..32767] range and returns them casted to int16.
+struct OutputStageSaturatingCastToInt16 {};
+
// This output stage depends on a "bias vector" that should contain int32
// entries, and be either a row-vector of the same number of columns as the
// result matrix, or a column-vector of the same number of rows as the
diff --git a/scripts/ci-test.sh b/scripts/ci-test.sh
index de6e344..83cc5cd 100755
--- a/scripts/ci-test.sh
+++ b/scripts/ci-test.sh
@@ -11,4 +11,4 @@
fi
if [ $TEST == "x86" ]; then
make -f Makefile.travis unittest
-fi
+fi
diff --git a/standalone/neon-gemm-kernel-benchmark.cc b/standalone/neon-gemm-kernel-benchmark.cc
index 2a936c1..bff33fb 100644
--- a/standalone/neon-gemm-kernel-benchmark.cc
+++ b/standalone/neon-gemm-kernel-benchmark.cc
@@ -61,15 +61,30 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
+#include <cstring>
#include <iostream>
#include <random>
#include <type_traits>
-#if !defined __arm__ && !defined __aarch64__
-#error This benchmark assumes ARM (for inline assembly sections).
+#if !defined(__arm__) && !defined(__aarch64__) && \
+ !(defined(__mips) && (__mips_isa_rev >= 5) && defined(__mips_msa))
+#error This benchmark assumes ARM or MIPS (for intrinsics and inline assembly sections).
#endif
+#if defined(__arm__) || defined(__aarch64__)
#include <arm_neon.h>
+#endif
+
+#if defined(__mips)
+#include <msa.h>
+
+// Some convenience macros to hide differences between MIPS32 and MIPS64.
+#ifdef __LP64__
+#define GEMMLOWP_MIPS_XADDIU "daddiu"
+#else
+#define GEMMLOWP_MIPS_XADDIU "addiu"
+#endif
+#endif
// Typically one wants to fit in L1 cache, and GEMM implementations
// are carefully optimized to tune their access patterns to that effect.
@@ -2501,6 +2516,291 @@
}
};
+#ifdef __ARM_FEATURE_DOTPROD
+// Kernels utilizing the Armv8.2 Dot Product extension.
+//
+// The dot product instructions work by taking 4 consecutive 8-bit depth
+// values from each operand, multiplying the 4 pairs together and
+// accumulating all the results into the corresponding 32-bit accumulator
+// lane. As such, the operation is identical to a 32-bit instruction (like
+// FMLA used in SGEMM), except that 4 depth values are processed at a time
+// instead of 1.
+
+// Thus, this first kernel is a carbon copy of
+// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
+// performance for most processors) below with the opcode (fmla -> udot) and
+// types (float32 -> uint8/uint32) changed.
+//
+// A signed version of this kernel could be produced by replacing "udot"
+// with "sdot" - performance should be identical to this udot kernel.
+struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // The start of the loop assumes first Rhs cell is already loaded, so
+ // do it here for first iteration.
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+
+ // And the same for the first Lhs cell.
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ // Start the MACs at the head of the loop - 1st cell from each side
+ // already loaded.
+ "udot v8.4s, v2.16b, v0.b[0]\n"
+ "udot v9.4s, v2.16b, v0.b[1]\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
+ "udot v10.4s, v2.16b, v0.b[2]\n"
+ "udot v11.4s, v2.16b, v0.b[3]\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
+ "udot v12.4s, v2.16b, v1.b[0]\n"
+ "udot v13.4s, v2.16b, v1.b[1]\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
+ "udot v14.4s, v2.16b, v1.b[2]\n"
+ "udot v15.4s, v2.16b, v1.b[3]\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
+ // for the next iteration early.
+ "udot v16.4s, v3.16b, v0.b[0]\n"
+ "udot v17.4s, v3.16b, v0.b[1]\n"
+ "udot v18.4s, v3.16b, v0.b[2]\n"
+ "udot v19.4s, v3.16b, v0.b[3]\n"
+ "udot v20.4s, v3.16b, v1.b[0]\n"
+ "udot v21.4s, v3.16b, v1.b[1]\n"
+ "udot v22.4s, v3.16b, v1.b[2]\n"
+ "udot v23.4s, v3.16b, v1.b[3]\n"
+ "udot v24.4s, v4.16b, v0.b[0]\n"
+ "udot v25.4s, v4.16b, v0.b[1]\n"
+ "udot v26.4s, v4.16b, v0.b[2]\n"
+ "udot v27.4s, v4.16b, v0.b[3]\n"
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
+ // load for the next iteration early.
+ "udot v28.4s, v4.16b, v1.b[0]\n"
+ "udot v29.4s, v4.16b, v1.b[1]\n"
+
+ // Loop. Decrement loop index (depth) by 4 as udot processes 4
+ // depth values.
+ "subs %w[depth], %w[depth], #4\n"
+ "udot v30.4s, v4.16b, v1.b[2]\n"
+ "udot v31.4s, v4.16b, v1.b[3]\n"
+
+ "bne " GEMMLOWP_LABEL_LOOP
+ "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.16b}, [x0], #16\n"
+ "st1 {v16.16b}, [x0], #16\n"
+ "st1 {v24.16b}, [x0], #16\n"
+ "st1 {v9.16b}, [x0], #16\n"
+ "st1 {v17.16b}, [x0], #16\n"
+ "st1 {v25.16b}, [x0], #16\n"
+ "st1 {v10.16b}, [x0], #16\n"
+ "st1 {v18.16b}, [x0], #16\n"
+ "st1 {v26.16b}, [x0], #16\n"
+ "st1 {v11.16b}, [x0], #16\n"
+ "st1 {v19.16b}, [x0], #16\n"
+ "st1 {v27.16b}, [x0], #16\n"
+ "st1 {v12.16b}, [x0], #16\n"
+ "st1 {v20.16b}, [x0], #16\n"
+ "st1 {v28.16b}, [x0], #16\n"
+ "st1 {v13.16b}, [x0], #16\n"
+ "st1 {v21.16b}, [x0], #16\n"
+ "st1 {v29.16b}, [x0], #16\n"
+ "st1 {v14.16b}, [x0], #16\n"
+ "st1 {v22.16b}, [x0], #16\n"
+ "st1 {v30.16b}, [x0], #16\n"
+ "st1 {v15.16b}, [x0], #16\n"
+ "st1 {v23.16b}, [x0], #16\n"
+ "st1 {v31.16b}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
+ "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
+ "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
+ "v28", "v29", "v30", "v31");
+ }
+};
+
+// As above, except tuned for Cortex-A55r1.
+//
+// Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1
+// with the names changed.
+struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // For details on how this kernel works, see the Float32 kernel below.
+
+ "ldr d0, [%[rhs_ptr]]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n"
+
+ "ldr q2, [%[lhs_ptr]]\n"
+ "ldr q3, [%[lhs_ptr], #16]\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ "udot v8.4s, v2.16b, v0.b[0]\n"
+ "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
+ "udot v9.4s, v2.16b, v0.b[1]\n"
+ "ins v0.d[1], x18\n" // Finish loading v0
+ "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure.
+ "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
+ "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure.
+ "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
+ "udot v10.4s, v2.16b, v0.b[2]\n"
+ "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
+ "udot v11.4s, v2.16b, v0.b[3]\n"
+ "ins v1.d[1], x18\n" // Finish loading v1
+ "udot v12.4s, v2.16b, v1.b[0]\n"
+ "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
+ "udot v13.4s, v2.16b, v1.b[1]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
+ "udot v14.4s, v2.16b, v1.b[2]\n"
+
+ "udot v15.4s, v2.16b, v1.b[3]\n"
+ "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
+ "udot v18.4s, v3.16b, v0.b[2]\n"
+ "ins v4.d[1], x18\n" // Finish loading v4
+ "udot v19.4s, v3.16b, v0.b[3]\n"
+ "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
+ "udot v20.4s, v3.16b, v1.b[0]\n"
+ "subs %w[depth], %w[depth], #4\n"
+ "udot v21.4s, v3.16b, v1.b[1]\n"
+
+ "udot v22.4s, v3.16b, v1.b[2]\n"
+
+ "udot v23.4s, v3.16b, v1.b[3]\n"
+ "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
+ "udot v24.4s, v4.16b, v0.b[0]\n"
+ "ins v2.d[1], x18\n" // Finish loading next v2
+ "udot v25.4s, v4.16b, v0.b[1]\n"
+ "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
+ "udot v26.4s, v4.16b, v0.b[2]\n"
+
+ "udot v27.4s, v4.16b, v0.b[3]\n"
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
+ "udot v28.4s, v4.16b, v1.b[0]\n"
+ "ins v3.d[1], x18\n" // Finish loading next v3
+ "udot v29.4s, v4.16b, v1.b[1]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
+ "udot v30.4s, v4.16b, v1.b[2]\n"
+
+ "udot v31.4s, v4.16b, v1.b[3]\n"
+ "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.4s}, [x0], #16\n"
+ "st1 {v16.4s}, [x0], #16\n"
+ "st1 {v24.4s}, [x0], #16\n"
+ "st1 {v9.4s}, [x0], #16\n"
+ "st1 {v17.4s}, [x0], #16\n"
+ "st1 {v25.4s}, [x0], #16\n"
+ "st1 {v10.4s}, [x0], #16\n"
+ "st1 {v18.4s}, [x0], #16\n"
+ "st1 {v26.4s}, [x0], #16\n"
+ "st1 {v11.4s}, [x0], #16\n"
+ "st1 {v19.4s}, [x0], #16\n"
+ "st1 {v27.4s}, [x0], #16\n"
+ "st1 {v12.4s}, [x0], #16\n"
+ "st1 {v20.4s}, [x0], #16\n"
+ "st1 {v28.4s}, [x0], #16\n"
+ "st1 {v13.4s}, [x0], #16\n"
+ "st1 {v21.4s}, [x0], #16\n"
+ "st1 {v29.4s}, [x0], #16\n"
+ "st1 {v14.4s}, [x0], #16\n"
+ "st1 {v22.4s}, [x0], #16\n"
+ "st1 {v30.4s}, [x0], #16\n"
+ "st1 {v15.4s}, [x0], #16\n"
+ "st1 {v23.4s}, [x0], #16\n"
+ "st1 {v31.4s}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+ }
+};
+#endif // __ARM_FEATURE_DOTPROD
+
// We don't actually use int32*int32 in production. This is just an
// experiment to help dissociate the effect of integer-vs-float, from the
// effect of operands width.
@@ -3203,8 +3503,172 @@
};
#endif
+// Faster kernel contributed by ARM. Tuned for A55r1.
+struct NEON_64bit_GEMM_Float32_WithScalar_A55r1 {
+ typedef float OperandType;
+ typedef float AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // A55r1 requires a hybrid of the A53 and standard approaches.
+ //
+ // Like A53, this processor prefers 64-bit loads.
+ //
+ // Unlike A53, it is capable of dual-issuing a 64-bit vector load
+ // (or INS) with a FMLA instruction.
+ //
+ // Therefore we aim to issue an FMLA instruction every cycle.
+ // Alongside three FMLAs we can dual issue a (vector) 64-bit load, a
+ // scalar 64-bit load and finally an INS to replicate the effect of
+ // a single 128-bit load.
+ //
+ // The loop contains 24 FMLA instructions, and 5 vector registers
+ // need to be loaded, consuming 15 dual issue slots. This leaves 9
+ // dual issue slots. Four of these are used for loop housekeeping
+ // (2 pointer adds, 1 counter update and 1 branch), leaving 5 left
+ // over (marked by blank lines).
+ //
+ // Choice of x18 to store the upper halves on their way into the
+ // vector registers is arbitrary. Added to the clobber list so that
+ // the compiler will make it available.
+
+
+ // At the start of the loop, it is assumed that v0 is "half loaded" -
+ // bottom half in place in d0 and the upper half in x18 ready to
+ // insert. So set that up here for the first iteration:
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell
+ "ldr x18, [%[rhs_ptr], #8]\n" // Upper half
+
+ // v2-v3 should be fully loaded - as it's outside the loop proper it's fine
+ // to use a 128-bit load here.
+ "ldr q2, [%[lhs_ptr]]\n" // first Lhs cell
+ "ldr q3, [%[lhs_ptr], #16]\n" // second Lhs cell
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ "fmla v8.4s, v2.4s, v0.s[0]\n"
+ "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
+ "fmla v9.4s, v2.4s, v0.s[1]\n"
+ "ins v0.d[1], x18\n" // Finish loading v0
+ "fmla v16.4s, v3.4s, v0.s[0]\n" // out of sequence - used to reduce load/use pressure.
+ "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
+ "fmla v17.4s, v3.4s, v0.s[1]\n" // out of sequence - used to reduce load/use pressure.
+ "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
+ "fmla v10.4s, v2.4s, v0.s[2]\n"
+ "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
+ "fmla v11.4s, v2.4s, v0.s[3]\n"
+ "ins v1.d[1], x18\n" // Finish loading v1
+ "fmla v12.4s, v2.4s, v1.s[0]\n"
+ "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
+ "fmla v13.4s, v2.4s, v1.s[1]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
+ "fmla v14.4s, v2.4s, v1.s[2]\n"
+
+ "fmla v15.4s, v2.4s, v1.s[3]\n"
+ "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
+ "fmla v18.4s, v3.4s, v0.s[2]\n"
+ "ins v4.d[1], x18\n" // Finish loading v4
+ "fmla v19.4s, v3.4s, v0.s[3]\n"
+ "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
+ "fmla v20.4s, v3.4s, v1.s[0]\n"
+ "subs %w[depth], %w[depth], #1\n"
+ "fmla v21.4s, v3.4s, v1.s[1]\n"
+
+ "fmla v22.4s, v3.4s, v1.s[2]\n"
+
+ "fmla v23.4s, v3.4s, v1.s[3]\n"
+ "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
+ "fmla v24.4s, v4.4s, v0.s[0]\n"
+ "ins v2.d[1], x18\n" // Finish loading next v2
+ "fmla v25.4s, v4.4s, v0.s[1]\n"
+ "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
+ "fmla v26.4s, v4.4s, v0.s[2]\n"
+
+ "fmla v27.4s, v4.4s, v0.s[3]\n"
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
+ "fmla v28.4s, v4.4s, v1.s[0]\n"
+ "ins v3.d[1], x18\n" // Finish loading next v3
+ "fmla v29.4s, v4.4s, v1.s[1]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
+ "fmla v30.4s, v4.4s, v1.s[2]\n"
+
+ "fmla v31.4s, v4.4s, v1.s[3]\n"
+ "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.4s}, [x0], #16\n"
+ "st1 {v16.4s}, [x0], #16\n"
+ "st1 {v24.4s}, [x0], #16\n"
+ "st1 {v9.4s}, [x0], #16\n"
+ "st1 {v17.4s}, [x0], #16\n"
+ "st1 {v25.4s}, [x0], #16\n"
+ "st1 {v10.4s}, [x0], #16\n"
+ "st1 {v18.4s}, [x0], #16\n"
+ "st1 {v26.4s}, [x0], #16\n"
+ "st1 {v11.4s}, [x0], #16\n"
+ "st1 {v19.4s}, [x0], #16\n"
+ "st1 {v27.4s}, [x0], #16\n"
+ "st1 {v12.4s}, [x0], #16\n"
+ "st1 {v20.4s}, [x0], #16\n"
+ "st1 {v28.4s}, [x0], #16\n"
+ "st1 {v13.4s}, [x0], #16\n"
+ "st1 {v21.4s}, [x0], #16\n"
+ "st1 {v29.4s}, [x0], #16\n"
+ "st1 {v14.4s}, [x0], #16\n"
+ "st1 {v22.4s}, [x0], #16\n"
+ "st1 {v30.4s}, [x0], #16\n"
+ "st1 {v15.4s}, [x0], #16\n"
+ "st1 {v23.4s}, [x0], #16\n"
+ "st1 {v31.4s}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+ }
+};
+
#endif // __aarch64__
+#if defined(__arm__) || defined(__aarch64__)
#ifndef __aarch64__
inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
const int32x2_t c = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
@@ -3388,6 +3852,974 @@
using NEON_64bit_GEMM_Float32_WithScalar_intrinsics =
NEON_GEMM_Float32_WithScalar_intrinsics<2>;
+#endif // __arm__ || __aarch64__
+
+#ifdef __mips
+static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
+ // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
+#if 0
+ return __builtin_msa_maddv_w(a, b, c);
+#else
+ asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
+ // Outputs
+ : [a] "+f"(a)
+ // Inputs
+ : [b] "f"(b), [c] "f"(c));
+ return a;
+#endif
+}
+
+// Using 32x32=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~55 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ v4i32 acc[3][4];
+ // Load accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+
+ while (depth > 0) {
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ v8i16 lhs[6];
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
+ lhs[1] =
+ reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[0])));
+ lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+ lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+ lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+
+ // Depth 0.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+
+ // Depth 1.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+
+ lhs_ptr += 24;
+ rhs_ptr += 8;
+ depth -= 2;
+ }
+
+ // Store accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics.
+// Using 32x32=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~55 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w19, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2x4 cell of Rhs is stored in 32bit in w18.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17.
+ // A 12x4 block of accumulators is stored in 32bit in w0-w11.
+ //
+ // +------+------+------+------+
+ // Rhs |w18[0]|w18[1]|w18[2]|w18[3]|
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+---+ - - - - +------+------+------+------+
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // +---+---+ - - - - +------+------+------+------+
+ //
+ // Accumulator
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w12, 0(%[lhs_ptr])\n"
+ "ld.b $w13, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w12, $w19, $w12\n"
+ "ilvl.b $w14, $w19, $w13\n"
+ "ilvr.b $w13, $w19, $w13\n"
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ "ilvl.h $w15, $w19, $w12\n"
+ "ilvl.h $w16, $w19, $w13\n"
+ "ilvl.h $w17, $w19, $w14\n"
+ "ilvr.h $w12, $w19, $w12\n"
+ "ilvr.h $w13, $w19, $w13\n"
+ "ilvr.h $w14, $w19, $w14\n"
+
+ // Depth 0.
+ "fill.w $w18, $a0\n"
+ "lbu $a0, 4(%[rhs_ptr])\n"
+ "maddv.w $w0, $w12, $w18\n"
+ "maddv.w $w4, $w13, $w18\n"
+ "maddv.w $w8, $w14, $w18\n"
+ "fill.w $w18, $a1\n"
+ "lbu $a1, 5(%[rhs_ptr])\n"
+ "maddv.w $w1, $w12, $w18\n"
+ "maddv.w $w5, $w13, $w18\n"
+ "maddv.w $w9, $w14, $w18\n"
+ "fill.w $w18, $a2\n"
+ "lbu $a2, 6(%[rhs_ptr])\n"
+ "maddv.w $w2, $w12, $w18\n"
+ "maddv.w $w6, $w13, $w18\n"
+ "maddv.w $w10, $w14, $w18\n"
+ "fill.w $w18, $a3\n"
+ "lbu $a3, 7(%[rhs_ptr])\n"
+ "maddv.w $w3, $w12, $w18\n"
+ "maddv.w $w7, $w13, $w18\n"
+ "maddv.w $w11, $w14, $w18\n"
+
+ // Depth 1.
+ "fill.w $w18, $a0\n"
+ "maddv.w $w0, $w15, $w18\n"
+ "maddv.w $w4, $w16, $w18\n"
+ "maddv.w $w8, $w17, $w18\n"
+ "fill.w $w18, $a1\n"
+ "maddv.w $w1, $w15, $w18\n"
+ "maddv.w $w5, $w16, $w18\n"
+ "maddv.w $w9, $w17, $w18\n"
+ "fill.w $w18, $a2\n"
+ "maddv.w $w2, $w15, $w18\n"
+ "maddv.w $w6, $w16, $w18\n"
+ "maddv.w $w10, $w17, $w18\n"
+ "fill.w $w18, $a3\n"
+ "maddv.w $w3, $w15, $w18\n"
+ "maddv.w $w7, $w16, $w18\n"
+ "maddv.w $w11, $w17, $w18\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "a0", "a1", "a2", "a3",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19");
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
+// Using 16x16=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 3 lhs
+// - 4 rhs
+// - 1 temps/zeroes
+// ~45 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w19, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register
+ // contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14.
+ // A 12x4 block of accumulators is stored in 32bit in w0-w11.
+ //
+ // +-----+-----+-----+-----+
+ // Rhs | w15 | w16 | w17 | w18 |
+ // +-----+-----+-----+-----+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w12, 0(%[lhs_ptr])\n"
+ "ld.b $w13, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w12, $w19, $w12\n"
+ "ilvl.b $w14, $w19, $w13\n"
+ "ilvr.b $w13, $w19, $w13\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w15, $w19, $w12\n"
+ "ilvl.d $w16, $w19, $w13\n"
+ "ilvl.d $w17, $w19, $w14\n"
+ "ilvr.h $w12, $w15, $w12\n"
+ "ilvr.h $w13, $w16, $w13\n"
+ "ilvr.h $w14, $w17, $w14\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w.
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w15, $a0\n"
+ "fill.w $w16, $a1\n"
+ "fill.w $w17, $a2\n"
+ "fill.w $w18, $a3\n"
+
+ // Depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w12, $w15\n"
+ "dpadd_u.w $w4, $w13, $w15\n"
+ "dpadd_u.w $w8, $w14, $w15\n"
+ "dpadd_u.w $w1, $w12, $w16\n"
+ "dpadd_u.w $w5, $w13, $w16\n"
+ "dpadd_u.w $w9, $w14, $w16\n"
+ "dpadd_u.w $w2, $w12, $w17\n"
+ "dpadd_u.w $w6, $w13, $w17\n"
+ "dpadd_u.w $w10, $w14, $w17\n"
+ "dpadd_u.w $w3, $w12, $w18\n"
+ "dpadd_u.w $w7, $w13, $w18\n"
+ "dpadd_u.w $w11, $w14, $w18\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "v0", "v1",
+ "a0", "a1", "a2", "a3",
+ "t8", "t9",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19");
+ }
+};
+
+// Using 32x32=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~95 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ v4i32 acc[3][8];
+ // Load accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 8; j++) {
+ acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+
+ while (depth > 0) {
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ v8i16 lhs[6];
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
+ lhs[1] =
+ reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[0])));
+ lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+ lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+ lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+
+ // Depth 0.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+ for (int j = 4; j < 8; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+
+ // Depth 1.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+ for (int j = 4; j < 8; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+
+ lhs_ptr += 24;
+ rhs_ptr += 16;
+ depth -= 2;
+ }
+
+ // Store accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 8; j++) {
+ __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics.
+// Using 32x32=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~95 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ "ld.w $w12, (12*16)(%[accum_ptr])\n"
+ "ld.w $w16, (13*16)(%[accum_ptr])\n"
+ "ld.w $w20, (14*16)(%[accum_ptr])\n"
+ "ld.w $w13, (15*16)(%[accum_ptr])\n"
+ "ld.w $w17, (16*16)(%[accum_ptr])\n"
+ "ld.w $w21, (17*16)(%[accum_ptr])\n"
+ "ld.w $w14, (18*16)(%[accum_ptr])\n"
+ "ld.w $w18, (19*16)(%[accum_ptr])\n"
+ "ld.w $w22, (20*16)(%[accum_ptr])\n"
+ "ld.w $w15, (21*16)(%[accum_ptr])\n"
+ "ld.w $w19, (22*16)(%[accum_ptr])\n"
+ "ld.w $w23, (23*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w30[0]|w30[1]|w30[2]|w30[3]|
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+---+ - - - - +------+------+------+------+
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+---+ - - - - +------+------+------+------+
+ //
+ // Accumulator
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ "ilvl.h $w27, $w31, $w24\n"
+ "ilvl.h $w28, $w31, $w25\n"
+ "ilvl.h $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w31, $w24\n"
+ "ilvr.h $w25, $w31, $w25\n"
+ "ilvr.h $w26, $w31, $w26\n"
+
+ // Depth 0.
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "maddv.w $w0, $w24, $w30\n"
+ "maddv.w $w4, $w25, $w30\n"
+ "maddv.w $w8, $w26, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "maddv.w $w1, $w24, $w30\n"
+ "maddv.w $w5, $w25, $w30\n"
+ "maddv.w $w9, $w26, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "maddv.w $w2, $w24, $w30\n"
+ "maddv.w $w6, $w25, $w30\n"
+ "maddv.w $w10, $w26, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ "maddv.w $w3, $w24, $w30\n"
+ "maddv.w $w7, $w25, $w30\n"
+ "maddv.w $w11, $w26, $w30\n"
+
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 4(%[rhs_ptr])\n"
+ "maddv.w $w12, $w24, $w30\n"
+ "maddv.w $w16, $w25, $w30\n"
+ "maddv.w $w20, $w26, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 5(%[rhs_ptr])\n"
+ "maddv.w $w13, $w24, $w30\n"
+ "maddv.w $w17, $w25, $w30\n"
+ "maddv.w $w21, $w26, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 6(%[rhs_ptr])\n"
+ "maddv.w $w14, $w24, $w30\n"
+ "maddv.w $w18, $w25, $w30\n"
+ "maddv.w $w22, $w26, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 7(%[rhs_ptr])\n"
+ "maddv.w $w15, $w24, $w30\n"
+ "maddv.w $w19, $w25, $w30\n"
+ "maddv.w $w23, $w26, $w30\n"
+
+ // Depth 1.
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 12(%[rhs_ptr])\n"
+ "maddv.w $w0, $w27, $w30\n"
+ "maddv.w $w4, $w28, $w30\n"
+ "maddv.w $w8, $w29, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 13(%[rhs_ptr])\n"
+ "maddv.w $w1, $w27, $w30\n"
+ "maddv.w $w5, $w28, $w30\n"
+ "maddv.w $w9, $w29, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 14(%[rhs_ptr])\n"
+ "maddv.w $w2, $w27, $w30\n"
+ "maddv.w $w6, $w28, $w30\n"
+ "maddv.w $w10, $w29, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 15(%[rhs_ptr])\n"
+ "maddv.w $w3, $w27, $w30\n"
+ "maddv.w $w7, $w28, $w30\n"
+ "maddv.w $w11, $w29, $w30\n"
+
+ "fill.w $w30, $a0\n"
+ "maddv.w $w12, $w27, $w30\n"
+ "maddv.w $w16, $w28, $w30\n"
+ "maddv.w $w20, $w29, $w30\n"
+ "fill.w $w30, $a1\n"
+ "maddv.w $w13, $w27, $w30\n"
+ "maddv.w $w17, $w28, $w30\n"
+ "maddv.w $w21, $w29, $w30\n"
+ "fill.w $w30, $a2\n"
+ "maddv.w $w14, $w27, $w30\n"
+ "maddv.w $w18, $w28, $w30\n"
+ "maddv.w $w22, $w29, $w30\n"
+ "fill.w $w30, $a3\n"
+ "maddv.w $w15, $w27, $w30\n"
+ "maddv.w $w19, $w28, $w30\n"
+ "maddv.w $w23, $w29, $w30\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ "st.w $w12, (12*16)(%[accum_ptr])\n"
+ "st.w $w16, (13*16)(%[accum_ptr])\n"
+ "st.w $w20, (14*16)(%[accum_ptr])\n"
+ "st.w $w13, (15*16)(%[accum_ptr])\n"
+ "st.w $w17, (16*16)(%[accum_ptr])\n"
+ "st.w $w21, (17*16)(%[accum_ptr])\n"
+ "st.w $w14, (18*16)(%[accum_ptr])\n"
+ "st.w $w18, (19*16)(%[accum_ptr])\n"
+ "st.w $w22, (20*16)(%[accum_ptr])\n"
+ "st.w $w15, (21*16)(%[accum_ptr])\n"
+ "st.w $w19, (22*16)(%[accum_ptr])\n"
+ "st.w $w23, (23*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "a0", "a1", "a2", "a3",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
+ "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
+// Using 16x16=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 3 lhs
+// - 4 rhs
+// - 1 temps/zeroes
+// ~70 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ "ld.w $w12, (12*16)(%[accum_ptr])\n"
+ "ld.w $w16, (13*16)(%[accum_ptr])\n"
+ "ld.w $w20, (14*16)(%[accum_ptr])\n"
+ "ld.w $w13, (15*16)(%[accum_ptr])\n"
+ "ld.w $w17, (16*16)(%[accum_ptr])\n"
+ "ld.w $w21, (17*16)(%[accum_ptr])\n"
+ "ld.w $w14, (18*16)(%[accum_ptr])\n"
+ "ld.w $w18, (19*16)(%[accum_ptr])\n"
+ "ld.w $w22, (20*16)(%[accum_ptr])\n"
+ "ld.w $w15, (21*16)(%[accum_ptr])\n"
+ "ld.w $w19, (22*16)(%[accum_ptr])\n"
+ "ld.w $w23, (23*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // (each register contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w27 |w28 |w29 |w30 |
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +------+------+------+------+
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+ - - - - +------+------+------+------+
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+ - - - - +------+------+------+------+
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+ - - - - +------+------+------+------+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w27, $w31, $w24\n"
+ "ilvl.d $w28, $w31, $w25\n"
+ "ilvl.d $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w27, $w24\n"
+ "ilvr.h $w25, $w28, $w25\n"
+ "ilvr.h $w26, $w29, $w26\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the first half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Load 4 bytes of rhs[] for the second half of depth 0.
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the second half of depth 1.
+ "lbu $v0, 12(%[rhs_ptr])\n"
+ "lbu $v1, 13(%[rhs_ptr])\n"
+ "lbu $t8, 14(%[rhs_ptr])\n"
+ "lbu $t9, 15(%[rhs_ptr])\n"
+
+ // First half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w27\n"
+ "dpadd_u.w $w4, $w25, $w27\n"
+ "dpadd_u.w $w8, $w26, $w27\n"
+ "dpadd_u.w $w1, $w24, $w28\n"
+ "dpadd_u.w $w5, $w25, $w28\n"
+ "dpadd_u.w $w9, $w26, $w28\n"
+ "dpadd_u.w $w2, $w24, $w29\n"
+ "dpadd_u.w $w6, $w25, $w29\n"
+ "dpadd_u.w $w10, $w26, $w29\n"
+ "dpadd_u.w $w3, $w24, $w30\n"
+ "dpadd_u.w $w7, $w25, $w30\n"
+ "dpadd_u.w $w11, $w26, $w30\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the second half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Second half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w12, $w24, $w27\n"
+ "dpadd_u.w $w16, $w25, $w27\n"
+ "dpadd_u.w $w20, $w26, $w27\n"
+ "dpadd_u.w $w13, $w24, $w28\n"
+ "dpadd_u.w $w17, $w25, $w28\n"
+ "dpadd_u.w $w21, $w26, $w28\n"
+ "dpadd_u.w $w14, $w24, $w29\n"
+ "dpadd_u.w $w18, $w25, $w29\n"
+ "dpadd_u.w $w22, $w26, $w29\n"
+ "dpadd_u.w $w15, $w24, $w30\n"
+ "dpadd_u.w $w19, $w25, $w30\n"
+ "dpadd_u.w $w23, $w26, $w30\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ "st.w $w12, (12*16)(%[accum_ptr])\n"
+ "st.w $w16, (13*16)(%[accum_ptr])\n"
+ "st.w $w20, (14*16)(%[accum_ptr])\n"
+ "st.w $w13, (15*16)(%[accum_ptr])\n"
+ "st.w $w17, (16*16)(%[accum_ptr])\n"
+ "st.w $w21, (17*16)(%[accum_ptr])\n"
+ "st.w $w14, (18*16)(%[accum_ptr])\n"
+ "st.w $w18, (19*16)(%[accum_ptr])\n"
+ "st.w $w22, (20*16)(%[accum_ptr])\n"
+ "st.w $w15, (21*16)(%[accum_ptr])\n"
+ "st.w $w19, (22*16)(%[accum_ptr])\n"
+ "st.w $w23, (23*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "v0", "v1",
+ "a0", "a1", "a2", "a3",
+ "t8", "t9",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
+ "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ }
+};
+#endif // __mips
// BEGIN code copied from gemmlowp/internal/kernel_reference.h
@@ -3451,8 +4883,9 @@
data_ = nullptr;
// Adds a few bytes of padding here, because the 64-bit 'A57' kernel
// reads one iteration past the end the buffer, causing a crash on iOS.
- posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
- size_ * sizeof(DataType) + 16);
+ int res = posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
+ size_ * sizeof(DataType) + 16);
+ (void)res;
}
~CacheLineAlignedBuffer() { free(data_); }
@@ -3460,7 +4893,7 @@
const DataType* data() const { return data_; }
DataType* data() { return data_; }
- const std::size_t size() const { return size_; }
+ std::size_t size() const { return size_; }
private:
const std::size_t size_;
@@ -3726,12 +5159,15 @@
#endif
#ifdef __aarch64__
-
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57);
+#ifdef __ARM_FEATURE_DOTPROD
+ BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct);
+ BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1);
+#endif
BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar);
BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar);
BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar);
@@ -3740,6 +5176,16 @@
#ifndef __APPLE__
BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A53);
#endif
+ BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A55r1);
+#endif
+
+#ifdef __mips
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics);
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly);
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2);
#endif
return 0;