blob: a6857b0ca73f9810df294b81c9541a86f69ffcdf [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <gtest/gtest.h>
10
11#include "softargmax-operator-tester.h"
12
XNNPACK Teamb455b122019-09-27 18:10:33 -070013
Marat Dukhanefc47b82019-11-18 09:25:38 -080014TEST(SOFTARGMAX_NC_Q8, single_class) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070015 SoftArgMaxOperatorTester()
16 .batch_size(1)
17 .channels(1)
18 .iterations(100)
19 .TestQ8();
20}
21
Marat Dukhanefc47b82019-11-18 09:25:38 -080022TEST(SOFTARGMAX_NC_Q8, two_classes) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070023 SoftArgMaxOperatorTester()
24 .batch_size(1)
25 .channels(2)
26 .iterations(100)
27 .TestQ8();
28}
29
Marat Dukhanefc47b82019-11-18 09:25:38 -080030TEST(SOFTARGMAX_NC_Q8, many_classes) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070031 for (size_t channels = 3; channels < 100; channels++) {
32 SoftArgMaxOperatorTester()
33 .batch_size(1)
34 .channels(channels)
35 .iterations(1)
36 .TestQ8();
37 }
38}
39
Marat Dukhanefc47b82019-11-18 09:25:38 -080040TEST(SOFTARGMAX_NC_Q8, cifar_classes) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070041 // CIFAR-10
42 SoftArgMaxOperatorTester()
43 .batch_size(1)
44 .channels(10)
45 .iterations(15)
46 .TestQ8();
47 // CIFAR-100
48 SoftArgMaxOperatorTester()
49 .batch_size(1)
50 .channels(100)
51 .iterations(15)
52 .TestQ8();
53}
54
Marat Dukhanefc47b82019-11-18 09:25:38 -080055TEST(SOFTARGMAX_NC_Q8, imagenet_classes) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070056 // ImageNet-1K
57 SoftArgMaxOperatorTester()
58 .batch_size(1)
59 .channels(1000)
60 .iterations(10)
61 .TestQ8();
62 // ImageNet-1K+1
63 SoftArgMaxOperatorTester()
64 .batch_size(1)
65 .channels(1001)
66 .iterations(10)
67 .TestQ8();
68 // ImageNet-22K
69 SoftArgMaxOperatorTester()
70 .batch_size(1)
71 .channels(21841)
72 .iterations(10)
73 .TestQ8();
74}
75
Marat Dukhanefc47b82019-11-18 09:25:38 -080076TEST(SOFTARGMAX_NC_Q8, many_channels_with_input_scale) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070077 for (size_t channels = 1; channels < 100; channels += 5) {
78 for (float input_scale = 1.0e-2f; input_scale < 1.0e+2f; input_scale *= 3.14159265f) {
79 SoftArgMaxOperatorTester()
80 .batch_size(1)
81 .channels(channels)
82 .input_scale(input_scale)
83 .iterations(1)
84 .TestQ8();
85 }
86 }
87}
88
Marat Dukhanefc47b82019-11-18 09:25:38 -080089TEST(SOFTARGMAX_NC_Q8, many_channels_with_input_zero_point) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070090 for (size_t channels = 1; channels < 100; channels += 5) {
91 for (int32_t input_zero_point = 0; input_zero_point <= 255; input_zero_point += 51) {
92 SoftArgMaxOperatorTester()
93 .batch_size(1)
94 .channels(channels)
95 .input_zero_point(uint8_t(input_zero_point))
96 .iterations(1)
97 .TestQ8();
98 }
99 }
100}
101
Marat Dukhanefc47b82019-11-18 09:25:38 -0800102TEST(SOFTARGMAX_NC_Q8, small_batch) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700103 for (size_t channels = 1; channels < 100; channels += 5) {
104 SoftArgMaxOperatorTester()
105 .batch_size(3)
106 .channels(channels)
107 .iterations(3)
108 .TestQ8();
109 }
110}
111
Marat Dukhanefc47b82019-11-18 09:25:38 -0800112TEST(SOFTARGMAX_NC_Q8, small_batch_with_input_stride) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700113 for (size_t channels = 1; channels < 100; channels += 5) {
114 SoftArgMaxOperatorTester()
115 .batch_size(3)
116 .channels(channels)
117 .input_stride(129)
118 .iterations(3)
119 .TestQ8();
120 }
121}
122
Marat Dukhanefc47b82019-11-18 09:25:38 -0800123TEST(SOFTARGMAX_NC_Q8, small_batch_with_output_stride) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700124 for (size_t channels = 1; channels < 100; channels += 5) {
125 SoftArgMaxOperatorTester()
126 .batch_size(3)
127 .channels(channels)
128 .output_stride(117)
129 .iterations(3)
130 .TestQ8();
131 }
132}
133
Marat Dukhanefc47b82019-11-18 09:25:38 -0800134TEST(SOFTARGMAX_NC_Q8, strided_batch_with_input_and_output_stride) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700135 for (size_t channels = 1; channels < 100; channels += 5) {
136 SoftArgMaxOperatorTester()
137 .batch_size(3)
138 .channels(channels)
139 .input_stride(129)
140 .output_stride(117)
141 .iterations(3)
142 .TestQ8();
143 }
144}