blob: 93870d2429d2dcba46252f3c4556a1b1985ba3f7 [file] [log] [blame]
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -07001#include <algorithm>
2#include <cfloat>
Marat Dukhan4fa0fbe2019-10-31 10:23:46 -07003#include <chrono>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -07004#include <cmath>
5#include <functional>
6#include <random>
7#include <vector>
8
9#include "bench/utils.h"
10#include <xnnpack/common.h>
11#include <xnnpack/params.h>
12#include <xnnpack/raddexpminusmax.h>
Marat Dukhan4a2bbc62019-10-25 17:36:32 -070013#include <xnnpack/raddextexp.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070014#include <xnnpack/raddstoreexpminusmax.h>
15#include <xnnpack/rmax.h>
Marat Dukhan05ac8e32019-10-21 15:39:33 -070016#include <xnnpack/vscale.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070017#include <xnnpack/vscaleexpminusmax.h>
Marat Dukhan4a2bbc62019-10-25 17:36:32 -070018#include <xnnpack/vscaleextexp.h>
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070019
20#include <benchmark/benchmark.h>
Marat Dukhan8d3c6932020-03-06 20:27:27 -080021#ifdef BENCHMARK_INTEL_DNNL
22#include <dnnl.h>
23#endif // BENCHMARK_INTEL_DNNL
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -070024
25
Marat Dukhan8d3c6932020-03-06 20:27:27 -080026#ifdef BENCHMARK_INTEL_DNNL
27static void DNNLSoftArgMax(
28 benchmark::State& state)
29{
30 const size_t n = state.range(0);
31 const size_t cache_line_size_max = 128;
32 const size_t packed_n = benchmark::utils::RoundUp(n, cache_line_size_max / sizeof(float));
33
34 std::random_device random_device;
35 auto rng = std::mt19937(random_device());
36 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), rng);
37
38 const size_t num_buffers = 1 +
39 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_n * sizeof(float));
40 std::vector<float> x(n);
41 std::vector<float> y(packed_n * num_buffers);
42
43 std::generate(x.begin(), x.end(), std::ref(f32rng));
44
45 dnnl_engine_t engine;
46 if (dnnl_engine_create(&engine, dnnl_cpu, 0) != dnnl_success) {
47 state.SkipWithError("failed to create CPU engine");
48 return;
49 }
50
51 dnnl_dim_t input_output_shape[1] = { static_cast<int>(n) };
52
53 dnnl_memory_desc_t memory_descriptor = { 0 };
54 if (dnnl_memory_desc_init_by_tag(
55 &memory_descriptor, 1, input_output_shape, dnnl_f32, dnnl_x) != dnnl_success)
56 {
57 state.SkipWithError("failed to create input memory descriptor");
58 return;
59 }
60
61 dnnl_memory_t input_memory = nullptr;
62 if (dnnl_memory_create(
63 &input_memory, &memory_descriptor, engine, x.data()) != dnnl_success)
64 {
65 state.SkipWithError("failed to create input memory");
66 return;
67 }
68
69 dnnl_memory_t output_memory = nullptr;
70 if (dnnl_memory_create(
71 &output_memory, &memory_descriptor, engine, y.data()) != dnnl_success)
72 {
73 state.SkipWithError("failed to create output memory");
74 return;
75 }
76
77 dnnl_softmax_desc_t softmax_forward_descriptor = {};
78 if (dnnl_softmax_forward_desc_init(
79 &softmax_forward_descriptor, dnnl_forward_inference,
80 &memory_descriptor, 0) != dnnl_success)
81 {
82 state.SkipWithError("failed to create SoftMax forward descriptor");
83 return;
84 }
85
86 dnnl_primitive_desc_t softmax_primitive_descriptor = nullptr;
87 if (dnnl_primitive_desc_create(
88 &softmax_primitive_descriptor, &softmax_forward_descriptor,
89 nullptr /* primitive attributes */, engine, nullptr /* hint */) != dnnl_success)
90 {
91 state.SkipWithError("failed to create SoftMax primitive descriptor");
92 return;
93 }
94
95 dnnl_primitive_t softmax_primitive = nullptr;
96 if (dnnl_primitive_create(
97 &softmax_primitive, softmax_primitive_descriptor) != dnnl_success)
98 {
99 state.SkipWithError("failed to create SoftMax primitive");
100 return;
101 }
102
103 dnnl_exec_arg_t softmax_args[2] = {
104 {DNNL_ARG_SRC, input_memory},
105 {DNNL_ARG_DST, output_memory},
106 };
107
108 dnnl_stream_t stream = nullptr;
109 if (dnnl_stream_create(&stream, engine, dnnl_stream_default_flags) != dnnl_success) {
110 state.SkipWithError("failed to create stream");
111 return;
112 }
113
114 size_t buffer_index = 0;
115 for (auto _ : state) {
116 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
117 if (++buffer_index == num_buffers) {
118 buffer_index = 0;
119 }
120
121 const auto start = std::chrono::high_resolution_clock::now();
122 if (dnnl_primitive_execute(
123 softmax_primitive, stream, 2, softmax_args) != dnnl_success)
124 {
125 state.SkipWithError("failed to execute SoftMax");
126 return;
127 }
128 const auto end = std::chrono::high_resolution_clock::now();
129
130 const auto elapsed_seconds =
131 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
132 state.SetIterationTime(elapsed_seconds.count());
133 }
134
135 if (dnnl_stream_destroy(stream) != dnnl_success) {
136 state.SkipWithError("failed to destroy stream");
137 return;
138 }
139
140 if (dnnl_primitive_desc_destroy(softmax_primitive_descriptor) != dnnl_success) {
141 state.SkipWithError("failed to destroy SoftMax primitive descriptor");
142 return;
143 }
144
145 if (dnnl_primitive_destroy(softmax_primitive) != dnnl_success) {
146 state.SkipWithError("failed to destroy SoftMax primitive");
147 return;
148 }
149
150 if (dnnl_memory_destroy(input_memory) != dnnl_success) {
151 state.SkipWithError("failed to destroy input memory");
152 return;
153 }
154
155 if (dnnl_memory_destroy(output_memory) != dnnl_success) {
156 state.SkipWithError("failed to destroy output memory");
157 return;
158 }
159
160 if (dnnl_engine_destroy(engine) != dnnl_success) {
161 state.SkipWithError("failed to destroy engine");
162 return;
163 }
164
165 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
166 state.counters["elements"] =
167 benchmark::Counter(uint64_t(state.iterations()) * n, benchmark::Counter::kIsRate);
168 state.counters["bytes"] =
169 benchmark::Counter(uint64_t(state.iterations()) * 2 * sizeof(float) * n, benchmark::Counter::kIsRate);
170}
171#endif // BENCHMARK_INTEL_DNNL
172
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800173static void ThreePassSoftMaxWithRecomputing(
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700174 benchmark::State& state,
175 xnn_f32_rmax_ukernel_function rmax,
176 xnn_f32_raddexpminusmax_ukernel_function raddexpminusmax,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800177 xnn_f32_vscaleexpminusmax_ukernel_function vscaleexpminusmax,
178 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700179{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800180 if (isa_check && !isa_check(state)) {
181 return;
182 }
183
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700184 const size_t n = state.range(0);
185 const size_t cache_line_size_max = 128;
Marat Dukhan42323232019-10-23 02:09:02 -0700186 const size_t packed_n = benchmark::utils::RoundUp(n, cache_line_size_max / sizeof(float));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700187
188 std::random_device random_device;
189 auto rng = std::mt19937(random_device());
190 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), rng);
191
192 const size_t num_buffers = 1 +
Marat Dukhan42323232019-10-23 02:09:02 -0700193 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_n * sizeof(float));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700194 std::vector<float> x(n);
195 std::vector<float> y(packed_n * num_buffers);
196
197 std::generate(x.begin(), x.end(), std::ref(f32rng));
198
199 benchmark::utils::DisableDenormals();
200
201 size_t buffer_index = 0;
202 for (auto _ : state) {
Marat Dukhan42323232019-10-23 02:09:02 -0700203 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700204 if (++buffer_index == num_buffers) {
205 buffer_index = 0;
206 }
207
208 const auto start = std::chrono::high_resolution_clock::now();
209 float x_max = nanf("");
210 rmax(n * sizeof(float), x.data(), &x_max);
211 float y_sum = nanf("");
212 raddexpminusmax(n * sizeof(float), x.data(), &y_sum, x_max);
213 vscaleexpminusmax(n * sizeof(float), x.data(), y.data() + packed_n * buffer_index, x_max, 1.0f / y_sum);
214 const auto end = std::chrono::high_resolution_clock::now();
215
216 const auto elapsed_seconds =
217 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
218 state.SetIterationTime(elapsed_seconds.count());
219 }
220
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700221 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
222 state.counters["elements"] =
223 benchmark::Counter(uint64_t(state.iterations()) * n, benchmark::Counter::kIsRate);
224 state.counters["bytes"] =
225 benchmark::Counter(uint64_t(state.iterations()) * 2 * sizeof(float) * n, benchmark::Counter::kIsRate);
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700226}
227
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800228static void ThreePassSoftMaxWithReloading(
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700229 benchmark::State& state,
230 xnn_f32_rmax_ukernel_function rmax,
231 xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800232 xnn_f32_vscale_ukernel_function vscale,
233 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700234{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800235 if (isa_check && !isa_check(state)) {
236 return;
237 }
238
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700239 const size_t n = state.range(0);
240 const size_t cache_line_size_max = 128;
Marat Dukhan42323232019-10-23 02:09:02 -0700241 const size_t packed_n = benchmark::utils::RoundUp(n, cache_line_size_max / sizeof(float));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700242
243 std::random_device random_device;
244 auto rng = std::mt19937(random_device());
245 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), rng);
246
247 const size_t num_buffers = 1 +
Marat Dukhan42323232019-10-23 02:09:02 -0700248 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_n * sizeof(float));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700249 std::vector<float> x(n);
250 std::vector<float> y(packed_n * num_buffers);
251
252 std::generate(x.begin(), x.end(), std::ref(f32rng));
253
254 benchmark::utils::DisableDenormals();
255
256 size_t buffer_index = 0;
257 for (auto _ : state) {
Marat Dukhan42323232019-10-23 02:09:02 -0700258 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700259 if (++buffer_index == num_buffers) {
260 buffer_index = 0;
261 }
262
263 const auto start = std::chrono::high_resolution_clock::now();
264 float x_max = nanf("");
265 rmax(n * sizeof(float), x.data(), &x_max);
266 float y_sum = nanf("");
267 raddstoreexpminusmax(n * sizeof(float), x.data(), y.data() + packed_n * buffer_index, &y_sum, x_max);
268 vscale(n * sizeof(float), y.data() + packed_n * buffer_index, y.data() + packed_n * buffer_index, 1.0f / y_sum);
269 const auto end = std::chrono::high_resolution_clock::now();
270
271 const auto elapsed_seconds =
272 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
273 state.SetIterationTime(elapsed_seconds.count());
274 }
275
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700276 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
277 state.counters["elements"] =
278 benchmark::Counter(uint64_t(state.iterations()) * n, benchmark::Counter::kIsRate);
279 state.counters["bytes"] =
280 benchmark::Counter(uint64_t(state.iterations()) * 2 * sizeof(float) * n, benchmark::Counter::kIsRate);
281}
282
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800283static void TwoPassSoftMax(
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700284 benchmark::State& state,
285 xnn_f32_raddextexp_ukernel_function raddextexp,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800286 xnn_f32_vscaleextexp_ukernel_function vscaleextexp,
287 benchmark::utils::IsaCheckFunction isa_check = nullptr)
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700288{
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800289 if (isa_check && !isa_check(state)) {
290 return;
291 }
292
Marat Dukhan4a2bbc62019-10-25 17:36:32 -0700293 const size_t n = state.range(0);
294 const size_t cache_line_size_max = 128;
295 const size_t packed_n = benchmark::utils::RoundUp(n, cache_line_size_max / sizeof(float));
296
297 std::random_device random_device;
298 auto rng = std::mt19937(random_device());
299 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), rng);
300
301 const size_t num_buffers = 1 +
302 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_n * sizeof(float));
303 std::vector<float> x(n);
304 std::vector<float> y(packed_n * num_buffers);
305
306 std::generate(x.begin(), x.end(), std::ref(f32rng));
307
308 benchmark::utils::DisableDenormals();
309
310 size_t buffer_index = 0;
311 for (auto _ : state) {
312 benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
313 if (++buffer_index == num_buffers) {
314 buffer_index = 0;
315 }
316
317 const auto start = std::chrono::high_resolution_clock::now();
318 float scale[2];
319 raddextexp(n * sizeof(float), x.data(), scale);
320 vscaleextexp(n * sizeof(float), x.data(), y.data() + packed_n * buffer_index, 1.0f / scale[0], -scale[1]);
321 const auto end = std::chrono::high_resolution_clock::now();
322
323 const auto elapsed_seconds =
324 std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
325 state.SetIterationTime(elapsed_seconds.count());
326 }
327
328 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
329 state.counters["elements"] =
330 benchmark::Counter(uint64_t(state.iterations()) * n, benchmark::Counter::kIsRate);
331 state.counters["bytes"] =
332 benchmark::Counter(uint64_t(state.iterations()) * 2 * sizeof(float) * n, benchmark::Counter::kIsRate);
Marat Dukhan05ac8e32019-10-21 15:39:33 -0700333}
334
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700335static void CharacteristicArguments(benchmark::internal::Benchmark* b) {
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800336 for (int32_t n = 1000; n <= 100000000; n *= 10) {
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700337 b->Arg(n);
338 b->Arg(3 * n);
339 }
340}
341
Marat Dukhan8d3c6932020-03-06 20:27:27 -0800342#ifdef BENCHMARK_INTEL_DNNL
343 BENCHMARK(DNNLSoftArgMax)->Apply(CharacteristicArguments)->UseManualTime();
344#endif
345
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700346#if XNN_ARCH_X86 || XNN_ARCH_X86_64
Marat Dukhanb6862d12020-03-08 15:26:49 -0700347 BENCHMARK_CAPTURE(TwoPassSoftMax, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800348 xnn_f32_raddextexp_ukernel__avx2_p5_x96,
349 xnn_f32_vscaleextexp_ukernel__avx2_p5_x40,
350 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700351 BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800352 xnn_f32_rmax_ukernel__avx,
353 xnn_f32_raddexpminusmax_ukernel__avx2_p5_x96,
354 xnn_f32_vscaleexpminusmax_ukernel__avx2_p5_x24,
355 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700356 BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx2_p5,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800357 xnn_f32_rmax_ukernel__avx,
358 xnn_f32_raddstoreexpminusmax_ukernel__avx2_p5_x64_acc2,
359 xnn_f32_vscale_ukernel__avx_unroll32,
360 benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700361
Marat Dukhanb6862d12020-03-08 15:26:49 -0700362 BENCHMARK_CAPTURE(TwoPassSoftMax, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800363 xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_x144_acc3,
364 xnn_f32_vscaleextexp_ukernel__avx512f_p5_scalef_x16,
365 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700366 BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800367 xnn_f32_rmax_ukernel__avx512f,
368 xnn_f32_raddexpminusmax_ukernel__avx512f_p5_scalef_x128_acc4,
369 xnn_f32_vscaleexpminusmax_ukernel__avx512f_p5_scalef_x16,
370 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhanb6862d12020-03-08 15:26:49 -0700371 BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx512f_p5_scalef,
Marat Dukhan4c4eb002019-12-08 21:27:49 -0800372 xnn_f32_rmax_ukernel__avx512f,
373 xnn_f32_raddstoreexpminusmax_ukernel__avx512f_p5_scalef_x128_acc2,
374 xnn_f32_vscale_ukernel__avx512f_unroll64,
375 benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
Marat Dukhan4a4a7fa2019-10-21 13:46:14 -0700376#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
377
378#ifndef XNNPACK_BENCHMARK_NO_MAIN
379BENCHMARK_MAIN();
380#endif