blob: d5aa27b1520bc4e3a57502ef424dff8745ccda09 [file] [log] [blame]
Jared Duke13689fe2019-04-16 16:22:07 -04001/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
Benoit Jacobfd803fb2020-06-25 11:40:59 -070016#ifndef RUY_RUY_MUL_PARAMS_H_
17#define RUY_RUY_MUL_PARAMS_H_
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -040018
Benoit Jacobfd803fb2020-06-25 11:40:59 -070019#include <cstdint>
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -040020#include <limits>
Sean Silvac5d291b2019-04-30 19:31:21 -040021#include <type_traits>
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -040022
Benoit Jacob39df7432020-06-29 12:01:41 -070023#include "ruy/check_macros.h"
Benoit Jacobf88e08e2020-07-13 09:40:39 -070024#include "ruy/size_util.h"
Benoit Jacob39df7432020-06-29 12:01:41 -070025
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -040026namespace ruy {
27
Benoit Jacobfd803fb2020-06-25 11:40:59 -070028// Enumeration to designate which dimension is the 'channels', for MulParams
29// features that are 'per-channel', namely the bias-vector and the quantized
30// multiplier.
31enum class ChannelDimension : std::int8_t {
32 // kRow means that 'per-channel' means 'per row of the destination matrix'
33 kRow,
34 // kCol means that 'per-channel' means 'per column of the destination matrix'
35 kCol
36};
37
Benoit Jacob39df7432020-06-29 12:01:41 -070038namespace detail {
39template <typename tAccumScalar, typename tDstScalar>
40struct MulParamsStorage;
41}
42
Benoit Jacob98c1b9c2020-04-16 11:37:15 -070043// MulParams describes all about a matrix multiplication that
44// isn't encoded in the LHS, RHS and destination matrices. Some of that
45// information is encoded as compile-time constants and types (for instance, the
46// choice of accumulator type, AccumScalar). Some of that information is encoded
47// as runtime values (for instance, the optional bias vector).
Benoit Jacob39df7432020-06-29 12:01:41 -070048//
49// Template parameters:
50// AccumScalar: Accumulator type. The type of accumulators used to compute the
51// dot-products before being ultimately casted to the destination type.
52// DstScalar: The destination scalar type.
bjacob58e30512021-01-21 12:33:11 -080053//
54// Constraints on these template parameters (see also the ruy::Mul comment):
55// * If DstScalar is floating-point then AccumScalar must also be.
56// * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover
57// in that integral case, there is a mode switch:
58// - If DstScalar is std::int32_t then the multiplier_* fields are all
59// disabled, and ruy::Mul will just return raw (unscaled) accumulators.
60// - If DstScalar is not std::int32_t then the multiplier_* fields are
61// enabled, and ruy::Mul will use them to scale internal std::int32_t
62// accumulators before casting them to the DstScalar type. The default
63// values are such that the effective multiplier is 1 (no scaling).
64//
Georgios Pinitasbe760b62021-02-09 10:49:17 -080065// For the latter case (DstScalar integral and narrower than std::int32_t),
66// reference code can be found in the implementation of ruy::ApplyMultiplier.
67// If you look there, you'll find warnings like this:
bjacob58e30512021-01-21 12:33:11 -080068//
69// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
70// Warning: this code is not meant to be bit-exact-normative.
71// Please refer to the class comment of ruy::MulParams, in mul_params.h.
72// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
73//
74// The explanation of this warning is that as of early 2021, we still don't know
75// whether it is advisable to let this code as-is have normative value, or
76// whether that would become advisable after some specific final change.
77//
78// Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform
79// bit-exactly to this reference, but we also know that x86 could be faster if
80// it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't
81// know that this particular reference code is inherently better than other
82// forms that could perform better on these architectures --- in fact, the
83// alternative that was proposed in [2] as better performing on ARM Cortex-M
84// is also inherently more accurate thanks to rounding only once, but it would
85// perform worse on both ARM NEON, and x86.
86//
87// In fact, if we look at other hardware architectures beyond current Ruy
88// targets, namely "hardware accelerators", it becomes clear that there is no
89// hope for any form of this to be efficiently implementable simultaneously on
90// all current relevant hardware. Indeed, some accelerators prefer to perform
91// the multiplication in IEEE float32, others in IEEE float16, others in
92// bfloat16, others in 16-bit fixed-point...
93//
94// See:
95// [1] https://github.com/google/ruy/pull/227
96// [2] https://github.com/tensorflow/tensorflow/issues/25087
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -040097template <typename tAccumScalar, typename tDstScalar>
Benoit Jacobe273e152020-06-25 11:54:15 -070098class MulParams final {
Benoit Jacobe866a682020-04-22 11:04:33 -070099 public:
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400100 using AccumScalar = tAccumScalar;
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400101 using DstScalar = tDstScalar;
Benoit Jacobd4dccd62020-04-15 15:48:17 -0700102
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400103 // The bias vector data, if not null.
Benoit Jacob39df7432020-06-29 12:01:41 -0700104 const AccumScalar* bias() const { return storage_.bias; }
105 void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
bjacob58e30512021-01-21 12:33:11 -0800106 // Only for non-floating-point cases. The fixed-point part of the multiplier
107 // by which accumulators are multiplied before being casted to the destination
108 // type. This is a fixed-point quantity with 0 integer bits. Since
109 // (as explained in the class comment) AccumScalar must be std::int32_t,
110 // that means that the fixed-point format is Q0.31. For example,
111 // a multiplier_fixedpoint value of 2^30 has the effect of multiplying
112 // by one half (1/2). More generally, the effect is to multiply by
113 // (multiplier_fixedpoint / (2^31)).
Benoit Jacob39df7432020-06-29 12:01:41 -0700114 AccumScalar multiplier_fixedpoint() const {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700115 return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
Benoit Jacob39df7432020-06-29 12:01:41 -0700116 }
117 void set_multiplier_fixedpoint(const AccumScalar value) {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700118 set_perchannel(false);
Benoit Jacob39df7432020-06-29 12:01:41 -0700119 storage_.multiplier_fixedpoint = value;
120 }
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400121 // Only for non-floating-point cases. The exponent part of the aforementioned
122 // multiplier.
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700123 int multiplier_exponent() const {
124 return storage_.perchannel ? 0 : storage_.multiplier_exponent;
125 }
Benoit Jacob39df7432020-06-29 12:01:41 -0700126 void set_multiplier_exponent(const int value) {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700127 set_perchannel(false);
Benoit Jacob39df7432020-06-29 12:01:41 -0700128 storage_.multiplier_exponent = value;
129 }
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700130 // Per-channel variant of multiplier_fixedpoint. Setting this switches
131 // to per-channel mode, where `multiplier_fixedpoint` and
132 // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
133 // and `multiplier_exponent_perchannel` are used instead.
134 //
bjacob58e30512021-01-21 12:33:11 -0800135 // This must point to a buffer of as many values as there are rows or columns
136 // in the destination matrix, whichever is the channels dimension. Each
137 // channel of the destination matrix will use the corresponding buffer element
138 // instead of multiplier_fixedpoint.
Benoit Jacob39df7432020-06-29 12:01:41 -0700139 const AccumScalar* multiplier_fixedpoint_perchannel() const {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700140 return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
141 : nullptr;
Benoit Jacob39df7432020-06-29 12:01:41 -0700142 }
143 void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700144 set_perchannel(true);
Benoit Jacob39df7432020-06-29 12:01:41 -0700145 storage_.multiplier_fixedpoint_perchannel = ptr;
146 }
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700147 // Per-channel variant of multiplier_exponent. Same comments as for
148 // multiplier_fixedpoint_perchannel.
Benoit Jacob39df7432020-06-29 12:01:41 -0700149 const int* multiplier_exponent_perchannel() const {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700150 return storage_.perchannel ? storage_.multiplier_exponent_perchannel
151 : nullptr;
Benoit Jacob39df7432020-06-29 12:01:41 -0700152 }
153 void set_multiplier_exponent_perchannel(const int* ptr) {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700154 set_perchannel(true);
Benoit Jacob39df7432020-06-29 12:01:41 -0700155 storage_.multiplier_exponent_perchannel = ptr;
156 }
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400157 // min clamp bound of destination values.
Benoit Jacob39df7432020-06-29 12:01:41 -0700158 DstScalar clamp_min() const { return storage_.clamp_min; }
159 void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; }
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400160 // max clamp bound of destination values.
Benoit Jacob39df7432020-06-29 12:01:41 -0700161 DstScalar clamp_max() const { return storage_.clamp_max; }
162 void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; }
Benoit Jacobfd803fb2020-06-25 11:40:59 -0700163 // Designates which dimension is the 'channels', for per-channel features
164 // such as bias-addition and per-channel quantization multipliers.
Benoit Jacob39df7432020-06-29 12:01:41 -0700165 ChannelDimension channel_dimension() const {
166 return storage_.channel_dimension;
167 }
168 void set_channel_dimension(ChannelDimension value) {
169 storage_.channel_dimension = value;
170 }
Benoit Jacobf88e08e2020-07-13 09:40:39 -0700171 // Specifies the upward rounding of the allocated capacity of per-channel
172 // buffers such as bias vectors and per-channel quantization multipliers.
173 // The unit is matrix entries, not bytes.
174 //
175 // This value must be a power of two.
176 //
177 // The default value, 1, means no upward rounding, meaning that the buffers
178 // are not required to have a capacity greater than the size of the
179 // corresponding matrix dimension, i.e. the number of rows (respectively
180 // columns) of the destination matrix if `channel_dimension()` is kRow
181 // (respectively kCol).
182 //
183 // Higher values allow the implementation to assume that it is OK to access
184 // these buffers a little past this boundary, which is useful in SIMD
185 // optimized kernels. In practice, when this value is lower than what the
186 // kernel requires, ruy has to internally reallocate and copy per-channel
187 // buffers. When this value is high enough, this reallocation and copy is
188 // avoided.
189 //
190 // When a value greater than 1 is specified, the tail region of the buffer
191 // (past the end of the values actually corresponding to channels) is required
192 // to be zero-initialized.
193 //
194 // As of 2020, values as high as 16 may be useful on some CPU architectures
195 // (corresponding to the widest kernels used on any CPU architecture).
196 int perchannel_buffers_capacity_rounding() const {
197 return 1 << storage_.perchannel_buffers_capacity_rounding_log2;
198 }
199 void set_perchannel_buffers_capacity_rounding(int value) {
200 // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two.
201 storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value);
202 }
Benoit Jacob39df7432020-06-29 12:01:41 -0700203
204 private:
205 detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700206
207 void set_perchannel(bool perchannel) {
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700208 storage_.perchannel = perchannel;
209 }
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400210};
211
Benoit Jacob39df7432020-06-29 12:01:41 -0700212namespace detail {
213
214// Floating-point case.
215template <typename AccumScalar, typename DstScalar>
216struct MulParamsStorage final {
217 static_assert(std::is_floating_point<AccumScalar>::value, "");
218 static_assert(std::is_floating_point<DstScalar>::value, "");
219 static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), "");
220
221 const AccumScalar* bias = nullptr;
222 DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity();
223 DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity();
224 ChannelDimension channel_dimension = ChannelDimension::kRow;
Benoit Jacobf88e08e2020-07-13 09:40:39 -0700225 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
Benoit Jacob39df7432020-06-29 12:01:41 -0700226
227 // Data members that are disabled in this case are left as `static constexpr`
228 // so that one can write some generic code.
229 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
230 nullptr;
231 static constexpr const int* multiplier_exponent_perchannel = nullptr;
232 static constexpr AccumScalar multiplier_fixedpoint = 0;
233 static constexpr int multiplier_exponent = 0;
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700234 static constexpr bool perchannel = false;
Benoit Jacob39df7432020-06-29 12:01:41 -0700235};
236
237// Specialization for the integer-quantized type, with down-quantization of
238// int32 accumulators to a narrower destination scalar type.
239template <typename DstScalar>
240struct MulParamsStorage<std::int32_t, DstScalar> final {
241 using AccumScalar = std::int32_t;
242 static_assert(std::is_integral<DstScalar>::value, "");
bjacob58e30512021-01-21 12:33:11 -0800243 static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, "");
Benoit Jacob39df7432020-06-29 12:01:41 -0700244
245 const AccumScalar* bias = nullptr;
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700246 union {
bjacob58e30512021-01-21 12:33:11 -0800247 const AccumScalar* multiplier_fixedpoint_perchannel;
248 // Let the default multiplier be effecively a multiplication by 1, so that
249 // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
250 // 1 is not exactly representable in fixedpoint with 0 integer bits, but
251 // using the highest representable value is a sufficiently good
252 // approximation: since this specialization of MulParams is for the case
253 // where DstScalar is at least 2x narrower than MulScalar, the values
254 // for which there would be a difference will get saturated anyway.
255 AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max();
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700256 };
257 union {
bjacob58e30512021-01-21 12:33:11 -0800258 const int* multiplier_exponent_perchannel;
259 // See the above comment about the default value of multiplier_fixedpoint.
260 int multiplier_exponent = 0;
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700261 };
Benoit Jacob39df7432020-06-29 12:01:41 -0700262 DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
263 DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
264 ChannelDimension channel_dimension = ChannelDimension::kRow;
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700265 bool perchannel = false;
Benoit Jacobf88e08e2020-07-13 09:40:39 -0700266 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
Benoit Jacob39df7432020-06-29 12:01:41 -0700267};
268
269// Specialization used in the integer case when outputting raw int32
270// accumulators, without down-quantization to a narrower destination scalar
271// type. In this case, the feature of clamping destination values is not
272// available.
273template <>
274struct MulParamsStorage<std::int32_t, std::int32_t> final {
275 using AccumScalar = std::int32_t;
276 using DstScalar = std::int32_t;
277
278 const AccumScalar* bias = nullptr;
279 ChannelDimension channel_dimension = ChannelDimension::kRow;
Benoit Jacobf88e08e2020-07-13 09:40:39 -0700280 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
Benoit Jacob39df7432020-06-29 12:01:41 -0700281
282 // Data members that are disabled in this case are left as `static constexpr`
283 // so that one can write some generic code.
284 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
285 nullptr;
286 static constexpr const int* multiplier_exponent_perchannel = nullptr;
287 static constexpr AccumScalar multiplier_fixedpoint = 0;
288 static constexpr int multiplier_exponent = 0;
289 static constexpr DstScalar clamp_min =
290 std::numeric_limits<DstScalar>::lowest();
291 static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
Benoit Jacob03bbc8f2020-06-29 12:10:31 -0700292 static constexpr bool perchannel = false;
Benoit Jacob39df7432020-06-29 12:01:41 -0700293};
294
295} // namespace detail
296
Benoit Jacoba0ba3ac2019-04-08 12:00:37 -0400297} // namespace ruy
298
Benoit Jacobfd803fb2020-06-25 11:40:59 -0700299#endif // RUY_RUY_MUL_PARAMS_H_