blob: ecdad160e00a3481230951429ab9ca7ca3b70125 [file] [log] [blame]
Matt Sharifid40f9762017-03-14 21:24:23 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "smartselect/feature-processor.h"
18
19#include "gmock/gmock.h"
20#include "gtest/gtest.h"
21
22namespace libtextclassifier {
23namespace {
24
25using testing::ElementsAreArray;
26
27TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
28 std::vector<Token> tokens{Token("Hělló", 0, 5, false),
29 Token("fěěbař@google.com", 6, 23, false),
30 Token("heře!", 24, 29, false)};
31
32 internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
33
34 // clang-format off
35 EXPECT_THAT(tokens, ElementsAreArray(
36 {Token("Hělló", 0, 5, false),
37 Token("fěě", 6, 9, false),
38 Token("bař", 9, 12, false),
39 Token("@google.com", 12, 23, false),
40 Token("heře!", 24, 29, false)}));
41 // clang-format on
42}
43
44TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
45 std::vector<Token> tokens{Token("Hělló", 0, 5, false),
46 Token("fěěbař@google.com", 6, 23, false),
47 Token("heře!", 24, 29, false)};
48
49 internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
50
51 // clang-format off
52 EXPECT_THAT(tokens, ElementsAreArray(
53 {Token("Hělló", 0, 5, false),
54 Token("fěěbař", 6, 12, false),
55 Token("@google.com", 12, 23, false),
56 Token("heře!", 24, 29, false)}));
57 // clang-format on
58}
59
60TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
61 std::vector<Token> tokens{Token("Hělló", 0, 5, false),
62 Token("fěěbař@google.com", 6, 23, false),
63 Token("heře!", 24, 29, false)};
64
65 internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
66
67 // clang-format off
68 EXPECT_THAT(tokens, ElementsAreArray(
69 {Token("Hělló", 0, 5, false),
70 Token("fěě", 6, 9, false),
71 Token("bař@google.com", 9, 23, false),
72 Token("heře!", 24, 29, false)}));
73 // clang-format on
74}
75
76TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
77 std::vector<Token> tokens{Token("Hělló", 0, 5, false),
78 Token("fěěbař@google.com", 6, 23, false),
79 Token("heře!", 24, 29, false)};
80
81 internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
82
83 // clang-format off
84 EXPECT_THAT(tokens, ElementsAreArray(
85 {Token("Hělló", 0, 5, false),
86 Token("fěěbař@google.com", 6, 23, false),
87 Token("heře!", 24, 29, false)}));
88 // clang-format on
89}
90
91TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
92 std::vector<Token> tokens{Token("Hělló", 0, 5, false),
93 Token("fěěbař@google.com", 6, 23, false),
94 Token("heře!", 24, 29, false)};
95
96 internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
97
98 // clang-format off
99 EXPECT_THAT(tokens, ElementsAreArray(
100 {Token("Hě", 0, 2, false),
101 Token("lló", 2, 5, false),
102 Token("fěě", 6, 9, false),
103 Token("bař@google.com", 9, 23, false),
104 Token("heře!", 24, 29, false)}));
105 // clang-format on
106}
107
108TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
109 SelectionWithContext selection;
110 selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
111
112 // Keeps the first line.
113 selection.click_start = 0;
114 selection.click_end = 5;
115 selection.selection_start = 6;
116 selection.selection_end = 10;
117
118 SelectionWithContext line_selection;
119 int shift;
120 std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
121
122 EXPECT_EQ(line_selection.context, "Fiřst Lině");
123 EXPECT_EQ(line_selection.click_start, 0);
124 EXPECT_EQ(line_selection.click_end, 5);
125 EXPECT_EQ(line_selection.selection_start, 6);
126 EXPECT_EQ(line_selection.selection_end, 10);
127 EXPECT_EQ(shift, 0);
128}
129
130TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
131 SelectionWithContext selection;
132 selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
133
134 // Keeps the second line.
135 selection.click_start = 11;
136 selection.click_end = 17;
137 selection.selection_start = 18;
138 selection.selection_end = 22;
139
140 SelectionWithContext line_selection;
141 int shift;
142 std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
143
144 EXPECT_EQ(line_selection.context, "Sěcond Lině");
145 EXPECT_EQ(line_selection.click_start, 0);
146 EXPECT_EQ(line_selection.click_end, 6);
147 EXPECT_EQ(line_selection.selection_start, 7);
148 EXPECT_EQ(line_selection.selection_end, 11);
149 EXPECT_EQ(shift, 11);
150}
151
152TEST(FeatureProcessorTest, KeepLineWithClickThird) {
153 SelectionWithContext selection;
154 selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
155
156 // Keeps the third line.
157 selection.click_start = 29;
158 selection.click_end = 33;
159 selection.selection_start = 23;
160 selection.selection_end = 28;
161
162 SelectionWithContext line_selection;
163 int shift;
164 std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
165
166 EXPECT_EQ(line_selection.context, "Thiřd Lině");
167 EXPECT_EQ(line_selection.click_start, 6);
168 EXPECT_EQ(line_selection.click_end, 10);
169 EXPECT_EQ(line_selection.selection_start, 0);
170 EXPECT_EQ(line_selection.selection_end, 5);
171 EXPECT_EQ(shift, 23);
172}
173
174TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
175 SelectionWithContext selection;
176 selection.context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
177
178 // Keeps the second line.
179 selection.click_start = 11;
180 selection.click_end = 17;
181 selection.selection_start = 18;
182 selection.selection_end = 22;
183
184 SelectionWithContext line_selection;
185 int shift;
186 std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
187
188 EXPECT_EQ(line_selection.context, "Sěcond Lině");
189 EXPECT_EQ(line_selection.click_start, 0);
190 EXPECT_EQ(line_selection.click_end, 6);
191 EXPECT_EQ(line_selection.selection_start, 7);
192 EXPECT_EQ(line_selection.selection_end, 11);
193 EXPECT_EQ(shift, 11);
194}
195
196TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
197 SelectionWithContext selection;
198 selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
199
200 // Selects across lines, so KeepLine should not do any changes.
201 selection.click_start = 6;
202 selection.click_end = 17;
203 selection.selection_start = 0;
204 selection.selection_end = 22;
205
206 SelectionWithContext line_selection;
207 int shift;
208 std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
209
210 EXPECT_EQ(line_selection.context, "Fiřst Lině\nSěcond Lině\nThiřd Lině");
211 EXPECT_EQ(line_selection.click_start, 6);
212 EXPECT_EQ(line_selection.click_end, 17);
213 EXPECT_EQ(line_selection.selection_start, 0);
214 EXPECT_EQ(line_selection.selection_end, 22);
215 EXPECT_EQ(shift, 0);
216}
217
218TEST(FeatureProcessorTest, GetFeaturesWithContextDropout) {
219 FeatureProcessorOptions options;
220 options.set_num_buckets(10);
221 options.set_context_size(7);
222 options.set_max_selection_span(7);
223 options.add_chargram_orders(1);
224 options.set_tokenize_on_space(true);
225 options.set_context_dropout_probability(0.5);
226 options.set_use_variable_context_dropout(true);
227 TokenizationCodepointRange* config =
228 options.add_tokenization_codepoint_config();
229 config->set_start(32);
230 config->set_end(33);
231 config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
232 FeatureProcessor feature_processor(options);
233
234 SelectionWithContext selection_with_context;
235 selection_with_context.context = "1 2 3 c o n t e x t X c o n t e x t 1 2 3";
236 // Selection and click indices of the X in the middle:
237 selection_with_context.selection_start = 20;
238 selection_with_context.selection_end = 21;
239 selection_with_context.click_start = 20;
240 selection_with_context.click_end = 21;
241
242 // Test that two subsequent runs with random context dropout produce
243 // different features.
244 feature_processor.SetRandom(new std::mt19937);
245
246 std::vector<std::vector<std::pair<int, float>>> features;
247 std::vector<std::vector<std::pair<int, float>>> features2;
248 std::vector<float> extra_features;
249 std::vector<CodepointSpan> selection_label_spans;
250 int selection_label;
251 CodepointSpan selection_codepoint_label;
252 int classification_label;
253 EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
254 selection_with_context, &features, &extra_features,
255 &selection_label_spans, &selection_label, &selection_codepoint_label,
256 &classification_label));
257 EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
258 selection_with_context, &features2, &extra_features,
259 &selection_label_spans, &selection_label, &selection_codepoint_label,
260 &classification_label));
261
262 EXPECT_NE(features, features2);
263}
264
265TEST(FeatureProcessorTest, GetFeaturesWithLongerContext) {
266 FeatureProcessorOptions options;
267 options.set_num_buckets(10);
268 options.set_context_size(9);
269 options.set_max_selection_span(7);
270 options.add_chargram_orders(1);
271 options.set_tokenize_on_space(true);
272 TokenizationCodepointRange* config =
273 options.add_tokenization_codepoint_config();
274 config->set_start(32);
275 config->set_end(33);
276 config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
277 FeatureProcessor feature_processor(options);
278
279 SelectionWithContext selection_with_context;
280 selection_with_context.context = "1 2 3 c o n t e x t X c o n t e x t 1 2 3";
281 // Selection and click indices of the X in the middle:
282 selection_with_context.selection_start = 20;
283 selection_with_context.selection_end = 21;
284 selection_with_context.click_start = 20;
285 selection_with_context.click_end = 21;
286
287 std::vector<std::vector<std::pair<int, float>>> features;
288 std::vector<float> extra_features;
289 std::vector<CodepointSpan> selection_label_spans;
290 int selection_label;
291 CodepointSpan selection_codepoint_label;
292 int classification_label;
293 EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
294 selection_with_context, &features, &extra_features,
295 &selection_label_spans, &selection_label, &selection_codepoint_label,
296 &classification_label));
297 EXPECT_EQ(19, features.size());
298
299 // Should pad the string.
300 selection_with_context.context = "X";
301 selection_with_context.selection_start = 0;
302 selection_with_context.selection_end = 1;
303 selection_with_context.click_start = 0;
304 selection_with_context.click_end = 1;
305 EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
306 selection_with_context, &features, &extra_features,
307 &selection_label_spans, &selection_label, &selection_codepoint_label,
308 &classification_label));
309 EXPECT_EQ(19, features.size());
310}
311
312class TestingFeatureProcessor : public FeatureProcessor {
313 public:
314 using FeatureProcessor::FeatureProcessor;
315 using FeatureProcessor::FindTokensInSelection;
316};
317
318TEST(FeatureProcessorTest, FindTokensInSelectionSingleCharacter) {
319 FeatureProcessorOptions options;
320 options.set_num_buckets(10);
321 options.set_context_size(9);
322 options.set_max_selection_span(7);
323 options.add_chargram_orders(1);
324 options.set_tokenize_on_space(true);
325 TokenizationCodepointRange* config =
326 options.add_tokenization_codepoint_config();
327 config->set_start(32);
328 config->set_end(33);
329 config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
330 TestingFeatureProcessor feature_processor(options);
331
332 SelectionWithContext selection_with_context;
333 selection_with_context.context = "1 2 3 c o n t e x t X c o n t e x t 1 2 3";
334
335 // Selection and click indices of the X in the middle:
336 selection_with_context.selection_start = 20;
337 selection_with_context.selection_end = 21;
338 // clang-format off
339 EXPECT_THAT(feature_processor.FindTokensInSelection(
340 feature_processor.Tokenize(selection_with_context.context),
341 selection_with_context),
342 ElementsAreArray({Token("X", 20, 21, false)}));
343 // clang-format on
344}
345
346TEST(FeatureProcessorTest, FindTokensInSelectionInsideTokenBoundary) {
347 FeatureProcessorOptions options;
348 options.set_num_buckets(10);
349 options.set_context_size(9);
350 options.set_max_selection_span(7);
351 options.add_chargram_orders(1);
352 options.set_tokenize_on_space(true);
353 TokenizationCodepointRange* config =
354 options.add_tokenization_codepoint_config();
355 config->set_start(32);
356 config->set_end(33);
357 config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
358 TestingFeatureProcessor feature_processor(options);
359
360 SelectionWithContext selection_with_context;
361 selection_with_context.context = "I live at 350 Third Street, today.";
362
363 const std::vector<Token> expected_selection = {
364 // clang-format off
365 Token("350", 10, 13, false),
366 Token("Third", 14, 19, false),
367 Token("Street,", 20, 27, false),
368 // clang-format on
369 };
370
371 // Selection: I live at {350 Third Str}eet, today.
372 selection_with_context.selection_start = 10;
373 selection_with_context.selection_end = 23;
374 EXPECT_THAT(feature_processor.FindTokensInSelection(
375 feature_processor.Tokenize(selection_with_context.context),
376 selection_with_context),
377 ElementsAreArray(expected_selection));
378
379 // Selection: I live at {350 Third Street,} today.
380 selection_with_context.selection_start = 10;
381 selection_with_context.selection_end = 27;
382 EXPECT_THAT(feature_processor.FindTokensInSelection(
383 feature_processor.Tokenize(selection_with_context.context),
384 selection_with_context),
385 ElementsAreArray(expected_selection));
386
387 // Selection: I live at {350 Third Street, }today.
388 selection_with_context.selection_start = 10;
389 selection_with_context.selection_end = 28;
390 EXPECT_THAT(feature_processor.FindTokensInSelection(
391 feature_processor.Tokenize(selection_with_context.context),
392 selection_with_context),
393 ElementsAreArray(expected_selection));
394
395 // Selection: I live at {350 Third S}treet, today.
396 selection_with_context.selection_start = 10;
397 selection_with_context.selection_end = 21;
398 EXPECT_THAT(feature_processor.FindTokensInSelection(
399 feature_processor.Tokenize(selection_with_context.context),
400 selection_with_context),
401 ElementsAreArray(expected_selection));
402
403 // Test that when crossing the boundary, we select less/more.
404
405 // Selection: I live at {350 Third} Street, today.
406 selection_with_context.selection_start = 10;
407 selection_with_context.selection_end = 19;
408 EXPECT_THAT(feature_processor.FindTokensInSelection(
409 feature_processor.Tokenize(selection_with_context.context),
410 selection_with_context),
411 ElementsAreArray({
412 // clang-format off
413 Token("350", 10, 13, false),
414 Token("Third", 14, 19, false),
415 // clang-format on
416 }));
417
418 // Selection: I live at {350 Third Street, t}oday.
419 selection_with_context.selection_start = 10;
420 selection_with_context.selection_end = 29;
421 EXPECT_THAT(
422 feature_processor.FindTokensInSelection(
423 feature_processor.Tokenize(selection_with_context.context),
424 selection_with_context),
425 ElementsAreArray({
426 // clang-format off
427 Token("350", 10, 13, false),
428 Token("Third", 14, 19, false),
429 Token("Street,", 20, 27, false),
430 Token("today.", 28, 34, false),
431 // clang-format on
432 }));
433}
434
435} // namespace
436} // namespace libtextclassifier