blob: 01fc8ccaad2d302eb278d5a37a468bf9173e2cd1 [file] [log] [blame]
XNNPACK Teamab8c4c82020-10-09 08:05:51 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack.h>
7
8#include "subgraph-tester.h"
9#include <gtest/gtest.h>
10
Marat Dukhan54b2d542020-12-08 00:19:52 -080011TEST(SUBGRAPH_NCHW, single_conv) {
12 auto tester = SubgraphTester(4);
13 tester
14 .add_tensor({1, 256, 256, 3}, kDynamic, 0)
15 .add_tensor({32, 3, 3, 3}, kStaticDense, 1)
16 .add_tensor({32}, kStaticDense, 2)
17 .add_tensor({1, 128, 128, 32}, kDynamic, 3)
18 .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 32, 0, 1, 2, 3)
19 .optimize()
20 .rewrite();
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070021
Marat Dukhan54b2d542020-12-08 00:19:52 -080022 ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
23 ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nhwc);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070024}
25
Marat Dukhan54b2d542020-12-08 00:19:52 -080026TEST(SUBGRAPH_NCHW, single_conv_and_global_average_pooling) {
27 auto tester = SubgraphTester(5);
28 tester
29 .add_tensor({1, 256, 256, 3}, kDynamic, 0)
30 .add_tensor({32, 3, 3, 3}, kStaticDense, 1)
31 .add_tensor({32}, kStaticDense, 2)
32 .add_tensor({1, 128, 128, 32}, kDynamic, 3)
33 .add_tensor({32}, kDynamic, 4)
34 .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 32, 0, 1, 2, 3)
35 .add_global_average_pooling(3, 4)
36 .optimize()
37 .rewrite();
38
39 ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
40 ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nhwc);
41 ASSERT_EQ(tester.get_layout(4), xnn_layout_type_nhwc);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -070042}
43
Marat Dukhan54b2d542020-12-08 00:19:52 -080044TEST(SUBGRAPH_NCHW, pixelwise_conv_sandwich) {
45 auto tester = SubgraphTester(8);
46 tester
47 .add_tensor({1, 256, 256, 3}, kDynamic, 0)
48 .add_tensor({8, 3, 3, 3}, kStaticDense, 1)
49 .add_tensor({8}, kStaticDense, 2)
50 .add_tensor({1, 128, 128, 8}, kDynamic, 3)
51 .add_tensor({4, 1, 1, 8}, kStaticSparse, 4)
52 .add_tensor({4}, kStaticDense, 5)
53 .add_tensor({1, 128, 128, 4}, kDynamic, 6)
54 .add_tensor({1, 4}, kDynamic, 7)
55 .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 8, 0, 1, 2, 3)
56 .add_conv(0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 8, 4, 3, 4, 5, 6)
57 .add_global_average_pooling(6, 7)
58 .optimize()
59 .rewrite();
60
61 ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
62 ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nchw);
63 ASSERT_EQ(tester.get_layout(6), xnn_layout_type_nchw);
64 ASSERT_EQ(tester.get_layout(7), xnn_layout_type_nhwc);
65}
66
67TEST(SUBGRAPH_NCHW, bottleneck) {
68 auto tester = SubgraphTester(15);
69 tester
70 .add_tensor({1, 256, 256, 3}, kDynamic, 0)
71 .add_tensor({8, 3, 3, 3}, kStaticDense, 1)
72 .add_tensor({8}, kStaticDense, 2)
73 .add_tensor({1, 128, 128, 8}, kDynamic, 3)
74 .add_tensor({4, 1, 1, 8}, kStaticSparse, 4)
75 .add_tensor({4}, kStaticDense, 5)
76 .add_tensor({1, 128, 128, 4}, kDynamic, 6)
77 .add_tensor({1, 3, 3, 4}, kStaticDense, 7)
78 .add_tensor({4}, kStaticDense, 8)
79 .add_tensor({1, 128, 128, 4}, kDynamic, 9)
80 .add_tensor({8, 1, 1, 4}, kStaticSparse, 10)
81 .add_tensor({8}, kStaticDense, 11)
82 .add_tensor({1, 128, 128, 8}, kDynamic, 12)
83 .add_tensor({1, 128, 128, 8}, kDynamic, 13)
84 .add_tensor({1, 128, 128, 8}, kDynamic, 13)
85 .add_tensor({1, 8}, kDynamic, 14)
86 .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 8, 0, 1, 2, 3)
87 .add_conv(0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 8, 4, 3, 4, 5, 6)
88 .add_depthwise_conv(1, 1, 1, 1, 3, 3, 1, 1, 1, 1, 1, 4, 6, 7, 8, 9)
89 .add_conv(0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 8, 4, 9, 10, 11, 12)
90 .add_addition(3, 12, 13)
91 .add_global_average_pooling(13, 14)
92 .optimize()
93 .rewrite();
94
95 ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
96 ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nchw);
97 ASSERT_EQ(tester.get_layout(6), xnn_layout_type_nchw);
98 ASSERT_EQ(tester.get_layout(9), xnn_layout_type_nchw);
99 ASSERT_EQ(tester.get_layout(12), xnn_layout_type_nchw);
100 ASSERT_EQ(tester.get_layout(13), xnn_layout_type_nchw);
101 ASSERT_EQ(tester.get_layout(14), xnn_layout_type_nhwc);
XNNPACK Teamab8c4c82020-10-09 08:05:51 -0700102}