blob: 4f06cecc78ef95fe3ec6ad727ca3155403aaef3d [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <stdint.h>
#include <psimd.h>
#include <fp16/bitcasts.h>
#include <xnnpack/requantization-stubs.h>
void xnn_requantize_precise__psimd(
size_t n,
const int32_t* input,
float scale,
uint8_t zero_point,
uint8_t qmin,
uint8_t qmax,
uint8_t* output)
{
assert(n % 16 == 0);
assert(scale < 1.0f);
assert(scale >= 0x1.0p-32f);
const uint32_t scale_bits = fp32_to_bits(scale);
const uint32_t multiplier = (scale_bits << 8) | UINT32_C(0x80000000);
const uint32_t shift = 127 + 31 - (scale_bits >> 23);
assert(shift >= 32);
assert(shift < 64);
const uint64_t rounding = UINT64_C(1) << (shift - 1);
const psimd_u32 vmultiplier_lo = psimd_splat_u32(multiplier & UINT32_C(0x0000FFFF));
const psimd_u32 vmultiplier_hi = psimd_splat_u32(multiplier >> 16);
const psimd_s32 vzero_point = psimd_splat_s32((int32_t)(uint32_t) zero_point);
const psimd_s32 vsmin = psimd_splat_s32((int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point);
const psimd_s32 vsmax = psimd_splat_s32((int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point);
const psimd_u32 vrounding_lo = psimd_splat_u32((uint32_t) rounding);
const psimd_u32 vrounding_hi = psimd_splat_u32((uint32_t)(rounding >> 32));
const psimd_u32 vshift = psimd_splat_u32(shift - 32);
for (; n != 0; n -= 16) {
const psimd_s32 x = psimd_load_s32(input);
const psimd_s32 y = psimd_load_s32(input + 4);
const psimd_s32 z = psimd_load_s32(input + 8);
const psimd_s32 w = psimd_load_s32(input + 12);
input += 16;
const psimd_s32 x_neg_mask = x >> psimd_splat_s32(31);
const psimd_s32 y_neg_mask = y >> psimd_splat_s32(31);
const psimd_s32 z_neg_mask = z >> psimd_splat_s32(31);
const psimd_s32 w_neg_mask = w >> psimd_splat_s32(31);
const psimd_u32 x_abs = (psimd_u32)((x ^ x_neg_mask) - x_neg_mask);
const psimd_u32 y_abs = (psimd_u32)((y ^ y_neg_mask) - y_neg_mask);
const psimd_u32 z_abs = (psimd_u32)((z ^ z_neg_mask) - z_neg_mask);
const psimd_u32 w_abs = (psimd_u32)((w ^ w_neg_mask) - w_neg_mask);
const psimd_u32 x_abs_lo = x_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
const psimd_u32 x_abs_hi = x_abs >> psimd_splat_u32(16);
const psimd_u32 y_abs_lo = y_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
const psimd_u32 y_abs_hi = y_abs >> psimd_splat_u32(16);
const psimd_u32 z_abs_lo = z_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
const psimd_u32 z_abs_hi = z_abs >> psimd_splat_u32(16);
const psimd_u32 w_abs_lo = w_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
const psimd_u32 w_abs_hi = w_abs >> psimd_splat_u32(16);
const psimd_u32 x_product_ll = x_abs_lo * vmultiplier_lo;
const psimd_u32 y_product_ll = y_abs_lo * vmultiplier_lo;
const psimd_u32 z_product_ll = z_abs_lo * vmultiplier_lo;
const psimd_u32 w_product_ll = w_abs_lo * vmultiplier_lo;
const psimd_u32 x_product_lh = x_abs_lo * vmultiplier_hi + (x_product_ll >> psimd_splat_u32(16));
const psimd_u32 y_product_lh = y_abs_lo * vmultiplier_hi + (y_product_ll >> psimd_splat_u32(16));
const psimd_u32 z_product_lh = z_abs_lo * vmultiplier_hi + (z_product_ll >> psimd_splat_u32(16));
const psimd_u32 w_product_lh = w_abs_lo * vmultiplier_hi + (w_product_ll >> psimd_splat_u32(16));
const psimd_u32 x_product_hl = x_abs_hi * vmultiplier_lo + (x_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 y_product_hl = y_abs_hi * vmultiplier_lo + (y_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 z_product_hl = z_abs_hi * vmultiplier_lo + (z_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 w_product_hl = w_abs_hi * vmultiplier_lo + (w_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 x_product_lo =
(x_product_hl << psimd_splat_u32(16)) + (x_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 y_product_lo =
(y_product_hl << psimd_splat_u32(16)) + (y_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 z_product_lo =
(z_product_hl << psimd_splat_u32(16)) + (z_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 w_product_lo =
(w_product_hl << psimd_splat_u32(16)) + (w_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
const psimd_u32 x_product_hi =
x_abs_hi * vmultiplier_hi + (x_product_lh >> psimd_splat_u32(16)) + (x_product_hl >> psimd_splat_u32(16));
const psimd_u32 y_product_hi =
y_abs_hi * vmultiplier_hi + (y_product_lh >> psimd_splat_u32(16)) + (y_product_hl >> psimd_splat_u32(16));
const psimd_u32 z_product_hi =
z_abs_hi * vmultiplier_hi + (z_product_lh >> psimd_splat_u32(16)) + (z_product_hl >> psimd_splat_u32(16));
const psimd_u32 w_product_hi =
w_abs_hi * vmultiplier_hi + (w_product_lh >> psimd_splat_u32(16)) + (w_product_hl >> psimd_splat_u32(16));
const psimd_u32 x_adjusted_product =
(x_product_hi + vrounding_hi) - ((psimd_s32)(x_product_lo & vrounding_lo) >> psimd_splat_s32(31));
const psimd_u32 y_adjusted_product =
(y_product_hi + vrounding_hi) - ((psimd_s32)(y_product_lo & vrounding_lo) >> psimd_splat_s32(31));
const psimd_u32 z_adjusted_product =
(z_product_hi + vrounding_hi) - ((psimd_s32)(z_product_lo & vrounding_lo) >> psimd_splat_s32(31));
const psimd_u32 w_adjusted_product =
(w_product_hi + vrounding_hi) - ((psimd_s32)(w_product_lo & vrounding_lo) >> psimd_splat_s32(31));
const psimd_u32 x_abs_scaled = x_adjusted_product >> vshift;
const psimd_u32 y_abs_scaled = y_adjusted_product >> vshift;
const psimd_u32 z_abs_scaled = z_adjusted_product >> vshift;
const psimd_u32 w_abs_scaled = w_adjusted_product >> vshift;
const psimd_s32 x_scaled = (psimd_s32)(x_abs_scaled ^ x_neg_mask) - x_neg_mask;
const psimd_s32 y_scaled = (psimd_s32)(y_abs_scaled ^ y_neg_mask) - y_neg_mask;
const psimd_s32 z_scaled = (psimd_s32)(z_abs_scaled ^ z_neg_mask) - z_neg_mask;
const psimd_s32 w_scaled = (psimd_s32)(w_abs_scaled ^ w_neg_mask) - w_neg_mask;
const psimd_u32 x_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(x_scaled, vsmax), vsmin) + vzero_point;
const psimd_u32 y_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(y_scaled, vsmax), vsmin) + vzero_point;
const psimd_u32 z_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(z_scaled, vsmax), vsmin) + vzero_point;
const psimd_u32 w_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(w_scaled, vsmax), vsmin) + vzero_point;
const psimd_u16 xy_clamped = psimd_concat_even_u16((psimd_u16) x_clamped, (psimd_u16) y_clamped);
const psimd_u16 zw_clamped = psimd_concat_even_u16((psimd_u16) z_clamped, (psimd_u16) w_clamped);
const psimd_u8 xyzw_clamped = psimd_concat_even_u8((psimd_u8) xy_clamped, (psimd_u8) zw_clamped);
psimd_store_u8(output, xyzw_clamped);
output += 16;
}
}