blob: cc4a65043a73c8d3955203f5214e20bd6282fa6b [file] [log] [blame]
XNNPACK Teamab8c4c82020-10-09 08:05:51 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <xnnpack.h>
9#include <xnnpack/subgraph.h>
10
11#include <algorithm>
12#include <cassert>
13#include <cmath>
14#include <cstddef>
15#include <cstdlib>
16#include <functional>
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070017#include <random>
18#include <vector>
19
20#include <gtest/gtest.h>
21
22enum xnn_tensor_type {
Marat Dukhan54b2d542020-12-08 00:19:52 -080023 kStaticDense,
24 kStaticSparse,
25 kDynamic,
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070026};
27
28class SubgraphTester {
29 public:
30 explicit SubgraphTester(uint32_t external_value_ids) {
Marat Dukhanc10585f2020-12-08 09:34:55 -080031 xnn_status status = xnn_initialize(nullptr);
32 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070033
Marat Dukhanc10585f2020-12-08 09:34:55 -080034 xnn_subgraph_t subgraph_ptr = nullptr;
35 status = xnn_create_subgraph(external_value_ids, 0 /* flags */, &subgraph_ptr);
36 EXPECT_EQ(status, xnn_status_success);
37 subgraph_.reset(subgraph_ptr);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070038
39 std::random_device random_device;
40 rng_ = std::mt19937(random_device());
41 }
42
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070043 inline SubgraphTester& add_tensor(const std::vector<size_t>& dims,
44 xnn_tensor_type tensor_type,
45 uint32_t external_id) {
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070046 void* data = nullptr;
Marat Dukhan54b2d542020-12-08 00:19:52 -080047 if (tensor_type == kStaticDense || tensor_type == kStaticSparse) {
Marat Dukhan4eddb9c2020-12-13 17:29:44 -080048 const size_t num_elements = std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<size_t>());
Marat Dukhan54b2d542020-12-08 00:19:52 -080049 static_data_.emplace_back(num_elements);
50 std::vector<float>& weights = static_data_.back();
51 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, +1.0f), std::ref(rng_));
52 if (tensor_type == kStaticDense) {
53 std::generate(weights.begin(), weights.end(), std::ref(f32rng));
54 } else {
55 // Create tensor with 90% sparsity in two steps:
56 // 1. Generate non-zero elements in the beginning of the vector
57 // 2. Randomize positions of non-zero elements
58 const size_t num_nonzero_elements = num_elements / 10;
59 std::generate(weights.begin(), weights.begin() + num_nonzero_elements, std::ref(f32rng));
60 std::shuffle(weights.begin(), weights.end(), rng_);
61 }
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070062 data = weights.data();
63 }
64 uint32_t id_out = 0;
Marat Dukhanc10585f2020-12-08 09:34:55 -080065 const xnn_status status =
66 xnn_define_tensor_value(subgraph_.get(), xnn_datatype_fp32, dims.size(),
67 dims.data(), data, external_id, 0 /* flags */, &id_out);
68 EXPECT_EQ(status, xnn_status_success);
69 EXPECT_EQ(id_out, external_id);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070070
71 return *this;
72 }
73
74 inline SubgraphTester& add_conv(
75 uint32_t input_padding_top, uint32_t input_padding_right,
76 uint32_t input_padding_bottom, uint32_t input_padding_left,
77 uint32_t kernel_height, uint32_t kernel_width,
78 uint32_t subsampling_height, uint32_t subsampling_width,
79 uint32_t dilation_height, uint32_t dilation_width, uint32_t groups,
80 size_t group_input_channels, size_t group_output_channels,
81 uint32_t input_id, uint32_t filter_id, uint32_t bias_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -080082 uint32_t output_id)
83 {
84 const xnn_status status = xnn_define_convolution_2d(
85 subgraph_.get(), input_padding_top, input_padding_right,
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070086 input_padding_bottom, input_padding_left, kernel_height, kernel_width,
87 subsampling_height, subsampling_width, dilation_height, dilation_width,
88 groups, group_input_channels, group_output_channels,
89 -std::numeric_limits<float>::infinity(),
90 std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -080091 output_id, 0 /* flags */);
92 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070093
94 return *this;
95 }
96
97 inline SubgraphTester& add_depthwise_conv(
98 uint32_t input_padding_top, uint32_t input_padding_right,
99 uint32_t input_padding_bottom, uint32_t input_padding_left,
100 uint32_t kernel_height, uint32_t kernel_width,
101 uint32_t subsampling_height, uint32_t subsampling_width,
102 uint32_t dilation_height, uint32_t dilation_width,
103 uint32_t depth_multiplier, size_t input_channels, uint32_t input_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -0800104 uint32_t filter_id, uint32_t bias_id, uint32_t output_id)
105 {
106 const xnn_status status = xnn_define_depthwise_convolution_2d(
107 subgraph_.get(), input_padding_top, input_padding_right,
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700108 input_padding_bottom, input_padding_left, kernel_height, kernel_width,
109 subsampling_height, subsampling_width, dilation_height, dilation_width,
110 depth_multiplier, input_channels,
111 -std::numeric_limits<float>::infinity(),
112 std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
Marat Dukhanc10585f2020-12-08 09:34:55 -0800113 output_id, 0 /* flags */);
114 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700115
116 return *this;
117 }
118
Marat Dukhanc10585f2020-12-08 09:34:55 -0800119 inline SubgraphTester& add_addition(uint32_t input_id1, uint32_t input_id2, uint32_t output_id)
120 {
121 const xnn_status status =
122 xnn_define_add2(subgraph_.get(), -std::numeric_limits<float>::infinity(),
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700123 std::numeric_limits<float>::infinity(), input_id1,
Marat Dukhanc10585f2020-12-08 09:34:55 -0800124 input_id2, output_id, 0 /* flags */);
125 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700126
127 return *this;
128 }
129
Marat Dukhanc10585f2020-12-08 09:34:55 -0800130 inline SubgraphTester& add_global_average_pooling(uint32_t input_id, uint32_t output_id)
131 {
132 const xnn_status status = xnn_define_global_average_pooling_2d(
133 subgraph_.get(), -std::numeric_limits<float>::infinity(),
134 std::numeric_limits<float>::infinity(), input_id, output_id, 0 /* flags */);
135 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700136
137 return *this;
138 }
139
140 inline SubgraphTester& optimize() {
Marat Dukhanb36582b2020-12-08 11:16:28 -0800141 const xnn_status status = xnn_subgraph_optimize(subgraph_.get(), 0 /* flags */);
Marat Dukhanc10585f2020-12-08 09:34:55 -0800142 EXPECT_EQ(status, xnn_status_success);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700143
144 return *this;
145 }
146
147 inline SubgraphTester& rewrite() {
Marat Dukhanc10585f2020-12-08 09:34:55 -0800148 xnn_subgraph_rewrite_for_nchw(subgraph_.get());
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700149
150 return *this;
151 }
152
Marat Dukhan54b2d542020-12-08 00:19:52 -0800153 inline xnn_layout_type get_layout(uint32_t value_id) const {
Marat Dukhanc10585f2020-12-08 09:34:55 -0800154 return subgraph_->values[value_id].layout;
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700155 }
156
157 private:
Marat Dukhan54b2d542020-12-08 00:19:52 -0800158 std::vector<std::vector<float>> static_data_;
Marat Dukhanc10585f2020-12-08 09:34:55 -0800159 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph_{nullptr, xnn_delete_subgraph};
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700160 std::mt19937 rng_;
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700161};