blob: dc9b30ea17e39e5ed480f53c6a65865b28bcd8e0 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <algorithm>
Lukas Zilka21d8c982018-01-24 11:11:20 +010020#include <cmath>
Tony Mak21460022020-03-12 18:29:35 +000021#include <cstddef>
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <iterator>
23#include <numeric>
Tony Mak63959242020-02-07 18:31:16 +000024#include <string>
Tony Mak448b5862019-03-22 13:36:41 +000025#include <unordered_map>
Tony Mak63959242020-02-07 18:31:16 +000026#include <vector>
Lukas Zilka21d8c982018-01-24 11:11:20 +010027
Tony Mak854015a2019-01-16 15:56:48 +000028#include "annotator/collections.h"
Tony Mak83d2de62019-04-10 16:12:15 +010029#include "annotator/model_generated.h"
30#include "annotator/types.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010031#include "utils/base/logging.h"
32#include "utils/checksum.h"
Tony Mak63959242020-02-07 18:31:16 +000033#include "utils/i18n/locale.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010034#include "utils/math/softmax.h"
Tony Mak8cd7ba62019-10-15 15:29:22 +010035#include "utils/normalization.h"
Tony Maka2a1ff42019-09-12 15:40:32 +010036#include "utils/optional.h"
Tony Makd9446602019-02-20 18:25:39 +000037#include "utils/regex-match.h"
Tony Mak63959242020-02-07 18:31:16 +000038#include "utils/strings/numbers.h"
39#include "utils/strings/split.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010040#include "utils/utf8/unicodetext.h"
Tony Mak21460022020-03-12 18:29:35 +000041#include "utils/utf8/unilib-common.h"
Tony Mak378c1f52019-03-04 15:58:11 +000042#include "utils/zlib/zlib_regex.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010043
Tony Mak6c4cc672018-09-17 11:48:50 +010044namespace libtextclassifier3 {
Tony Mak448b5862019-03-22 13:36:41 +000045
46using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
47
Tony Mak6c4cc672018-09-17 11:48:50 +010048const std::string& Annotator::kPhoneCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010049 *[]() { return new std::string("phone"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010050const std::string& Annotator::kAddressCollection =
Lukas Zilkae7962cc2018-03-28 18:09:48 +020051 *[]() { return new std::string("address"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010052const std::string& Annotator::kDateCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010053 *[]() { return new std::string("date"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000054const std::string& Annotator::kUrlCollection =
55 *[]() { return new std::string("url"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000056const std::string& Annotator::kEmailCollection =
57 *[]() { return new std::string("email"); }();
Lukas Zilkab23e2122018-02-09 10:25:19 +010058
Lukas Zilka21d8c982018-01-24 11:11:20 +010059namespace {
60const Model* LoadAndVerifyModel(const void* addr, int size) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010061 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
Tony Mak51a9e542018-11-02 13:36:22 +000062 if (VerifyModelBuffer(verifier)) {
63 return GetModel(addr);
Lukas Zilka21d8c982018-01-24 11:11:20 +010064 } else {
65 return nullptr;
66 }
67}
Tony Mak6c4cc672018-09-17 11:48:50 +010068
Tony Mak76d80962020-01-08 17:30:51 +000069const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
70 int size) {
71 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
72 if (VerifyPersonNameModelBuffer(verifier)) {
73 return GetPersonNameModel(addr);
74 } else {
75 return nullptr;
76 }
77}
78
Tony Mak6c4cc672018-09-17 11:48:50 +010079// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
80// create a new instance, assign ownership to owned_lib, and return it.
81const UniLib* MaybeCreateUnilib(const UniLib* lib,
82 std::unique_ptr<UniLib>* owned_lib) {
83 if (lib) {
84 return lib;
85 } else {
86 owned_lib->reset(new UniLib);
87 return owned_lib->get();
88 }
89}
90
91// As above, but for CalendarLib.
92const CalendarLib* MaybeCreateCalendarlib(
93 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
94 if (lib) {
95 return lib;
96 } else {
97 owned_lib->reset(new CalendarLib);
98 return owned_lib->get();
99 }
100}
101
Tony Mak968412a2019-11-13 15:39:57 +0000102// Returns whether the provided input is valid:
103// * Valid utf8 text.
104// * Sane span indices.
105bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
106 if (!context.is_valid()) {
107 return false;
108 }
109 return (span.first >= 0 && span.first < span.second &&
110 span.second <= context.size_codepoints());
111}
112
Tony Mak63959242020-02-07 18:31:16 +0000113std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
114 const flatbuffers::Vector<int32_t>* ints) {
115 if (ints == nullptr) {
116 return {};
117 }
118 std::unordered_set<char32> ints_set;
119 for (auto value : *ints) {
120 ints_set.insert(static_cast<char32>(value));
121 }
122 return ints_set;
123}
124
Tony Mak21460022020-03-12 18:29:35 +0000125DateAnnotationOptions ToDateAnnotationOptions(
126 const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
127 const std::string& reference_timezone, const int64 reference_time_ms_utc) {
128 DateAnnotationOptions result_annotation_options;
129 result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
130 result_annotation_options.reference_timezone = reference_timezone;
131 if (fb_annotation_options != nullptr) {
132 result_annotation_options.enable_special_day_offset =
133 fb_annotation_options->enable_special_day_offset();
134 result_annotation_options.merge_adjacent_components =
135 fb_annotation_options->merge_adjacent_components();
136 result_annotation_options.enable_date_range =
137 fb_annotation_options->enable_date_range();
138 result_annotation_options.include_preposition =
139 fb_annotation_options->include_preposition();
140 result_annotation_options.expand_date_series =
141 fb_annotation_options->expand_date_series();
142 if (fb_annotation_options->extra_requested_dates() != nullptr) {
143 for (const auto& extra_requested_date :
144 *fb_annotation_options->extra_requested_dates()) {
145 result_annotation_options.extra_requested_dates.push_back(
146 extra_requested_date->str());
147 }
148 }
149 }
150 return result_annotation_options;
151}
152
Lukas Zilka21d8c982018-01-24 11:11:20 +0100153} // namespace
154
Lukas Zilkaba849e72018-03-08 14:48:21 +0100155tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
156 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100157 TC3_CHECK(selection_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100158 selection_interpreter_ = selection_executor_->CreateInterpreter();
159 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100160 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100161 }
162 }
163 return selection_interpreter_.get();
164}
165
166tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
167 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100168 TC3_CHECK(classification_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100169 classification_interpreter_ = classification_executor_->CreateInterpreter();
170 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100171 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100172 }
173 }
174 return classification_interpreter_.get();
175}
176
Tony Mak6c4cc672018-09-17 11:48:50 +0100177std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
178 const char* buffer, int size, const UniLib* unilib,
179 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100180 const Model* model = LoadAndVerifyModel(buffer, size);
181 if (model == nullptr) {
182 return nullptr;
183 }
184
Lukas Zilkab23e2122018-02-09 10:25:19 +0100185 auto classifier =
Tony Mak6c4cc672018-09-17 11:48:50 +0100186 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100187 if (!classifier->IsInitialized()) {
188 return nullptr;
189 }
190
191 return classifier;
192}
193
Tony Mak6c4cc672018-09-17 11:48:50 +0100194std::unique_ptr<Annotator> Annotator::FromScopedMmap(
195 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
196 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100197 if (!(*mmap)->handle().ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100198 TC3_VLOG(1) << "Mmap failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100199 return nullptr;
200 }
201
202 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
203 (*mmap)->handle().num_bytes());
204 if (!model) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100205 TC3_LOG(ERROR) << "Model verification failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100206 return nullptr;
207 }
208
Tony Mak6c4cc672018-09-17 11:48:50 +0100209 auto classifier = std::unique_ptr<Annotator>(
210 new Annotator(mmap, model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100211 if (!classifier->IsInitialized()) {
212 return nullptr;
213 }
214
215 return classifier;
216}
217
Tony Makdf54e742019-03-26 14:04:00 +0000218std::unique_ptr<Annotator> Annotator::FromScopedMmap(
219 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
220 std::unique_ptr<CalendarLib> calendarlib) {
221 if (!(*mmap)->handle().ok()) {
222 TC3_VLOG(1) << "Mmap failed.";
223 return nullptr;
224 }
225
226 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
227 (*mmap)->handle().num_bytes());
228 if (model == nullptr) {
229 TC3_LOG(ERROR) << "Model verification failed.";
230 return nullptr;
231 }
232
233 auto classifier = std::unique_ptr<Annotator>(
234 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
235 if (!classifier->IsInitialized()) {
236 return nullptr;
237 }
238
239 return classifier;
240}
241
Tony Mak6c4cc672018-09-17 11:48:50 +0100242std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
243 int fd, int offset, int size, const UniLib* unilib,
244 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100245 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
Tony Mak6c4cc672018-09-17 11:48:50 +0100246 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100247}
248
Tony Mak6c4cc672018-09-17 11:48:50 +0100249std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Makdf54e742019-03-26 14:04:00 +0000250 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
251 std::unique_ptr<CalendarLib> calendarlib) {
252 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
253 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
254}
255
256std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Mak6c4cc672018-09-17 11:48:50 +0100257 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100258 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
Tony Mak6c4cc672018-09-17 11:48:50 +0100259 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100260}
261
Tony Makdf54e742019-03-26 14:04:00 +0000262std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
263 int fd, std::unique_ptr<UniLib> unilib,
264 std::unique_ptr<CalendarLib> calendarlib) {
265 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
266 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
267}
268
Tony Mak6c4cc672018-09-17 11:48:50 +0100269std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
270 const UniLib* unilib,
271 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100272 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
Tony Mak6c4cc672018-09-17 11:48:50 +0100273 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100274}
275
Tony Makdf54e742019-03-26 14:04:00 +0000276std::unique_ptr<Annotator> Annotator::FromPath(
277 const std::string& path, std::unique_ptr<UniLib> unilib,
278 std::unique_ptr<CalendarLib> calendarlib) {
279 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
280 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
281}
282
Tony Mak6c4cc672018-09-17 11:48:50 +0100283Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
284 const UniLib* unilib, const CalendarLib* calendarlib)
285 : model_(model),
286 mmap_(std::move(*mmap)),
287 owned_unilib_(nullptr),
288 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
289 owned_calendarlib_(nullptr),
290 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
291 ValidateAndInitialize();
292}
293
Tony Makdf54e742019-03-26 14:04:00 +0000294Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
295 std::unique_ptr<UniLib> unilib,
296 std::unique_ptr<CalendarLib> calendarlib)
297 : model_(model),
298 mmap_(std::move(*mmap)),
299 owned_unilib_(std::move(unilib)),
300 unilib_(owned_unilib_.get()),
301 owned_calendarlib_(std::move(calendarlib)),
302 calendarlib_(owned_calendarlib_.get()) {
303 ValidateAndInitialize();
304}
305
Tony Mak6c4cc672018-09-17 11:48:50 +0100306Annotator::Annotator(const Model* model, const UniLib* unilib,
307 const CalendarLib* calendarlib)
308 : model_(model),
309 owned_unilib_(nullptr),
310 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
311 owned_calendarlib_(nullptr),
312 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
313 ValidateAndInitialize();
314}
315
316void Annotator::ValidateAndInitialize() {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100317 initialized_ = false;
318
Lukas Zilka21d8c982018-01-24 11:11:20 +0100319 if (model_ == nullptr) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100320 TC3_LOG(ERROR) << "No model specified.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100321 return;
322 }
323
Lukas Zilkaba849e72018-03-08 14:48:21 +0100324 const bool model_enabled_for_annotation =
325 (model_->triggering_options() != nullptr &&
326 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
327 const bool model_enabled_for_classification =
328 (model_->triggering_options() != nullptr &&
329 (model_->triggering_options()->enabled_modes() &
330 ModeFlag_CLASSIFICATION));
331 const bool model_enabled_for_selection =
332 (model_->triggering_options() != nullptr &&
333 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
334
335 // Annotation requires the selection model.
336 if (model_enabled_for_annotation || model_enabled_for_selection) {
337 if (!model_->selection_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100338 TC3_LOG(ERROR) << "No selection options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100339 return;
340 }
341 if (!model_->selection_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100342 TC3_LOG(ERROR) << "No selection feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100343 return;
344 }
345 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100346 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100347 return;
348 }
349 if (!model_->selection_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100350 TC3_LOG(ERROR) << "No selection model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100351 return;
352 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100353 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100354 if (!selection_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100355 TC3_LOG(ERROR) << "Could not initialize selection executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100356 return;
357 }
358 selection_feature_processor_.reset(
359 new FeatureProcessor(model_->selection_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100360 }
361
Lukas Zilkaba849e72018-03-08 14:48:21 +0100362 // Annotation requires the classification model for conflict resolution and
363 // scoring.
364 // Selection requires the classification model for conflict resolution.
365 if (model_enabled_for_annotation || model_enabled_for_classification ||
366 model_enabled_for_selection) {
367 if (!model_->classification_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100368 TC3_LOG(ERROR) << "No classification options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100369 return;
370 }
371
372 if (!model_->classification_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100373 TC3_LOG(ERROR) << "No classification feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100374 return;
375 }
376
377 if (!model_->classification_feature_options()
378 ->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100379 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100380 return;
381 }
382 if (!model_->classification_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100383 TC3_LOG(ERROR) << "No clf model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100384 return;
385 }
386
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200387 classification_executor_ =
Tony Mak6c4cc672018-09-17 11:48:50 +0100388 ModelExecutor::FromBuffer(model_->classification_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100389 if (!classification_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100390 TC3_LOG(ERROR) << "Could not initialize classification executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100391 return;
392 }
393
394 classification_feature_processor_.reset(new FeatureProcessor(
395 model_->classification_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100396 }
397
Lukas Zilkaba849e72018-03-08 14:48:21 +0100398 // The embeddings need to be specified if the model is to be used for
399 // classification or selection.
400 if (model_enabled_for_annotation || model_enabled_for_classification ||
401 model_enabled_for_selection) {
402 if (!model_->embedding_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100403 TC3_LOG(ERROR) << "No embedding model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100404 return;
405 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100406
Lukas Zilkaba849e72018-03-08 14:48:21 +0100407 // Check that the embedding size of the selection and classification model
408 // matches, as they are using the same embeddings.
409 if (model_enabled_for_selection &&
410 (model_->selection_feature_options()->embedding_size() !=
411 model_->classification_feature_options()->embedding_size() ||
412 model_->selection_feature_options()->embedding_quantization_bits() !=
413 model_->classification_feature_options()
414 ->embedding_quantization_bits())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100415 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100416 return;
417 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100418
Tony Mak6c4cc672018-09-17 11:48:50 +0100419 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200420 model_->embedding_model(),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100421 model_->classification_feature_options()->embedding_size(),
Tony Makdf54e742019-03-26 14:04:00 +0000422 model_->classification_feature_options()->embedding_quantization_bits(),
423 model_->embedding_pruning_mask());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200424 if (!embedding_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100425 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100426 return;
427 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100428 }
429
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200430 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
Lukas Zilkab23e2122018-02-09 10:25:19 +0100431 if (model_->regex_model()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200432 if (!InitializeRegexModel(decompressor.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100433 TC3_LOG(ERROR) << "Could not initialize regex model.";
Lukas Zilka434442d2018-04-25 11:38:51 +0200434 return;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100435 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100436 }
Tony Mak63959242020-02-07 18:31:16 +0000437 if (model_->grammar_datetime_model() &&
438 model_->grammar_datetime_model()->datetime_rules()) {
439 cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
440 *unilib_,
441 /*tokenizer_options=*/
442 model_->grammar_datetime_model()->grammar_tokenizer_options(),
443 *calendarlib_,
Tony Mak21460022020-03-12 18:29:35 +0000444 /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
445 model_->grammar_datetime_model()->target_classification_score(),
446 model_->grammar_datetime_model()->priority_score()));
Tony Mak63959242020-02-07 18:31:16 +0000447 if (!cfg_datetime_parser_) {
448 TC3_LOG(ERROR) << "Could not initialize context free grammar based "
449 "datetime parser.";
450 return;
451 }
452 } else if (model_->datetime_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100453 datetime_parser_ = DatetimeParser::Instance(
454 model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100455 if (!datetime_parser_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100456 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100457 return;
458 }
459 }
460
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200461 if (model_->output_options()) {
462 if (model_->output_options()->filtered_collections_annotation()) {
463 for (const auto collection :
464 *model_->output_options()->filtered_collections_annotation()) {
465 filtered_collections_annotation_.insert(collection->str());
466 }
467 }
468 if (model_->output_options()->filtered_collections_classification()) {
469 for (const auto collection :
470 *model_->output_options()->filtered_collections_classification()) {
471 filtered_collections_classification_.insert(collection->str());
472 }
473 }
474 if (model_->output_options()->filtered_collections_selection()) {
475 for (const auto collection :
476 *model_->output_options()->filtered_collections_selection()) {
477 filtered_collections_selection_.insert(collection->str());
478 }
479 }
480 }
481
Tony Mak378c1f52019-03-04 15:58:11 +0000482 if (model_->number_annotator_options() &&
483 model_->number_annotator_options()->enabled()) {
484 number_annotator_.reset(
Tony Mak63959242020-02-07 18:31:16 +0000485 new NumberAnnotator(model_->number_annotator_options(), unilib_));
486 }
487
488 if (model_->money_parsing_options()) {
489 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
490 model_->money_parsing_options()->separators());
Tony Mak378c1f52019-03-04 15:58:11 +0000491 }
492
Tony Makad2e22d2019-03-20 17:35:13 +0000493 if (model_->duration_annotator_options() &&
494 model_->duration_annotator_options()->enabled()) {
495 duration_annotator_.reset(
496 new DurationAnnotator(model_->duration_annotator_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +0100497 selection_feature_processor_.get(), unilib_));
Tony Makad2e22d2019-03-20 17:35:13 +0000498 }
499
Tony Makd9446602019-02-20 18:25:39 +0000500 if (model_->entity_data_schema()) {
501 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
502 model_->entity_data_schema()->Data(),
503 model_->entity_data_schema()->size());
504 if (entity_data_schema_ == nullptr) {
505 TC3_LOG(ERROR) << "Could not load entity data schema data.";
506 return;
507 }
508
509 entity_data_builder_.reset(
510 new ReflectiveFlatbufferBuilder(entity_data_schema_));
511 } else {
Tony Mak378c1f52019-03-04 15:58:11 +0000512 entity_data_schema_ = nullptr;
Tony Makd9446602019-02-20 18:25:39 +0000513 entity_data_builder_ = nullptr;
514 }
515
Tony Mak21460022020-03-12 18:29:35 +0000516 if (model_->grammar_model()) {
517 grammar_annotator_.reset(new GrammarAnnotator(
518 unilib_, model_->grammar_model(), entity_data_builder_.get()));
519 }
520
Tony Makdf54e742019-03-26 14:04:00 +0000521 if (model_->triggering_locales() &&
522 !ParseLocales(model_->triggering_locales()->c_str(),
523 &model_triggering_locales_)) {
Tony Mak378c1f52019-03-04 15:58:11 +0000524 TC3_LOG(ERROR) << "Could not parse model supported locales.";
525 return;
526 }
527
528 if (model_->triggering_options() != nullptr &&
Tony Makdf54e742019-03-26 14:04:00 +0000529 model_->triggering_options()->locales() != nullptr &&
530 !ParseLocales(model_->triggering_options()->locales()->c_str(),
531 &ml_model_triggering_locales_)) {
532 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
533 return;
534 }
535
536 if (model_->triggering_options() != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +0000537 model_->triggering_options()->dictionary_locales() != nullptr &&
538 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
539 &dictionary_locales_)) {
540 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
541 return;
542 }
543
Lukas Zilka21d8c982018-01-24 11:11:20 +0100544 initialized_ = true;
545}
546
Tony Mak6c4cc672018-09-17 11:48:50 +0100547bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100548 if (!model_->regex_model()->patterns()) {
Lukas Zilka434442d2018-04-25 11:38:51 +0200549 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100550 }
551
552 // Initialize pattern recognizers.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100553 int regex_pattern_id = 0;
554 for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200555 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
Tony Mak378c1f52019-03-04 15:58:11 +0000556 UncompressMakeRegexPattern(
557 *unilib_, regex_pattern->pattern(),
558 regex_pattern->compressed_pattern(),
559 model_->regex_model()->lazy_regex_compilation(), decompressor);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100560 if (!compiled_pattern) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100561 TC3_LOG(INFO) << "Failed to load regex pattern";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200562 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100563 }
564
Lukas Zilkaba849e72018-03-08 14:48:21 +0100565 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100566 annotation_regex_patterns_.push_back(regex_pattern_id);
567 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100568 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100569 classification_regex_patterns_.push_back(regex_pattern_id);
570 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100571 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100572 selection_regex_patterns_.push_back(regex_pattern_id);
573 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100574 regex_patterns_.push_back({
Tony Mak854015a2019-01-16 15:56:48 +0000575 regex_pattern,
Tony Mak6c4cc672018-09-17 11:48:50 +0100576 std::move(compiled_pattern),
Tony Mak6c4cc672018-09-17 11:48:50 +0100577 });
Lukas Zilkab23e2122018-02-09 10:25:19 +0100578 ++regex_pattern_id;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100579 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100580
Lukas Zilkab23e2122018-02-09 10:25:19 +0100581 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100582}
583
Tony Mak6c4cc672018-09-17 11:48:50 +0100584bool Annotator::InitializeKnowledgeEngine(
585 const std::string& serialized_config) {
Tony Maka2a1ff42019-09-12 15:40:32 +0100586 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
Tony Mak63959242020-02-07 18:31:16 +0000587 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100588 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
589 return false;
590 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100591 if (model_->triggering_options() != nullptr) {
592 knowledge_engine->SetPriorityScore(
593 model_->triggering_options()->knowledge_priority_score());
594 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100595 knowledge_engine_ = std::move(knowledge_engine);
596 return true;
597}
598
Tony Mak854015a2019-01-16 15:56:48 +0000599bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
Tony Mak5dc5e112019-02-01 14:52:10 +0000600 std::unique_ptr<ContactEngine> contact_engine(
Tony Mak63959242020-02-07 18:31:16 +0000601 new ContactEngine(selection_feature_processor_.get(), unilib_,
602 model_->contact_annotator_options()));
Tony Mak854015a2019-01-16 15:56:48 +0000603 if (!contact_engine->Initialize(serialized_config)) {
604 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
605 return false;
606 }
607 contact_engine_ = std::move(contact_engine);
608 return true;
609}
610
Tony Makd9446602019-02-20 18:25:39 +0000611bool Annotator::InitializeInstalledAppEngine(
612 const std::string& serialized_config) {
613 std::unique_ptr<InstalledAppEngine> installed_app_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000614 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
Tony Makd9446602019-02-20 18:25:39 +0000615 if (!installed_app_engine->Initialize(serialized_config)) {
616 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
617 return false;
618 }
619 installed_app_engine_ = std::move(installed_app_engine);
620 return true;
621}
622
Tony Mak63959242020-02-07 18:31:16 +0000623void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
624 lang_id_ = lang_id;
Tony Mak21460022020-03-12 18:29:35 +0000625 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
Tony Mak63959242020-02-07 18:31:16 +0000626 model_->translate_annotator_options()->enabled()) {
627 translate_annotator_.reset(new TranslateAnnotator(
628 model_->translate_annotator_options(), lang_id_, unilib_));
Tony Mak21460022020-03-12 18:29:35 +0000629 } else {
630 translate_annotator_.reset(nullptr);
Tony Mak63959242020-02-07 18:31:16 +0000631 }
632}
633
Tony Mak21460022020-03-12 18:29:35 +0000634bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
635 int size) {
636 const PersonNameModel* person_name_model =
637 LoadAndVerifyPersonNameModel(buffer, size);
Tony Mak76d80962020-01-08 17:30:51 +0000638
639 if (person_name_model == nullptr) {
640 TC3_LOG(ERROR) << "Person name model verification failed.";
641 return false;
642 }
643
644 if (!person_name_model->enabled()) {
645 return true;
646 }
647
648 std::unique_ptr<PersonNameEngine> person_name_engine(
Tony Mak21460022020-03-12 18:29:35 +0000649 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
Tony Mak76d80962020-01-08 17:30:51 +0000650 if (!person_name_engine->Initialize(person_name_model)) {
651 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
652 return false;
653 }
654 person_name_engine_ = std::move(person_name_engine);
655 return true;
656}
657
Tony Mak21460022020-03-12 18:29:35 +0000658bool Annotator::InitializePersonNameEngineFromScopedMmap(
659 const ScopedMmap& mmap) {
660 if (!mmap.handle().ok()) {
661 TC3_LOG(ERROR) << "Mmap for person name model failed.";
662 return false;
663 }
664
665 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
666 mmap.handle().num_bytes());
667}
668
669bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
670 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
671 return InitializePersonNameEngineFromScopedMmap(*mmap);
672}
673
674bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
675 int size) {
676 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
677 return InitializePersonNameEngineFromScopedMmap(*mmap);
678}
679
Lukas Zilka21d8c982018-01-24 11:11:20 +0100680namespace {
681
682int CountDigits(const std::string& str, CodepointSpan selection_indices) {
683 int count = 0;
684 int i = 0;
685 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
686 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
687 if (i >= selection_indices.first && i < selection_indices.second &&
Tony Mak21460022020-03-12 18:29:35 +0000688 IsDigit(*it)) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100689 ++count;
690 }
691 }
692 return count;
693}
694
Lukas Zilka21d8c982018-01-24 11:11:20 +0100695} // namespace
696
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200697namespace internal {
698// Helper function, which if the initial 'span' contains only white-spaces,
699// moves the selection to a single-codepoint selection on a left or right side
700// of this space.
701CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
702 const UnicodeText& context_unicode,
703 const UniLib& unilib) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100704 TC3_CHECK(ValidNonEmptySpan(span));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200705
706 UnicodeText::const_iterator it;
707
708 // Check that the current selection is all whitespaces.
709 it = context_unicode.begin();
710 std::advance(it, span.first);
711 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
712 if (!unilib.IsWhitespace(*it)) {
713 return span;
714 }
715 }
716
717 CodepointSpan result;
718
719 // Try moving left.
720 result = span;
721 it = context_unicode.begin();
722 std::advance(it, span.first);
723 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
724 --result.first;
725 --it;
726 }
727 result.second = result.first + 1;
728 if (!unilib.IsWhitespace(*it)) {
729 return result;
730 }
731
732 // If moving left didn't find a non-whitespace character, just return the
733 // original span.
734 return span;
735}
736} // namespace internal
737
Tony Mak6c4cc672018-09-17 11:48:50 +0100738bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200739 return !span.classification.empty() &&
740 filtered_collections_annotation_.find(
741 span.classification[0].collection) !=
742 filtered_collections_annotation_.end();
743}
744
Tony Mak6c4cc672018-09-17 11:48:50 +0100745bool Annotator::FilteredForClassification(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200746 const ClassificationResult& classification) const {
747 return filtered_collections_classification_.find(classification.collection) !=
748 filtered_collections_classification_.end();
749}
750
Tony Mak6c4cc672018-09-17 11:48:50 +0100751bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200752 return !span.classification.empty() &&
753 filtered_collections_selection_.find(
754 span.classification[0].collection) !=
755 filtered_collections_selection_.end();
756}
757
Tony Mak378c1f52019-03-04 15:58:11 +0000758namespace {
759inline bool ClassifiedAsOther(
760 const std::vector<ClassificationResult>& classification) {
761 return !classification.empty() &&
762 classification[0].collection == Collections::Other();
763}
764
Tony Maka2a1ff42019-09-12 15:40:32 +0100765} // namespace
766
767float Annotator::GetPriorityScore(
768 const std::vector<ClassificationResult>& classification) const {
Tony Mak378c1f52019-03-04 15:58:11 +0000769 if (!classification.empty() && !ClassifiedAsOther(classification)) {
770 return classification[0].priority_score;
771 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +0100772 if (model_->triggering_options() != nullptr) {
773 return model_->triggering_options()->other_collection_priority_score();
774 } else {
775 return -1000.0;
776 }
Tony Mak378c1f52019-03-04 15:58:11 +0000777 }
778}
Tony Mak378c1f52019-03-04 15:58:11 +0000779
Tony Makdf54e742019-03-26 14:04:00 +0000780bool Annotator::VerifyRegexMatchCandidate(
781 const std::string& context, const VerificationOptions* verification_options,
782 const std::string& match, const UniLib::RegexMatcher* matcher) const {
783 if (verification_options == nullptr) {
784 return true;
785 }
786 if (verification_options->verify_luhn_checksum() &&
787 !VerifyLuhnChecksum(match)) {
788 return false;
789 }
790 const int lua_verifier = verification_options->lua_verifier();
791 if (lua_verifier >= 0) {
792 if (model_->regex_model()->lua_verifier() == nullptr ||
793 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
794 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
795 return false;
796 }
797 return VerifyMatch(
798 context, matcher,
799 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
800 }
801 return true;
802}
803
Tony Mak6c4cc672018-09-17 11:48:50 +0100804CodepointSpan Annotator::SuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100805 const std::string& context, CodepointSpan click_indices,
806 const SelectionOptions& options) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200807 CodepointSpan original_click_indices = click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100808 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100809 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200810 return original_click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100811 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100812 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200813 return original_click_indices;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100814 }
815
Tony Makdf54e742019-03-26 14:04:00 +0000816 std::vector<Locale> detected_text_language_tags;
817 if (!ParseLocales(options.detected_text_language_tags,
818 &detected_text_language_tags)) {
819 TC3_LOG(WARNING)
820 << "Failed to parse the detected_text_language_tags in options: "
821 << options.detected_text_language_tags;
822 }
823 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
824 model_triggering_locales_,
825 /*default_value=*/true)) {
826 return original_click_indices;
827 }
828
Lukas Zilkadf710db2018-02-27 12:44:09 +0100829 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
830 /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200831
Tony Mak968412a2019-11-13 15:39:57 +0000832 if (!IsValidSpanInput(context_unicode, click_indices)) {
833 TC3_VLOG(1)
834 << "Trying to run SuggestSelection with invalid input, indices: "
835 << click_indices.first << " " << click_indices.second;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200836 return original_click_indices;
837 }
838
839 if (model_->snap_whitespace_selections()) {
840 // We want to expand a purely white-space selection to a multi-selection it
841 // would've been part of. But with this feature disabled we would do a no-
842 // op, because no token is found. Therefore, we need to modify the
843 // 'click_indices' a bit to include a part of the token, so that the click-
844 // finding logic finds the clicked token correctly. This modification is
845 // done by the following function. Note, that it's enough to check the left
846 // side of the current selection, because if the white-space is a part of a
Tony Mak6c4cc672018-09-17 11:48:50 +0100847 // multi-selection, necessarily both tokens - on the left and the right
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200848 // sides need to be selected. Thus snapping only to the left is sufficient
849 // (there's a check at the bottom that makes sure that if we snap to the
850 // left token but the result does not contain the initial white-space,
851 // returns the original indices).
852 click_indices = internal::SnapLeftIfWhitespaceSelection(
853 click_indices, context_unicode, *unilib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100854 }
855
Lukas Zilkab23e2122018-02-09 10:25:19 +0100856 std::vector<AnnotatedSpan> candidates;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100857 InterpreterManager interpreter_manager(selection_executor_.get(),
858 classification_executor_.get());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200859 std::vector<Token> tokens;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100860 if (!ModelSuggestSelection(context_unicode, click_indices,
Tony Makdf54e742019-03-26 14:04:00 +0000861 detected_text_language_tags, &interpreter_manager,
862 &tokens, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100863 TC3_LOG(ERROR) << "Model suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200864 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100865 }
Tony Mak83d2de62019-04-10 16:12:15 +0100866 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
867 /*is_serialized_entity_data_enabled=*/false)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100868 TC3_LOG(ERROR) << "Regex suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200869 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100870 }
Tony Mak83d2de62019-04-10 16:12:15 +0100871 if (!DatetimeChunk(
872 UTF8ToUnicodeText(context, /*do_copy=*/false),
873 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
874 options.locales, ModeFlag_SELECTION, options.annotation_usecase,
875 /*is_serialized_entity_data_enabled=*/false, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100876 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
877 return original_click_indices;
878 }
Tony Mak378c1f52019-03-04 15:58:11 +0000879 if (knowledge_engine_ != nullptr &&
Tony Maka2a1ff42019-09-12 15:40:32 +0100880 !knowledge_engine_->Chunk(context, options.annotation_usecase,
Tony Mak63959242020-02-07 18:31:16 +0000881 options.location_context, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100882 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200883 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100884 }
Tony Mak378c1f52019-03-04 15:58:11 +0000885 if (contact_engine_ != nullptr &&
Tony Mak854015a2019-01-16 15:56:48 +0000886 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
887 TC3_LOG(ERROR) << "Contact suggest selection failed.";
888 return original_click_indices;
889 }
Tony Mak378c1f52019-03-04 15:58:11 +0000890 if (installed_app_engine_ != nullptr &&
Tony Makd9446602019-02-20 18:25:39 +0000891 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
892 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
893 return original_click_indices;
894 }
Tony Mak378c1f52019-03-04 15:58:11 +0000895 if (number_annotator_ != nullptr &&
896 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
897 &candidates)) {
898 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
899 return original_click_indices;
900 }
Tony Makad2e22d2019-03-20 17:35:13 +0000901 if (duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +0000902 !duration_annotator_->FindAll(context_unicode, tokens,
903 options.annotation_usecase, &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +0000904 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
905 return original_click_indices;
906 }
Tony Mak76d80962020-01-08 17:30:51 +0000907 if (person_name_engine_ != nullptr &&
908 !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
909 TC3_LOG(ERROR) << "Person name suggest selection failed.";
910 return original_click_indices;
911 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100912
Tony Mak21460022020-03-12 18:29:35 +0000913 AnnotatedSpan grammar_suggested_span;
914 if (grammar_annotator_ != nullptr &&
915 grammar_annotator_->SuggestSelection(detected_text_language_tags,
916 context_unicode, click_indices,
917 &grammar_suggested_span)) {
918 candidates.push_back(grammar_suggested_span);
919 }
920
Lukas Zilkab23e2122018-02-09 10:25:19 +0100921 // Sort candidates according to their position in the input, so that the next
922 // code can assume that any connected component of overlapping spans forms a
923 // contiguous block.
924 std::sort(candidates.begin(), candidates.end(),
925 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
926 return a.span.first < b.span.first;
927 });
928
929 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +0000930 if (!ResolveConflicts(candidates, context, tokens,
931 detected_text_language_tags, options.annotation_usecase,
932 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100933 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200934 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100935 }
936
Tony Mak378c1f52019-03-04 15:58:11 +0000937 std::sort(candidate_indices.begin(), candidate_indices.end(),
Tony Maka2a1ff42019-09-12 15:40:32 +0100938 [this, &candidates](int a, int b) {
Tony Mak378c1f52019-03-04 15:58:11 +0000939 return GetPriorityScore(candidates[a].classification) >
940 GetPriorityScore(candidates[b].classification);
941 });
942
Lukas Zilkab23e2122018-02-09 10:25:19 +0100943 for (const int i : candidate_indices) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200944 if (SpansOverlap(candidates[i].span, click_indices) &&
945 SpansOverlap(candidates[i].span, original_click_indices)) {
946 // Run model classification if not present but requested and there's a
947 // classification collection filter specified.
948 if (candidates[i].classification.empty() &&
949 model_->selection_options()->always_classify_suggested_selection() &&
950 !filtered_collections_selection_.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +0000951 if (!ModelClassifyText(context, detected_text_language_tags,
952 candidates[i].span, &interpreter_manager,
953 /*embedding_cache=*/nullptr,
954 &candidates[i].classification)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200955 return original_click_indices;
956 }
957 }
958
959 // Ignore if span classification is filtered.
960 if (FilteredForSelection(candidates[i])) {
961 return original_click_indices;
962 }
963
Lukas Zilkab23e2122018-02-09 10:25:19 +0100964 return candidates[i].span;
965 }
966 }
967
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200968 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100969}
970
971namespace {
972// Helper function that returns the index of the first candidate that
973// transitively does not overlap with the candidate on 'start_index'. If the end
974// of 'candidates' is reached, it returns the index that points right behind the
975// array.
976int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
977 int start_index) {
978 int first_non_overlapping = start_index + 1;
979 CodepointSpan conflicting_span = candidates[start_index].span;
980 while (
981 first_non_overlapping < candidates.size() &&
982 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
983 // Grow the span to include the current one.
984 conflicting_span.second = std::max(
985 conflicting_span.second, candidates[first_non_overlapping].span.second);
986
987 ++first_non_overlapping;
988 }
989 return first_non_overlapping;
990}
991} // namespace
992
Tony Mak378c1f52019-03-04 15:58:11 +0000993bool Annotator::ResolveConflicts(
994 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
995 const std::vector<Token>& cached_tokens,
996 const std::vector<Locale>& detected_text_language_tags,
997 AnnotationUsecase annotation_usecase,
998 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100999 result->clear();
1000 result->reserve(candidates.size());
1001 for (int i = 0; i < candidates.size();) {
1002 int first_non_overlapping =
1003 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1004
1005 const bool conflict_found = first_non_overlapping != (i + 1);
1006 if (conflict_found) {
1007 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00001008 if (!ResolveConflict(context, cached_tokens, candidates,
1009 detected_text_language_tags, i,
1010 first_non_overlapping, annotation_usecase,
1011 interpreter_manager, &candidate_indices)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001012 return false;
1013 }
1014 result->insert(result->end(), candidate_indices.begin(),
1015 candidate_indices.end());
1016 } else {
1017 result->push_back(i);
1018 }
1019
1020 // Skip over the whole conflicting group/go to next candidate.
1021 i = first_non_overlapping;
1022 }
1023 return true;
1024}
1025
1026namespace {
Tony Mak448b5862019-03-22 13:36:41 +00001027// Returns true, if the given two sources do conflict in given annotation
1028// usecase.
1029// - In SMART usecase, all sources do conflict, because there's only 1 possible
1030// annotation for a given span.
1031// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1032// and duration), while others not (e.g. duration and number).
1033bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1034 const AnnotatedSpan::Source source1,
1035 const AnnotatedSpan::Source source2) {
1036 uint32 source_mask =
1037 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1038
Tony Mak378c1f52019-03-04 15:58:11 +00001039 switch (annotation_usecase) {
1040 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
Tony Mak448b5862019-03-22 13:36:41 +00001041 // In the SMART mode, all annotations conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001042 return true;
Tony Mak448b5862019-03-22 13:36:41 +00001043
Tony Mak378c1f52019-03-04 15:58:11 +00001044 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
Tony Mak448b5862019-03-22 13:36:41 +00001045 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1046 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1047 // hours" (duration).
1048 if ((source_mask &
1049 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1050 (source_mask &
1051 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1052 return false;
Tony Mak378c1f52019-03-04 15:58:11 +00001053 }
Tony Mak448b5862019-03-22 13:36:41 +00001054
1055 // A KNOWLEDGE entity does not conflict with anything.
1056 if ((source_mask &
1057 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1058 return false;
1059 }
1060
1061 // Entities from other sources can conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001062 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001063 }
1064}
1065} // namespace
1066
Tony Mak378c1f52019-03-04 15:58:11 +00001067bool Annotator::ResolveConflict(
1068 const std::string& context, const std::vector<Token>& cached_tokens,
1069 const std::vector<AnnotatedSpan>& candidates,
1070 const std::vector<Locale>& detected_text_language_tags, int start_index,
1071 int end_index, AnnotationUsecase annotation_usecase,
1072 InterpreterManager* interpreter_manager,
1073 std::vector<int>* chosen_indices) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001074 std::vector<int> conflicting_indices;
Tony Mak76d80962020-01-08 17:30:51 +00001075 std::unordered_map<int, std::pair<float, int>> scores_lengths;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001076 for (int i = start_index; i < end_index; ++i) {
1077 conflicting_indices.push_back(i);
1078 if (!candidates[i].classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001079 scores_lengths[i] = {
1080 GetPriorityScore(candidates[i].classification),
1081 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001082 continue;
1083 }
1084
1085 // OPTIMIZATION: So that we don't have to classify all the ML model
1086 // spans apriori, we wait until we get here, when they conflict with
1087 // something and we need the actual classification scores. So if the
1088 // candidate conflicts and comes from the model, we need to run a
1089 // classification to determine its priority:
1090 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001091 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1092 candidates[i].span, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001093 /*embedding_cache=*/nullptr, &classification)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001094 return false;
1095 }
1096
1097 if (!classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001098 scores_lengths[i] = {
1099 GetPriorityScore(classification),
1100 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001101 }
1102 }
1103
Tony Mak101bc2a2020-01-09 12:32:17 +00001104 const bool prioritize_longest_annotation =
1105 model_->triggering_options() != nullptr &&
1106 model_->triggering_options()->prioritize_longest_annotation();
1107 std::sort(conflicting_indices.begin(), conflicting_indices.end(),
1108 [&scores_lengths, candidates, conflicting_indices,
1109 prioritize_longest_annotation](int i, int j) {
1110 if (scores_lengths[i].first == scores_lengths[j].first &&
1111 prioritize_longest_annotation) {
1112 return scores_lengths[i].second > scores_lengths[j].second;
1113 }
1114 return scores_lengths[i].first > scores_lengths[j].first;
1115 });
Lukas Zilkab23e2122018-02-09 10:25:19 +01001116
Tony Mak448b5862019-03-22 13:36:41 +00001117 // Here we keep a set of indices that were chosen, per-source, to enable
1118 // effective computation.
1119 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1120 chosen_indices_for_source_map;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001121
1122 // Greedily place the candidates if they don't conflict with the already
1123 // placed ones.
1124 for (int i = 0; i < conflicting_indices.size(); ++i) {
1125 const int considered_candidate = conflicting_indices[i];
Tony Mak448b5862019-03-22 13:36:41 +00001126
1127 // See if there is a conflict between the candidate and all already placed
1128 // candidates.
1129 bool conflict = false;
1130 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1131 for (auto& source_set_pair : chosen_indices_for_source_map) {
1132 if (source_set_pair.first == candidates[considered_candidate].source) {
1133 chosen_indices_for_source_ptr = &source_set_pair.second;
1134 }
1135
1136 if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
1137 candidates[considered_candidate].source) &&
1138 DoesCandidateConflict(considered_candidate, candidates,
1139 source_set_pair.second)) {
1140 conflict = true;
1141 break;
1142 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001143 }
Tony Mak448b5862019-03-22 13:36:41 +00001144
1145 // Skip the candidate if a conflict was found.
1146 if (conflict) {
1147 continue;
1148 }
1149
1150 // If the set of indices for the current source doesn't exist yet,
1151 // initialize it.
1152 if (chosen_indices_for_source_ptr == nullptr) {
1153 SortedIntSet new_set([&candidates](int a, int b) {
1154 return candidates[a].span.first < candidates[b].span.first;
1155 });
1156 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1157 std::move(new_set);
1158 chosen_indices_for_source_ptr =
1159 &chosen_indices_for_source_map[candidates[considered_candidate]
1160 .source];
1161 }
1162
1163 // Place the candidate to the output and to the per-source conflict set.
1164 chosen_indices->push_back(considered_candidate);
1165 chosen_indices_for_source_ptr->insert(considered_candidate);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001166 }
1167
Tony Mak378c1f52019-03-04 15:58:11 +00001168 std::sort(chosen_indices->begin(), chosen_indices->end());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001169
1170 return true;
1171}
1172
Tony Mak6c4cc672018-09-17 11:48:50 +01001173bool Annotator::ModelSuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001174 const UnicodeText& context_unicode, CodepointSpan click_indices,
Tony Makdf54e742019-03-26 14:04:00 +00001175 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001176 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001177 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001178 if (model_->triggering_options() == nullptr ||
1179 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1180 return true;
1181 }
1182
Tony Makdf54e742019-03-26 14:04:00 +00001183 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1184 ml_model_triggering_locales_,
1185 /*default_value=*/true)) {
1186 return true;
1187 }
1188
Lukas Zilka21d8c982018-01-24 11:11:20 +01001189 int click_pos;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001190 *tokens = selection_feature_processor_->Tokenize(context_unicode);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001191 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001192 context_unicode, click_indices,
1193 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001194 tokens, &click_pos);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001195 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001196 TC3_VLOG(1) << "Could not calculate the click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001197 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001198 }
1199
1200 const int symmetry_context_size =
1201 model_->selection_options()->symmetry_context_size();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001202 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
Lukas Zilkab23e2122018-02-09 10:25:19 +01001203 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1204 ->bounds_sensitive_features();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001205
1206 // The symmetry context span is the clicked token with symmetry_context_size
1207 // tokens on either side.
1208 const TokenSpan symmetry_context_span = IntersectTokenSpans(
1209 ExpandTokenSpan(SingleTokenSpan(click_pos),
1210 /*num_tokens_left=*/symmetry_context_size,
1211 /*num_tokens_right=*/symmetry_context_size),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001212 {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001213
Lukas Zilkab23e2122018-02-09 10:25:19 +01001214 // Compute the extraction span based on the model type.
1215 TokenSpan extraction_span;
1216 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1217 // The extraction span is the symmetry context span expanded to include
1218 // max_selection_span tokens on either side, which is how far a selection
1219 // can stretch from the click, plus a relevant number of tokens outside of
1220 // the bounds of the selection.
1221 const int max_selection_span =
1222 selection_feature_processor_->GetOptions()->max_selection_span();
1223 extraction_span =
1224 ExpandTokenSpan(symmetry_context_span,
1225 /*num_tokens_left=*/max_selection_span +
1226 bounds_sensitive_features->num_tokens_before(),
1227 /*num_tokens_right=*/max_selection_span +
1228 bounds_sensitive_features->num_tokens_after());
1229 } else {
1230 // The extraction span is the symmetry context span expanded to include
1231 // context_size tokens on either side.
1232 const int context_size =
1233 selection_feature_processor_->GetOptions()->context_size();
1234 extraction_span = ExpandTokenSpan(symmetry_context_span,
1235 /*num_tokens_left=*/context_size,
1236 /*num_tokens_right=*/context_size);
1237 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001238 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilkab23e2122018-02-09 10:25:19 +01001239
Lukas Zilka434442d2018-04-25 11:38:51 +02001240 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1241 *tokens, extraction_span)) {
1242 return true;
1243 }
1244
Lukas Zilkab23e2122018-02-09 10:25:19 +01001245 std::unique_ptr<CachedFeatures> cached_features;
1246 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001247 *tokens, extraction_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001248 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1249 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001250 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001251 selection_feature_processor_->EmbeddingSize() +
1252 selection_feature_processor_->DenseFeaturesCount(),
1253 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001254 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001255 return false;
1256 }
1257
1258 // Produce selection model candidates.
1259 std::vector<TokenSpan> chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001260 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001261 interpreter_manager->SelectionInterpreter(), *cached_features,
1262 &chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001263 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001264 return false;
1265 }
1266
1267 for (const TokenSpan& chunk : chunks) {
1268 AnnotatedSpan candidate;
1269 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001270 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001271 if (model_->selection_options()->strip_unpaired_brackets()) {
1272 candidate.span =
1273 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1274 }
1275
1276 // Only output non-empty spans.
1277 if (candidate.span.first != candidate.span.second) {
1278 result->push_back(candidate);
1279 }
1280 }
1281 return true;
1282}
1283
Tony Mak6c4cc672018-09-17 11:48:50 +01001284bool Annotator::ModelClassifyText(
Tony Mak378c1f52019-03-04 15:58:11 +00001285 const std::string& context,
1286 const std::vector<Locale>& detected_text_language_tags,
1287 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001288 FeatureProcessor::EmbeddingCache* embedding_cache,
1289 std::vector<ClassificationResult>* classification_results) const {
Tony Mak378c1f52019-03-04 15:58:11 +00001290 return ModelClassifyText(context, {}, detected_text_language_tags,
1291 selection_indices, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001292 embedding_cache, classification_results);
1293}
1294
1295namespace internal {
1296std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1297 CodepointSpan selection_indices,
1298 TokenSpan tokens_around_selection_to_copy) {
1299 const auto first_selection_token = std::upper_bound(
1300 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1301 [](int selection_start, const Token& token) {
1302 return selection_start < token.end;
1303 });
1304 const auto last_selection_token = std::lower_bound(
1305 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1306 [](const Token& token, int selection_end) {
1307 return token.start < selection_end;
1308 });
1309
1310 const int64 first_token = std::max(
1311 static_cast<int64>(0),
1312 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1313 tokens_around_selection_to_copy.first));
1314 const int64 last_token = std::min(
1315 static_cast<int64>(cached_tokens.size()),
1316 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1317 tokens_around_selection_to_copy.second));
1318
1319 std::vector<Token> tokens;
1320 tokens.reserve(last_token - first_token);
1321 for (int i = first_token; i < last_token; ++i) {
1322 tokens.push_back(cached_tokens[i]);
1323 }
1324 return tokens;
1325}
1326} // namespace internal
1327
Tony Mak6c4cc672018-09-17 11:48:50 +01001328TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001329 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1330 bounds_sensitive_features =
1331 classification_feature_processor_->GetOptions()
1332 ->bounds_sensitive_features();
1333 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1334 // The extraction span is the selection span expanded to include a relevant
1335 // number of tokens outside of the bounds of the selection.
1336 return {bounds_sensitive_features->num_tokens_before(),
1337 bounds_sensitive_features->num_tokens_after()};
1338 } else {
1339 // The extraction span is the clicked token with context_size tokens on
1340 // either side.
1341 const int context_size =
1342 selection_feature_processor_->GetOptions()->context_size();
1343 return {context_size, context_size};
1344 }
1345}
1346
Tony Mak378c1f52019-03-04 15:58:11 +00001347namespace {
1348// Sorts the classification results from high score to low score.
1349void SortClassificationResults(
1350 std::vector<ClassificationResult>* classification_results) {
1351 std::sort(classification_results->begin(), classification_results->end(),
1352 [](const ClassificationResult& a, const ClassificationResult& b) {
1353 return a.score > b.score;
1354 });
1355}
1356} // namespace
1357
Tony Mak6c4cc672018-09-17 11:48:50 +01001358bool Annotator::ModelClassifyText(
Lukas Zilkaba849e72018-03-08 14:48:21 +01001359 const std::string& context, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001360 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001361 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1362 FeatureProcessor::EmbeddingCache* embedding_cache,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001363 std::vector<ClassificationResult>* classification_results) const {
1364 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00001365 return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1366 selection_indices, interpreter_manager,
1367 embedding_cache, classification_results, &tokens);
1368}
1369
1370bool Annotator::ModelClassifyText(
1371 const std::string& context, const std::vector<Token>& cached_tokens,
1372 const std::vector<Locale>& detected_text_language_tags,
1373 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1374 FeatureProcessor::EmbeddingCache* embedding_cache,
1375 std::vector<ClassificationResult>* classification_results,
1376 std::vector<Token>* tokens) const {
1377 if (model_->triggering_options() == nullptr ||
1378 !(model_->triggering_options()->enabled_modes() &
1379 ModeFlag_CLASSIFICATION)) {
1380 return true;
1381 }
1382
Tony Makdf54e742019-03-26 14:04:00 +00001383 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1384 ml_model_triggering_locales_,
1385 /*default_value=*/true)) {
1386 return true;
1387 }
1388
Lukas Zilkaba849e72018-03-08 14:48:21 +01001389 if (cached_tokens.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001390 *tokens = classification_feature_processor_->Tokenize(context);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001391 } else {
Tony Mak378c1f52019-03-04 15:58:11 +00001392 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1393 ClassifyTextUpperBoundNeededTokens());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001394 }
1395
Lukas Zilkab23e2122018-02-09 10:25:19 +01001396 int click_pos;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001397 classification_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001398 context, selection_indices,
1399 classification_feature_processor_->GetOptions()
1400 ->only_use_line_with_click(),
Tony Mak378c1f52019-03-04 15:58:11 +00001401 tokens, &click_pos);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001402 const TokenSpan selection_token_span =
Tony Mak378c1f52019-03-04 15:58:11 +00001403 CodepointSpanToTokenSpan(*tokens, selection_indices);
Lukas Zilka434442d2018-04-25 11:38:51 +02001404 const int selection_num_tokens = TokenSpanSize(selection_token_span);
1405 if (model_->classification_options()->max_num_tokens() > 0 &&
1406 model_->classification_options()->max_num_tokens() <
1407 selection_num_tokens) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001408 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001409 return true;
1410 }
1411
Lukas Zilkab23e2122018-02-09 10:25:19 +01001412 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1413 bounds_sensitive_features =
1414 classification_feature_processor_->GetOptions()
1415 ->bounds_sensitive_features();
1416 if (selection_token_span.first == kInvalidIndex ||
1417 selection_token_span.second == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001418 TC3_LOG(ERROR) << "Could not determine span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001419 return false;
1420 }
1421
1422 // Compute the extraction span based on the model type.
1423 TokenSpan extraction_span;
1424 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1425 // The extraction span is the selection span expanded to include a relevant
1426 // number of tokens outside of the bounds of the selection.
1427 extraction_span = ExpandTokenSpan(
1428 selection_token_span,
1429 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1430 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1431 } else {
1432 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001433 TC3_LOG(ERROR) << "Couldn't choose a click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001434 return false;
1435 }
1436 // The extraction span is the clicked token with context_size tokens on
1437 // either side.
1438 const int context_size =
Lukas Zilkaba849e72018-03-08 14:48:21 +01001439 classification_feature_processor_->GetOptions()->context_size();
Lukas Zilkab23e2122018-02-09 10:25:19 +01001440 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
1441 /*num_tokens_left=*/context_size,
1442 /*num_tokens_right=*/context_size);
1443 }
Tony Mak378c1f52019-03-04 15:58:11 +00001444 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001445
Lukas Zilka434442d2018-04-25 11:38:51 +02001446 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak378c1f52019-03-04 15:58:11 +00001447 *tokens, extraction_span)) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001448 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001449 return true;
1450 }
1451
Lukas Zilka21d8c982018-01-24 11:11:20 +01001452 std::unique_ptr<CachedFeatures> cached_features;
1453 if (!classification_feature_processor_->ExtractFeatures(
Tony Mak378c1f52019-03-04 15:58:11 +00001454 *tokens, extraction_span, selection_indices,
1455 embedding_executor_.get(), embedding_cache,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001456 classification_feature_processor_->EmbeddingSize() +
1457 classification_feature_processor_->DenseFeaturesCount(),
1458 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001459 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001460 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001461 }
1462
Lukas Zilkab23e2122018-02-09 10:25:19 +01001463 std::vector<float> features;
1464 features.reserve(cached_features->OutputFeaturesSize());
1465 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1466 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1467 &features);
1468 } else {
1469 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001470 }
1471
Lukas Zilkaba849e72018-03-08 14:48:21 +01001472 TensorView<float> logits = classification_executor_->ComputeLogits(
1473 TensorView<float>(features.data(),
1474 {1, static_cast<int>(features.size())}),
1475 interpreter_manager->ClassificationInterpreter());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001476 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001477 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001478 return false;
1479 }
1480
1481 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1482 logits.dim(1) != classification_feature_processor_->NumCollections()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001483 TC3_LOG(ERROR) << "Mismatching output";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001484 return false;
1485 }
1486
1487 const std::vector<float> scores =
1488 ComputeSoftmax(logits.data(), logits.dim(1));
1489
Tony Mak81e52422019-04-30 09:34:45 +01001490 if (scores.empty()) {
1491 *classification_results = {{Collections::Other(), 1.0}};
1492 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001493 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001494
Tony Mak81e52422019-04-30 09:34:45 +01001495 const int best_score_index =
1496 std::max_element(scores.begin(), scores.end()) - scores.begin();
1497 const std::string top_collection =
1498 classification_feature_processor_->LabelToCollection(best_score_index);
1499
1500 // Sanity checks.
1501 if (top_collection == Collections::Phone()) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001502 const int digit_count = CountDigits(context, selection_indices);
1503 if (digit_count <
1504 model_->classification_options()->phone_min_num_digits() ||
1505 digit_count >
1506 model_->classification_options()->phone_max_num_digits()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001507 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001508 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001509 }
Tony Mak81e52422019-04-30 09:34:45 +01001510 } else if (top_collection == Collections::Address()) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001511 if (selection_num_tokens <
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001512 model_->classification_options()->address_min_num_tokens()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001513 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001514 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001515 }
Tony Mak81e52422019-04-30 09:34:45 +01001516 } else if (top_collection == Collections::Dictionary()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001517 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1518 dictionary_locales_,
Tony Makdf54e742019-03-26 14:04:00 +00001519 /*default_value=*/false)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001520 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001521 return true;
Tony Mak378c1f52019-03-04 15:58:11 +00001522 }
1523 }
Tony Mak81e52422019-04-30 09:34:45 +01001524
1525 *classification_results = {{top_collection, 1.0, scores[best_score_index]}};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001526 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001527}
1528
Tony Mak6c4cc672018-09-17 11:48:50 +01001529bool Annotator::RegexClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001530 const std::string& context, CodepointSpan selection_indices,
Tony Mak378c1f52019-03-04 15:58:11 +00001531 std::vector<ClassificationResult>* classification_result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001532 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001533 UTF8ToUnicodeText(context, /*do_copy=*/false)
1534 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001535 const UnicodeText selection_text_unicode(
1536 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1537
1538 // Check whether any of the regular expressions match.
1539 for (const int pattern_id : classification_regex_patterns_) {
1540 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1541 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1542 regex_pattern.pattern->Matcher(selection_text_unicode);
1543 int status = UniLib::RegexMatcher::kNoError;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001544 bool matches;
Tony Mak854015a2019-01-16 15:56:48 +00001545 if (regex_pattern.config->use_approximate_matching()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001546 matches = matcher->ApproximatelyMatches(&status);
1547 } else {
1548 matches = matcher->Matches(&status);
1549 }
1550 if (status != UniLib::RegexMatcher::kNoError) {
1551 return false;
1552 }
Tony Makdf54e742019-03-26 14:04:00 +00001553 if (matches && VerifyRegexMatchCandidate(
1554 context, regex_pattern.config->verification_options(),
1555 selection_text, matcher.get())) {
Tony Mak378c1f52019-03-04 15:58:11 +00001556 classification_result->push_back(
1557 {regex_pattern.config->collection_name()->str(),
1558 regex_pattern.config->target_classification_score(),
1559 regex_pattern.config->priority_score()});
Tony Makd9446602019-02-20 18:25:39 +00001560 if (!SerializedEntityDataFromRegexMatch(
1561 regex_pattern.config, matcher.get(),
Tony Mak378c1f52019-03-04 15:58:11 +00001562 &classification_result->back().serialized_entity_data)) {
Tony Makd9446602019-02-20 18:25:39 +00001563 TC3_LOG(ERROR) << "Could not get entity data.";
1564 return false;
1565 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001566 }
1567 }
1568
Tony Mak378c1f52019-03-04 15:58:11 +00001569 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001570}
1571
Tony Mak5dc5e112019-02-01 14:52:10 +00001572namespace {
1573std::string PickCollectionForDatetime(
1574 const DatetimeParseResult& datetime_parse_result) {
1575 switch (datetime_parse_result.granularity) {
1576 case GRANULARITY_HOUR:
1577 case GRANULARITY_MINUTE:
1578 case GRANULARITY_SECOND:
1579 return Collections::DateTime();
1580 default:
1581 return Collections::Date();
1582 }
1583}
Tony Mak83d2de62019-04-10 16:12:15 +01001584
1585std::string CreateDatetimeSerializedEntityData(
1586 const DatetimeParseResult& parse_result) {
1587 EntityDataT entity_data;
1588 entity_data.datetime.reset(new EntityData_::DatetimeT());
1589 entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
1590 entity_data.datetime->granularity =
1591 static_cast<EntityData_::Datetime_::Granularity>(
1592 parse_result.granularity);
1593
Tony Maka2a1ff42019-09-12 15:40:32 +01001594 for (const auto& c : parse_result.datetime_components) {
1595 EntityData_::Datetime_::DatetimeComponentT datetime_component;
1596 datetime_component.absolute_value = c.value;
1597 datetime_component.relative_count = c.relative_count;
1598 datetime_component.component_type =
1599 static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
1600 c.component_type);
1601 datetime_component.relation_type =
1602 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
1603 if (c.relative_qualifier !=
1604 DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
1605 datetime_component.relation_type =
1606 EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
1607 }
1608 entity_data.datetime->datetime_component.emplace_back(
1609 new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
1610 }
Tony Mak83d2de62019-04-10 16:12:15 +01001611 flatbuffers::FlatBufferBuilder builder;
1612 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
1613 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
1614 builder.GetSize());
1615}
Tony Mak63959242020-02-07 18:31:16 +00001616
Tony Mak5dc5e112019-02-01 14:52:10 +00001617} // namespace
1618
Tony Mak6c4cc672018-09-17 11:48:50 +01001619bool Annotator::DatetimeClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001620 const std::string& context, CodepointSpan selection_indices,
1621 const ClassificationOptions& options,
Tony Mak854015a2019-01-16 15:56:48 +00001622 std::vector<ClassificationResult>* classification_results) const {
Tony Mak63959242020-02-07 18:31:16 +00001623 if (!datetime_parser_ && !cfg_datetime_parser_) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001624 return false;
1625 }
1626
Lukas Zilkab23e2122018-02-09 10:25:19 +01001627 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001628 UTF8ToUnicodeText(context, /*do_copy=*/false)
1629 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001630
1631 std::vector<DatetimeParseResultSpan> datetime_spans;
Tony Mak63959242020-02-07 18:31:16 +00001632 if (cfg_datetime_parser_) {
1633 if (!(model_->grammar_datetime_model()->enabled_modes() &
1634 ModeFlag_CLASSIFICATION)) {
1635 return true;
1636 }
1637 std::vector<Locale> parsed_locales;
1638 ParseLocales(options.locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00001639 cfg_datetime_parser_->Parse(
1640 selection_text,
1641 ToDateAnnotationOptions(
1642 model_->grammar_datetime_model()->annotation_options(),
1643 options.reference_timezone, options.reference_time_ms_utc),
1644 parsed_locales, &datetime_spans);
Tony Mak63959242020-02-07 18:31:16 +00001645 } else if (datetime_parser_) {
1646 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1647 options.reference_timezone, options.locales,
1648 ModeFlag_CLASSIFICATION,
1649 options.annotation_usecase,
1650 /*anchor_start_end=*/true, &datetime_spans)) {
1651 TC3_LOG(ERROR) << "Error during parsing datetime.";
1652 return false;
1653 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001654 }
1655 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1656 // Only consider the result valid if the selection and extracted datetime
1657 // spans exactly match.
1658 if (std::make_pair(datetime_span.span.first + selection_indices.first,
1659 datetime_span.span.second + selection_indices.first) ==
1660 selection_indices) {
Tony Mak854015a2019-01-16 15:56:48 +00001661 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1662 classification_results->emplace_back(
Tony Mak5dc5e112019-02-01 14:52:10 +00001663 PickCollectionForDatetime(parse_result),
1664 datetime_span.target_classification_score);
Tony Mak854015a2019-01-16 15:56:48 +00001665 classification_results->back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01001666 classification_results->back().serialized_entity_data =
1667 CreateDatetimeSerializedEntityData(parse_result);
Tony Mak378c1f52019-03-04 15:58:11 +00001668 classification_results->back().priority_score =
1669 datetime_span.priority_score;
Tony Mak854015a2019-01-16 15:56:48 +00001670 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001671 return true;
1672 }
1673 }
Tony Mak378c1f52019-03-04 15:58:11 +00001674 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001675}
1676
Tony Mak6c4cc672018-09-17 11:48:50 +01001677std::vector<ClassificationResult> Annotator::ClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001678 const std::string& context, CodepointSpan selection_indices,
1679 const ClassificationOptions& options) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01001680 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001681 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001682 return {};
1683 }
1684
Lukas Zilkaba849e72018-03-08 14:48:21 +01001685 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1686 return {};
1687 }
1688
Tony Makdf54e742019-03-26 14:04:00 +00001689 std::vector<Locale> detected_text_language_tags;
1690 if (!ParseLocales(options.detected_text_language_tags,
1691 &detected_text_language_tags)) {
1692 TC3_LOG(WARNING)
1693 << "Failed to parse the detected_text_language_tags in options: "
1694 << options.detected_text_language_tags;
1695 }
1696 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1697 model_triggering_locales_,
1698 /*default_value=*/true)) {
1699 return {};
1700 }
1701
Tony Mak968412a2019-11-13 15:39:57 +00001702 if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
1703 selection_indices)) {
1704 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
Tony Mak6c4cc672018-09-17 11:48:50 +01001705 << std::get<0>(selection_indices) << " "
1706 << std::get<1>(selection_indices);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001707 return {};
1708 }
1709
Tony Mak378c1f52019-03-04 15:58:11 +00001710 // We'll accumulate a list of candidates, and pick the best candidate in the
1711 // end.
1712 std::vector<AnnotatedSpan> candidates;
1713
Tony Mak6c4cc672018-09-17 11:48:50 +01001714 // Try the knowledge engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001715 // TODO(b/126579108): Propagate error status.
Tony Mak6c4cc672018-09-17 11:48:50 +01001716 ClassificationResult knowledge_result;
Tony Mak63959242020-02-07 18:31:16 +00001717 if (knowledge_engine_ &&
1718 knowledge_engine_->ClassifyText(
1719 context, selection_indices, options.annotation_usecase,
1720 options.location_context, &knowledge_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001721 candidates.push_back({selection_indices, {knowledge_result}});
1722 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
Tony Mak854015a2019-01-16 15:56:48 +00001723 }
1724
Tony Maka2a1ff42019-09-12 15:40:32 +01001725 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1726
Tony Mak854015a2019-01-16 15:56:48 +00001727 // Try the contact engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001728 // TODO(b/126579108): Propagate error status.
Tony Mak854015a2019-01-16 15:56:48 +00001729 ClassificationResult contact_result;
1730 if (contact_engine_ && contact_engine_->ClassifyText(
1731 context, selection_indices, &contact_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001732 candidates.push_back({selection_indices, {contact_result}});
Tony Mak6c4cc672018-09-17 11:48:50 +01001733 }
1734
Tony Mak76d80962020-01-08 17:30:51 +00001735 // Try the person name engine.
1736 ClassificationResult person_name_result;
1737 if (person_name_engine_ &&
1738 person_name_engine_->ClassifyText(context, selection_indices,
1739 &person_name_result)) {
1740 candidates.push_back({selection_indices, {person_name_result}});
1741 }
1742
Tony Makd9446602019-02-20 18:25:39 +00001743 // Try the installed app engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001744 // TODO(b/126579108): Propagate error status.
Tony Makd9446602019-02-20 18:25:39 +00001745 ClassificationResult installed_app_result;
1746 if (installed_app_engine_ &&
1747 installed_app_engine_->ClassifyText(context, selection_indices,
1748 &installed_app_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001749 candidates.push_back({selection_indices, {installed_app_result}});
Tony Makd9446602019-02-20 18:25:39 +00001750 }
1751
Lukas Zilkab23e2122018-02-09 10:25:19 +01001752 // Try the regular expression models.
Tony Mak378c1f52019-03-04 15:58:11 +00001753 std::vector<ClassificationResult> regex_results;
1754 if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1755 return {};
1756 }
1757 for (const ClassificationResult& result : regex_results) {
1758 candidates.push_back({selection_indices, {result}});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001759 }
1760
Lukas Zilkab23e2122018-02-09 10:25:19 +01001761 // Try the date model.
Tony Mak378c1f52019-03-04 15:58:11 +00001762 //
1763 // DatetimeClassifyText only returns the first result, which can however have
1764 // more interpretations. They are inserted in the candidates as a single
1765 // AnnotatedSpan, so that they get treated together by the conflict resolution
1766 // algorithm.
Tony Mak854015a2019-01-16 15:56:48 +00001767 std::vector<ClassificationResult> datetime_results;
Tony Mak378c1f52019-03-04 15:58:11 +00001768 if (!DatetimeClassifyText(context, selection_indices, options,
1769 &datetime_results)) {
1770 return {};
1771 }
1772 if (!datetime_results.empty()) {
1773 candidates.push_back({selection_indices, std::move(datetime_results)});
Tony Mak448b5862019-03-22 13:36:41 +00001774 candidates.back().source = AnnotatedSpan::Source::DATETIME;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001775 }
1776
Tony Mak378c1f52019-03-04 15:58:11 +00001777 // Try the number annotator.
1778 // TODO(b/126579108): Propagate error status.
1779 ClassificationResult number_annotator_result;
1780 if (number_annotator_ &&
1781 number_annotator_->ClassifyText(
1782 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1783 options.annotation_usecase, &number_annotator_result)) {
1784 candidates.push_back({selection_indices, {number_annotator_result}});
1785 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001786
Tony Makad2e22d2019-03-20 17:35:13 +00001787 // Try the duration annotator.
1788 ClassificationResult duration_annotator_result;
1789 if (duration_annotator_ &&
1790 duration_annotator_->ClassifyText(
1791 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1792 options.annotation_usecase, &duration_annotator_result)) {
1793 candidates.push_back({selection_indices, {duration_annotator_result}});
Tony Mak448b5862019-03-22 13:36:41 +00001794 candidates.back().source = AnnotatedSpan::Source::DURATION;
Tony Makad2e22d2019-03-20 17:35:13 +00001795 }
1796
Tony Mak63959242020-02-07 18:31:16 +00001797 // Try the translate annotator.
1798 ClassificationResult translate_annotator_result;
1799 if (translate_annotator_ &&
1800 translate_annotator_->ClassifyText(
1801 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1802 options.user_familiar_language_tags, &translate_annotator_result)) {
1803 candidates.push_back({selection_indices, {translate_annotator_result}});
1804 }
1805
Tony Mak21460022020-03-12 18:29:35 +00001806 // Try the grammar model.
1807 ClassificationResult grammar_annotator_result;
1808 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1809 detected_text_language_tags,
1810 UTF8ToUnicodeText(context, /*do_copy=*/false),
1811 selection_indices, &grammar_annotator_result)) {
1812 candidates.push_back({selection_indices, {grammar_annotator_result}});
1813 }
1814
Tony Mak378c1f52019-03-04 15:58:11 +00001815 // Try the ML model.
1816 //
1817 // The output of the model is considered as an exclusive 1-of-N choice. That's
1818 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1819 // span for each candidate, like e.g. the regex model.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001820 InterpreterManager interpreter_manager(selection_executor_.get(),
1821 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001822 std::vector<ClassificationResult> model_results;
1823 std::vector<Token> tokens;
1824 if (!ModelClassifyText(
1825 context, /*cached_tokens=*/{}, detected_text_language_tags,
1826 selection_indices, &interpreter_manager,
1827 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1828 return {};
1829 }
1830 if (!model_results.empty()) {
1831 candidates.push_back({selection_indices, std::move(model_results)});
1832 }
1833
1834 std::vector<int> candidate_indices;
1835 if (!ResolveConflicts(candidates, context, tokens,
1836 detected_text_language_tags, options.annotation_usecase,
1837 &interpreter_manager, &candidate_indices)) {
1838 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1839 return {};
1840 }
1841
1842 std::vector<ClassificationResult> results;
1843 for (const int i : candidate_indices) {
1844 for (const ClassificationResult& result : candidates[i].classification) {
1845 if (!FilteredForClassification(result)) {
1846 results.push_back(result);
1847 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001848 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001849 }
1850
Tony Mak378c1f52019-03-04 15:58:11 +00001851 // Sort results according to score.
1852 std::sort(results.begin(), results.end(),
1853 [](const ClassificationResult& a, const ClassificationResult& b) {
1854 return a.score > b.score;
1855 });
1856
1857 if (results.empty()) {
Tony Mak81e52422019-04-30 09:34:45 +01001858 results = {{Collections::Other(), 1.0}};
Tony Mak378c1f52019-03-04 15:58:11 +00001859 }
Tony Mak378c1f52019-03-04 15:58:11 +00001860 return results;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001861}
1862
Tony Mak378c1f52019-03-04 15:58:11 +00001863bool Annotator::ModelAnnotate(
1864 const std::string& context,
1865 const std::vector<Locale>& detected_text_language_tags,
1866 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1867 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001868 if (model_->triggering_options() == nullptr ||
1869 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1870 return true;
1871 }
1872
Tony Makdf54e742019-03-26 14:04:00 +00001873 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1874 ml_model_triggering_locales_,
1875 /*default_value=*/true)) {
1876 return true;
1877 }
1878
Lukas Zilka21d8c982018-01-24 11:11:20 +01001879 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1880 /*do_copy=*/false);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001881 std::vector<UnicodeTextRange> lines;
1882 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1883 lines.push_back({context_unicode.begin(), context_unicode.end()});
1884 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +01001885 lines = selection_feature_processor_->SplitContext(
1886 context_unicode, selection_feature_processor_->GetOptions()
1887 ->use_pipe_character_for_newline());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001888 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001889
Lukas Zilkaba849e72018-03-08 14:48:21 +01001890 const float min_annotate_confidence =
1891 (model_->triggering_options() != nullptr
1892 ? model_->triggering_options()->min_annotate_confidence()
1893 : 0.f);
1894
Lukas Zilkab23e2122018-02-09 10:25:19 +01001895 for (const UnicodeTextRange& line : lines) {
Tony Mak408c6b82019-03-08 17:57:27 +00001896 FeatureProcessor::EmbeddingCache embedding_cache;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001897 const std::string line_str =
1898 UnicodeText::UTF8Substring(line.first, line.second);
1899
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001900 *tokens = selection_feature_processor_->Tokenize(line_str);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001901 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001902 line_str, {0, std::distance(line.first, line.second)},
1903 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001904 tokens,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001905 /*click_pos=*/nullptr);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001906 const TokenSpan full_line_span = {0, tokens->size()};
Lukas Zilka21d8c982018-01-24 11:11:20 +01001907
Lukas Zilka434442d2018-04-25 11:38:51 +02001908 // TODO(zilka): Add support for greater granularity of this check.
1909 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1910 *tokens, full_line_span)) {
1911 continue;
1912 }
1913
Lukas Zilka21d8c982018-01-24 11:11:20 +01001914 std::unique_ptr<CachedFeatures> cached_features;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001915 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001916 *tokens, full_line_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001917 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1918 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001919 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001920 selection_feature_processor_->EmbeddingSize() +
1921 selection_feature_processor_->DenseFeaturesCount(),
Lukas Zilka21d8c982018-01-24 11:11:20 +01001922 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001923 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001924 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001925 }
1926
1927 std::vector<TokenSpan> local_chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001928 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001929 interpreter_manager->SelectionInterpreter(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01001930 *cached_features, &local_chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001931 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001932 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001933 }
1934
1935 const int offset = std::distance(context_unicode.begin(), line.first);
1936 for (const TokenSpan& chunk : local_chunks) {
1937 const CodepointSpan codepoint_span =
1938 selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001939 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001940
1941 // Skip empty spans.
1942 if (codepoint_span.first != codepoint_span.second) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001943 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001944 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
1945 codepoint_span, interpreter_manager,
1946 &embedding_cache, &classification)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001947 TC3_LOG(ERROR) << "Could not classify text: "
1948 << (codepoint_span.first + offset) << " "
1949 << (codepoint_span.second + offset);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001950 return false;
1951 }
1952
1953 // Do not include the span if it's classified as "other".
1954 if (!classification.empty() && !ClassifiedAsOther(classification) &&
1955 classification[0].score >= min_annotate_confidence) {
1956 AnnotatedSpan result_span;
1957 result_span.span = {codepoint_span.first + offset,
1958 codepoint_span.second + offset};
1959 result_span.classification = std::move(classification);
1960 result->push_back(std::move(result_span));
1961 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001962 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001963 }
1964 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001965 return true;
1966}
1967
Tony Mak6c4cc672018-09-17 11:48:50 +01001968const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
Lukas Zilka434442d2018-04-25 11:38:51 +02001969 return selection_feature_processor_.get();
1970}
1971
Tony Mak6c4cc672018-09-17 11:48:50 +01001972const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
Lukas Zilka434442d2018-04-25 11:38:51 +02001973 const {
1974 return classification_feature_processor_.get();
Lukas Zilkab23e2122018-02-09 10:25:19 +01001975}
1976
Tony Mak6c4cc672018-09-17 11:48:50 +01001977const DatetimeParser* Annotator::DatetimeParserForTests() const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001978 return datetime_parser_.get();
1979}
1980
Tony Mak83d2de62019-04-10 16:12:15 +01001981void Annotator::RemoveNotEnabledEntityTypes(
1982 const EnabledEntityTypes& is_entity_type_enabled,
1983 std::vector<AnnotatedSpan>* annotated_spans) const {
1984 for (AnnotatedSpan& annotated_span : *annotated_spans) {
1985 std::vector<ClassificationResult>& classifications =
1986 annotated_span.classification;
1987 classifications.erase(
1988 std::remove_if(classifications.begin(), classifications.end(),
1989 [&is_entity_type_enabled](
1990 const ClassificationResult& classification_result) {
1991 return !is_entity_type_enabled(
1992 classification_result.collection);
1993 }),
1994 classifications.end());
1995 }
1996 annotated_spans->erase(
1997 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
1998 [](const AnnotatedSpan& annotated_span) {
1999 return annotated_span.classification.empty();
2000 }),
2001 annotated_spans->end());
2002}
2003
Tony Maka2a1ff42019-09-12 15:40:32 +01002004void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2005 std::vector<AnnotatedSpan>* candidates) const {
2006 if (candidates == nullptr || contact_engine_ == nullptr) {
2007 return;
2008 }
2009 for (auto& candidate : *candidates) {
2010 for (auto& classification_result : candidate.classification) {
2011 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2012 &classification_result);
2013 }
2014 }
2015}
2016
Tony Mak6c4cc672018-09-17 11:48:50 +01002017std::vector<AnnotatedSpan> Annotator::Annotate(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002018 const std::string& context, const AnnotationOptions& options) const {
2019 std::vector<AnnotatedSpan> candidates;
2020
Lukas Zilkaba849e72018-03-08 14:48:21 +01002021 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2022 return {};
2023 }
2024
Tony Mak854015a2019-01-16 15:56:48 +00002025 const UnicodeText context_unicode =
2026 UTF8ToUnicodeText(context, /*do_copy=*/false);
2027 if (!context_unicode.is_valid()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002028 return {};
2029 }
2030
Tony Mak378c1f52019-03-04 15:58:11 +00002031 std::vector<Locale> detected_text_language_tags;
2032 if (!ParseLocales(options.detected_text_language_tags,
2033 &detected_text_language_tags)) {
2034 TC3_LOG(WARNING)
Tony Makdf54e742019-03-26 14:04:00 +00002035 << "Failed to parse the detected_text_language_tags in options: "
Tony Mak378c1f52019-03-04 15:58:11 +00002036 << options.detected_text_language_tags;
2037 }
Tony Makdf54e742019-03-26 14:04:00 +00002038 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2039 model_triggering_locales_,
2040 /*default_value=*/true)) {
2041 return {};
2042 }
2043
2044 InterpreterManager interpreter_manager(selection_executor_.get(),
2045 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00002046
Lukas Zilkab23e2122018-02-09 10:25:19 +01002047 // Annotate with the selection model.
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002048 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00002049 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
2050 &tokens, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002051 TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002052 return {};
2053 }
2054
2055 // Annotate with the regular expression models.
2056 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Tony Mak83d2de62019-04-10 16:12:15 +01002057 annotation_regex_patterns_, &candidates,
2058 options.is_serialized_entity_data_enabled)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002059 TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002060 return {};
2061 }
2062
2063 // Annotate with the datetime model.
Tony Mak83d2de62019-04-10 16:12:15 +01002064 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2065 if ((is_entity_type_enabled(Collections::Date()) ||
2066 is_entity_type_enabled(Collections::DateTime())) &&
2067 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002068 options.reference_time_ms_utc, options.reference_timezone,
Tony Makd9446602019-02-20 18:25:39 +00002069 options.locales, ModeFlag_ANNOTATION,
Tony Mak83d2de62019-04-10 16:12:15 +01002070 options.annotation_usecase,
2071 options.is_serialized_entity_data_enabled, &candidates)) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002072 TC3_LOG(ERROR) << "Couldn't run DatetimeChunk.";
Tony Mak6c4cc672018-09-17 11:48:50 +01002073 return {};
2074 }
2075
Tony Maka2a1ff42019-09-12 15:40:32 +01002076 // Annotate with the knowledge engine into a temporary vector.
2077 std::vector<AnnotatedSpan> knowledge_candidates;
2078 if (knowledge_engine_ &&
2079 !knowledge_engine_->Chunk(context, options.annotation_usecase,
Tony Mak63959242020-02-07 18:31:16 +00002080 options.location_context,
Tony Maka2a1ff42019-09-12 15:40:32 +01002081 &knowledge_candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002082 TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002083 return {};
2084 }
2085
Tony Maka2a1ff42019-09-12 15:40:32 +01002086 AddContactMetadataToKnowledgeClassificationResults(&knowledge_candidates);
2087
2088 // Move the knowledge candidates to the full candidate list, and erase
2089 // knowledge_candidates.
2090 candidates.insert(candidates.end(),
2091 std::make_move_iterator(knowledge_candidates.begin()),
2092 std::make_move_iterator(knowledge_candidates.end()));
2093 knowledge_candidates.clear();
2094
Tony Mak854015a2019-01-16 15:56:48 +00002095 // Annotate with the contact engine.
2096 if (contact_engine_ &&
2097 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
2098 TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
2099 return {};
2100 }
2101
Tony Makd9446602019-02-20 18:25:39 +00002102 // Annotate with the installed app engine.
2103 if (installed_app_engine_ &&
2104 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
2105 TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
2106 return {};
2107 }
2108
Tony Mak378c1f52019-03-04 15:58:11 +00002109 // Annotate with the number annotator.
2110 if (number_annotator_ != nullptr &&
2111 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2112 &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +00002113 TC3_LOG(ERROR) << "Couldn't run number annotator FindAll.";
2114 return {};
2115 }
2116
2117 // Annotate with the duration annotator.
Tony Mak83d2de62019-04-10 16:12:15 +01002118 if (is_entity_type_enabled(Collections::Duration()) &&
2119 duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +00002120 !duration_annotator_->FindAll(context_unicode, tokens,
2121 options.annotation_usecase, &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +00002122 TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll.";
Tony Mak378c1f52019-03-04 15:58:11 +00002123 return {};
2124 }
2125
Tony Mak76d80962020-01-08 17:30:51 +00002126 // Annotate with the person name engine.
2127 if (is_entity_type_enabled(Collections::PersonName()) &&
2128 person_name_engine_ &&
2129 !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
2130 TC3_LOG(ERROR) << "Couldn't run person name engine Chunk.";
2131 return {};
2132 }
2133
Tony Mak21460022020-03-12 18:29:35 +00002134 // Annotate with the grammar annotators.
2135 if (grammar_annotator_ != nullptr &&
2136 !grammar_annotator_->Annotate(detected_text_language_tags,
2137 context_unicode, &candidates)) {
2138 TC3_LOG(ERROR) << "Couldn't run grammar annotators.";
2139 return {};
2140 }
2141
Lukas Zilkab23e2122018-02-09 10:25:19 +01002142 // Sort candidates according to their position in the input, so that the next
2143 // code can assume that any connected component of overlapping spans forms a
2144 // contiguous block.
2145 std::sort(candidates.begin(), candidates.end(),
2146 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2147 return a.span.first < b.span.first;
2148 });
2149
2150 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00002151 if (!ResolveConflicts(candidates, context, tokens,
2152 detected_text_language_tags, options.annotation_usecase,
2153 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002154 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002155 return {};
2156 }
2157
Lukas Zilkab23e2122018-02-09 10:25:19 +01002158 std::vector<AnnotatedSpan> result;
2159 result.reserve(candidate_indices.size());
Tony Mak378c1f52019-03-04 15:58:11 +00002160 AnnotatedSpan aggregated_span;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002161 for (const int i : candidate_indices) {
Tony Mak378c1f52019-03-04 15:58:11 +00002162 if (candidates[i].span != aggregated_span.span) {
2163 if (!aggregated_span.classification.empty()) {
2164 result.push_back(std::move(aggregated_span));
2165 }
2166 aggregated_span =
2167 AnnotatedSpan(candidates[i].span, /*arg_classification=*/{});
Lukas Zilkab23e2122018-02-09 10:25:19 +01002168 }
Tony Mak378c1f52019-03-04 15:58:11 +00002169 if (candidates[i].classification.empty() ||
2170 ClassifiedAsOther(candidates[i].classification) ||
2171 FilteredForAnnotation(candidates[i])) {
2172 continue;
2173 }
2174 for (ClassificationResult& classification : candidates[i].classification) {
2175 aggregated_span.classification.push_back(std::move(classification));
2176 }
2177 }
2178 if (!aggregated_span.classification.empty()) {
2179 result.push_back(std::move(aggregated_span));
2180 }
2181
Tony Mak83d2de62019-04-10 16:12:15 +01002182 // We generate all candidates and remove them later (with the exception of
2183 // date/time/duration entities) because there are complex interdependencies
2184 // between the entity types. E.g., the TLD of an email can be interpreted as a
2185 // URL, but most likely a user of the API does not want such annotations if
2186 // "url" is enabled and "email" is not.
2187 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2188
Tony Mak378c1f52019-03-04 15:58:11 +00002189 for (AnnotatedSpan& annotated_span : result) {
2190 SortClassificationResults(&annotated_span.classification);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002191 }
2192
Lukas Zilka21d8c982018-01-24 11:11:20 +01002193 return result;
2194}
2195
Tony Mak854015a2019-01-16 15:56:48 +00002196CodepointSpan Annotator::ComputeSelectionBoundaries(
2197 const UniLib::RegexMatcher* match,
2198 const RegexModel_::Pattern* config) const {
2199 if (config->capturing_group() == nullptr) {
2200 // Use first capturing group to specify the selection.
2201 int status = UniLib::RegexMatcher::kNoError;
2202 const CodepointSpan result = {match->Start(1, &status),
2203 match->End(1, &status)};
2204 if (status != UniLib::RegexMatcher::kNoError) {
2205 return {kInvalidIndex, kInvalidIndex};
2206 }
2207 return result;
2208 }
2209
2210 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2211 const int num_groups = config->capturing_group()->size();
2212 for (int i = 0; i < num_groups; i++) {
2213 if (!config->capturing_group()->Get(i)->extend_selection()) {
2214 continue;
2215 }
2216
2217 int status = UniLib::RegexMatcher::kNoError;
2218 // Check match and adjust bounds.
2219 const int group_start = match->Start(i, &status);
2220 const int group_end = match->End(i, &status);
2221 if (status != UniLib::RegexMatcher::kNoError) {
2222 return {kInvalidIndex, kInvalidIndex};
2223 }
2224 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2225 continue;
2226 }
2227 if (result.first == kInvalidIndex) {
2228 result = {group_start, group_end};
2229 } else {
2230 result.first = std::min(result.first, group_start);
2231 result.second = std::max(result.second, group_end);
2232 }
2233 }
2234 return result;
2235}
2236
Tony Makd9446602019-02-20 18:25:39 +00002237bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
Tony Mak21460022020-03-12 18:29:35 +00002238 if (pattern->serialized_entity_data() != nullptr ||
2239 pattern->entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002240 return true;
2241 }
2242 if (pattern->capturing_group() != nullptr) {
Tony Mak63959242020-02-07 18:31:16 +00002243 for (const CapturingGroup* group : *pattern->capturing_group()) {
Tony Makd9446602019-02-20 18:25:39 +00002244 if (group->entity_field_path() != nullptr) {
2245 return true;
2246 }
Tony Mak21460022020-03-12 18:29:35 +00002247 if (group->serialized_entity_data() != nullptr ||
2248 group->entity_data() != nullptr) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002249 return true;
2250 }
Tony Makd9446602019-02-20 18:25:39 +00002251 }
2252 }
2253 return false;
2254}
2255
2256bool Annotator::SerializedEntityDataFromRegexMatch(
2257 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2258 std::string* serialized_entity_data) const {
2259 if (!HasEntityData(pattern)) {
2260 serialized_entity_data->clear();
2261 return true;
2262 }
2263 TC3_CHECK(entity_data_builder_ != nullptr);
2264
2265 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
2266 entity_data_builder_->NewRoot();
2267
2268 TC3_CHECK(entity_data != nullptr);
2269
Tony Mak21460022020-03-12 18:29:35 +00002270 // Set fixed entity data.
Tony Makd9446602019-02-20 18:25:39 +00002271 if (pattern->serialized_entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002272 entity_data->MergeFromSerializedFlatbuffer(
2273 StringPiece(pattern->serialized_entity_data()->c_str(),
2274 pattern->serialized_entity_data()->size()));
2275 }
Tony Mak21460022020-03-12 18:29:35 +00002276 if (pattern->entity_data() != nullptr) {
2277 entity_data->MergeFrom(
2278 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2279 }
Tony Makd9446602019-02-20 18:25:39 +00002280
2281 // Add entity data from rule capturing groups.
2282 if (pattern->capturing_group() != nullptr) {
2283 const int num_groups = pattern->capturing_group()->size();
2284 for (int i = 0; i < num_groups; i++) {
Tony Mak63959242020-02-07 18:31:16 +00002285 const CapturingGroup* group = pattern->capturing_group()->Get(i);
Tony Maka2a1ff42019-09-12 15:40:32 +01002286
2287 // Check whether the group matched.
2288 Optional<std::string> group_match_text =
2289 GetCapturingGroupText(matcher, /*group_id=*/i);
2290 if (!group_match_text.has_value()) {
Tony Makd9446602019-02-20 18:25:39 +00002291 continue;
2292 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002293
Tony Mak21460022020-03-12 18:29:35 +00002294 // Set fixed entity data from capturing group match.
Tony Maka2a1ff42019-09-12 15:40:32 +01002295 if (group->serialized_entity_data() != nullptr) {
2296 entity_data->MergeFromSerializedFlatbuffer(
2297 StringPiece(group->serialized_entity_data()->c_str(),
2298 group->serialized_entity_data()->size()));
2299 }
Tony Mak21460022020-03-12 18:29:35 +00002300 if (group->entity_data() != nullptr) {
2301 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2302 pattern->entity_data()));
2303 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002304
2305 // Set entity field from capturing group text.
2306 if (group->entity_field_path() != nullptr) {
Tony Mak8cd7ba62019-10-15 15:29:22 +01002307 UnicodeText normalized_group_match_text =
2308 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2309
2310 // Apply normalization if specified.
2311 if (group->normalization_options() != nullptr) {
2312 normalized_group_match_text =
2313 NormalizeText(unilib_, group->normalization_options(),
2314 normalized_group_match_text);
2315 }
2316
2317 if (!entity_data->ParseAndSet(
2318 group->entity_field_path(),
2319 normalized_group_match_text.ToUTF8String())) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002320 TC3_LOG(ERROR)
2321 << "Could not set entity data from rule capturing group.";
2322 return false;
2323 }
Tony Makd9446602019-02-20 18:25:39 +00002324 }
2325 }
2326 }
2327
2328 *serialized_entity_data = entity_data->Serialize();
2329 return true;
2330}
2331
Tony Mak63959242020-02-07 18:31:16 +00002332UnicodeText RemoveMoneySeparators(
2333 const std::unordered_set<char32>& decimal_separators,
2334 const UnicodeText& amount,
2335 UnicodeText::const_iterator it_decimal_separator) {
2336 UnicodeText whole_amount;
2337 for (auto it = amount.begin();
2338 it != amount.end() && it != it_decimal_separator; ++it) {
2339 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2340 static_cast<char32>(*it)) == decimal_separators.end()) {
2341 whole_amount.push_back(*it);
2342 }
2343 }
2344 return whole_amount;
2345}
2346
2347bool Annotator::ParseAndFillInMoneyAmount(
2348 std::string* serialized_entity_data) const {
2349 std::unique_ptr<EntityDataT> data =
2350 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2351 *serialized_entity_data);
Tony Mak21460022020-03-12 18:29:35 +00002352 if (data == nullptr || data->money->unnormalized_amount.empty()) {
Tony Mak63959242020-02-07 18:31:16 +00002353 return false;
2354 }
2355
2356 UnicodeText amount =
2357 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2358 int separator_back_index = 0;
Tony Mak21460022020-03-12 18:29:35 +00002359 auto it_decimal_separator = --amount.end();
Tony Mak63959242020-02-07 18:31:16 +00002360 for (; it_decimal_separator != amount.begin();
2361 --it_decimal_separator, ++separator_back_index) {
2362 if (std::find(money_separators_.begin(), money_separators_.end(),
2363 static_cast<char32>(*it_decimal_separator)) !=
2364 money_separators_.end()) {
2365 break;
2366 }
2367 }
2368
2369 // If there are 3 digits after the last separator, we consider that a
2370 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2371 // If there is no separator in number, also that number is an int.
Tony Mak21460022020-03-12 18:29:35 +00002372 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
Tony Mak63959242020-02-07 18:31:16 +00002373 it_decimal_separator = amount.end();
2374 }
2375
2376 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2377 it_decimal_separator),
2378 &data->money->amount_whole_part)) {
2379 TC3_LOG(ERROR) << "Could not parse the money whole part as int32.";
2380 return false;
2381 }
2382 if (it_decimal_separator == amount.end()) {
2383 data->money->amount_decimal_part = 0;
2384 } else {
2385 const int amount_codepoints_size = amount.size_codepoints();
2386 if (!unilib_->ParseInt32(
2387 UnicodeText::Substring(
Tony Mak21460022020-03-12 18:29:35 +00002388 amount, amount_codepoints_size - separator_back_index,
Tony Mak63959242020-02-07 18:31:16 +00002389 amount_codepoints_size, /*do_copy=*/false),
2390 &data->money->amount_decimal_part)) {
2391 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32.";
2392 return false;
2393 }
2394 }
2395
2396 *serialized_entity_data =
2397 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2398 return true;
2399}
2400
Tony Mak6c4cc672018-09-17 11:48:50 +01002401bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2402 const std::vector<int>& rules,
Tony Mak83d2de62019-04-10 16:12:15 +01002403 std::vector<AnnotatedSpan>* result,
2404 bool is_serialized_entity_data_enabled) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002405 for (int pattern_id : rules) {
2406 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2407 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2408 if (!matcher) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002409 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2410 << pattern_id;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002411 return false;
2412 }
2413
2414 int status = UniLib::RegexMatcher::kNoError;
2415 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
Tony Mak854015a2019-01-16 15:56:48 +00002416 if (regex_pattern.config->verification_options()) {
Tony Makdf54e742019-03-26 14:04:00 +00002417 if (!VerifyRegexMatchCandidate(
2418 context_unicode.ToUTF8String(),
2419 regex_pattern.config->verification_options(),
2420 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002421 continue;
2422 }
2423 }
Tony Makd9446602019-02-20 18:25:39 +00002424
2425 std::string serialized_entity_data;
Tony Mak83d2de62019-04-10 16:12:15 +01002426 if (is_serialized_entity_data_enabled) {
2427 if (!SerializedEntityDataFromRegexMatch(
2428 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2429 TC3_LOG(ERROR) << "Could not get entity data.";
2430 return false;
2431 }
Tony Mak63959242020-02-07 18:31:16 +00002432
2433 // Further parsing unnormalized_amount for money into amount_whole_part
2434 // and amount_decimal_part. Can't do this with regexes because we cannot
2435 // have empty groups (amount_decimal_part might be an empty group).
2436 if (regex_pattern.config->collection_name()->str() ==
2437 Collections::Money()) {
2438 if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
2439 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2440 }
2441 }
Tony Makd9446602019-02-20 18:25:39 +00002442 }
2443
Lukas Zilkab23e2122018-02-09 10:25:19 +01002444 result->emplace_back();
Tony Mak854015a2019-01-16 15:56:48 +00002445
Lukas Zilkab23e2122018-02-09 10:25:19 +01002446 // Selection/annotation regular expressions need to specify a capturing
2447 // group specifying the selection.
Tony Mak854015a2019-01-16 15:56:48 +00002448 result->back().span =
2449 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2450
Lukas Zilkab23e2122018-02-09 10:25:19 +01002451 result->back().classification = {
Tony Mak854015a2019-01-16 15:56:48 +00002452 {regex_pattern.config->collection_name()->str(),
2453 regex_pattern.config->target_classification_score(),
2454 regex_pattern.config->priority_score()}};
Tony Makd9446602019-02-20 18:25:39 +00002455
2456 result->back().classification[0].serialized_entity_data =
2457 serialized_entity_data;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002458 }
2459 }
2460 return true;
2461}
2462
Tony Mak6c4cc672018-09-17 11:48:50 +01002463bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2464 tflite::Interpreter* selection_interpreter,
2465 const CachedFeatures& cached_features,
2466 std::vector<TokenSpan>* chunks) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01002467 const int max_selection_span =
2468 selection_feature_processor_->GetOptions()->max_selection_span();
Lukas Zilka21d8c982018-01-24 11:11:20 +01002469 // The inference span is the span of interest expanded to include
2470 // max_selection_span tokens on either side, which is how far a selection can
2471 // stretch from the click.
2472 const TokenSpan inference_span = IntersectTokenSpans(
2473 ExpandTokenSpan(span_of_interest,
2474 /*num_tokens_left=*/max_selection_span,
2475 /*num_tokens_right=*/max_selection_span),
2476 {0, num_tokens});
2477
2478 std::vector<ScoredChunk> scored_chunks;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002479 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2480 selection_feature_processor_->GetOptions()
2481 ->bounds_sensitive_features()
2482 ->enabled()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002483 if (!ModelBoundsSensitiveScoreChunks(
2484 num_tokens, span_of_interest, inference_span, cached_features,
2485 selection_interpreter, &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002486 return false;
2487 }
2488 } else {
2489 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002490 cached_features, selection_interpreter,
2491 &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002492 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002493 }
2494 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002495 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2496 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2497 return lhs.score < rhs.score;
2498 });
Lukas Zilka21d8c982018-01-24 11:11:20 +01002499
2500 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2501 // them greedily as long as they do not overlap with any previously picked
2502 // chunks.
2503 std::vector<bool> token_used(TokenSpanSize(inference_span));
2504 chunks->clear();
2505 for (const ScoredChunk& scored_chunk : scored_chunks) {
2506 bool feasible = true;
2507 for (int i = scored_chunk.token_span.first;
2508 i < scored_chunk.token_span.second; ++i) {
2509 if (token_used[i - inference_span.first]) {
2510 feasible = false;
2511 break;
2512 }
2513 }
2514
2515 if (!feasible) {
2516 continue;
2517 }
2518
2519 for (int i = scored_chunk.token_span.first;
2520 i < scored_chunk.token_span.second; ++i) {
2521 token_used[i - inference_span.first] = true;
2522 }
2523
2524 chunks->push_back(scored_chunk.token_span);
2525 }
2526
2527 std::sort(chunks->begin(), chunks->end());
2528
2529 return true;
2530}
2531
Lukas Zilkab23e2122018-02-09 10:25:19 +01002532namespace {
2533// Updates the value at the given key in the map to maximum of the current value
2534// and the given value, or simply inserts the value if the key is not yet there.
2535template <typename Map>
2536void UpdateMax(Map* map, typename Map::key_type key,
2537 typename Map::mapped_type value) {
2538 const auto it = map->find(key);
2539 if (it != map->end()) {
2540 it->second = std::max(it->second, value);
2541 } else {
2542 (*map)[key] = value;
2543 }
2544}
2545} // namespace
2546
Tony Mak6c4cc672018-09-17 11:48:50 +01002547bool Annotator::ModelClickContextScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002548 int num_tokens, const TokenSpan& span_of_interest,
2549 const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002550 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002551 std::vector<ScoredChunk>* scored_chunks) const {
2552 const int max_batch_size = model_->selection_options()->batch_size();
2553
2554 std::vector<float> all_features;
2555 std::map<TokenSpan, float> chunk_scores;
2556 for (int batch_start = span_of_interest.first;
2557 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2558 const int batch_end =
2559 std::min(batch_start + max_batch_size, span_of_interest.second);
2560
2561 // Prepare features for the whole batch.
2562 all_features.clear();
2563 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2564 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2565 cached_features.AppendClickContextFeaturesForClick(click_pos,
2566 &all_features);
2567 }
2568
2569 // Run batched inference.
2570 const int batch_size = batch_end - batch_start;
2571 const int features_size = cached_features.OutputFeaturesSize();
2572 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002573 TensorView<float>(all_features.data(), {batch_size, features_size}),
2574 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002575 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002576 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002577 return false;
2578 }
2579 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2580 logits.dim(1) !=
2581 selection_feature_processor_->GetSelectionLabelCount()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002582 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002583 return false;
2584 }
2585
2586 // Save results.
2587 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2588 const std::vector<float> scores = ComputeSoftmax(
2589 logits.data() + logits.dim(1) * (click_pos - batch_start),
2590 logits.dim(1));
2591 for (int j = 0;
2592 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2593 TokenSpan relative_token_span;
2594 if (!selection_feature_processor_->LabelToTokenSpan(
2595 j, &relative_token_span)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002596 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002597 return false;
2598 }
2599 const TokenSpan candidate_span = ExpandTokenSpan(
2600 SingleTokenSpan(click_pos), relative_token_span.first,
2601 relative_token_span.second);
2602 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2603 UpdateMax(&chunk_scores, candidate_span, scores[j]);
2604 }
2605 }
2606 }
2607 }
2608
2609 scored_chunks->clear();
2610 scored_chunks->reserve(chunk_scores.size());
2611 for (const auto& entry : chunk_scores) {
2612 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2613 }
2614
2615 return true;
2616}
2617
Tony Mak6c4cc672018-09-17 11:48:50 +01002618bool Annotator::ModelBoundsSensitiveScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002619 int num_tokens, const TokenSpan& span_of_interest,
2620 const TokenSpan& inference_span, const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002621 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002622 std::vector<ScoredChunk>* scored_chunks) const {
2623 const int max_selection_span =
2624 selection_feature_processor_->GetOptions()->max_selection_span();
2625 const int max_chunk_length = selection_feature_processor_->GetOptions()
2626 ->selection_reduced_output_space()
2627 ? max_selection_span + 1
2628 : 2 * max_selection_span + 1;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002629 const bool score_single_token_spans_as_zero =
2630 selection_feature_processor_->GetOptions()
2631 ->bounds_sensitive_features()
2632 ->score_single_token_spans_as_zero();
2633
2634 scored_chunks->clear();
2635 if (score_single_token_spans_as_zero) {
2636 scored_chunks->reserve(TokenSpanSize(span_of_interest));
2637 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002638
2639 // Prepare all chunk candidates into one batch:
2640 // - Are contained in the inference span
2641 // - Have a non-empty intersection with the span of interest
2642 // - Are at least one token long
2643 // - Are not longer than the maximum chunk length
2644 std::vector<TokenSpan> candidate_spans;
2645 for (int start = inference_span.first; start < span_of_interest.second;
2646 ++start) {
2647 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2648 for (int end = leftmost_end_index;
2649 end <= inference_span.second && end - start <= max_chunk_length;
2650 ++end) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002651 const TokenSpan candidate_span = {start, end};
2652 if (score_single_token_spans_as_zero &&
2653 TokenSpanSize(candidate_span) == 1) {
2654 // Do not include the single token span in the batch, add a zero score
2655 // for it directly to the output.
2656 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2657 } else {
2658 candidate_spans.push_back(candidate_span);
2659 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002660 }
2661 }
2662
2663 const int max_batch_size = model_->selection_options()->batch_size();
2664
2665 std::vector<float> all_features;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002666 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
Lukas Zilkab23e2122018-02-09 10:25:19 +01002667 for (int batch_start = 0; batch_start < candidate_spans.size();
2668 batch_start += max_batch_size) {
2669 const int batch_end = std::min(batch_start + max_batch_size,
2670 static_cast<int>(candidate_spans.size()));
2671
2672 // Prepare features for the whole batch.
2673 all_features.clear();
2674 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2675 for (int i = batch_start; i < batch_end; ++i) {
2676 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2677 &all_features);
2678 }
2679
2680 // Run batched inference.
2681 const int batch_size = batch_end - batch_start;
2682 const int features_size = cached_features.OutputFeaturesSize();
2683 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002684 TensorView<float>(all_features.data(), {batch_size, features_size}),
2685 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002686 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002687 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002688 return false;
2689 }
2690 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2691 logits.dim(1) != 1) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002692 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002693 return false;
2694 }
2695
2696 // Save results.
2697 for (int i = batch_start; i < batch_end; ++i) {
2698 scored_chunks->push_back(
2699 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2700 }
2701 }
2702
2703 return true;
2704}
2705
Tony Mak6c4cc672018-09-17 11:48:50 +01002706bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
2707 int64 reference_time_ms_utc,
2708 const std::string& reference_timezone,
2709 const std::string& locales, ModeFlag mode,
Tony Makd9446602019-02-20 18:25:39 +00002710 AnnotationUsecase annotation_usecase,
Tony Mak83d2de62019-04-10 16:12:15 +01002711 bool is_serialized_entity_data_enabled,
Tony Mak6c4cc672018-09-17 11:48:50 +01002712 std::vector<AnnotatedSpan>* result) const {
Tony Mak63959242020-02-07 18:31:16 +00002713 std::vector<DatetimeParseResultSpan> datetime_spans;
2714 if (cfg_datetime_parser_) {
2715 if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
2716 return true;
2717 }
2718 std::vector<Locale> parsed_locales;
2719 ParseLocales(locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00002720 cfg_datetime_parser_->Parse(
2721 context_unicode.ToUTF8String(),
2722 ToDateAnnotationOptions(
2723 model_->grammar_datetime_model()->annotation_options(),
2724 reference_timezone, reference_time_ms_utc),
2725 parsed_locales, &datetime_spans);
Tony Mak63959242020-02-07 18:31:16 +00002726 } else if (datetime_parser_) {
2727 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
2728 reference_timezone, locales, mode,
2729 annotation_usecase,
2730 /*anchor_start_end=*/false, &datetime_spans)) {
2731 return false;
2732 }
2733 } else {
Lukas Zilka434442d2018-04-25 11:38:51 +02002734 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002735 }
2736
Lukas Zilkab23e2122018-02-09 10:25:19 +01002737 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
Tony Mak378c1f52019-03-04 15:58:11 +00002738 AnnotatedSpan annotated_span;
2739 annotated_span.span = datetime_span.span;
Tony Mak854015a2019-01-16 15:56:48 +00002740 for (const DatetimeParseResult& parse_result : datetime_span.data) {
Tony Mak378c1f52019-03-04 15:58:11 +00002741 annotated_span.classification.emplace_back(
2742 PickCollectionForDatetime(parse_result),
2743 datetime_span.target_classification_score,
2744 datetime_span.priority_score);
2745 annotated_span.classification.back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01002746 if (is_serialized_entity_data_enabled) {
2747 annotated_span.classification.back().serialized_entity_data =
2748 CreateDatetimeSerializedEntityData(parse_result);
2749 }
Tony Mak854015a2019-01-16 15:56:48 +00002750 }
Tony Mak448b5862019-03-22 13:36:41 +00002751 annotated_span.source = AnnotatedSpan::Source::DATETIME;
Tony Mak378c1f52019-03-04 15:58:11 +00002752 result->push_back(std::move(annotated_span));
Lukas Zilkab23e2122018-02-09 10:25:19 +01002753 }
2754 return true;
2755}
2756
Tony Mak378c1f52019-03-04 15:58:11 +00002757const Model* Annotator::model() const { return model_; }
Tony Makd9446602019-02-20 18:25:39 +00002758const reflection::Schema* Annotator::entity_data_schema() const {
2759 return entity_data_schema_;
2760}
Tony Mak854015a2019-01-16 15:56:48 +00002761
Lukas Zilka21d8c982018-01-24 11:11:20 +01002762const Model* ViewModel(const void* buffer, int size) {
2763 if (!buffer) {
2764 return nullptr;
2765 }
2766
2767 return LoadAndVerifyModel(buffer, size);
2768}
2769
Tony Makd9446602019-02-20 18:25:39 +00002770bool Annotator::LookUpKnowledgeEntity(
2771 const std::string& id, std::string* serialized_knowledge_result) const {
2772 return knowledge_engine_ &&
2773 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
2774}
2775
Tony Mak6c4cc672018-09-17 11:48:50 +01002776} // namespace libtextclassifier3