blob: ecc4d105387f27ff13dc57f8118a8f295fa40fd8 [file] [log] [blame]
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -07001#include <algorithm>
2#include <cfloat>
Marat Dukhan4fa0fbe2019-10-31 10:23:46 -07003#include <chrono>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -07004#include <cmath>
5#include <functional>
6#include <random>
7#include <vector>
8
9#include "bench/utils.h"
10#include <xnnpack/common.h>
11#include <xnnpack/params.h>
12#include <xnnpack/raddexpminusmax.h>
Marat Dukhan4a2bbc62019-10-25 17:36:32 -070013#include <xnnpack/raddextexp.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070014#include <xnnpack/raddstoreexpminusmax.h>
15#include <xnnpack/rmax.h>
Marat Dukhan05ac8e32019-10-21 15:39:33 -070016#include <xnnpack/vscale.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070017#include <xnnpack/vscaleexpminusmax.h>
Marat Dukhan4a2bbc62019-10-25 17:36:32 -070018#include <xnnpack/vscaleextexp.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070019
20#include <benchmark/benchmark.h>
Marat Dukhan8d3c6932020-03-06 20:27:27 -080021#ifdef BENCHMARK_INTEL_DNNL
22#include <dnnl.h>
23#endif // BENCHMARK_INTEL_DNNL
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070024
25
Marat Dukhan8d3c6932020-03-06 20:27:27 -080026#ifdef BENCHMARK_INTEL_DNNL
27static void DNNLSoftArgMax(
28 benchmark::State& state)
29{
Marat Dukhand713e8a2020-12-04 14:23:12 -080030 const size_t elements = state.range(0);
Marat Dukhan8d3c6932020-03-06 20:27:27 -080031 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -080032 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan8d3c6932020-03-06 20:27:27 -080033
34 std::random_device random_device;
35 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -070036 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan8d3c6932020-03-06 20:27:27 -080037
38 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -080039 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
40 std::vector<float> x(elements);
41 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan8d3c6932020-03-06 20:27:27 -080042
43 std::generate(x.begin(), x.end(), std::ref(f32rng));
44
45 dnnl_engine_t engine;
46 if (dnnl_engine_create(&engine, dnnl_cpu, 0) != dnnl_success) {
47 state.SkipWithError("failed to create CPU engine");
48 return;
49 }
50
Marat Dukhand713e8a2020-12-04 14:23:12 -080051 dnnl_dim_t input_output_shape[1] = { static_cast<int>(elements) };
Marat Dukhan8d3c6932020-03-06 20:27:27 -080052
53 dnnl_memory_desc_t memory_descriptor = { 0 };
54 if (dnnl_memory_desc_init_by_tag(
55 &memory_descriptor, 1, input_output_shape, dnnl_f32, dnnl_x) != dnnl_success)
56 {
57 state.SkipWithError("failed to create input memory descriptor");
58 return;
59 }
60
61 dnnl_memory_t input_memory = nullptr;
62 if (dnnl_memory_create(
63 &input_memory, &memory_descriptor, engine, x.data()) != dnnl_success)
64 {
65 state.SkipWithError("failed to create input memory");
66 return;
67 }
68
69 dnnl_memory_t output_memory = nullptr;
70 if (dnnl_memory_create(
71 &output_memory, &memory_descriptor, engine, y.data()) != dnnl_success)
72 {
73 state.SkipWithError("failed to create output memory");
74 return;
75 }
76
77 dnnl_softmax_desc_t softmax_forward_descriptor = {};
78 if (dnnl_softmax_forward_desc_init(
79 &softmax_forward_descriptor, dnnl_forward_inference,
80 &memory_descriptor, 0) != dnnl_success)
81 {
82 state.SkipWithError("failed to create SoftMax forward descriptor");
83 return;
84 }
85
86 dnnl_primitive_desc_t softmax_primitive_descriptor = nullptr;
87 if (dnnl_primitive_desc_create(
88 &softmax_primitive_descriptor, &softmax_forward_descriptor,
89 nullptr /* primitive attributes */, engine, nullptr /* hint */) != dnnl_success)
90 {
91 state.SkipWithError("failed to create SoftMax primitive descriptor");
92 return;
93 }
94
95 dnnl_primitive_t softmax_primitive = nullptr;
96 if (dnnl_primitive_create(
97 &softmax_primitive, softmax_primitive_descriptor) != dnnl_success)
98 {
99 state.SkipWithError("failed to create SoftMax primitive");
100 return;
101 }
102
103 dnnl_exec_arg_t softmax_args[2] = {
104 {DNNL_ARG_SRC, input_memory},
105 {DNNL_ARG_DST, output_memory},
106 };
107
108 dnnl_stream_t stream = nullptr;
109 if (dnnl_stream_create(&stream, engine, dnnl_stream_default_flags) != dnnl_success) {
110 state.SkipWithError("failed to create stream");
111 return;
112 }
113
114 size_t buffer_index = 0;
115 for (auto _ : state) {
116 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
117 if (++buffer_index == num_buffers) {
118 buffer_index = 0;
119 }
120
121 const auto start = std::chrono::high_resolution_clock::now();
122 if (dnnl_primitive_execute(
123 softmax_primitive, stream, 2, softmax_args) != dnnl_success)
124 {
125 state.SkipWithError("failed to execute SoftMax");
126 return;
127 }
128 const auto end = std::chrono::high_resolution_clock::now();
129
130 const auto elapsed_seconds =
131 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
132 state.SetIterationTime(elapsed_seconds.count());
133 }
134
135 if (dnnl_stream_destroy(stream) != dnnl_success) {
136 state.SkipWithError("failed to destroy stream");
137 return;
138 }
139
140 if (dnnl_primitive_desc_destroy(softmax_primitive_descriptor) != dnnl_success) {
141 state.SkipWithError("failed to destroy SoftMax primitive descriptor");
142 return;
143 }
144
145 if (dnnl_primitive_destroy(softmax_primitive) != dnnl_success) {
146 state.SkipWithError("failed to destroy SoftMax primitive");
147 return;
148 }
149
150 if (dnnl_memory_destroy(input_memory) != dnnl_success) {
151 state.SkipWithError("failed to destroy input memory");
152 return;
153 }
154
155 if (dnnl_memory_destroy(output_memory) != dnnl_success) {
156 state.SkipWithError("failed to destroy output memory");
157 return;
158 }
159
160 if (dnnl_engine_destroy(engine) != dnnl_success) {
161 state.SkipWithError("failed to destroy engine");
162 return;
163 }
164
Marat Dukhand713e8a2020-12-04 14:23:12 -0800165 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
166 if (cpu_frequency != 0) {
167 state.counters["cpufreq"] = cpu_frequency;
168 }
169
170 const size_t elements_per_iteration = elements;
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800171 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800172 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
173
174 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800175 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800176 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800177}
178#endif // BENCHMARK_INTEL_DNNL
179
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800180static void ThreePassSoftMaxWithRecomputing(
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700181 benchmark::State& state,
182 xnn_f32_rmax_ukernel_function rmax,
183 xnn_f32_raddexpminusmax_ukernel_function raddexpminusmax,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800184 xnn_f32_vscaleexpminusmax_ukernel_function vscaleexpminusmax,
185 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700186{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800187 if (isa_check && !isa_check(state)) {
188 return;
189 }
190
Marat Dukhand713e8a2020-12-04 14:23:12 -0800191 const size_t elements = state.range(0);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700192 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -0800193 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700194
195 std::random_device random_device;
196 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -0700197 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700198
199 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -0800200 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
201 std::vector<float> x(elements);
202 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700203
204 std::generate(x.begin(), x.end(), std::ref(f32rng));
205
206 benchmark::utils::DisableDenormals();
207
208 size_t buffer_index = 0;
209 for (auto _ : state) {
Marat Dukhan42323232019-10-23 02:09:02 -0700210 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700211 if (++buffer_index == num_buffers) {
212 buffer_index = 0;
213 }
214
215 const auto start = std::chrono::high_resolution_clock::now();
216 float x_max = nanf("");
Marat Dukhand713e8a2020-12-04 14:23:12 -0800217 rmax(elements * sizeof(float), x.data(), &x_max);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700218 float y_sum = nanf("");
Marat Dukhand713e8a2020-12-04 14:23:12 -0800219 raddexpminusmax(elements * sizeof(float), x.data(), &y_sum, x_max);
220 vscaleexpminusmax(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, x_max, 1.0f / y_sum);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700221 const auto end = std::chrono::high_resolution_clock::now();
222
223 const auto elapsed_seconds =
224 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
225 state.SetIterationTime(elapsed_seconds.count());
226 }
227
Marat Dukhand713e8a2020-12-04 14:23:12 -0800228 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
229 if (cpu_frequency != 0) {
230 state.counters["cpufreq"] = cpu_frequency;
231 }
232
233 const size_t elements_per_iteration = elements;
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700234 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800235 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
236
237 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700238 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800239 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700240}
241
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800242static void ThreePassSoftMaxWithReloading(
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700243 benchmark::State& state,
244 xnn_f32_rmax_ukernel_function rmax,
245 xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800246 xnn_f32_vscale_ukernel_function vscale,
247 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700248{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800249 if (isa_check && !isa_check(state)) {
250 return;
251 }
252
Marat Dukhand713e8a2020-12-04 14:23:12 -0800253 const size_t elements = state.range(0);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700254 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -0800255 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700256
257 std::random_device random_device;
258 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -0700259 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700260
261 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -0800262 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
263 std::vector<float> x(elements);
264 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700265
266 std::generate(x.begin(), x.end(), std::ref(f32rng));
267
268 benchmark::utils::DisableDenormals();
269
270 size_t buffer_index = 0;
271 for (auto _ : state) {
Marat Dukhan42323232019-10-23 02:09:02 -0700272 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700273 if (++buffer_index == num_buffers) {
274 buffer_index = 0;
275 }
276
277 const auto start = std::chrono::high_resolution_clock::now();
278 float x_max = nanf("");
Marat Dukhand713e8a2020-12-04 14:23:12 -0800279 rmax(elements * sizeof(float), x.data(), &x_max);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700280 float y_sum = nanf("");
Marat Dukhand713e8a2020-12-04 14:23:12 -0800281 raddstoreexpminusmax(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, &y_sum, x_max);
282 vscale(elements * sizeof(float), y.data() + packed_elements * buffer_index, y.data() + packed_elements * buffer_index, 1.0f / y_sum);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700283 const auto end = std::chrono::high_resolution_clock::now();
284
285 const auto elapsed_seconds =
286 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
287 state.SetIterationTime(elapsed_seconds.count());
288 }
289
Marat Dukhand713e8a2020-12-04 14:23:12 -0800290 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
291 if (cpu_frequency != 0) {
292 state.counters["cpufreq"] = cpu_frequency;
293 }
294
295 const size_t elements_per_iteration = elements;
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700296 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800297 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
298
299 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700300 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800301 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700302}
303
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800304static void TwoPassSoftMax(
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700305 benchmark::State& state,
306 xnn_f32_raddextexp_ukernel_function raddextexp,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800307 xnn_f32_vscaleextexp_ukernel_function vscaleextexp,
308 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700309{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800310 if (isa_check && !isa_check(state)) {
311 return;
312 }
313
Marat Dukhand713e8a2020-12-04 14:23:12 -0800314 const size_t elements = state.range(0);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700315 const size_t cache_line_size_max = 128;
Marat Dukhand713e8a2020-12-04 14:23:12 -0800316 const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700317
318 std::random_device random_device;
319 auto rng = std::mt19937(random_device());
Marat Dukhan44f0ca72020-08-02 21:46:58 -0700320 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700321
322 const size_t num_buffers = 1 +
Marat Dukhand713e8a2020-12-04 14:23:12 -0800323 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
324 std::vector<float> x(elements);
325 std::vector<float> y(packed_elements * num_buffers);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700326
327 std::generate(x.begin(), x.end(), std::ref(f32rng));
328
329 benchmark::utils::DisableDenormals();
330
331 size_t buffer_index = 0;
332 for (auto _ : state) {
333 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
334 if (++buffer_index == num_buffers) {
335 buffer_index = 0;
336 }
337
338 const auto start = std::chrono::high_resolution_clock::now();
339 float scale[2];
Marat Dukhand713e8a2020-12-04 14:23:12 -0800340 raddextexp(elements * sizeof(float), x.data(), scale);
341 vscaleextexp(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, 1.0f / scale[0], -scale[1]);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700342 const auto end = std::chrono::high_resolution_clock::now();
343
344 const auto elapsed_seconds =
345 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
346 state.SetIterationTime(elapsed_seconds.count());
347 }
348
Marat Dukhand713e8a2020-12-04 14:23:12 -0800349 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
350 if (cpu_frequency != 0) {
351 state.counters["cpufreq"] = cpu_frequency;
352 }
353
354 const size_t elements_per_iteration = elements;
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700355 state.counters["elements"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800356 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
357
358 const size_t bytes_per_iteration = 2 * elements * sizeof(float);
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700359 state.counters["bytes"] =
Marat Dukhand713e8a2020-12-04 14:23:12 -0800360 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700361}
362
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700363static void CharacteristicArguments(benchmark::internal::Benchmark* b) {
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800364 for (int32_t n = 1000; n <= 100000000; n *= 10) {
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700365 b->Arg(n);
366 b->Arg(3 * n);
367 }
368}
369
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800370#ifdef BENCHMARK_INTEL_DNNL
371 BENCHMARK(DNNLSoftArgMax)->Apply(CharacteristicArguments)->UseManualTime();
372#endif
373
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700374#if XNN_ARCH_X86 || XNN_ARCH_X86_64
Marat Dukhanb6862d12020-03-08 15:26:49 -0700375 BENCHMARK_CAPTURE(TwoPassSoftMax, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800376 xnn_f32_raddextexp_ukernel__avx2_p5_x96,
377 xnn_f32_vscaleextexp_ukernel__avx2_p5_x40,
378 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700379 BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800380 xnn_f32_rmax_ukernel__avx,
381 xnn_f32_raddexpminusmax_ukernel__avx2_p5_x96,
382 xnn_f32_vscaleexpminusmax_ukernel__avx2_p5_x24,
383 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700384 BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800385 xnn_f32_rmax_ukernel__avx,
386 xnn_f32_raddstoreexpminusmax_ukernel__avx2_p5_x64_acc2,
Frank Barchardbeca6522020-10-30 22:34:35 -0700387 xnn_f32_vscale_ukernel__avx_x32,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800388 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700389
Marat Dukhanb6862d12020-03-08 15:26:49 -0700390 BENCHMARK_CAPTURE(TwoPassSoftMax, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800391 xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_x144_acc3,
392 xnn_f32_vscaleextexp_ukernel__avx512f_p5_scalef_x16,
393 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700394 BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800395 xnn_f32_rmax_ukernel__avx512f,
396 xnn_f32_raddexpminusmax_ukernel__avx512f_p5_scalef_x128_acc4,
397 xnn_f32_vscaleexpminusmax_ukernel__avx512f_p5_scalef_x16,
398 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700399 BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800400 xnn_f32_rmax_ukernel__avx512f,
401 xnn_f32_raddstoreexpminusmax_ukernel__avx512f_p5_scalef_x128_acc2,
Frank Barchardbeca6522020-10-30 22:34:35 -0700402 xnn_f32_vscale_ukernel__avx512f_x64,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800403 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700404#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
405
406#ifndef XNNPACK_BENCHMARK_NO_MAIN
407BENCHMARK_MAIN();
408#endif