blob: d336f46699cccab12cefdd35e81f23cdc7fbd9c0 [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <stdint.h>
#include <stddef.h>
#include <assert.h>
#include <math.h>
#include <fp16.h>
#include <xnnpack/common.h>
#include <xnnpack/math.h>
#include <xnnpack/params.h>
typedef int8_t (*xnn_qs8_requantize_fn)(
int32_t input,
float scale,
int8_t output_zero_point,
int8_t output_min,
int8_t output_max);
typedef uint8_t (*xnn_qu8_requantize_fn)(
int32_t input,
float scale,
uint8_t output_zero_point,
uint8_t output_min,
uint8_t output_max);
static inline int8_t xnn_qs8_requantize_fp32(
int32_t input,
float scale,
int8_t zero_point,
int8_t min,
int8_t max)
{
assert(scale >= 0x1.0p-32f);
assert(scale < 256.0f);
const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
float scaled_input = (float) input * scale;
scaled_input = math_max_f32(scaled_input, min_less_zero_point);
scaled_input = math_min_f32(scaled_input, max_less_zero_point);
const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
return (int8_t) output;
}
static inline uint8_t xnn_qu8_requantize_fp32(
int32_t input,
float scale,
uint8_t zero_point,
uint8_t min,
uint8_t max)
{
assert(scale >= 0x1.0p-32f);
assert(scale < 256.0f);
const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
float scaled_input = (float) input * scale;
scaled_input = math_max_f32(scaled_input, min_less_zero_point);
scaled_input = math_min_f32(scaled_input, max_less_zero_point);
const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
return (uint8_t) output;
}
static inline int8_t xnn_qs8_requantize_rndna(
int32_t input,
float scale,
int8_t zero_point,
int8_t min,
int8_t max)
{
assert(scale >= 0x1.0p-32f);
assert(scale < 256.0f);
const uint32_t scale_bits = fp32_to_bits(scale);
const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
const uint32_t shift = 127 + 23 - (scale_bits >> 23);
assert(shift >= 16);
assert(shift < 56);
const uint64_t rounding = UINT64_C(1) << (shift - 1);
const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
uint32_t abs_input = (uint32_t) input;
if (input < 0) {
abs_input = -abs_input;
}
const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
int32_t output = (int32_t) abs_scaled_input;
if (input < 0) {
output = -output;
}
output = math_max_s32(output, min_less_zero_point);
output = math_min_s32(output, max_less_zero_point);
return (int8_t) (output + (int32_t) zero_point);
}
static inline uint8_t xnn_qu8_requantize_rndna(
int32_t input,
float scale,
uint8_t zero_point,
uint8_t min,
uint8_t max)
{
assert(scale >= 0x1.0p-32f);
assert(scale < 256.0f);
const uint32_t scale_bits = fp32_to_bits(scale);
const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
const uint32_t shift = 127 + 23 - (scale_bits >> 23);
assert(shift >= 16);
assert(shift < 56);
const uint64_t rounding = UINT64_C(1) << (shift - 1);
const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
uint32_t abs_input = (uint32_t) input;
if (input < 0) {
abs_input = -abs_input;
}
const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
int32_t output = (int32_t) abs_scaled_input;
if (input < 0) {
output = -output;
}
output = math_max_s32(output, min_less_zero_point);
output = math_min_s32(output, max_less_zero_point);
return (uint8_t) (output + (int32_t) zero_point);
}
static inline int8_t xnn_qs8_requantize_rndnu(
int32_t input,
float scale,
int8_t zero_point,
int8_t min,
int8_t max)
{
assert(scale < 256.0f);
assert(scale >= 0x1.0p-32f);
const uint32_t scale_bits = fp32_to_bits(scale);
const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
const uint32_t shift = 127 + 23 - (scale_bits >> 23);
assert(shift >= 16);
assert(shift < 56);
const int64_t rounding = INT64_C(1) << (shift - 1);
const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
int32_t output = (int32_t) asr_s64(abs_prescaled_input + rounding, shift);
output = math_max_s32(output, min_less_zero_point);
output = math_min_s32(output, max_less_zero_point);
return (int8_t) (output + (int32_t) zero_point);
}
static inline uint8_t xnn_qu8_requantize_rndnu(
int32_t input,
float scale,
uint8_t zero_point,
uint8_t min,
uint8_t max)
{
assert(scale < 256.0f);
assert(scale >= 0x1.0p-32f);
const uint32_t scale_bits = fp32_to_bits(scale);
const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
const uint32_t shift = 127 + 23 - (scale_bits >> 23);
assert(shift >= 16);
assert(shift < 56);
const int64_t rounding = INT64_C(1) << (shift - 1);
const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
int32_t output = (int32_t) asr_s64(abs_prescaled_input + rounding, shift);
output = math_max_s32(output, min_less_zero_point);
output = math_min_s32(output, max_less_zero_point);
return (uint8_t) (output + (int32_t) zero_point);
}
static inline uint8_t xnn_qu8_quantize_add(
uint8_t a, uint8_t b,
union xnn_qu8_addsub_minmax_params params)
{
// Multiply by factors and accumulate products.
int32_t acc = params.scalar.bias + (int32_t) (uint32_t) a * params.scalar.a_multiplier + (int32_t) (uint32_t) b * params.scalar.b_multiplier;
// Shift right with rounding away from zero.
acc = asr_s32(acc, params.scalar.shift);
// Clamp and add output zero point.
acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
}
static inline int8_t xnn_qs8_quantize_add(
int8_t a, int8_t b,
union xnn_qs8_addsub_minmax_params params)
{
// Multiply by factors and accumulate products.
int32_t acc = params.scalar.bias + (int32_t) a * params.scalar.a_multiplier + (int32_t) b * params.scalar.b_multiplier;
// Shift right with rounding away from zero.
acc = asr_s32(acc, params.scalar.shift);
// Clamp and add output zero point.
acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
}