blob: 58097bc8c61fbef47bea5c761cf9f374eaf30766 [file] [log] [blame]
Matt Sharifid40f9762017-03-14 21:24:23 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "smartselect/token-feature-extractor.h"
18
19#include "gmock/gmock.h"
20#include "gtest/gtest.h"
21
22namespace libtextclassifier {
23namespace {
24
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020025class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
26 public:
27 using TokenFeatureExtractor::TokenFeatureExtractor;
28 using TokenFeatureExtractor::HashToken;
29};
30
31TEST(TokenFeatureExtractorTest, ExtractAscii) {
Matt Sharifid40f9762017-03-14 21:24:23 +010032 TokenFeatureExtractorOptions options;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020033 options.num_buckets = 1000;
34 options.chargram_orders = std::vector<int>{1, 2, 3};
Matt Sharifid40f9762017-03-14 21:24:23 +010035 options.extract_case_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020036 options.unicode_aware_features = false;
Matt Sharifid40f9762017-03-14 21:24:23 +010037 options.extract_selection_mask_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020038 TestingTokenFeatureExtractor extractor(options);
Matt Sharifid40f9762017-03-14 21:24:23 +010039
40 std::vector<int> sparse_features;
41 std::vector<float> dense_features;
42
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020043 extractor.Extract(Token{"Hello", 0, 5, true}, &sparse_features,
Matt Sharifid40f9762017-03-14 21:24:23 +010044 &dense_features);
45
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020046 EXPECT_THAT(sparse_features,
47 testing::ElementsAreArray({
48 // clang-format off
49 extractor.HashToken("H"),
50 extractor.HashToken("e"),
51 extractor.HashToken("l"),
52 extractor.HashToken("l"),
53 extractor.HashToken("o"),
54 extractor.HashToken("^H"),
55 extractor.HashToken("He"),
56 extractor.HashToken("el"),
57 extractor.HashToken("ll"),
58 extractor.HashToken("lo"),
59 extractor.HashToken("o$"),
60 extractor.HashToken("^He"),
61 extractor.HashToken("Hel"),
62 extractor.HashToken("ell"),
63 extractor.HashToken("llo"),
64 extractor.HashToken("lo$")
65 // clang-format on
66 }));
Matt Sharifid40f9762017-03-14 21:24:23 +010067 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
68
69 sparse_features.clear();
70 dense_features.clear();
71 extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
72 &dense_features);
73
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020074 EXPECT_THAT(sparse_features,
75 testing::ElementsAreArray({
76 // clang-format off
77 extractor.HashToken("w"),
78 extractor.HashToken("o"),
79 extractor.HashToken("r"),
80 extractor.HashToken("l"),
81 extractor.HashToken("d"),
82 extractor.HashToken("!"),
83 extractor.HashToken("^w"),
84 extractor.HashToken("wo"),
85 extractor.HashToken("or"),
86 extractor.HashToken("rl"),
87 extractor.HashToken("ld"),
88 extractor.HashToken("d!"),
89 extractor.HashToken("!$"),
90 extractor.HashToken("^wo"),
91 extractor.HashToken("wor"),
92 extractor.HashToken("orl"),
93 extractor.HashToken("rld"),
94 extractor.HashToken("ld!"),
95 extractor.HashToken("d!$"),
96 // clang-format on
97 }));
Matt Sharifid40f9762017-03-14 21:24:23 +010098 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
99}
100
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200101TEST(TokenFeatureExtractorTest, ExtractUnicode) {
102 TokenFeatureExtractorOptions options;
103 options.num_buckets = 1000;
104 options.chargram_orders = std::vector<int>{1, 2, 3};
105 options.extract_case_feature = true;
106 options.unicode_aware_features = true;
107 options.extract_selection_mask_feature = true;
108 TestingTokenFeatureExtractor extractor(options);
109
110 std::vector<int> sparse_features;
111 std::vector<float> dense_features;
112
113 extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
114 &dense_features);
115
116 EXPECT_THAT(sparse_features,
117 testing::ElementsAreArray({
118 // clang-format off
119 extractor.HashToken("H"),
120 extractor.HashToken("ě"),
121 extractor.HashToken("l"),
122 extractor.HashToken("l"),
123 extractor.HashToken("ó"),
124 extractor.HashToken("^H"),
125 extractor.HashToken("Hě"),
126 extractor.HashToken("ěl"),
127 extractor.HashToken("ll"),
128 extractor.HashToken("ló"),
129 extractor.HashToken("ó$"),
130 extractor.HashToken("^Hě"),
131 extractor.HashToken("Hěl"),
132 extractor.HashToken("ěll"),
133 extractor.HashToken("lló"),
134 extractor.HashToken("ló$")
135 // clang-format on
136 }));
137 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
138
139 sparse_features.clear();
140 dense_features.clear();
141 extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
142 &dense_features);
143
144 EXPECT_THAT(sparse_features,
145 testing::ElementsAreArray({
146 // clang-format off
147 extractor.HashToken("w"),
148 extractor.HashToken("o"),
149 extractor.HashToken("r"),
150 extractor.HashToken("l"),
151 extractor.HashToken("d"),
152 extractor.HashToken("!"),
153 extractor.HashToken("^w"),
154 extractor.HashToken("wo"),
155 extractor.HashToken("or"),
156 extractor.HashToken("rl"),
157 extractor.HashToken("ld"),
158 extractor.HashToken("d!"),
159 extractor.HashToken("!$"),
160 extractor.HashToken("^wo"),
161 extractor.HashToken("wor"),
162 extractor.HashToken("orl"),
163 extractor.HashToken("rld"),
164 extractor.HashToken("ld!"),
165 extractor.HashToken("d!$"),
166 // clang-format on
167 }));
168 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
169}
170
171TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
172 TokenFeatureExtractorOptions options;
173 options.num_buckets = 1000;
174 options.chargram_orders = std::vector<int>{1, 2};
175 options.extract_case_feature = true;
176 options.unicode_aware_features = true;
177 options.extract_selection_mask_feature = false;
178 TokenFeatureExtractor extractor(options);
179
180 std::vector<int> sparse_features;
181 std::vector<float> dense_features;
182 extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
183 &dense_features);
184 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
185
186 sparse_features.clear();
187 dense_features.clear();
188 extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
189 &dense_features);
190 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
191
192 sparse_features.clear();
193 dense_features.clear();
194 extractor.Extract(Token{"Ř", 23, 29, false}, &sparse_features,
195 &dense_features);
196 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
197
198 sparse_features.clear();
199 dense_features.clear();
200 extractor.Extract(Token{"ř", 23, 29, false}, &sparse_features,
201 &dense_features);
202 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
203}
204
205TEST(TokenFeatureExtractorTest, DigitRemapping) {
206 TokenFeatureExtractorOptions options;
207 options.num_buckets = 1000;
208 options.chargram_orders = std::vector<int>{1, 2};
209 options.remap_digits = true;
210 options.unicode_aware_features = false;
211 TokenFeatureExtractor extractor(options);
212
213 std::vector<int> sparse_features;
214 std::vector<float> dense_features;
215 extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
216 &dense_features);
217
218 std::vector<int> sparse_features2;
219 extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
220 &dense_features);
221 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
222
223 extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
224 &dense_features);
225 EXPECT_THAT(sparse_features,
226 testing::Not(testing::ElementsAreArray(sparse_features2)));
227}
228
229TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
230 TokenFeatureExtractorOptions options;
231 options.num_buckets = 1000;
232 options.chargram_orders = std::vector<int>{1, 2};
233 options.remap_digits = true;
234 options.unicode_aware_features = true;
235 TokenFeatureExtractor extractor(options);
236
237 std::vector<int> sparse_features;
238 std::vector<float> dense_features;
239 extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
240 &dense_features);
241
242 std::vector<int> sparse_features2;
243 extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
244 &dense_features);
245 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
246
247 extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
248 &dense_features);
249 EXPECT_THAT(sparse_features,
250 testing::Not(testing::ElementsAreArray(sparse_features2)));
251}
252
253TEST(TokenFeatureExtractorTest, RegexFeatures) {
254 TokenFeatureExtractorOptions options;
255 options.num_buckets = 1000;
256 options.chargram_orders = std::vector<int>{1, 2};
257 options.remap_digits = false;
258 options.unicode_aware_features = false;
259 options.regexp_features.push_back("^[a-z]+$"); // all lower case.
260 options.regexp_features.push_back("^[0-9]+$"); // all digits.
261 TokenFeatureExtractor extractor(options);
262
263 std::vector<int> sparse_features;
264 std::vector<float> dense_features;
265 extractor.Extract(Token{"abCde", 0, 6, true}, &sparse_features,
266 &dense_features);
267 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
268
269 dense_features.clear();
270 extractor.Extract(Token{"abcde", 0, 6, true}, &sparse_features,
271 &dense_features);
272 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
273
274 dense_features.clear();
275 extractor.Extract(Token{"12c45", 0, 6, true}, &sparse_features,
276 &dense_features);
277 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
278
279 dense_features.clear();
280 extractor.Extract(Token{"12345", 0, 6, true}, &sparse_features,
281 &dense_features);
282 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
283}
284
285TEST(TokenFeatureExtractorTest, ExtractInvalidUTF8) {
286 TokenFeatureExtractorOptions options;
287 options.num_buckets = 1000;
288 options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5, 100};
289 options.extract_case_feature = true;
290 options.unicode_aware_features = true;
291 options.extract_selection_mask_feature = true;
292 TestingTokenFeatureExtractor extractor(options);
293
294 // Test that this runs. ASAN should catch problems.
295 std::vector<int> sparse_features;
296 std::vector<float> dense_features;
297 extractor.Extract(Token{"\xf0👶👶¾👶🏿\xf0", 0, 7, true},
298 &sparse_features, &dense_features);
299}
300
301TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
302 TokenFeatureExtractorOptions options;
303 options.num_buckets = 1000;
304 options.chargram_orders = std::vector<int>{22};
305 options.extract_case_feature = true;
306 options.unicode_aware_features = true;
307 options.extract_selection_mask_feature = true;
308 TestingTokenFeatureExtractor extractor(options);
309
310 // Test that this runs. ASAN should catch problems.
311 std::vector<int> sparse_features;
312 std::vector<float> dense_features;
313 extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0, true},
314 &sparse_features, &dense_features);
315
316 EXPECT_THAT(sparse_features,
317 testing::ElementsAreArray({
318 // clang-format off
319 extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
320 extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
321 // clang-format on
322 }));
323}
324
Matt Sharifid40f9762017-03-14 21:24:23 +0100325TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
326 TokenFeatureExtractorOptions options;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200327 options.num_buckets = 1000;
Matt Sharifid40f9762017-03-14 21:24:23 +0100328 options.chargram_orders = std::vector<int>{1, 2};
329 options.extract_case_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200330 options.unicode_aware_features = false;
Matt Sharifid40f9762017-03-14 21:24:23 +0100331 options.extract_selection_mask_feature = true;
332
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200333 TestingTokenFeatureExtractor extractor(options);
Matt Sharifid40f9762017-03-14 21:24:23 +0100334
335 std::vector<int> sparse_features;
336 std::vector<float> dense_features;
337
338 extractor.Extract(Token(), &sparse_features, &dense_features);
339
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200340 EXPECT_THAT(sparse_features,
341 testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100342 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
343}
344
345} // namespace
346} // namespace libtextclassifier