blob: c85ba507475bb598503d67c85eccf5a192dbdbb7 [file] [log] [blame]
Matt Sharifid40f9762017-03-14 21:24:23 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "smartselect/token-feature-extractor.h"
18
19#include "gmock/gmock.h"
20#include "gtest/gtest.h"
21
22namespace libtextclassifier {
23namespace {
24
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020025class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
26 public:
27 using TokenFeatureExtractor::TokenFeatureExtractor;
28 using TokenFeatureExtractor::HashToken;
29};
30
31TEST(TokenFeatureExtractorTest, ExtractAscii) {
Matt Sharifid40f9762017-03-14 21:24:23 +010032 TokenFeatureExtractorOptions options;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020033 options.num_buckets = 1000;
34 options.chargram_orders = std::vector<int>{1, 2, 3};
Matt Sharifid40f9762017-03-14 21:24:23 +010035 options.extract_case_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020036 options.unicode_aware_features = false;
Matt Sharifid40f9762017-03-14 21:24:23 +010037 options.extract_selection_mask_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020038 TestingTokenFeatureExtractor extractor(options);
Matt Sharifid40f9762017-03-14 21:24:23 +010039
40 std::vector<int> sparse_features;
41 std::vector<float> dense_features;
42
Lukas Zilka6bb39a82017-04-07 19:55:11 +020043 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
Matt Sharifid40f9762017-03-14 21:24:23 +010044 &dense_features);
45
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020046 EXPECT_THAT(sparse_features,
47 testing::ElementsAreArray({
48 // clang-format off
49 extractor.HashToken("H"),
50 extractor.HashToken("e"),
51 extractor.HashToken("l"),
52 extractor.HashToken("l"),
53 extractor.HashToken("o"),
54 extractor.HashToken("^H"),
55 extractor.HashToken("He"),
56 extractor.HashToken("el"),
57 extractor.HashToken("ll"),
58 extractor.HashToken("lo"),
59 extractor.HashToken("o$"),
60 extractor.HashToken("^He"),
61 extractor.HashToken("Hel"),
62 extractor.HashToken("ell"),
63 extractor.HashToken("llo"),
64 extractor.HashToken("lo$")
65 // clang-format on
66 }));
Matt Sharifid40f9762017-03-14 21:24:23 +010067 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
68
69 sparse_features.clear();
70 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +020071 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
Matt Sharifid40f9762017-03-14 21:24:23 +010072 &dense_features);
73
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020074 EXPECT_THAT(sparse_features,
75 testing::ElementsAreArray({
76 // clang-format off
77 extractor.HashToken("w"),
78 extractor.HashToken("o"),
79 extractor.HashToken("r"),
80 extractor.HashToken("l"),
81 extractor.HashToken("d"),
82 extractor.HashToken("!"),
83 extractor.HashToken("^w"),
84 extractor.HashToken("wo"),
85 extractor.HashToken("or"),
86 extractor.HashToken("rl"),
87 extractor.HashToken("ld"),
88 extractor.HashToken("d!"),
89 extractor.HashToken("!$"),
90 extractor.HashToken("^wo"),
91 extractor.HashToken("wor"),
92 extractor.HashToken("orl"),
93 extractor.HashToken("rld"),
94 extractor.HashToken("ld!"),
95 extractor.HashToken("d!$"),
96 // clang-format on
97 }));
Matt Sharifid40f9762017-03-14 21:24:23 +010098 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
99}
100
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200101TEST(TokenFeatureExtractorTest, ExtractUnicode) {
102 TokenFeatureExtractorOptions options;
103 options.num_buckets = 1000;
104 options.chargram_orders = std::vector<int>{1, 2, 3};
105 options.extract_case_feature = true;
106 options.unicode_aware_features = true;
107 options.extract_selection_mask_feature = true;
108 TestingTokenFeatureExtractor extractor(options);
109
110 std::vector<int> sparse_features;
111 std::vector<float> dense_features;
112
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200113 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200114 &dense_features);
115
116 EXPECT_THAT(sparse_features,
117 testing::ElementsAreArray({
118 // clang-format off
119 extractor.HashToken("H"),
120 extractor.HashToken("ě"),
121 extractor.HashToken("l"),
122 extractor.HashToken("l"),
123 extractor.HashToken("ó"),
124 extractor.HashToken("^H"),
125 extractor.HashToken("Hě"),
126 extractor.HashToken("ěl"),
127 extractor.HashToken("ll"),
128 extractor.HashToken("ló"),
129 extractor.HashToken("ó$"),
130 extractor.HashToken("^Hě"),
131 extractor.HashToken("Hěl"),
132 extractor.HashToken("ěll"),
133 extractor.HashToken("lló"),
134 extractor.HashToken("ló$")
135 // clang-format on
136 }));
137 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
138
139 sparse_features.clear();
140 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200141 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200142 &dense_features);
143
144 EXPECT_THAT(sparse_features,
145 testing::ElementsAreArray({
146 // clang-format off
147 extractor.HashToken("w"),
148 extractor.HashToken("o"),
149 extractor.HashToken("r"),
150 extractor.HashToken("l"),
151 extractor.HashToken("d"),
152 extractor.HashToken("!"),
153 extractor.HashToken("^w"),
154 extractor.HashToken("wo"),
155 extractor.HashToken("or"),
156 extractor.HashToken("rl"),
157 extractor.HashToken("ld"),
158 extractor.HashToken("d!"),
159 extractor.HashToken("!$"),
160 extractor.HashToken("^wo"),
161 extractor.HashToken("wor"),
162 extractor.HashToken("orl"),
163 extractor.HashToken("rld"),
164 extractor.HashToken("ld!"),
165 extractor.HashToken("d!$"),
166 // clang-format on
167 }));
168 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
169}
170
171TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
172 TokenFeatureExtractorOptions options;
173 options.num_buckets = 1000;
174 options.chargram_orders = std::vector<int>{1, 2};
175 options.extract_case_feature = true;
176 options.unicode_aware_features = true;
177 options.extract_selection_mask_feature = false;
178 TokenFeatureExtractor extractor(options);
179
180 std::vector<int> sparse_features;
181 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200182 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200183 &dense_features);
184 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
185
186 sparse_features.clear();
187 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200188 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200189 &dense_features);
190 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
191
192 sparse_features.clear();
193 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200194 extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200195 &dense_features);
196 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
197
198 sparse_features.clear();
199 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200200 extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200201 &dense_features);
202 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
203}
204
205TEST(TokenFeatureExtractorTest, DigitRemapping) {
206 TokenFeatureExtractorOptions options;
207 options.num_buckets = 1000;
208 options.chargram_orders = std::vector<int>{1, 2};
209 options.remap_digits = true;
210 options.unicode_aware_features = false;
211 TokenFeatureExtractor extractor(options);
212
213 std::vector<int> sparse_features;
214 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200215 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200216 &dense_features);
217
218 std::vector<int> sparse_features2;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200219 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200220 &dense_features);
221 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
222
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200223 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200224 &dense_features);
225 EXPECT_THAT(sparse_features,
226 testing::Not(testing::ElementsAreArray(sparse_features2)));
227}
228
229TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
230 TokenFeatureExtractorOptions options;
231 options.num_buckets = 1000;
232 options.chargram_orders = std::vector<int>{1, 2};
233 options.remap_digits = true;
234 options.unicode_aware_features = true;
235 TokenFeatureExtractor extractor(options);
236
237 std::vector<int> sparse_features;
238 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200239 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200240 &dense_features);
241
242 std::vector<int> sparse_features2;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200243 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200244 &dense_features);
245 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
246
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200247 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200248 &dense_features);
249 EXPECT_THAT(sparse_features,
250 testing::Not(testing::ElementsAreArray(sparse_features2)));
251}
252
Matt Sharifideb722d2017-04-24 13:30:47 +0200253TEST(TokenFeatureExtractorTest, LowercaseAscii) {
254 TokenFeatureExtractorOptions options;
255 options.num_buckets = 1000;
256 options.chargram_orders = std::vector<int>{1, 2};
257 options.lowercase_tokens = true;
258 options.unicode_aware_features = false;
259 TokenFeatureExtractor extractor(options);
260
261 std::vector<int> sparse_features;
262 std::vector<float> dense_features;
263 extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
264 &dense_features);
265
266 std::vector<int> sparse_features2;
267 extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
268 &dense_features);
269 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
270
271 extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
272 &dense_features);
273 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
274}
275
276TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
277 TokenFeatureExtractorOptions options;
278 options.num_buckets = 1000;
279 options.chargram_orders = std::vector<int>{1, 2};
280 options.lowercase_tokens = true;
281 options.unicode_aware_features = true;
282 TokenFeatureExtractor extractor(options);
283
284 std::vector<int> sparse_features;
285 std::vector<float> dense_features;
286 extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
287
288 std::vector<int> sparse_features2;
289 extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
290 &dense_features);
291 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
292}
293
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200294TEST(TokenFeatureExtractorTest, RegexFeatures) {
295 TokenFeatureExtractorOptions options;
296 options.num_buckets = 1000;
297 options.chargram_orders = std::vector<int>{1, 2};
298 options.remap_digits = false;
299 options.unicode_aware_features = false;
300 options.regexp_features.push_back("^[a-z]+$"); // all lower case.
301 options.regexp_features.push_back("^[0-9]+$"); // all digits.
302 TokenFeatureExtractor extractor(options);
303
304 std::vector<int> sparse_features;
305 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200306 extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200307 &dense_features);
308 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
309
310 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200311 extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200312 &dense_features);
313 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
314
315 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200316 extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200317 &dense_features);
318 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
319
320 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200321 extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200322 &dense_features);
323 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
324}
325
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200326TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
327 TokenFeatureExtractorOptions options;
328 options.num_buckets = 1000;
329 options.chargram_orders = std::vector<int>{22};
330 options.extract_case_feature = true;
331 options.unicode_aware_features = true;
332 options.extract_selection_mask_feature = true;
333 TestingTokenFeatureExtractor extractor(options);
334
335 // Test that this runs. ASAN should catch problems.
336 std::vector<int> sparse_features;
337 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200338 extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200339 &sparse_features, &dense_features);
340
341 EXPECT_THAT(sparse_features,
342 testing::ElementsAreArray({
343 // clang-format off
344 extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
345 extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
346 // clang-format on
347 }));
348}
349
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200350TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
351 TokenFeatureExtractorOptions options;
352 options.num_buckets = 1000;
353 options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
354 options.extract_case_feature = true;
355 options.unicode_aware_features = true;
356 options.extract_selection_mask_feature = true;
357 TestingTokenFeatureExtractor extractor_unicode(options);
358
359 options.unicode_aware_features = false;
360 TestingTokenFeatureExtractor extractor_ascii(options);
361
362 for (const std::string& input :
363 {"https://www.abcdefgh.com/in/xxxkkkvayio",
364 "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
365 "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
366 "x", "Hello", "Hey,", "Hi", ""}) {
367 std::vector<int> sparse_features_unicode;
368 std::vector<float> dense_features_unicode;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200369 extractor_unicode.Extract(Token{input, 0, 0}, true,
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200370 &sparse_features_unicode,
371 &dense_features_unicode);
372
373 std::vector<int> sparse_features_ascii;
374 std::vector<float> dense_features_ascii;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200375 extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200376 &dense_features_ascii);
377
378 EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
379 EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
380 }
381}
382
Matt Sharifid40f9762017-03-14 21:24:23 +0100383TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
384 TokenFeatureExtractorOptions options;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200385 options.num_buckets = 1000;
Matt Sharifid40f9762017-03-14 21:24:23 +0100386 options.chargram_orders = std::vector<int>{1, 2};
387 options.extract_case_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200388 options.unicode_aware_features = false;
Matt Sharifid40f9762017-03-14 21:24:23 +0100389 options.extract_selection_mask_feature = true;
390
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200391 TestingTokenFeatureExtractor extractor(options);
Matt Sharifid40f9762017-03-14 21:24:23 +0100392
393 std::vector<int> sparse_features;
394 std::vector<float> dense_features;
395
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200396 extractor.Extract(Token(), false, &sparse_features, &dense_features);
Matt Sharifid40f9762017-03-14 21:24:23 +0100397
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200398 EXPECT_THAT(sparse_features,
399 testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100400 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
401}
402
403} // namespace
404} // namespace libtextclassifier