blob: d50fac751048b84a550916526c7d201cafd4e346 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "tensor-view.h"
18
19#include "gmock/gmock.h"
20#include "gtest/gtest.h"
21
22namespace libtextclassifier2 {
23namespace {
24
25TEST(TensorViewTest, TestSize) {
26 std::vector<float> data{0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
27 const TensorView<float> tensor(data.data(), {3, 1, 2});
28 EXPECT_TRUE(tensor.is_valid());
29 EXPECT_EQ(tensor.shape(), (std::vector<int>{3, 1, 2}));
30 EXPECT_EQ(tensor.data(), data.data());
31 EXPECT_EQ(tensor.size(), 6);
32 EXPECT_EQ(tensor.dims(), 3);
33 EXPECT_EQ(tensor.dim(0), 3);
34 EXPECT_EQ(tensor.dim(1), 1);
35 EXPECT_EQ(tensor.dim(2), 2);
36 std::vector<float> output_data(6);
37 EXPECT_TRUE(tensor.copy_to(output_data.data(), output_data.size()));
38 EXPECT_EQ(data, output_data);
39
40 // Should not copy when the output is small.
41 std::vector<float> small_output_data{-1, -1, -1};
42 EXPECT_FALSE(
43 tensor.copy_to(small_output_data.data(), small_output_data.size()));
44 // The output buffer should not be changed.
45 EXPECT_EQ(small_output_data, (std::vector<float>{-1, -1, -1}));
46
47 const TensorView<float> invalid_tensor = TensorView<float>::Invalid();
48 EXPECT_FALSE(invalid_tensor.is_valid());
49}
50
51} // namespace
52} // namespace libtextclassifier2