blob: 35075455fc1a87b3deebaf56f01b35726a3e369f [file] [log] [blame]
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -07001#include <algorithm>
2#include <cfloat>
Marat Dukhan4fa0fbe2019-10-31 10:23:46 -07003#include <chrono>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -07004#include <cmath>
5#include <functional>
6#include <random>
7#include <vector>
8
9#include "bench/utils.h"
10#include <xnnpack/common.h>
11#include <xnnpack/params.h>
Marat Dukhan4a5c7712022-01-05 22:43:13 -080012#include <xnnpack/params-init.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070013#include <xnnpack/raddexpminusmax.h>
Marat Dukhan4a2bbc62019-10-25 17:36:32 -070014#include <xnnpack/raddextexp.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070015#include <xnnpack/raddstoreexpminusmax.h>
16#include <xnnpack/rmax.h>
Marat Dukhan58b17ba2022-01-06 11:34:09 -080017#include <xnnpack/vbinary.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070018#include <xnnpack/vscaleexpminusmax.h>
Marat Dukhan4a2bbc62019-10-25 17:36:32 -070019#include <xnnpack/vscaleextexp.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070020
21#include <benchmark/benchmark.h>
Marat Dukhan8d3c6932020-03-06 20:27:27 -080022#ifdef BENCHMARK_INTEL_DNNL
23#include <dnnl.h>
24#endif // BENCHMARK_INTEL_DNNL
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070025
26
Marat Dukhan8d3c6932020-03-06 20:27:27 -080027#ifdef BENCHMARK_INTEL_DNNL
28static void DNNLSoftArgMax(
29 benchmark::State& state)
30{
Marat Dukhand713e8a2020-12-04 14:23:12 -080031 const size_t elements = state.range(0);
Marat Dukhan8d3c6932020-03-06 20:27:27 -080032 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -080033 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan8d3c6932020-03-06 20:27:27 -080034
35 std::random_device random_device;
36 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -070037 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan8d3c6932020-03-06 20:27:27 -080038
39 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -080040 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
41 std::vector<float> x(elements);
42 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan8d3c6932020-03-06 20:27:27 -080043
44 std::generate(x.begin(), x.end(), std::ref(f32rng));
45
46 dnnl_engine_t engine;
47 if (dnnl_engine_create(&engine, dnnl_cpu, 0) != dnnl_success) {
48 state.SkipWithError("failed to create CPU engine");
49 return;
50 }
51
Marat Dukhand713e8a2020-12-04 14:23:12 -080052 dnnl_dim_t input_output_shape[1] = { static_cast<int>(elements) };
Marat Dukhan8d3c6932020-03-06 20:27:27 -080053
54 dnnl_memory_desc_t memory_descriptor = { 0 };
55 if (dnnl_memory_desc_init_by_tag(
56 &memory_descriptor, 1, input_output_shape, dnnl_f32, dnnl_x) != dnnl_success)
57 {
58 state.SkipWithError("failed to create input memory descriptor");
59 return;
60 }
61
62 dnnl_memory_t input_memory = nullptr;
63 if (dnnl_memory_create(
64 &input_memory, &memory_descriptor, engine, x.data()) != dnnl_success)
65 {
66 state.SkipWithError("failed to create input memory");
67 return;
68 }
69
70 dnnl_memory_t output_memory = nullptr;
71 if (dnnl_memory_create(
72 &output_memory, &memory_descriptor, engine, y.data()) != dnnl_success)
73 {
74 state.SkipWithError("failed to create output memory");
75 return;
76 }
77
78 dnnl_softmax_desc_t softmax_forward_descriptor = {};
79 if (dnnl_softmax_forward_desc_init(
80 &softmax_forward_descriptor, dnnl_forward_inference,
81 &memory_descriptor, 0) != dnnl_success)
82 {
83 state.SkipWithError("failed to create SoftMax forward descriptor");
84 return;
85 }
86
87 dnnl_primitive_desc_t softmax_primitive_descriptor = nullptr;
88 if (dnnl_primitive_desc_create(
89 &softmax_primitive_descriptor, &softmax_forward_descriptor,
90 nullptr /* primitive attributes */, engine, nullptr /* hint */) != dnnl_success)
91 {
92 state.SkipWithError("failed to create SoftMax primitive descriptor");
93 return;
94 }
95
96 dnnl_primitive_t softmax_primitive = nullptr;
97 if (dnnl_primitive_create(
98 &softmax_primitive, softmax_primitive_descriptor) != dnnl_success)
99 {
100 state.SkipWithError("failed to create SoftMax primitive");
101 return;
102 }
103
104 dnnl_exec_arg_t softmax_args[2] = {
105 {DNNL_ARG_SRC, input_memory},
106 {DNNL_ARG_DST, output_memory},
107 };
108
109 dnnl_stream_t stream = nullptr;
110 if (dnnl_stream_create(&stream, engine, dnnl_stream_default_flags) != dnnl_success) {
111 state.SkipWithError("failed to create stream");
112 return;
113 }
114
115 size_t buffer_index = 0;
116 for (auto _ : state) {
117 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
118 if (++buffer_index == num_buffers) {
119 buffer_index = 0;
120 }
121
122 const auto start = std::chrono::high_resolution_clock::now();
123 if (dnnl_primitive_execute(
124 softmax_primitive, stream, 2, softmax_args) != dnnl_success)
125 {
126 state.SkipWithError("failed to execute SoftMax");
127 return;
128 }
129 const auto end = std::chrono::high_resolution_clock::now();
130
131 const auto elapsed_seconds =
132 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
133 state.SetIterationTime(elapsed_seconds.count());
134 }
135
136 if (dnnl_stream_destroy(stream) != dnnl_success) {
137 state.SkipWithError("failed to destroy stream");
138 return;
139 }
140
141 if (dnnl_primitive_desc_destroy(softmax_primitive_descriptor) != dnnl_success) {
142 state.SkipWithError("failed to destroy SoftMax primitive descriptor");
143 return;
144 }
145
146 if (dnnl_primitive_destroy(softmax_primitive) != dnnl_success) {
147 state.SkipWithError("failed to destroy SoftMax primitive");
148 return;
149 }
150
151 if (dnnl_memory_destroy(input_memory) != dnnl_success) {
152 state.SkipWithError("failed to destroy input memory");
153 return;
154 }
155
156 if (dnnl_memory_destroy(output_memory) != dnnl_success) {
157 state.SkipWithError("failed to destroy output memory");
158 return;
159 }
160
161 if (dnnl_engine_destroy(engine) != dnnl_success) {
162 state.SkipWithError("failed to destroy engine");
163 return;
164 }
165
Marat Dukhand713e8a2020-12-04 14:23:12 -0800166 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
167 if (cpu_frequency != 0) {
168 state.counters["cpufreq"] = cpu_frequency;
169 }
170
171 const size_t elements_per_iteration = elements;
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800172 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800173 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
174
175 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800176 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800177 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800178}
179#endif // BENCHMARK_INTEL_DNNL
180
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800181static void ThreePassSoftMaxWithRecomputing(
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700182 benchmark::State& state,
183 xnn_f32_rmax_ukernel_function rmax,
184 xnn_f32_raddexpminusmax_ukernel_function raddexpminusmax,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800185 xnn_f32_vscaleexpminusmax_ukernel_function vscaleexpminusmax,
186 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700187{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800188 if (isa_check && !isa_check(state)) {
189 return;
190 }
191
Marat Dukhand713e8a2020-12-04 14:23:12 -0800192 const size_t elements = state.range(0);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700193 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -0800194 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700195
196 std::random_device random_device;
197 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -0700198 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700199
200 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -0800201 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
202 std::vector<float> x(elements);
203 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700204
205 std::generate(x.begin(), x.end(), std::ref(f32rng));
206
207 benchmark::utils::DisableDenormals();
208
209 size_t buffer_index = 0;
210 for (auto _ : state) {
Marat Dukhan42323232019-10-23 02:09:02 -0700211 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700212 if (++buffer_index == num_buffers) {
213 buffer_index = 0;
214 }
215
216 const auto start = std::chrono::high_resolution_clock::now();
217 float x_max = nanf("");
Marat Dukhand713e8a2020-12-04 14:23:12 -0800218 rmax(elements * sizeof(float), x.data(), &x_max);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700219 float y_sum = nanf("");
Marat Dukhand713e8a2020-12-04 14:23:12 -0800220 raddexpminusmax(elements * sizeof(float), x.data(), &y_sum, x_max);
221 vscaleexpminusmax(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, x_max, 1.0f / y_sum);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700222 const auto end = std::chrono::high_resolution_clock::now();
223
224 const auto elapsed_seconds =
225 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
226 state.SetIterationTime(elapsed_seconds.count());
227 }
228
Marat Dukhand713e8a2020-12-04 14:23:12 -0800229 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
230 if (cpu_frequency != 0) {
231 state.counters["cpufreq"] = cpu_frequency;
232 }
233
234 const size_t elements_per_iteration = elements;
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700235 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800236 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
237
238 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700239 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800240 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700241}
242
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800243static void ThreePassSoftMaxWithReloading(
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700244 benchmark::State& state,
245 xnn_f32_rmax_ukernel_function rmax,
246 xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax,
Marat Dukhan4a5c7712022-01-05 22:43:13 -0800247 xnn_init_f32_expminus_params_fn init_expminus_params,
Marat Dukhan58b17ba2022-01-06 11:34:09 -0800248 xnn_f32_vbinary_minmax_ukernel_function vmulc,
249 xnn_init_f32_minmax_params_fn init_minmax_params,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800250 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700251{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800252 if (isa_check && !isa_check(state)) {
253 return;
254 }
255
Marat Dukhand713e8a2020-12-04 14:23:12 -0800256 const size_t elements = state.range(0);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700257 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -0800258 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700259
260 std::random_device random_device;
261 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -0700262 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700263
264 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -0800265 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
266 std::vector<float> x(elements);
267 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700268
269 std::generate(x.begin(), x.end(), std::ref(f32rng));
270
271 benchmark::utils::DisableDenormals();
272
Marat Dukhan4a5c7712022-01-05 22:43:13 -0800273 xnn_f32_expminus_params expminus_params;
Marat Dukhan58b17ba2022-01-06 11:34:09 -0800274 xnn_f32_minmax_params minmax_params;
Marat Dukhan4a5c7712022-01-05 22:43:13 -0800275 init_expminus_params(&expminus_params);
Marat Dukhan58b17ba2022-01-06 11:34:09 -0800276 init_minmax_params(&minmax_params, -INFINITY, INFINITY);
Marat Dukhan4a5c7712022-01-05 22:43:13 -0800277
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700278 size_t buffer_index = 0;
279 for (auto _ : state) {
Marat Dukhan42323232019-10-23 02:09:02 -0700280 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700281 if (++buffer_index == num_buffers) {
282 buffer_index = 0;
283 }
284
285 const auto start = std::chrono::high_resolution_clock::now();
286 float x_max = nanf("");
Marat Dukhand713e8a2020-12-04 14:23:12 -0800287 rmax(elements * sizeof(float), x.data(), &x_max);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700288 float y_sum = nanf("");
Marat Dukhan4a5c7712022-01-05 22:43:13 -0800289 raddstoreexpminusmax(elements * sizeof(float), x.data(), &x_max, y.data() + packed_elements * buffer_index, &y_sum, &expminus_params);
Marat Dukhan58b17ba2022-01-06 11:34:09 -0800290 const float inv_y_sum = 1.0f / y_sum;
291 vmulc(elements * sizeof(float), y.data() + packed_elements * buffer_index, &inv_y_sum, y.data() + packed_elements * buffer_index, &minmax_params);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700292 const auto end = std::chrono::high_resolution_clock::now();
293
294 const auto elapsed_seconds =
295 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
296 state.SetIterationTime(elapsed_seconds.count());
297 }
298
Marat Dukhand713e8a2020-12-04 14:23:12 -0800299 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
300 if (cpu_frequency != 0) {
301 state.counters["cpufreq"] = cpu_frequency;
302 }
303
304 const size_t elements_per_iteration = elements;
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700305 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800306 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
307
308 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700309 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800310 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700311}
312
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800313static void TwoPassSoftMax(
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700314 benchmark::State& state,
315 xnn_f32_raddextexp_ukernel_function raddextexp,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800316 xnn_f32_vscaleextexp_ukernel_function vscaleextexp,
317 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700318{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800319 if (isa_check && !isa_check(state)) {
320 return;
321 }
322
Marat Dukhand713e8a2020-12-04 14:23:12 -0800323 const size_t elements = state.range(0);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700324 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -0800325 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700326
327 std::random_device random_device;
328 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -0700329 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700330
331 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -0800332 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
333 std::vector<float> x(elements);
334 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700335
336 std::generate(x.begin(), x.end(), std::ref(f32rng));
337
338 benchmark::utils::DisableDenormals();
339
340 size_t buffer_index = 0;
341 for (auto _ : state) {
342 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
343 if (++buffer_index == num_buffers) {
344 buffer_index = 0;
345 }
346
347 const auto start = std::chrono::high_resolution_clock::now();
348 float scale[2];
Marat Dukhand713e8a2020-12-04 14:23:12 -0800349 raddextexp(elements * sizeof(float), x.data(), scale);
350 vscaleextexp(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, 1.0f / scale[0], -scale[1]);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700351 const auto end = std::chrono::high_resolution_clock::now();
352
353 const auto elapsed_seconds =
354 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
355 state.SetIterationTime(elapsed_seconds.count());
356 }
357
Marat Dukhand713e8a2020-12-04 14:23:12 -0800358 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
359 if (cpu_frequency != 0) {
360 state.counters["cpufreq"] = cpu_frequency;
361 }
362
363 const size_t elements_per_iteration = elements;
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700364 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800365 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
366
367 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700368 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800369 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700370}
371
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700372static void CharacteristicArguments(benchmark::internal::Benchmark* b) {
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800373 for (int32_t n = 1000; n <= 100000000; n *= 10) {
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700374 b->Arg(n);
375 b->Arg(3 * n);
376 }
377}
378
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800379#ifdef BENCHMARK_INTEL_DNNL
380 BENCHMARK(DNNLSoftArgMax)->Apply(CharacteristicArguments)->UseManualTime();
381#endif
382
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700383#if XNN_ARCH_X86 || XNN_ARCH_X86_64
Marat Dukhanb6862d12020-03-08 15:26:49 -0700384 BENCHMARK_CAPTURE(TwoPassSoftMax, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800385 xnn_f32_raddextexp_ukernel__avx2_p5_x96,
386 xnn_f32_vscaleextexp_ukernel__avx2_p5_x40,
387 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700388 BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800389 xnn_f32_rmax_ukernel__avx,
390 xnn_f32_raddexpminusmax_ukernel__avx2_p5_x96,
391 xnn_f32_vscaleexpminusmax_ukernel__avx2_p5_x24,
392 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700393 BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800394 xnn_f32_rmax_ukernel__avx,
Marat Dukhan5999c922022-01-05 18:10:20 -0800395 xnn_f32_raddstoreexpminusmax_ukernel__avx2_rr1_p5_x64_acc2,
Marat Dukhan4a5c7712022-01-05 22:43:13 -0800396 xnn_init_f32_expminus_avx2_rr1_p5_params,
Marat Dukhan58b17ba2022-01-06 11:34:09 -0800397 xnn_f32_vmulc_minmax_ukernel__avx_x16,
398 xnn_init_f32_minmax_avx_params,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800399 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700400
Marat Dukhanb6862d12020-03-08 15:26:49 -0700401 BENCHMARK_CAPTURE(TwoPassSoftMax, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800402 xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_x144_acc3,
403 xnn_f32_vscaleextexp_ukernel__avx512f_p5_scalef_x16,
404 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700405 BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800406 xnn_f32_rmax_ukernel__avx512f,
407 xnn_f32_raddexpminusmax_ukernel__avx512f_p5_scalef_x128_acc4,
408 xnn_f32_vscaleexpminusmax_ukernel__avx512f_p5_scalef_x16,
409 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700410 BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800411 xnn_f32_rmax_ukernel__avx512f,
Marat Dukhan5999c922022-01-05 18:10:20 -0800412 xnn_f32_raddstoreexpminusmax_ukernel__avx512f_rr1_p5_scalef_x128_acc2,
Marat Dukhan4a5c7712022-01-05 22:43:13 -0800413 xnn_init_f32_expminus_avx512_rr1_p5_params,
Marat Dukhan58b17ba2022-01-06 11:34:09 -0800414 xnn_f32_vmulc_minmax_ukernel__avx512f_x32,
415 xnn_init_f32_minmax_scalar_params,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800416 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700417#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
418
419#ifndef XNNPACK_BENCHMARK_NO_MAIN
420BENCHMARK_MAIN();
421#endif