blob: 630475c1f8f99b93dd7340f2558d00a89b916027 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#include <algorithm>
8#include <cmath>
9#include <functional>
10#include <random>
11#include <vector>
12
13#include <xnnpack.h>
14
15#include <benchmark/benchmark.h>
16
17
18static void softargmax_q8(benchmark::State& state) {
19 const size_t batch_size = static_cast<size_t>(state.range(0));
20 const size_t channels = static_cast<size_t>(state.range(1));
21
22 std::random_device random_device;
23 auto rng = std::mt19937(random_device());
24 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
25
26 std::vector<uint8_t> input(batch_size * channels);
27 std::vector<uint8_t> output(batch_size * channels);
28 std::generate(input.begin(), input.end(), std::ref(u8rng));
29 std::fill(output.begin(), output.end(), 0xA5);
30
31 xnn_status status = xnn_initialize();
32 if (status != xnn_status_success) {
33 state.SkipWithError("failed to initialize XNNPACK");
34 return;
35 }
36
37 xnn_operator_t softargmax_op = nullptr;
38 status = xnn_create_softargmax_nc_q8(
39 channels, channels /* input stride */, channels /* output stride */,
40 1.0f /* input scale */,
41 0 /* output zero point */, 1.0f / 256.0f /* output scale */,
42 0 /* flags */, &softargmax_op);
43 if (status != xnn_status_success || softargmax_op == nullptr) {
44 state.SkipWithError("failed to create SoftArgMax operator");
45 return;
46 }
47
48 status = xnn_setup_softargmax_nc_q8(
49 softargmax_op,
50 batch_size,
51 input.data(), output.data(),
52 nullptr /* thread pool */);
53 if (status != xnn_status_success) {
54 state.SkipWithError("failed to setup SoftArgMax operator");
55 return;
56 }
57
58 for (auto _ : state) {
59 status = xnn_run_operator(softargmax_op, nullptr /* thread pool */);
60 if (status != xnn_status_success) {
61 state.SkipWithError("failed to run SoftArgMax operator");
62 return;
63 }
64 }
65
66 status = xnn_delete_operator(softargmax_op);
67 if (status != xnn_status_success) {
68 state.SkipWithError("failed to delete SoftArgMax operator");
69 return;
70 }
71
72 const size_t elements_per_iteration = batch_size * channels;
73 state.counters["elements"] =
74 benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
75
76 const size_t bytes_per_iteration = 2 * elements_per_iteration * sizeof(uint8_t);
77 state.counters["bytes"] =
78 benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
79}
80
81static void CharacteristicArguments(benchmark::internal::Benchmark* b)
82{
83 b->ArgNames({"N", "C"});
84
85 // CIFAR-10
86 b->Args({1, 10});
87 // CIFAR-100 */
88 b->Args({1, 100});
89 // ImageNet-1K
90 b->Args({1, 1000});
91 // ImageNet-1K+1
92 b->Args({1, 1001});
93 // ImageNet-22K
94 b->Args({1, 21841});
95}
96
97BENCHMARK(softargmax_q8)->Apply(CharacteristicArguments)->UseRealTime();
98
99#ifndef XNNPACK_BENCHMARK_NO_MAIN
100BENCHMARK_MAIN();
101#endif