blob: 182fe1f98bdb9c2acdf0fd53ae7b53bc35e27fd2 [file] [log] [blame]
// Copyright 2015 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// unpack_neon.h: optimized NEON specializations of the templates in unpack.h.
#ifndef GEMMLOWP_INTERNAL_UNPACK_NEON_H_
#define GEMMLOWP_INTERNAL_UNPACK_NEON_H_
#include "unpack.h"
#include <arm_neon.h>
namespace gemmlowp {
template <std::uint32_t numerator, std::uint32_t denominator>
int32x4_t MultiplyByConstantFraction(int32x4_t x) {
static_assert(numerator > 0 && denominator > 0,
"only supporting positive num/denom");
if (numerator == denominator) {
return x;
}
static const std::int32_t int_quotient =
(numerator + denominator / 2) / denominator;
static const std::int32_t remaining_numerator =
numerator - int_quotient * denominator;
static const std::int32_t scaled_remaining_numerator =
static_cast<std::int32_t>(
(static_cast<std::int64_t>(remaining_numerator) << 31) / denominator);
const int32x4_t remaining_product =
vqrdmulhq_n_s32(x, scaled_remaining_numerator);
return vmlaq_n_s32(remaining_product, x, int_quotient);
}
template <BitDepthSetting BitDepth, typename PackedResultType>
struct UnpackResultImpl<BitDepth, MatrixMap<std::uint8_t, MapOrder::ColMajor>,
PackedResultType> {
typedef MatrixMap<std::uint8_t, MapOrder::ColMajor> ResultBlockType;
static void Unpack(ResultBlockType* dst, const PackedResultType& src,
int depth, const std::int32_t* lhs_rank_one_update,
const std::int32_t* rhs_rank_one_update,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset, std::int32_t result_mult_int,
std::int32_t result_shift) {
ScopedProfilingLabel label("optimized path (NEON)");
const int kLhsBits = LhsBitDepth<BitDepth>::kBits;
const int kRhsBits = RhsBitDepth<BitDepth>::kBits;
const std::int32_t kLhsMax = (1 << kLhsBits) - 1;
const std::int32_t kRhsMax = (1 << kRhsBits) - 1;
auto src_map = src.Map();
const std::int32_t term_11 =
lhs_offset * rhs_offset * depth + result_offset;
const int32x4_t shift_reg = vdupq_n_s32(-result_shift);
const std::int32_t preshift_offset = 1 << (result_shift - 1);
const int32x4_t preshift_offset_reg = vdupq_n_s32(preshift_offset);
for (int c = 0; c < dst->cols(); c++) {
std::uint8_t* dst_ptr = dst->data(0, c);
const std::int32_t* src_ptr = src_map.data(0, c);
const std::int32_t* rank_one_update_ptr = lhs_rank_one_update;
const std::int32_t raw_1x = rhs_rank_one_update[c];
const std::int32_t term_1x =
MultiplyByConstantFraction<255, kRhsMax>(raw_1x);
const std::int32_t term_1x_plus_term_11 = term_1x + term_11;
int dst_rows_aligned16 = RoundDown<16>(dst->rows());
for (int r = 0; r < dst_rows_aligned16; r += 16) {
int32x4_t raw_xx[4];
for (int i = 0; i < 4; i++) {
raw_xx[i] = vld1q_s32(src_ptr);
src_ptr += 4;
}
int32x4_t raw_x1[4];
for (int i = 0; i < 4; i++) {
raw_x1[i] = vld1q_s32(rank_one_update_ptr);
rank_one_update_ptr += 4;
}
int32x4_t term_xx[4];
for (int i = 0; i < 4; i++) {
term_xx[i] =
MultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>(
raw_xx[i]);
}
int32x4_t term_x1[4];
for (int i = 0; i < 4; i++) {
term_x1[i] = MultiplyByConstantFraction<255, kLhsMax>(raw_x1[i]);
}
int32x4_t q[4];
for (int i = 0; i < 4; i++) {
q[i] = vaddq_s32(vaddq_s32(term_xx[i], term_x1[i]),
vdupq_n_s32(term_1x_plus_term_11));
}
for (int i = 0; i < 4; i++) {
q[i] = vmulq_n_s32(q[i], result_mult_int);
}
for (int i = 0; i < 4; i++) {
q[i] = vshlq_s32(vaddq_s32(q[i], preshift_offset_reg), shift_reg);
}
int16x4_t q16[4];
for (int i = 0; i < 4; i++) {
q16[i] = vqmovn_s32(q[i]);
}
uint8x8_t q8[4];
for (int i = 0; i < 4; i++) {
q8[i] = vqmovun_s16(vcombine_s16(q16[i], q16[i]));
}
for (int i = 0; i < 4; i++) {
vst1_lane_u32(reinterpret_cast<std::uint32_t*>(dst_ptr),
vreinterpret_u32_u8(q8[i]), 0);
dst_ptr += 4;
}
}
// We have finished handling groups of 16 entries at once; now
// try to handle 4 entries at once.
int dst_rows_aligned4 = RoundDown<4>(dst->rows());
for (int r = dst_rows_aligned16; r < dst_rows_aligned4; r += 4) {
const int32x4_t raw_xx = vld1q_s32(src_ptr);
const int32x4_t term_xx =
MultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>(raw_xx);
const int32x4_t raw_x1 = vld1q_s32(rank_one_update_ptr);
const int32x4_t term_x1 =
MultiplyByConstantFraction<255, kLhsMax>(raw_x1);
int32x4_t q = vaddq_s32(vaddq_s32(term_xx, term_x1),
vdupq_n_s32(term_1x_plus_term_11));
q = vmulq_n_s32(q, result_mult_int);
q = vshlq_s32(vaddq_s32(q, preshift_offset_reg), shift_reg);
int16x4_t q16 = vqmovn_s32(q);
uint8x8_t q8 = vqmovun_s16(vcombine_s16(q16, q16));
vst1_lane_u32(reinterpret_cast<std::uint32_t*>(dst_ptr),
vreinterpret_u32_u8(q8), 0);
dst_ptr += 4;
src_ptr += 4;
rank_one_update_ptr += 4;
}
// We have finished handling 4 entries at once; now handle
// remaining entries one by one.
for (int r = dst_rows_aligned4; r < dst->rows(); r++) {
std::int32_t raw_xx = src_map(r, c);
std::int32_t raw_x1 = lhs_rank_one_update[r];
std::int32_t term_xx =
MultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>(raw_xx);
std::int32_t term_x1 =
MultiplyByConstantFraction<255, kLhsMax>(raw_x1);
std::int32_t sum = term_xx + term_x1 + term_1x_plus_term_11;
std::int32_t result =
(sum * result_mult_int + (1 << (result_shift - 1))) >> result_shift;
(*dst)(r, c) = result > 255 ? 255 : result < 0 ? 0 : result;
}
}
}
};
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_UNPACK_NEON_H_