blob: 6ee983fd12592f38570a486c08d70dc7a85b9bae [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <algorithm>
Lukas Zilka21d8c982018-01-24 11:11:20 +010020#include <cmath>
Tony Mak21460022020-03-12 18:29:35 +000021#include <cstddef>
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <iterator>
23#include <numeric>
Tony Mak63959242020-02-07 18:31:16 +000024#include <string>
Tony Mak448b5862019-03-22 13:36:41 +000025#include <unordered_map>
Tony Mak63959242020-02-07 18:31:16 +000026#include <vector>
Lukas Zilka21d8c982018-01-24 11:11:20 +010027
Tony Mak854015a2019-01-16 15:56:48 +000028#include "annotator/collections.h"
Tony Mak83d2de62019-04-10 16:12:15 +010029#include "annotator/model_generated.h"
30#include "annotator/types.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010031#include "utils/base/logging.h"
Tony Makff31efb2020-03-31 11:13:06 +010032#include "utils/base/status.h"
33#include "utils/base/statusor.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010034#include "utils/checksum.h"
Tony Mak63959242020-02-07 18:31:16 +000035#include "utils/i18n/locale.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010036#include "utils/math/softmax.h"
Tony Mak8cd7ba62019-10-15 15:29:22 +010037#include "utils/normalization.h"
Tony Maka2a1ff42019-09-12 15:40:32 +010038#include "utils/optional.h"
Tony Makd9446602019-02-20 18:25:39 +000039#include "utils/regex-match.h"
Tony Mak63959242020-02-07 18:31:16 +000040#include "utils/strings/numbers.h"
41#include "utils/strings/split.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010042#include "utils/utf8/unicodetext.h"
Tony Mak21460022020-03-12 18:29:35 +000043#include "utils/utf8/unilib-common.h"
Tony Mak378c1f52019-03-04 15:58:11 +000044#include "utils/zlib/zlib_regex.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010045
Tony Mak6c4cc672018-09-17 11:48:50 +010046namespace libtextclassifier3 {
Tony Mak448b5862019-03-22 13:36:41 +000047
48using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
49
Tony Mak6c4cc672018-09-17 11:48:50 +010050const std::string& Annotator::kPhoneCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010051 *[]() { return new std::string("phone"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010052const std::string& Annotator::kAddressCollection =
Lukas Zilkae7962cc2018-03-28 18:09:48 +020053 *[]() { return new std::string("address"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010054const std::string& Annotator::kDateCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010055 *[]() { return new std::string("date"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000056const std::string& Annotator::kUrlCollection =
57 *[]() { return new std::string("url"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000058const std::string& Annotator::kEmailCollection =
59 *[]() { return new std::string("email"); }();
Lukas Zilkab23e2122018-02-09 10:25:19 +010060
Lukas Zilka21d8c982018-01-24 11:11:20 +010061namespace {
62const Model* LoadAndVerifyModel(const void* addr, int size) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010063 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
Tony Mak51a9e542018-11-02 13:36:22 +000064 if (VerifyModelBuffer(verifier)) {
65 return GetModel(addr);
Lukas Zilka21d8c982018-01-24 11:11:20 +010066 } else {
67 return nullptr;
68 }
69}
Tony Mak6c4cc672018-09-17 11:48:50 +010070
Tony Mak76d80962020-01-08 17:30:51 +000071const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
72 int size) {
73 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
74 if (VerifyPersonNameModelBuffer(verifier)) {
75 return GetPersonNameModel(addr);
76 } else {
77 return nullptr;
78 }
79}
80
Tony Mak6c4cc672018-09-17 11:48:50 +010081// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
82// create a new instance, assign ownership to owned_lib, and return it.
83const UniLib* MaybeCreateUnilib(const UniLib* lib,
84 std::unique_ptr<UniLib>* owned_lib) {
85 if (lib) {
86 return lib;
87 } else {
88 owned_lib->reset(new UniLib);
89 return owned_lib->get();
90 }
91}
92
93// As above, but for CalendarLib.
94const CalendarLib* MaybeCreateCalendarlib(
95 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
96 if (lib) {
97 return lib;
98 } else {
99 owned_lib->reset(new CalendarLib);
100 return owned_lib->get();
101 }
102}
103
Tony Mak968412a2019-11-13 15:39:57 +0000104// Returns whether the provided input is valid:
105// * Valid utf8 text.
106// * Sane span indices.
107bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
108 if (!context.is_valid()) {
109 return false;
110 }
111 return (span.first >= 0 && span.first < span.second &&
112 span.second <= context.size_codepoints());
113}
114
Tony Mak63959242020-02-07 18:31:16 +0000115std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
116 const flatbuffers::Vector<int32_t>* ints) {
117 if (ints == nullptr) {
118 return {};
119 }
120 std::unordered_set<char32> ints_set;
121 for (auto value : *ints) {
122 ints_set.insert(static_cast<char32>(value));
123 }
124 return ints_set;
125}
126
Tony Mak21460022020-03-12 18:29:35 +0000127DateAnnotationOptions ToDateAnnotationOptions(
128 const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
129 const std::string& reference_timezone, const int64 reference_time_ms_utc) {
130 DateAnnotationOptions result_annotation_options;
131 result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
132 result_annotation_options.reference_timezone = reference_timezone;
133 if (fb_annotation_options != nullptr) {
134 result_annotation_options.enable_special_day_offset =
135 fb_annotation_options->enable_special_day_offset();
136 result_annotation_options.merge_adjacent_components =
137 fb_annotation_options->merge_adjacent_components();
138 result_annotation_options.enable_date_range =
139 fb_annotation_options->enable_date_range();
140 result_annotation_options.include_preposition =
141 fb_annotation_options->include_preposition();
Tony Mak21460022020-03-12 18:29:35 +0000142 if (fb_annotation_options->extra_requested_dates() != nullptr) {
143 for (const auto& extra_requested_date :
144 *fb_annotation_options->extra_requested_dates()) {
145 result_annotation_options.extra_requested_dates.push_back(
146 extra_requested_date->str());
147 }
148 }
Tony Makd99d58c2020-03-19 21:52:02 +0000149 if (fb_annotation_options->ignored_spans() != nullptr) {
150 for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) {
151 result_annotation_options.ignored_spans.push_back(ignored_span->str());
Tony Mak0b8b3322020-03-17 16:30:19 +0000152 }
153 }
Tony Mak21460022020-03-12 18:29:35 +0000154 }
155 return result_annotation_options;
156}
157
Lukas Zilka21d8c982018-01-24 11:11:20 +0100158} // namespace
159
Lukas Zilkaba849e72018-03-08 14:48:21 +0100160tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
161 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100162 TC3_CHECK(selection_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100163 selection_interpreter_ = selection_executor_->CreateInterpreter();
164 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100165 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100166 }
167 }
168 return selection_interpreter_.get();
169}
170
171tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
172 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100173 TC3_CHECK(classification_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100174 classification_interpreter_ = classification_executor_->CreateInterpreter();
175 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100176 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100177 }
178 }
179 return classification_interpreter_.get();
180}
181
Tony Mak6c4cc672018-09-17 11:48:50 +0100182std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
183 const char* buffer, int size, const UniLib* unilib,
184 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100185 const Model* model = LoadAndVerifyModel(buffer, size);
186 if (model == nullptr) {
187 return nullptr;
188 }
189
Lukas Zilkab23e2122018-02-09 10:25:19 +0100190 auto classifier =
Tony Mak6c4cc672018-09-17 11:48:50 +0100191 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100192 if (!classifier->IsInitialized()) {
193 return nullptr;
194 }
195
196 return classifier;
197}
198
Tony Mak6c4cc672018-09-17 11:48:50 +0100199std::unique_ptr<Annotator> Annotator::FromScopedMmap(
200 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
201 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100202 if (!(*mmap)->handle().ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100203 TC3_VLOG(1) << "Mmap failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100204 return nullptr;
205 }
206
207 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
208 (*mmap)->handle().num_bytes());
209 if (!model) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100210 TC3_LOG(ERROR) << "Model verification failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100211 return nullptr;
212 }
213
Tony Mak6c4cc672018-09-17 11:48:50 +0100214 auto classifier = std::unique_ptr<Annotator>(
215 new Annotator(mmap, model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100216 if (!classifier->IsInitialized()) {
217 return nullptr;
218 }
219
220 return classifier;
221}
222
Tony Makdf54e742019-03-26 14:04:00 +0000223std::unique_ptr<Annotator> Annotator::FromScopedMmap(
224 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
225 std::unique_ptr<CalendarLib> calendarlib) {
226 if (!(*mmap)->handle().ok()) {
227 TC3_VLOG(1) << "Mmap failed.";
228 return nullptr;
229 }
230
231 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
232 (*mmap)->handle().num_bytes());
233 if (model == nullptr) {
234 TC3_LOG(ERROR) << "Model verification failed.";
235 return nullptr;
236 }
237
238 auto classifier = std::unique_ptr<Annotator>(
239 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
240 if (!classifier->IsInitialized()) {
241 return nullptr;
242 }
243
244 return classifier;
245}
246
Tony Mak6c4cc672018-09-17 11:48:50 +0100247std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
248 int fd, int offset, int size, const UniLib* unilib,
249 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100250 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
Tony Mak6c4cc672018-09-17 11:48:50 +0100251 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100252}
253
Tony Mak6c4cc672018-09-17 11:48:50 +0100254std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Makdf54e742019-03-26 14:04:00 +0000255 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
256 std::unique_ptr<CalendarLib> calendarlib) {
257 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
258 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
259}
260
261std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Mak6c4cc672018-09-17 11:48:50 +0100262 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100263 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
Tony Mak6c4cc672018-09-17 11:48:50 +0100264 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100265}
266
Tony Makdf54e742019-03-26 14:04:00 +0000267std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268 int fd, std::unique_ptr<UniLib> unilib,
269 std::unique_ptr<CalendarLib> calendarlib) {
270 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
271 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
272}
273
Tony Mak6c4cc672018-09-17 11:48:50 +0100274std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
275 const UniLib* unilib,
276 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100277 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
Tony Mak6c4cc672018-09-17 11:48:50 +0100278 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100279}
280
Tony Makdf54e742019-03-26 14:04:00 +0000281std::unique_ptr<Annotator> Annotator::FromPath(
282 const std::string& path, std::unique_ptr<UniLib> unilib,
283 std::unique_ptr<CalendarLib> calendarlib) {
284 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
285 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
286}
287
Tony Mak6c4cc672018-09-17 11:48:50 +0100288Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
289 const UniLib* unilib, const CalendarLib* calendarlib)
290 : model_(model),
291 mmap_(std::move(*mmap)),
292 owned_unilib_(nullptr),
293 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
294 owned_calendarlib_(nullptr),
295 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
296 ValidateAndInitialize();
297}
298
Tony Makdf54e742019-03-26 14:04:00 +0000299Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
300 std::unique_ptr<UniLib> unilib,
301 std::unique_ptr<CalendarLib> calendarlib)
302 : model_(model),
303 mmap_(std::move(*mmap)),
304 owned_unilib_(std::move(unilib)),
305 unilib_(owned_unilib_.get()),
306 owned_calendarlib_(std::move(calendarlib)),
307 calendarlib_(owned_calendarlib_.get()) {
308 ValidateAndInitialize();
309}
310
Tony Mak6c4cc672018-09-17 11:48:50 +0100311Annotator::Annotator(const Model* model, const UniLib* unilib,
312 const CalendarLib* calendarlib)
313 : model_(model),
314 owned_unilib_(nullptr),
315 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
316 owned_calendarlib_(nullptr),
317 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
318 ValidateAndInitialize();
319}
320
321void Annotator::ValidateAndInitialize() {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100322 initialized_ = false;
323
Lukas Zilka21d8c982018-01-24 11:11:20 +0100324 if (model_ == nullptr) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100325 TC3_LOG(ERROR) << "No model specified.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100326 return;
327 }
328
Lukas Zilkaba849e72018-03-08 14:48:21 +0100329 const bool model_enabled_for_annotation =
330 (model_->triggering_options() != nullptr &&
331 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
332 const bool model_enabled_for_classification =
333 (model_->triggering_options() != nullptr &&
334 (model_->triggering_options()->enabled_modes() &
335 ModeFlag_CLASSIFICATION));
336 const bool model_enabled_for_selection =
337 (model_->triggering_options() != nullptr &&
338 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
339
340 // Annotation requires the selection model.
341 if (model_enabled_for_annotation || model_enabled_for_selection) {
342 if (!model_->selection_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100343 TC3_LOG(ERROR) << "No selection options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100344 return;
345 }
346 if (!model_->selection_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100347 TC3_LOG(ERROR) << "No selection feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100348 return;
349 }
350 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100351 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100352 return;
353 }
354 if (!model_->selection_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100355 TC3_LOG(ERROR) << "No selection model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100356 return;
357 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100358 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100359 if (!selection_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100360 TC3_LOG(ERROR) << "Could not initialize selection executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100361 return;
362 }
363 selection_feature_processor_.reset(
364 new FeatureProcessor(model_->selection_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100365 }
366
Lukas Zilkaba849e72018-03-08 14:48:21 +0100367 // Annotation requires the classification model for conflict resolution and
368 // scoring.
369 // Selection requires the classification model for conflict resolution.
370 if (model_enabled_for_annotation || model_enabled_for_classification ||
371 model_enabled_for_selection) {
372 if (!model_->classification_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100373 TC3_LOG(ERROR) << "No classification options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100374 return;
375 }
376
377 if (!model_->classification_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100378 TC3_LOG(ERROR) << "No classification feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100379 return;
380 }
381
382 if (!model_->classification_feature_options()
383 ->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100384 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100385 return;
386 }
387 if (!model_->classification_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100388 TC3_LOG(ERROR) << "No clf model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100389 return;
390 }
391
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200392 classification_executor_ =
Tony Mak6c4cc672018-09-17 11:48:50 +0100393 ModelExecutor::FromBuffer(model_->classification_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100394 if (!classification_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100395 TC3_LOG(ERROR) << "Could not initialize classification executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100396 return;
397 }
398
399 classification_feature_processor_.reset(new FeatureProcessor(
400 model_->classification_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100401 }
402
Lukas Zilkaba849e72018-03-08 14:48:21 +0100403 // The embeddings need to be specified if the model is to be used for
404 // classification or selection.
405 if (model_enabled_for_annotation || model_enabled_for_classification ||
406 model_enabled_for_selection) {
407 if (!model_->embedding_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100408 TC3_LOG(ERROR) << "No embedding model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100409 return;
410 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100411
Lukas Zilkaba849e72018-03-08 14:48:21 +0100412 // Check that the embedding size of the selection and classification model
413 // matches, as they are using the same embeddings.
414 if (model_enabled_for_selection &&
415 (model_->selection_feature_options()->embedding_size() !=
416 model_->classification_feature_options()->embedding_size() ||
417 model_->selection_feature_options()->embedding_quantization_bits() !=
418 model_->classification_feature_options()
419 ->embedding_quantization_bits())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100420 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100421 return;
422 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100423
Tony Mak6c4cc672018-09-17 11:48:50 +0100424 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200425 model_->embedding_model(),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100426 model_->classification_feature_options()->embedding_size(),
Tony Makdf54e742019-03-26 14:04:00 +0000427 model_->classification_feature_options()->embedding_quantization_bits(),
428 model_->embedding_pruning_mask());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200429 if (!embedding_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100430 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100431 return;
432 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100433 }
434
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200435 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
Lukas Zilkab23e2122018-02-09 10:25:19 +0100436 if (model_->regex_model()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200437 if (!InitializeRegexModel(decompressor.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100438 TC3_LOG(ERROR) << "Could not initialize regex model.";
Lukas Zilka434442d2018-04-25 11:38:51 +0200439 return;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100440 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100441 }
Tony Mak63959242020-02-07 18:31:16 +0000442 if (model_->grammar_datetime_model() &&
443 model_->grammar_datetime_model()->datetime_rules()) {
444 cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100445 unilib_,
Tony Mak63959242020-02-07 18:31:16 +0000446 /*tokenizer_options=*/
447 model_->grammar_datetime_model()->grammar_tokenizer_options(),
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100448 calendarlib_,
Tony Mak21460022020-03-12 18:29:35 +0000449 /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
450 model_->grammar_datetime_model()->target_classification_score(),
451 model_->grammar_datetime_model()->priority_score()));
Tony Mak63959242020-02-07 18:31:16 +0000452 if (!cfg_datetime_parser_) {
453 TC3_LOG(ERROR) << "Could not initialize context free grammar based "
454 "datetime parser.";
455 return;
456 }
Tony Makd99d58c2020-03-19 21:52:02 +0000457 }
458
459 if (model_->datetime_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100460 datetime_parser_ = DatetimeParser::Instance(
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100461 model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100462 if (!datetime_parser_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100463 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100464 return;
465 }
466 }
467
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200468 if (model_->output_options()) {
469 if (model_->output_options()->filtered_collections_annotation()) {
470 for (const auto collection :
471 *model_->output_options()->filtered_collections_annotation()) {
472 filtered_collections_annotation_.insert(collection->str());
473 }
474 }
475 if (model_->output_options()->filtered_collections_classification()) {
476 for (const auto collection :
477 *model_->output_options()->filtered_collections_classification()) {
478 filtered_collections_classification_.insert(collection->str());
479 }
480 }
481 if (model_->output_options()->filtered_collections_selection()) {
482 for (const auto collection :
483 *model_->output_options()->filtered_collections_selection()) {
484 filtered_collections_selection_.insert(collection->str());
485 }
486 }
487 }
488
Tony Mak378c1f52019-03-04 15:58:11 +0000489 if (model_->number_annotator_options() &&
490 model_->number_annotator_options()->enabled()) {
491 number_annotator_.reset(
Tony Mak63959242020-02-07 18:31:16 +0000492 new NumberAnnotator(model_->number_annotator_options(), unilib_));
493 }
494
495 if (model_->money_parsing_options()) {
496 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
497 model_->money_parsing_options()->separators());
Tony Mak378c1f52019-03-04 15:58:11 +0000498 }
499
Tony Makad2e22d2019-03-20 17:35:13 +0000500 if (model_->duration_annotator_options() &&
501 model_->duration_annotator_options()->enabled()) {
502 duration_annotator_.reset(
503 new DurationAnnotator(model_->duration_annotator_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +0100504 selection_feature_processor_.get(), unilib_));
Tony Makad2e22d2019-03-20 17:35:13 +0000505 }
506
Tony Makd9446602019-02-20 18:25:39 +0000507 if (model_->entity_data_schema()) {
508 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
509 model_->entity_data_schema()->Data(),
510 model_->entity_data_schema()->size());
511 if (entity_data_schema_ == nullptr) {
512 TC3_LOG(ERROR) << "Could not load entity data schema data.";
513 return;
514 }
515
516 entity_data_builder_.reset(
517 new ReflectiveFlatbufferBuilder(entity_data_schema_));
518 } else {
Tony Mak378c1f52019-03-04 15:58:11 +0000519 entity_data_schema_ = nullptr;
Tony Makd9446602019-02-20 18:25:39 +0000520 entity_data_builder_ = nullptr;
521 }
522
Tony Mak21460022020-03-12 18:29:35 +0000523 if (model_->grammar_model()) {
524 grammar_annotator_.reset(new GrammarAnnotator(
525 unilib_, model_->grammar_model(), entity_data_builder_.get()));
526 }
527
Tony Makdf54e742019-03-26 14:04:00 +0000528 if (model_->triggering_locales() &&
529 !ParseLocales(model_->triggering_locales()->c_str(),
530 &model_triggering_locales_)) {
Tony Mak378c1f52019-03-04 15:58:11 +0000531 TC3_LOG(ERROR) << "Could not parse model supported locales.";
532 return;
533 }
534
535 if (model_->triggering_options() != nullptr &&
Tony Makdf54e742019-03-26 14:04:00 +0000536 model_->triggering_options()->locales() != nullptr &&
537 !ParseLocales(model_->triggering_options()->locales()->c_str(),
538 &ml_model_triggering_locales_)) {
539 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
540 return;
541 }
542
543 if (model_->triggering_options() != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +0000544 model_->triggering_options()->dictionary_locales() != nullptr &&
545 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
546 &dictionary_locales_)) {
547 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
548 return;
549 }
550
Tony Mak5a12b942020-05-01 12:41:31 +0100551 if (model_->conflict_resolution_options() != nullptr) {
552 prioritize_longest_annotation_ =
553 model_->conflict_resolution_options()->prioritize_longest_annotation();
554 do_conflict_resolution_in_raw_mode_ =
555 model_->conflict_resolution_options()
556 ->do_conflict_resolution_in_raw_mode();
557 }
558
Chang Licac0b442020-05-21 15:09:37 +0100559#ifdef TC3_EXPERIMENTAL
560 TC3_LOG(WARNING) << "Enabling experimental annotators.";
561 InitializeExperimentalAnnotators();
562#endif
563
Lukas Zilka21d8c982018-01-24 11:11:20 +0100564 initialized_ = true;
565}
566
Tony Mak6c4cc672018-09-17 11:48:50 +0100567bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100568 if (!model_->regex_model()->patterns()) {
Lukas Zilka434442d2018-04-25 11:38:51 +0200569 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100570 }
571
572 // Initialize pattern recognizers.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100573 int regex_pattern_id = 0;
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100574 for (const auto regex_pattern : *model_->regex_model()->patterns()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200575 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
Tony Mak378c1f52019-03-04 15:58:11 +0000576 UncompressMakeRegexPattern(
577 *unilib_, regex_pattern->pattern(),
578 regex_pattern->compressed_pattern(),
579 model_->regex_model()->lazy_regex_compilation(), decompressor);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100580 if (!compiled_pattern) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100581 TC3_LOG(INFO) << "Failed to load regex pattern";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200582 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100583 }
584
Lukas Zilkaba849e72018-03-08 14:48:21 +0100585 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100586 annotation_regex_patterns_.push_back(regex_pattern_id);
587 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100588 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100589 classification_regex_patterns_.push_back(regex_pattern_id);
590 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100591 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100592 selection_regex_patterns_.push_back(regex_pattern_id);
593 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100594 regex_patterns_.push_back({
Tony Mak854015a2019-01-16 15:56:48 +0000595 regex_pattern,
Tony Mak6c4cc672018-09-17 11:48:50 +0100596 std::move(compiled_pattern),
Tony Mak6c4cc672018-09-17 11:48:50 +0100597 });
Lukas Zilkab23e2122018-02-09 10:25:19 +0100598 ++regex_pattern_id;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100599 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100600
Lukas Zilkab23e2122018-02-09 10:25:19 +0100601 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100602}
603
Tony Mak6c4cc672018-09-17 11:48:50 +0100604bool Annotator::InitializeKnowledgeEngine(
605 const std::string& serialized_config) {
Tony Maka2a1ff42019-09-12 15:40:32 +0100606 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
Tony Mak63959242020-02-07 18:31:16 +0000607 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100608 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
609 return false;
610 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100611 if (model_->triggering_options() != nullptr) {
612 knowledge_engine->SetPriorityScore(
613 model_->triggering_options()->knowledge_priority_score());
614 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100615 knowledge_engine_ = std::move(knowledge_engine);
616 return true;
617}
618
Tony Mak854015a2019-01-16 15:56:48 +0000619bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
Tony Mak5dc5e112019-02-01 14:52:10 +0000620 std::unique_ptr<ContactEngine> contact_engine(
Tony Mak63959242020-02-07 18:31:16 +0000621 new ContactEngine(selection_feature_processor_.get(), unilib_,
622 model_->contact_annotator_options()));
Tony Mak854015a2019-01-16 15:56:48 +0000623 if (!contact_engine->Initialize(serialized_config)) {
624 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
625 return false;
626 }
627 contact_engine_ = std::move(contact_engine);
628 return true;
629}
630
Tony Makd9446602019-02-20 18:25:39 +0000631bool Annotator::InitializeInstalledAppEngine(
632 const std::string& serialized_config) {
633 std::unique_ptr<InstalledAppEngine> installed_app_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000634 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
Tony Makd9446602019-02-20 18:25:39 +0000635 if (!installed_app_engine->Initialize(serialized_config)) {
636 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
637 return false;
638 }
639 installed_app_engine_ = std::move(installed_app_engine);
640 return true;
641}
642
Tony Mak63959242020-02-07 18:31:16 +0000643void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
644 lang_id_ = lang_id;
Tony Mak21460022020-03-12 18:29:35 +0000645 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
Tony Mak63959242020-02-07 18:31:16 +0000646 model_->translate_annotator_options()->enabled()) {
647 translate_annotator_.reset(new TranslateAnnotator(
648 model_->translate_annotator_options(), lang_id_, unilib_));
Tony Mak21460022020-03-12 18:29:35 +0000649 } else {
650 translate_annotator_.reset(nullptr);
Tony Mak63959242020-02-07 18:31:16 +0000651 }
652}
653
Tony Mak21460022020-03-12 18:29:35 +0000654bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
655 int size) {
656 const PersonNameModel* person_name_model =
657 LoadAndVerifyPersonNameModel(buffer, size);
Tony Mak76d80962020-01-08 17:30:51 +0000658
659 if (person_name_model == nullptr) {
660 TC3_LOG(ERROR) << "Person name model verification failed.";
661 return false;
662 }
663
664 if (!person_name_model->enabled()) {
665 return true;
666 }
667
668 std::unique_ptr<PersonNameEngine> person_name_engine(
Tony Mak21460022020-03-12 18:29:35 +0000669 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
Tony Mak76d80962020-01-08 17:30:51 +0000670 if (!person_name_engine->Initialize(person_name_model)) {
671 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
672 return false;
673 }
674 person_name_engine_ = std::move(person_name_engine);
675 return true;
676}
677
Tony Mak21460022020-03-12 18:29:35 +0000678bool Annotator::InitializePersonNameEngineFromScopedMmap(
679 const ScopedMmap& mmap) {
680 if (!mmap.handle().ok()) {
681 TC3_LOG(ERROR) << "Mmap for person name model failed.";
682 return false;
683 }
684
685 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
686 mmap.handle().num_bytes());
687}
688
689bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
690 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
691 return InitializePersonNameEngineFromScopedMmap(*mmap);
692}
693
694bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
695 int size) {
696 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
697 return InitializePersonNameEngineFromScopedMmap(*mmap);
698}
699
Tony Mak5a12b942020-05-01 12:41:31 +0100700bool Annotator::InitializeExperimentalAnnotators() {
701 if (ExperimentalAnnotator::IsEnabled()) {
Tony Makc121edd2020-05-28 15:25:17 +0100702 experimental_annotator_.reset(new ExperimentalAnnotator(
703 model_->experimental_model(), *selection_feature_processor_, *unilib_));
Tony Mak5a12b942020-05-01 12:41:31 +0100704 return true;
705 }
706 return false;
707}
708
Lukas Zilka21d8c982018-01-24 11:11:20 +0100709namespace {
710
711int CountDigits(const std::string& str, CodepointSpan selection_indices) {
712 int count = 0;
713 int i = 0;
714 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
715 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
716 if (i >= selection_indices.first && i < selection_indices.second &&
Tony Mak21460022020-03-12 18:29:35 +0000717 IsDigit(*it)) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100718 ++count;
719 }
720 }
721 return count;
722}
723
Lukas Zilka21d8c982018-01-24 11:11:20 +0100724} // namespace
725
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200726namespace internal {
727// Helper function, which if the initial 'span' contains only white-spaces,
728// moves the selection to a single-codepoint selection on a left or right side
729// of this space.
730CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
731 const UnicodeText& context_unicode,
732 const UniLib& unilib) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100733 TC3_CHECK(ValidNonEmptySpan(span));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200734
735 UnicodeText::const_iterator it;
736
737 // Check that the current selection is all whitespaces.
738 it = context_unicode.begin();
739 std::advance(it, span.first);
740 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
741 if (!unilib.IsWhitespace(*it)) {
742 return span;
743 }
744 }
745
746 CodepointSpan result;
747
748 // Try moving left.
749 result = span;
750 it = context_unicode.begin();
751 std::advance(it, span.first);
752 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
753 --result.first;
754 --it;
755 }
756 result.second = result.first + 1;
757 if (!unilib.IsWhitespace(*it)) {
758 return result;
759 }
760
761 // If moving left didn't find a non-whitespace character, just return the
762 // original span.
763 return span;
764}
765} // namespace internal
766
Tony Mak6c4cc672018-09-17 11:48:50 +0100767bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200768 return !span.classification.empty() &&
769 filtered_collections_annotation_.find(
770 span.classification[0].collection) !=
771 filtered_collections_annotation_.end();
772}
773
Tony Mak6c4cc672018-09-17 11:48:50 +0100774bool Annotator::FilteredForClassification(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200775 const ClassificationResult& classification) const {
776 return filtered_collections_classification_.find(classification.collection) !=
777 filtered_collections_classification_.end();
778}
779
Tony Mak6c4cc672018-09-17 11:48:50 +0100780bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200781 return !span.classification.empty() &&
782 filtered_collections_selection_.find(
783 span.classification[0].collection) !=
784 filtered_collections_selection_.end();
785}
786
Tony Mak378c1f52019-03-04 15:58:11 +0000787namespace {
788inline bool ClassifiedAsOther(
789 const std::vector<ClassificationResult>& classification) {
790 return !classification.empty() &&
791 classification[0].collection == Collections::Other();
792}
793
Tony Maka2a1ff42019-09-12 15:40:32 +0100794} // namespace
795
796float Annotator::GetPriorityScore(
797 const std::vector<ClassificationResult>& classification) const {
Tony Mak378c1f52019-03-04 15:58:11 +0000798 if (!classification.empty() && !ClassifiedAsOther(classification)) {
799 return classification[0].priority_score;
800 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +0100801 if (model_->triggering_options() != nullptr) {
802 return model_->triggering_options()->other_collection_priority_score();
803 } else {
804 return -1000.0;
805 }
Tony Mak378c1f52019-03-04 15:58:11 +0000806 }
807}
Tony Mak378c1f52019-03-04 15:58:11 +0000808
Tony Makdf54e742019-03-26 14:04:00 +0000809bool Annotator::VerifyRegexMatchCandidate(
810 const std::string& context, const VerificationOptions* verification_options,
811 const std::string& match, const UniLib::RegexMatcher* matcher) const {
812 if (verification_options == nullptr) {
813 return true;
814 }
815 if (verification_options->verify_luhn_checksum() &&
816 !VerifyLuhnChecksum(match)) {
817 return false;
818 }
819 const int lua_verifier = verification_options->lua_verifier();
820 if (lua_verifier >= 0) {
821 if (model_->regex_model()->lua_verifier() == nullptr ||
822 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
823 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
824 return false;
825 }
826 return VerifyMatch(
827 context, matcher,
828 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
829 }
830 return true;
831}
832
Tony Mak6c4cc672018-09-17 11:48:50 +0100833CodepointSpan Annotator::SuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100834 const std::string& context, CodepointSpan click_indices,
835 const SelectionOptions& options) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200836 CodepointSpan original_click_indices = click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100837 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100838 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200839 return original_click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100840 }
Tony Mak5a12b942020-05-01 12:41:31 +0100841 if (options.annotation_usecase !=
842 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
843 TC3_LOG(WARNING)
844 << "Invoking SuggestSelection, which is not supported in RAW mode.";
845 return original_click_indices;
846 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100847 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200848 return original_click_indices;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100849 }
850
Tony Makdf54e742019-03-26 14:04:00 +0000851 std::vector<Locale> detected_text_language_tags;
852 if (!ParseLocales(options.detected_text_language_tags,
853 &detected_text_language_tags)) {
854 TC3_LOG(WARNING)
855 << "Failed to parse the detected_text_language_tags in options: "
856 << options.detected_text_language_tags;
857 }
858 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
859 model_triggering_locales_,
860 /*default_value=*/true)) {
861 return original_click_indices;
862 }
863
Lukas Zilkadf710db2018-02-27 12:44:09 +0100864 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
865 /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200866
Tony Mak968412a2019-11-13 15:39:57 +0000867 if (!IsValidSpanInput(context_unicode, click_indices)) {
868 TC3_VLOG(1)
869 << "Trying to run SuggestSelection with invalid input, indices: "
870 << click_indices.first << " " << click_indices.second;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200871 return original_click_indices;
872 }
873
874 if (model_->snap_whitespace_selections()) {
875 // We want to expand a purely white-space selection to a multi-selection it
876 // would've been part of. But with this feature disabled we would do a no-
877 // op, because no token is found. Therefore, we need to modify the
878 // 'click_indices' a bit to include a part of the token, so that the click-
879 // finding logic finds the clicked token correctly. This modification is
880 // done by the following function. Note, that it's enough to check the left
881 // side of the current selection, because if the white-space is a part of a
Tony Mak6c4cc672018-09-17 11:48:50 +0100882 // multi-selection, necessarily both tokens - on the left and the right
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200883 // sides need to be selected. Thus snapping only to the left is sufficient
884 // (there's a check at the bottom that makes sure that if we snap to the
885 // left token but the result does not contain the initial white-space,
886 // returns the original indices).
887 click_indices = internal::SnapLeftIfWhitespaceSelection(
888 click_indices, context_unicode, *unilib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100889 }
890
Lukas Zilkab23e2122018-02-09 10:25:19 +0100891 std::vector<AnnotatedSpan> candidates;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100892 InterpreterManager interpreter_manager(selection_executor_.get(),
893 classification_executor_.get());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200894 std::vector<Token> tokens;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100895 if (!ModelSuggestSelection(context_unicode, click_indices,
Tony Makdf54e742019-03-26 14:04:00 +0000896 detected_text_language_tags, &interpreter_manager,
897 &tokens, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100898 TC3_LOG(ERROR) << "Model suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200899 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100900 }
Tony Mak83d2de62019-04-10 16:12:15 +0100901 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
902 /*is_serialized_entity_data_enabled=*/false)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100903 TC3_LOG(ERROR) << "Regex suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200904 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100905 }
Tony Mak83d2de62019-04-10 16:12:15 +0100906 if (!DatetimeChunk(
907 UTF8ToUnicodeText(context, /*do_copy=*/false),
908 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
909 options.locales, ModeFlag_SELECTION, options.annotation_usecase,
910 /*is_serialized_entity_data_enabled=*/false, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100911 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
912 return original_click_indices;
913 }
Tony Mak378c1f52019-03-04 15:58:11 +0000914 if (knowledge_engine_ != nullptr &&
Tony Maka2a1ff42019-09-12 15:40:32 +0100915 !knowledge_engine_->Chunk(context, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +0100916 options.location_context, Permissions(),
917 &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100918 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200919 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100920 }
Tony Mak378c1f52019-03-04 15:58:11 +0000921 if (contact_engine_ != nullptr &&
Tony Mak854015a2019-01-16 15:56:48 +0000922 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
923 TC3_LOG(ERROR) << "Contact suggest selection failed.";
924 return original_click_indices;
925 }
Tony Mak378c1f52019-03-04 15:58:11 +0000926 if (installed_app_engine_ != nullptr &&
Tony Makd9446602019-02-20 18:25:39 +0000927 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
928 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
929 return original_click_indices;
930 }
Tony Mak378c1f52019-03-04 15:58:11 +0000931 if (number_annotator_ != nullptr &&
932 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
933 &candidates)) {
934 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
935 return original_click_indices;
936 }
Tony Makad2e22d2019-03-20 17:35:13 +0000937 if (duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +0000938 !duration_annotator_->FindAll(context_unicode, tokens,
939 options.annotation_usecase, &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +0000940 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
941 return original_click_indices;
942 }
Tony Mak76d80962020-01-08 17:30:51 +0000943 if (person_name_engine_ != nullptr &&
944 !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
945 TC3_LOG(ERROR) << "Person name suggest selection failed.";
946 return original_click_indices;
947 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100948
Tony Mak21460022020-03-12 18:29:35 +0000949 AnnotatedSpan grammar_suggested_span;
950 if (grammar_annotator_ != nullptr &&
951 grammar_annotator_->SuggestSelection(detected_text_language_tags,
952 context_unicode, click_indices,
953 &grammar_suggested_span)) {
954 candidates.push_back(grammar_suggested_span);
955 }
956
Tony Mak5a12b942020-05-01 12:41:31 +0100957 if (experimental_annotator_ != nullptr) {
958 candidates.push_back(experimental_annotator_->SuggestSelection(
959 context_unicode, click_indices));
960 }
961
Lukas Zilkab23e2122018-02-09 10:25:19 +0100962 // Sort candidates according to their position in the input, so that the next
963 // code can assume that any connected component of overlapping spans forms a
964 // contiguous block.
965 std::sort(candidates.begin(), candidates.end(),
966 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
967 return a.span.first < b.span.first;
968 });
969
970 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +0000971 if (!ResolveConflicts(candidates, context, tokens,
972 detected_text_language_tags, options.annotation_usecase,
973 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100974 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200975 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100976 }
977
Tony Mak378c1f52019-03-04 15:58:11 +0000978 std::sort(candidate_indices.begin(), candidate_indices.end(),
Tony Maka2a1ff42019-09-12 15:40:32 +0100979 [this, &candidates](int a, int b) {
Tony Mak378c1f52019-03-04 15:58:11 +0000980 return GetPriorityScore(candidates[a].classification) >
981 GetPriorityScore(candidates[b].classification);
982 });
983
Lukas Zilkab23e2122018-02-09 10:25:19 +0100984 for (const int i : candidate_indices) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200985 if (SpansOverlap(candidates[i].span, click_indices) &&
986 SpansOverlap(candidates[i].span, original_click_indices)) {
987 // Run model classification if not present but requested and there's a
988 // classification collection filter specified.
989 if (candidates[i].classification.empty() &&
990 model_->selection_options()->always_classify_suggested_selection() &&
991 !filtered_collections_selection_.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +0000992 if (!ModelClassifyText(context, detected_text_language_tags,
993 candidates[i].span, &interpreter_manager,
994 /*embedding_cache=*/nullptr,
995 &candidates[i].classification)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200996 return original_click_indices;
997 }
998 }
999
1000 // Ignore if span classification is filtered.
1001 if (FilteredForSelection(candidates[i])) {
1002 return original_click_indices;
1003 }
1004
Lukas Zilkab23e2122018-02-09 10:25:19 +01001005 return candidates[i].span;
1006 }
1007 }
1008
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001009 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001010}
1011
1012namespace {
1013// Helper function that returns the index of the first candidate that
1014// transitively does not overlap with the candidate on 'start_index'. If the end
1015// of 'candidates' is reached, it returns the index that points right behind the
1016// array.
1017int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1018 int start_index) {
1019 int first_non_overlapping = start_index + 1;
1020 CodepointSpan conflicting_span = candidates[start_index].span;
1021 while (
1022 first_non_overlapping < candidates.size() &&
1023 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1024 // Grow the span to include the current one.
1025 conflicting_span.second = std::max(
1026 conflicting_span.second, candidates[first_non_overlapping].span.second);
1027
1028 ++first_non_overlapping;
1029 }
1030 return first_non_overlapping;
1031}
1032} // namespace
1033
Tony Mak378c1f52019-03-04 15:58:11 +00001034bool Annotator::ResolveConflicts(
1035 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1036 const std::vector<Token>& cached_tokens,
1037 const std::vector<Locale>& detected_text_language_tags,
1038 AnnotationUsecase annotation_usecase,
1039 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001040 result->clear();
1041 result->reserve(candidates.size());
1042 for (int i = 0; i < candidates.size();) {
1043 int first_non_overlapping =
1044 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1045
1046 const bool conflict_found = first_non_overlapping != (i + 1);
1047 if (conflict_found) {
1048 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00001049 if (!ResolveConflict(context, cached_tokens, candidates,
1050 detected_text_language_tags, i,
1051 first_non_overlapping, annotation_usecase,
1052 interpreter_manager, &candidate_indices)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001053 return false;
1054 }
1055 result->insert(result->end(), candidate_indices.begin(),
1056 candidate_indices.end());
1057 } else {
1058 result->push_back(i);
1059 }
1060
1061 // Skip over the whole conflicting group/go to next candidate.
1062 i = first_non_overlapping;
1063 }
1064 return true;
1065}
1066
1067namespace {
Tony Mak448b5862019-03-22 13:36:41 +00001068// Returns true, if the given two sources do conflict in given annotation
1069// usecase.
1070// - In SMART usecase, all sources do conflict, because there's only 1 possible
1071// annotation for a given span.
1072// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1073// and duration), while others not (e.g. duration and number).
1074bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1075 const AnnotatedSpan::Source source1,
1076 const AnnotatedSpan::Source source2) {
1077 uint32 source_mask =
1078 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1079
Tony Mak378c1f52019-03-04 15:58:11 +00001080 switch (annotation_usecase) {
1081 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
Tony Mak448b5862019-03-22 13:36:41 +00001082 // In the SMART mode, all annotations conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001083 return true;
Tony Mak448b5862019-03-22 13:36:41 +00001084
Tony Mak378c1f52019-03-04 15:58:11 +00001085 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
Tony Mak448b5862019-03-22 13:36:41 +00001086 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1087 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1088 // hours" (duration).
1089 if ((source_mask &
1090 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1091 (source_mask &
1092 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1093 return false;
Tony Mak378c1f52019-03-04 15:58:11 +00001094 }
Tony Mak448b5862019-03-22 13:36:41 +00001095
1096 // A KNOWLEDGE entity does not conflict with anything.
1097 if ((source_mask &
1098 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1099 return false;
1100 }
1101
Tony Makd0ae7c62020-03-27 13:58:00 +00001102 // A PERSONNAME entity does not conflict with anything.
1103 if ((source_mask &
1104 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1105 return false;
1106 }
1107
Tony Mak448b5862019-03-22 13:36:41 +00001108 // Entities from other sources can conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001109 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001110 }
1111}
1112} // namespace
1113
Tony Mak378c1f52019-03-04 15:58:11 +00001114bool Annotator::ResolveConflict(
1115 const std::string& context, const std::vector<Token>& cached_tokens,
1116 const std::vector<AnnotatedSpan>& candidates,
1117 const std::vector<Locale>& detected_text_language_tags, int start_index,
1118 int end_index, AnnotationUsecase annotation_usecase,
1119 InterpreterManager* interpreter_manager,
1120 std::vector<int>* chosen_indices) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001121 std::vector<int> conflicting_indices;
Tony Mak76d80962020-01-08 17:30:51 +00001122 std::unordered_map<int, std::pair<float, int>> scores_lengths;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001123 for (int i = start_index; i < end_index; ++i) {
1124 conflicting_indices.push_back(i);
1125 if (!candidates[i].classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001126 scores_lengths[i] = {
1127 GetPriorityScore(candidates[i].classification),
1128 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001129 continue;
1130 }
1131
1132 // OPTIMIZATION: So that we don't have to classify all the ML model
1133 // spans apriori, we wait until we get here, when they conflict with
1134 // something and we need the actual classification scores. So if the
1135 // candidate conflicts and comes from the model, we need to run a
1136 // classification to determine its priority:
1137 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001138 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1139 candidates[i].span, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001140 /*embedding_cache=*/nullptr, &classification)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001141 return false;
1142 }
1143
1144 if (!classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001145 scores_lengths[i] = {
1146 GetPriorityScore(classification),
1147 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001148 }
1149 }
1150
Tony Mak5a12b942020-05-01 12:41:31 +01001151 std::sort(
1152 conflicting_indices.begin(), conflicting_indices.end(),
1153 [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1154 if (scores_lengths[i].first == scores_lengths[j].first &&
1155 prioritize_longest_annotation_) {
1156 return scores_lengths[i].second > scores_lengths[j].second;
1157 }
1158 return scores_lengths[i].first > scores_lengths[j].first;
1159 });
Lukas Zilkab23e2122018-02-09 10:25:19 +01001160
Tony Mak448b5862019-03-22 13:36:41 +00001161 // Here we keep a set of indices that were chosen, per-source, to enable
1162 // effective computation.
1163 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1164 chosen_indices_for_source_map;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001165
1166 // Greedily place the candidates if they don't conflict with the already
1167 // placed ones.
1168 for (int i = 0; i < conflicting_indices.size(); ++i) {
1169 const int considered_candidate = conflicting_indices[i];
Tony Mak448b5862019-03-22 13:36:41 +00001170
1171 // See if there is a conflict between the candidate and all already placed
1172 // candidates.
1173 bool conflict = false;
1174 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1175 for (auto& source_set_pair : chosen_indices_for_source_map) {
1176 if (source_set_pair.first == candidates[considered_candidate].source) {
1177 chosen_indices_for_source_ptr = &source_set_pair.second;
1178 }
1179
Tony Mak5a12b942020-05-01 12:41:31 +01001180 const bool needs_conflict_resolution =
1181 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1182 (annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1183 do_conflict_resolution_in_raw_mode_);
1184 if (needs_conflict_resolution &&
1185 DoSourcesConflict(annotation_usecase, source_set_pair.first,
Tony Mak448b5862019-03-22 13:36:41 +00001186 candidates[considered_candidate].source) &&
1187 DoesCandidateConflict(considered_candidate, candidates,
1188 source_set_pair.second)) {
1189 conflict = true;
1190 break;
1191 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001192 }
Tony Mak448b5862019-03-22 13:36:41 +00001193
1194 // Skip the candidate if a conflict was found.
1195 if (conflict) {
1196 continue;
1197 }
1198
1199 // If the set of indices for the current source doesn't exist yet,
1200 // initialize it.
1201 if (chosen_indices_for_source_ptr == nullptr) {
1202 SortedIntSet new_set([&candidates](int a, int b) {
1203 return candidates[a].span.first < candidates[b].span.first;
1204 });
1205 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1206 std::move(new_set);
1207 chosen_indices_for_source_ptr =
1208 &chosen_indices_for_source_map[candidates[considered_candidate]
1209 .source];
1210 }
1211
1212 // Place the candidate to the output and to the per-source conflict set.
1213 chosen_indices->push_back(considered_candidate);
1214 chosen_indices_for_source_ptr->insert(considered_candidate);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001215 }
1216
Tony Mak378c1f52019-03-04 15:58:11 +00001217 std::sort(chosen_indices->begin(), chosen_indices->end());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001218
1219 return true;
1220}
1221
Tony Mak6c4cc672018-09-17 11:48:50 +01001222bool Annotator::ModelSuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001223 const UnicodeText& context_unicode, CodepointSpan click_indices,
Tony Makdf54e742019-03-26 14:04:00 +00001224 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001225 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001226 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001227 if (model_->triggering_options() == nullptr ||
1228 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1229 return true;
1230 }
1231
Tony Makdf54e742019-03-26 14:04:00 +00001232 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1233 ml_model_triggering_locales_,
1234 /*default_value=*/true)) {
1235 return true;
1236 }
1237
Lukas Zilka21d8c982018-01-24 11:11:20 +01001238 int click_pos;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001239 *tokens = selection_feature_processor_->Tokenize(context_unicode);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001240 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001241 context_unicode, click_indices,
1242 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001243 tokens, &click_pos);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001244 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001245 TC3_VLOG(1) << "Could not calculate the click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001246 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001247 }
1248
1249 const int symmetry_context_size =
1250 model_->selection_options()->symmetry_context_size();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001251 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
Lukas Zilkab23e2122018-02-09 10:25:19 +01001252 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1253 ->bounds_sensitive_features();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001254
1255 // The symmetry context span is the clicked token with symmetry_context_size
1256 // tokens on either side.
1257 const TokenSpan symmetry_context_span = IntersectTokenSpans(
1258 ExpandTokenSpan(SingleTokenSpan(click_pos),
1259 /*num_tokens_left=*/symmetry_context_size,
1260 /*num_tokens_right=*/symmetry_context_size),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001261 {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001262
Lukas Zilkab23e2122018-02-09 10:25:19 +01001263 // Compute the extraction span based on the model type.
1264 TokenSpan extraction_span;
1265 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1266 // The extraction span is the symmetry context span expanded to include
1267 // max_selection_span tokens on either side, which is how far a selection
1268 // can stretch from the click, plus a relevant number of tokens outside of
1269 // the bounds of the selection.
1270 const int max_selection_span =
1271 selection_feature_processor_->GetOptions()->max_selection_span();
1272 extraction_span =
1273 ExpandTokenSpan(symmetry_context_span,
1274 /*num_tokens_left=*/max_selection_span +
1275 bounds_sensitive_features->num_tokens_before(),
1276 /*num_tokens_right=*/max_selection_span +
1277 bounds_sensitive_features->num_tokens_after());
1278 } else {
1279 // The extraction span is the symmetry context span expanded to include
1280 // context_size tokens on either side.
1281 const int context_size =
1282 selection_feature_processor_->GetOptions()->context_size();
1283 extraction_span = ExpandTokenSpan(symmetry_context_span,
1284 /*num_tokens_left=*/context_size,
1285 /*num_tokens_right=*/context_size);
1286 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001287 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilkab23e2122018-02-09 10:25:19 +01001288
Lukas Zilka434442d2018-04-25 11:38:51 +02001289 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1290 *tokens, extraction_span)) {
1291 return true;
1292 }
1293
Lukas Zilkab23e2122018-02-09 10:25:19 +01001294 std::unique_ptr<CachedFeatures> cached_features;
1295 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001296 *tokens, extraction_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001297 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1298 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001299 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001300 selection_feature_processor_->EmbeddingSize() +
1301 selection_feature_processor_->DenseFeaturesCount(),
1302 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001303 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001304 return false;
1305 }
1306
1307 // Produce selection model candidates.
1308 std::vector<TokenSpan> chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001309 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001310 interpreter_manager->SelectionInterpreter(), *cached_features,
1311 &chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001312 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001313 return false;
1314 }
1315
1316 for (const TokenSpan& chunk : chunks) {
1317 AnnotatedSpan candidate;
1318 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001319 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001320 if (model_->selection_options()->strip_unpaired_brackets()) {
1321 candidate.span =
1322 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1323 }
1324
1325 // Only output non-empty spans.
1326 if (candidate.span.first != candidate.span.second) {
1327 result->push_back(candidate);
1328 }
1329 }
1330 return true;
1331}
1332
Tony Mak6c4cc672018-09-17 11:48:50 +01001333bool Annotator::ModelClassifyText(
Tony Mak378c1f52019-03-04 15:58:11 +00001334 const std::string& context,
1335 const std::vector<Locale>& detected_text_language_tags,
1336 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001337 FeatureProcessor::EmbeddingCache* embedding_cache,
1338 std::vector<ClassificationResult>* classification_results) const {
Tony Mak378c1f52019-03-04 15:58:11 +00001339 return ModelClassifyText(context, {}, detected_text_language_tags,
1340 selection_indices, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001341 embedding_cache, classification_results);
1342}
1343
1344namespace internal {
1345std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1346 CodepointSpan selection_indices,
1347 TokenSpan tokens_around_selection_to_copy) {
1348 const auto first_selection_token = std::upper_bound(
1349 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1350 [](int selection_start, const Token& token) {
1351 return selection_start < token.end;
1352 });
1353 const auto last_selection_token = std::lower_bound(
1354 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1355 [](const Token& token, int selection_end) {
1356 return token.start < selection_end;
1357 });
1358
1359 const int64 first_token = std::max(
1360 static_cast<int64>(0),
1361 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1362 tokens_around_selection_to_copy.first));
1363 const int64 last_token = std::min(
1364 static_cast<int64>(cached_tokens.size()),
1365 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1366 tokens_around_selection_to_copy.second));
1367
1368 std::vector<Token> tokens;
1369 tokens.reserve(last_token - first_token);
1370 for (int i = first_token; i < last_token; ++i) {
1371 tokens.push_back(cached_tokens[i]);
1372 }
1373 return tokens;
1374}
1375} // namespace internal
1376
Tony Mak6c4cc672018-09-17 11:48:50 +01001377TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001378 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1379 bounds_sensitive_features =
1380 classification_feature_processor_->GetOptions()
1381 ->bounds_sensitive_features();
1382 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1383 // The extraction span is the selection span expanded to include a relevant
1384 // number of tokens outside of the bounds of the selection.
1385 return {bounds_sensitive_features->num_tokens_before(),
1386 bounds_sensitive_features->num_tokens_after()};
1387 } else {
1388 // The extraction span is the clicked token with context_size tokens on
1389 // either side.
1390 const int context_size =
1391 selection_feature_processor_->GetOptions()->context_size();
1392 return {context_size, context_size};
1393 }
1394}
1395
Tony Mak378c1f52019-03-04 15:58:11 +00001396namespace {
1397// Sorts the classification results from high score to low score.
1398void SortClassificationResults(
1399 std::vector<ClassificationResult>* classification_results) {
1400 std::sort(classification_results->begin(), classification_results->end(),
1401 [](const ClassificationResult& a, const ClassificationResult& b) {
1402 return a.score > b.score;
1403 });
1404}
1405} // namespace
1406
Tony Mak6c4cc672018-09-17 11:48:50 +01001407bool Annotator::ModelClassifyText(
Lukas Zilkaba849e72018-03-08 14:48:21 +01001408 const std::string& context, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001409 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001410 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1411 FeatureProcessor::EmbeddingCache* embedding_cache,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001412 std::vector<ClassificationResult>* classification_results) const {
1413 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00001414 return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1415 selection_indices, interpreter_manager,
1416 embedding_cache, classification_results, &tokens);
1417}
1418
1419bool Annotator::ModelClassifyText(
1420 const std::string& context, const std::vector<Token>& cached_tokens,
1421 const std::vector<Locale>& detected_text_language_tags,
1422 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1423 FeatureProcessor::EmbeddingCache* embedding_cache,
1424 std::vector<ClassificationResult>* classification_results,
1425 std::vector<Token>* tokens) const {
1426 if (model_->triggering_options() == nullptr ||
1427 !(model_->triggering_options()->enabled_modes() &
1428 ModeFlag_CLASSIFICATION)) {
1429 return true;
1430 }
1431
Tony Makdf54e742019-03-26 14:04:00 +00001432 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1433 ml_model_triggering_locales_,
1434 /*default_value=*/true)) {
1435 return true;
1436 }
1437
Lukas Zilkaba849e72018-03-08 14:48:21 +01001438 if (cached_tokens.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001439 *tokens = classification_feature_processor_->Tokenize(context);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001440 } else {
Tony Mak378c1f52019-03-04 15:58:11 +00001441 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1442 ClassifyTextUpperBoundNeededTokens());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001443 }
1444
Lukas Zilkab23e2122018-02-09 10:25:19 +01001445 int click_pos;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001446 classification_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001447 context, selection_indices,
1448 classification_feature_processor_->GetOptions()
1449 ->only_use_line_with_click(),
Tony Mak378c1f52019-03-04 15:58:11 +00001450 tokens, &click_pos);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001451 const TokenSpan selection_token_span =
Tony Mak378c1f52019-03-04 15:58:11 +00001452 CodepointSpanToTokenSpan(*tokens, selection_indices);
Lukas Zilka434442d2018-04-25 11:38:51 +02001453 const int selection_num_tokens = TokenSpanSize(selection_token_span);
1454 if (model_->classification_options()->max_num_tokens() > 0 &&
1455 model_->classification_options()->max_num_tokens() <
1456 selection_num_tokens) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001457 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001458 return true;
1459 }
1460
Lukas Zilkab23e2122018-02-09 10:25:19 +01001461 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1462 bounds_sensitive_features =
1463 classification_feature_processor_->GetOptions()
1464 ->bounds_sensitive_features();
1465 if (selection_token_span.first == kInvalidIndex ||
1466 selection_token_span.second == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001467 TC3_LOG(ERROR) << "Could not determine span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001468 return false;
1469 }
1470
1471 // Compute the extraction span based on the model type.
1472 TokenSpan extraction_span;
1473 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1474 // The extraction span is the selection span expanded to include a relevant
1475 // number of tokens outside of the bounds of the selection.
1476 extraction_span = ExpandTokenSpan(
1477 selection_token_span,
1478 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1479 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1480 } else {
1481 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001482 TC3_LOG(ERROR) << "Couldn't choose a click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001483 return false;
1484 }
1485 // The extraction span is the clicked token with context_size tokens on
1486 // either side.
1487 const int context_size =
Lukas Zilkaba849e72018-03-08 14:48:21 +01001488 classification_feature_processor_->GetOptions()->context_size();
Lukas Zilkab23e2122018-02-09 10:25:19 +01001489 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
1490 /*num_tokens_left=*/context_size,
1491 /*num_tokens_right=*/context_size);
1492 }
Tony Mak378c1f52019-03-04 15:58:11 +00001493 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001494
Lukas Zilka434442d2018-04-25 11:38:51 +02001495 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak378c1f52019-03-04 15:58:11 +00001496 *tokens, extraction_span)) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001497 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001498 return true;
1499 }
1500
Lukas Zilka21d8c982018-01-24 11:11:20 +01001501 std::unique_ptr<CachedFeatures> cached_features;
1502 if (!classification_feature_processor_->ExtractFeatures(
Tony Mak378c1f52019-03-04 15:58:11 +00001503 *tokens, extraction_span, selection_indices,
1504 embedding_executor_.get(), embedding_cache,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001505 classification_feature_processor_->EmbeddingSize() +
1506 classification_feature_processor_->DenseFeaturesCount(),
1507 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001508 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001509 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001510 }
1511
Lukas Zilkab23e2122018-02-09 10:25:19 +01001512 std::vector<float> features;
1513 features.reserve(cached_features->OutputFeaturesSize());
1514 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1515 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1516 &features);
1517 } else {
1518 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001519 }
1520
Lukas Zilkaba849e72018-03-08 14:48:21 +01001521 TensorView<float> logits = classification_executor_->ComputeLogits(
1522 TensorView<float>(features.data(),
1523 {1, static_cast<int>(features.size())}),
1524 interpreter_manager->ClassificationInterpreter());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001525 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001526 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001527 return false;
1528 }
1529
1530 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1531 logits.dim(1) != classification_feature_processor_->NumCollections()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001532 TC3_LOG(ERROR) << "Mismatching output";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001533 return false;
1534 }
1535
1536 const std::vector<float> scores =
1537 ComputeSoftmax(logits.data(), logits.dim(1));
1538
Tony Mak81e52422019-04-30 09:34:45 +01001539 if (scores.empty()) {
1540 *classification_results = {{Collections::Other(), 1.0}};
1541 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001542 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001543
Tony Mak81e52422019-04-30 09:34:45 +01001544 const int best_score_index =
1545 std::max_element(scores.begin(), scores.end()) - scores.begin();
1546 const std::string top_collection =
1547 classification_feature_processor_->LabelToCollection(best_score_index);
1548
1549 // Sanity checks.
1550 if (top_collection == Collections::Phone()) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001551 const int digit_count = CountDigits(context, selection_indices);
1552 if (digit_count <
1553 model_->classification_options()->phone_min_num_digits() ||
1554 digit_count >
1555 model_->classification_options()->phone_max_num_digits()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001556 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001557 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001558 }
Tony Mak81e52422019-04-30 09:34:45 +01001559 } else if (top_collection == Collections::Address()) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001560 if (selection_num_tokens <
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001561 model_->classification_options()->address_min_num_tokens()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001562 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001563 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001564 }
Tony Mak81e52422019-04-30 09:34:45 +01001565 } else if (top_collection == Collections::Dictionary()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001566 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1567 dictionary_locales_,
Tony Makdf54e742019-03-26 14:04:00 +00001568 /*default_value=*/false)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001569 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001570 return true;
Tony Mak378c1f52019-03-04 15:58:11 +00001571 }
1572 }
Tony Mak81e52422019-04-30 09:34:45 +01001573
Tony Makd99d58c2020-03-19 21:52:02 +00001574 *classification_results = {{top_collection, /*arg_score=*/1.0,
1575 /*arg_priority_score=*/scores[best_score_index]}};
1576
1577 // For some entities, we might want to clamp the priority score, for better
1578 // conflict resolution between entities.
1579 if (model_->triggering_options() != nullptr &&
1580 model_->triggering_options()->collection_to_priority() != nullptr) {
1581 if (auto entry =
1582 model_->triggering_options()->collection_to_priority()->LookupByKey(
1583 top_collection.c_str())) {
1584 (*classification_results)[0].priority_score *= entry->value();
1585 }
1586 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001587 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001588}
1589
Tony Mak6c4cc672018-09-17 11:48:50 +01001590bool Annotator::RegexClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001591 const std::string& context, CodepointSpan selection_indices,
Tony Mak378c1f52019-03-04 15:58:11 +00001592 std::vector<ClassificationResult>* classification_result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001593 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001594 UTF8ToUnicodeText(context, /*do_copy=*/false)
1595 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001596 const UnicodeText selection_text_unicode(
1597 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1598
1599 // Check whether any of the regular expressions match.
1600 for (const int pattern_id : classification_regex_patterns_) {
1601 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1602 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1603 regex_pattern.pattern->Matcher(selection_text_unicode);
1604 int status = UniLib::RegexMatcher::kNoError;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001605 bool matches;
Tony Mak854015a2019-01-16 15:56:48 +00001606 if (regex_pattern.config->use_approximate_matching()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001607 matches = matcher->ApproximatelyMatches(&status);
1608 } else {
1609 matches = matcher->Matches(&status);
1610 }
1611 if (status != UniLib::RegexMatcher::kNoError) {
1612 return false;
1613 }
Tony Makdf54e742019-03-26 14:04:00 +00001614 if (matches && VerifyRegexMatchCandidate(
1615 context, regex_pattern.config->verification_options(),
1616 selection_text, matcher.get())) {
Tony Mak378c1f52019-03-04 15:58:11 +00001617 classification_result->push_back(
1618 {regex_pattern.config->collection_name()->str(),
1619 regex_pattern.config->target_classification_score(),
1620 regex_pattern.config->priority_score()});
Tony Makd9446602019-02-20 18:25:39 +00001621 if (!SerializedEntityDataFromRegexMatch(
1622 regex_pattern.config, matcher.get(),
Tony Mak378c1f52019-03-04 15:58:11 +00001623 &classification_result->back().serialized_entity_data)) {
Tony Makd9446602019-02-20 18:25:39 +00001624 TC3_LOG(ERROR) << "Could not get entity data.";
1625 return false;
1626 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001627 }
1628 }
1629
Tony Mak378c1f52019-03-04 15:58:11 +00001630 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001631}
1632
Tony Mak5dc5e112019-02-01 14:52:10 +00001633namespace {
1634std::string PickCollectionForDatetime(
1635 const DatetimeParseResult& datetime_parse_result) {
1636 switch (datetime_parse_result.granularity) {
1637 case GRANULARITY_HOUR:
1638 case GRANULARITY_MINUTE:
1639 case GRANULARITY_SECOND:
1640 return Collections::DateTime();
1641 default:
1642 return Collections::Date();
1643 }
1644}
Tony Mak83d2de62019-04-10 16:12:15 +01001645
1646std::string CreateDatetimeSerializedEntityData(
1647 const DatetimeParseResult& parse_result) {
1648 EntityDataT entity_data;
1649 entity_data.datetime.reset(new EntityData_::DatetimeT());
1650 entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
1651 entity_data.datetime->granularity =
1652 static_cast<EntityData_::Datetime_::Granularity>(
1653 parse_result.granularity);
1654
Tony Maka2a1ff42019-09-12 15:40:32 +01001655 for (const auto& c : parse_result.datetime_components) {
1656 EntityData_::Datetime_::DatetimeComponentT datetime_component;
1657 datetime_component.absolute_value = c.value;
1658 datetime_component.relative_count = c.relative_count;
1659 datetime_component.component_type =
1660 static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
1661 c.component_type);
1662 datetime_component.relation_type =
1663 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
1664 if (c.relative_qualifier !=
1665 DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
1666 datetime_component.relation_type =
1667 EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
1668 }
1669 entity_data.datetime->datetime_component.emplace_back(
1670 new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
1671 }
Tony Mak83d2de62019-04-10 16:12:15 +01001672 flatbuffers::FlatBufferBuilder builder;
1673 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
1674 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
1675 builder.GetSize());
1676}
Tony Mak63959242020-02-07 18:31:16 +00001677
Tony Mak5dc5e112019-02-01 14:52:10 +00001678} // namespace
1679
Tony Mak6c4cc672018-09-17 11:48:50 +01001680bool Annotator::DatetimeClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001681 const std::string& context, CodepointSpan selection_indices,
1682 const ClassificationOptions& options,
Tony Mak854015a2019-01-16 15:56:48 +00001683 std::vector<ClassificationResult>* classification_results) const {
Tony Mak63959242020-02-07 18:31:16 +00001684 if (!datetime_parser_ && !cfg_datetime_parser_) {
Tony Makd99d58c2020-03-19 21:52:02 +00001685 return true;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001686 }
1687
Lukas Zilkab23e2122018-02-09 10:25:19 +01001688 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001689 UTF8ToUnicodeText(context, /*do_copy=*/false)
1690 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001691
1692 std::vector<DatetimeParseResultSpan> datetime_spans;
Tony Makd99d58c2020-03-19 21:52:02 +00001693
Tony Mak63959242020-02-07 18:31:16 +00001694 if (cfg_datetime_parser_) {
1695 if (!(model_->grammar_datetime_model()->enabled_modes() &
1696 ModeFlag_CLASSIFICATION)) {
1697 return true;
1698 }
1699 std::vector<Locale> parsed_locales;
1700 ParseLocales(options.locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00001701 cfg_datetime_parser_->Parse(
1702 selection_text,
1703 ToDateAnnotationOptions(
1704 model_->grammar_datetime_model()->annotation_options(),
1705 options.reference_timezone, options.reference_time_ms_utc),
1706 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00001707 }
1708
1709 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00001710 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1711 options.reference_timezone, options.locales,
1712 ModeFlag_CLASSIFICATION,
1713 options.annotation_usecase,
1714 /*anchor_start_end=*/true, &datetime_spans)) {
1715 TC3_LOG(ERROR) << "Error during parsing datetime.";
1716 return false;
1717 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001718 }
Tony Makd99d58c2020-03-19 21:52:02 +00001719
Lukas Zilkab23e2122018-02-09 10:25:19 +01001720 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1721 // Only consider the result valid if the selection and extracted datetime
1722 // spans exactly match.
1723 if (std::make_pair(datetime_span.span.first + selection_indices.first,
1724 datetime_span.span.second + selection_indices.first) ==
1725 selection_indices) {
Tony Mak854015a2019-01-16 15:56:48 +00001726 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1727 classification_results->emplace_back(
Tony Mak5dc5e112019-02-01 14:52:10 +00001728 PickCollectionForDatetime(parse_result),
1729 datetime_span.target_classification_score);
Tony Mak854015a2019-01-16 15:56:48 +00001730 classification_results->back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01001731 classification_results->back().serialized_entity_data =
1732 CreateDatetimeSerializedEntityData(parse_result);
Tony Mak378c1f52019-03-04 15:58:11 +00001733 classification_results->back().priority_score =
1734 datetime_span.priority_score;
Tony Mak854015a2019-01-16 15:56:48 +00001735 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001736 return true;
1737 }
1738 }
Tony Mak378c1f52019-03-04 15:58:11 +00001739 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001740}
1741
Tony Mak6c4cc672018-09-17 11:48:50 +01001742std::vector<ClassificationResult> Annotator::ClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001743 const std::string& context, CodepointSpan selection_indices,
1744 const ClassificationOptions& options) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01001745 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001746 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001747 return {};
1748 }
Tony Mak5a12b942020-05-01 12:41:31 +01001749 if (options.annotation_usecase !=
1750 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1751 TC3_LOG(WARNING)
1752 << "Invoking ClassifyText, which is not supported in RAW mode.";
1753 return {};
1754 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001755 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1756 return {};
1757 }
1758
Tony Makdf54e742019-03-26 14:04:00 +00001759 std::vector<Locale> detected_text_language_tags;
1760 if (!ParseLocales(options.detected_text_language_tags,
1761 &detected_text_language_tags)) {
1762 TC3_LOG(WARNING)
1763 << "Failed to parse the detected_text_language_tags in options: "
1764 << options.detected_text_language_tags;
1765 }
1766 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1767 model_triggering_locales_,
1768 /*default_value=*/true)) {
1769 return {};
1770 }
1771
Tony Mak968412a2019-11-13 15:39:57 +00001772 if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
1773 selection_indices)) {
1774 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
Tony Mak6c4cc672018-09-17 11:48:50 +01001775 << std::get<0>(selection_indices) << " "
1776 << std::get<1>(selection_indices);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001777 return {};
1778 }
1779
Tony Mak378c1f52019-03-04 15:58:11 +00001780 // We'll accumulate a list of candidates, and pick the best candidate in the
1781 // end.
1782 std::vector<AnnotatedSpan> candidates;
1783
Tony Mak6c4cc672018-09-17 11:48:50 +01001784 // Try the knowledge engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001785 // TODO(b/126579108): Propagate error status.
Tony Mak6c4cc672018-09-17 11:48:50 +01001786 ClassificationResult knowledge_result;
Tony Mak63959242020-02-07 18:31:16 +00001787 if (knowledge_engine_ &&
1788 knowledge_engine_->ClassifyText(
1789 context, selection_indices, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +01001790 options.location_context, Permissions(), &knowledge_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001791 candidates.push_back({selection_indices, {knowledge_result}});
1792 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
Tony Mak854015a2019-01-16 15:56:48 +00001793 }
1794
Tony Maka2a1ff42019-09-12 15:40:32 +01001795 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1796
Tony Mak854015a2019-01-16 15:56:48 +00001797 // Try the contact engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001798 // TODO(b/126579108): Propagate error status.
Tony Mak854015a2019-01-16 15:56:48 +00001799 ClassificationResult contact_result;
1800 if (contact_engine_ && contact_engine_->ClassifyText(
1801 context, selection_indices, &contact_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001802 candidates.push_back({selection_indices, {contact_result}});
Tony Mak6c4cc672018-09-17 11:48:50 +01001803 }
1804
Tony Mak76d80962020-01-08 17:30:51 +00001805 // Try the person name engine.
1806 ClassificationResult person_name_result;
1807 if (person_name_engine_ &&
1808 person_name_engine_->ClassifyText(context, selection_indices,
1809 &person_name_result)) {
1810 candidates.push_back({selection_indices, {person_name_result}});
Tony Makd0ae7c62020-03-27 13:58:00 +00001811 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
Tony Mak76d80962020-01-08 17:30:51 +00001812 }
1813
Tony Makd9446602019-02-20 18:25:39 +00001814 // Try the installed app engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001815 // TODO(b/126579108): Propagate error status.
Tony Makd9446602019-02-20 18:25:39 +00001816 ClassificationResult installed_app_result;
1817 if (installed_app_engine_ &&
1818 installed_app_engine_->ClassifyText(context, selection_indices,
1819 &installed_app_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001820 candidates.push_back({selection_indices, {installed_app_result}});
Tony Makd9446602019-02-20 18:25:39 +00001821 }
1822
Lukas Zilkab23e2122018-02-09 10:25:19 +01001823 // Try the regular expression models.
Tony Mak378c1f52019-03-04 15:58:11 +00001824 std::vector<ClassificationResult> regex_results;
1825 if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1826 return {};
1827 }
1828 for (const ClassificationResult& result : regex_results) {
1829 candidates.push_back({selection_indices, {result}});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001830 }
1831
Lukas Zilkab23e2122018-02-09 10:25:19 +01001832 // Try the date model.
Tony Mak378c1f52019-03-04 15:58:11 +00001833 //
1834 // DatetimeClassifyText only returns the first result, which can however have
1835 // more interpretations. They are inserted in the candidates as a single
1836 // AnnotatedSpan, so that they get treated together by the conflict resolution
1837 // algorithm.
Tony Mak854015a2019-01-16 15:56:48 +00001838 std::vector<ClassificationResult> datetime_results;
Tony Mak378c1f52019-03-04 15:58:11 +00001839 if (!DatetimeClassifyText(context, selection_indices, options,
1840 &datetime_results)) {
1841 return {};
1842 }
1843 if (!datetime_results.empty()) {
1844 candidates.push_back({selection_indices, std::move(datetime_results)});
Tony Mak448b5862019-03-22 13:36:41 +00001845 candidates.back().source = AnnotatedSpan::Source::DATETIME;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001846 }
1847
Tony Mak5a12b942020-05-01 12:41:31 +01001848 const UnicodeText context_unicode =
1849 UTF8ToUnicodeText(context, /*do_copy=*/false);
1850
Tony Mak378c1f52019-03-04 15:58:11 +00001851 // Try the number annotator.
1852 // TODO(b/126579108): Propagate error status.
1853 ClassificationResult number_annotator_result;
1854 if (number_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001855 number_annotator_->ClassifyText(context_unicode, selection_indices,
1856 options.annotation_usecase,
1857 &number_annotator_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001858 candidates.push_back({selection_indices, {number_annotator_result}});
1859 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001860
Tony Makad2e22d2019-03-20 17:35:13 +00001861 // Try the duration annotator.
1862 ClassificationResult duration_annotator_result;
1863 if (duration_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001864 duration_annotator_->ClassifyText(context_unicode, selection_indices,
1865 options.annotation_usecase,
1866 &duration_annotator_result)) {
Tony Makad2e22d2019-03-20 17:35:13 +00001867 candidates.push_back({selection_indices, {duration_annotator_result}});
Tony Mak448b5862019-03-22 13:36:41 +00001868 candidates.back().source = AnnotatedSpan::Source::DURATION;
Tony Makad2e22d2019-03-20 17:35:13 +00001869 }
1870
Tony Mak63959242020-02-07 18:31:16 +00001871 // Try the translate annotator.
1872 ClassificationResult translate_annotator_result;
1873 if (translate_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001874 translate_annotator_->ClassifyText(context_unicode, selection_indices,
1875 options.user_familiar_language_tags,
1876 &translate_annotator_result)) {
Tony Mak63959242020-02-07 18:31:16 +00001877 candidates.push_back({selection_indices, {translate_annotator_result}});
1878 }
1879
Tony Mak21460022020-03-12 18:29:35 +00001880 // Try the grammar model.
1881 ClassificationResult grammar_annotator_result;
1882 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
Tony Mak5a12b942020-05-01 12:41:31 +01001883 detected_text_language_tags, context_unicode,
Tony Mak21460022020-03-12 18:29:35 +00001884 selection_indices, &grammar_annotator_result)) {
1885 candidates.push_back({selection_indices, {grammar_annotator_result}});
1886 }
1887
Tony Mak5a12b942020-05-01 12:41:31 +01001888 ClassificationResult experimental_annotator_result;
1889 if (experimental_annotator_ &&
1890 experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1891 &experimental_annotator_result)) {
1892 candidates.push_back({selection_indices, {experimental_annotator_result}});
1893 }
1894
Tony Mak378c1f52019-03-04 15:58:11 +00001895 // Try the ML model.
1896 //
1897 // The output of the model is considered as an exclusive 1-of-N choice. That's
1898 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1899 // span for each candidate, like e.g. the regex model.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001900 InterpreterManager interpreter_manager(selection_executor_.get(),
1901 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001902 std::vector<ClassificationResult> model_results;
1903 std::vector<Token> tokens;
1904 if (!ModelClassifyText(
1905 context, /*cached_tokens=*/{}, detected_text_language_tags,
1906 selection_indices, &interpreter_manager,
1907 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1908 return {};
1909 }
1910 if (!model_results.empty()) {
1911 candidates.push_back({selection_indices, std::move(model_results)});
1912 }
1913
1914 std::vector<int> candidate_indices;
1915 if (!ResolveConflicts(candidates, context, tokens,
1916 detected_text_language_tags, options.annotation_usecase,
1917 &interpreter_manager, &candidate_indices)) {
1918 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1919 return {};
1920 }
1921
1922 std::vector<ClassificationResult> results;
1923 for (const int i : candidate_indices) {
1924 for (const ClassificationResult& result : candidates[i].classification) {
1925 if (!FilteredForClassification(result)) {
1926 results.push_back(result);
1927 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001928 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001929 }
1930
Tony Mak378c1f52019-03-04 15:58:11 +00001931 // Sort results according to score.
1932 std::sort(results.begin(), results.end(),
1933 [](const ClassificationResult& a, const ClassificationResult& b) {
1934 return a.score > b.score;
1935 });
1936
1937 if (results.empty()) {
Tony Mak81e52422019-04-30 09:34:45 +01001938 results = {{Collections::Other(), 1.0}};
Tony Mak378c1f52019-03-04 15:58:11 +00001939 }
Tony Mak378c1f52019-03-04 15:58:11 +00001940 return results;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001941}
1942
Tony Mak378c1f52019-03-04 15:58:11 +00001943bool Annotator::ModelAnnotate(
1944 const std::string& context,
1945 const std::vector<Locale>& detected_text_language_tags,
1946 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1947 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001948 if (model_->triggering_options() == nullptr ||
1949 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1950 return true;
1951 }
1952
Tony Makdf54e742019-03-26 14:04:00 +00001953 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1954 ml_model_triggering_locales_,
1955 /*default_value=*/true)) {
1956 return true;
1957 }
1958
Lukas Zilka21d8c982018-01-24 11:11:20 +01001959 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1960 /*do_copy=*/false);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001961 std::vector<UnicodeTextRange> lines;
1962 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1963 lines.push_back({context_unicode.begin(), context_unicode.end()});
1964 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +01001965 lines = selection_feature_processor_->SplitContext(
1966 context_unicode, selection_feature_processor_->GetOptions()
1967 ->use_pipe_character_for_newline());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001968 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001969
Lukas Zilkaba849e72018-03-08 14:48:21 +01001970 const float min_annotate_confidence =
1971 (model_->triggering_options() != nullptr
1972 ? model_->triggering_options()->min_annotate_confidence()
1973 : 0.f);
1974
Lukas Zilkab23e2122018-02-09 10:25:19 +01001975 for (const UnicodeTextRange& line : lines) {
Tony Mak408c6b82019-03-08 17:57:27 +00001976 FeatureProcessor::EmbeddingCache embedding_cache;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001977 const std::string line_str =
1978 UnicodeText::UTF8Substring(line.first, line.second);
1979
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001980 *tokens = selection_feature_processor_->Tokenize(line_str);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001981 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001982 line_str, {0, std::distance(line.first, line.second)},
1983 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001984 tokens,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001985 /*click_pos=*/nullptr);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001986 const TokenSpan full_line_span = {0, tokens->size()};
Lukas Zilka21d8c982018-01-24 11:11:20 +01001987
Lukas Zilka434442d2018-04-25 11:38:51 +02001988 // TODO(zilka): Add support for greater granularity of this check.
1989 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1990 *tokens, full_line_span)) {
1991 continue;
1992 }
1993
Lukas Zilka21d8c982018-01-24 11:11:20 +01001994 std::unique_ptr<CachedFeatures> cached_features;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001995 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001996 *tokens, full_line_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001997 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1998 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001999 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002000 selection_feature_processor_->EmbeddingSize() +
2001 selection_feature_processor_->DenseFeaturesCount(),
Lukas Zilka21d8c982018-01-24 11:11:20 +01002002 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002003 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002004 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002005 }
2006
2007 std::vector<TokenSpan> local_chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002008 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002009 interpreter_manager->SelectionInterpreter(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002010 *cached_features, &local_chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002011 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002012 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002013 }
2014
2015 const int offset = std::distance(context_unicode.begin(), line.first);
2016 for (const TokenSpan& chunk : local_chunks) {
2017 const CodepointSpan codepoint_span =
2018 selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002019 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01002020
2021 // Skip empty spans.
2022 if (codepoint_span.first != codepoint_span.second) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002023 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00002024 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
2025 codepoint_span, interpreter_manager,
2026 &embedding_cache, &classification)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002027 TC3_LOG(ERROR) << "Could not classify text: "
2028 << (codepoint_span.first + offset) << " "
2029 << (codepoint_span.second + offset);
Lukas Zilkaba849e72018-03-08 14:48:21 +01002030 return false;
2031 }
2032
2033 // Do not include the span if it's classified as "other".
2034 if (!classification.empty() && !ClassifiedAsOther(classification) &&
2035 classification[0].score >= min_annotate_confidence) {
2036 AnnotatedSpan result_span;
2037 result_span.span = {codepoint_span.first + offset,
2038 codepoint_span.second + offset};
2039 result_span.classification = std::move(classification);
2040 result->push_back(std::move(result_span));
2041 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002042 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01002043 }
2044 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002045 return true;
2046}
2047
Tony Mak6c4cc672018-09-17 11:48:50 +01002048const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
Lukas Zilka434442d2018-04-25 11:38:51 +02002049 return selection_feature_processor_.get();
2050}
2051
Tony Mak6c4cc672018-09-17 11:48:50 +01002052const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
Lukas Zilka434442d2018-04-25 11:38:51 +02002053 const {
2054 return classification_feature_processor_.get();
Lukas Zilkab23e2122018-02-09 10:25:19 +01002055}
2056
Tony Mak6c4cc672018-09-17 11:48:50 +01002057const DatetimeParser* Annotator::DatetimeParserForTests() const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002058 return datetime_parser_.get();
2059}
2060
Tony Mak83d2de62019-04-10 16:12:15 +01002061void Annotator::RemoveNotEnabledEntityTypes(
2062 const EnabledEntityTypes& is_entity_type_enabled,
2063 std::vector<AnnotatedSpan>* annotated_spans) const {
2064 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2065 std::vector<ClassificationResult>& classifications =
2066 annotated_span.classification;
2067 classifications.erase(
2068 std::remove_if(classifications.begin(), classifications.end(),
2069 [&is_entity_type_enabled](
2070 const ClassificationResult& classification_result) {
2071 return !is_entity_type_enabled(
2072 classification_result.collection);
2073 }),
2074 classifications.end());
2075 }
2076 annotated_spans->erase(
2077 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2078 [](const AnnotatedSpan& annotated_span) {
2079 return annotated_span.classification.empty();
2080 }),
2081 annotated_spans->end());
2082}
2083
Tony Maka2a1ff42019-09-12 15:40:32 +01002084void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2085 std::vector<AnnotatedSpan>* candidates) const {
2086 if (candidates == nullptr || contact_engine_ == nullptr) {
2087 return;
2088 }
2089 for (auto& candidate : *candidates) {
2090 for (auto& classification_result : candidate.classification) {
2091 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2092 &classification_result);
2093 }
2094 }
2095}
2096
Tony Makff31efb2020-03-31 11:13:06 +01002097Status Annotator::AnnotateSingleInput(
2098 const std::string& context, const AnnotationOptions& options,
2099 std::vector<AnnotatedSpan>* candidates) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002100 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
Tony Makff31efb2020-03-31 11:13:06 +01002101 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
Lukas Zilkaba849e72018-03-08 14:48:21 +01002102 }
2103
Tony Mak854015a2019-01-16 15:56:48 +00002104 const UnicodeText context_unicode =
2105 UTF8ToUnicodeText(context, /*do_copy=*/false);
2106 if (!context_unicode.is_valid()) {
Tony Makff31efb2020-03-31 11:13:06 +01002107 return Status(StatusCode::INVALID_ARGUMENT,
2108 "Context string isn't valid UTF8.");
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002109 }
2110
Tony Mak378c1f52019-03-04 15:58:11 +00002111 std::vector<Locale> detected_text_language_tags;
2112 if (!ParseLocales(options.detected_text_language_tags,
2113 &detected_text_language_tags)) {
2114 TC3_LOG(WARNING)
Tony Makdf54e742019-03-26 14:04:00 +00002115 << "Failed to parse the detected_text_language_tags in options: "
Tony Mak378c1f52019-03-04 15:58:11 +00002116 << options.detected_text_language_tags;
2117 }
Tony Makdf54e742019-03-26 14:04:00 +00002118 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2119 model_triggering_locales_,
2120 /*default_value=*/true)) {
Tony Makff31efb2020-03-31 11:13:06 +01002121 return Status(
2122 StatusCode::UNAVAILABLE,
2123 "The detected language tags are not in the supported locales.");
Tony Makdf54e742019-03-26 14:04:00 +00002124 }
2125
2126 InterpreterManager interpreter_manager(selection_executor_.get(),
2127 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00002128
Lukas Zilkab23e2122018-02-09 10:25:19 +01002129 // Annotate with the selection model.
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002130 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00002131 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
Tony Makff31efb2020-03-31 11:13:06 +01002132 &tokens, candidates)) {
2133 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002134 }
2135
2136 // Annotate with the regular expression models.
2137 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Tony Makff31efb2020-03-31 11:13:06 +01002138 annotation_regex_patterns_, candidates,
Tony Mak83d2de62019-04-10 16:12:15 +01002139 options.is_serialized_entity_data_enabled)) {
Tony Makff31efb2020-03-31 11:13:06 +01002140 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002141 }
2142
2143 // Annotate with the datetime model.
Tony Mak83d2de62019-04-10 16:12:15 +01002144 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2145 if ((is_entity_type_enabled(Collections::Date()) ||
2146 is_entity_type_enabled(Collections::DateTime())) &&
2147 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002148 options.reference_time_ms_utc, options.reference_timezone,
Tony Makd9446602019-02-20 18:25:39 +00002149 options.locales, ModeFlag_ANNOTATION,
Tony Mak83d2de62019-04-10 16:12:15 +01002150 options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002151 options.is_serialized_entity_data_enabled, candidates)) {
2152 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
Tony Mak6c4cc672018-09-17 11:48:50 +01002153 }
2154
Tony Mak854015a2019-01-16 15:56:48 +00002155 // Annotate with the contact engine.
2156 if (contact_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002157 !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2158 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
Tony Mak854015a2019-01-16 15:56:48 +00002159 }
2160
Tony Makd9446602019-02-20 18:25:39 +00002161 // Annotate with the installed app engine.
2162 if (installed_app_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002163 !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2164 return Status(StatusCode::INTERNAL,
2165 "Couldn't run installed app engine Chunk.");
Tony Makd9446602019-02-20 18:25:39 +00002166 }
2167
Tony Mak378c1f52019-03-04 15:58:11 +00002168 // Annotate with the number annotator.
2169 if (number_annotator_ != nullptr &&
2170 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002171 candidates)) {
2172 return Status(StatusCode::INTERNAL,
2173 "Couldn't run number annotator FindAll.");
Tony Makad2e22d2019-03-20 17:35:13 +00002174 }
2175
2176 // Annotate with the duration annotator.
Tony Mak83d2de62019-04-10 16:12:15 +01002177 if (is_entity_type_enabled(Collections::Duration()) &&
2178 duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +00002179 !duration_annotator_->FindAll(context_unicode, tokens,
Tony Makff31efb2020-03-31 11:13:06 +01002180 options.annotation_usecase, candidates)) {
2181 return Status(StatusCode::INTERNAL,
2182 "Couldn't run duration annotator FindAll.");
Tony Mak378c1f52019-03-04 15:58:11 +00002183 }
2184
Tony Mak76d80962020-01-08 17:30:51 +00002185 // Annotate with the person name engine.
2186 if (is_entity_type_enabled(Collections::PersonName()) &&
2187 person_name_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002188 !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2189 return Status(StatusCode::INTERNAL,
2190 "Couldn't run person name engine Chunk.");
Tony Mak76d80962020-01-08 17:30:51 +00002191 }
2192
Tony Mak21460022020-03-12 18:29:35 +00002193 // Annotate with the grammar annotators.
2194 if (grammar_annotator_ != nullptr &&
2195 !grammar_annotator_->Annotate(detected_text_language_tags,
Tony Makff31efb2020-03-31 11:13:06 +01002196 context_unicode, candidates)) {
2197 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
Tony Mak21460022020-03-12 18:29:35 +00002198 }
2199
Tony Mak5a12b942020-05-01 12:41:31 +01002200 if (experimental_annotator_ != nullptr &&
2201 !experimental_annotator_->Annotate(context_unicode, candidates)) {
2202 return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2203 }
2204
Lukas Zilkab23e2122018-02-09 10:25:19 +01002205 // Sort candidates according to their position in the input, so that the next
2206 // code can assume that any connected component of overlapping spans forms a
2207 // contiguous block.
Tony Mak5a12b942020-05-01 12:41:31 +01002208 // Also sort them according to the end position and collection, so that the
2209 // deduplication code below can assume that same spans and classifications
2210 // form contiguous blocks.
Tony Makff31efb2020-03-31 11:13:06 +01002211 std::sort(candidates->begin(), candidates->end(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002212 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
Tony Mak5a12b942020-05-01 12:41:31 +01002213 if (a.span.first != b.span.first) {
2214 return a.span.first < b.span.first;
2215 }
2216
2217 if (a.span.second != b.span.second) {
2218 return a.span.second < b.span.second;
2219 }
2220
2221 return a.classification[0].collection <
2222 b.classification[0].collection;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002223 });
2224
2225 std::vector<int> candidate_indices;
Tony Makff31efb2020-03-31 11:13:06 +01002226 if (!ResolveConflicts(*candidates, context, tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00002227 detected_text_language_tags, options.annotation_usecase,
2228 &interpreter_manager, &candidate_indices)) {
Tony Makff31efb2020-03-31 11:13:06 +01002229 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002230 }
2231
Tony Mak5a12b942020-05-01 12:41:31 +01002232 // Remove candidates that overlap exactly and have the same collection.
2233 // This can e.g. happen for phone coming from both ML model and regex.
2234 candidate_indices.erase(
2235 std::unique(candidate_indices.begin(), candidate_indices.end(),
2236 [&candidates](const int a_index, const int b_index) {
2237 const AnnotatedSpan& a = (*candidates)[a_index];
2238 const AnnotatedSpan& b = (*candidates)[b_index];
2239 return a.span == b.span &&
2240 a.classification[0].collection ==
2241 b.classification[0].collection;
2242 }),
2243 candidate_indices.end());
2244
Lukas Zilkab23e2122018-02-09 10:25:19 +01002245 std::vector<AnnotatedSpan> result;
2246 result.reserve(candidate_indices.size());
2247 for (const int i : candidate_indices) {
Tony Makff31efb2020-03-31 11:13:06 +01002248 if ((*candidates)[i].classification.empty() ||
2249 ClassifiedAsOther((*candidates)[i].classification) ||
2250 FilteredForAnnotation((*candidates)[i])) {
Tony Mak378c1f52019-03-04 15:58:11 +00002251 continue;
2252 }
Tony Mak5a12b942020-05-01 12:41:31 +01002253 result.push_back(std::move((*candidates)[i]));
Tony Mak378c1f52019-03-04 15:58:11 +00002254 }
2255
Tony Mak83d2de62019-04-10 16:12:15 +01002256 // We generate all candidates and remove them later (with the exception of
2257 // date/time/duration entities) because there are complex interdependencies
2258 // between the entity types. E.g., the TLD of an email can be interpreted as a
2259 // URL, but most likely a user of the API does not want such annotations if
2260 // "url" is enabled and "email" is not.
2261 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2262
Tony Mak378c1f52019-03-04 15:58:11 +00002263 for (AnnotatedSpan& annotated_span : result) {
2264 SortClassificationResults(&annotated_span.classification);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002265 }
Tony Makff31efb2020-03-31 11:13:06 +01002266 *candidates = result;
2267 return Status::OK;
2268}
Lukas Zilkab23e2122018-02-09 10:25:19 +01002269
Tony Makff31efb2020-03-31 11:13:06 +01002270StatusOr<std::vector<std::vector<AnnotatedSpan>>>
2271Annotator::AnnotateStructuredInput(
2272 const std::vector<InputFragment>& string_fragments,
2273 const AnnotationOptions& options) const {
2274 std::vector<std::vector<AnnotatedSpan>> annotation_candidates(
2275 string_fragments.size());
2276
2277 std::vector<std::string> text_to_annotate;
2278 text_to_annotate.reserve(string_fragments.size());
2279 for (const auto& string_fragment : string_fragments) {
2280 text_to_annotate.push_back(string_fragment.text);
2281 }
2282
2283 // KnowledgeEngine is special, because it supports annotation of multiple
2284 // fragments at once.
2285 if (knowledge_engine_ &&
2286 !knowledge_engine_
2287 ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +01002288 options.location_context, options.permissions,
Tony Makff31efb2020-03-31 11:13:06 +01002289 &annotation_candidates)
2290 .ok()) {
2291 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2292 }
2293 // The annotator engines shouldn't change the number of annotation vectors.
2294 if (annotation_candidates.size() != text_to_annotate.size()) {
2295 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2296 << " texts to annotate but generated a different number of "
2297 "lists of annotations:"
2298 << annotation_candidates.size();
2299 return Status(StatusCode::INTERNAL,
2300 "Number of annotation candidates differs from "
2301 "number of texts to annotate.");
2302 }
2303
2304 // Other annotators run on each fragment independently.
2305 for (int i = 0; i < text_to_annotate.size(); ++i) {
2306 AnnotationOptions annotation_options = options;
2307 if (string_fragments[i].datetime_options.has_value()) {
2308 DatetimeOptions reference_datetime =
2309 string_fragments[i].datetime_options.value();
2310 annotation_options.reference_time_ms_utc =
2311 reference_datetime.reference_time_ms_utc;
2312 annotation_options.reference_timezone =
2313 reference_datetime.reference_timezone;
2314 }
2315
2316 AddContactMetadataToKnowledgeClassificationResults(
2317 &annotation_candidates[i]);
2318
2319 Status annotation_status = AnnotateSingleInput(
2320 text_to_annotate[i], annotation_options, &annotation_candidates[i]);
2321 if (!annotation_status.ok()) {
2322 return annotation_status;
2323 }
2324 }
2325 return annotation_candidates;
2326}
2327
2328std::vector<AnnotatedSpan> Annotator::Annotate(
2329 const std::string& context, const AnnotationOptions& options) const {
2330 std::vector<InputFragment> string_fragments;
2331 string_fragments.push_back({.text = context});
2332 StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations =
2333 AnnotateStructuredInput(string_fragments, options);
2334 if (!annotations.ok()) {
2335 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2336 << annotations.status().error_message();
2337 return {};
2338 }
2339 return annotations.ValueOrDie()[0];
Lukas Zilka21d8c982018-01-24 11:11:20 +01002340}
2341
Tony Mak854015a2019-01-16 15:56:48 +00002342CodepointSpan Annotator::ComputeSelectionBoundaries(
2343 const UniLib::RegexMatcher* match,
2344 const RegexModel_::Pattern* config) const {
2345 if (config->capturing_group() == nullptr) {
2346 // Use first capturing group to specify the selection.
2347 int status = UniLib::RegexMatcher::kNoError;
2348 const CodepointSpan result = {match->Start(1, &status),
2349 match->End(1, &status)};
2350 if (status != UniLib::RegexMatcher::kNoError) {
2351 return {kInvalidIndex, kInvalidIndex};
2352 }
2353 return result;
2354 }
2355
2356 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2357 const int num_groups = config->capturing_group()->size();
2358 for (int i = 0; i < num_groups; i++) {
2359 if (!config->capturing_group()->Get(i)->extend_selection()) {
2360 continue;
2361 }
2362
2363 int status = UniLib::RegexMatcher::kNoError;
2364 // Check match and adjust bounds.
2365 const int group_start = match->Start(i, &status);
2366 const int group_end = match->End(i, &status);
2367 if (status != UniLib::RegexMatcher::kNoError) {
2368 return {kInvalidIndex, kInvalidIndex};
2369 }
2370 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2371 continue;
2372 }
2373 if (result.first == kInvalidIndex) {
2374 result = {group_start, group_end};
2375 } else {
2376 result.first = std::min(result.first, group_start);
2377 result.second = std::max(result.second, group_end);
2378 }
2379 }
2380 return result;
2381}
2382
Tony Makd9446602019-02-20 18:25:39 +00002383bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
Tony Mak21460022020-03-12 18:29:35 +00002384 if (pattern->serialized_entity_data() != nullptr ||
2385 pattern->entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002386 return true;
2387 }
2388 if (pattern->capturing_group() != nullptr) {
Tony Mak63959242020-02-07 18:31:16 +00002389 for (const CapturingGroup* group : *pattern->capturing_group()) {
Tony Makd9446602019-02-20 18:25:39 +00002390 if (group->entity_field_path() != nullptr) {
2391 return true;
2392 }
Tony Mak21460022020-03-12 18:29:35 +00002393 if (group->serialized_entity_data() != nullptr ||
2394 group->entity_data() != nullptr) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002395 return true;
2396 }
Tony Makd9446602019-02-20 18:25:39 +00002397 }
2398 }
2399 return false;
2400}
2401
2402bool Annotator::SerializedEntityDataFromRegexMatch(
2403 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2404 std::string* serialized_entity_data) const {
2405 if (!HasEntityData(pattern)) {
2406 serialized_entity_data->clear();
2407 return true;
2408 }
2409 TC3_CHECK(entity_data_builder_ != nullptr);
2410
2411 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
2412 entity_data_builder_->NewRoot();
2413
2414 TC3_CHECK(entity_data != nullptr);
2415
Tony Mak21460022020-03-12 18:29:35 +00002416 // Set fixed entity data.
Tony Makd9446602019-02-20 18:25:39 +00002417 if (pattern->serialized_entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002418 entity_data->MergeFromSerializedFlatbuffer(
2419 StringPiece(pattern->serialized_entity_data()->c_str(),
2420 pattern->serialized_entity_data()->size()));
2421 }
Tony Mak21460022020-03-12 18:29:35 +00002422 if (pattern->entity_data() != nullptr) {
2423 entity_data->MergeFrom(
2424 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2425 }
Tony Makd9446602019-02-20 18:25:39 +00002426
2427 // Add entity data from rule capturing groups.
2428 if (pattern->capturing_group() != nullptr) {
2429 const int num_groups = pattern->capturing_group()->size();
2430 for (int i = 0; i < num_groups; i++) {
Tony Mak63959242020-02-07 18:31:16 +00002431 const CapturingGroup* group = pattern->capturing_group()->Get(i);
Tony Maka2a1ff42019-09-12 15:40:32 +01002432
2433 // Check whether the group matched.
2434 Optional<std::string> group_match_text =
2435 GetCapturingGroupText(matcher, /*group_id=*/i);
2436 if (!group_match_text.has_value()) {
Tony Makd9446602019-02-20 18:25:39 +00002437 continue;
2438 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002439
Tony Mak21460022020-03-12 18:29:35 +00002440 // Set fixed entity data from capturing group match.
Tony Maka2a1ff42019-09-12 15:40:32 +01002441 if (group->serialized_entity_data() != nullptr) {
2442 entity_data->MergeFromSerializedFlatbuffer(
2443 StringPiece(group->serialized_entity_data()->c_str(),
2444 group->serialized_entity_data()->size()));
2445 }
Tony Mak21460022020-03-12 18:29:35 +00002446 if (group->entity_data() != nullptr) {
2447 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2448 pattern->entity_data()));
2449 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002450
2451 // Set entity field from capturing group text.
2452 if (group->entity_field_path() != nullptr) {
Tony Mak8cd7ba62019-10-15 15:29:22 +01002453 UnicodeText normalized_group_match_text =
2454 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2455
2456 // Apply normalization if specified.
2457 if (group->normalization_options() != nullptr) {
2458 normalized_group_match_text =
Tony Mak1ac2e4a2020-04-29 13:41:53 +01002459 NormalizeText(*unilib_, group->normalization_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +01002460 normalized_group_match_text);
2461 }
2462
2463 if (!entity_data->ParseAndSet(
2464 group->entity_field_path(),
2465 normalized_group_match_text.ToUTF8String())) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002466 TC3_LOG(ERROR)
2467 << "Could not set entity data from rule capturing group.";
2468 return false;
2469 }
Tony Makd9446602019-02-20 18:25:39 +00002470 }
2471 }
2472 }
2473
2474 *serialized_entity_data = entity_data->Serialize();
2475 return true;
2476}
2477
Tony Mak63959242020-02-07 18:31:16 +00002478UnicodeText RemoveMoneySeparators(
2479 const std::unordered_set<char32>& decimal_separators,
2480 const UnicodeText& amount,
2481 UnicodeText::const_iterator it_decimal_separator) {
2482 UnicodeText whole_amount;
2483 for (auto it = amount.begin();
2484 it != amount.end() && it != it_decimal_separator; ++it) {
2485 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2486 static_cast<char32>(*it)) == decimal_separators.end()) {
2487 whole_amount.push_back(*it);
2488 }
2489 }
2490 return whole_amount;
2491}
2492
2493bool Annotator::ParseAndFillInMoneyAmount(
2494 std::string* serialized_entity_data) const {
2495 std::unique_ptr<EntityDataT> data =
2496 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2497 *serialized_entity_data);
Tony Mak0b8b3322020-03-17 16:30:19 +00002498 if (data == nullptr) {
Tony Makc121edd2020-05-28 15:25:17 +01002499 if (model_->version() >= 706) {
2500 // This way of parsing money entity data is enabled for models newer than
2501 // v706, consequently logging errors only for them (b/156634162).
2502 TC3_LOG(ERROR)
2503 << "Data field is null when trying to parse Money Entity Data";
2504 }
Tony Mak0b8b3322020-03-17 16:30:19 +00002505 return false;
2506 }
2507 if (data->money->unnormalized_amount.empty()) {
Tony Makc121edd2020-05-28 15:25:17 +01002508 if (model_->version() >= 706) {
2509 // This way of parsing money entity data is enabled for models newer than
2510 // v706, consequently logging errors only for them (b/156634162).
2511 TC3_LOG(ERROR)
2512 << "Data unnormalized_amount is empty when trying to parse "
2513 "Money Entity Data";
2514 }
Tony Mak63959242020-02-07 18:31:16 +00002515 return false;
2516 }
2517
2518 UnicodeText amount =
2519 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2520 int separator_back_index = 0;
Tony Mak21460022020-03-12 18:29:35 +00002521 auto it_decimal_separator = --amount.end();
Tony Mak63959242020-02-07 18:31:16 +00002522 for (; it_decimal_separator != amount.begin();
2523 --it_decimal_separator, ++separator_back_index) {
2524 if (std::find(money_separators_.begin(), money_separators_.end(),
2525 static_cast<char32>(*it_decimal_separator)) !=
2526 money_separators_.end()) {
2527 break;
2528 }
2529 }
2530
2531 // If there are 3 digits after the last separator, we consider that a
2532 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2533 // If there is no separator in number, also that number is an int.
Tony Mak21460022020-03-12 18:29:35 +00002534 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
Tony Mak63959242020-02-07 18:31:16 +00002535 it_decimal_separator = amount.end();
2536 }
2537
2538 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2539 it_decimal_separator),
2540 &data->money->amount_whole_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002541 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2542 "amount: "
2543 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002544 return false;
2545 }
2546 if (it_decimal_separator == amount.end()) {
2547 data->money->amount_decimal_part = 0;
2548 } else {
2549 const int amount_codepoints_size = amount.size_codepoints();
2550 if (!unilib_->ParseInt32(
2551 UnicodeText::Substring(
Tony Mak21460022020-03-12 18:29:35 +00002552 amount, amount_codepoints_size - separator_back_index,
Tony Mak63959242020-02-07 18:31:16 +00002553 amount_codepoints_size, /*do_copy=*/false),
2554 &data->money->amount_decimal_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002555 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2556 "the amount: "
2557 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002558 return false;
2559 }
2560 }
2561
2562 *serialized_entity_data =
2563 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2564 return true;
2565}
2566
Tony Mak6c4cc672018-09-17 11:48:50 +01002567bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2568 const std::vector<int>& rules,
Tony Mak83d2de62019-04-10 16:12:15 +01002569 std::vector<AnnotatedSpan>* result,
2570 bool is_serialized_entity_data_enabled) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002571 for (int pattern_id : rules) {
2572 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2573 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2574 if (!matcher) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002575 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2576 << pattern_id;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002577 return false;
2578 }
2579
2580 int status = UniLib::RegexMatcher::kNoError;
2581 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
Tony Mak854015a2019-01-16 15:56:48 +00002582 if (regex_pattern.config->verification_options()) {
Tony Makdf54e742019-03-26 14:04:00 +00002583 if (!VerifyRegexMatchCandidate(
2584 context_unicode.ToUTF8String(),
2585 regex_pattern.config->verification_options(),
2586 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002587 continue;
2588 }
2589 }
Tony Makd9446602019-02-20 18:25:39 +00002590
2591 std::string serialized_entity_data;
Tony Mak83d2de62019-04-10 16:12:15 +01002592 if (is_serialized_entity_data_enabled) {
2593 if (!SerializedEntityDataFromRegexMatch(
2594 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2595 TC3_LOG(ERROR) << "Could not get entity data.";
2596 return false;
2597 }
Tony Mak63959242020-02-07 18:31:16 +00002598
2599 // Further parsing unnormalized_amount for money into amount_whole_part
2600 // and amount_decimal_part. Can't do this with regexes because we cannot
2601 // have empty groups (amount_decimal_part might be an empty group).
2602 if (regex_pattern.config->collection_name()->str() ==
2603 Collections::Money()) {
2604 if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
Tony Makc121edd2020-05-28 15:25:17 +01002605 if (model_->version() >= 706) {
2606 // This way of parsing money entity data is enabled for models
2607 // newer than v706 => logging errors only for them (b/156634162).
2608 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2609 }
Tony Mak63959242020-02-07 18:31:16 +00002610 }
2611 }
Tony Makd9446602019-02-20 18:25:39 +00002612 }
2613
Lukas Zilkab23e2122018-02-09 10:25:19 +01002614 result->emplace_back();
Tony Mak854015a2019-01-16 15:56:48 +00002615
Lukas Zilkab23e2122018-02-09 10:25:19 +01002616 // Selection/annotation regular expressions need to specify a capturing
2617 // group specifying the selection.
Tony Mak854015a2019-01-16 15:56:48 +00002618 result->back().span =
2619 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2620
Lukas Zilkab23e2122018-02-09 10:25:19 +01002621 result->back().classification = {
Tony Mak854015a2019-01-16 15:56:48 +00002622 {regex_pattern.config->collection_name()->str(),
2623 regex_pattern.config->target_classification_score(),
2624 regex_pattern.config->priority_score()}};
Tony Makd9446602019-02-20 18:25:39 +00002625
2626 result->back().classification[0].serialized_entity_data =
2627 serialized_entity_data;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002628 }
2629 }
2630 return true;
2631}
2632
Tony Mak6c4cc672018-09-17 11:48:50 +01002633bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2634 tflite::Interpreter* selection_interpreter,
2635 const CachedFeatures& cached_features,
2636 std::vector<TokenSpan>* chunks) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01002637 const int max_selection_span =
2638 selection_feature_processor_->GetOptions()->max_selection_span();
Lukas Zilka21d8c982018-01-24 11:11:20 +01002639 // The inference span is the span of interest expanded to include
2640 // max_selection_span tokens on either side, which is how far a selection can
2641 // stretch from the click.
2642 const TokenSpan inference_span = IntersectTokenSpans(
2643 ExpandTokenSpan(span_of_interest,
2644 /*num_tokens_left=*/max_selection_span,
2645 /*num_tokens_right=*/max_selection_span),
2646 {0, num_tokens});
2647
2648 std::vector<ScoredChunk> scored_chunks;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002649 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2650 selection_feature_processor_->GetOptions()
2651 ->bounds_sensitive_features()
2652 ->enabled()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002653 if (!ModelBoundsSensitiveScoreChunks(
2654 num_tokens, span_of_interest, inference_span, cached_features,
2655 selection_interpreter, &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002656 return false;
2657 }
2658 } else {
2659 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002660 cached_features, selection_interpreter,
2661 &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002662 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002663 }
2664 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002665 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2666 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2667 return lhs.score < rhs.score;
2668 });
Lukas Zilka21d8c982018-01-24 11:11:20 +01002669
2670 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2671 // them greedily as long as they do not overlap with any previously picked
2672 // chunks.
2673 std::vector<bool> token_used(TokenSpanSize(inference_span));
2674 chunks->clear();
2675 for (const ScoredChunk& scored_chunk : scored_chunks) {
2676 bool feasible = true;
2677 for (int i = scored_chunk.token_span.first;
2678 i < scored_chunk.token_span.second; ++i) {
2679 if (token_used[i - inference_span.first]) {
2680 feasible = false;
2681 break;
2682 }
2683 }
2684
2685 if (!feasible) {
2686 continue;
2687 }
2688
2689 for (int i = scored_chunk.token_span.first;
2690 i < scored_chunk.token_span.second; ++i) {
2691 token_used[i - inference_span.first] = true;
2692 }
2693
2694 chunks->push_back(scored_chunk.token_span);
2695 }
2696
2697 std::sort(chunks->begin(), chunks->end());
2698
2699 return true;
2700}
2701
Lukas Zilkab23e2122018-02-09 10:25:19 +01002702namespace {
2703// Updates the value at the given key in the map to maximum of the current value
2704// and the given value, or simply inserts the value if the key is not yet there.
2705template <typename Map>
2706void UpdateMax(Map* map, typename Map::key_type key,
2707 typename Map::mapped_type value) {
2708 const auto it = map->find(key);
2709 if (it != map->end()) {
2710 it->second = std::max(it->second, value);
2711 } else {
2712 (*map)[key] = value;
2713 }
2714}
2715} // namespace
2716
Tony Mak6c4cc672018-09-17 11:48:50 +01002717bool Annotator::ModelClickContextScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002718 int num_tokens, const TokenSpan& span_of_interest,
2719 const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002720 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002721 std::vector<ScoredChunk>* scored_chunks) const {
2722 const int max_batch_size = model_->selection_options()->batch_size();
2723
2724 std::vector<float> all_features;
2725 std::map<TokenSpan, float> chunk_scores;
2726 for (int batch_start = span_of_interest.first;
2727 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2728 const int batch_end =
2729 std::min(batch_start + max_batch_size, span_of_interest.second);
2730
2731 // Prepare features for the whole batch.
2732 all_features.clear();
2733 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2734 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2735 cached_features.AppendClickContextFeaturesForClick(click_pos,
2736 &all_features);
2737 }
2738
2739 // Run batched inference.
2740 const int batch_size = batch_end - batch_start;
2741 const int features_size = cached_features.OutputFeaturesSize();
2742 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002743 TensorView<float>(all_features.data(), {batch_size, features_size}),
2744 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002745 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002746 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002747 return false;
2748 }
2749 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2750 logits.dim(1) !=
2751 selection_feature_processor_->GetSelectionLabelCount()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002752 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002753 return false;
2754 }
2755
2756 // Save results.
2757 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2758 const std::vector<float> scores = ComputeSoftmax(
2759 logits.data() + logits.dim(1) * (click_pos - batch_start),
2760 logits.dim(1));
2761 for (int j = 0;
2762 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2763 TokenSpan relative_token_span;
2764 if (!selection_feature_processor_->LabelToTokenSpan(
2765 j, &relative_token_span)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002766 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002767 return false;
2768 }
2769 const TokenSpan candidate_span = ExpandTokenSpan(
2770 SingleTokenSpan(click_pos), relative_token_span.first,
2771 relative_token_span.second);
2772 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2773 UpdateMax(&chunk_scores, candidate_span, scores[j]);
2774 }
2775 }
2776 }
2777 }
2778
2779 scored_chunks->clear();
2780 scored_chunks->reserve(chunk_scores.size());
2781 for (const auto& entry : chunk_scores) {
2782 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2783 }
2784
2785 return true;
2786}
2787
Tony Mak6c4cc672018-09-17 11:48:50 +01002788bool Annotator::ModelBoundsSensitiveScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002789 int num_tokens, const TokenSpan& span_of_interest,
2790 const TokenSpan& inference_span, const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002791 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002792 std::vector<ScoredChunk>* scored_chunks) const {
2793 const int max_selection_span =
2794 selection_feature_processor_->GetOptions()->max_selection_span();
2795 const int max_chunk_length = selection_feature_processor_->GetOptions()
2796 ->selection_reduced_output_space()
2797 ? max_selection_span + 1
2798 : 2 * max_selection_span + 1;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002799 const bool score_single_token_spans_as_zero =
2800 selection_feature_processor_->GetOptions()
2801 ->bounds_sensitive_features()
2802 ->score_single_token_spans_as_zero();
2803
2804 scored_chunks->clear();
2805 if (score_single_token_spans_as_zero) {
2806 scored_chunks->reserve(TokenSpanSize(span_of_interest));
2807 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002808
2809 // Prepare all chunk candidates into one batch:
2810 // - Are contained in the inference span
2811 // - Have a non-empty intersection with the span of interest
2812 // - Are at least one token long
2813 // - Are not longer than the maximum chunk length
2814 std::vector<TokenSpan> candidate_spans;
2815 for (int start = inference_span.first; start < span_of_interest.second;
2816 ++start) {
2817 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2818 for (int end = leftmost_end_index;
2819 end <= inference_span.second && end - start <= max_chunk_length;
2820 ++end) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002821 const TokenSpan candidate_span = {start, end};
2822 if (score_single_token_spans_as_zero &&
2823 TokenSpanSize(candidate_span) == 1) {
2824 // Do not include the single token span in the batch, add a zero score
2825 // for it directly to the output.
2826 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2827 } else {
2828 candidate_spans.push_back(candidate_span);
2829 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002830 }
2831 }
2832
2833 const int max_batch_size = model_->selection_options()->batch_size();
2834
2835 std::vector<float> all_features;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002836 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
Lukas Zilkab23e2122018-02-09 10:25:19 +01002837 for (int batch_start = 0; batch_start < candidate_spans.size();
2838 batch_start += max_batch_size) {
2839 const int batch_end = std::min(batch_start + max_batch_size,
2840 static_cast<int>(candidate_spans.size()));
2841
2842 // Prepare features for the whole batch.
2843 all_features.clear();
2844 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2845 for (int i = batch_start; i < batch_end; ++i) {
2846 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2847 &all_features);
2848 }
2849
2850 // Run batched inference.
2851 const int batch_size = batch_end - batch_start;
2852 const int features_size = cached_features.OutputFeaturesSize();
2853 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002854 TensorView<float>(all_features.data(), {batch_size, features_size}),
2855 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002856 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002857 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002858 return false;
2859 }
2860 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2861 logits.dim(1) != 1) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002862 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002863 return false;
2864 }
2865
2866 // Save results.
2867 for (int i = batch_start; i < batch_end; ++i) {
2868 scored_chunks->push_back(
2869 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2870 }
2871 }
2872
2873 return true;
2874}
2875
Tony Mak6c4cc672018-09-17 11:48:50 +01002876bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
2877 int64 reference_time_ms_utc,
2878 const std::string& reference_timezone,
2879 const std::string& locales, ModeFlag mode,
Tony Makd9446602019-02-20 18:25:39 +00002880 AnnotationUsecase annotation_usecase,
Tony Mak83d2de62019-04-10 16:12:15 +01002881 bool is_serialized_entity_data_enabled,
Tony Mak6c4cc672018-09-17 11:48:50 +01002882 std::vector<AnnotatedSpan>* result) const {
Tony Mak63959242020-02-07 18:31:16 +00002883 std::vector<DatetimeParseResultSpan> datetime_spans;
2884 if (cfg_datetime_parser_) {
2885 if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
2886 return true;
2887 }
2888 std::vector<Locale> parsed_locales;
2889 ParseLocales(locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00002890 cfg_datetime_parser_->Parse(
2891 context_unicode.ToUTF8String(),
2892 ToDateAnnotationOptions(
2893 model_->grammar_datetime_model()->annotation_options(),
2894 reference_timezone, reference_time_ms_utc),
2895 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00002896 }
2897
2898 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00002899 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
2900 reference_timezone, locales, mode,
2901 annotation_usecase,
2902 /*anchor_start_end=*/false, &datetime_spans)) {
2903 return false;
2904 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002905 }
2906
Lukas Zilkab23e2122018-02-09 10:25:19 +01002907 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
Tony Mak378c1f52019-03-04 15:58:11 +00002908 AnnotatedSpan annotated_span;
2909 annotated_span.span = datetime_span.span;
Tony Mak854015a2019-01-16 15:56:48 +00002910 for (const DatetimeParseResult& parse_result : datetime_span.data) {
Tony Mak378c1f52019-03-04 15:58:11 +00002911 annotated_span.classification.emplace_back(
2912 PickCollectionForDatetime(parse_result),
2913 datetime_span.target_classification_score,
2914 datetime_span.priority_score);
2915 annotated_span.classification.back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01002916 if (is_serialized_entity_data_enabled) {
2917 annotated_span.classification.back().serialized_entity_data =
2918 CreateDatetimeSerializedEntityData(parse_result);
2919 }
Tony Mak854015a2019-01-16 15:56:48 +00002920 }
Tony Mak448b5862019-03-22 13:36:41 +00002921 annotated_span.source = AnnotatedSpan::Source::DATETIME;
Tony Mak378c1f52019-03-04 15:58:11 +00002922 result->push_back(std::move(annotated_span));
Lukas Zilkab23e2122018-02-09 10:25:19 +01002923 }
2924 return true;
2925}
2926
Tony Mak378c1f52019-03-04 15:58:11 +00002927const Model* Annotator::model() const { return model_; }
Tony Makd9446602019-02-20 18:25:39 +00002928const reflection::Schema* Annotator::entity_data_schema() const {
2929 return entity_data_schema_;
2930}
Tony Mak854015a2019-01-16 15:56:48 +00002931
Lukas Zilka21d8c982018-01-24 11:11:20 +01002932const Model* ViewModel(const void* buffer, int size) {
2933 if (!buffer) {
2934 return nullptr;
2935 }
2936
2937 return LoadAndVerifyModel(buffer, size);
2938}
2939
Tony Makd9446602019-02-20 18:25:39 +00002940bool Annotator::LookUpKnowledgeEntity(
2941 const std::string& id, std::string* serialized_knowledge_result) const {
2942 return knowledge_engine_ &&
2943 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
2944}
2945
Tony Mak6c4cc672018-09-17 11:48:50 +01002946} // namespace libtextclassifier3