blob: 2412ff3d0891ce1ba59f7840a138df0357ff831b [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "cached-features.h"
18
19#include "model-executor.h"
20#include "tensor-view.h"
21
22#include "gmock/gmock.h"
23#include "gtest/gtest.h"
24
25using testing::ElementsAreArray;
26using testing::FloatEq;
27using testing::Matcher;
28
29namespace libtextclassifier2 {
30namespace {
31
32Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
33 std::vector<Matcher<float>> matchers;
34 for (const float value : values) {
35 matchers.push_back(FloatEq(value));
36 }
37 return ElementsAreArray(matchers);
38}
39
40// EmbeddingExecutor that always returns features based on
41class FakeEmbeddingExecutor : public EmbeddingExecutor {
42 public:
43 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
44 int dest_size) override {
45 TC_CHECK_GE(dest_size, 2);
46 EXPECT_EQ(sparse_features.size(), 1);
47
48 dest[0] = sparse_features.data()[0] * 11.0f;
49 dest[1] = -sparse_features.data()[0] * 11.0f;
50 return true;
51 }
52
53 private:
54 std::vector<float> storage_;
55};
56
57TEST(CachedFeaturesTest, Simple) {
58 FeatureProcessorOptions_::BoundsSensitiveFeaturesT config;
59 config.enabled = true;
60 config.num_tokens_before = 2;
61 config.num_tokens_inside_left = 2;
62 config.num_tokens_inside_right = 2;
63 config.num_tokens_after = 2;
64 config.include_inside_bag = true;
65 config.include_inside_length = true;
66 flatbuffers::FlatBufferBuilder builder;
67 builder.Finish(CreateBoundsSensitiveFeatures(builder, &config));
68 flatbuffers::DetachedBuffer config_fb = builder.Release();
69
70 std::vector<std::vector<int>> sparse_features(9);
71 for (int i = 0; i < sparse_features.size(); ++i) {
72 sparse_features[i].push_back(i + 1);
73 }
74 std::vector<std::vector<float>> dense_features(9);
75 for (int i = 0; i < dense_features.size(); ++i) {
76 dense_features[i].push_back((i + 1) * 0.1);
77 }
78
79 std::vector<int> padding_sparse_features = {10203};
80 std::vector<float> padding_dense_features = {321.0};
81
82 FakeEmbeddingExecutor executor;
83 const CachedFeatures cached_features(
84 {3, 9}, sparse_features, dense_features, padding_sparse_features,
85 padding_dense_features,
86 flatbuffers::GetRoot<FeatureProcessorOptions_::BoundsSensitiveFeatures>(
87 config_fb.data()),
88 &executor, /*feature_vector_size=*/3);
89
90 EXPECT_THAT(cached_features.Get({5, 8}),
91 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2,
92 33.0, -33.0, 0.3, 44.0, -44.0, 0.4,
93 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
94 66.0, -66.0, 0.6, 112233.0, -112233.0, 321.0,
95 132.0, -132.0, 1.2, 3.0}));
96
97 EXPECT_THAT(
98 cached_features.Get({5, 7}),
99 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
100 -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
101 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
102 66.0, -66.0, 0.6, 77.0, -77.0, 0.7, 2.0}));
103
104 EXPECT_THAT(
105 cached_features.Get({6, 8}),
106 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
107 -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
108 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
109 112233.0, -112233.0, 321.0, 99.0, -99.0, 0.9, 2.0}));
110
111 EXPECT_THAT(
112 cached_features.Get({6, 7}),
113 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
114 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
115 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,
116 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
117 44.0, -44.0, 0.4, 1.0}));
118}
119
120} // namespace
121} // namespace libtextclassifier2