blob: 93a3270a35fd7c71e8483299dda7f87e3be38d87 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <algorithm>
Lukas Zilka21d8c982018-01-24 11:11:20 +010020#include <cmath>
Tony Mak21460022020-03-12 18:29:35 +000021#include <cstddef>
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <iterator>
23#include <numeric>
Tony Mak63959242020-02-07 18:31:16 +000024#include <string>
Tony Mak448b5862019-03-22 13:36:41 +000025#include <unordered_map>
Tony Mak63959242020-02-07 18:31:16 +000026#include <vector>
Lukas Zilka21d8c982018-01-24 11:11:20 +010027
Tony Mak854015a2019-01-16 15:56:48 +000028#include "annotator/collections.h"
Tony Mak83d2de62019-04-10 16:12:15 +010029#include "annotator/model_generated.h"
30#include "annotator/types.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010031#include "utils/base/logging.h"
Tony Makff31efb2020-03-31 11:13:06 +010032#include "utils/base/status.h"
33#include "utils/base/statusor.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010034#include "utils/checksum.h"
Tony Mak63959242020-02-07 18:31:16 +000035#include "utils/i18n/locale.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010036#include "utils/math/softmax.h"
Tony Mak8cd7ba62019-10-15 15:29:22 +010037#include "utils/normalization.h"
Tony Maka2a1ff42019-09-12 15:40:32 +010038#include "utils/optional.h"
Tony Makd9446602019-02-20 18:25:39 +000039#include "utils/regex-match.h"
Tony Mak63959242020-02-07 18:31:16 +000040#include "utils/strings/numbers.h"
41#include "utils/strings/split.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010042#include "utils/utf8/unicodetext.h"
Tony Mak21460022020-03-12 18:29:35 +000043#include "utils/utf8/unilib-common.h"
Tony Mak378c1f52019-03-04 15:58:11 +000044#include "utils/zlib/zlib_regex.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010045
Tony Mak6c4cc672018-09-17 11:48:50 +010046namespace libtextclassifier3 {
Tony Mak448b5862019-03-22 13:36:41 +000047
48using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
49
Tony Mak6c4cc672018-09-17 11:48:50 +010050const std::string& Annotator::kPhoneCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010051 *[]() { return new std::string("phone"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010052const std::string& Annotator::kAddressCollection =
Lukas Zilkae7962cc2018-03-28 18:09:48 +020053 *[]() { return new std::string("address"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010054const std::string& Annotator::kDateCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010055 *[]() { return new std::string("date"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000056const std::string& Annotator::kUrlCollection =
57 *[]() { return new std::string("url"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000058const std::string& Annotator::kEmailCollection =
59 *[]() { return new std::string("email"); }();
Lukas Zilkab23e2122018-02-09 10:25:19 +010060
Lukas Zilka21d8c982018-01-24 11:11:20 +010061namespace {
62const Model* LoadAndVerifyModel(const void* addr, int size) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010063 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
Tony Mak51a9e542018-11-02 13:36:22 +000064 if (VerifyModelBuffer(verifier)) {
65 return GetModel(addr);
Lukas Zilka21d8c982018-01-24 11:11:20 +010066 } else {
67 return nullptr;
68 }
69}
Tony Mak6c4cc672018-09-17 11:48:50 +010070
Tony Mak76d80962020-01-08 17:30:51 +000071const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
72 int size) {
73 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
74 if (VerifyPersonNameModelBuffer(verifier)) {
75 return GetPersonNameModel(addr);
76 } else {
77 return nullptr;
78 }
79}
80
Tony Mak6c4cc672018-09-17 11:48:50 +010081// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
82// create a new instance, assign ownership to owned_lib, and return it.
83const UniLib* MaybeCreateUnilib(const UniLib* lib,
84 std::unique_ptr<UniLib>* owned_lib) {
85 if (lib) {
86 return lib;
87 } else {
88 owned_lib->reset(new UniLib);
89 return owned_lib->get();
90 }
91}
92
93// As above, but for CalendarLib.
94const CalendarLib* MaybeCreateCalendarlib(
95 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
96 if (lib) {
97 return lib;
98 } else {
99 owned_lib->reset(new CalendarLib);
100 return owned_lib->get();
101 }
102}
103
Tony Mak968412a2019-11-13 15:39:57 +0000104// Returns whether the provided input is valid:
105// * Valid utf8 text.
106// * Sane span indices.
107bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
108 if (!context.is_valid()) {
109 return false;
110 }
111 return (span.first >= 0 && span.first < span.second &&
112 span.second <= context.size_codepoints());
113}
114
Tony Mak63959242020-02-07 18:31:16 +0000115std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
116 const flatbuffers::Vector<int32_t>* ints) {
117 if (ints == nullptr) {
118 return {};
119 }
120 std::unordered_set<char32> ints_set;
121 for (auto value : *ints) {
122 ints_set.insert(static_cast<char32>(value));
123 }
124 return ints_set;
125}
126
Tony Mak21460022020-03-12 18:29:35 +0000127DateAnnotationOptions ToDateAnnotationOptions(
128 const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
129 const std::string& reference_timezone, const int64 reference_time_ms_utc) {
130 DateAnnotationOptions result_annotation_options;
131 result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
132 result_annotation_options.reference_timezone = reference_timezone;
133 if (fb_annotation_options != nullptr) {
134 result_annotation_options.enable_special_day_offset =
135 fb_annotation_options->enable_special_day_offset();
136 result_annotation_options.merge_adjacent_components =
137 fb_annotation_options->merge_adjacent_components();
138 result_annotation_options.enable_date_range =
139 fb_annotation_options->enable_date_range();
140 result_annotation_options.include_preposition =
141 fb_annotation_options->include_preposition();
Tony Mak21460022020-03-12 18:29:35 +0000142 if (fb_annotation_options->extra_requested_dates() != nullptr) {
143 for (const auto& extra_requested_date :
144 *fb_annotation_options->extra_requested_dates()) {
145 result_annotation_options.extra_requested_dates.push_back(
146 extra_requested_date->str());
147 }
148 }
Tony Makd99d58c2020-03-19 21:52:02 +0000149 if (fb_annotation_options->ignored_spans() != nullptr) {
150 for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) {
151 result_annotation_options.ignored_spans.push_back(ignored_span->str());
Tony Mak0b8b3322020-03-17 16:30:19 +0000152 }
153 }
Tony Mak21460022020-03-12 18:29:35 +0000154 }
155 return result_annotation_options;
156}
157
Lukas Zilka21d8c982018-01-24 11:11:20 +0100158} // namespace
159
Lukas Zilkaba849e72018-03-08 14:48:21 +0100160tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
161 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100162 TC3_CHECK(selection_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100163 selection_interpreter_ = selection_executor_->CreateInterpreter();
164 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100165 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100166 }
167 }
168 return selection_interpreter_.get();
169}
170
171tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
172 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100173 TC3_CHECK(classification_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100174 classification_interpreter_ = classification_executor_->CreateInterpreter();
175 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100176 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100177 }
178 }
179 return classification_interpreter_.get();
180}
181
Tony Mak6c4cc672018-09-17 11:48:50 +0100182std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
183 const char* buffer, int size, const UniLib* unilib,
184 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100185 const Model* model = LoadAndVerifyModel(buffer, size);
186 if (model == nullptr) {
187 return nullptr;
188 }
189
Lukas Zilkab23e2122018-02-09 10:25:19 +0100190 auto classifier =
Tony Mak6c4cc672018-09-17 11:48:50 +0100191 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100192 if (!classifier->IsInitialized()) {
193 return nullptr;
194 }
195
196 return classifier;
197}
198
Tony Mak6c4cc672018-09-17 11:48:50 +0100199std::unique_ptr<Annotator> Annotator::FromScopedMmap(
200 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
201 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100202 if (!(*mmap)->handle().ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100203 TC3_VLOG(1) << "Mmap failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100204 return nullptr;
205 }
206
207 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
208 (*mmap)->handle().num_bytes());
209 if (!model) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100210 TC3_LOG(ERROR) << "Model verification failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100211 return nullptr;
212 }
213
Tony Mak6c4cc672018-09-17 11:48:50 +0100214 auto classifier = std::unique_ptr<Annotator>(
215 new Annotator(mmap, model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100216 if (!classifier->IsInitialized()) {
217 return nullptr;
218 }
219
220 return classifier;
221}
222
Tony Makdf54e742019-03-26 14:04:00 +0000223std::unique_ptr<Annotator> Annotator::FromScopedMmap(
224 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
225 std::unique_ptr<CalendarLib> calendarlib) {
226 if (!(*mmap)->handle().ok()) {
227 TC3_VLOG(1) << "Mmap failed.";
228 return nullptr;
229 }
230
231 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
232 (*mmap)->handle().num_bytes());
233 if (model == nullptr) {
234 TC3_LOG(ERROR) << "Model verification failed.";
235 return nullptr;
236 }
237
238 auto classifier = std::unique_ptr<Annotator>(
239 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
240 if (!classifier->IsInitialized()) {
241 return nullptr;
242 }
243
244 return classifier;
245}
246
Tony Mak6c4cc672018-09-17 11:48:50 +0100247std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
248 int fd, int offset, int size, const UniLib* unilib,
249 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100250 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
Tony Mak6c4cc672018-09-17 11:48:50 +0100251 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100252}
253
Tony Mak6c4cc672018-09-17 11:48:50 +0100254std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Makdf54e742019-03-26 14:04:00 +0000255 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
256 std::unique_ptr<CalendarLib> calendarlib) {
257 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
258 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
259}
260
261std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Mak6c4cc672018-09-17 11:48:50 +0100262 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100263 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
Tony Mak6c4cc672018-09-17 11:48:50 +0100264 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100265}
266
Tony Makdf54e742019-03-26 14:04:00 +0000267std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268 int fd, std::unique_ptr<UniLib> unilib,
269 std::unique_ptr<CalendarLib> calendarlib) {
270 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
271 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
272}
273
Tony Mak6c4cc672018-09-17 11:48:50 +0100274std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
275 const UniLib* unilib,
276 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100277 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
Tony Mak6c4cc672018-09-17 11:48:50 +0100278 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100279}
280
Tony Makdf54e742019-03-26 14:04:00 +0000281std::unique_ptr<Annotator> Annotator::FromPath(
282 const std::string& path, std::unique_ptr<UniLib> unilib,
283 std::unique_ptr<CalendarLib> calendarlib) {
284 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
285 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
286}
287
Tony Mak6c4cc672018-09-17 11:48:50 +0100288Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
289 const UniLib* unilib, const CalendarLib* calendarlib)
290 : model_(model),
291 mmap_(std::move(*mmap)),
292 owned_unilib_(nullptr),
293 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
294 owned_calendarlib_(nullptr),
295 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
296 ValidateAndInitialize();
297}
298
Tony Makdf54e742019-03-26 14:04:00 +0000299Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
300 std::unique_ptr<UniLib> unilib,
301 std::unique_ptr<CalendarLib> calendarlib)
302 : model_(model),
303 mmap_(std::move(*mmap)),
304 owned_unilib_(std::move(unilib)),
305 unilib_(owned_unilib_.get()),
306 owned_calendarlib_(std::move(calendarlib)),
307 calendarlib_(owned_calendarlib_.get()) {
308 ValidateAndInitialize();
309}
310
Tony Mak6c4cc672018-09-17 11:48:50 +0100311Annotator::Annotator(const Model* model, const UniLib* unilib,
312 const CalendarLib* calendarlib)
313 : model_(model),
314 owned_unilib_(nullptr),
315 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
316 owned_calendarlib_(nullptr),
317 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
318 ValidateAndInitialize();
319}
320
321void Annotator::ValidateAndInitialize() {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100322 initialized_ = false;
323
Lukas Zilka21d8c982018-01-24 11:11:20 +0100324 if (model_ == nullptr) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100325 TC3_LOG(ERROR) << "No model specified.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100326 return;
327 }
328
Lukas Zilkaba849e72018-03-08 14:48:21 +0100329 const bool model_enabled_for_annotation =
330 (model_->triggering_options() != nullptr &&
331 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
332 const bool model_enabled_for_classification =
333 (model_->triggering_options() != nullptr &&
334 (model_->triggering_options()->enabled_modes() &
335 ModeFlag_CLASSIFICATION));
336 const bool model_enabled_for_selection =
337 (model_->triggering_options() != nullptr &&
338 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
339
340 // Annotation requires the selection model.
341 if (model_enabled_for_annotation || model_enabled_for_selection) {
342 if (!model_->selection_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100343 TC3_LOG(ERROR) << "No selection options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100344 return;
345 }
346 if (!model_->selection_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100347 TC3_LOG(ERROR) << "No selection feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100348 return;
349 }
350 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100351 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100352 return;
353 }
354 if (!model_->selection_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100355 TC3_LOG(ERROR) << "No selection model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100356 return;
357 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100358 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100359 if (!selection_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100360 TC3_LOG(ERROR) << "Could not initialize selection executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100361 return;
362 }
363 selection_feature_processor_.reset(
364 new FeatureProcessor(model_->selection_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100365 }
366
Lukas Zilkaba849e72018-03-08 14:48:21 +0100367 // Annotation requires the classification model for conflict resolution and
368 // scoring.
369 // Selection requires the classification model for conflict resolution.
370 if (model_enabled_for_annotation || model_enabled_for_classification ||
371 model_enabled_for_selection) {
372 if (!model_->classification_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100373 TC3_LOG(ERROR) << "No classification options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100374 return;
375 }
376
377 if (!model_->classification_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100378 TC3_LOG(ERROR) << "No classification feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100379 return;
380 }
381
382 if (!model_->classification_feature_options()
383 ->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100384 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100385 return;
386 }
387 if (!model_->classification_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100388 TC3_LOG(ERROR) << "No clf model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100389 return;
390 }
391
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200392 classification_executor_ =
Tony Mak6c4cc672018-09-17 11:48:50 +0100393 ModelExecutor::FromBuffer(model_->classification_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100394 if (!classification_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100395 TC3_LOG(ERROR) << "Could not initialize classification executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100396 return;
397 }
398
399 classification_feature_processor_.reset(new FeatureProcessor(
400 model_->classification_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100401 }
402
Lukas Zilkaba849e72018-03-08 14:48:21 +0100403 // The embeddings need to be specified if the model is to be used for
404 // classification or selection.
405 if (model_enabled_for_annotation || model_enabled_for_classification ||
406 model_enabled_for_selection) {
407 if (!model_->embedding_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100408 TC3_LOG(ERROR) << "No embedding model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100409 return;
410 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100411
Lukas Zilkaba849e72018-03-08 14:48:21 +0100412 // Check that the embedding size of the selection and classification model
413 // matches, as they are using the same embeddings.
414 if (model_enabled_for_selection &&
415 (model_->selection_feature_options()->embedding_size() !=
416 model_->classification_feature_options()->embedding_size() ||
417 model_->selection_feature_options()->embedding_quantization_bits() !=
418 model_->classification_feature_options()
419 ->embedding_quantization_bits())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100420 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100421 return;
422 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100423
Tony Mak6c4cc672018-09-17 11:48:50 +0100424 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200425 model_->embedding_model(),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100426 model_->classification_feature_options()->embedding_size(),
Tony Makdf54e742019-03-26 14:04:00 +0000427 model_->classification_feature_options()->embedding_quantization_bits(),
428 model_->embedding_pruning_mask());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200429 if (!embedding_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100430 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100431 return;
432 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100433 }
434
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200435 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
Lukas Zilkab23e2122018-02-09 10:25:19 +0100436 if (model_->regex_model()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200437 if (!InitializeRegexModel(decompressor.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100438 TC3_LOG(ERROR) << "Could not initialize regex model.";
Lukas Zilka434442d2018-04-25 11:38:51 +0200439 return;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100440 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100441 }
Tony Mak63959242020-02-07 18:31:16 +0000442 if (model_->grammar_datetime_model() &&
443 model_->grammar_datetime_model()->datetime_rules()) {
444 cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100445 unilib_,
Tony Mak63959242020-02-07 18:31:16 +0000446 /*tokenizer_options=*/
447 model_->grammar_datetime_model()->grammar_tokenizer_options(),
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100448 calendarlib_,
Tony Mak21460022020-03-12 18:29:35 +0000449 /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
450 model_->grammar_datetime_model()->target_classification_score(),
451 model_->grammar_datetime_model()->priority_score()));
Tony Mak63959242020-02-07 18:31:16 +0000452 if (!cfg_datetime_parser_) {
453 TC3_LOG(ERROR) << "Could not initialize context free grammar based "
454 "datetime parser.";
455 return;
456 }
Tony Makd99d58c2020-03-19 21:52:02 +0000457 }
458
459 if (model_->datetime_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100460 datetime_parser_ = DatetimeParser::Instance(
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100461 model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100462 if (!datetime_parser_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100463 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100464 return;
465 }
466 }
467
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200468 if (model_->output_options()) {
469 if (model_->output_options()->filtered_collections_annotation()) {
470 for (const auto collection :
471 *model_->output_options()->filtered_collections_annotation()) {
472 filtered_collections_annotation_.insert(collection->str());
473 }
474 }
475 if (model_->output_options()->filtered_collections_classification()) {
476 for (const auto collection :
477 *model_->output_options()->filtered_collections_classification()) {
478 filtered_collections_classification_.insert(collection->str());
479 }
480 }
481 if (model_->output_options()->filtered_collections_selection()) {
482 for (const auto collection :
483 *model_->output_options()->filtered_collections_selection()) {
484 filtered_collections_selection_.insert(collection->str());
485 }
486 }
487 }
488
Tony Mak378c1f52019-03-04 15:58:11 +0000489 if (model_->number_annotator_options() &&
490 model_->number_annotator_options()->enabled()) {
491 number_annotator_.reset(
Tony Mak63959242020-02-07 18:31:16 +0000492 new NumberAnnotator(model_->number_annotator_options(), unilib_));
493 }
494
495 if (model_->money_parsing_options()) {
496 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
497 model_->money_parsing_options()->separators());
Tony Mak378c1f52019-03-04 15:58:11 +0000498 }
499
Tony Makad2e22d2019-03-20 17:35:13 +0000500 if (model_->duration_annotator_options() &&
501 model_->duration_annotator_options()->enabled()) {
502 duration_annotator_.reset(
503 new DurationAnnotator(model_->duration_annotator_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +0100504 selection_feature_processor_.get(), unilib_));
Tony Makad2e22d2019-03-20 17:35:13 +0000505 }
506
Tony Makd9446602019-02-20 18:25:39 +0000507 if (model_->entity_data_schema()) {
508 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
509 model_->entity_data_schema()->Data(),
510 model_->entity_data_schema()->size());
511 if (entity_data_schema_ == nullptr) {
512 TC3_LOG(ERROR) << "Could not load entity data schema data.";
513 return;
514 }
515
516 entity_data_builder_.reset(
517 new ReflectiveFlatbufferBuilder(entity_data_schema_));
518 } else {
Tony Mak378c1f52019-03-04 15:58:11 +0000519 entity_data_schema_ = nullptr;
Tony Makd9446602019-02-20 18:25:39 +0000520 entity_data_builder_ = nullptr;
521 }
522
Tony Mak21460022020-03-12 18:29:35 +0000523 if (model_->grammar_model()) {
524 grammar_annotator_.reset(new GrammarAnnotator(
525 unilib_, model_->grammar_model(), entity_data_builder_.get()));
526 }
527
Tony Makdf54e742019-03-26 14:04:00 +0000528 if (model_->triggering_locales() &&
529 !ParseLocales(model_->triggering_locales()->c_str(),
530 &model_triggering_locales_)) {
Tony Mak378c1f52019-03-04 15:58:11 +0000531 TC3_LOG(ERROR) << "Could not parse model supported locales.";
532 return;
533 }
534
535 if (model_->triggering_options() != nullptr &&
Tony Makdf54e742019-03-26 14:04:00 +0000536 model_->triggering_options()->locales() != nullptr &&
537 !ParseLocales(model_->triggering_options()->locales()->c_str(),
538 &ml_model_triggering_locales_)) {
539 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
540 return;
541 }
542
543 if (model_->triggering_options() != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +0000544 model_->triggering_options()->dictionary_locales() != nullptr &&
545 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
546 &dictionary_locales_)) {
547 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
548 return;
549 }
550
Tony Mak5a12b942020-05-01 12:41:31 +0100551 if (model_->conflict_resolution_options() != nullptr) {
552 prioritize_longest_annotation_ =
553 model_->conflict_resolution_options()->prioritize_longest_annotation();
554 do_conflict_resolution_in_raw_mode_ =
555 model_->conflict_resolution_options()
556 ->do_conflict_resolution_in_raw_mode();
557 }
558
Lukas Zilka21d8c982018-01-24 11:11:20 +0100559 initialized_ = true;
560}
561
Tony Mak6c4cc672018-09-17 11:48:50 +0100562bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100563 if (!model_->regex_model()->patterns()) {
Lukas Zilka434442d2018-04-25 11:38:51 +0200564 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100565 }
566
567 // Initialize pattern recognizers.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100568 int regex_pattern_id = 0;
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100569 for (const auto regex_pattern : *model_->regex_model()->patterns()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200570 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
Tony Mak378c1f52019-03-04 15:58:11 +0000571 UncompressMakeRegexPattern(
572 *unilib_, regex_pattern->pattern(),
573 regex_pattern->compressed_pattern(),
574 model_->regex_model()->lazy_regex_compilation(), decompressor);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100575 if (!compiled_pattern) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100576 TC3_LOG(INFO) << "Failed to load regex pattern";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200577 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100578 }
579
Lukas Zilkaba849e72018-03-08 14:48:21 +0100580 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100581 annotation_regex_patterns_.push_back(regex_pattern_id);
582 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100583 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100584 classification_regex_patterns_.push_back(regex_pattern_id);
585 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100586 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100587 selection_regex_patterns_.push_back(regex_pattern_id);
588 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100589 regex_patterns_.push_back({
Tony Mak854015a2019-01-16 15:56:48 +0000590 regex_pattern,
Tony Mak6c4cc672018-09-17 11:48:50 +0100591 std::move(compiled_pattern),
Tony Mak6c4cc672018-09-17 11:48:50 +0100592 });
Lukas Zilkab23e2122018-02-09 10:25:19 +0100593 ++regex_pattern_id;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100594 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100595
Lukas Zilkab23e2122018-02-09 10:25:19 +0100596 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100597}
598
Tony Mak6c4cc672018-09-17 11:48:50 +0100599bool Annotator::InitializeKnowledgeEngine(
600 const std::string& serialized_config) {
Tony Maka2a1ff42019-09-12 15:40:32 +0100601 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
Tony Mak63959242020-02-07 18:31:16 +0000602 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100603 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
604 return false;
605 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100606 if (model_->triggering_options() != nullptr) {
607 knowledge_engine->SetPriorityScore(
608 model_->triggering_options()->knowledge_priority_score());
609 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100610 knowledge_engine_ = std::move(knowledge_engine);
611 return true;
612}
613
Tony Mak854015a2019-01-16 15:56:48 +0000614bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
Tony Mak5dc5e112019-02-01 14:52:10 +0000615 std::unique_ptr<ContactEngine> contact_engine(
Tony Mak63959242020-02-07 18:31:16 +0000616 new ContactEngine(selection_feature_processor_.get(), unilib_,
617 model_->contact_annotator_options()));
Tony Mak854015a2019-01-16 15:56:48 +0000618 if (!contact_engine->Initialize(serialized_config)) {
619 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
620 return false;
621 }
622 contact_engine_ = std::move(contact_engine);
623 return true;
624}
625
Tony Makd9446602019-02-20 18:25:39 +0000626bool Annotator::InitializeInstalledAppEngine(
627 const std::string& serialized_config) {
628 std::unique_ptr<InstalledAppEngine> installed_app_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000629 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
Tony Makd9446602019-02-20 18:25:39 +0000630 if (!installed_app_engine->Initialize(serialized_config)) {
631 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
632 return false;
633 }
634 installed_app_engine_ = std::move(installed_app_engine);
635 return true;
636}
637
Tony Mak63959242020-02-07 18:31:16 +0000638void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
639 lang_id_ = lang_id;
Tony Mak21460022020-03-12 18:29:35 +0000640 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
Tony Mak63959242020-02-07 18:31:16 +0000641 model_->translate_annotator_options()->enabled()) {
642 translate_annotator_.reset(new TranslateAnnotator(
643 model_->translate_annotator_options(), lang_id_, unilib_));
Tony Mak21460022020-03-12 18:29:35 +0000644 } else {
645 translate_annotator_.reset(nullptr);
Tony Mak63959242020-02-07 18:31:16 +0000646 }
647}
648
Tony Mak21460022020-03-12 18:29:35 +0000649bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
650 int size) {
651 const PersonNameModel* person_name_model =
652 LoadAndVerifyPersonNameModel(buffer, size);
Tony Mak76d80962020-01-08 17:30:51 +0000653
654 if (person_name_model == nullptr) {
655 TC3_LOG(ERROR) << "Person name model verification failed.";
656 return false;
657 }
658
659 if (!person_name_model->enabled()) {
660 return true;
661 }
662
663 std::unique_ptr<PersonNameEngine> person_name_engine(
Tony Mak21460022020-03-12 18:29:35 +0000664 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
Tony Mak76d80962020-01-08 17:30:51 +0000665 if (!person_name_engine->Initialize(person_name_model)) {
666 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
667 return false;
668 }
669 person_name_engine_ = std::move(person_name_engine);
670 return true;
671}
672
Tony Mak21460022020-03-12 18:29:35 +0000673bool Annotator::InitializePersonNameEngineFromScopedMmap(
674 const ScopedMmap& mmap) {
675 if (!mmap.handle().ok()) {
676 TC3_LOG(ERROR) << "Mmap for person name model failed.";
677 return false;
678 }
679
680 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
681 mmap.handle().num_bytes());
682}
683
684bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
685 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
686 return InitializePersonNameEngineFromScopedMmap(*mmap);
687}
688
689bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
690 int size) {
691 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
692 return InitializePersonNameEngineFromScopedMmap(*mmap);
693}
694
Tony Mak5a12b942020-05-01 12:41:31 +0100695bool Annotator::InitializeExperimentalAnnotators() {
696 if (ExperimentalAnnotator::IsEnabled()) {
697 experimental_annotator_.reset(new ExperimentalAnnotator(*unilib_));
698 return true;
699 }
700 return false;
701}
702
Lukas Zilka21d8c982018-01-24 11:11:20 +0100703namespace {
704
705int CountDigits(const std::string& str, CodepointSpan selection_indices) {
706 int count = 0;
707 int i = 0;
708 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
709 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
710 if (i >= selection_indices.first && i < selection_indices.second &&
Tony Mak21460022020-03-12 18:29:35 +0000711 IsDigit(*it)) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100712 ++count;
713 }
714 }
715 return count;
716}
717
Lukas Zilka21d8c982018-01-24 11:11:20 +0100718} // namespace
719
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200720namespace internal {
721// Helper function, which if the initial 'span' contains only white-spaces,
722// moves the selection to a single-codepoint selection on a left or right side
723// of this space.
724CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
725 const UnicodeText& context_unicode,
726 const UniLib& unilib) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100727 TC3_CHECK(ValidNonEmptySpan(span));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200728
729 UnicodeText::const_iterator it;
730
731 // Check that the current selection is all whitespaces.
732 it = context_unicode.begin();
733 std::advance(it, span.first);
734 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
735 if (!unilib.IsWhitespace(*it)) {
736 return span;
737 }
738 }
739
740 CodepointSpan result;
741
742 // Try moving left.
743 result = span;
744 it = context_unicode.begin();
745 std::advance(it, span.first);
746 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
747 --result.first;
748 --it;
749 }
750 result.second = result.first + 1;
751 if (!unilib.IsWhitespace(*it)) {
752 return result;
753 }
754
755 // If moving left didn't find a non-whitespace character, just return the
756 // original span.
757 return span;
758}
759} // namespace internal
760
Tony Mak6c4cc672018-09-17 11:48:50 +0100761bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200762 return !span.classification.empty() &&
763 filtered_collections_annotation_.find(
764 span.classification[0].collection) !=
765 filtered_collections_annotation_.end();
766}
767
Tony Mak6c4cc672018-09-17 11:48:50 +0100768bool Annotator::FilteredForClassification(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200769 const ClassificationResult& classification) const {
770 return filtered_collections_classification_.find(classification.collection) !=
771 filtered_collections_classification_.end();
772}
773
Tony Mak6c4cc672018-09-17 11:48:50 +0100774bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200775 return !span.classification.empty() &&
776 filtered_collections_selection_.find(
777 span.classification[0].collection) !=
778 filtered_collections_selection_.end();
779}
780
Tony Mak378c1f52019-03-04 15:58:11 +0000781namespace {
782inline bool ClassifiedAsOther(
783 const std::vector<ClassificationResult>& classification) {
784 return !classification.empty() &&
785 classification[0].collection == Collections::Other();
786}
787
Tony Maka2a1ff42019-09-12 15:40:32 +0100788} // namespace
789
790float Annotator::GetPriorityScore(
791 const std::vector<ClassificationResult>& classification) const {
Tony Mak378c1f52019-03-04 15:58:11 +0000792 if (!classification.empty() && !ClassifiedAsOther(classification)) {
793 return classification[0].priority_score;
794 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +0100795 if (model_->triggering_options() != nullptr) {
796 return model_->triggering_options()->other_collection_priority_score();
797 } else {
798 return -1000.0;
799 }
Tony Mak378c1f52019-03-04 15:58:11 +0000800 }
801}
Tony Mak378c1f52019-03-04 15:58:11 +0000802
Tony Makdf54e742019-03-26 14:04:00 +0000803bool Annotator::VerifyRegexMatchCandidate(
804 const std::string& context, const VerificationOptions* verification_options,
805 const std::string& match, const UniLib::RegexMatcher* matcher) const {
806 if (verification_options == nullptr) {
807 return true;
808 }
809 if (verification_options->verify_luhn_checksum() &&
810 !VerifyLuhnChecksum(match)) {
811 return false;
812 }
813 const int lua_verifier = verification_options->lua_verifier();
814 if (lua_verifier >= 0) {
815 if (model_->regex_model()->lua_verifier() == nullptr ||
816 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
817 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
818 return false;
819 }
820 return VerifyMatch(
821 context, matcher,
822 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
823 }
824 return true;
825}
826
Tony Mak6c4cc672018-09-17 11:48:50 +0100827CodepointSpan Annotator::SuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100828 const std::string& context, CodepointSpan click_indices,
829 const SelectionOptions& options) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200830 CodepointSpan original_click_indices = click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100831 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100832 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200833 return original_click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100834 }
Tony Mak5a12b942020-05-01 12:41:31 +0100835 if (options.annotation_usecase !=
836 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
837 TC3_LOG(WARNING)
838 << "Invoking SuggestSelection, which is not supported in RAW mode.";
839 return original_click_indices;
840 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100841 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200842 return original_click_indices;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100843 }
844
Tony Makdf54e742019-03-26 14:04:00 +0000845 std::vector<Locale> detected_text_language_tags;
846 if (!ParseLocales(options.detected_text_language_tags,
847 &detected_text_language_tags)) {
848 TC3_LOG(WARNING)
849 << "Failed to parse the detected_text_language_tags in options: "
850 << options.detected_text_language_tags;
851 }
852 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
853 model_triggering_locales_,
854 /*default_value=*/true)) {
855 return original_click_indices;
856 }
857
Lukas Zilkadf710db2018-02-27 12:44:09 +0100858 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
859 /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200860
Tony Mak968412a2019-11-13 15:39:57 +0000861 if (!IsValidSpanInput(context_unicode, click_indices)) {
862 TC3_VLOG(1)
863 << "Trying to run SuggestSelection with invalid input, indices: "
864 << click_indices.first << " " << click_indices.second;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200865 return original_click_indices;
866 }
867
868 if (model_->snap_whitespace_selections()) {
869 // We want to expand a purely white-space selection to a multi-selection it
870 // would've been part of. But with this feature disabled we would do a no-
871 // op, because no token is found. Therefore, we need to modify the
872 // 'click_indices' a bit to include a part of the token, so that the click-
873 // finding logic finds the clicked token correctly. This modification is
874 // done by the following function. Note, that it's enough to check the left
875 // side of the current selection, because if the white-space is a part of a
Tony Mak6c4cc672018-09-17 11:48:50 +0100876 // multi-selection, necessarily both tokens - on the left and the right
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200877 // sides need to be selected. Thus snapping only to the left is sufficient
878 // (there's a check at the bottom that makes sure that if we snap to the
879 // left token but the result does not contain the initial white-space,
880 // returns the original indices).
881 click_indices = internal::SnapLeftIfWhitespaceSelection(
882 click_indices, context_unicode, *unilib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100883 }
884
Lukas Zilkab23e2122018-02-09 10:25:19 +0100885 std::vector<AnnotatedSpan> candidates;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100886 InterpreterManager interpreter_manager(selection_executor_.get(),
887 classification_executor_.get());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200888 std::vector<Token> tokens;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100889 if (!ModelSuggestSelection(context_unicode, click_indices,
Tony Makdf54e742019-03-26 14:04:00 +0000890 detected_text_language_tags, &interpreter_manager,
891 &tokens, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100892 TC3_LOG(ERROR) << "Model suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200893 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100894 }
Tony Mak83d2de62019-04-10 16:12:15 +0100895 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
896 /*is_serialized_entity_data_enabled=*/false)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100897 TC3_LOG(ERROR) << "Regex suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200898 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100899 }
Tony Mak83d2de62019-04-10 16:12:15 +0100900 if (!DatetimeChunk(
901 UTF8ToUnicodeText(context, /*do_copy=*/false),
902 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
903 options.locales, ModeFlag_SELECTION, options.annotation_usecase,
904 /*is_serialized_entity_data_enabled=*/false, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100905 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
906 return original_click_indices;
907 }
Tony Mak378c1f52019-03-04 15:58:11 +0000908 if (knowledge_engine_ != nullptr &&
Tony Maka2a1ff42019-09-12 15:40:32 +0100909 !knowledge_engine_->Chunk(context, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +0100910 options.location_context, Permissions(),
911 &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100912 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200913 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100914 }
Tony Mak378c1f52019-03-04 15:58:11 +0000915 if (contact_engine_ != nullptr &&
Tony Mak854015a2019-01-16 15:56:48 +0000916 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
917 TC3_LOG(ERROR) << "Contact suggest selection failed.";
918 return original_click_indices;
919 }
Tony Mak378c1f52019-03-04 15:58:11 +0000920 if (installed_app_engine_ != nullptr &&
Tony Makd9446602019-02-20 18:25:39 +0000921 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
922 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
923 return original_click_indices;
924 }
Tony Mak378c1f52019-03-04 15:58:11 +0000925 if (number_annotator_ != nullptr &&
926 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
927 &candidates)) {
928 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
929 return original_click_indices;
930 }
Tony Makad2e22d2019-03-20 17:35:13 +0000931 if (duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +0000932 !duration_annotator_->FindAll(context_unicode, tokens,
933 options.annotation_usecase, &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +0000934 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
935 return original_click_indices;
936 }
Tony Mak76d80962020-01-08 17:30:51 +0000937 if (person_name_engine_ != nullptr &&
938 !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
939 TC3_LOG(ERROR) << "Person name suggest selection failed.";
940 return original_click_indices;
941 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100942
Tony Mak21460022020-03-12 18:29:35 +0000943 AnnotatedSpan grammar_suggested_span;
944 if (grammar_annotator_ != nullptr &&
945 grammar_annotator_->SuggestSelection(detected_text_language_tags,
946 context_unicode, click_indices,
947 &grammar_suggested_span)) {
948 candidates.push_back(grammar_suggested_span);
949 }
950
Tony Mak5a12b942020-05-01 12:41:31 +0100951 if (experimental_annotator_ != nullptr) {
952 candidates.push_back(experimental_annotator_->SuggestSelection(
953 context_unicode, click_indices));
954 }
955
Lukas Zilkab23e2122018-02-09 10:25:19 +0100956 // Sort candidates according to their position in the input, so that the next
957 // code can assume that any connected component of overlapping spans forms a
958 // contiguous block.
959 std::sort(candidates.begin(), candidates.end(),
960 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
961 return a.span.first < b.span.first;
962 });
963
964 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +0000965 if (!ResolveConflicts(candidates, context, tokens,
966 detected_text_language_tags, options.annotation_usecase,
967 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100968 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200969 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100970 }
971
Tony Mak378c1f52019-03-04 15:58:11 +0000972 std::sort(candidate_indices.begin(), candidate_indices.end(),
Tony Maka2a1ff42019-09-12 15:40:32 +0100973 [this, &candidates](int a, int b) {
Tony Mak378c1f52019-03-04 15:58:11 +0000974 return GetPriorityScore(candidates[a].classification) >
975 GetPriorityScore(candidates[b].classification);
976 });
977
Lukas Zilkab23e2122018-02-09 10:25:19 +0100978 for (const int i : candidate_indices) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200979 if (SpansOverlap(candidates[i].span, click_indices) &&
980 SpansOverlap(candidates[i].span, original_click_indices)) {
981 // Run model classification if not present but requested and there's a
982 // classification collection filter specified.
983 if (candidates[i].classification.empty() &&
984 model_->selection_options()->always_classify_suggested_selection() &&
985 !filtered_collections_selection_.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +0000986 if (!ModelClassifyText(context, detected_text_language_tags,
987 candidates[i].span, &interpreter_manager,
988 /*embedding_cache=*/nullptr,
989 &candidates[i].classification)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200990 return original_click_indices;
991 }
992 }
993
994 // Ignore if span classification is filtered.
995 if (FilteredForSelection(candidates[i])) {
996 return original_click_indices;
997 }
998
Lukas Zilkab23e2122018-02-09 10:25:19 +0100999 return candidates[i].span;
1000 }
1001 }
1002
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001003 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001004}
1005
1006namespace {
1007// Helper function that returns the index of the first candidate that
1008// transitively does not overlap with the candidate on 'start_index'. If the end
1009// of 'candidates' is reached, it returns the index that points right behind the
1010// array.
1011int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1012 int start_index) {
1013 int first_non_overlapping = start_index + 1;
1014 CodepointSpan conflicting_span = candidates[start_index].span;
1015 while (
1016 first_non_overlapping < candidates.size() &&
1017 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1018 // Grow the span to include the current one.
1019 conflicting_span.second = std::max(
1020 conflicting_span.second, candidates[first_non_overlapping].span.second);
1021
1022 ++first_non_overlapping;
1023 }
1024 return first_non_overlapping;
1025}
1026} // namespace
1027
Tony Mak378c1f52019-03-04 15:58:11 +00001028bool Annotator::ResolveConflicts(
1029 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1030 const std::vector<Token>& cached_tokens,
1031 const std::vector<Locale>& detected_text_language_tags,
1032 AnnotationUsecase annotation_usecase,
1033 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001034 result->clear();
1035 result->reserve(candidates.size());
1036 for (int i = 0; i < candidates.size();) {
1037 int first_non_overlapping =
1038 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1039
1040 const bool conflict_found = first_non_overlapping != (i + 1);
1041 if (conflict_found) {
1042 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00001043 if (!ResolveConflict(context, cached_tokens, candidates,
1044 detected_text_language_tags, i,
1045 first_non_overlapping, annotation_usecase,
1046 interpreter_manager, &candidate_indices)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001047 return false;
1048 }
1049 result->insert(result->end(), candidate_indices.begin(),
1050 candidate_indices.end());
1051 } else {
1052 result->push_back(i);
1053 }
1054
1055 // Skip over the whole conflicting group/go to next candidate.
1056 i = first_non_overlapping;
1057 }
1058 return true;
1059}
1060
1061namespace {
Tony Mak448b5862019-03-22 13:36:41 +00001062// Returns true, if the given two sources do conflict in given annotation
1063// usecase.
1064// - In SMART usecase, all sources do conflict, because there's only 1 possible
1065// annotation for a given span.
1066// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1067// and duration), while others not (e.g. duration and number).
1068bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1069 const AnnotatedSpan::Source source1,
1070 const AnnotatedSpan::Source source2) {
1071 uint32 source_mask =
1072 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1073
Tony Mak378c1f52019-03-04 15:58:11 +00001074 switch (annotation_usecase) {
1075 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
Tony Mak448b5862019-03-22 13:36:41 +00001076 // In the SMART mode, all annotations conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001077 return true;
Tony Mak448b5862019-03-22 13:36:41 +00001078
Tony Mak378c1f52019-03-04 15:58:11 +00001079 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
Tony Mak448b5862019-03-22 13:36:41 +00001080 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1081 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1082 // hours" (duration).
1083 if ((source_mask &
1084 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1085 (source_mask &
1086 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1087 return false;
Tony Mak378c1f52019-03-04 15:58:11 +00001088 }
Tony Mak448b5862019-03-22 13:36:41 +00001089
1090 // A KNOWLEDGE entity does not conflict with anything.
1091 if ((source_mask &
1092 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1093 return false;
1094 }
1095
Tony Makd0ae7c62020-03-27 13:58:00 +00001096 // A PERSONNAME entity does not conflict with anything.
1097 if ((source_mask &
1098 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1099 return false;
1100 }
1101
Tony Mak448b5862019-03-22 13:36:41 +00001102 // Entities from other sources can conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001103 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001104 }
1105}
1106} // namespace
1107
Tony Mak378c1f52019-03-04 15:58:11 +00001108bool Annotator::ResolveConflict(
1109 const std::string& context, const std::vector<Token>& cached_tokens,
1110 const std::vector<AnnotatedSpan>& candidates,
1111 const std::vector<Locale>& detected_text_language_tags, int start_index,
1112 int end_index, AnnotationUsecase annotation_usecase,
1113 InterpreterManager* interpreter_manager,
1114 std::vector<int>* chosen_indices) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001115 std::vector<int> conflicting_indices;
Tony Mak76d80962020-01-08 17:30:51 +00001116 std::unordered_map<int, std::pair<float, int>> scores_lengths;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001117 for (int i = start_index; i < end_index; ++i) {
1118 conflicting_indices.push_back(i);
1119 if (!candidates[i].classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001120 scores_lengths[i] = {
1121 GetPriorityScore(candidates[i].classification),
1122 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001123 continue;
1124 }
1125
1126 // OPTIMIZATION: So that we don't have to classify all the ML model
1127 // spans apriori, we wait until we get here, when they conflict with
1128 // something and we need the actual classification scores. So if the
1129 // candidate conflicts and comes from the model, we need to run a
1130 // classification to determine its priority:
1131 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001132 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1133 candidates[i].span, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001134 /*embedding_cache=*/nullptr, &classification)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001135 return false;
1136 }
1137
1138 if (!classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001139 scores_lengths[i] = {
1140 GetPriorityScore(classification),
1141 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001142 }
1143 }
1144
Tony Mak5a12b942020-05-01 12:41:31 +01001145 std::sort(
1146 conflicting_indices.begin(), conflicting_indices.end(),
1147 [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1148 if (scores_lengths[i].first == scores_lengths[j].first &&
1149 prioritize_longest_annotation_) {
1150 return scores_lengths[i].second > scores_lengths[j].second;
1151 }
1152 return scores_lengths[i].first > scores_lengths[j].first;
1153 });
Lukas Zilkab23e2122018-02-09 10:25:19 +01001154
Tony Mak448b5862019-03-22 13:36:41 +00001155 // Here we keep a set of indices that were chosen, per-source, to enable
1156 // effective computation.
1157 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1158 chosen_indices_for_source_map;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001159
1160 // Greedily place the candidates if they don't conflict with the already
1161 // placed ones.
1162 for (int i = 0; i < conflicting_indices.size(); ++i) {
1163 const int considered_candidate = conflicting_indices[i];
Tony Mak448b5862019-03-22 13:36:41 +00001164
1165 // See if there is a conflict between the candidate and all already placed
1166 // candidates.
1167 bool conflict = false;
1168 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1169 for (auto& source_set_pair : chosen_indices_for_source_map) {
1170 if (source_set_pair.first == candidates[considered_candidate].source) {
1171 chosen_indices_for_source_ptr = &source_set_pair.second;
1172 }
1173
Tony Mak5a12b942020-05-01 12:41:31 +01001174 const bool needs_conflict_resolution =
1175 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1176 (annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1177 do_conflict_resolution_in_raw_mode_);
1178 if (needs_conflict_resolution &&
1179 DoSourcesConflict(annotation_usecase, source_set_pair.first,
Tony Mak448b5862019-03-22 13:36:41 +00001180 candidates[considered_candidate].source) &&
1181 DoesCandidateConflict(considered_candidate, candidates,
1182 source_set_pair.second)) {
1183 conflict = true;
1184 break;
1185 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001186 }
Tony Mak448b5862019-03-22 13:36:41 +00001187
1188 // Skip the candidate if a conflict was found.
1189 if (conflict) {
1190 continue;
1191 }
1192
1193 // If the set of indices for the current source doesn't exist yet,
1194 // initialize it.
1195 if (chosen_indices_for_source_ptr == nullptr) {
1196 SortedIntSet new_set([&candidates](int a, int b) {
1197 return candidates[a].span.first < candidates[b].span.first;
1198 });
1199 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1200 std::move(new_set);
1201 chosen_indices_for_source_ptr =
1202 &chosen_indices_for_source_map[candidates[considered_candidate]
1203 .source];
1204 }
1205
1206 // Place the candidate to the output and to the per-source conflict set.
1207 chosen_indices->push_back(considered_candidate);
1208 chosen_indices_for_source_ptr->insert(considered_candidate);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001209 }
1210
Tony Mak378c1f52019-03-04 15:58:11 +00001211 std::sort(chosen_indices->begin(), chosen_indices->end());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001212
1213 return true;
1214}
1215
Tony Mak6c4cc672018-09-17 11:48:50 +01001216bool Annotator::ModelSuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001217 const UnicodeText& context_unicode, CodepointSpan click_indices,
Tony Makdf54e742019-03-26 14:04:00 +00001218 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001219 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001220 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001221 if (model_->triggering_options() == nullptr ||
1222 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1223 return true;
1224 }
1225
Tony Makdf54e742019-03-26 14:04:00 +00001226 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1227 ml_model_triggering_locales_,
1228 /*default_value=*/true)) {
1229 return true;
1230 }
1231
Lukas Zilka21d8c982018-01-24 11:11:20 +01001232 int click_pos;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001233 *tokens = selection_feature_processor_->Tokenize(context_unicode);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001234 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001235 context_unicode, click_indices,
1236 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001237 tokens, &click_pos);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001238 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001239 TC3_VLOG(1) << "Could not calculate the click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001240 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001241 }
1242
1243 const int symmetry_context_size =
1244 model_->selection_options()->symmetry_context_size();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001245 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
Lukas Zilkab23e2122018-02-09 10:25:19 +01001246 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1247 ->bounds_sensitive_features();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001248
1249 // The symmetry context span is the clicked token with symmetry_context_size
1250 // tokens on either side.
1251 const TokenSpan symmetry_context_span = IntersectTokenSpans(
1252 ExpandTokenSpan(SingleTokenSpan(click_pos),
1253 /*num_tokens_left=*/symmetry_context_size,
1254 /*num_tokens_right=*/symmetry_context_size),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001255 {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001256
Lukas Zilkab23e2122018-02-09 10:25:19 +01001257 // Compute the extraction span based on the model type.
1258 TokenSpan extraction_span;
1259 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1260 // The extraction span is the symmetry context span expanded to include
1261 // max_selection_span tokens on either side, which is how far a selection
1262 // can stretch from the click, plus a relevant number of tokens outside of
1263 // the bounds of the selection.
1264 const int max_selection_span =
1265 selection_feature_processor_->GetOptions()->max_selection_span();
1266 extraction_span =
1267 ExpandTokenSpan(symmetry_context_span,
1268 /*num_tokens_left=*/max_selection_span +
1269 bounds_sensitive_features->num_tokens_before(),
1270 /*num_tokens_right=*/max_selection_span +
1271 bounds_sensitive_features->num_tokens_after());
1272 } else {
1273 // The extraction span is the symmetry context span expanded to include
1274 // context_size tokens on either side.
1275 const int context_size =
1276 selection_feature_processor_->GetOptions()->context_size();
1277 extraction_span = ExpandTokenSpan(symmetry_context_span,
1278 /*num_tokens_left=*/context_size,
1279 /*num_tokens_right=*/context_size);
1280 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001281 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilkab23e2122018-02-09 10:25:19 +01001282
Lukas Zilka434442d2018-04-25 11:38:51 +02001283 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1284 *tokens, extraction_span)) {
1285 return true;
1286 }
1287
Lukas Zilkab23e2122018-02-09 10:25:19 +01001288 std::unique_ptr<CachedFeatures> cached_features;
1289 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001290 *tokens, extraction_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001291 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1292 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001293 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001294 selection_feature_processor_->EmbeddingSize() +
1295 selection_feature_processor_->DenseFeaturesCount(),
1296 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001297 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001298 return false;
1299 }
1300
1301 // Produce selection model candidates.
1302 std::vector<TokenSpan> chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001303 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001304 interpreter_manager->SelectionInterpreter(), *cached_features,
1305 &chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001306 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001307 return false;
1308 }
1309
1310 for (const TokenSpan& chunk : chunks) {
1311 AnnotatedSpan candidate;
1312 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001313 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001314 if (model_->selection_options()->strip_unpaired_brackets()) {
1315 candidate.span =
1316 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1317 }
1318
1319 // Only output non-empty spans.
1320 if (candidate.span.first != candidate.span.second) {
1321 result->push_back(candidate);
1322 }
1323 }
1324 return true;
1325}
1326
Tony Mak6c4cc672018-09-17 11:48:50 +01001327bool Annotator::ModelClassifyText(
Tony Mak378c1f52019-03-04 15:58:11 +00001328 const std::string& context,
1329 const std::vector<Locale>& detected_text_language_tags,
1330 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001331 FeatureProcessor::EmbeddingCache* embedding_cache,
1332 std::vector<ClassificationResult>* classification_results) const {
Tony Mak378c1f52019-03-04 15:58:11 +00001333 return ModelClassifyText(context, {}, detected_text_language_tags,
1334 selection_indices, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001335 embedding_cache, classification_results);
1336}
1337
1338namespace internal {
1339std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1340 CodepointSpan selection_indices,
1341 TokenSpan tokens_around_selection_to_copy) {
1342 const auto first_selection_token = std::upper_bound(
1343 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1344 [](int selection_start, const Token& token) {
1345 return selection_start < token.end;
1346 });
1347 const auto last_selection_token = std::lower_bound(
1348 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1349 [](const Token& token, int selection_end) {
1350 return token.start < selection_end;
1351 });
1352
1353 const int64 first_token = std::max(
1354 static_cast<int64>(0),
1355 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1356 tokens_around_selection_to_copy.first));
1357 const int64 last_token = std::min(
1358 static_cast<int64>(cached_tokens.size()),
1359 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1360 tokens_around_selection_to_copy.second));
1361
1362 std::vector<Token> tokens;
1363 tokens.reserve(last_token - first_token);
1364 for (int i = first_token; i < last_token; ++i) {
1365 tokens.push_back(cached_tokens[i]);
1366 }
1367 return tokens;
1368}
1369} // namespace internal
1370
Tony Mak6c4cc672018-09-17 11:48:50 +01001371TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001372 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1373 bounds_sensitive_features =
1374 classification_feature_processor_->GetOptions()
1375 ->bounds_sensitive_features();
1376 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1377 // The extraction span is the selection span expanded to include a relevant
1378 // number of tokens outside of the bounds of the selection.
1379 return {bounds_sensitive_features->num_tokens_before(),
1380 bounds_sensitive_features->num_tokens_after()};
1381 } else {
1382 // The extraction span is the clicked token with context_size tokens on
1383 // either side.
1384 const int context_size =
1385 selection_feature_processor_->GetOptions()->context_size();
1386 return {context_size, context_size};
1387 }
1388}
1389
Tony Mak378c1f52019-03-04 15:58:11 +00001390namespace {
1391// Sorts the classification results from high score to low score.
1392void SortClassificationResults(
1393 std::vector<ClassificationResult>* classification_results) {
1394 std::sort(classification_results->begin(), classification_results->end(),
1395 [](const ClassificationResult& a, const ClassificationResult& b) {
1396 return a.score > b.score;
1397 });
1398}
1399} // namespace
1400
Tony Mak6c4cc672018-09-17 11:48:50 +01001401bool Annotator::ModelClassifyText(
Lukas Zilkaba849e72018-03-08 14:48:21 +01001402 const std::string& context, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001403 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001404 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1405 FeatureProcessor::EmbeddingCache* embedding_cache,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001406 std::vector<ClassificationResult>* classification_results) const {
1407 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00001408 return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1409 selection_indices, interpreter_manager,
1410 embedding_cache, classification_results, &tokens);
1411}
1412
1413bool Annotator::ModelClassifyText(
1414 const std::string& context, const std::vector<Token>& cached_tokens,
1415 const std::vector<Locale>& detected_text_language_tags,
1416 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1417 FeatureProcessor::EmbeddingCache* embedding_cache,
1418 std::vector<ClassificationResult>* classification_results,
1419 std::vector<Token>* tokens) const {
1420 if (model_->triggering_options() == nullptr ||
1421 !(model_->triggering_options()->enabled_modes() &
1422 ModeFlag_CLASSIFICATION)) {
1423 return true;
1424 }
1425
Tony Makdf54e742019-03-26 14:04:00 +00001426 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1427 ml_model_triggering_locales_,
1428 /*default_value=*/true)) {
1429 return true;
1430 }
1431
Lukas Zilkaba849e72018-03-08 14:48:21 +01001432 if (cached_tokens.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001433 *tokens = classification_feature_processor_->Tokenize(context);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001434 } else {
Tony Mak378c1f52019-03-04 15:58:11 +00001435 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1436 ClassifyTextUpperBoundNeededTokens());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001437 }
1438
Lukas Zilkab23e2122018-02-09 10:25:19 +01001439 int click_pos;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001440 classification_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001441 context, selection_indices,
1442 classification_feature_processor_->GetOptions()
1443 ->only_use_line_with_click(),
Tony Mak378c1f52019-03-04 15:58:11 +00001444 tokens, &click_pos);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001445 const TokenSpan selection_token_span =
Tony Mak378c1f52019-03-04 15:58:11 +00001446 CodepointSpanToTokenSpan(*tokens, selection_indices);
Lukas Zilka434442d2018-04-25 11:38:51 +02001447 const int selection_num_tokens = TokenSpanSize(selection_token_span);
1448 if (model_->classification_options()->max_num_tokens() > 0 &&
1449 model_->classification_options()->max_num_tokens() <
1450 selection_num_tokens) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001451 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001452 return true;
1453 }
1454
Lukas Zilkab23e2122018-02-09 10:25:19 +01001455 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1456 bounds_sensitive_features =
1457 classification_feature_processor_->GetOptions()
1458 ->bounds_sensitive_features();
1459 if (selection_token_span.first == kInvalidIndex ||
1460 selection_token_span.second == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001461 TC3_LOG(ERROR) << "Could not determine span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001462 return false;
1463 }
1464
1465 // Compute the extraction span based on the model type.
1466 TokenSpan extraction_span;
1467 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1468 // The extraction span is the selection span expanded to include a relevant
1469 // number of tokens outside of the bounds of the selection.
1470 extraction_span = ExpandTokenSpan(
1471 selection_token_span,
1472 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1473 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1474 } else {
1475 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001476 TC3_LOG(ERROR) << "Couldn't choose a click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001477 return false;
1478 }
1479 // The extraction span is the clicked token with context_size tokens on
1480 // either side.
1481 const int context_size =
Lukas Zilkaba849e72018-03-08 14:48:21 +01001482 classification_feature_processor_->GetOptions()->context_size();
Lukas Zilkab23e2122018-02-09 10:25:19 +01001483 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
1484 /*num_tokens_left=*/context_size,
1485 /*num_tokens_right=*/context_size);
1486 }
Tony Mak378c1f52019-03-04 15:58:11 +00001487 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001488
Lukas Zilka434442d2018-04-25 11:38:51 +02001489 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak378c1f52019-03-04 15:58:11 +00001490 *tokens, extraction_span)) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001491 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001492 return true;
1493 }
1494
Lukas Zilka21d8c982018-01-24 11:11:20 +01001495 std::unique_ptr<CachedFeatures> cached_features;
1496 if (!classification_feature_processor_->ExtractFeatures(
Tony Mak378c1f52019-03-04 15:58:11 +00001497 *tokens, extraction_span, selection_indices,
1498 embedding_executor_.get(), embedding_cache,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001499 classification_feature_processor_->EmbeddingSize() +
1500 classification_feature_processor_->DenseFeaturesCount(),
1501 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001502 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001503 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001504 }
1505
Lukas Zilkab23e2122018-02-09 10:25:19 +01001506 std::vector<float> features;
1507 features.reserve(cached_features->OutputFeaturesSize());
1508 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1509 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1510 &features);
1511 } else {
1512 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001513 }
1514
Lukas Zilkaba849e72018-03-08 14:48:21 +01001515 TensorView<float> logits = classification_executor_->ComputeLogits(
1516 TensorView<float>(features.data(),
1517 {1, static_cast<int>(features.size())}),
1518 interpreter_manager->ClassificationInterpreter());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001519 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001520 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001521 return false;
1522 }
1523
1524 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1525 logits.dim(1) != classification_feature_processor_->NumCollections()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001526 TC3_LOG(ERROR) << "Mismatching output";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001527 return false;
1528 }
1529
1530 const std::vector<float> scores =
1531 ComputeSoftmax(logits.data(), logits.dim(1));
1532
Tony Mak81e52422019-04-30 09:34:45 +01001533 if (scores.empty()) {
1534 *classification_results = {{Collections::Other(), 1.0}};
1535 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001536 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001537
Tony Mak81e52422019-04-30 09:34:45 +01001538 const int best_score_index =
1539 std::max_element(scores.begin(), scores.end()) - scores.begin();
1540 const std::string top_collection =
1541 classification_feature_processor_->LabelToCollection(best_score_index);
1542
1543 // Sanity checks.
1544 if (top_collection == Collections::Phone()) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001545 const int digit_count = CountDigits(context, selection_indices);
1546 if (digit_count <
1547 model_->classification_options()->phone_min_num_digits() ||
1548 digit_count >
1549 model_->classification_options()->phone_max_num_digits()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001550 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001551 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001552 }
Tony Mak81e52422019-04-30 09:34:45 +01001553 } else if (top_collection == Collections::Address()) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001554 if (selection_num_tokens <
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001555 model_->classification_options()->address_min_num_tokens()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001556 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001557 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001558 }
Tony Mak81e52422019-04-30 09:34:45 +01001559 } else if (top_collection == Collections::Dictionary()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001560 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1561 dictionary_locales_,
Tony Makdf54e742019-03-26 14:04:00 +00001562 /*default_value=*/false)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001563 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001564 return true;
Tony Mak378c1f52019-03-04 15:58:11 +00001565 }
1566 }
Tony Mak81e52422019-04-30 09:34:45 +01001567
Tony Makd99d58c2020-03-19 21:52:02 +00001568 *classification_results = {{top_collection, /*arg_score=*/1.0,
1569 /*arg_priority_score=*/scores[best_score_index]}};
1570
1571 // For some entities, we might want to clamp the priority score, for better
1572 // conflict resolution between entities.
1573 if (model_->triggering_options() != nullptr &&
1574 model_->triggering_options()->collection_to_priority() != nullptr) {
1575 if (auto entry =
1576 model_->triggering_options()->collection_to_priority()->LookupByKey(
1577 top_collection.c_str())) {
1578 (*classification_results)[0].priority_score *= entry->value();
1579 }
1580 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001581 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001582}
1583
Tony Mak6c4cc672018-09-17 11:48:50 +01001584bool Annotator::RegexClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001585 const std::string& context, CodepointSpan selection_indices,
Tony Mak378c1f52019-03-04 15:58:11 +00001586 std::vector<ClassificationResult>* classification_result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001587 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001588 UTF8ToUnicodeText(context, /*do_copy=*/false)
1589 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001590 const UnicodeText selection_text_unicode(
1591 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1592
1593 // Check whether any of the regular expressions match.
1594 for (const int pattern_id : classification_regex_patterns_) {
1595 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1596 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1597 regex_pattern.pattern->Matcher(selection_text_unicode);
1598 int status = UniLib::RegexMatcher::kNoError;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001599 bool matches;
Tony Mak854015a2019-01-16 15:56:48 +00001600 if (regex_pattern.config->use_approximate_matching()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001601 matches = matcher->ApproximatelyMatches(&status);
1602 } else {
1603 matches = matcher->Matches(&status);
1604 }
1605 if (status != UniLib::RegexMatcher::kNoError) {
1606 return false;
1607 }
Tony Makdf54e742019-03-26 14:04:00 +00001608 if (matches && VerifyRegexMatchCandidate(
1609 context, regex_pattern.config->verification_options(),
1610 selection_text, matcher.get())) {
Tony Mak378c1f52019-03-04 15:58:11 +00001611 classification_result->push_back(
1612 {regex_pattern.config->collection_name()->str(),
1613 regex_pattern.config->target_classification_score(),
1614 regex_pattern.config->priority_score()});
Tony Makd9446602019-02-20 18:25:39 +00001615 if (!SerializedEntityDataFromRegexMatch(
1616 regex_pattern.config, matcher.get(),
Tony Mak378c1f52019-03-04 15:58:11 +00001617 &classification_result->back().serialized_entity_data)) {
Tony Makd9446602019-02-20 18:25:39 +00001618 TC3_LOG(ERROR) << "Could not get entity data.";
1619 return false;
1620 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001621 }
1622 }
1623
Tony Mak378c1f52019-03-04 15:58:11 +00001624 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001625}
1626
Tony Mak5dc5e112019-02-01 14:52:10 +00001627namespace {
1628std::string PickCollectionForDatetime(
1629 const DatetimeParseResult& datetime_parse_result) {
1630 switch (datetime_parse_result.granularity) {
1631 case GRANULARITY_HOUR:
1632 case GRANULARITY_MINUTE:
1633 case GRANULARITY_SECOND:
1634 return Collections::DateTime();
1635 default:
1636 return Collections::Date();
1637 }
1638}
Tony Mak83d2de62019-04-10 16:12:15 +01001639
1640std::string CreateDatetimeSerializedEntityData(
1641 const DatetimeParseResult& parse_result) {
1642 EntityDataT entity_data;
1643 entity_data.datetime.reset(new EntityData_::DatetimeT());
1644 entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
1645 entity_data.datetime->granularity =
1646 static_cast<EntityData_::Datetime_::Granularity>(
1647 parse_result.granularity);
1648
Tony Maka2a1ff42019-09-12 15:40:32 +01001649 for (const auto& c : parse_result.datetime_components) {
1650 EntityData_::Datetime_::DatetimeComponentT datetime_component;
1651 datetime_component.absolute_value = c.value;
1652 datetime_component.relative_count = c.relative_count;
1653 datetime_component.component_type =
1654 static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
1655 c.component_type);
1656 datetime_component.relation_type =
1657 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
1658 if (c.relative_qualifier !=
1659 DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
1660 datetime_component.relation_type =
1661 EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
1662 }
1663 entity_data.datetime->datetime_component.emplace_back(
1664 new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
1665 }
Tony Mak83d2de62019-04-10 16:12:15 +01001666 flatbuffers::FlatBufferBuilder builder;
1667 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
1668 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
1669 builder.GetSize());
1670}
Tony Mak63959242020-02-07 18:31:16 +00001671
Tony Mak5dc5e112019-02-01 14:52:10 +00001672} // namespace
1673
Tony Mak6c4cc672018-09-17 11:48:50 +01001674bool Annotator::DatetimeClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001675 const std::string& context, CodepointSpan selection_indices,
1676 const ClassificationOptions& options,
Tony Mak854015a2019-01-16 15:56:48 +00001677 std::vector<ClassificationResult>* classification_results) const {
Tony Mak63959242020-02-07 18:31:16 +00001678 if (!datetime_parser_ && !cfg_datetime_parser_) {
Tony Makd99d58c2020-03-19 21:52:02 +00001679 return true;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001680 }
1681
Lukas Zilkab23e2122018-02-09 10:25:19 +01001682 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001683 UTF8ToUnicodeText(context, /*do_copy=*/false)
1684 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001685
1686 std::vector<DatetimeParseResultSpan> datetime_spans;
Tony Makd99d58c2020-03-19 21:52:02 +00001687
Tony Mak63959242020-02-07 18:31:16 +00001688 if (cfg_datetime_parser_) {
1689 if (!(model_->grammar_datetime_model()->enabled_modes() &
1690 ModeFlag_CLASSIFICATION)) {
1691 return true;
1692 }
1693 std::vector<Locale> parsed_locales;
1694 ParseLocales(options.locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00001695 cfg_datetime_parser_->Parse(
1696 selection_text,
1697 ToDateAnnotationOptions(
1698 model_->grammar_datetime_model()->annotation_options(),
1699 options.reference_timezone, options.reference_time_ms_utc),
1700 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00001701 }
1702
1703 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00001704 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1705 options.reference_timezone, options.locales,
1706 ModeFlag_CLASSIFICATION,
1707 options.annotation_usecase,
1708 /*anchor_start_end=*/true, &datetime_spans)) {
1709 TC3_LOG(ERROR) << "Error during parsing datetime.";
1710 return false;
1711 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001712 }
Tony Makd99d58c2020-03-19 21:52:02 +00001713
Lukas Zilkab23e2122018-02-09 10:25:19 +01001714 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1715 // Only consider the result valid if the selection and extracted datetime
1716 // spans exactly match.
1717 if (std::make_pair(datetime_span.span.first + selection_indices.first,
1718 datetime_span.span.second + selection_indices.first) ==
1719 selection_indices) {
Tony Mak854015a2019-01-16 15:56:48 +00001720 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1721 classification_results->emplace_back(
Tony Mak5dc5e112019-02-01 14:52:10 +00001722 PickCollectionForDatetime(parse_result),
1723 datetime_span.target_classification_score);
Tony Mak854015a2019-01-16 15:56:48 +00001724 classification_results->back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01001725 classification_results->back().serialized_entity_data =
1726 CreateDatetimeSerializedEntityData(parse_result);
Tony Mak378c1f52019-03-04 15:58:11 +00001727 classification_results->back().priority_score =
1728 datetime_span.priority_score;
Tony Mak854015a2019-01-16 15:56:48 +00001729 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001730 return true;
1731 }
1732 }
Tony Mak378c1f52019-03-04 15:58:11 +00001733 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001734}
1735
Tony Mak6c4cc672018-09-17 11:48:50 +01001736std::vector<ClassificationResult> Annotator::ClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001737 const std::string& context, CodepointSpan selection_indices,
1738 const ClassificationOptions& options) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01001739 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001740 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001741 return {};
1742 }
Tony Mak5a12b942020-05-01 12:41:31 +01001743 if (options.annotation_usecase !=
1744 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1745 TC3_LOG(WARNING)
1746 << "Invoking ClassifyText, which is not supported in RAW mode.";
1747 return {};
1748 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001749 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1750 return {};
1751 }
1752
Tony Makdf54e742019-03-26 14:04:00 +00001753 std::vector<Locale> detected_text_language_tags;
1754 if (!ParseLocales(options.detected_text_language_tags,
1755 &detected_text_language_tags)) {
1756 TC3_LOG(WARNING)
1757 << "Failed to parse the detected_text_language_tags in options: "
1758 << options.detected_text_language_tags;
1759 }
1760 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1761 model_triggering_locales_,
1762 /*default_value=*/true)) {
1763 return {};
1764 }
1765
Tony Mak968412a2019-11-13 15:39:57 +00001766 if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
1767 selection_indices)) {
1768 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
Tony Mak6c4cc672018-09-17 11:48:50 +01001769 << std::get<0>(selection_indices) << " "
1770 << std::get<1>(selection_indices);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001771 return {};
1772 }
1773
Tony Mak378c1f52019-03-04 15:58:11 +00001774 // We'll accumulate a list of candidates, and pick the best candidate in the
1775 // end.
1776 std::vector<AnnotatedSpan> candidates;
1777
Tony Mak6c4cc672018-09-17 11:48:50 +01001778 // Try the knowledge engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001779 // TODO(b/126579108): Propagate error status.
Tony Mak6c4cc672018-09-17 11:48:50 +01001780 ClassificationResult knowledge_result;
Tony Mak63959242020-02-07 18:31:16 +00001781 if (knowledge_engine_ &&
1782 knowledge_engine_->ClassifyText(
1783 context, selection_indices, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +01001784 options.location_context, Permissions(), &knowledge_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001785 candidates.push_back({selection_indices, {knowledge_result}});
1786 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
Tony Mak854015a2019-01-16 15:56:48 +00001787 }
1788
Tony Maka2a1ff42019-09-12 15:40:32 +01001789 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1790
Tony Mak854015a2019-01-16 15:56:48 +00001791 // Try the contact engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001792 // TODO(b/126579108): Propagate error status.
Tony Mak854015a2019-01-16 15:56:48 +00001793 ClassificationResult contact_result;
1794 if (contact_engine_ && contact_engine_->ClassifyText(
1795 context, selection_indices, &contact_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001796 candidates.push_back({selection_indices, {contact_result}});
Tony Mak6c4cc672018-09-17 11:48:50 +01001797 }
1798
Tony Mak76d80962020-01-08 17:30:51 +00001799 // Try the person name engine.
1800 ClassificationResult person_name_result;
1801 if (person_name_engine_ &&
1802 person_name_engine_->ClassifyText(context, selection_indices,
1803 &person_name_result)) {
1804 candidates.push_back({selection_indices, {person_name_result}});
Tony Makd0ae7c62020-03-27 13:58:00 +00001805 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
Tony Mak76d80962020-01-08 17:30:51 +00001806 }
1807
Tony Makd9446602019-02-20 18:25:39 +00001808 // Try the installed app engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001809 // TODO(b/126579108): Propagate error status.
Tony Makd9446602019-02-20 18:25:39 +00001810 ClassificationResult installed_app_result;
1811 if (installed_app_engine_ &&
1812 installed_app_engine_->ClassifyText(context, selection_indices,
1813 &installed_app_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001814 candidates.push_back({selection_indices, {installed_app_result}});
Tony Makd9446602019-02-20 18:25:39 +00001815 }
1816
Lukas Zilkab23e2122018-02-09 10:25:19 +01001817 // Try the regular expression models.
Tony Mak378c1f52019-03-04 15:58:11 +00001818 std::vector<ClassificationResult> regex_results;
1819 if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1820 return {};
1821 }
1822 for (const ClassificationResult& result : regex_results) {
1823 candidates.push_back({selection_indices, {result}});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001824 }
1825
Lukas Zilkab23e2122018-02-09 10:25:19 +01001826 // Try the date model.
Tony Mak378c1f52019-03-04 15:58:11 +00001827 //
1828 // DatetimeClassifyText only returns the first result, which can however have
1829 // more interpretations. They are inserted in the candidates as a single
1830 // AnnotatedSpan, so that they get treated together by the conflict resolution
1831 // algorithm.
Tony Mak854015a2019-01-16 15:56:48 +00001832 std::vector<ClassificationResult> datetime_results;
Tony Mak378c1f52019-03-04 15:58:11 +00001833 if (!DatetimeClassifyText(context, selection_indices, options,
1834 &datetime_results)) {
1835 return {};
1836 }
1837 if (!datetime_results.empty()) {
1838 candidates.push_back({selection_indices, std::move(datetime_results)});
Tony Mak448b5862019-03-22 13:36:41 +00001839 candidates.back().source = AnnotatedSpan::Source::DATETIME;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001840 }
1841
Tony Mak5a12b942020-05-01 12:41:31 +01001842 const UnicodeText context_unicode =
1843 UTF8ToUnicodeText(context, /*do_copy=*/false);
1844
Tony Mak378c1f52019-03-04 15:58:11 +00001845 // Try the number annotator.
1846 // TODO(b/126579108): Propagate error status.
1847 ClassificationResult number_annotator_result;
1848 if (number_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001849 number_annotator_->ClassifyText(context_unicode, selection_indices,
1850 options.annotation_usecase,
1851 &number_annotator_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001852 candidates.push_back({selection_indices, {number_annotator_result}});
1853 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001854
Tony Makad2e22d2019-03-20 17:35:13 +00001855 // Try the duration annotator.
1856 ClassificationResult duration_annotator_result;
1857 if (duration_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001858 duration_annotator_->ClassifyText(context_unicode, selection_indices,
1859 options.annotation_usecase,
1860 &duration_annotator_result)) {
Tony Makad2e22d2019-03-20 17:35:13 +00001861 candidates.push_back({selection_indices, {duration_annotator_result}});
Tony Mak448b5862019-03-22 13:36:41 +00001862 candidates.back().source = AnnotatedSpan::Source::DURATION;
Tony Makad2e22d2019-03-20 17:35:13 +00001863 }
1864
Tony Mak63959242020-02-07 18:31:16 +00001865 // Try the translate annotator.
1866 ClassificationResult translate_annotator_result;
1867 if (translate_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001868 translate_annotator_->ClassifyText(context_unicode, selection_indices,
1869 options.user_familiar_language_tags,
1870 &translate_annotator_result)) {
Tony Mak63959242020-02-07 18:31:16 +00001871 candidates.push_back({selection_indices, {translate_annotator_result}});
1872 }
1873
Tony Mak21460022020-03-12 18:29:35 +00001874 // Try the grammar model.
1875 ClassificationResult grammar_annotator_result;
1876 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
Tony Mak5a12b942020-05-01 12:41:31 +01001877 detected_text_language_tags, context_unicode,
Tony Mak21460022020-03-12 18:29:35 +00001878 selection_indices, &grammar_annotator_result)) {
1879 candidates.push_back({selection_indices, {grammar_annotator_result}});
1880 }
1881
Tony Mak5a12b942020-05-01 12:41:31 +01001882 ClassificationResult experimental_annotator_result;
1883 if (experimental_annotator_ &&
1884 experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1885 &experimental_annotator_result)) {
1886 candidates.push_back({selection_indices, {experimental_annotator_result}});
1887 }
1888
Tony Mak378c1f52019-03-04 15:58:11 +00001889 // Try the ML model.
1890 //
1891 // The output of the model is considered as an exclusive 1-of-N choice. That's
1892 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1893 // span for each candidate, like e.g. the regex model.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001894 InterpreterManager interpreter_manager(selection_executor_.get(),
1895 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001896 std::vector<ClassificationResult> model_results;
1897 std::vector<Token> tokens;
1898 if (!ModelClassifyText(
1899 context, /*cached_tokens=*/{}, detected_text_language_tags,
1900 selection_indices, &interpreter_manager,
1901 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1902 return {};
1903 }
1904 if (!model_results.empty()) {
1905 candidates.push_back({selection_indices, std::move(model_results)});
1906 }
1907
1908 std::vector<int> candidate_indices;
1909 if (!ResolveConflicts(candidates, context, tokens,
1910 detected_text_language_tags, options.annotation_usecase,
1911 &interpreter_manager, &candidate_indices)) {
1912 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1913 return {};
1914 }
1915
1916 std::vector<ClassificationResult> results;
1917 for (const int i : candidate_indices) {
1918 for (const ClassificationResult& result : candidates[i].classification) {
1919 if (!FilteredForClassification(result)) {
1920 results.push_back(result);
1921 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001922 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001923 }
1924
Tony Mak378c1f52019-03-04 15:58:11 +00001925 // Sort results according to score.
1926 std::sort(results.begin(), results.end(),
1927 [](const ClassificationResult& a, const ClassificationResult& b) {
1928 return a.score > b.score;
1929 });
1930
1931 if (results.empty()) {
Tony Mak81e52422019-04-30 09:34:45 +01001932 results = {{Collections::Other(), 1.0}};
Tony Mak378c1f52019-03-04 15:58:11 +00001933 }
Tony Mak378c1f52019-03-04 15:58:11 +00001934 return results;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001935}
1936
Tony Mak378c1f52019-03-04 15:58:11 +00001937bool Annotator::ModelAnnotate(
1938 const std::string& context,
1939 const std::vector<Locale>& detected_text_language_tags,
1940 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1941 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001942 if (model_->triggering_options() == nullptr ||
1943 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1944 return true;
1945 }
1946
Tony Makdf54e742019-03-26 14:04:00 +00001947 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1948 ml_model_triggering_locales_,
1949 /*default_value=*/true)) {
1950 return true;
1951 }
1952
Lukas Zilka21d8c982018-01-24 11:11:20 +01001953 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1954 /*do_copy=*/false);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001955 std::vector<UnicodeTextRange> lines;
1956 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1957 lines.push_back({context_unicode.begin(), context_unicode.end()});
1958 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +01001959 lines = selection_feature_processor_->SplitContext(
1960 context_unicode, selection_feature_processor_->GetOptions()
1961 ->use_pipe_character_for_newline());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001962 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001963
Lukas Zilkaba849e72018-03-08 14:48:21 +01001964 const float min_annotate_confidence =
1965 (model_->triggering_options() != nullptr
1966 ? model_->triggering_options()->min_annotate_confidence()
1967 : 0.f);
1968
Lukas Zilkab23e2122018-02-09 10:25:19 +01001969 for (const UnicodeTextRange& line : lines) {
Tony Mak408c6b82019-03-08 17:57:27 +00001970 FeatureProcessor::EmbeddingCache embedding_cache;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001971 const std::string line_str =
1972 UnicodeText::UTF8Substring(line.first, line.second);
1973
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001974 *tokens = selection_feature_processor_->Tokenize(line_str);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001975 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001976 line_str, {0, std::distance(line.first, line.second)},
1977 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001978 tokens,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001979 /*click_pos=*/nullptr);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001980 const TokenSpan full_line_span = {0, tokens->size()};
Lukas Zilka21d8c982018-01-24 11:11:20 +01001981
Lukas Zilka434442d2018-04-25 11:38:51 +02001982 // TODO(zilka): Add support for greater granularity of this check.
1983 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1984 *tokens, full_line_span)) {
1985 continue;
1986 }
1987
Lukas Zilka21d8c982018-01-24 11:11:20 +01001988 std::unique_ptr<CachedFeatures> cached_features;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001989 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001990 *tokens, full_line_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001991 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1992 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001993 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001994 selection_feature_processor_->EmbeddingSize() +
1995 selection_feature_processor_->DenseFeaturesCount(),
Lukas Zilka21d8c982018-01-24 11:11:20 +01001996 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001997 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001998 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001999 }
2000
2001 std::vector<TokenSpan> local_chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002002 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002003 interpreter_manager->SelectionInterpreter(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002004 *cached_features, &local_chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002005 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002006 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002007 }
2008
2009 const int offset = std::distance(context_unicode.begin(), line.first);
2010 for (const TokenSpan& chunk : local_chunks) {
2011 const CodepointSpan codepoint_span =
2012 selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002013 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01002014
2015 // Skip empty spans.
2016 if (codepoint_span.first != codepoint_span.second) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002017 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00002018 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
2019 codepoint_span, interpreter_manager,
2020 &embedding_cache, &classification)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002021 TC3_LOG(ERROR) << "Could not classify text: "
2022 << (codepoint_span.first + offset) << " "
2023 << (codepoint_span.second + offset);
Lukas Zilkaba849e72018-03-08 14:48:21 +01002024 return false;
2025 }
2026
2027 // Do not include the span if it's classified as "other".
2028 if (!classification.empty() && !ClassifiedAsOther(classification) &&
2029 classification[0].score >= min_annotate_confidence) {
2030 AnnotatedSpan result_span;
2031 result_span.span = {codepoint_span.first + offset,
2032 codepoint_span.second + offset};
2033 result_span.classification = std::move(classification);
2034 result->push_back(std::move(result_span));
2035 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002036 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01002037 }
2038 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002039 return true;
2040}
2041
Tony Mak6c4cc672018-09-17 11:48:50 +01002042const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
Lukas Zilka434442d2018-04-25 11:38:51 +02002043 return selection_feature_processor_.get();
2044}
2045
Tony Mak6c4cc672018-09-17 11:48:50 +01002046const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
Lukas Zilka434442d2018-04-25 11:38:51 +02002047 const {
2048 return classification_feature_processor_.get();
Lukas Zilkab23e2122018-02-09 10:25:19 +01002049}
2050
Tony Mak6c4cc672018-09-17 11:48:50 +01002051const DatetimeParser* Annotator::DatetimeParserForTests() const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002052 return datetime_parser_.get();
2053}
2054
Tony Mak83d2de62019-04-10 16:12:15 +01002055void Annotator::RemoveNotEnabledEntityTypes(
2056 const EnabledEntityTypes& is_entity_type_enabled,
2057 std::vector<AnnotatedSpan>* annotated_spans) const {
2058 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2059 std::vector<ClassificationResult>& classifications =
2060 annotated_span.classification;
2061 classifications.erase(
2062 std::remove_if(classifications.begin(), classifications.end(),
2063 [&is_entity_type_enabled](
2064 const ClassificationResult& classification_result) {
2065 return !is_entity_type_enabled(
2066 classification_result.collection);
2067 }),
2068 classifications.end());
2069 }
2070 annotated_spans->erase(
2071 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2072 [](const AnnotatedSpan& annotated_span) {
2073 return annotated_span.classification.empty();
2074 }),
2075 annotated_spans->end());
2076}
2077
Tony Maka2a1ff42019-09-12 15:40:32 +01002078void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2079 std::vector<AnnotatedSpan>* candidates) const {
2080 if (candidates == nullptr || contact_engine_ == nullptr) {
2081 return;
2082 }
2083 for (auto& candidate : *candidates) {
2084 for (auto& classification_result : candidate.classification) {
2085 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2086 &classification_result);
2087 }
2088 }
2089}
2090
Tony Makff31efb2020-03-31 11:13:06 +01002091Status Annotator::AnnotateSingleInput(
2092 const std::string& context, const AnnotationOptions& options,
2093 std::vector<AnnotatedSpan>* candidates) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002094 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
Tony Makff31efb2020-03-31 11:13:06 +01002095 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
Lukas Zilkaba849e72018-03-08 14:48:21 +01002096 }
2097
Tony Mak854015a2019-01-16 15:56:48 +00002098 const UnicodeText context_unicode =
2099 UTF8ToUnicodeText(context, /*do_copy=*/false);
2100 if (!context_unicode.is_valid()) {
Tony Makff31efb2020-03-31 11:13:06 +01002101 return Status(StatusCode::INVALID_ARGUMENT,
2102 "Context string isn't valid UTF8.");
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002103 }
2104
Tony Mak378c1f52019-03-04 15:58:11 +00002105 std::vector<Locale> detected_text_language_tags;
2106 if (!ParseLocales(options.detected_text_language_tags,
2107 &detected_text_language_tags)) {
2108 TC3_LOG(WARNING)
Tony Makdf54e742019-03-26 14:04:00 +00002109 << "Failed to parse the detected_text_language_tags in options: "
Tony Mak378c1f52019-03-04 15:58:11 +00002110 << options.detected_text_language_tags;
2111 }
Tony Makdf54e742019-03-26 14:04:00 +00002112 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2113 model_triggering_locales_,
2114 /*default_value=*/true)) {
Tony Makff31efb2020-03-31 11:13:06 +01002115 return Status(
2116 StatusCode::UNAVAILABLE,
2117 "The detected language tags are not in the supported locales.");
Tony Makdf54e742019-03-26 14:04:00 +00002118 }
2119
2120 InterpreterManager interpreter_manager(selection_executor_.get(),
2121 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00002122
Lukas Zilkab23e2122018-02-09 10:25:19 +01002123 // Annotate with the selection model.
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002124 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00002125 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
Tony Makff31efb2020-03-31 11:13:06 +01002126 &tokens, candidates)) {
2127 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002128 }
2129
2130 // Annotate with the regular expression models.
2131 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Tony Makff31efb2020-03-31 11:13:06 +01002132 annotation_regex_patterns_, candidates,
Tony Mak83d2de62019-04-10 16:12:15 +01002133 options.is_serialized_entity_data_enabled)) {
Tony Makff31efb2020-03-31 11:13:06 +01002134 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002135 }
2136
2137 // Annotate with the datetime model.
Tony Mak83d2de62019-04-10 16:12:15 +01002138 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2139 if ((is_entity_type_enabled(Collections::Date()) ||
2140 is_entity_type_enabled(Collections::DateTime())) &&
2141 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002142 options.reference_time_ms_utc, options.reference_timezone,
Tony Makd9446602019-02-20 18:25:39 +00002143 options.locales, ModeFlag_ANNOTATION,
Tony Mak83d2de62019-04-10 16:12:15 +01002144 options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002145 options.is_serialized_entity_data_enabled, candidates)) {
2146 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
Tony Mak6c4cc672018-09-17 11:48:50 +01002147 }
2148
Tony Mak854015a2019-01-16 15:56:48 +00002149 // Annotate with the contact engine.
2150 if (contact_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002151 !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2152 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
Tony Mak854015a2019-01-16 15:56:48 +00002153 }
2154
Tony Makd9446602019-02-20 18:25:39 +00002155 // Annotate with the installed app engine.
2156 if (installed_app_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002157 !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2158 return Status(StatusCode::INTERNAL,
2159 "Couldn't run installed app engine Chunk.");
Tony Makd9446602019-02-20 18:25:39 +00002160 }
2161
Tony Mak378c1f52019-03-04 15:58:11 +00002162 // Annotate with the number annotator.
2163 if (number_annotator_ != nullptr &&
2164 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002165 candidates)) {
2166 return Status(StatusCode::INTERNAL,
2167 "Couldn't run number annotator FindAll.");
Tony Makad2e22d2019-03-20 17:35:13 +00002168 }
2169
2170 // Annotate with the duration annotator.
Tony Mak83d2de62019-04-10 16:12:15 +01002171 if (is_entity_type_enabled(Collections::Duration()) &&
2172 duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +00002173 !duration_annotator_->FindAll(context_unicode, tokens,
Tony Makff31efb2020-03-31 11:13:06 +01002174 options.annotation_usecase, candidates)) {
2175 return Status(StatusCode::INTERNAL,
2176 "Couldn't run duration annotator FindAll.");
Tony Mak378c1f52019-03-04 15:58:11 +00002177 }
2178
Tony Mak76d80962020-01-08 17:30:51 +00002179 // Annotate with the person name engine.
2180 if (is_entity_type_enabled(Collections::PersonName()) &&
2181 person_name_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002182 !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2183 return Status(StatusCode::INTERNAL,
2184 "Couldn't run person name engine Chunk.");
Tony Mak76d80962020-01-08 17:30:51 +00002185 }
2186
Tony Mak21460022020-03-12 18:29:35 +00002187 // Annotate with the grammar annotators.
2188 if (grammar_annotator_ != nullptr &&
2189 !grammar_annotator_->Annotate(detected_text_language_tags,
Tony Makff31efb2020-03-31 11:13:06 +01002190 context_unicode, candidates)) {
2191 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
Tony Mak21460022020-03-12 18:29:35 +00002192 }
2193
Tony Mak5a12b942020-05-01 12:41:31 +01002194 if (experimental_annotator_ != nullptr &&
2195 !experimental_annotator_->Annotate(context_unicode, candidates)) {
2196 return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2197 }
2198
Lukas Zilkab23e2122018-02-09 10:25:19 +01002199 // Sort candidates according to their position in the input, so that the next
2200 // code can assume that any connected component of overlapping spans forms a
2201 // contiguous block.
Tony Mak5a12b942020-05-01 12:41:31 +01002202 // Also sort them according to the end position and collection, so that the
2203 // deduplication code below can assume that same spans and classifications
2204 // form contiguous blocks.
Tony Makff31efb2020-03-31 11:13:06 +01002205 std::sort(candidates->begin(), candidates->end(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002206 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
Tony Mak5a12b942020-05-01 12:41:31 +01002207 if (a.span.first != b.span.first) {
2208 return a.span.first < b.span.first;
2209 }
2210
2211 if (a.span.second != b.span.second) {
2212 return a.span.second < b.span.second;
2213 }
2214
2215 return a.classification[0].collection <
2216 b.classification[0].collection;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002217 });
2218
2219 std::vector<int> candidate_indices;
Tony Makff31efb2020-03-31 11:13:06 +01002220 if (!ResolveConflicts(*candidates, context, tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00002221 detected_text_language_tags, options.annotation_usecase,
2222 &interpreter_manager, &candidate_indices)) {
Tony Makff31efb2020-03-31 11:13:06 +01002223 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002224 }
2225
Tony Mak5a12b942020-05-01 12:41:31 +01002226 // Remove candidates that overlap exactly and have the same collection.
2227 // This can e.g. happen for phone coming from both ML model and regex.
2228 candidate_indices.erase(
2229 std::unique(candidate_indices.begin(), candidate_indices.end(),
2230 [&candidates](const int a_index, const int b_index) {
2231 const AnnotatedSpan& a = (*candidates)[a_index];
2232 const AnnotatedSpan& b = (*candidates)[b_index];
2233 return a.span == b.span &&
2234 a.classification[0].collection ==
2235 b.classification[0].collection;
2236 }),
2237 candidate_indices.end());
2238
Lukas Zilkab23e2122018-02-09 10:25:19 +01002239 std::vector<AnnotatedSpan> result;
2240 result.reserve(candidate_indices.size());
2241 for (const int i : candidate_indices) {
Tony Makff31efb2020-03-31 11:13:06 +01002242 if ((*candidates)[i].classification.empty() ||
2243 ClassifiedAsOther((*candidates)[i].classification) ||
2244 FilteredForAnnotation((*candidates)[i])) {
Tony Mak378c1f52019-03-04 15:58:11 +00002245 continue;
2246 }
Tony Mak5a12b942020-05-01 12:41:31 +01002247 result.push_back(std::move((*candidates)[i]));
Tony Mak378c1f52019-03-04 15:58:11 +00002248 }
2249
Tony Mak83d2de62019-04-10 16:12:15 +01002250 // We generate all candidates and remove them later (with the exception of
2251 // date/time/duration entities) because there are complex interdependencies
2252 // between the entity types. E.g., the TLD of an email can be interpreted as a
2253 // URL, but most likely a user of the API does not want such annotations if
2254 // "url" is enabled and "email" is not.
2255 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2256
Tony Mak378c1f52019-03-04 15:58:11 +00002257 for (AnnotatedSpan& annotated_span : result) {
2258 SortClassificationResults(&annotated_span.classification);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002259 }
Tony Makff31efb2020-03-31 11:13:06 +01002260 *candidates = result;
2261 return Status::OK;
2262}
Lukas Zilkab23e2122018-02-09 10:25:19 +01002263
Tony Makff31efb2020-03-31 11:13:06 +01002264StatusOr<std::vector<std::vector<AnnotatedSpan>>>
2265Annotator::AnnotateStructuredInput(
2266 const std::vector<InputFragment>& string_fragments,
2267 const AnnotationOptions& options) const {
2268 std::vector<std::vector<AnnotatedSpan>> annotation_candidates(
2269 string_fragments.size());
2270
2271 std::vector<std::string> text_to_annotate;
2272 text_to_annotate.reserve(string_fragments.size());
2273 for (const auto& string_fragment : string_fragments) {
2274 text_to_annotate.push_back(string_fragment.text);
2275 }
2276
2277 // KnowledgeEngine is special, because it supports annotation of multiple
2278 // fragments at once.
2279 if (knowledge_engine_ &&
2280 !knowledge_engine_
2281 ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +01002282 options.location_context, options.permissions,
Tony Makff31efb2020-03-31 11:13:06 +01002283 &annotation_candidates)
2284 .ok()) {
2285 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2286 }
2287 // The annotator engines shouldn't change the number of annotation vectors.
2288 if (annotation_candidates.size() != text_to_annotate.size()) {
2289 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2290 << " texts to annotate but generated a different number of "
2291 "lists of annotations:"
2292 << annotation_candidates.size();
2293 return Status(StatusCode::INTERNAL,
2294 "Number of annotation candidates differs from "
2295 "number of texts to annotate.");
2296 }
2297
2298 // Other annotators run on each fragment independently.
2299 for (int i = 0; i < text_to_annotate.size(); ++i) {
2300 AnnotationOptions annotation_options = options;
2301 if (string_fragments[i].datetime_options.has_value()) {
2302 DatetimeOptions reference_datetime =
2303 string_fragments[i].datetime_options.value();
2304 annotation_options.reference_time_ms_utc =
2305 reference_datetime.reference_time_ms_utc;
2306 annotation_options.reference_timezone =
2307 reference_datetime.reference_timezone;
2308 }
2309
2310 AddContactMetadataToKnowledgeClassificationResults(
2311 &annotation_candidates[i]);
2312
2313 Status annotation_status = AnnotateSingleInput(
2314 text_to_annotate[i], annotation_options, &annotation_candidates[i]);
2315 if (!annotation_status.ok()) {
2316 return annotation_status;
2317 }
2318 }
2319 return annotation_candidates;
2320}
2321
2322std::vector<AnnotatedSpan> Annotator::Annotate(
2323 const std::string& context, const AnnotationOptions& options) const {
2324 std::vector<InputFragment> string_fragments;
2325 string_fragments.push_back({.text = context});
2326 StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations =
2327 AnnotateStructuredInput(string_fragments, options);
2328 if (!annotations.ok()) {
2329 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2330 << annotations.status().error_message();
2331 return {};
2332 }
2333 return annotations.ValueOrDie()[0];
Lukas Zilka21d8c982018-01-24 11:11:20 +01002334}
2335
Tony Mak854015a2019-01-16 15:56:48 +00002336CodepointSpan Annotator::ComputeSelectionBoundaries(
2337 const UniLib::RegexMatcher* match,
2338 const RegexModel_::Pattern* config) const {
2339 if (config->capturing_group() == nullptr) {
2340 // Use first capturing group to specify the selection.
2341 int status = UniLib::RegexMatcher::kNoError;
2342 const CodepointSpan result = {match->Start(1, &status),
2343 match->End(1, &status)};
2344 if (status != UniLib::RegexMatcher::kNoError) {
2345 return {kInvalidIndex, kInvalidIndex};
2346 }
2347 return result;
2348 }
2349
2350 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2351 const int num_groups = config->capturing_group()->size();
2352 for (int i = 0; i < num_groups; i++) {
2353 if (!config->capturing_group()->Get(i)->extend_selection()) {
2354 continue;
2355 }
2356
2357 int status = UniLib::RegexMatcher::kNoError;
2358 // Check match and adjust bounds.
2359 const int group_start = match->Start(i, &status);
2360 const int group_end = match->End(i, &status);
2361 if (status != UniLib::RegexMatcher::kNoError) {
2362 return {kInvalidIndex, kInvalidIndex};
2363 }
2364 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2365 continue;
2366 }
2367 if (result.first == kInvalidIndex) {
2368 result = {group_start, group_end};
2369 } else {
2370 result.first = std::min(result.first, group_start);
2371 result.second = std::max(result.second, group_end);
2372 }
2373 }
2374 return result;
2375}
2376
Tony Makd9446602019-02-20 18:25:39 +00002377bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
Tony Mak21460022020-03-12 18:29:35 +00002378 if (pattern->serialized_entity_data() != nullptr ||
2379 pattern->entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002380 return true;
2381 }
2382 if (pattern->capturing_group() != nullptr) {
Tony Mak63959242020-02-07 18:31:16 +00002383 for (const CapturingGroup* group : *pattern->capturing_group()) {
Tony Makd9446602019-02-20 18:25:39 +00002384 if (group->entity_field_path() != nullptr) {
2385 return true;
2386 }
Tony Mak21460022020-03-12 18:29:35 +00002387 if (group->serialized_entity_data() != nullptr ||
2388 group->entity_data() != nullptr) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002389 return true;
2390 }
Tony Makd9446602019-02-20 18:25:39 +00002391 }
2392 }
2393 return false;
2394}
2395
2396bool Annotator::SerializedEntityDataFromRegexMatch(
2397 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2398 std::string* serialized_entity_data) const {
2399 if (!HasEntityData(pattern)) {
2400 serialized_entity_data->clear();
2401 return true;
2402 }
2403 TC3_CHECK(entity_data_builder_ != nullptr);
2404
2405 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
2406 entity_data_builder_->NewRoot();
2407
2408 TC3_CHECK(entity_data != nullptr);
2409
Tony Mak21460022020-03-12 18:29:35 +00002410 // Set fixed entity data.
Tony Makd9446602019-02-20 18:25:39 +00002411 if (pattern->serialized_entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002412 entity_data->MergeFromSerializedFlatbuffer(
2413 StringPiece(pattern->serialized_entity_data()->c_str(),
2414 pattern->serialized_entity_data()->size()));
2415 }
Tony Mak21460022020-03-12 18:29:35 +00002416 if (pattern->entity_data() != nullptr) {
2417 entity_data->MergeFrom(
2418 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2419 }
Tony Makd9446602019-02-20 18:25:39 +00002420
2421 // Add entity data from rule capturing groups.
2422 if (pattern->capturing_group() != nullptr) {
2423 const int num_groups = pattern->capturing_group()->size();
2424 for (int i = 0; i < num_groups; i++) {
Tony Mak63959242020-02-07 18:31:16 +00002425 const CapturingGroup* group = pattern->capturing_group()->Get(i);
Tony Maka2a1ff42019-09-12 15:40:32 +01002426
2427 // Check whether the group matched.
2428 Optional<std::string> group_match_text =
2429 GetCapturingGroupText(matcher, /*group_id=*/i);
2430 if (!group_match_text.has_value()) {
Tony Makd9446602019-02-20 18:25:39 +00002431 continue;
2432 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002433
Tony Mak21460022020-03-12 18:29:35 +00002434 // Set fixed entity data from capturing group match.
Tony Maka2a1ff42019-09-12 15:40:32 +01002435 if (group->serialized_entity_data() != nullptr) {
2436 entity_data->MergeFromSerializedFlatbuffer(
2437 StringPiece(group->serialized_entity_data()->c_str(),
2438 group->serialized_entity_data()->size()));
2439 }
Tony Mak21460022020-03-12 18:29:35 +00002440 if (group->entity_data() != nullptr) {
2441 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2442 pattern->entity_data()));
2443 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002444
2445 // Set entity field from capturing group text.
2446 if (group->entity_field_path() != nullptr) {
Tony Mak8cd7ba62019-10-15 15:29:22 +01002447 UnicodeText normalized_group_match_text =
2448 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2449
2450 // Apply normalization if specified.
2451 if (group->normalization_options() != nullptr) {
2452 normalized_group_match_text =
Tony Mak1ac2e4a2020-04-29 13:41:53 +01002453 NormalizeText(*unilib_, group->normalization_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +01002454 normalized_group_match_text);
2455 }
2456
2457 if (!entity_data->ParseAndSet(
2458 group->entity_field_path(),
2459 normalized_group_match_text.ToUTF8String())) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002460 TC3_LOG(ERROR)
2461 << "Could not set entity data from rule capturing group.";
2462 return false;
2463 }
Tony Makd9446602019-02-20 18:25:39 +00002464 }
2465 }
2466 }
2467
2468 *serialized_entity_data = entity_data->Serialize();
2469 return true;
2470}
2471
Tony Mak63959242020-02-07 18:31:16 +00002472UnicodeText RemoveMoneySeparators(
2473 const std::unordered_set<char32>& decimal_separators,
2474 const UnicodeText& amount,
2475 UnicodeText::const_iterator it_decimal_separator) {
2476 UnicodeText whole_amount;
2477 for (auto it = amount.begin();
2478 it != amount.end() && it != it_decimal_separator; ++it) {
2479 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2480 static_cast<char32>(*it)) == decimal_separators.end()) {
2481 whole_amount.push_back(*it);
2482 }
2483 }
2484 return whole_amount;
2485}
2486
2487bool Annotator::ParseAndFillInMoneyAmount(
2488 std::string* serialized_entity_data) const {
2489 std::unique_ptr<EntityDataT> data =
2490 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2491 *serialized_entity_data);
Tony Mak0b8b3322020-03-17 16:30:19 +00002492 if (data == nullptr) {
2493 TC3_LOG(ERROR)
2494 << "Data field is null when trying to parse Money Entity Data";
2495 return false;
2496 }
2497 if (data->money->unnormalized_amount.empty()) {
2498 TC3_LOG(ERROR) << "Data unnormalized_amount is empty when trying to parse "
2499 "Money Entity Data";
Tony Mak63959242020-02-07 18:31:16 +00002500 return false;
2501 }
2502
2503 UnicodeText amount =
2504 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2505 int separator_back_index = 0;
Tony Mak21460022020-03-12 18:29:35 +00002506 auto it_decimal_separator = --amount.end();
Tony Mak63959242020-02-07 18:31:16 +00002507 for (; it_decimal_separator != amount.begin();
2508 --it_decimal_separator, ++separator_back_index) {
2509 if (std::find(money_separators_.begin(), money_separators_.end(),
2510 static_cast<char32>(*it_decimal_separator)) !=
2511 money_separators_.end()) {
2512 break;
2513 }
2514 }
2515
2516 // If there are 3 digits after the last separator, we consider that a
2517 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2518 // If there is no separator in number, also that number is an int.
Tony Mak21460022020-03-12 18:29:35 +00002519 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
Tony Mak63959242020-02-07 18:31:16 +00002520 it_decimal_separator = amount.end();
2521 }
2522
2523 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2524 it_decimal_separator),
2525 &data->money->amount_whole_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002526 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2527 "amount: "
2528 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002529 return false;
2530 }
2531 if (it_decimal_separator == amount.end()) {
2532 data->money->amount_decimal_part = 0;
2533 } else {
2534 const int amount_codepoints_size = amount.size_codepoints();
2535 if (!unilib_->ParseInt32(
2536 UnicodeText::Substring(
Tony Mak21460022020-03-12 18:29:35 +00002537 amount, amount_codepoints_size - separator_back_index,
Tony Mak63959242020-02-07 18:31:16 +00002538 amount_codepoints_size, /*do_copy=*/false),
2539 &data->money->amount_decimal_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002540 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2541 "the amount: "
2542 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002543 return false;
2544 }
2545 }
2546
2547 *serialized_entity_data =
2548 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2549 return true;
2550}
2551
Tony Mak6c4cc672018-09-17 11:48:50 +01002552bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2553 const std::vector<int>& rules,
Tony Mak83d2de62019-04-10 16:12:15 +01002554 std::vector<AnnotatedSpan>* result,
2555 bool is_serialized_entity_data_enabled) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002556 for (int pattern_id : rules) {
2557 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2558 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2559 if (!matcher) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002560 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2561 << pattern_id;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002562 return false;
2563 }
2564
2565 int status = UniLib::RegexMatcher::kNoError;
2566 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
Tony Mak854015a2019-01-16 15:56:48 +00002567 if (regex_pattern.config->verification_options()) {
Tony Makdf54e742019-03-26 14:04:00 +00002568 if (!VerifyRegexMatchCandidate(
2569 context_unicode.ToUTF8String(),
2570 regex_pattern.config->verification_options(),
2571 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002572 continue;
2573 }
2574 }
Tony Makd9446602019-02-20 18:25:39 +00002575
2576 std::string serialized_entity_data;
Tony Mak83d2de62019-04-10 16:12:15 +01002577 if (is_serialized_entity_data_enabled) {
2578 if (!SerializedEntityDataFromRegexMatch(
2579 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2580 TC3_LOG(ERROR) << "Could not get entity data.";
2581 return false;
2582 }
Tony Mak63959242020-02-07 18:31:16 +00002583
2584 // Further parsing unnormalized_amount for money into amount_whole_part
2585 // and amount_decimal_part. Can't do this with regexes because we cannot
2586 // have empty groups (amount_decimal_part might be an empty group).
2587 if (regex_pattern.config->collection_name()->str() ==
2588 Collections::Money()) {
2589 if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
2590 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2591 }
2592 }
Tony Makd9446602019-02-20 18:25:39 +00002593 }
2594
Lukas Zilkab23e2122018-02-09 10:25:19 +01002595 result->emplace_back();
Tony Mak854015a2019-01-16 15:56:48 +00002596
Lukas Zilkab23e2122018-02-09 10:25:19 +01002597 // Selection/annotation regular expressions need to specify a capturing
2598 // group specifying the selection.
Tony Mak854015a2019-01-16 15:56:48 +00002599 result->back().span =
2600 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2601
Lukas Zilkab23e2122018-02-09 10:25:19 +01002602 result->back().classification = {
Tony Mak854015a2019-01-16 15:56:48 +00002603 {regex_pattern.config->collection_name()->str(),
2604 regex_pattern.config->target_classification_score(),
2605 regex_pattern.config->priority_score()}};
Tony Makd9446602019-02-20 18:25:39 +00002606
2607 result->back().classification[0].serialized_entity_data =
2608 serialized_entity_data;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002609 }
2610 }
2611 return true;
2612}
2613
Tony Mak6c4cc672018-09-17 11:48:50 +01002614bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2615 tflite::Interpreter* selection_interpreter,
2616 const CachedFeatures& cached_features,
2617 std::vector<TokenSpan>* chunks) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01002618 const int max_selection_span =
2619 selection_feature_processor_->GetOptions()->max_selection_span();
Lukas Zilka21d8c982018-01-24 11:11:20 +01002620 // The inference span is the span of interest expanded to include
2621 // max_selection_span tokens on either side, which is how far a selection can
2622 // stretch from the click.
2623 const TokenSpan inference_span = IntersectTokenSpans(
2624 ExpandTokenSpan(span_of_interest,
2625 /*num_tokens_left=*/max_selection_span,
2626 /*num_tokens_right=*/max_selection_span),
2627 {0, num_tokens});
2628
2629 std::vector<ScoredChunk> scored_chunks;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002630 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2631 selection_feature_processor_->GetOptions()
2632 ->bounds_sensitive_features()
2633 ->enabled()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002634 if (!ModelBoundsSensitiveScoreChunks(
2635 num_tokens, span_of_interest, inference_span, cached_features,
2636 selection_interpreter, &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002637 return false;
2638 }
2639 } else {
2640 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002641 cached_features, selection_interpreter,
2642 &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002643 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002644 }
2645 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002646 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2647 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2648 return lhs.score < rhs.score;
2649 });
Lukas Zilka21d8c982018-01-24 11:11:20 +01002650
2651 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2652 // them greedily as long as they do not overlap with any previously picked
2653 // chunks.
2654 std::vector<bool> token_used(TokenSpanSize(inference_span));
2655 chunks->clear();
2656 for (const ScoredChunk& scored_chunk : scored_chunks) {
2657 bool feasible = true;
2658 for (int i = scored_chunk.token_span.first;
2659 i < scored_chunk.token_span.second; ++i) {
2660 if (token_used[i - inference_span.first]) {
2661 feasible = false;
2662 break;
2663 }
2664 }
2665
2666 if (!feasible) {
2667 continue;
2668 }
2669
2670 for (int i = scored_chunk.token_span.first;
2671 i < scored_chunk.token_span.second; ++i) {
2672 token_used[i - inference_span.first] = true;
2673 }
2674
2675 chunks->push_back(scored_chunk.token_span);
2676 }
2677
2678 std::sort(chunks->begin(), chunks->end());
2679
2680 return true;
2681}
2682
Lukas Zilkab23e2122018-02-09 10:25:19 +01002683namespace {
2684// Updates the value at the given key in the map to maximum of the current value
2685// and the given value, or simply inserts the value if the key is not yet there.
2686template <typename Map>
2687void UpdateMax(Map* map, typename Map::key_type key,
2688 typename Map::mapped_type value) {
2689 const auto it = map->find(key);
2690 if (it != map->end()) {
2691 it->second = std::max(it->second, value);
2692 } else {
2693 (*map)[key] = value;
2694 }
2695}
2696} // namespace
2697
Tony Mak6c4cc672018-09-17 11:48:50 +01002698bool Annotator::ModelClickContextScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002699 int num_tokens, const TokenSpan& span_of_interest,
2700 const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002701 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002702 std::vector<ScoredChunk>* scored_chunks) const {
2703 const int max_batch_size = model_->selection_options()->batch_size();
2704
2705 std::vector<float> all_features;
2706 std::map<TokenSpan, float> chunk_scores;
2707 for (int batch_start = span_of_interest.first;
2708 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2709 const int batch_end =
2710 std::min(batch_start + max_batch_size, span_of_interest.second);
2711
2712 // Prepare features for the whole batch.
2713 all_features.clear();
2714 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2715 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2716 cached_features.AppendClickContextFeaturesForClick(click_pos,
2717 &all_features);
2718 }
2719
2720 // Run batched inference.
2721 const int batch_size = batch_end - batch_start;
2722 const int features_size = cached_features.OutputFeaturesSize();
2723 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002724 TensorView<float>(all_features.data(), {batch_size, features_size}),
2725 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002726 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002727 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002728 return false;
2729 }
2730 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2731 logits.dim(1) !=
2732 selection_feature_processor_->GetSelectionLabelCount()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002733 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002734 return false;
2735 }
2736
2737 // Save results.
2738 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2739 const std::vector<float> scores = ComputeSoftmax(
2740 logits.data() + logits.dim(1) * (click_pos - batch_start),
2741 logits.dim(1));
2742 for (int j = 0;
2743 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2744 TokenSpan relative_token_span;
2745 if (!selection_feature_processor_->LabelToTokenSpan(
2746 j, &relative_token_span)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002747 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002748 return false;
2749 }
2750 const TokenSpan candidate_span = ExpandTokenSpan(
2751 SingleTokenSpan(click_pos), relative_token_span.first,
2752 relative_token_span.second);
2753 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2754 UpdateMax(&chunk_scores, candidate_span, scores[j]);
2755 }
2756 }
2757 }
2758 }
2759
2760 scored_chunks->clear();
2761 scored_chunks->reserve(chunk_scores.size());
2762 for (const auto& entry : chunk_scores) {
2763 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2764 }
2765
2766 return true;
2767}
2768
Tony Mak6c4cc672018-09-17 11:48:50 +01002769bool Annotator::ModelBoundsSensitiveScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002770 int num_tokens, const TokenSpan& span_of_interest,
2771 const TokenSpan& inference_span, const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002772 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002773 std::vector<ScoredChunk>* scored_chunks) const {
2774 const int max_selection_span =
2775 selection_feature_processor_->GetOptions()->max_selection_span();
2776 const int max_chunk_length = selection_feature_processor_->GetOptions()
2777 ->selection_reduced_output_space()
2778 ? max_selection_span + 1
2779 : 2 * max_selection_span + 1;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002780 const bool score_single_token_spans_as_zero =
2781 selection_feature_processor_->GetOptions()
2782 ->bounds_sensitive_features()
2783 ->score_single_token_spans_as_zero();
2784
2785 scored_chunks->clear();
2786 if (score_single_token_spans_as_zero) {
2787 scored_chunks->reserve(TokenSpanSize(span_of_interest));
2788 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002789
2790 // Prepare all chunk candidates into one batch:
2791 // - Are contained in the inference span
2792 // - Have a non-empty intersection with the span of interest
2793 // - Are at least one token long
2794 // - Are not longer than the maximum chunk length
2795 std::vector<TokenSpan> candidate_spans;
2796 for (int start = inference_span.first; start < span_of_interest.second;
2797 ++start) {
2798 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2799 for (int end = leftmost_end_index;
2800 end <= inference_span.second && end - start <= max_chunk_length;
2801 ++end) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002802 const TokenSpan candidate_span = {start, end};
2803 if (score_single_token_spans_as_zero &&
2804 TokenSpanSize(candidate_span) == 1) {
2805 // Do not include the single token span in the batch, add a zero score
2806 // for it directly to the output.
2807 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2808 } else {
2809 candidate_spans.push_back(candidate_span);
2810 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002811 }
2812 }
2813
2814 const int max_batch_size = model_->selection_options()->batch_size();
2815
2816 std::vector<float> all_features;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002817 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
Lukas Zilkab23e2122018-02-09 10:25:19 +01002818 for (int batch_start = 0; batch_start < candidate_spans.size();
2819 batch_start += max_batch_size) {
2820 const int batch_end = std::min(batch_start + max_batch_size,
2821 static_cast<int>(candidate_spans.size()));
2822
2823 // Prepare features for the whole batch.
2824 all_features.clear();
2825 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2826 for (int i = batch_start; i < batch_end; ++i) {
2827 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2828 &all_features);
2829 }
2830
2831 // Run batched inference.
2832 const int batch_size = batch_end - batch_start;
2833 const int features_size = cached_features.OutputFeaturesSize();
2834 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002835 TensorView<float>(all_features.data(), {batch_size, features_size}),
2836 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002837 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002838 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002839 return false;
2840 }
2841 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2842 logits.dim(1) != 1) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002843 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002844 return false;
2845 }
2846
2847 // Save results.
2848 for (int i = batch_start; i < batch_end; ++i) {
2849 scored_chunks->push_back(
2850 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2851 }
2852 }
2853
2854 return true;
2855}
2856
Tony Mak6c4cc672018-09-17 11:48:50 +01002857bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
2858 int64 reference_time_ms_utc,
2859 const std::string& reference_timezone,
2860 const std::string& locales, ModeFlag mode,
Tony Makd9446602019-02-20 18:25:39 +00002861 AnnotationUsecase annotation_usecase,
Tony Mak83d2de62019-04-10 16:12:15 +01002862 bool is_serialized_entity_data_enabled,
Tony Mak6c4cc672018-09-17 11:48:50 +01002863 std::vector<AnnotatedSpan>* result) const {
Tony Mak63959242020-02-07 18:31:16 +00002864 std::vector<DatetimeParseResultSpan> datetime_spans;
2865 if (cfg_datetime_parser_) {
2866 if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
2867 return true;
2868 }
2869 std::vector<Locale> parsed_locales;
2870 ParseLocales(locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00002871 cfg_datetime_parser_->Parse(
2872 context_unicode.ToUTF8String(),
2873 ToDateAnnotationOptions(
2874 model_->grammar_datetime_model()->annotation_options(),
2875 reference_timezone, reference_time_ms_utc),
2876 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00002877 }
2878
2879 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00002880 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
2881 reference_timezone, locales, mode,
2882 annotation_usecase,
2883 /*anchor_start_end=*/false, &datetime_spans)) {
2884 return false;
2885 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002886 }
2887
Lukas Zilkab23e2122018-02-09 10:25:19 +01002888 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
Tony Mak378c1f52019-03-04 15:58:11 +00002889 AnnotatedSpan annotated_span;
2890 annotated_span.span = datetime_span.span;
Tony Mak854015a2019-01-16 15:56:48 +00002891 for (const DatetimeParseResult& parse_result : datetime_span.data) {
Tony Mak378c1f52019-03-04 15:58:11 +00002892 annotated_span.classification.emplace_back(
2893 PickCollectionForDatetime(parse_result),
2894 datetime_span.target_classification_score,
2895 datetime_span.priority_score);
2896 annotated_span.classification.back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01002897 if (is_serialized_entity_data_enabled) {
2898 annotated_span.classification.back().serialized_entity_data =
2899 CreateDatetimeSerializedEntityData(parse_result);
2900 }
Tony Mak854015a2019-01-16 15:56:48 +00002901 }
Tony Mak448b5862019-03-22 13:36:41 +00002902 annotated_span.source = AnnotatedSpan::Source::DATETIME;
Tony Mak378c1f52019-03-04 15:58:11 +00002903 result->push_back(std::move(annotated_span));
Lukas Zilkab23e2122018-02-09 10:25:19 +01002904 }
2905 return true;
2906}
2907
Tony Mak378c1f52019-03-04 15:58:11 +00002908const Model* Annotator::model() const { return model_; }
Tony Makd9446602019-02-20 18:25:39 +00002909const reflection::Schema* Annotator::entity_data_schema() const {
2910 return entity_data_schema_;
2911}
Tony Mak854015a2019-01-16 15:56:48 +00002912
Lukas Zilka21d8c982018-01-24 11:11:20 +01002913const Model* ViewModel(const void* buffer, int size) {
2914 if (!buffer) {
2915 return nullptr;
2916 }
2917
2918 return LoadAndVerifyModel(buffer, size);
2919}
2920
Tony Makd9446602019-02-20 18:25:39 +00002921bool Annotator::LookUpKnowledgeEntity(
2922 const std::string& id, std::string* serialized_knowledge_result) const {
2923 return knowledge_engine_ &&
2924 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
2925}
2926
Tony Mak6c4cc672018-09-17 11:48:50 +01002927} // namespace libtextclassifier3