blob: 70ef0a74652e2629a5b9ed2cd7ecba6a9340502a [file] [log] [blame]
Matt Sharifid40f9762017-03-14 21:24:23 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Lukas Zilka21d8c982018-01-24 11:11:20 +010017#include "feature-processor.h"
18
19#include "model-executor.h"
20#include "tensor-view.h"
Matt Sharifid40f9762017-03-14 21:24:23 +010021
22#include "gmock/gmock.h"
23#include "gtest/gtest.h"
24
Lukas Zilka21d8c982018-01-24 11:11:20 +010025namespace libtextclassifier2 {
Matt Sharifid40f9762017-03-14 21:24:23 +010026namespace {
27
28using testing::ElementsAreArray;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +020029using testing::FloatEq;
Lukas Zilkaba849e72018-03-08 14:48:21 +010030using testing::Matcher;
Matt Sharifid40f9762017-03-14 21:24:23 +010031
Lukas Zilka21d8c982018-01-24 11:11:20 +010032flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
33 const FeatureProcessorOptionsT& options) {
34 flatbuffers::FlatBufferBuilder builder;
35 builder.Finish(CreateFeatureProcessorOptions(builder, &options));
36 return builder.Release();
37}
38
Lukas Zilkaba849e72018-03-08 14:48:21 +010039template <typename T>
40std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) {
41 return std::vector<T>(vector.begin() + start, vector.begin() + end);
42}
43
44Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
45 std::vector<Matcher<float>> matchers;
46 for (const float value : values) {
47 matchers.push_back(FloatEq(value));
48 }
49 return ElementsAreArray(matchers);
50}
51
Lukas Zilka726b4d22017-12-13 16:37:03 +010052class TestingFeatureProcessor : public FeatureProcessor {
53 public:
54 using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
55 using FeatureProcessor::FeatureProcessor;
56 using FeatureProcessor::ICUTokenize;
57 using FeatureProcessor::IsCodepointInRanges;
58 using FeatureProcessor::SpanToLabel;
59 using FeatureProcessor::StripTokensFromOtherLines;
60 using FeatureProcessor::supported_codepoint_ranges_;
61 using FeatureProcessor::SupportedCodepointsRatio;
62};
63
Lukas Zilka21d8c982018-01-24 11:11:20 +010064// EmbeddingExecutor that always returns features based on
65class FakeEmbeddingExecutor : public EmbeddingExecutor {
66 public:
67 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
Lukas Zilkaba849e72018-03-08 14:48:21 +010068 int dest_size) const override {
Lukas Zilka21d8c982018-01-24 11:11:20 +010069 TC_CHECK_GE(dest_size, 4);
70 EXPECT_EQ(sparse_features.size(), 1);
71 dest[0] = sparse_features.data()[0];
72 dest[1] = sparse_features.data()[0];
73 dest[2] = -sparse_features.data()[0];
74 dest[3] = -sparse_features.data()[0];
75 return true;
76 }
77
78 private:
79 std::vector<float> storage_;
80};
81
Matt Sharifid40f9762017-03-14 21:24:23 +010082TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
Lukas Zilka6bb39a82017-04-07 19:55:11 +020083 std::vector<Token> tokens{Token("Hělló", 0, 5),
84 Token("fěěbař@google.com", 6, 23),
85 Token("heře!", 24, 29)};
Matt Sharifid40f9762017-03-14 21:24:23 +010086
87 internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
88
89 // clang-format off
90 EXPECT_THAT(tokens, ElementsAreArray(
Lukas Zilka6bb39a82017-04-07 19:55:11 +020091 {Token("Hělló", 0, 5),
92 Token("fěě", 6, 9),
93 Token("bař", 9, 12),
94 Token("@google.com", 12, 23),
95 Token("heře!", 24, 29)}));
Matt Sharifid40f9762017-03-14 21:24:23 +010096 // clang-format on
97}
98
99TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200100 std::vector<Token> tokens{Token("Hělló", 0, 5),
101 Token("fěěbař@google.com", 6, 23),
102 Token("heře!", 24, 29)};
Matt Sharifid40f9762017-03-14 21:24:23 +0100103
104 internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
105
106 // clang-format off
107 EXPECT_THAT(tokens, ElementsAreArray(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200108 {Token("Hělló", 0, 5),
109 Token("fěěbař", 6, 12),
110 Token("@google.com", 12, 23),
111 Token("heře!", 24, 29)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100112 // clang-format on
113}
114
115TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200116 std::vector<Token> tokens{Token("Hělló", 0, 5),
117 Token("fěěbař@google.com", 6, 23),
118 Token("heře!", 24, 29)};
Matt Sharifid40f9762017-03-14 21:24:23 +0100119
120 internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
121
122 // clang-format off
123 EXPECT_THAT(tokens, ElementsAreArray(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200124 {Token("Hělló", 0, 5),
125 Token("fěě", 6, 9),
126 Token("bař@google.com", 9, 23),
127 Token("heře!", 24, 29)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100128 // clang-format on
129}
130
131TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200132 std::vector<Token> tokens{Token("Hělló", 0, 5),
133 Token("fěěbař@google.com", 6, 23),
134 Token("heře!", 24, 29)};
Matt Sharifid40f9762017-03-14 21:24:23 +0100135
136 internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
137
138 // clang-format off
139 EXPECT_THAT(tokens, ElementsAreArray(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200140 {Token("Hělló", 0, 5),
141 Token("fěěbař@google.com", 6, 23),
142 Token("heře!", 24, 29)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100143 // clang-format on
144}
145
146TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200147 std::vector<Token> tokens{Token("Hělló", 0, 5),
148 Token("fěěbař@google.com", 6, 23),
149 Token("heře!", 24, 29)};
Matt Sharifid40f9762017-03-14 21:24:23 +0100150
151 internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
152
153 // clang-format off
154 EXPECT_THAT(tokens, ElementsAreArray(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200155 {Token("Hě", 0, 2),
156 Token("lló", 2, 5),
157 Token("fěě", 6, 9),
158 Token("bař@google.com", 9, 23),
159 Token("heře!", 24, 29)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100160 // clang-format on
161}
162
163TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100164 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100165 FeatureProcessorOptionsT options;
166 options.only_use_line_with_click = true;
167 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
168 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100169 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
170 &unilib);
Lukas Zilka726b4d22017-12-13 16:37:03 +0100171
Matt Sharifibe876dc2017-03-17 17:02:43 +0100172 const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
173 const CodepointSpan span = {0, 5};
174 // clang-format off
175 std::vector<Token> tokens = {Token("Fiřst", 0, 5),
176 Token("Lině", 6, 10),
177 Token("Sěcond", 11, 17),
178 Token("Lině", 18, 22),
179 Token("Thiřd", 23, 28),
180 Token("Lině", 29, 33)};
181 // clang-format on
Matt Sharifid40f9762017-03-14 21:24:23 +0100182
183 // Keeps the first line.
Lukas Zilka726b4d22017-12-13 16:37:03 +0100184 feature_processor.StripTokensFromOtherLines(context, span, &tokens);
Matt Sharifibe876dc2017-03-17 17:02:43 +0100185 EXPECT_THAT(tokens,
186 ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100187}
188
189TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100190 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100191 FeatureProcessorOptionsT options;
192 options.only_use_line_with_click = true;
193 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
194 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100195 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
196 &unilib);
Lukas Zilka726b4d22017-12-13 16:37:03 +0100197
Matt Sharifibe876dc2017-03-17 17:02:43 +0100198 const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
199 const CodepointSpan span = {18, 22};
200 // clang-format off
201 std::vector<Token> tokens = {Token("Fiřst", 0, 5),
202 Token("Lině", 6, 10),
203 Token("Sěcond", 11, 17),
204 Token("Lině", 18, 22),
205 Token("Thiřd", 23, 28),
206 Token("Lině", 29, 33)};
207 // clang-format on
Matt Sharifid40f9762017-03-14 21:24:23 +0100208
Matt Sharifibe876dc2017-03-17 17:02:43 +0100209 // Keeps the first line.
Lukas Zilka726b4d22017-12-13 16:37:03 +0100210 feature_processor.StripTokensFromOtherLines(context, span, &tokens);
Matt Sharifibe876dc2017-03-17 17:02:43 +0100211 EXPECT_THAT(tokens, ElementsAreArray(
212 {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100213}
214
215TEST(FeatureProcessorTest, KeepLineWithClickThird) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100216 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100217 FeatureProcessorOptionsT options;
218 options.only_use_line_with_click = true;
219 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
220 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100221 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
222 &unilib);
Lukas Zilka726b4d22017-12-13 16:37:03 +0100223
Matt Sharifibe876dc2017-03-17 17:02:43 +0100224 const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
225 const CodepointSpan span = {24, 33};
226 // clang-format off
227 std::vector<Token> tokens = {Token("Fiřst", 0, 5),
228 Token("Lině", 6, 10),
229 Token("Sěcond", 11, 17),
230 Token("Lině", 18, 22),
231 Token("Thiřd", 23, 28),
232 Token("Lině", 29, 33)};
233 // clang-format on
Matt Sharifid40f9762017-03-14 21:24:23 +0100234
Matt Sharifibe876dc2017-03-17 17:02:43 +0100235 // Keeps the first line.
Lukas Zilka726b4d22017-12-13 16:37:03 +0100236 feature_processor.StripTokensFromOtherLines(context, span, &tokens);
Matt Sharifibe876dc2017-03-17 17:02:43 +0100237 EXPECT_THAT(tokens, ElementsAreArray(
238 {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100239}
240
241TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100242 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100243 FeatureProcessorOptionsT options;
244 options.only_use_line_with_click = true;
245 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
246 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100247 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
248 &unilib);
Lukas Zilka726b4d22017-12-13 16:37:03 +0100249
Matt Sharifibe876dc2017-03-17 17:02:43 +0100250 const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
251 const CodepointSpan span = {18, 22};
252 // clang-format off
253 std::vector<Token> tokens = {Token("Fiřst", 0, 5),
254 Token("Lině", 6, 10),
255 Token("Sěcond", 11, 17),
256 Token("Lině", 18, 22),
257 Token("Thiřd", 23, 28),
258 Token("Lině", 29, 33)};
259 // clang-format on
Matt Sharifid40f9762017-03-14 21:24:23 +0100260
Matt Sharifibe876dc2017-03-17 17:02:43 +0100261 // Keeps the first line.
Lukas Zilka726b4d22017-12-13 16:37:03 +0100262 feature_processor.StripTokensFromOtherLines(context, span, &tokens);
Matt Sharifibe876dc2017-03-17 17:02:43 +0100263 EXPECT_THAT(tokens, ElementsAreArray(
264 {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100265}
266
267TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100268 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100269 FeatureProcessorOptionsT options;
270 options.only_use_line_with_click = true;
271 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
272 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100273 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
274 &unilib);
Lukas Zilka726b4d22017-12-13 16:37:03 +0100275
Matt Sharifibe876dc2017-03-17 17:02:43 +0100276 const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
277 const CodepointSpan span = {5, 23};
278 // clang-format off
279 std::vector<Token> tokens = {Token("Fiřst", 0, 5),
280 Token("Lině", 6, 10),
281 Token("Sěcond", 18, 23),
282 Token("Lině", 19, 23),
283 Token("Thiřd", 23, 28),
284 Token("Lině", 29, 33)};
285 // clang-format on
Matt Sharifid40f9762017-03-14 21:24:23 +0100286
Matt Sharifibe876dc2017-03-17 17:02:43 +0100287 // Keeps the first line.
Lukas Zilka726b4d22017-12-13 16:37:03 +0100288 feature_processor.StripTokensFromOtherLines(context, span, &tokens);
Matt Sharifibe876dc2017-03-17 17:02:43 +0100289 EXPECT_THAT(tokens, ElementsAreArray(
290 {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
291 Token("Sěcond", 18, 23), Token("Lině", 19, 23),
292 Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100293}
294
Matt Sharifi0d68ef92017-03-27 14:20:21 +0200295TEST(FeatureProcessorTest, SpanToLabel) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100296 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100297 FeatureProcessorOptionsT options;
298 options.context_size = 1;
299 options.max_selection_span = 1;
300 options.snap_label_span_boundaries_to_containing_tokens = false;
Matt Sharifi0d68ef92017-03-27 14:20:21 +0200301
Lukas Zilka21d8c982018-01-24 11:11:20 +0100302 options.tokenization_codepoint_config.emplace_back(
303 new TokenizationCodepointRangeT());
304 auto& config = options.tokenization_codepoint_config.back();
305 config->start = 32;
306 config->end = 33;
307 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
Matt Sharifi0d68ef92017-03-27 14:20:21 +0200308
Lukas Zilka21d8c982018-01-24 11:11:20 +0100309 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
310 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100311 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
312 &unilib);
Matt Sharifi0d68ef92017-03-27 14:20:21 +0200313 std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
314 ASSERT_EQ(3, tokens.size());
315 int label;
316 ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
317 EXPECT_EQ(kInvalidLabel, label);
318 ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
319 EXPECT_NE(kInvalidLabel, label);
320 TokenSpan token_span;
321 feature_processor.LabelToTokenSpan(label, &token_span);
322 EXPECT_EQ(0, token_span.first);
323 EXPECT_EQ(0, token_span.second);
324
325 // Reconfigure with snapping enabled.
Lukas Zilka21d8c982018-01-24 11:11:20 +0100326 options.snap_label_span_boundaries_to_containing_tokens = true;
327 flatbuffers::DetachedBuffer options2_fb =
328 PackFeatureProcessorOptions(options);
329 TestingFeatureProcessor feature_processor2(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100330 flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
331 &unilib);
Matt Sharifi0d68ef92017-03-27 14:20:21 +0200332 int label2;
333 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
334 EXPECT_EQ(label, label2);
335 ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
336 EXPECT_EQ(label, label2);
337 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
338 EXPECT_EQ(label, label2);
339
340 // Cross a token boundary.
341 ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
342 EXPECT_EQ(kInvalidLabel, label2);
343 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
344 EXPECT_EQ(kInvalidLabel, label2);
345
346 // Multiple tokens.
Lukas Zilka21d8c982018-01-24 11:11:20 +0100347 options.context_size = 2;
348 options.max_selection_span = 2;
349 flatbuffers::DetachedBuffer options3_fb =
350 PackFeatureProcessorOptions(options);
351 TestingFeatureProcessor feature_processor3(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100352 flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
353 &unilib);
Matt Sharifi0d68ef92017-03-27 14:20:21 +0200354 tokens = feature_processor3.Tokenize("zero, one, two, three, four");
355 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
356 EXPECT_NE(kInvalidLabel, label2);
357 feature_processor3.LabelToTokenSpan(label2, &token_span);
358 EXPECT_EQ(1, token_span.first);
359 EXPECT_EQ(0, token_span.second);
360
361 int label3;
362 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
363 EXPECT_EQ(label2, label3);
364 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
365 EXPECT_EQ(label2, label3);
366 ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
367 EXPECT_EQ(label2, label3);
368}
369
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200370TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100371 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100372 FeatureProcessorOptionsT options;
373 options.context_size = 1;
374 options.max_selection_span = 1;
375 options.snap_label_span_boundaries_to_containing_tokens = false;
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200376
Lukas Zilka21d8c982018-01-24 11:11:20 +0100377 options.tokenization_codepoint_config.emplace_back(
378 new TokenizationCodepointRangeT());
379 auto& config = options.tokenization_codepoint_config.back();
380 config->start = 32;
381 config->end = 33;
382 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200383
Lukas Zilka21d8c982018-01-24 11:11:20 +0100384 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
385 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100386 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
387 &unilib);
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200388 std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
389 ASSERT_EQ(3, tokens.size());
390 int label;
391 ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
392 EXPECT_EQ(kInvalidLabel, label);
393 ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
394 EXPECT_NE(kInvalidLabel, label);
395 TokenSpan token_span;
396 feature_processor.LabelToTokenSpan(label, &token_span);
397 EXPECT_EQ(0, token_span.first);
398 EXPECT_EQ(0, token_span.second);
399
400 // Reconfigure with snapping enabled.
Lukas Zilka21d8c982018-01-24 11:11:20 +0100401 options.snap_label_span_boundaries_to_containing_tokens = true;
402 flatbuffers::DetachedBuffer options2_fb =
403 PackFeatureProcessorOptions(options);
404 TestingFeatureProcessor feature_processor2(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100405 flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
406 &unilib);
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200407 int label2;
408 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
409 EXPECT_EQ(label, label2);
410 ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
411 EXPECT_EQ(label, label2);
412 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
413 EXPECT_EQ(label, label2);
414
415 // Cross a token boundary.
416 ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
417 EXPECT_EQ(kInvalidLabel, label2);
418 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
419 EXPECT_EQ(kInvalidLabel, label2);
420
421 // Multiple tokens.
Lukas Zilka21d8c982018-01-24 11:11:20 +0100422 options.context_size = 2;
423 options.max_selection_span = 2;
424 flatbuffers::DetachedBuffer options3_fb =
425 PackFeatureProcessorOptions(options);
426 TestingFeatureProcessor feature_processor3(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100427 flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
428 &unilib);
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200429 tokens = feature_processor3.Tokenize("zero, one, two, three, four");
430 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
431 EXPECT_NE(kInvalidLabel, label2);
432 feature_processor3.LabelToTokenSpan(label2, &token_span);
433 EXPECT_EQ(1, token_span.first);
434 EXPECT_EQ(0, token_span.second);
435
436 int label3;
437 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
438 EXPECT_EQ(label2, label3);
439 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
440 EXPECT_EQ(label2, label3);
441 ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
442 EXPECT_EQ(label2, label3);
443}
444
Matt Sharifibe876dc2017-03-17 17:02:43 +0100445TEST(FeatureProcessorTest, CenterTokenFromClick) {
446 int token_index;
447
448 // Exactly aligned indices.
449 token_index = internal::CenterTokenFromClick(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200450 {6, 11},
451 {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100452 EXPECT_EQ(token_index, 1);
453
454 // Click is contained in a token.
455 token_index = internal::CenterTokenFromClick(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200456 {13, 17},
457 {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100458 EXPECT_EQ(token_index, 2);
459
460 // Click spans two tokens.
461 token_index = internal::CenterTokenFromClick(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200462 {6, 17},
463 {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100464 EXPECT_EQ(token_index, kInvalidIndex);
465}
466
467TEST(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
Matt Sharifibe876dc2017-03-17 17:02:43 +0100468 int token_index;
469
470 // Selection of length 3. Exactly aligned indices.
471 token_index = internal::CenterTokenFromMiddleOfSelection(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200472 {7, 27},
473 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
474 Token("Token4", 21, 27), Token("Token5", 28, 34)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100475 EXPECT_EQ(token_index, 2);
476
477 // Selection of length 1 token. Exactly aligned indices.
478 token_index = internal::CenterTokenFromMiddleOfSelection(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200479 {21, 27},
480 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
481 Token("Token4", 21, 27), Token("Token5", 28, 34)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100482 EXPECT_EQ(token_index, 3);
483
484 // Selection marks sub-token range, with no tokens in it.
485 token_index = internal::CenterTokenFromMiddleOfSelection(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200486 {29, 33},
487 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
488 Token("Token4", 21, 27), Token("Token5", 28, 34)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100489 EXPECT_EQ(token_index, kInvalidIndex);
490
491 // Selection of length 2. Sub-token indices.
492 token_index = internal::CenterTokenFromMiddleOfSelection(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200493 {3, 25},
494 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
495 Token("Token4", 21, 27), Token("Token5", 28, 34)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100496 EXPECT_EQ(token_index, 1);
497
498 // Selection of length 1. Sub-token indices.
499 token_index = internal::CenterTokenFromMiddleOfSelection(
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200500 {22, 34},
501 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
502 Token("Token4", 21, 27), Token("Token5", 28, 34)});
Matt Sharifibe876dc2017-03-17 17:02:43 +0100503 EXPECT_EQ(token_index, 4);
Alex Salcianu9087f1f2017-03-22 21:22:39 -0400504
505 // Some invalid ones.
506 token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {});
507 EXPECT_EQ(token_index, -1);
508}
509
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200510TEST(FeatureProcessorTest, SupportedCodepointsRatio) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100511 FeatureProcessorOptionsT options;
512 options.context_size = 2;
513 options.max_selection_span = 2;
514 options.snap_label_span_boundaries_to_containing_tokens = false;
515 options.feature_version = 2;
516 options.embedding_size = 4;
517 options.bounds_sensitive_features.reset(
518 new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
519 options.bounds_sensitive_features->enabled = true;
520 options.bounds_sensitive_features->num_tokens_before = 5;
521 options.bounds_sensitive_features->num_tokens_inside_left = 3;
522 options.bounds_sensitive_features->num_tokens_inside_right = 3;
523 options.bounds_sensitive_features->num_tokens_after = 5;
524 options.bounds_sensitive_features->include_inside_bag = true;
525 options.bounds_sensitive_features->include_inside_length = true;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200526
Lukas Zilka21d8c982018-01-24 11:11:20 +0100527 options.tokenization_codepoint_config.emplace_back(
528 new TokenizationCodepointRangeT());
529 auto& config = options.tokenization_codepoint_config.back();
530 config->start = 32;
531 config->end = 33;
532 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200533
Lukas Zilka21d8c982018-01-24 11:11:20 +0100534 {
535 options.supported_codepoint_ranges.emplace_back(
536 new FeatureProcessorOptions_::CodepointRangeT());
537 auto& range = options.supported_codepoint_ranges.back();
538 range->start = 0;
539 range->end = 128;
540 }
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200541
Lukas Zilka21d8c982018-01-24 11:11:20 +0100542 {
543 options.supported_codepoint_ranges.emplace_back(
544 new FeatureProcessorOptions_::CodepointRangeT());
545 auto& range = options.supported_codepoint_ranges.back();
546 range->start = 10000;
547 range->end = 10001;
548 }
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200549
Lukas Zilka21d8c982018-01-24 11:11:20 +0100550 {
551 options.supported_codepoint_ranges.emplace_back(
552 new FeatureProcessorOptions_::CodepointRangeT());
553 auto& range = options.supported_codepoint_ranges.back();
554 range->start = 20000;
555 range->end = 30000;
556 }
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200557
Lukas Zilka21d8c982018-01-24 11:11:20 +0100558 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100559 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100560 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100561 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
562 &unilib);
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200563 EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
Lukas Zilka21d8c982018-01-24 11:11:20 +0100564 {0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200565 FloatEq(1.0));
566 EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
Lukas Zilka21d8c982018-01-24 11:11:20 +0100567 {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")),
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200568 FloatEq(2.0 / 3));
569 EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
Lukas Zilka21d8c982018-01-24 11:11:20 +0100570 {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200571 FloatEq(0.0));
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200572 EXPECT_FALSE(feature_processor.IsCodepointInRanges(
573 -1, feature_processor.supported_codepoint_ranges_));
574 EXPECT_TRUE(feature_processor.IsCodepointInRanges(
575 0, feature_processor.supported_codepoint_ranges_));
576 EXPECT_TRUE(feature_processor.IsCodepointInRanges(
577 10, feature_processor.supported_codepoint_ranges_));
578 EXPECT_TRUE(feature_processor.IsCodepointInRanges(
579 127, feature_processor.supported_codepoint_ranges_));
580 EXPECT_FALSE(feature_processor.IsCodepointInRanges(
581 128, feature_processor.supported_codepoint_ranges_));
582 EXPECT_FALSE(feature_processor.IsCodepointInRanges(
583 9999, feature_processor.supported_codepoint_ranges_));
584 EXPECT_TRUE(feature_processor.IsCodepointInRanges(
585 10000, feature_processor.supported_codepoint_ranges_));
586 EXPECT_FALSE(feature_processor.IsCodepointInRanges(
587 10001, feature_processor.supported_codepoint_ranges_));
588 EXPECT_TRUE(feature_processor.IsCodepointInRanges(
589 25000, feature_processor.supported_codepoint_ranges_));
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200590
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200591 std::unique_ptr<CachedFeatures> cached_features;
592
Lukas Zilka21d8c982018-01-24 11:11:20 +0100593 FakeEmbeddingExecutor embedding_executor;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200594
Lukas Zilka21d8c982018-01-24 11:11:20 +0100595 const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
596 Token("eee", 8, 11)};
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200597
Lukas Zilka21d8c982018-01-24 11:11:20 +0100598 options.min_supported_codepoint_ratio = 0.0;
599 flatbuffers::DetachedBuffer options2_fb =
600 PackFeatureProcessorOptions(options);
601 TestingFeatureProcessor feature_processor2(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100602 flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
603 &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100604 EXPECT_TRUE(feature_processor2.ExtractFeatures(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100605 tokens, /*token_span=*/{0, 3},
606 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
Lukas Zilkaba849e72018-03-08 14:48:21 +0100607 &embedding_executor, /*embedding_cache=*/nullptr,
Lukas Zilka21d8c982018-01-24 11:11:20 +0100608 /*feature_vector_size=*/4, &cached_features));
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200609
Lukas Zilka21d8c982018-01-24 11:11:20 +0100610 options.min_supported_codepoint_ratio = 0.2;
611 flatbuffers::DetachedBuffer options3_fb =
612 PackFeatureProcessorOptions(options);
613 TestingFeatureProcessor feature_processor3(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100614 flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
615 &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100616 EXPECT_TRUE(feature_processor3.ExtractFeatures(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100617 tokens, /*token_span=*/{0, 3},
618 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
Lukas Zilkaba849e72018-03-08 14:48:21 +0100619 &embedding_executor, /*embedding_cache=*/nullptr,
Lukas Zilka21d8c982018-01-24 11:11:20 +0100620 /*feature_vector_size=*/4, &cached_features));
621
622 options.min_supported_codepoint_ratio = 0.5;
623 flatbuffers::DetachedBuffer options4_fb =
624 PackFeatureProcessorOptions(options);
625 TestingFeatureProcessor feature_processor4(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100626 flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
627 &unilib);
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200628 EXPECT_FALSE(feature_processor4.ExtractFeatures(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100629 tokens, /*token_span=*/{0, 3},
630 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
Lukas Zilkaba849e72018-03-08 14:48:21 +0100631 &embedding_executor, /*embedding_cache=*/nullptr,
Lukas Zilka21d8c982018-01-24 11:11:20 +0100632 /*feature_vector_size=*/4, &cached_features));
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200633}
634
Lukas Zilkab23e2122018-02-09 10:25:19 +0100635TEST(FeatureProcessorTest, InSpanFeature) {
636 FeatureProcessorOptionsT options;
637 options.context_size = 2;
638 options.max_selection_span = 2;
639 options.snap_label_span_boundaries_to_containing_tokens = false;
640 options.feature_version = 2;
641 options.embedding_size = 4;
642 options.extract_selection_mask_feature = true;
643
644 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100645 CREATE_UNILIB_FOR_TESTING;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100646 TestingFeatureProcessor feature_processor(
647 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
648 &unilib);
649
650 std::unique_ptr<CachedFeatures> cached_features;
651
652 FakeEmbeddingExecutor embedding_executor;
653
654 const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7),
655 Token("ccc", 8, 11), Token("ddd", 12, 15)};
656
657 EXPECT_TRUE(feature_processor.ExtractFeatures(
658 tokens, /*token_span=*/{0, 4},
659 /*selection_span_for_feature=*/{4, 11}, &embedding_executor,
Lukas Zilkaba849e72018-03-08 14:48:21 +0100660 /*embedding_cache=*/nullptr, /*feature_vector_size=*/5,
661 &cached_features));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100662 std::vector<float> features;
663 cached_features->AppendClickContextFeaturesForClick(1, &features);
664 ASSERT_EQ(features.size(), 25);
665 EXPECT_THAT(features[4], FloatEq(0.0));
666 EXPECT_THAT(features[9], FloatEq(0.0));
667 EXPECT_THAT(features[14], FloatEq(1.0));
668 EXPECT_THAT(features[19], FloatEq(1.0));
669 EXPECT_THAT(features[24], FloatEq(0.0));
670}
671
Lukas Zilkaba849e72018-03-08 14:48:21 +0100672TEST(FeatureProcessorTest, EmbeddingCache) {
673 FeatureProcessorOptionsT options;
674 options.context_size = 2;
675 options.max_selection_span = 2;
676 options.snap_label_span_boundaries_to_containing_tokens = false;
677 options.feature_version = 2;
678 options.embedding_size = 4;
679 options.bounds_sensitive_features.reset(
680 new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
681 options.bounds_sensitive_features->enabled = true;
682 options.bounds_sensitive_features->num_tokens_before = 3;
683 options.bounds_sensitive_features->num_tokens_inside_left = 2;
684 options.bounds_sensitive_features->num_tokens_inside_right = 2;
685 options.bounds_sensitive_features->num_tokens_after = 3;
686
687 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
688 CREATE_UNILIB_FOR_TESTING;
689 TestingFeatureProcessor feature_processor(
690 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
691 &unilib);
692
693 std::unique_ptr<CachedFeatures> cached_features;
694
695 FakeEmbeddingExecutor embedding_executor;
696
697 const std::vector<Token> tokens = {
698 Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11),
699 Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)};
700
701 // We pre-populate the cache with dummy embeddings, to make sure they are
702 // used when populating the features vector.
703 const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0};
704 const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0};
705 const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0};
706 FeatureProcessor::EmbeddingCache embedding_cache = {
707 {{kInvalidIndex, kInvalidIndex}, cached_padding_features},
708 {{4, 7}, cached_features1},
709 {{12, 15}, cached_features2},
710 };
711
712 EXPECT_TRUE(feature_processor.ExtractFeatures(
713 tokens, /*token_span=*/{0, 6},
714 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
715 &embedding_executor, &embedding_cache, /*feature_vector_size=*/4,
716 &cached_features));
717 std::vector<float> features;
718 cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features);
719 ASSERT_EQ(features.size(), 40);
720 // Check that the dummy embeddings were used.
721 EXPECT_THAT(Subvector(features, 0, 4),
722 ElementsAreFloat(cached_padding_features));
723 EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1));
724 EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2));
725 EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2));
726 EXPECT_THAT(Subvector(features, 36, 40),
727 ElementsAreFloat(cached_padding_features));
728 // Check that the real embeddings were cached.
729 EXPECT_EQ(embedding_cache.size(), 7);
730 EXPECT_THAT(Subvector(features, 4, 8),
731 ElementsAreFloat(embedding_cache.at({0, 3})));
732 EXPECT_THAT(Subvector(features, 12, 16),
733 ElementsAreFloat(embedding_cache.at({8, 11})));
734 EXPECT_THAT(Subvector(features, 20, 24),
735 ElementsAreFloat(embedding_cache.at({8, 11})));
736 EXPECT_THAT(Subvector(features, 28, 32),
737 ElementsAreFloat(embedding_cache.at({16, 19})));
738 EXPECT_THAT(Subvector(features, 32, 36),
739 ElementsAreFloat(embedding_cache.at({20, 23})));
740}
741
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200742TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
743 std::vector<Token> tokens_orig{
744 Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
745 Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
746 Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
747 Token("12", 0, 0)};
748
749 std::vector<Token> tokens;
750 int click_index;
751
752 // Try to click first token and see if it gets padded from left.
753 tokens = tokens_orig;
754 click_index = 0;
755 internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
756 // clang-format off
757 EXPECT_EQ(tokens, std::vector<Token>({Token(),
758 Token(),
759 Token("0", 0, 0),
760 Token("1", 0, 0),
761 Token("2", 0, 0)}));
762 // clang-format on
763 EXPECT_EQ(click_index, 2);
764
765 // When we click the second token nothing should get padded.
766 tokens = tokens_orig;
767 click_index = 2;
768 internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
769 // clang-format off
770 EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
771 Token("1", 0, 0),
772 Token("2", 0, 0),
773 Token("3", 0, 0),
774 Token("4", 0, 0)}));
775 // clang-format on
776 EXPECT_EQ(click_index, 2);
777
778 // When we click the last token tokens should get padded from the right.
779 tokens = tokens_orig;
780 click_index = 12;
781 internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
782 // clang-format off
783 EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
784 Token("11", 0, 0),
785 Token("12", 0, 0),
786 Token(),
787 Token()}));
788 // clang-format on
789 EXPECT_EQ(click_index, 2);
790}
791
792TEST(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
793 std::vector<Token> tokens_orig{
794 Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
795 Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
796 Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
797 Token("12", 0, 0)};
798
799 std::vector<Token> tokens;
800 int click_index;
801
802 // Try to click first token and see if it gets padded from left to maximum
803 // context_size.
804 tokens = tokens_orig;
805 click_index = 0;
806 internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
807 // clang-format off
808 EXPECT_EQ(tokens, std::vector<Token>({Token(),
809 Token(),
810 Token("0", 0, 0),
811 Token("1", 0, 0),
812 Token("2", 0, 0),
813 Token("3", 0, 0),
814 Token("4", 0, 0),
815 Token("5", 0, 0)}));
816 // clang-format on
817 EXPECT_EQ(click_index, 2);
818
819 // Clicking to the middle with enough context should not produce any padding.
820 tokens = tokens_orig;
821 click_index = 6;
822 internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
823 // clang-format off
824 EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
825 Token("2", 0, 0),
826 Token("3", 0, 0),
827 Token("4", 0, 0),
828 Token("5", 0, 0),
829 Token("6", 0, 0),
830 Token("7", 0, 0),
831 Token("8", 0, 0),
832 Token("9", 0, 0)}));
833 // clang-format on
834 EXPECT_EQ(click_index, 5);
835
836 // Clicking at the end should pad right to maximum context_size.
837 tokens = tokens_orig;
838 click_index = 11;
839 internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
840 // clang-format off
841 EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
842 Token("7", 0, 0),
843 Token("8", 0, 0),
844 Token("9", 0, 0),
845 Token("10", 0, 0),
846 Token("11", 0, 0),
847 Token("12", 0, 0),
848 Token(),
849 Token()}));
850 // clang-format on
851 EXPECT_EQ(click_index, 5);
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200852}
853
Lukas Zilka21d8c982018-01-24 11:11:20 +0100854TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100855 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100856 FeatureProcessorOptionsT options;
857 options.tokenization_codepoint_config.emplace_back(
858 new TokenizationCodepointRangeT());
859 {
860 auto& config = options.tokenization_codepoint_config.back();
861 config->start = 0;
862 config->end = 256;
863 config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
864 config->script_id = 1;
865 }
866 options.tokenize_on_script_change = false;
Lukas Zilka40c18de2017-04-10 17:22:22 +0200867
Lukas Zilka21d8c982018-01-24 11:11:20 +0100868 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
869 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100870 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
871 &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100872
873 EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"),
874 std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
875
876 options.tokenize_on_script_change = true;
877 flatbuffers::DetachedBuffer options_fb2 =
878 PackFeatureProcessorOptions(options);
879 TestingFeatureProcessor feature_processor2(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100880 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()),
881 &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100882
883 EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"),
884 std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
885 Token("웹사이트", 7, 11)}));
886}
887
888#ifdef LIBTEXTCLASSIFIER_TEST_ICU
889TEST(FeatureProcessorTest, ICUTokenize) {
890 FeatureProcessorOptionsT options;
891 options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
892
893 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
894 TestingFeatureProcessor feature_processor(
895 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
Lukas Zilka40c18de2017-04-10 17:22:22 +0200896 std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ");
897 ASSERT_EQ(tokens,
898 // clang-format off
899 std::vector<Token>({Token("พระบาท", 0, 6),
900 Token("สมเด็จ", 6, 12),
901 Token("พระ", 12, 15),
902 Token("ปร", 15, 17),
903 Token("มิ", 17, 19)}));
904 // clang-format on
905}
Lukas Zilka21d8c982018-01-24 11:11:20 +0100906#endif
Lukas Zilka40c18de2017-04-10 17:22:22 +0200907
Lukas Zilka21d8c982018-01-24 11:11:20 +0100908#ifdef LIBTEXTCLASSIFIER_TEST_ICU
Lukas Zilka40c18de2017-04-10 17:22:22 +0200909TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100910 FeatureProcessorOptionsT options;
911 options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
912 options.icu_preserve_whitespace_tokens = true;
Lukas Zilka40c18de2017-04-10 17:22:22 +0200913
Lukas Zilka21d8c982018-01-24 11:11:20 +0100914 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
915 TestingFeatureProcessor feature_processor(
916 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
Lukas Zilka40c18de2017-04-10 17:22:22 +0200917 std::vector<Token> tokens =
918 feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
919 ASSERT_EQ(tokens,
920 // clang-format off
921 std::vector<Token>({Token("พระบาท", 0, 6),
922 Token(" ", 6, 7),
923 Token("สมเด็จ", 7, 13),
924 Token(" ", 13, 14),
925 Token("พระ", 14, 17),
926 Token(" ", 17, 18),
927 Token("ปร", 18, 20),
928 Token(" ", 20, 21),
929 Token("มิ", 21, 23)}));
930 // clang-format on
931}
Lukas Zilka21d8c982018-01-24 11:11:20 +0100932#endif
Lukas Zilka40c18de2017-04-10 17:22:22 +0200933
Lukas Zilka21d8c982018-01-24 11:11:20 +0100934#ifdef LIBTEXTCLASSIFIER_TEST_ICU
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200935TEST(FeatureProcessorTest, MixedTokenize) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100936 FeatureProcessorOptionsT options;
937 options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED;
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200938
Lukas Zilka21d8c982018-01-24 11:11:20 +0100939 options.tokenization_codepoint_config.emplace_back(
940 new TokenizationCodepointRangeT());
941 auto& config = options.tokenization_codepoint_config.back();
942 config->start = 32;
943 config->end = 33;
944 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200945
Lukas Zilka21d8c982018-01-24 11:11:20 +0100946 {
947 options.internal_tokenizer_codepoint_ranges.emplace_back(
948 new FeatureProcessorOptions_::CodepointRangeT());
949 auto& range = options.internal_tokenizer_codepoint_ranges.back();
950 range->start = 0;
951 range->end = 128;
952 }
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200953
Lukas Zilka21d8c982018-01-24 11:11:20 +0100954 {
955 options.internal_tokenizer_codepoint_ranges.emplace_back(
956 new FeatureProcessorOptions_::CodepointRangeT());
957 auto& range = options.internal_tokenizer_codepoint_ranges.back();
958 range->start = 128;
959 range->end = 256;
960 }
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200961
Lukas Zilka21d8c982018-01-24 11:11:20 +0100962 {
963 options.internal_tokenizer_codepoint_ranges.emplace_back(
964 new FeatureProcessorOptions_::CodepointRangeT());
965 auto& range = options.internal_tokenizer_codepoint_ranges.back();
966 range->start = 256;
967 range->end = 384;
968 }
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200969
Lukas Zilka21d8c982018-01-24 11:11:20 +0100970 {
971 options.internal_tokenizer_codepoint_ranges.emplace_back(
972 new FeatureProcessorOptions_::CodepointRangeT());
973 auto& range = options.internal_tokenizer_codepoint_ranges.back();
974 range->start = 384;
975 range->end = 592;
976 }
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200977
Lukas Zilka21d8c982018-01-24 11:11:20 +0100978 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
979 TestingFeatureProcessor feature_processor(
980 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200981 std::vector<Token> tokens = feature_processor.Tokenize(
982 "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
983 ASSERT_EQ(tokens,
984 // clang-format off
985 std::vector<Token>({Token("こんにちは", 0, 5),
986 Token("Japanese-ląnguagę", 5, 22),
987 Token("text", 23, 27),
988 Token("世界", 28, 30),
989 Token("http://www.google.com/", 31, 53)}));
990 // clang-format on
991}
Lukas Zilka21d8c982018-01-24 11:11:20 +0100992#endif
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200993
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200994TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100995 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100996 FeatureProcessorOptionsT options;
997 options.ignored_span_boundary_codepoints.push_back('.');
998 options.ignored_span_boundary_codepoints.push_back(',');
999 options.ignored_span_boundary_codepoints.push_back('[');
1000 options.ignored_span_boundary_codepoints.push_back(']');
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +02001001
Lukas Zilka21d8c982018-01-24 11:11:20 +01001002 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
1003 TestingFeatureProcessor feature_processor(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001004 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
1005 &unilib);
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +02001006
1007 const std::string text1_utf8 = "ěščř";
1008 const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
1009 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1010 text1.begin(), text1.end(),
1011 /*count_from_beginning=*/true),
1012 0);
1013 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1014 text1.begin(), text1.end(),
1015 /*count_from_beginning=*/false),
1016 0);
1017
1018 const std::string text2_utf8 = ".,abčd";
1019 const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
1020 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1021 text2.begin(), text2.end(),
1022 /*count_from_beginning=*/true),
1023 2);
1024 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1025 text2.begin(), text2.end(),
1026 /*count_from_beginning=*/false),
1027 0);
1028
1029 const std::string text3_utf8 = ".,abčd[]";
1030 const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
1031 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1032 text3.begin(), text3.end(),
1033 /*count_from_beginning=*/true),
1034 2);
1035 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1036 text3.begin(), text3.end(),
1037 /*count_from_beginning=*/false),
1038 2);
1039
1040 const std::string text4_utf8 = "[abčd]";
1041 const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
1042 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1043 text4.begin(), text4.end(),
1044 /*count_from_beginning=*/true),
1045 1);
1046 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1047 text4.begin(), text4.end(),
1048 /*count_from_beginning=*/false),
1049 1);
1050
1051 const std::string text5_utf8 = "";
1052 const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
1053 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1054 text5.begin(), text5.end(),
1055 /*count_from_beginning=*/true),
1056 0);
1057 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1058 text5.begin(), text5.end(),
1059 /*count_from_beginning=*/false),
1060 0);
1061
1062 const std::string text6_utf8 = "012345ěščř";
1063 const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
1064 UnicodeText::const_iterator text6_begin = text6.begin();
1065 std::advance(text6_begin, 6);
1066 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1067 text6_begin, text6.end(),
1068 /*count_from_beginning=*/true),
1069 0);
1070 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1071 text6_begin, text6.end(),
1072 /*count_from_beginning=*/false),
1073 0);
1074
1075 const std::string text7_utf8 = "012345.,ěščř";
1076 const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
1077 UnicodeText::const_iterator text7_begin = text7.begin();
1078 std::advance(text7_begin, 6);
1079 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1080 text7_begin, text7.end(),
1081 /*count_from_beginning=*/true),
1082 2);
1083 UnicodeText::const_iterator text7_end = text7.begin();
1084 std::advance(text7_end, 8);
1085 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
1086 text7.begin(), text7_end,
1087 /*count_from_beginning=*/false),
1088 2);
1089
1090 // Test not stripping.
1091 EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
1092 "Hello [[[Wořld]] or not?", {0, 24}),
1093 std::make_pair(0, 24));
1094 // Test basic stripping.
1095 EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
1096 "Hello [[[Wořld]] or not?", {6, 16}),
1097 std::make_pair(9, 14));
1098 // Test stripping when everything is stripped.
1099 EXPECT_EQ(
1100 feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
1101 std::make_pair(6, 6));
1102 // Test stripping empty string.
1103 EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
1104 std::make_pair(0, 0));
1105}
1106
Lukas Zilka726b4d22017-12-13 16:37:03 +01001107TEST(FeatureProcessorTest, CodepointSpanToTokenSpan) {
1108 const std::vector<Token> tokens{Token("Hělló", 0, 5),
1109 Token("fěěbař@google.com", 6, 23),
1110 Token("heře!", 24, 29)};
1111
1112 // Spans matching the tokens exactly.
1113 EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
1114 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
1115 EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
1116 EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
1117 EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
1118 EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
1119
1120 // Snapping to containing tokens has no effect.
1121 EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
1122 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
1123 EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
1124 EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
1125 EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
1126 EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
1127
1128 // Span boundaries inside tokens.
1129 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
1130 EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
1131
1132 // Tokens adjacent to the span, but not overlapping.
1133 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
1134 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
1135}
1136
Matt Sharifid40f9762017-03-14 21:24:23 +01001137} // namespace
Lukas Zilka21d8c982018-01-24 11:11:20 +01001138} // namespace libtextclassifier2