blob: 813356006ec696df273dc6e9b2f9315a6624221c [file] [log] [blame]
XNNPACK Teamab8c4c82020-10-09 08:05:51 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <xnnpack.h>
9#include <xnnpack/subgraph.h>
10
11#include <algorithm>
12#include <cassert>
13#include <cmath>
14#include <cstddef>
15#include <cstdlib>
16#include <functional>
peter65692c72021-08-17 23:41:05 +080017#include <numeric>
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070018#include <random>
19#include <vector>
20
21#include <gtest/gtest.h>
22
23enum xnn_tensor_type {
Marat Dukhan54b2d542020-12-08 00:19:52 -080024 kStaticDense,
25 kStaticSparse,
26 kDynamic,
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070027};
28
29class SubgraphTester {
30 public:
31 explicit SubgraphTester(uint32_t external_value_ids) {
Marat Dukhanc10585f2020-12-08 09:34:55 -080032 xnn_status status = xnn_initialize(nullptr);
33 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070034
Marat Dukhanc10585f2020-12-08 09:34:55 -080035 xnn_subgraph_t subgraph_ptr = nullptr;
36 status = xnn_create_subgraph(external_value_ids, 0 /* flags */, &subgraph_ptr);
37 EXPECT_EQ(status, xnn_status_success);
38 subgraph_.reset(subgraph_ptr);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070039
40 std::random_device random_device;
41 rng_ = std::mt19937(random_device());
42 }
43
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070044 inline SubgraphTester& add_tensor(const std::vector<size_t>& dims,
45 xnn_tensor_type tensor_type,
46 uint32_t external_id) {
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070047 void* data = nullptr;
Marat Dukhan54b2d542020-12-08 00:19:52 -080048 if (tensor_type == kStaticDense || tensor_type == kStaticSparse) {
Marat Dukhan4eddb9c2020-12-13 17:29:44 -080049 const size_t num_elements = std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<size_t>());
Marat Dukhan54b2d542020-12-08 00:19:52 -080050 static_data_.emplace_back(num_elements);
51 std::vector<float>& weights = static_data_.back();
52 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, +1.0f), std::ref(rng_));
53 if (tensor_type == kStaticDense) {
54 std::generate(weights.begin(), weights.end(), std::ref(f32rng));
55 } else {
56 // Create tensor with 90% sparsity in two steps:
57 // 1. Generate non-zero elements in the beginning of the vector
58 // 2. Randomize positions of non-zero elements
59 const size_t num_nonzero_elements = num_elements / 10;
60 std::generate(weights.begin(), weights.begin() + num_nonzero_elements, std::ref(f32rng));
61 std::shuffle(weights.begin(), weights.end(), rng_);
62 }
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070063 data = weights.data();
64 }
65 uint32_t id_out = 0;
Marat Dukhanc10585f2020-12-08 09:34:55 -080066 const xnn_status status =
67 xnn_define_tensor_value(subgraph_.get(), xnn_datatype_fp32, dims.size(),
68 dims.data(), data, external_id, 0 /* flags */, &id_out);
69 EXPECT_EQ(status, xnn_status_success);
70 EXPECT_EQ(id_out, external_id);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070071
72 return *this;
73 }
74
75 inline SubgraphTester& add_conv(
76 uint32_t input_padding_top, uint32_t input_padding_right,
77 uint32_t input_padding_bottom, uint32_t input_padding_left,
78 uint32_t kernel_height, uint32_t kernel_width,
79 uint32_t subsampling_height, uint32_t subsampling_width,
80 uint32_t dilation_height, uint32_t dilation_width, uint32_t groups,
81 size_t group_input_channels, size_t group_output_channels,
82 uint32_t input_id, uint32_t filter_id, uint32_t bias_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -080083 uint32_t output_id)
84 {
85 const xnn_status status = xnn_define_convolution_2d(
86 subgraph_.get(), input_padding_top, input_padding_right,
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070087 input_padding_bottom, input_padding_left, kernel_height, kernel_width,
88 subsampling_height, subsampling_width, dilation_height, dilation_width,
89 groups, group_input_channels, group_output_channels,
90 -std::numeric_limits<float>::infinity(),
91 std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -080092 output_id, 0 /* flags */);
93 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070094
95 return *this;
96 }
97
98 inline SubgraphTester& add_depthwise_conv(
99 uint32_t input_padding_top, uint32_t input_padding_right,
100 uint32_t input_padding_bottom, uint32_t input_padding_left,
101 uint32_t kernel_height, uint32_t kernel_width,
102 uint32_t subsampling_height, uint32_t subsampling_width,
103 uint32_t dilation_height, uint32_t dilation_width,
104 uint32_t depth_multiplier, size_t input_channels, uint32_t input_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -0800105 uint32_t filter_id, uint32_t bias_id, uint32_t output_id)
106 {
107 const xnn_status status = xnn_define_depthwise_convolution_2d(
108 subgraph_.get(), input_padding_top, input_padding_right,
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700109 input_padding_bottom, input_padding_left, kernel_height, kernel_width,
110 subsampling_height, subsampling_width, dilation_height, dilation_width,
111 depth_multiplier, input_channels,
112 -std::numeric_limits<float>::infinity(),
113 std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -0800114 output_id, 0 /* flags */);
115 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700116
117 return *this;
118 }
119
Marat Dukhanc10585f2020-12-08 09:34:55 -0800120 inline SubgraphTester& add_addition(uint32_t input_id1, uint32_t input_id2, uint32_t output_id)
121 {
122 const xnn_status status =
123 xnn_define_add2(subgraph_.get(), -std::numeric_limits<float>::infinity(),
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700124 std::numeric_limits<float>::infinity(), input_id1,
Marat Dukhanc10585f2020-12-08 09:34:55 -0800125 input_id2, output_id, 0 /* flags */);
126 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700127
128 return *this;
129 }
130
Marat Dukhanc10585f2020-12-08 09:34:55 -0800131 inline SubgraphTester& add_global_average_pooling(uint32_t input_id, uint32_t output_id)
132 {
133 const xnn_status status = xnn_define_global_average_pooling_2d(
134 subgraph_.get(), -std::numeric_limits<float>::infinity(),
135 std::numeric_limits<float>::infinity(), input_id, output_id, 0 /* flags */);
136 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700137
138 return *this;
139 }
140
141 inline SubgraphTester& optimize() {
Marat Dukhanb36582b2020-12-08 11:16:28 -0800142 const xnn_status status = xnn_subgraph_optimize(subgraph_.get(), 0 /* flags */);
Marat Dukhanc10585f2020-12-08 09:34:55 -0800143 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700144
145 return *this;
146 }
147
148 inline SubgraphTester& rewrite() {
Marat Dukhanc10585f2020-12-08 09:34:55 -0800149 xnn_subgraph_rewrite_for_nchw(subgraph_.get());
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700150
151 return *this;
152 }
153
Marat Dukhan54b2d542020-12-08 00:19:52 -0800154 inline xnn_layout_type get_layout(uint32_t value_id) const {
Marat Dukhanc10585f2020-12-08 09:34:55 -0800155 return subgraph_->values[value_id].layout;
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700156 }
157
158 private:
Marat Dukhan54b2d542020-12-08 00:19:52 -0800159 std::vector<std::vector<float>> static_data_;
Marat Dukhanc10585f2020-12-08 09:34:55 -0800160 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph_{nullptr, xnn_delete_subgraph};
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700161 std::mt19937 rng_;
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700162};