blob: 9566a8dc990d6122932032d116f854b04b93b7df [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "cached-features.h"
18
19#include "model-executor.h"
20#include "tensor-view.h"
21
22#include "gmock/gmock.h"
23#include "gtest/gtest.h"
24
25using testing::ElementsAreArray;
26using testing::FloatEq;
27using testing::Matcher;
28
29namespace libtextclassifier2 {
30namespace {
31
32Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
33 std::vector<Matcher<float>> matchers;
34 for (const float value : values) {
35 matchers.push_back(FloatEq(value));
36 }
37 return ElementsAreArray(matchers);
38}
39
40// EmbeddingExecutor that always returns features based on
41class FakeEmbeddingExecutor : public EmbeddingExecutor {
42 public:
43 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
44 int dest_size) override {
45 TC_CHECK_GE(dest_size, 2);
46 EXPECT_EQ(sparse_features.size(), 1);
47
48 dest[0] = sparse_features.data()[0] * 11.0f;
49 dest[1] = -sparse_features.data()[0] * 11.0f;
50 return true;
51 }
52
53 private:
54 std::vector<float> storage_;
55};
56
Lukas Zilkab23e2122018-02-09 10:25:19 +010057std::vector<float> GetCachedClickContextFeatures(
58 const CachedFeatures& cached_features, int click_pos) {
59 std::vector<float> output_features;
60 cached_features.AppendClickContextFeaturesForClick(click_pos,
61 &output_features);
62 return output_features;
63}
64
65std::vector<float> GetCachedBoundsSensitiveFeatures(
66 const CachedFeatures& cached_features, TokenSpan selected_span) {
67 std::vector<float> output_features;
68 cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
69 &output_features);
70 return output_features;
71}
72
73TEST(CachedFeaturesTest, ClickContext) {
74 FeatureProcessorOptionsT options;
75 options.context_size = 2;
76 options.feature_version = 1;
Lukas Zilka21d8c982018-01-24 11:11:20 +010077 flatbuffers::FlatBufferBuilder builder;
Lukas Zilkab23e2122018-02-09 10:25:19 +010078 builder.Finish(CreateFeatureProcessorOptions(builder, &options));
79 flatbuffers::DetachedBuffer options_fb = builder.Release();
Lukas Zilka21d8c982018-01-24 11:11:20 +010080
81 std::vector<std::vector<int>> sparse_features(9);
82 for (int i = 0; i < sparse_features.size(); ++i) {
83 sparse_features[i].push_back(i + 1);
84 }
85 std::vector<std::vector<float>> dense_features(9);
86 for (int i = 0; i < dense_features.size(); ++i) {
87 dense_features[i].push_back((i + 1) * 0.1);
88 }
89
90 std::vector<int> padding_sparse_features = {10203};
91 std::vector<float> padding_dense_features = {321.0};
92
93 FakeEmbeddingExecutor executor;
Lukas Zilkab23e2122018-02-09 10:25:19 +010094 const std::unique_ptr<CachedFeatures> cached_features =
95 CachedFeatures::Create(
96 {3, 10}, sparse_features, dense_features, padding_sparse_features,
97 padding_dense_features,
98 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
99 &executor, /*feature_vector_size=*/3);
100 ASSERT_TRUE(cached_features);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100101
Lukas Zilkab23e2122018-02-09 10:25:19 +0100102 EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
103 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
104 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
105
106 EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
107 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
108 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
109
110 EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
111 ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
112 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
113}
114
115TEST(CachedFeaturesTest, BoundsSensitive) {
116 std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
117 new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
118 config->enabled = true;
119 config->num_tokens_before = 2;
120 config->num_tokens_inside_left = 2;
121 config->num_tokens_inside_right = 2;
122 config->num_tokens_after = 2;
123 config->include_inside_bag = true;
124 config->include_inside_length = true;
125 FeatureProcessorOptionsT options;
126 options.bounds_sensitive_features = std::move(config);
127 options.feature_version = 2;
128 flatbuffers::FlatBufferBuilder builder;
129 builder.Finish(CreateFeatureProcessorOptions(builder, &options));
130 flatbuffers::DetachedBuffer options_fb = builder.Release();
131
132 std::vector<std::vector<int>> sparse_features(6);
133 for (int i = 0; i < sparse_features.size(); ++i) {
134 sparse_features[i].push_back(i + 1);
135 }
136 std::vector<std::vector<float>> dense_features(6);
137 for (int i = 0; i < dense_features.size(); ++i) {
138 dense_features[i].push_back((i + 1) * 0.1);
139 }
140
141 std::vector<int> padding_sparse_features = {10203};
142 std::vector<float> padding_dense_features = {321.0};
143
144 FakeEmbeddingExecutor executor;
145 const std::unique_ptr<CachedFeatures> cached_features =
146 CachedFeatures::Create(
147 {3, 9}, sparse_features, dense_features, padding_sparse_features,
148 padding_dense_features,
149 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
150 &executor, /*feature_vector_size=*/3);
151 ASSERT_TRUE(cached_features);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100152
153 EXPECT_THAT(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100154 GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
155 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
156 -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0,
157 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
158 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0}));
159
160 EXPECT_THAT(
161 GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100162 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
163 -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
164 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
Lukas Zilkab23e2122018-02-09 10:25:19 +0100165 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0}));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100166
167 EXPECT_THAT(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100168 GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100169 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
170 -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
171 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
Lukas Zilkab23e2122018-02-09 10:25:19 +0100172 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0}));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100173
174 EXPECT_THAT(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100175 GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100176 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
177 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
178 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,
179 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
180 44.0, -44.0, 0.4, 1.0}));
181}
182
183} // namespace
184} // namespace libtextclassifier2