blob: 1c3c102618f8ce1a063aea4478b0535649eebe83 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <algorithm>
Lukas Zilka21d8c982018-01-24 11:11:20 +010020#include <cmath>
Tony Mak21460022020-03-12 18:29:35 +000021#include <cstddef>
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <iterator>
23#include <numeric>
Tony Mak63959242020-02-07 18:31:16 +000024#include <string>
Tony Mak448b5862019-03-22 13:36:41 +000025#include <unordered_map>
Tony Mak63959242020-02-07 18:31:16 +000026#include <vector>
Lukas Zilka21d8c982018-01-24 11:11:20 +010027
Tony Mak854015a2019-01-16 15:56:48 +000028#include "annotator/collections.h"
Tony Maka44b3082020-08-13 18:57:10 +010029#include "annotator/flatbuffer-utils.h"
30#include "annotator/knowledge/knowledge-engine-types.h"
Tony Mak83d2de62019-04-10 16:12:15 +010031#include "annotator/model_generated.h"
32#include "annotator/types.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010033#include "utils/base/logging.h"
Tony Makff31efb2020-03-31 11:13:06 +010034#include "utils/base/status.h"
35#include "utils/base/statusor.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010036#include "utils/checksum.h"
Tony Mak63959242020-02-07 18:31:16 +000037#include "utils/i18n/locale.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010038#include "utils/math/softmax.h"
Tony Mak8cd7ba62019-10-15 15:29:22 +010039#include "utils/normalization.h"
Tony Maka2a1ff42019-09-12 15:40:32 +010040#include "utils/optional.h"
Tony Makd9446602019-02-20 18:25:39 +000041#include "utils/regex-match.h"
Tony Maka44b3082020-08-13 18:57:10 +010042#include "utils/strings/append.h"
Tony Mak63959242020-02-07 18:31:16 +000043#include "utils/strings/numbers.h"
44#include "utils/strings/split.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010045#include "utils/utf8/unicodetext.h"
Tony Mak21460022020-03-12 18:29:35 +000046#include "utils/utf8/unilib-common.h"
Tony Mak378c1f52019-03-04 15:58:11 +000047#include "utils/zlib/zlib_regex.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010048
Tony Mak6c4cc672018-09-17 11:48:50 +010049namespace libtextclassifier3 {
Tony Mak448b5862019-03-22 13:36:41 +000050
51using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
52
Tony Mak6c4cc672018-09-17 11:48:50 +010053const std::string& Annotator::kPhoneCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010054 *[]() { return new std::string("phone"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010055const std::string& Annotator::kAddressCollection =
Lukas Zilkae7962cc2018-03-28 18:09:48 +020056 *[]() { return new std::string("address"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010057const std::string& Annotator::kDateCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010058 *[]() { return new std::string("date"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000059const std::string& Annotator::kUrlCollection =
60 *[]() { return new std::string("url"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000061const std::string& Annotator::kEmailCollection =
62 *[]() { return new std::string("email"); }();
Lukas Zilkab23e2122018-02-09 10:25:19 +010063
Lukas Zilka21d8c982018-01-24 11:11:20 +010064namespace {
65const Model* LoadAndVerifyModel(const void* addr, int size) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010066 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
Tony Mak51a9e542018-11-02 13:36:22 +000067 if (VerifyModelBuffer(verifier)) {
68 return GetModel(addr);
Lukas Zilka21d8c982018-01-24 11:11:20 +010069 } else {
70 return nullptr;
71 }
72}
Tony Mak6c4cc672018-09-17 11:48:50 +010073
Tony Mak76d80962020-01-08 17:30:51 +000074const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
75 int size) {
76 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
77 if (VerifyPersonNameModelBuffer(verifier)) {
78 return GetPersonNameModel(addr);
79 } else {
80 return nullptr;
81 }
82}
83
Tony Mak6c4cc672018-09-17 11:48:50 +010084// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
85// create a new instance, assign ownership to owned_lib, and return it.
86const UniLib* MaybeCreateUnilib(const UniLib* lib,
87 std::unique_ptr<UniLib>* owned_lib) {
88 if (lib) {
89 return lib;
90 } else {
91 owned_lib->reset(new UniLib);
92 return owned_lib->get();
93 }
94}
95
96// As above, but for CalendarLib.
97const CalendarLib* MaybeCreateCalendarlib(
98 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
99 if (lib) {
100 return lib;
101 } else {
102 owned_lib->reset(new CalendarLib);
103 return owned_lib->get();
104 }
105}
106
Tony Mak968412a2019-11-13 15:39:57 +0000107// Returns whether the provided input is valid:
108// * Valid utf8 text.
109// * Sane span indices.
Tony Maka44b3082020-08-13 18:57:10 +0100110bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
Tony Mak968412a2019-11-13 15:39:57 +0000111 if (!context.is_valid()) {
112 return false;
113 }
114 return (span.first >= 0 && span.first < span.second &&
115 span.second <= context.size_codepoints());
116}
117
Tony Mak63959242020-02-07 18:31:16 +0000118std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
119 const flatbuffers::Vector<int32_t>* ints) {
120 if (ints == nullptr) {
121 return {};
122 }
123 std::unordered_set<char32> ints_set;
124 for (auto value : *ints) {
125 ints_set.insert(static_cast<char32>(value));
126 }
127 return ints_set;
128}
129
Tony Mak21460022020-03-12 18:29:35 +0000130DateAnnotationOptions ToDateAnnotationOptions(
131 const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
132 const std::string& reference_timezone, const int64 reference_time_ms_utc) {
133 DateAnnotationOptions result_annotation_options;
134 result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
135 result_annotation_options.reference_timezone = reference_timezone;
136 if (fb_annotation_options != nullptr) {
137 result_annotation_options.enable_special_day_offset =
138 fb_annotation_options->enable_special_day_offset();
139 result_annotation_options.merge_adjacent_components =
140 fb_annotation_options->merge_adjacent_components();
141 result_annotation_options.enable_date_range =
142 fb_annotation_options->enable_date_range();
143 result_annotation_options.include_preposition =
144 fb_annotation_options->include_preposition();
Tony Mak21460022020-03-12 18:29:35 +0000145 if (fb_annotation_options->extra_requested_dates() != nullptr) {
146 for (const auto& extra_requested_date :
147 *fb_annotation_options->extra_requested_dates()) {
148 result_annotation_options.extra_requested_dates.push_back(
149 extra_requested_date->str());
150 }
151 }
Tony Makd99d58c2020-03-19 21:52:02 +0000152 if (fb_annotation_options->ignored_spans() != nullptr) {
153 for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) {
154 result_annotation_options.ignored_spans.push_back(ignored_span->str());
Tony Mak0b8b3322020-03-17 16:30:19 +0000155 }
156 }
Tony Mak21460022020-03-12 18:29:35 +0000157 }
158 return result_annotation_options;
159}
160
Lukas Zilka21d8c982018-01-24 11:11:20 +0100161} // namespace
162
Lukas Zilkaba849e72018-03-08 14:48:21 +0100163tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
164 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100165 TC3_CHECK(selection_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100166 selection_interpreter_ = selection_executor_->CreateInterpreter();
167 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100168 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100169 }
170 }
171 return selection_interpreter_.get();
172}
173
174tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
175 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100176 TC3_CHECK(classification_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100177 classification_interpreter_ = classification_executor_->CreateInterpreter();
178 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100179 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100180 }
181 }
182 return classification_interpreter_.get();
183}
184
Tony Mak6c4cc672018-09-17 11:48:50 +0100185std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
186 const char* buffer, int size, const UniLib* unilib,
187 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100188 const Model* model = LoadAndVerifyModel(buffer, size);
189 if (model == nullptr) {
190 return nullptr;
191 }
192
Lukas Zilkab23e2122018-02-09 10:25:19 +0100193 auto classifier =
Tony Mak6c4cc672018-09-17 11:48:50 +0100194 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100195 if (!classifier->IsInitialized()) {
196 return nullptr;
197 }
198
199 return classifier;
200}
201
Tony Mak6c4cc672018-09-17 11:48:50 +0100202std::unique_ptr<Annotator> Annotator::FromScopedMmap(
203 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
204 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100205 if (!(*mmap)->handle().ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100206 TC3_VLOG(1) << "Mmap failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100207 return nullptr;
208 }
209
210 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
211 (*mmap)->handle().num_bytes());
212 if (!model) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100213 TC3_LOG(ERROR) << "Model verification failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100214 return nullptr;
215 }
216
Tony Mak6c4cc672018-09-17 11:48:50 +0100217 auto classifier = std::unique_ptr<Annotator>(
218 new Annotator(mmap, model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100219 if (!classifier->IsInitialized()) {
220 return nullptr;
221 }
222
223 return classifier;
224}
225
Tony Makdf54e742019-03-26 14:04:00 +0000226std::unique_ptr<Annotator> Annotator::FromScopedMmap(
227 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
228 std::unique_ptr<CalendarLib> calendarlib) {
229 if (!(*mmap)->handle().ok()) {
230 TC3_VLOG(1) << "Mmap failed.";
231 return nullptr;
232 }
233
234 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
235 (*mmap)->handle().num_bytes());
236 if (model == nullptr) {
237 TC3_LOG(ERROR) << "Model verification failed.";
238 return nullptr;
239 }
240
241 auto classifier = std::unique_ptr<Annotator>(
242 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
243 if (!classifier->IsInitialized()) {
244 return nullptr;
245 }
246
247 return classifier;
248}
249
Tony Mak6c4cc672018-09-17 11:48:50 +0100250std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
251 int fd, int offset, int size, const UniLib* unilib,
252 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100253 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
Tony Mak6c4cc672018-09-17 11:48:50 +0100254 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100255}
256
Tony Mak6c4cc672018-09-17 11:48:50 +0100257std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Makdf54e742019-03-26 14:04:00 +0000258 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
259 std::unique_ptr<CalendarLib> calendarlib) {
260 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
261 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
262}
263
264std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Mak6c4cc672018-09-17 11:48:50 +0100265 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100266 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
Tony Mak6c4cc672018-09-17 11:48:50 +0100267 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100268}
269
Tony Makdf54e742019-03-26 14:04:00 +0000270std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
271 int fd, std::unique_ptr<UniLib> unilib,
272 std::unique_ptr<CalendarLib> calendarlib) {
273 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
274 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
275}
276
Tony Mak6c4cc672018-09-17 11:48:50 +0100277std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
278 const UniLib* unilib,
279 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100280 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
Tony Mak6c4cc672018-09-17 11:48:50 +0100281 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100282}
283
Tony Makdf54e742019-03-26 14:04:00 +0000284std::unique_ptr<Annotator> Annotator::FromPath(
285 const std::string& path, std::unique_ptr<UniLib> unilib,
286 std::unique_ptr<CalendarLib> calendarlib) {
287 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
288 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
289}
290
Tony Mak6c4cc672018-09-17 11:48:50 +0100291Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
292 const UniLib* unilib, const CalendarLib* calendarlib)
293 : model_(model),
294 mmap_(std::move(*mmap)),
295 owned_unilib_(nullptr),
296 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
297 owned_calendarlib_(nullptr),
298 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
299 ValidateAndInitialize();
300}
301
Tony Makdf54e742019-03-26 14:04:00 +0000302Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
303 std::unique_ptr<UniLib> unilib,
304 std::unique_ptr<CalendarLib> calendarlib)
305 : model_(model),
306 mmap_(std::move(*mmap)),
307 owned_unilib_(std::move(unilib)),
308 unilib_(owned_unilib_.get()),
309 owned_calendarlib_(std::move(calendarlib)),
310 calendarlib_(owned_calendarlib_.get()) {
311 ValidateAndInitialize();
312}
313
Tony Mak6c4cc672018-09-17 11:48:50 +0100314Annotator::Annotator(const Model* model, const UniLib* unilib,
315 const CalendarLib* calendarlib)
316 : model_(model),
317 owned_unilib_(nullptr),
318 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
319 owned_calendarlib_(nullptr),
320 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
321 ValidateAndInitialize();
322}
323
324void Annotator::ValidateAndInitialize() {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100325 initialized_ = false;
326
Lukas Zilka21d8c982018-01-24 11:11:20 +0100327 if (model_ == nullptr) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100328 TC3_LOG(ERROR) << "No model specified.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100329 return;
330 }
331
Lukas Zilkaba849e72018-03-08 14:48:21 +0100332 const bool model_enabled_for_annotation =
333 (model_->triggering_options() != nullptr &&
334 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
335 const bool model_enabled_for_classification =
336 (model_->triggering_options() != nullptr &&
337 (model_->triggering_options()->enabled_modes() &
338 ModeFlag_CLASSIFICATION));
339 const bool model_enabled_for_selection =
340 (model_->triggering_options() != nullptr &&
341 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
342
343 // Annotation requires the selection model.
344 if (model_enabled_for_annotation || model_enabled_for_selection) {
345 if (!model_->selection_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100346 TC3_LOG(ERROR) << "No selection options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100347 return;
348 }
349 if (!model_->selection_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100350 TC3_LOG(ERROR) << "No selection feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100351 return;
352 }
353 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100354 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100355 return;
356 }
357 if (!model_->selection_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100358 TC3_LOG(ERROR) << "No selection model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100359 return;
360 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100361 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100362 if (!selection_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100363 TC3_LOG(ERROR) << "Could not initialize selection executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100364 return;
365 }
366 selection_feature_processor_.reset(
367 new FeatureProcessor(model_->selection_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100368 }
369
Lukas Zilkaba849e72018-03-08 14:48:21 +0100370 // Annotation requires the classification model for conflict resolution and
371 // scoring.
372 // Selection requires the classification model for conflict resolution.
373 if (model_enabled_for_annotation || model_enabled_for_classification ||
374 model_enabled_for_selection) {
375 if (!model_->classification_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100376 TC3_LOG(ERROR) << "No classification options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100377 return;
378 }
379
380 if (!model_->classification_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100381 TC3_LOG(ERROR) << "No classification feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100382 return;
383 }
384
385 if (!model_->classification_feature_options()
386 ->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100387 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100388 return;
389 }
390 if (!model_->classification_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100391 TC3_LOG(ERROR) << "No clf model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100392 return;
393 }
394
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200395 classification_executor_ =
Tony Mak6c4cc672018-09-17 11:48:50 +0100396 ModelExecutor::FromBuffer(model_->classification_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100397 if (!classification_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100398 TC3_LOG(ERROR) << "Could not initialize classification executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100399 return;
400 }
401
402 classification_feature_processor_.reset(new FeatureProcessor(
403 model_->classification_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100404 }
405
Lukas Zilkaba849e72018-03-08 14:48:21 +0100406 // The embeddings need to be specified if the model is to be used for
407 // classification or selection.
408 if (model_enabled_for_annotation || model_enabled_for_classification ||
409 model_enabled_for_selection) {
410 if (!model_->embedding_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100411 TC3_LOG(ERROR) << "No embedding model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100412 return;
413 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100414
Lukas Zilkaba849e72018-03-08 14:48:21 +0100415 // Check that the embedding size of the selection and classification model
416 // matches, as they are using the same embeddings.
417 if (model_enabled_for_selection &&
418 (model_->selection_feature_options()->embedding_size() !=
419 model_->classification_feature_options()->embedding_size() ||
420 model_->selection_feature_options()->embedding_quantization_bits() !=
421 model_->classification_feature_options()
422 ->embedding_quantization_bits())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100423 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100424 return;
425 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100426
Tony Mak6c4cc672018-09-17 11:48:50 +0100427 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200428 model_->embedding_model(),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100429 model_->classification_feature_options()->embedding_size(),
Tony Makdf54e742019-03-26 14:04:00 +0000430 model_->classification_feature_options()->embedding_quantization_bits(),
431 model_->embedding_pruning_mask());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200432 if (!embedding_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100433 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100434 return;
435 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100436 }
437
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200438 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
Lukas Zilkab23e2122018-02-09 10:25:19 +0100439 if (model_->regex_model()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200440 if (!InitializeRegexModel(decompressor.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100441 TC3_LOG(ERROR) << "Could not initialize regex model.";
Lukas Zilka434442d2018-04-25 11:38:51 +0200442 return;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100443 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100444 }
Tony Mak63959242020-02-07 18:31:16 +0000445 if (model_->grammar_datetime_model() &&
446 model_->grammar_datetime_model()->datetime_rules()) {
447 cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100448 unilib_,
Tony Mak63959242020-02-07 18:31:16 +0000449 /*tokenizer_options=*/
450 model_->grammar_datetime_model()->grammar_tokenizer_options(),
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100451 calendarlib_,
Tony Mak21460022020-03-12 18:29:35 +0000452 /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
453 model_->grammar_datetime_model()->target_classification_score(),
454 model_->grammar_datetime_model()->priority_score()));
Tony Mak63959242020-02-07 18:31:16 +0000455 if (!cfg_datetime_parser_) {
456 TC3_LOG(ERROR) << "Could not initialize context free grammar based "
457 "datetime parser.";
458 return;
459 }
Tony Makd99d58c2020-03-19 21:52:02 +0000460 }
461
462 if (model_->datetime_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100463 datetime_parser_ = DatetimeParser::Instance(
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100464 model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100465 if (!datetime_parser_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100466 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100467 return;
468 }
469 }
470
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200471 if (model_->output_options()) {
472 if (model_->output_options()->filtered_collections_annotation()) {
473 for (const auto collection :
474 *model_->output_options()->filtered_collections_annotation()) {
475 filtered_collections_annotation_.insert(collection->str());
476 }
477 }
478 if (model_->output_options()->filtered_collections_classification()) {
479 for (const auto collection :
480 *model_->output_options()->filtered_collections_classification()) {
481 filtered_collections_classification_.insert(collection->str());
482 }
483 }
484 if (model_->output_options()->filtered_collections_selection()) {
485 for (const auto collection :
486 *model_->output_options()->filtered_collections_selection()) {
487 filtered_collections_selection_.insert(collection->str());
488 }
489 }
490 }
491
Tony Mak378c1f52019-03-04 15:58:11 +0000492 if (model_->number_annotator_options() &&
493 model_->number_annotator_options()->enabled()) {
494 number_annotator_.reset(
Tony Mak63959242020-02-07 18:31:16 +0000495 new NumberAnnotator(model_->number_annotator_options(), unilib_));
496 }
497
498 if (model_->money_parsing_options()) {
499 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
500 model_->money_parsing_options()->separators());
Tony Mak378c1f52019-03-04 15:58:11 +0000501 }
502
Tony Makad2e22d2019-03-20 17:35:13 +0000503 if (model_->duration_annotator_options() &&
504 model_->duration_annotator_options()->enabled()) {
505 duration_annotator_.reset(
506 new DurationAnnotator(model_->duration_annotator_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +0100507 selection_feature_processor_.get(), unilib_));
Tony Makad2e22d2019-03-20 17:35:13 +0000508 }
509
Tony Maka44b3082020-08-13 18:57:10 +0100510 if (model_->grammar_model()) {
511 grammar_annotator_.reset(new GrammarAnnotator(
512 unilib_, model_->grammar_model(), entity_data_builder_.get()));
513 }
514
Tony Maka5090082020-09-18 16:41:23 +0100515 // The following #ifdef is here to aid quality evaluation of a situation, when
516 // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
517 // it.
518#if !defined(TC3_DISABLE_POD_NER)
Tony Maka44b3082020-08-13 18:57:10 +0100519 if (model_->pod_ner_model()) {
520 pod_ner_annotator_ =
521 PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
522 }
Tony Maka5090082020-09-18 16:41:23 +0100523#endif
524
525 if (model_->vocab_model()) {
526 vocab_annotator_ = VocabAnnotator::Create(
527 model_->vocab_model(), *selection_feature_processor_, *unilib_);
528 }
Tony Maka44b3082020-08-13 18:57:10 +0100529
Tony Makd9446602019-02-20 18:25:39 +0000530 if (model_->entity_data_schema()) {
531 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
532 model_->entity_data_schema()->Data(),
533 model_->entity_data_schema()->size());
534 if (entity_data_schema_ == nullptr) {
535 TC3_LOG(ERROR) << "Could not load entity data schema data.";
536 return;
537 }
538
539 entity_data_builder_.reset(
Tony Maka44b3082020-08-13 18:57:10 +0100540 new MutableFlatbufferBuilder(entity_data_schema_));
Tony Makd9446602019-02-20 18:25:39 +0000541 } else {
Tony Mak378c1f52019-03-04 15:58:11 +0000542 entity_data_schema_ = nullptr;
Tony Makd9446602019-02-20 18:25:39 +0000543 entity_data_builder_ = nullptr;
544 }
545
Tony Makdf54e742019-03-26 14:04:00 +0000546 if (model_->triggering_locales() &&
547 !ParseLocales(model_->triggering_locales()->c_str(),
548 &model_triggering_locales_)) {
Tony Mak378c1f52019-03-04 15:58:11 +0000549 TC3_LOG(ERROR) << "Could not parse model supported locales.";
550 return;
551 }
552
553 if (model_->triggering_options() != nullptr &&
Tony Makdf54e742019-03-26 14:04:00 +0000554 model_->triggering_options()->locales() != nullptr &&
555 !ParseLocales(model_->triggering_options()->locales()->c_str(),
556 &ml_model_triggering_locales_)) {
557 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
558 return;
559 }
560
561 if (model_->triggering_options() != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +0000562 model_->triggering_options()->dictionary_locales() != nullptr &&
563 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
564 &dictionary_locales_)) {
565 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
566 return;
567 }
568
Tony Mak5a12b942020-05-01 12:41:31 +0100569 if (model_->conflict_resolution_options() != nullptr) {
570 prioritize_longest_annotation_ =
571 model_->conflict_resolution_options()->prioritize_longest_annotation();
572 do_conflict_resolution_in_raw_mode_ =
573 model_->conflict_resolution_options()
574 ->do_conflict_resolution_in_raw_mode();
575 }
576
Chang Licac0b442020-05-21 15:09:37 +0100577#ifdef TC3_EXPERIMENTAL
578 TC3_LOG(WARNING) << "Enabling experimental annotators.";
579 InitializeExperimentalAnnotators();
580#endif
581
Lukas Zilka21d8c982018-01-24 11:11:20 +0100582 initialized_ = true;
583}
584
Tony Mak6c4cc672018-09-17 11:48:50 +0100585bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100586 if (!model_->regex_model()->patterns()) {
Lukas Zilka434442d2018-04-25 11:38:51 +0200587 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100588 }
589
590 // Initialize pattern recognizers.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100591 int regex_pattern_id = 0;
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100592 for (const auto regex_pattern : *model_->regex_model()->patterns()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200593 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
Tony Mak378c1f52019-03-04 15:58:11 +0000594 UncompressMakeRegexPattern(
595 *unilib_, regex_pattern->pattern(),
596 regex_pattern->compressed_pattern(),
597 model_->regex_model()->lazy_regex_compilation(), decompressor);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100598 if (!compiled_pattern) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100599 TC3_LOG(INFO) << "Failed to load regex pattern";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200600 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100601 }
602
Lukas Zilkaba849e72018-03-08 14:48:21 +0100603 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100604 annotation_regex_patterns_.push_back(regex_pattern_id);
605 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100606 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100607 classification_regex_patterns_.push_back(regex_pattern_id);
608 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100609 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100610 selection_regex_patterns_.push_back(regex_pattern_id);
611 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100612 regex_patterns_.push_back({
Tony Mak854015a2019-01-16 15:56:48 +0000613 regex_pattern,
Tony Mak6c4cc672018-09-17 11:48:50 +0100614 std::move(compiled_pattern),
Tony Mak6c4cc672018-09-17 11:48:50 +0100615 });
Lukas Zilkab23e2122018-02-09 10:25:19 +0100616 ++regex_pattern_id;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100617 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100618
Lukas Zilkab23e2122018-02-09 10:25:19 +0100619 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100620}
621
Tony Mak6c4cc672018-09-17 11:48:50 +0100622bool Annotator::InitializeKnowledgeEngine(
623 const std::string& serialized_config) {
Tony Maka2a1ff42019-09-12 15:40:32 +0100624 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
Tony Mak63959242020-02-07 18:31:16 +0000625 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100626 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
627 return false;
628 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100629 if (model_->triggering_options() != nullptr) {
630 knowledge_engine->SetPriorityScore(
631 model_->triggering_options()->knowledge_priority_score());
632 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100633 knowledge_engine_ = std::move(knowledge_engine);
634 return true;
635}
636
Tony Mak854015a2019-01-16 15:56:48 +0000637bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
Tony Mak5dc5e112019-02-01 14:52:10 +0000638 std::unique_ptr<ContactEngine> contact_engine(
Tony Mak63959242020-02-07 18:31:16 +0000639 new ContactEngine(selection_feature_processor_.get(), unilib_,
640 model_->contact_annotator_options()));
Tony Mak854015a2019-01-16 15:56:48 +0000641 if (!contact_engine->Initialize(serialized_config)) {
642 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
643 return false;
644 }
645 contact_engine_ = std::move(contact_engine);
646 return true;
647}
648
Tony Makd9446602019-02-20 18:25:39 +0000649bool Annotator::InitializeInstalledAppEngine(
650 const std::string& serialized_config) {
651 std::unique_ptr<InstalledAppEngine> installed_app_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000652 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
Tony Makd9446602019-02-20 18:25:39 +0000653 if (!installed_app_engine->Initialize(serialized_config)) {
654 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
655 return false;
656 }
657 installed_app_engine_ = std::move(installed_app_engine);
658 return true;
659}
660
Tony Mak63959242020-02-07 18:31:16 +0000661void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
662 lang_id_ = lang_id;
Tony Mak21460022020-03-12 18:29:35 +0000663 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
Tony Mak63959242020-02-07 18:31:16 +0000664 model_->translate_annotator_options()->enabled()) {
665 translate_annotator_.reset(new TranslateAnnotator(
666 model_->translate_annotator_options(), lang_id_, unilib_));
Tony Mak21460022020-03-12 18:29:35 +0000667 } else {
668 translate_annotator_.reset(nullptr);
Tony Mak63959242020-02-07 18:31:16 +0000669 }
670}
671
Tony Mak21460022020-03-12 18:29:35 +0000672bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
673 int size) {
674 const PersonNameModel* person_name_model =
675 LoadAndVerifyPersonNameModel(buffer, size);
Tony Mak76d80962020-01-08 17:30:51 +0000676
677 if (person_name_model == nullptr) {
678 TC3_LOG(ERROR) << "Person name model verification failed.";
679 return false;
680 }
681
682 if (!person_name_model->enabled()) {
683 return true;
684 }
685
686 std::unique_ptr<PersonNameEngine> person_name_engine(
Tony Mak21460022020-03-12 18:29:35 +0000687 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
Tony Mak76d80962020-01-08 17:30:51 +0000688 if (!person_name_engine->Initialize(person_name_model)) {
689 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
690 return false;
691 }
692 person_name_engine_ = std::move(person_name_engine);
693 return true;
694}
695
Tony Mak21460022020-03-12 18:29:35 +0000696bool Annotator::InitializePersonNameEngineFromScopedMmap(
697 const ScopedMmap& mmap) {
698 if (!mmap.handle().ok()) {
699 TC3_LOG(ERROR) << "Mmap for person name model failed.";
700 return false;
701 }
702
703 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
704 mmap.handle().num_bytes());
705}
706
707bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
708 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
709 return InitializePersonNameEngineFromScopedMmap(*mmap);
710}
711
712bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
713 int size) {
714 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
715 return InitializePersonNameEngineFromScopedMmap(*mmap);
716}
717
Tony Mak5a12b942020-05-01 12:41:31 +0100718bool Annotator::InitializeExperimentalAnnotators() {
719 if (ExperimentalAnnotator::IsEnabled()) {
Tony Makc121edd2020-05-28 15:25:17 +0100720 experimental_annotator_.reset(new ExperimentalAnnotator(
721 model_->experimental_model(), *selection_feature_processor_, *unilib_));
Tony Mak5a12b942020-05-01 12:41:31 +0100722 return true;
723 }
724 return false;
725}
726
Lukas Zilka21d8c982018-01-24 11:11:20 +0100727namespace {
728
Tony Maka44b3082020-08-13 18:57:10 +0100729int CountDigits(const std::string& str,
730 const CodepointSpan& selection_indices) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100731 int count = 0;
732 int i = 0;
733 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
734 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
735 if (i >= selection_indices.first && i < selection_indices.second &&
Tony Mak21460022020-03-12 18:29:35 +0000736 IsDigit(*it)) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100737 ++count;
738 }
739 }
740 return count;
741}
742
Lukas Zilka21d8c982018-01-24 11:11:20 +0100743} // namespace
744
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200745namespace internal {
746// Helper function, which if the initial 'span' contains only white-spaces,
747// moves the selection to a single-codepoint selection on a left or right side
748// of this space.
Tony Maka44b3082020-08-13 18:57:10 +0100749CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200750 const UnicodeText& context_unicode,
751 const UniLib& unilib) {
Tony Maka44b3082020-08-13 18:57:10 +0100752 TC3_CHECK(span.IsValid() && !span.IsEmpty());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200753
754 UnicodeText::const_iterator it;
755
756 // Check that the current selection is all whitespaces.
757 it = context_unicode.begin();
758 std::advance(it, span.first);
759 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
760 if (!unilib.IsWhitespace(*it)) {
761 return span;
762 }
763 }
764
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200765 // Try moving left.
Tony Maka44b3082020-08-13 18:57:10 +0100766 CodepointSpan result = span;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200767 it = context_unicode.begin();
768 std::advance(it, span.first);
769 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
770 --result.first;
771 --it;
772 }
773 result.second = result.first + 1;
774 if (!unilib.IsWhitespace(*it)) {
775 return result;
776 }
777
778 // If moving left didn't find a non-whitespace character, just return the
779 // original span.
780 return span;
781}
782} // namespace internal
783
Tony Mak6c4cc672018-09-17 11:48:50 +0100784bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200785 return !span.classification.empty() &&
786 filtered_collections_annotation_.find(
787 span.classification[0].collection) !=
788 filtered_collections_annotation_.end();
789}
790
Tony Mak6c4cc672018-09-17 11:48:50 +0100791bool Annotator::FilteredForClassification(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200792 const ClassificationResult& classification) const {
793 return filtered_collections_classification_.find(classification.collection) !=
794 filtered_collections_classification_.end();
795}
796
Tony Mak6c4cc672018-09-17 11:48:50 +0100797bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200798 return !span.classification.empty() &&
799 filtered_collections_selection_.find(
800 span.classification[0].collection) !=
801 filtered_collections_selection_.end();
802}
803
Tony Mak378c1f52019-03-04 15:58:11 +0000804namespace {
805inline bool ClassifiedAsOther(
806 const std::vector<ClassificationResult>& classification) {
807 return !classification.empty() &&
808 classification[0].collection == Collections::Other();
809}
810
Tony Maka2a1ff42019-09-12 15:40:32 +0100811} // namespace
812
813float Annotator::GetPriorityScore(
814 const std::vector<ClassificationResult>& classification) const {
Tony Mak378c1f52019-03-04 15:58:11 +0000815 if (!classification.empty() && !ClassifiedAsOther(classification)) {
816 return classification[0].priority_score;
817 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +0100818 if (model_->triggering_options() != nullptr) {
819 return model_->triggering_options()->other_collection_priority_score();
820 } else {
821 return -1000.0;
822 }
Tony Mak378c1f52019-03-04 15:58:11 +0000823 }
824}
Tony Mak378c1f52019-03-04 15:58:11 +0000825
Tony Makdf54e742019-03-26 14:04:00 +0000826bool Annotator::VerifyRegexMatchCandidate(
827 const std::string& context, const VerificationOptions* verification_options,
828 const std::string& match, const UniLib::RegexMatcher* matcher) const {
829 if (verification_options == nullptr) {
830 return true;
831 }
832 if (verification_options->verify_luhn_checksum() &&
833 !VerifyLuhnChecksum(match)) {
834 return false;
835 }
836 const int lua_verifier = verification_options->lua_verifier();
837 if (lua_verifier >= 0) {
838 if (model_->regex_model()->lua_verifier() == nullptr ||
839 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
840 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
841 return false;
842 }
843 return VerifyMatch(
844 context, matcher,
845 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
846 }
847 return true;
848}
849
Tony Mak6c4cc672018-09-17 11:48:50 +0100850CodepointSpan Annotator::SuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100851 const std::string& context, CodepointSpan click_indices,
852 const SelectionOptions& options) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200853 CodepointSpan original_click_indices = click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100854 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100855 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200856 return original_click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100857 }
Tony Mak5a12b942020-05-01 12:41:31 +0100858 if (options.annotation_usecase !=
859 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
860 TC3_LOG(WARNING)
861 << "Invoking SuggestSelection, which is not supported in RAW mode.";
862 return original_click_indices;
863 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100864 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200865 return original_click_indices;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100866 }
867
Tony Makdf54e742019-03-26 14:04:00 +0000868 std::vector<Locale> detected_text_language_tags;
869 if (!ParseLocales(options.detected_text_language_tags,
870 &detected_text_language_tags)) {
871 TC3_LOG(WARNING)
872 << "Failed to parse the detected_text_language_tags in options: "
873 << options.detected_text_language_tags;
874 }
875 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
876 model_triggering_locales_,
877 /*default_value=*/true)) {
878 return original_click_indices;
879 }
880
Lukas Zilkadf710db2018-02-27 12:44:09 +0100881 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
882 /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200883
Tony Mak968412a2019-11-13 15:39:57 +0000884 if (!IsValidSpanInput(context_unicode, click_indices)) {
885 TC3_VLOG(1)
886 << "Trying to run SuggestSelection with invalid input, indices: "
887 << click_indices.first << " " << click_indices.second;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200888 return original_click_indices;
889 }
890
891 if (model_->snap_whitespace_selections()) {
892 // We want to expand a purely white-space selection to a multi-selection it
893 // would've been part of. But with this feature disabled we would do a no-
894 // op, because no token is found. Therefore, we need to modify the
895 // 'click_indices' a bit to include a part of the token, so that the click-
896 // finding logic finds the clicked token correctly. This modification is
897 // done by the following function. Note, that it's enough to check the left
898 // side of the current selection, because if the white-space is a part of a
Tony Mak6c4cc672018-09-17 11:48:50 +0100899 // multi-selection, necessarily both tokens - on the left and the right
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200900 // sides need to be selected. Thus snapping only to the left is sufficient
901 // (there's a check at the bottom that makes sure that if we snap to the
902 // left token but the result does not contain the initial white-space,
903 // returns the original indices).
904 click_indices = internal::SnapLeftIfWhitespaceSelection(
905 click_indices, context_unicode, *unilib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100906 }
907
Tony Maka44b3082020-08-13 18:57:10 +0100908 Annotations candidates;
909 // As we process a single string of context, the candidates will only
910 // contain one vector of AnnotatedSpan.
911 candidates.annotated_spans.resize(1);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100912 InterpreterManager interpreter_manager(selection_executor_.get(),
913 classification_executor_.get());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200914 std::vector<Token> tokens;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100915 if (!ModelSuggestSelection(context_unicode, click_indices,
Tony Makdf54e742019-03-26 14:04:00 +0000916 detected_text_language_tags, &interpreter_manager,
Tony Maka44b3082020-08-13 18:57:10 +0100917 &tokens, &candidates.annotated_spans[0])) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100918 TC3_LOG(ERROR) << "Model suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200919 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100920 }
Tony Maka44b3082020-08-13 18:57:10 +0100921 const std::unordered_set<std::string> set;
922 const EnabledEntityTypes is_entity_type_enabled(set);
923 if (!RegexChunk(context_unicode, selection_regex_patterns_,
924 /*is_serialized_entity_data_enabled=*/false,
925 is_entity_type_enabled, options.annotation_usecase,
926 &candidates.annotated_spans[0])) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100927 TC3_LOG(ERROR) << "Regex suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200928 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100929 }
Tony Maka44b3082020-08-13 18:57:10 +0100930 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
931 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
932 options.locales, ModeFlag_SELECTION,
933 options.annotation_usecase,
934 /*is_serialized_entity_data_enabled=*/false,
935 &candidates.annotated_spans[0])) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100936 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
937 return original_click_indices;
938 }
Tony Mak378c1f52019-03-04 15:58:11 +0000939 if (knowledge_engine_ != nullptr &&
Tony Maka2a1ff42019-09-12 15:40:32 +0100940 !knowledge_engine_->Chunk(context, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +0100941 options.location_context, Permissions(),
Tony Maka44b3082020-08-13 18:57:10 +0100942 AnnotateMode::kEntityAnnotation, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100943 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200944 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100945 }
Tony Mak378c1f52019-03-04 15:58:11 +0000946 if (contact_engine_ != nullptr &&
Tony Maka44b3082020-08-13 18:57:10 +0100947 !contact_engine_->Chunk(context_unicode, tokens,
948 &candidates.annotated_spans[0])) {
Tony Mak854015a2019-01-16 15:56:48 +0000949 TC3_LOG(ERROR) << "Contact suggest selection failed.";
950 return original_click_indices;
951 }
Tony Mak378c1f52019-03-04 15:58:11 +0000952 if (installed_app_engine_ != nullptr &&
Tony Maka44b3082020-08-13 18:57:10 +0100953 !installed_app_engine_->Chunk(context_unicode, tokens,
954 &candidates.annotated_spans[0])) {
Tony Makd9446602019-02-20 18:25:39 +0000955 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
956 return original_click_indices;
957 }
Tony Mak378c1f52019-03-04 15:58:11 +0000958 if (number_annotator_ != nullptr &&
959 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
Tony Maka44b3082020-08-13 18:57:10 +0100960 &candidates.annotated_spans[0])) {
Tony Mak378c1f52019-03-04 15:58:11 +0000961 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
962 return original_click_indices;
963 }
Tony Makad2e22d2019-03-20 17:35:13 +0000964 if (duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +0000965 !duration_annotator_->FindAll(context_unicode, tokens,
Tony Maka44b3082020-08-13 18:57:10 +0100966 options.annotation_usecase,
967 &candidates.annotated_spans[0])) {
Tony Makad2e22d2019-03-20 17:35:13 +0000968 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
969 return original_click_indices;
970 }
Tony Mak76d80962020-01-08 17:30:51 +0000971 if (person_name_engine_ != nullptr &&
Tony Maka44b3082020-08-13 18:57:10 +0100972 !person_name_engine_->Chunk(context_unicode, tokens,
973 &candidates.annotated_spans[0])) {
Tony Mak76d80962020-01-08 17:30:51 +0000974 TC3_LOG(ERROR) << "Person name suggest selection failed.";
975 return original_click_indices;
976 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100977
Tony Mak21460022020-03-12 18:29:35 +0000978 AnnotatedSpan grammar_suggested_span;
979 if (grammar_annotator_ != nullptr &&
980 grammar_annotator_->SuggestSelection(detected_text_language_tags,
981 context_unicode, click_indices,
982 &grammar_suggested_span)) {
Tony Maka44b3082020-08-13 18:57:10 +0100983 candidates.annotated_spans[0].push_back(grammar_suggested_span);
984 }
985
986 if (pod_ner_annotator_ != nullptr && options.use_pod_ner) {
987 candidates.annotated_spans[0].push_back(
988 pod_ner_annotator_->SuggestSelection(context_unicode, click_indices));
Tony Mak21460022020-03-12 18:29:35 +0000989 }
990
Tony Mak5a12b942020-05-01 12:41:31 +0100991 if (experimental_annotator_ != nullptr) {
Tony Maka44b3082020-08-13 18:57:10 +0100992 candidates.annotated_spans[0].push_back(
993 experimental_annotator_->SuggestSelection(context_unicode,
994 click_indices));
Tony Mak5a12b942020-05-01 12:41:31 +0100995 }
996
Lukas Zilkab23e2122018-02-09 10:25:19 +0100997 // Sort candidates according to their position in the input, so that the next
998 // code can assume that any connected component of overlapping spans forms a
999 // contiguous block.
Tony Maka44b3082020-08-13 18:57:10 +01001000 std::sort(candidates.annotated_spans[0].begin(),
1001 candidates.annotated_spans[0].end(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01001002 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
1003 return a.span.first < b.span.first;
1004 });
1005
1006 std::vector<int> candidate_indices;
Tony Maka44b3082020-08-13 18:57:10 +01001007 if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001008 detected_text_language_tags, options.annotation_usecase,
1009 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001010 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001011 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001012 }
1013
Tony Mak378c1f52019-03-04 15:58:11 +00001014 std::sort(candidate_indices.begin(), candidate_indices.end(),
Tony Maka2a1ff42019-09-12 15:40:32 +01001015 [this, &candidates](int a, int b) {
Tony Maka44b3082020-08-13 18:57:10 +01001016 return GetPriorityScore(
1017 candidates.annotated_spans[0][a].classification) >
1018 GetPriorityScore(
1019 candidates.annotated_spans[0][b].classification);
Tony Mak378c1f52019-03-04 15:58:11 +00001020 });
1021
Lukas Zilkab23e2122018-02-09 10:25:19 +01001022 for (const int i : candidate_indices) {
Tony Maka44b3082020-08-13 18:57:10 +01001023 if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
1024 SpansOverlap(candidates.annotated_spans[0][i].span,
1025 original_click_indices)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001026 // Run model classification if not present but requested and there's a
1027 // classification collection filter specified.
Tony Maka44b3082020-08-13 18:57:10 +01001028 if (candidates.annotated_spans[0][i].classification.empty() &&
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001029 model_->selection_options()->always_classify_suggested_selection() &&
1030 !filtered_collections_selection_.empty()) {
Tony Maka44b3082020-08-13 18:57:10 +01001031 if (!ModelClassifyText(
1032 context, detected_text_language_tags,
1033 candidates.annotated_spans[0][i].span, &interpreter_manager,
1034 /*embedding_cache=*/nullptr,
1035 &candidates.annotated_spans[0][i].classification)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001036 return original_click_indices;
1037 }
1038 }
1039
1040 // Ignore if span classification is filtered.
Tony Maka44b3082020-08-13 18:57:10 +01001041 if (FilteredForSelection(candidates.annotated_spans[0][i])) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001042 return original_click_indices;
1043 }
1044
Tony Maka44b3082020-08-13 18:57:10 +01001045 return candidates.annotated_spans[0][i].span;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001046 }
1047 }
1048
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001049 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001050}
1051
1052namespace {
1053// Helper function that returns the index of the first candidate that
1054// transitively does not overlap with the candidate on 'start_index'. If the end
1055// of 'candidates' is reached, it returns the index that points right behind the
1056// array.
1057int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1058 int start_index) {
1059 int first_non_overlapping = start_index + 1;
1060 CodepointSpan conflicting_span = candidates[start_index].span;
1061 while (
1062 first_non_overlapping < candidates.size() &&
1063 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1064 // Grow the span to include the current one.
1065 conflicting_span.second = std::max(
1066 conflicting_span.second, candidates[first_non_overlapping].span.second);
1067
1068 ++first_non_overlapping;
1069 }
1070 return first_non_overlapping;
1071}
1072} // namespace
1073
Tony Mak378c1f52019-03-04 15:58:11 +00001074bool Annotator::ResolveConflicts(
1075 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1076 const std::vector<Token>& cached_tokens,
1077 const std::vector<Locale>& detected_text_language_tags,
1078 AnnotationUsecase annotation_usecase,
1079 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001080 result->clear();
1081 result->reserve(candidates.size());
1082 for (int i = 0; i < candidates.size();) {
1083 int first_non_overlapping =
1084 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1085
1086 const bool conflict_found = first_non_overlapping != (i + 1);
1087 if (conflict_found) {
1088 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00001089 if (!ResolveConflict(context, cached_tokens, candidates,
1090 detected_text_language_tags, i,
1091 first_non_overlapping, annotation_usecase,
1092 interpreter_manager, &candidate_indices)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001093 return false;
1094 }
1095 result->insert(result->end(), candidate_indices.begin(),
1096 candidate_indices.end());
1097 } else {
1098 result->push_back(i);
1099 }
1100
1101 // Skip over the whole conflicting group/go to next candidate.
1102 i = first_non_overlapping;
1103 }
1104 return true;
1105}
1106
1107namespace {
Tony Mak448b5862019-03-22 13:36:41 +00001108// Returns true, if the given two sources do conflict in given annotation
1109// usecase.
1110// - In SMART usecase, all sources do conflict, because there's only 1 possible
1111// annotation for a given span.
1112// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1113// and duration), while others not (e.g. duration and number).
1114bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1115 const AnnotatedSpan::Source source1,
1116 const AnnotatedSpan::Source source2) {
1117 uint32 source_mask =
1118 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1119
Tony Mak378c1f52019-03-04 15:58:11 +00001120 switch (annotation_usecase) {
1121 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
Tony Mak448b5862019-03-22 13:36:41 +00001122 // In the SMART mode, all annotations conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001123 return true;
Tony Mak448b5862019-03-22 13:36:41 +00001124
Tony Mak378c1f52019-03-04 15:58:11 +00001125 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
Tony Mak448b5862019-03-22 13:36:41 +00001126 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1127 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1128 // hours" (duration).
1129 if ((source_mask &
1130 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1131 (source_mask &
1132 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1133 return false;
Tony Mak378c1f52019-03-04 15:58:11 +00001134 }
Tony Mak448b5862019-03-22 13:36:41 +00001135
1136 // A KNOWLEDGE entity does not conflict with anything.
1137 if ((source_mask &
1138 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1139 return false;
1140 }
1141
Tony Makd0ae7c62020-03-27 13:58:00 +00001142 // A PERSONNAME entity does not conflict with anything.
1143 if ((source_mask &
1144 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1145 return false;
1146 }
1147
Tony Mak448b5862019-03-22 13:36:41 +00001148 // Entities from other sources can conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001149 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001150 }
1151}
1152} // namespace
1153
Tony Mak378c1f52019-03-04 15:58:11 +00001154bool Annotator::ResolveConflict(
1155 const std::string& context, const std::vector<Token>& cached_tokens,
1156 const std::vector<AnnotatedSpan>& candidates,
1157 const std::vector<Locale>& detected_text_language_tags, int start_index,
1158 int end_index, AnnotationUsecase annotation_usecase,
1159 InterpreterManager* interpreter_manager,
1160 std::vector<int>* chosen_indices) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001161 std::vector<int> conflicting_indices;
Tony Mak76d80962020-01-08 17:30:51 +00001162 std::unordered_map<int, std::pair<float, int>> scores_lengths;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001163 for (int i = start_index; i < end_index; ++i) {
1164 conflicting_indices.push_back(i);
1165 if (!candidates[i].classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001166 scores_lengths[i] = {
1167 GetPriorityScore(candidates[i].classification),
1168 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001169 continue;
1170 }
1171
1172 // OPTIMIZATION: So that we don't have to classify all the ML model
1173 // spans apriori, we wait until we get here, when they conflict with
1174 // something and we need the actual classification scores. So if the
1175 // candidate conflicts and comes from the model, we need to run a
1176 // classification to determine its priority:
1177 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001178 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1179 candidates[i].span, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001180 /*embedding_cache=*/nullptr, &classification)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001181 return false;
1182 }
1183
1184 if (!classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001185 scores_lengths[i] = {
1186 GetPriorityScore(classification),
1187 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001188 }
1189 }
1190
Tony Mak5a12b942020-05-01 12:41:31 +01001191 std::sort(
1192 conflicting_indices.begin(), conflicting_indices.end(),
1193 [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1194 if (scores_lengths[i].first == scores_lengths[j].first &&
1195 prioritize_longest_annotation_) {
1196 return scores_lengths[i].second > scores_lengths[j].second;
1197 }
1198 return scores_lengths[i].first > scores_lengths[j].first;
1199 });
Lukas Zilkab23e2122018-02-09 10:25:19 +01001200
Tony Mak448b5862019-03-22 13:36:41 +00001201 // Here we keep a set of indices that were chosen, per-source, to enable
1202 // effective computation.
1203 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1204 chosen_indices_for_source_map;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001205
1206 // Greedily place the candidates if they don't conflict with the already
1207 // placed ones.
1208 for (int i = 0; i < conflicting_indices.size(); ++i) {
1209 const int considered_candidate = conflicting_indices[i];
Tony Mak448b5862019-03-22 13:36:41 +00001210
1211 // See if there is a conflict between the candidate and all already placed
1212 // candidates.
1213 bool conflict = false;
1214 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1215 for (auto& source_set_pair : chosen_indices_for_source_map) {
1216 if (source_set_pair.first == candidates[considered_candidate].source) {
1217 chosen_indices_for_source_ptr = &source_set_pair.second;
1218 }
1219
Tony Mak5a12b942020-05-01 12:41:31 +01001220 const bool needs_conflict_resolution =
1221 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1222 (annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1223 do_conflict_resolution_in_raw_mode_);
1224 if (needs_conflict_resolution &&
1225 DoSourcesConflict(annotation_usecase, source_set_pair.first,
Tony Mak448b5862019-03-22 13:36:41 +00001226 candidates[considered_candidate].source) &&
1227 DoesCandidateConflict(considered_candidate, candidates,
1228 source_set_pair.second)) {
1229 conflict = true;
1230 break;
1231 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001232 }
Tony Mak448b5862019-03-22 13:36:41 +00001233
1234 // Skip the candidate if a conflict was found.
1235 if (conflict) {
1236 continue;
1237 }
1238
1239 // If the set of indices for the current source doesn't exist yet,
1240 // initialize it.
1241 if (chosen_indices_for_source_ptr == nullptr) {
1242 SortedIntSet new_set([&candidates](int a, int b) {
1243 return candidates[a].span.first < candidates[b].span.first;
1244 });
1245 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1246 std::move(new_set);
1247 chosen_indices_for_source_ptr =
1248 &chosen_indices_for_source_map[candidates[considered_candidate]
1249 .source];
1250 }
1251
1252 // Place the candidate to the output and to the per-source conflict set.
1253 chosen_indices->push_back(considered_candidate);
1254 chosen_indices_for_source_ptr->insert(considered_candidate);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001255 }
1256
Tony Mak378c1f52019-03-04 15:58:11 +00001257 std::sort(chosen_indices->begin(), chosen_indices->end());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001258
1259 return true;
1260}
1261
Tony Mak6c4cc672018-09-17 11:48:50 +01001262bool Annotator::ModelSuggestSelection(
Tony Maka44b3082020-08-13 18:57:10 +01001263 const UnicodeText& context_unicode, const CodepointSpan& click_indices,
Tony Makdf54e742019-03-26 14:04:00 +00001264 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001265 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001266 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001267 if (model_->triggering_options() == nullptr ||
1268 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1269 return true;
1270 }
1271
Tony Makdf54e742019-03-26 14:04:00 +00001272 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1273 ml_model_triggering_locales_,
1274 /*default_value=*/true)) {
1275 return true;
1276 }
1277
Lukas Zilka21d8c982018-01-24 11:11:20 +01001278 int click_pos;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001279 *tokens = selection_feature_processor_->Tokenize(context_unicode);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001280 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001281 context_unicode, click_indices,
1282 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001283 tokens, &click_pos);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001284 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001285 TC3_VLOG(1) << "Could not calculate the click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001286 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001287 }
1288
1289 const int symmetry_context_size =
1290 model_->selection_options()->symmetry_context_size();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001291 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
Lukas Zilkab23e2122018-02-09 10:25:19 +01001292 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1293 ->bounds_sensitive_features();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001294
1295 // The symmetry context span is the clicked token with symmetry_context_size
1296 // tokens on either side.
Tony Maka44b3082020-08-13 18:57:10 +01001297 const TokenSpan symmetry_context_span =
1298 IntersectTokenSpans(TokenSpan(click_pos).Expand(
1299 /*num_tokens_left=*/symmetry_context_size,
1300 /*num_tokens_right=*/symmetry_context_size),
1301 AllOf(*tokens));
Lukas Zilka21d8c982018-01-24 11:11:20 +01001302
Lukas Zilkab23e2122018-02-09 10:25:19 +01001303 // Compute the extraction span based on the model type.
1304 TokenSpan extraction_span;
1305 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1306 // The extraction span is the symmetry context span expanded to include
1307 // max_selection_span tokens on either side, which is how far a selection
1308 // can stretch from the click, plus a relevant number of tokens outside of
1309 // the bounds of the selection.
1310 const int max_selection_span =
1311 selection_feature_processor_->GetOptions()->max_selection_span();
Tony Maka44b3082020-08-13 18:57:10 +01001312 extraction_span = symmetry_context_span.Expand(
1313 /*num_tokens_left=*/max_selection_span +
1314 bounds_sensitive_features->num_tokens_before(),
1315 /*num_tokens_right=*/max_selection_span +
1316 bounds_sensitive_features->num_tokens_after());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001317 } else {
1318 // The extraction span is the symmetry context span expanded to include
1319 // context_size tokens on either side.
1320 const int context_size =
1321 selection_feature_processor_->GetOptions()->context_size();
Tony Maka44b3082020-08-13 18:57:10 +01001322 extraction_span = symmetry_context_span.Expand(
1323 /*num_tokens_left=*/context_size,
1324 /*num_tokens_right=*/context_size);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001325 }
Tony Maka44b3082020-08-13 18:57:10 +01001326 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001327
Lukas Zilka434442d2018-04-25 11:38:51 +02001328 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1329 *tokens, extraction_span)) {
1330 return true;
1331 }
1332
Lukas Zilkab23e2122018-02-09 10:25:19 +01001333 std::unique_ptr<CachedFeatures> cached_features;
1334 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001335 *tokens, extraction_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001336 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1337 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001338 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001339 selection_feature_processor_->EmbeddingSize() +
1340 selection_feature_processor_->DenseFeaturesCount(),
1341 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001342 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001343 return false;
1344 }
1345
1346 // Produce selection model candidates.
1347 std::vector<TokenSpan> chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001348 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001349 interpreter_manager->SelectionInterpreter(), *cached_features,
1350 &chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001351 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001352 return false;
1353 }
1354
1355 for (const TokenSpan& chunk : chunks) {
1356 AnnotatedSpan candidate;
1357 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001358 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001359 if (model_->selection_options()->strip_unpaired_brackets()) {
1360 candidate.span =
1361 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1362 }
1363
1364 // Only output non-empty spans.
1365 if (candidate.span.first != candidate.span.second) {
1366 result->push_back(candidate);
1367 }
1368 }
1369 return true;
1370}
1371
Tony Mak6c4cc672018-09-17 11:48:50 +01001372bool Annotator::ModelClassifyText(
Tony Mak378c1f52019-03-04 15:58:11 +00001373 const std::string& context,
1374 const std::vector<Locale>& detected_text_language_tags,
Tony Maka44b3082020-08-13 18:57:10 +01001375 const CodepointSpan& selection_indices,
1376 InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001377 FeatureProcessor::EmbeddingCache* embedding_cache,
1378 std::vector<ClassificationResult>* classification_results) const {
Tony Mak378c1f52019-03-04 15:58:11 +00001379 return ModelClassifyText(context, {}, detected_text_language_tags,
1380 selection_indices, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001381 embedding_cache, classification_results);
1382}
1383
1384namespace internal {
1385std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
Tony Maka44b3082020-08-13 18:57:10 +01001386 const CodepointSpan& selection_indices,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001387 TokenSpan tokens_around_selection_to_copy) {
1388 const auto first_selection_token = std::upper_bound(
1389 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1390 [](int selection_start, const Token& token) {
1391 return selection_start < token.end;
1392 });
1393 const auto last_selection_token = std::lower_bound(
1394 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1395 [](const Token& token, int selection_end) {
1396 return token.start < selection_end;
1397 });
1398
1399 const int64 first_token = std::max(
1400 static_cast<int64>(0),
1401 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1402 tokens_around_selection_to_copy.first));
1403 const int64 last_token = std::min(
1404 static_cast<int64>(cached_tokens.size()),
1405 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1406 tokens_around_selection_to_copy.second));
1407
1408 std::vector<Token> tokens;
1409 tokens.reserve(last_token - first_token);
1410 for (int i = first_token; i < last_token; ++i) {
1411 tokens.push_back(cached_tokens[i]);
1412 }
1413 return tokens;
1414}
1415} // namespace internal
1416
Tony Mak6c4cc672018-09-17 11:48:50 +01001417TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001418 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1419 bounds_sensitive_features =
1420 classification_feature_processor_->GetOptions()
1421 ->bounds_sensitive_features();
1422 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1423 // The extraction span is the selection span expanded to include a relevant
1424 // number of tokens outside of the bounds of the selection.
1425 return {bounds_sensitive_features->num_tokens_before(),
1426 bounds_sensitive_features->num_tokens_after()};
1427 } else {
1428 // The extraction span is the clicked token with context_size tokens on
1429 // either side.
1430 const int context_size =
1431 selection_feature_processor_->GetOptions()->context_size();
1432 return {context_size, context_size};
1433 }
1434}
1435
Tony Mak378c1f52019-03-04 15:58:11 +00001436namespace {
1437// Sorts the classification results from high score to low score.
1438void SortClassificationResults(
1439 std::vector<ClassificationResult>* classification_results) {
1440 std::sort(classification_results->begin(), classification_results->end(),
1441 [](const ClassificationResult& a, const ClassificationResult& b) {
1442 return a.score > b.score;
1443 });
1444}
1445} // namespace
1446
Tony Mak6c4cc672018-09-17 11:48:50 +01001447bool Annotator::ModelClassifyText(
Lukas Zilkaba849e72018-03-08 14:48:21 +01001448 const std::string& context, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001449 const std::vector<Locale>& detected_text_language_tags,
Tony Maka44b3082020-08-13 18:57:10 +01001450 const CodepointSpan& selection_indices,
1451 InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001452 FeatureProcessor::EmbeddingCache* embedding_cache,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001453 std::vector<ClassificationResult>* classification_results) const {
1454 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00001455 return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1456 selection_indices, interpreter_manager,
1457 embedding_cache, classification_results, &tokens);
1458}
1459
1460bool Annotator::ModelClassifyText(
1461 const std::string& context, const std::vector<Token>& cached_tokens,
1462 const std::vector<Locale>& detected_text_language_tags,
Tony Maka44b3082020-08-13 18:57:10 +01001463 const CodepointSpan& selection_indices,
1464 InterpreterManager* interpreter_manager,
Tony Mak378c1f52019-03-04 15:58:11 +00001465 FeatureProcessor::EmbeddingCache* embedding_cache,
1466 std::vector<ClassificationResult>* classification_results,
1467 std::vector<Token>* tokens) const {
1468 if (model_->triggering_options() == nullptr ||
1469 !(model_->triggering_options()->enabled_modes() &
1470 ModeFlag_CLASSIFICATION)) {
1471 return true;
1472 }
1473
Tony Makdf54e742019-03-26 14:04:00 +00001474 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1475 ml_model_triggering_locales_,
1476 /*default_value=*/true)) {
1477 return true;
1478 }
1479
Lukas Zilkaba849e72018-03-08 14:48:21 +01001480 if (cached_tokens.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001481 *tokens = classification_feature_processor_->Tokenize(context);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001482 } else {
Tony Mak378c1f52019-03-04 15:58:11 +00001483 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1484 ClassifyTextUpperBoundNeededTokens());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001485 }
1486
Lukas Zilkab23e2122018-02-09 10:25:19 +01001487 int click_pos;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001488 classification_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001489 context, selection_indices,
1490 classification_feature_processor_->GetOptions()
1491 ->only_use_line_with_click(),
Tony Mak378c1f52019-03-04 15:58:11 +00001492 tokens, &click_pos);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001493 const TokenSpan selection_token_span =
Tony Mak378c1f52019-03-04 15:58:11 +00001494 CodepointSpanToTokenSpan(*tokens, selection_indices);
Tony Maka44b3082020-08-13 18:57:10 +01001495 const int selection_num_tokens = selection_token_span.Size();
Lukas Zilka434442d2018-04-25 11:38:51 +02001496 if (model_->classification_options()->max_num_tokens() > 0 &&
1497 model_->classification_options()->max_num_tokens() <
1498 selection_num_tokens) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001499 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001500 return true;
1501 }
1502
Lukas Zilkab23e2122018-02-09 10:25:19 +01001503 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1504 bounds_sensitive_features =
1505 classification_feature_processor_->GetOptions()
1506 ->bounds_sensitive_features();
1507 if (selection_token_span.first == kInvalidIndex ||
1508 selection_token_span.second == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001509 TC3_LOG(ERROR) << "Could not determine span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001510 return false;
1511 }
1512
1513 // Compute the extraction span based on the model type.
1514 TokenSpan extraction_span;
1515 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1516 // The extraction span is the selection span expanded to include a relevant
1517 // number of tokens outside of the bounds of the selection.
Tony Maka44b3082020-08-13 18:57:10 +01001518 extraction_span = selection_token_span.Expand(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001519 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1520 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1521 } else {
1522 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001523 TC3_LOG(ERROR) << "Couldn't choose a click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001524 return false;
1525 }
1526 // The extraction span is the clicked token with context_size tokens on
1527 // either side.
1528 const int context_size =
Lukas Zilkaba849e72018-03-08 14:48:21 +01001529 classification_feature_processor_->GetOptions()->context_size();
Tony Maka44b3082020-08-13 18:57:10 +01001530 extraction_span = TokenSpan(click_pos).Expand(
1531 /*num_tokens_left=*/context_size,
1532 /*num_tokens_right=*/context_size);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001533 }
Tony Maka44b3082020-08-13 18:57:10 +01001534 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
Lukas Zilka21d8c982018-01-24 11:11:20 +01001535
Lukas Zilka434442d2018-04-25 11:38:51 +02001536 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak378c1f52019-03-04 15:58:11 +00001537 *tokens, extraction_span)) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001538 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001539 return true;
1540 }
1541
Lukas Zilka21d8c982018-01-24 11:11:20 +01001542 std::unique_ptr<CachedFeatures> cached_features;
1543 if (!classification_feature_processor_->ExtractFeatures(
Tony Mak378c1f52019-03-04 15:58:11 +00001544 *tokens, extraction_span, selection_indices,
1545 embedding_executor_.get(), embedding_cache,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001546 classification_feature_processor_->EmbeddingSize() +
1547 classification_feature_processor_->DenseFeaturesCount(),
1548 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001549 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001550 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001551 }
1552
Lukas Zilkab23e2122018-02-09 10:25:19 +01001553 std::vector<float> features;
1554 features.reserve(cached_features->OutputFeaturesSize());
1555 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1556 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1557 &features);
1558 } else {
1559 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001560 }
1561
Lukas Zilkaba849e72018-03-08 14:48:21 +01001562 TensorView<float> logits = classification_executor_->ComputeLogits(
1563 TensorView<float>(features.data(),
1564 {1, static_cast<int>(features.size())}),
1565 interpreter_manager->ClassificationInterpreter());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001566 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001567 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001568 return false;
1569 }
1570
1571 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1572 logits.dim(1) != classification_feature_processor_->NumCollections()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001573 TC3_LOG(ERROR) << "Mismatching output";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001574 return false;
1575 }
1576
1577 const std::vector<float> scores =
1578 ComputeSoftmax(logits.data(), logits.dim(1));
1579
Tony Mak81e52422019-04-30 09:34:45 +01001580 if (scores.empty()) {
1581 *classification_results = {{Collections::Other(), 1.0}};
1582 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001583 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001584
Tony Mak81e52422019-04-30 09:34:45 +01001585 const int best_score_index =
1586 std::max_element(scores.begin(), scores.end()) - scores.begin();
1587 const std::string top_collection =
1588 classification_feature_processor_->LabelToCollection(best_score_index);
1589
1590 // Sanity checks.
1591 if (top_collection == Collections::Phone()) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001592 const int digit_count = CountDigits(context, selection_indices);
1593 if (digit_count <
1594 model_->classification_options()->phone_min_num_digits() ||
1595 digit_count >
1596 model_->classification_options()->phone_max_num_digits()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001597 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001598 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001599 }
Tony Mak81e52422019-04-30 09:34:45 +01001600 } else if (top_collection == Collections::Address()) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001601 if (selection_num_tokens <
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001602 model_->classification_options()->address_min_num_tokens()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001603 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001604 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001605 }
Tony Mak81e52422019-04-30 09:34:45 +01001606 } else if (top_collection == Collections::Dictionary()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001607 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1608 dictionary_locales_,
Tony Makdf54e742019-03-26 14:04:00 +00001609 /*default_value=*/false)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001610 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001611 return true;
Tony Mak378c1f52019-03-04 15:58:11 +00001612 }
1613 }
Tony Mak81e52422019-04-30 09:34:45 +01001614
Tony Makd99d58c2020-03-19 21:52:02 +00001615 *classification_results = {{top_collection, /*arg_score=*/1.0,
1616 /*arg_priority_score=*/scores[best_score_index]}};
1617
1618 // For some entities, we might want to clamp the priority score, for better
1619 // conflict resolution between entities.
1620 if (model_->triggering_options() != nullptr &&
1621 model_->triggering_options()->collection_to_priority() != nullptr) {
1622 if (auto entry =
1623 model_->triggering_options()->collection_to_priority()->LookupByKey(
1624 top_collection.c_str())) {
1625 (*classification_results)[0].priority_score *= entry->value();
1626 }
1627 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001628 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001629}
1630
Tony Mak6c4cc672018-09-17 11:48:50 +01001631bool Annotator::RegexClassifyText(
Tony Maka44b3082020-08-13 18:57:10 +01001632 const std::string& context, const CodepointSpan& selection_indices,
Tony Mak378c1f52019-03-04 15:58:11 +00001633 std::vector<ClassificationResult>* classification_result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001634 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001635 UTF8ToUnicodeText(context, /*do_copy=*/false)
1636 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001637 const UnicodeText selection_text_unicode(
1638 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1639
1640 // Check whether any of the regular expressions match.
1641 for (const int pattern_id : classification_regex_patterns_) {
1642 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1643 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1644 regex_pattern.pattern->Matcher(selection_text_unicode);
1645 int status = UniLib::RegexMatcher::kNoError;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001646 bool matches;
Tony Mak854015a2019-01-16 15:56:48 +00001647 if (regex_pattern.config->use_approximate_matching()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001648 matches = matcher->ApproximatelyMatches(&status);
1649 } else {
1650 matches = matcher->Matches(&status);
1651 }
1652 if (status != UniLib::RegexMatcher::kNoError) {
1653 return false;
1654 }
Tony Makdf54e742019-03-26 14:04:00 +00001655 if (matches && VerifyRegexMatchCandidate(
1656 context, regex_pattern.config->verification_options(),
1657 selection_text, matcher.get())) {
Tony Mak378c1f52019-03-04 15:58:11 +00001658 classification_result->push_back(
1659 {regex_pattern.config->collection_name()->str(),
1660 regex_pattern.config->target_classification_score(),
1661 regex_pattern.config->priority_score()});
Tony Makd9446602019-02-20 18:25:39 +00001662 if (!SerializedEntityDataFromRegexMatch(
1663 regex_pattern.config, matcher.get(),
Tony Mak378c1f52019-03-04 15:58:11 +00001664 &classification_result->back().serialized_entity_data)) {
Tony Makd9446602019-02-20 18:25:39 +00001665 TC3_LOG(ERROR) << "Could not get entity data.";
1666 return false;
1667 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001668 }
1669 }
1670
Tony Mak378c1f52019-03-04 15:58:11 +00001671 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001672}
1673
Tony Mak5dc5e112019-02-01 14:52:10 +00001674namespace {
1675std::string PickCollectionForDatetime(
1676 const DatetimeParseResult& datetime_parse_result) {
1677 switch (datetime_parse_result.granularity) {
1678 case GRANULARITY_HOUR:
1679 case GRANULARITY_MINUTE:
1680 case GRANULARITY_SECOND:
1681 return Collections::DateTime();
1682 default:
1683 return Collections::Date();
1684 }
1685}
Tony Mak83d2de62019-04-10 16:12:15 +01001686
Tony Mak5dc5e112019-02-01 14:52:10 +00001687} // namespace
1688
Tony Mak6c4cc672018-09-17 11:48:50 +01001689bool Annotator::DatetimeClassifyText(
Tony Maka44b3082020-08-13 18:57:10 +01001690 const std::string& context, const CodepointSpan& selection_indices,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001691 const ClassificationOptions& options,
Tony Mak854015a2019-01-16 15:56:48 +00001692 std::vector<ClassificationResult>* classification_results) const {
Tony Mak63959242020-02-07 18:31:16 +00001693 if (!datetime_parser_ && !cfg_datetime_parser_) {
Tony Makd99d58c2020-03-19 21:52:02 +00001694 return true;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001695 }
1696
Lukas Zilkab23e2122018-02-09 10:25:19 +01001697 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001698 UTF8ToUnicodeText(context, /*do_copy=*/false)
1699 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001700
1701 std::vector<DatetimeParseResultSpan> datetime_spans;
Tony Makd99d58c2020-03-19 21:52:02 +00001702
Tony Mak63959242020-02-07 18:31:16 +00001703 if (cfg_datetime_parser_) {
1704 if (!(model_->grammar_datetime_model()->enabled_modes() &
1705 ModeFlag_CLASSIFICATION)) {
1706 return true;
1707 }
1708 std::vector<Locale> parsed_locales;
1709 ParseLocales(options.locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00001710 cfg_datetime_parser_->Parse(
1711 selection_text,
1712 ToDateAnnotationOptions(
1713 model_->grammar_datetime_model()->annotation_options(),
1714 options.reference_timezone, options.reference_time_ms_utc),
1715 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00001716 }
1717
1718 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00001719 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1720 options.reference_timezone, options.locales,
1721 ModeFlag_CLASSIFICATION,
1722 options.annotation_usecase,
1723 /*anchor_start_end=*/true, &datetime_spans)) {
1724 TC3_LOG(ERROR) << "Error during parsing datetime.";
1725 return false;
1726 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001727 }
Tony Makd99d58c2020-03-19 21:52:02 +00001728
Lukas Zilkab23e2122018-02-09 10:25:19 +01001729 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1730 // Only consider the result valid if the selection and extracted datetime
1731 // spans exactly match.
Tony Maka44b3082020-08-13 18:57:10 +01001732 if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1733 datetime_span.span.second + selection_indices.first) ==
Lukas Zilkab23e2122018-02-09 10:25:19 +01001734 selection_indices) {
Tony Mak854015a2019-01-16 15:56:48 +00001735 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1736 classification_results->emplace_back(
Tony Mak5dc5e112019-02-01 14:52:10 +00001737 PickCollectionForDatetime(parse_result),
1738 datetime_span.target_classification_score);
Tony Mak854015a2019-01-16 15:56:48 +00001739 classification_results->back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01001740 classification_results->back().serialized_entity_data =
1741 CreateDatetimeSerializedEntityData(parse_result);
Tony Mak378c1f52019-03-04 15:58:11 +00001742 classification_results->back().priority_score =
1743 datetime_span.priority_score;
Tony Mak854015a2019-01-16 15:56:48 +00001744 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001745 return true;
1746 }
1747 }
Tony Mak378c1f52019-03-04 15:58:11 +00001748 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001749}
1750
Tony Mak6c4cc672018-09-17 11:48:50 +01001751std::vector<ClassificationResult> Annotator::ClassifyText(
Tony Maka44b3082020-08-13 18:57:10 +01001752 const std::string& context, const CodepointSpan& selection_indices,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001753 const ClassificationOptions& options) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01001754 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001755 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001756 return {};
1757 }
Tony Mak5a12b942020-05-01 12:41:31 +01001758 if (options.annotation_usecase !=
1759 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1760 TC3_LOG(WARNING)
1761 << "Invoking ClassifyText, which is not supported in RAW mode.";
1762 return {};
1763 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001764 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1765 return {};
1766 }
1767
Tony Makdf54e742019-03-26 14:04:00 +00001768 std::vector<Locale> detected_text_language_tags;
1769 if (!ParseLocales(options.detected_text_language_tags,
1770 &detected_text_language_tags)) {
1771 TC3_LOG(WARNING)
1772 << "Failed to parse the detected_text_language_tags in options: "
1773 << options.detected_text_language_tags;
1774 }
1775 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1776 model_triggering_locales_,
1777 /*default_value=*/true)) {
1778 return {};
1779 }
1780
Tony Mak968412a2019-11-13 15:39:57 +00001781 if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
1782 selection_indices)) {
1783 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
Tony Maka44b3082020-08-13 18:57:10 +01001784 << selection_indices.first << " " << selection_indices.second;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001785 return {};
1786 }
1787
Tony Mak378c1f52019-03-04 15:58:11 +00001788 // We'll accumulate a list of candidates, and pick the best candidate in the
1789 // end.
1790 std::vector<AnnotatedSpan> candidates;
1791
Tony Mak6c4cc672018-09-17 11:48:50 +01001792 // Try the knowledge engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001793 // TODO(b/126579108): Propagate error status.
Tony Mak6c4cc672018-09-17 11:48:50 +01001794 ClassificationResult knowledge_result;
Tony Mak63959242020-02-07 18:31:16 +00001795 if (knowledge_engine_ &&
1796 knowledge_engine_->ClassifyText(
1797 context, selection_indices, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +01001798 options.location_context, Permissions(), &knowledge_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001799 candidates.push_back({selection_indices, {knowledge_result}});
1800 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
Tony Mak854015a2019-01-16 15:56:48 +00001801 }
1802
Tony Maka2a1ff42019-09-12 15:40:32 +01001803 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1804
Tony Mak854015a2019-01-16 15:56:48 +00001805 // Try the contact engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001806 // TODO(b/126579108): Propagate error status.
Tony Mak854015a2019-01-16 15:56:48 +00001807 ClassificationResult contact_result;
1808 if (contact_engine_ && contact_engine_->ClassifyText(
1809 context, selection_indices, &contact_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001810 candidates.push_back({selection_indices, {contact_result}});
Tony Mak6c4cc672018-09-17 11:48:50 +01001811 }
1812
Tony Mak76d80962020-01-08 17:30:51 +00001813 // Try the person name engine.
1814 ClassificationResult person_name_result;
1815 if (person_name_engine_ &&
1816 person_name_engine_->ClassifyText(context, selection_indices,
1817 &person_name_result)) {
1818 candidates.push_back({selection_indices, {person_name_result}});
Tony Makd0ae7c62020-03-27 13:58:00 +00001819 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
Tony Mak76d80962020-01-08 17:30:51 +00001820 }
1821
Tony Makd9446602019-02-20 18:25:39 +00001822 // Try the installed app engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001823 // TODO(b/126579108): Propagate error status.
Tony Makd9446602019-02-20 18:25:39 +00001824 ClassificationResult installed_app_result;
1825 if (installed_app_engine_ &&
1826 installed_app_engine_->ClassifyText(context, selection_indices,
1827 &installed_app_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001828 candidates.push_back({selection_indices, {installed_app_result}});
Tony Makd9446602019-02-20 18:25:39 +00001829 }
1830
Lukas Zilkab23e2122018-02-09 10:25:19 +01001831 // Try the regular expression models.
Tony Mak378c1f52019-03-04 15:58:11 +00001832 std::vector<ClassificationResult> regex_results;
1833 if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1834 return {};
1835 }
1836 for (const ClassificationResult& result : regex_results) {
1837 candidates.push_back({selection_indices, {result}});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001838 }
1839
Lukas Zilkab23e2122018-02-09 10:25:19 +01001840 // Try the date model.
Tony Mak378c1f52019-03-04 15:58:11 +00001841 //
1842 // DatetimeClassifyText only returns the first result, which can however have
1843 // more interpretations. They are inserted in the candidates as a single
1844 // AnnotatedSpan, so that they get treated together by the conflict resolution
1845 // algorithm.
Tony Mak854015a2019-01-16 15:56:48 +00001846 std::vector<ClassificationResult> datetime_results;
Tony Mak378c1f52019-03-04 15:58:11 +00001847 if (!DatetimeClassifyText(context, selection_indices, options,
1848 &datetime_results)) {
1849 return {};
1850 }
1851 if (!datetime_results.empty()) {
1852 candidates.push_back({selection_indices, std::move(datetime_results)});
Tony Mak448b5862019-03-22 13:36:41 +00001853 candidates.back().source = AnnotatedSpan::Source::DATETIME;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001854 }
1855
Tony Mak5a12b942020-05-01 12:41:31 +01001856 const UnicodeText context_unicode =
1857 UTF8ToUnicodeText(context, /*do_copy=*/false);
1858
Tony Mak378c1f52019-03-04 15:58:11 +00001859 // Try the number annotator.
1860 // TODO(b/126579108): Propagate error status.
1861 ClassificationResult number_annotator_result;
1862 if (number_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001863 number_annotator_->ClassifyText(context_unicode, selection_indices,
1864 options.annotation_usecase,
1865 &number_annotator_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001866 candidates.push_back({selection_indices, {number_annotator_result}});
1867 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001868
Tony Makad2e22d2019-03-20 17:35:13 +00001869 // Try the duration annotator.
1870 ClassificationResult duration_annotator_result;
1871 if (duration_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001872 duration_annotator_->ClassifyText(context_unicode, selection_indices,
1873 options.annotation_usecase,
1874 &duration_annotator_result)) {
Tony Makad2e22d2019-03-20 17:35:13 +00001875 candidates.push_back({selection_indices, {duration_annotator_result}});
Tony Mak448b5862019-03-22 13:36:41 +00001876 candidates.back().source = AnnotatedSpan::Source::DURATION;
Tony Makad2e22d2019-03-20 17:35:13 +00001877 }
1878
Tony Mak63959242020-02-07 18:31:16 +00001879 // Try the translate annotator.
1880 ClassificationResult translate_annotator_result;
1881 if (translate_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001882 translate_annotator_->ClassifyText(context_unicode, selection_indices,
1883 options.user_familiar_language_tags,
1884 &translate_annotator_result)) {
Tony Mak63959242020-02-07 18:31:16 +00001885 candidates.push_back({selection_indices, {translate_annotator_result}});
1886 }
1887
Tony Mak21460022020-03-12 18:29:35 +00001888 // Try the grammar model.
1889 ClassificationResult grammar_annotator_result;
1890 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
Tony Mak5a12b942020-05-01 12:41:31 +01001891 detected_text_language_tags, context_unicode,
Tony Mak21460022020-03-12 18:29:35 +00001892 selection_indices, &grammar_annotator_result)) {
1893 candidates.push_back({selection_indices, {grammar_annotator_result}});
1894 }
1895
Tony Maka44b3082020-08-13 18:57:10 +01001896 ClassificationResult pod_ner_annotator_result;
1897 if (pod_ner_annotator_ && options.use_pod_ner &&
1898 pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1899 &pod_ner_annotator_result)) {
1900 candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1901 }
1902
Tony Maka5090082020-09-18 16:41:23 +01001903 ClassificationResult vocab_annotator_result;
1904 if (vocab_annotator_ &&
1905 vocab_annotator_->ClassifyText(
1906 context_unicode, selection_indices, detected_text_language_tags,
1907 options.trigger_dictionary_on_beginner_words,
1908 &vocab_annotator_result)) {
1909 candidates.push_back({selection_indices, {vocab_annotator_result}});
1910 }
1911
Tony Maka44b3082020-08-13 18:57:10 +01001912 if (experimental_annotator_) {
1913 experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1914 candidates);
Tony Mak5a12b942020-05-01 12:41:31 +01001915 }
1916
Tony Mak378c1f52019-03-04 15:58:11 +00001917 // Try the ML model.
1918 //
1919 // The output of the model is considered as an exclusive 1-of-N choice. That's
1920 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1921 // span for each candidate, like e.g. the regex model.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001922 InterpreterManager interpreter_manager(selection_executor_.get(),
1923 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001924 std::vector<ClassificationResult> model_results;
1925 std::vector<Token> tokens;
1926 if (!ModelClassifyText(
1927 context, /*cached_tokens=*/{}, detected_text_language_tags,
1928 selection_indices, &interpreter_manager,
1929 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1930 return {};
1931 }
1932 if (!model_results.empty()) {
1933 candidates.push_back({selection_indices, std::move(model_results)});
1934 }
1935
1936 std::vector<int> candidate_indices;
1937 if (!ResolveConflicts(candidates, context, tokens,
1938 detected_text_language_tags, options.annotation_usecase,
1939 &interpreter_manager, &candidate_indices)) {
1940 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1941 return {};
1942 }
1943
1944 std::vector<ClassificationResult> results;
1945 for (const int i : candidate_indices) {
1946 for (const ClassificationResult& result : candidates[i].classification) {
1947 if (!FilteredForClassification(result)) {
1948 results.push_back(result);
1949 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001950 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001951 }
1952
Tony Mak378c1f52019-03-04 15:58:11 +00001953 // Sort results according to score.
1954 std::sort(results.begin(), results.end(),
1955 [](const ClassificationResult& a, const ClassificationResult& b) {
1956 return a.score > b.score;
1957 });
1958
1959 if (results.empty()) {
Tony Mak81e52422019-04-30 09:34:45 +01001960 results = {{Collections::Other(), 1.0}};
Tony Mak378c1f52019-03-04 15:58:11 +00001961 }
Tony Mak378c1f52019-03-04 15:58:11 +00001962 return results;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001963}
1964
Tony Mak378c1f52019-03-04 15:58:11 +00001965bool Annotator::ModelAnnotate(
1966 const std::string& context,
1967 const std::vector<Locale>& detected_text_language_tags,
1968 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1969 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001970 if (model_->triggering_options() == nullptr ||
1971 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1972 return true;
1973 }
1974
Tony Makdf54e742019-03-26 14:04:00 +00001975 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1976 ml_model_triggering_locales_,
1977 /*default_value=*/true)) {
1978 return true;
1979 }
1980
Lukas Zilka21d8c982018-01-24 11:11:20 +01001981 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1982 /*do_copy=*/false);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001983 std::vector<UnicodeTextRange> lines;
1984 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1985 lines.push_back({context_unicode.begin(), context_unicode.end()});
1986 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +01001987 lines = selection_feature_processor_->SplitContext(
1988 context_unicode, selection_feature_processor_->GetOptions()
1989 ->use_pipe_character_for_newline());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001990 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001991
Lukas Zilkaba849e72018-03-08 14:48:21 +01001992 const float min_annotate_confidence =
1993 (model_->triggering_options() != nullptr
1994 ? model_->triggering_options()->min_annotate_confidence()
1995 : 0.f);
1996
Lukas Zilkab23e2122018-02-09 10:25:19 +01001997 for (const UnicodeTextRange& line : lines) {
Tony Mak408c6b82019-03-08 17:57:27 +00001998 FeatureProcessor::EmbeddingCache embedding_cache;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001999 const std::string line_str =
2000 UnicodeText::UTF8Substring(line.first, line.second);
2001
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002002 *tokens = selection_feature_processor_->Tokenize(line_str);
Lukas Zilkaba849e72018-03-08 14:48:21 +01002003 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002004 line_str, {0, std::distance(line.first, line.second)},
2005 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002006 tokens,
Lukas Zilka21d8c982018-01-24 11:11:20 +01002007 /*click_pos=*/nullptr);
Tony Maka44b3082020-08-13 18:57:10 +01002008 const TokenSpan full_line_span = {0,
2009 static_cast<TokenIndex>(tokens->size())};
Lukas Zilka21d8c982018-01-24 11:11:20 +01002010
Lukas Zilka434442d2018-04-25 11:38:51 +02002011 // TODO(zilka): Add support for greater granularity of this check.
2012 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
2013 *tokens, full_line_span)) {
2014 continue;
2015 }
2016
Lukas Zilka21d8c982018-01-24 11:11:20 +01002017 std::unique_ptr<CachedFeatures> cached_features;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002018 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002019 *tokens, full_line_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002020 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2021 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01002022 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002023 selection_feature_processor_->EmbeddingSize() +
2024 selection_feature_processor_->DenseFeaturesCount(),
Lukas Zilka21d8c982018-01-24 11:11:20 +01002025 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002026 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002027 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002028 }
2029
2030 std::vector<TokenSpan> local_chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002031 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002032 interpreter_manager->SelectionInterpreter(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002033 *cached_features, &local_chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002034 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002035 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002036 }
2037
2038 const int offset = std::distance(context_unicode.begin(), line.first);
2039 for (const TokenSpan& chunk : local_chunks) {
Tony Maka44b3082020-08-13 18:57:10 +01002040 CodepointSpan codepoint_span =
Lukas Zilka21d8c982018-01-24 11:11:20 +01002041 selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002042 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
Tony Maka44b3082020-08-13 18:57:10 +01002043 if (model_->selection_options()->strip_unpaired_brackets()) {
2044 codepoint_span =
2045 StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
2046 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002047
2048 // Skip empty spans.
2049 if (codepoint_span.first != codepoint_span.second) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002050 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00002051 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
2052 codepoint_span, interpreter_manager,
2053 &embedding_cache, &classification)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002054 TC3_LOG(ERROR) << "Could not classify text: "
2055 << (codepoint_span.first + offset) << " "
2056 << (codepoint_span.second + offset);
Lukas Zilkaba849e72018-03-08 14:48:21 +01002057 return false;
2058 }
2059
2060 // Do not include the span if it's classified as "other".
2061 if (!classification.empty() && !ClassifiedAsOther(classification) &&
2062 classification[0].score >= min_annotate_confidence) {
2063 AnnotatedSpan result_span;
2064 result_span.span = {codepoint_span.first + offset,
2065 codepoint_span.second + offset};
2066 result_span.classification = std::move(classification);
2067 result->push_back(std::move(result_span));
2068 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002069 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01002070 }
2071 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002072 return true;
2073}
2074
Tony Mak6c4cc672018-09-17 11:48:50 +01002075const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
Lukas Zilka434442d2018-04-25 11:38:51 +02002076 return selection_feature_processor_.get();
2077}
2078
Tony Mak6c4cc672018-09-17 11:48:50 +01002079const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
Lukas Zilka434442d2018-04-25 11:38:51 +02002080 const {
2081 return classification_feature_processor_.get();
Lukas Zilkab23e2122018-02-09 10:25:19 +01002082}
2083
Tony Mak6c4cc672018-09-17 11:48:50 +01002084const DatetimeParser* Annotator::DatetimeParserForTests() const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002085 return datetime_parser_.get();
2086}
2087
Tony Mak83d2de62019-04-10 16:12:15 +01002088void Annotator::RemoveNotEnabledEntityTypes(
2089 const EnabledEntityTypes& is_entity_type_enabled,
2090 std::vector<AnnotatedSpan>* annotated_spans) const {
2091 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2092 std::vector<ClassificationResult>& classifications =
2093 annotated_span.classification;
2094 classifications.erase(
2095 std::remove_if(classifications.begin(), classifications.end(),
2096 [&is_entity_type_enabled](
2097 const ClassificationResult& classification_result) {
2098 return !is_entity_type_enabled(
2099 classification_result.collection);
2100 }),
2101 classifications.end());
2102 }
2103 annotated_spans->erase(
2104 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2105 [](const AnnotatedSpan& annotated_span) {
2106 return annotated_span.classification.empty();
2107 }),
2108 annotated_spans->end());
2109}
2110
Tony Maka2a1ff42019-09-12 15:40:32 +01002111void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2112 std::vector<AnnotatedSpan>* candidates) const {
2113 if (candidates == nullptr || contact_engine_ == nullptr) {
2114 return;
2115 }
2116 for (auto& candidate : *candidates) {
2117 for (auto& classification_result : candidate.classification) {
2118 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2119 &classification_result);
2120 }
2121 }
2122}
2123
Tony Makff31efb2020-03-31 11:13:06 +01002124Status Annotator::AnnotateSingleInput(
2125 const std::string& context, const AnnotationOptions& options,
2126 std::vector<AnnotatedSpan>* candidates) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002127 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
Tony Makff31efb2020-03-31 11:13:06 +01002128 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
Lukas Zilkaba849e72018-03-08 14:48:21 +01002129 }
2130
Tony Mak854015a2019-01-16 15:56:48 +00002131 const UnicodeText context_unicode =
2132 UTF8ToUnicodeText(context, /*do_copy=*/false);
2133 if (!context_unicode.is_valid()) {
Tony Makff31efb2020-03-31 11:13:06 +01002134 return Status(StatusCode::INVALID_ARGUMENT,
2135 "Context string isn't valid UTF8.");
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002136 }
2137
Tony Mak378c1f52019-03-04 15:58:11 +00002138 std::vector<Locale> detected_text_language_tags;
2139 if (!ParseLocales(options.detected_text_language_tags,
2140 &detected_text_language_tags)) {
2141 TC3_LOG(WARNING)
Tony Makdf54e742019-03-26 14:04:00 +00002142 << "Failed to parse the detected_text_language_tags in options: "
Tony Mak378c1f52019-03-04 15:58:11 +00002143 << options.detected_text_language_tags;
2144 }
Tony Makdf54e742019-03-26 14:04:00 +00002145 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2146 model_triggering_locales_,
2147 /*default_value=*/true)) {
Tony Makff31efb2020-03-31 11:13:06 +01002148 return Status(
2149 StatusCode::UNAVAILABLE,
2150 "The detected language tags are not in the supported locales.");
Tony Makdf54e742019-03-26 14:04:00 +00002151 }
2152
2153 InterpreterManager interpreter_manager(selection_executor_.get(),
2154 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00002155
Lukas Zilkab23e2122018-02-09 10:25:19 +01002156 // Annotate with the selection model.
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002157 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00002158 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
Tony Makff31efb2020-03-31 11:13:06 +01002159 &tokens, candidates)) {
2160 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002161 }
2162
Tony Maka44b3082020-08-13 18:57:10 +01002163 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002164 // Annotate with the regular expression models.
Tony Maka44b3082020-08-13 18:57:10 +01002165 if (!RegexChunk(
2166 UTF8ToUnicodeText(context, /*do_copy=*/false),
2167 annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2168 is_entity_type_enabled, options.annotation_usecase, candidates)) {
Tony Makff31efb2020-03-31 11:13:06 +01002169 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002170 }
2171
2172 // Annotate with the datetime model.
Tony Mak83d2de62019-04-10 16:12:15 +01002173 if ((is_entity_type_enabled(Collections::Date()) ||
2174 is_entity_type_enabled(Collections::DateTime())) &&
2175 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002176 options.reference_time_ms_utc, options.reference_timezone,
Tony Makd9446602019-02-20 18:25:39 +00002177 options.locales, ModeFlag_ANNOTATION,
Tony Mak83d2de62019-04-10 16:12:15 +01002178 options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002179 options.is_serialized_entity_data_enabled, candidates)) {
2180 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
Tony Mak6c4cc672018-09-17 11:48:50 +01002181 }
2182
Tony Mak854015a2019-01-16 15:56:48 +00002183 // Annotate with the contact engine.
2184 if (contact_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002185 !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2186 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
Tony Mak854015a2019-01-16 15:56:48 +00002187 }
2188
Tony Makd9446602019-02-20 18:25:39 +00002189 // Annotate with the installed app engine.
2190 if (installed_app_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002191 !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2192 return Status(StatusCode::INTERNAL,
2193 "Couldn't run installed app engine Chunk.");
Tony Makd9446602019-02-20 18:25:39 +00002194 }
2195
Tony Mak378c1f52019-03-04 15:58:11 +00002196 // Annotate with the number annotator.
Tony Maka44b3082020-08-13 18:57:10 +01002197 bool number_annotations_enabled = true;
2198 // Disable running the annotator in RAW mode if the number/percentage
2199 // annotations are not explicitly requested.
2200 if (options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
2201 !is_entity_type_enabled(Collections::Number()) &&
2202 !is_entity_type_enabled(Collections::Percentage())) {
2203 number_annotations_enabled = false;
2204 }
2205 if (number_annotations_enabled && number_annotator_ != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +00002206 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002207 candidates)) {
2208 return Status(StatusCode::INTERNAL,
2209 "Couldn't run number annotator FindAll.");
Tony Makad2e22d2019-03-20 17:35:13 +00002210 }
2211
2212 // Annotate with the duration annotator.
Tony Mak83d2de62019-04-10 16:12:15 +01002213 if (is_entity_type_enabled(Collections::Duration()) &&
2214 duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +00002215 !duration_annotator_->FindAll(context_unicode, tokens,
Tony Makff31efb2020-03-31 11:13:06 +01002216 options.annotation_usecase, candidates)) {
2217 return Status(StatusCode::INTERNAL,
2218 "Couldn't run duration annotator FindAll.");
Tony Mak378c1f52019-03-04 15:58:11 +00002219 }
2220
Tony Mak76d80962020-01-08 17:30:51 +00002221 // Annotate with the person name engine.
2222 if (is_entity_type_enabled(Collections::PersonName()) &&
2223 person_name_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002224 !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2225 return Status(StatusCode::INTERNAL,
2226 "Couldn't run person name engine Chunk.");
Tony Mak76d80962020-01-08 17:30:51 +00002227 }
2228
Tony Mak21460022020-03-12 18:29:35 +00002229 // Annotate with the grammar annotators.
2230 if (grammar_annotator_ != nullptr &&
2231 !grammar_annotator_->Annotate(detected_text_language_tags,
Tony Makff31efb2020-03-31 11:13:06 +01002232 context_unicode, candidates)) {
2233 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
Tony Mak21460022020-03-12 18:29:35 +00002234 }
2235
Tony Maka44b3082020-08-13 18:57:10 +01002236 // Annotate with the POD NER annotator.
2237 if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
2238 !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2239 return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2240 }
2241
Tony Maka5090082020-09-18 16:41:23 +01002242 // Annotate with the vocab annotator.
2243 if (vocab_annotator_ != nullptr &&
2244 !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2245 options.trigger_dictionary_on_beginner_words,
2246 candidates)) {
2247 return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2248 }
2249
Tony Maka44b3082020-08-13 18:57:10 +01002250 // Annotate with the experimental annotator.
Tony Mak5a12b942020-05-01 12:41:31 +01002251 if (experimental_annotator_ != nullptr &&
2252 !experimental_annotator_->Annotate(context_unicode, candidates)) {
2253 return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2254 }
2255
Lukas Zilkab23e2122018-02-09 10:25:19 +01002256 // Sort candidates according to their position in the input, so that the next
2257 // code can assume that any connected component of overlapping spans forms a
2258 // contiguous block.
Tony Mak5a12b942020-05-01 12:41:31 +01002259 // Also sort them according to the end position and collection, so that the
2260 // deduplication code below can assume that same spans and classifications
2261 // form contiguous blocks.
Tony Makff31efb2020-03-31 11:13:06 +01002262 std::sort(candidates->begin(), candidates->end(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002263 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
Tony Mak5a12b942020-05-01 12:41:31 +01002264 if (a.span.first != b.span.first) {
2265 return a.span.first < b.span.first;
2266 }
2267
2268 if (a.span.second != b.span.second) {
2269 return a.span.second < b.span.second;
2270 }
2271
2272 return a.classification[0].collection <
2273 b.classification[0].collection;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002274 });
2275
2276 std::vector<int> candidate_indices;
Tony Makff31efb2020-03-31 11:13:06 +01002277 if (!ResolveConflicts(*candidates, context, tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00002278 detected_text_language_tags, options.annotation_usecase,
2279 &interpreter_manager, &candidate_indices)) {
Tony Makff31efb2020-03-31 11:13:06 +01002280 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002281 }
2282
Tony Mak5a12b942020-05-01 12:41:31 +01002283 // Remove candidates that overlap exactly and have the same collection.
2284 // This can e.g. happen for phone coming from both ML model and regex.
2285 candidate_indices.erase(
2286 std::unique(candidate_indices.begin(), candidate_indices.end(),
2287 [&candidates](const int a_index, const int b_index) {
2288 const AnnotatedSpan& a = (*candidates)[a_index];
2289 const AnnotatedSpan& b = (*candidates)[b_index];
2290 return a.span == b.span &&
2291 a.classification[0].collection ==
2292 b.classification[0].collection;
2293 }),
2294 candidate_indices.end());
2295
Lukas Zilkab23e2122018-02-09 10:25:19 +01002296 std::vector<AnnotatedSpan> result;
2297 result.reserve(candidate_indices.size());
2298 for (const int i : candidate_indices) {
Tony Makff31efb2020-03-31 11:13:06 +01002299 if ((*candidates)[i].classification.empty() ||
2300 ClassifiedAsOther((*candidates)[i].classification) ||
2301 FilteredForAnnotation((*candidates)[i])) {
Tony Mak378c1f52019-03-04 15:58:11 +00002302 continue;
2303 }
Tony Mak5a12b942020-05-01 12:41:31 +01002304 result.push_back(std::move((*candidates)[i]));
Tony Mak378c1f52019-03-04 15:58:11 +00002305 }
2306
Tony Mak83d2de62019-04-10 16:12:15 +01002307 // We generate all candidates and remove them later (with the exception of
2308 // date/time/duration entities) because there are complex interdependencies
2309 // between the entity types. E.g., the TLD of an email can be interpreted as a
2310 // URL, but most likely a user of the API does not want such annotations if
2311 // "url" is enabled and "email" is not.
2312 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2313
Tony Mak378c1f52019-03-04 15:58:11 +00002314 for (AnnotatedSpan& annotated_span : result) {
2315 SortClassificationResults(&annotated_span.classification);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002316 }
Tony Makff31efb2020-03-31 11:13:06 +01002317 *candidates = result;
2318 return Status::OK;
2319}
Lukas Zilkab23e2122018-02-09 10:25:19 +01002320
Tony Maka44b3082020-08-13 18:57:10 +01002321StatusOr<Annotations> Annotator::AnnotateStructuredInput(
Tony Makff31efb2020-03-31 11:13:06 +01002322 const std::vector<InputFragment>& string_fragments,
2323 const AnnotationOptions& options) const {
Tony Maka44b3082020-08-13 18:57:10 +01002324 Annotations annotation_candidates;
2325 annotation_candidates.annotated_spans.resize(string_fragments.size());
Tony Makff31efb2020-03-31 11:13:06 +01002326
2327 std::vector<std::string> text_to_annotate;
2328 text_to_annotate.reserve(string_fragments.size());
2329 for (const auto& string_fragment : string_fragments) {
2330 text_to_annotate.push_back(string_fragment.text);
2331 }
2332
2333 // KnowledgeEngine is special, because it supports annotation of multiple
2334 // fragments at once.
2335 if (knowledge_engine_ &&
2336 !knowledge_engine_
2337 ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +01002338 options.location_context, options.permissions,
Tony Maka44b3082020-08-13 18:57:10 +01002339 options.annotate_mode, &annotation_candidates)
Tony Makff31efb2020-03-31 11:13:06 +01002340 .ok()) {
2341 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2342 }
2343 // The annotator engines shouldn't change the number of annotation vectors.
Tony Maka44b3082020-08-13 18:57:10 +01002344 if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
Tony Makff31efb2020-03-31 11:13:06 +01002345 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2346 << " texts to annotate but generated a different number of "
2347 "lists of annotations:"
Tony Maka44b3082020-08-13 18:57:10 +01002348 << annotation_candidates.annotated_spans.size();
Tony Makff31efb2020-03-31 11:13:06 +01002349 return Status(StatusCode::INTERNAL,
2350 "Number of annotation candidates differs from "
2351 "number of texts to annotate.");
2352 }
2353
Tony Maka44b3082020-08-13 18:57:10 +01002354 // As an optimization, if the only annotated type is Entity, we skip all the
2355 // other annotators than the KnowledgeEngine. This only happens in the raw
2356 // mode, to make sure it does not affect the result.
2357 if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2358 options.entity_types.size() == 1 &&
2359 *options.entity_types.begin() == Collections::Entity()) {
2360 return annotation_candidates;
2361 }
2362
Tony Makff31efb2020-03-31 11:13:06 +01002363 // Other annotators run on each fragment independently.
2364 for (int i = 0; i < text_to_annotate.size(); ++i) {
2365 AnnotationOptions annotation_options = options;
2366 if (string_fragments[i].datetime_options.has_value()) {
2367 DatetimeOptions reference_datetime =
2368 string_fragments[i].datetime_options.value();
2369 annotation_options.reference_time_ms_utc =
2370 reference_datetime.reference_time_ms_utc;
2371 annotation_options.reference_timezone =
2372 reference_datetime.reference_timezone;
2373 }
2374
2375 AddContactMetadataToKnowledgeClassificationResults(
Tony Maka44b3082020-08-13 18:57:10 +01002376 &annotation_candidates.annotated_spans[i]);
Tony Makff31efb2020-03-31 11:13:06 +01002377
Tony Maka44b3082020-08-13 18:57:10 +01002378 Status annotation_status =
2379 AnnotateSingleInput(text_to_annotate[i], annotation_options,
2380 &annotation_candidates.annotated_spans[i]);
Tony Makff31efb2020-03-31 11:13:06 +01002381 if (!annotation_status.ok()) {
2382 return annotation_status;
2383 }
2384 }
2385 return annotation_candidates;
2386}
2387
2388std::vector<AnnotatedSpan> Annotator::Annotate(
2389 const std::string& context, const AnnotationOptions& options) const {
2390 std::vector<InputFragment> string_fragments;
2391 string_fragments.push_back({.text = context});
Tony Maka44b3082020-08-13 18:57:10 +01002392 StatusOr<Annotations> annotations =
Tony Makff31efb2020-03-31 11:13:06 +01002393 AnnotateStructuredInput(string_fragments, options);
2394 if (!annotations.ok()) {
2395 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2396 << annotations.status().error_message();
2397 return {};
2398 }
Tony Maka44b3082020-08-13 18:57:10 +01002399 return annotations.ValueOrDie().annotated_spans[0];
Lukas Zilka21d8c982018-01-24 11:11:20 +01002400}
2401
Tony Mak854015a2019-01-16 15:56:48 +00002402CodepointSpan Annotator::ComputeSelectionBoundaries(
2403 const UniLib::RegexMatcher* match,
2404 const RegexModel_::Pattern* config) const {
2405 if (config->capturing_group() == nullptr) {
2406 // Use first capturing group to specify the selection.
2407 int status = UniLib::RegexMatcher::kNoError;
2408 const CodepointSpan result = {match->Start(1, &status),
2409 match->End(1, &status)};
2410 if (status != UniLib::RegexMatcher::kNoError) {
2411 return {kInvalidIndex, kInvalidIndex};
2412 }
2413 return result;
2414 }
2415
2416 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2417 const int num_groups = config->capturing_group()->size();
2418 for (int i = 0; i < num_groups; i++) {
2419 if (!config->capturing_group()->Get(i)->extend_selection()) {
2420 continue;
2421 }
2422
2423 int status = UniLib::RegexMatcher::kNoError;
2424 // Check match and adjust bounds.
2425 const int group_start = match->Start(i, &status);
2426 const int group_end = match->End(i, &status);
2427 if (status != UniLib::RegexMatcher::kNoError) {
2428 return {kInvalidIndex, kInvalidIndex};
2429 }
2430 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2431 continue;
2432 }
2433 if (result.first == kInvalidIndex) {
2434 result = {group_start, group_end};
2435 } else {
2436 result.first = std::min(result.first, group_start);
2437 result.second = std::max(result.second, group_end);
2438 }
2439 }
2440 return result;
2441}
2442
Tony Makd9446602019-02-20 18:25:39 +00002443bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
Tony Mak21460022020-03-12 18:29:35 +00002444 if (pattern->serialized_entity_data() != nullptr ||
2445 pattern->entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002446 return true;
2447 }
2448 if (pattern->capturing_group() != nullptr) {
Tony Mak63959242020-02-07 18:31:16 +00002449 for (const CapturingGroup* group : *pattern->capturing_group()) {
Tony Makd9446602019-02-20 18:25:39 +00002450 if (group->entity_field_path() != nullptr) {
2451 return true;
2452 }
Tony Mak21460022020-03-12 18:29:35 +00002453 if (group->serialized_entity_data() != nullptr ||
2454 group->entity_data() != nullptr) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002455 return true;
2456 }
Tony Makd9446602019-02-20 18:25:39 +00002457 }
2458 }
2459 return false;
2460}
2461
2462bool Annotator::SerializedEntityDataFromRegexMatch(
2463 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2464 std::string* serialized_entity_data) const {
2465 if (!HasEntityData(pattern)) {
2466 serialized_entity_data->clear();
2467 return true;
2468 }
2469 TC3_CHECK(entity_data_builder_ != nullptr);
2470
Tony Maka44b3082020-08-13 18:57:10 +01002471 std::unique_ptr<MutableFlatbuffer> entity_data =
Tony Makd9446602019-02-20 18:25:39 +00002472 entity_data_builder_->NewRoot();
2473
2474 TC3_CHECK(entity_data != nullptr);
2475
Tony Mak21460022020-03-12 18:29:35 +00002476 // Set fixed entity data.
Tony Makd9446602019-02-20 18:25:39 +00002477 if (pattern->serialized_entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002478 entity_data->MergeFromSerializedFlatbuffer(
2479 StringPiece(pattern->serialized_entity_data()->c_str(),
2480 pattern->serialized_entity_data()->size()));
2481 }
Tony Mak21460022020-03-12 18:29:35 +00002482 if (pattern->entity_data() != nullptr) {
2483 entity_data->MergeFrom(
2484 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2485 }
Tony Makd9446602019-02-20 18:25:39 +00002486
2487 // Add entity data from rule capturing groups.
2488 if (pattern->capturing_group() != nullptr) {
2489 const int num_groups = pattern->capturing_group()->size();
2490 for (int i = 0; i < num_groups; i++) {
Tony Mak63959242020-02-07 18:31:16 +00002491 const CapturingGroup* group = pattern->capturing_group()->Get(i);
Tony Maka2a1ff42019-09-12 15:40:32 +01002492
2493 // Check whether the group matched.
2494 Optional<std::string> group_match_text =
2495 GetCapturingGroupText(matcher, /*group_id=*/i);
2496 if (!group_match_text.has_value()) {
Tony Makd9446602019-02-20 18:25:39 +00002497 continue;
2498 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002499
Tony Mak21460022020-03-12 18:29:35 +00002500 // Set fixed entity data from capturing group match.
Tony Maka2a1ff42019-09-12 15:40:32 +01002501 if (group->serialized_entity_data() != nullptr) {
2502 entity_data->MergeFromSerializedFlatbuffer(
2503 StringPiece(group->serialized_entity_data()->c_str(),
2504 group->serialized_entity_data()->size()));
2505 }
Tony Mak21460022020-03-12 18:29:35 +00002506 if (group->entity_data() != nullptr) {
2507 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2508 pattern->entity_data()));
2509 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002510
2511 // Set entity field from capturing group text.
2512 if (group->entity_field_path() != nullptr) {
Tony Mak8cd7ba62019-10-15 15:29:22 +01002513 UnicodeText normalized_group_match_text =
2514 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2515
2516 // Apply normalization if specified.
2517 if (group->normalization_options() != nullptr) {
2518 normalized_group_match_text =
Tony Mak1ac2e4a2020-04-29 13:41:53 +01002519 NormalizeText(*unilib_, group->normalization_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +01002520 normalized_group_match_text);
2521 }
2522
2523 if (!entity_data->ParseAndSet(
2524 group->entity_field_path(),
2525 normalized_group_match_text.ToUTF8String())) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002526 TC3_LOG(ERROR)
2527 << "Could not set entity data from rule capturing group.";
2528 return false;
2529 }
Tony Makd9446602019-02-20 18:25:39 +00002530 }
2531 }
2532 }
2533
2534 *serialized_entity_data = entity_data->Serialize();
2535 return true;
2536}
2537
Tony Mak63959242020-02-07 18:31:16 +00002538UnicodeText RemoveMoneySeparators(
2539 const std::unordered_set<char32>& decimal_separators,
2540 const UnicodeText& amount,
2541 UnicodeText::const_iterator it_decimal_separator) {
2542 UnicodeText whole_amount;
2543 for (auto it = amount.begin();
2544 it != amount.end() && it != it_decimal_separator; ++it) {
2545 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2546 static_cast<char32>(*it)) == decimal_separators.end()) {
2547 whole_amount.push_back(*it);
2548 }
2549 }
2550 return whole_amount;
2551}
2552
Tony Maka44b3082020-08-13 18:57:10 +01002553void Annotator::GetMoneyQuantityFromCapturingGroup(
2554 const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2555 const UnicodeText& context_unicode, std::string* quantity,
2556 int* exponent) const {
2557 if (config->capturing_group() == nullptr) {
2558 *exponent = 0;
2559 return;
2560 }
2561
2562 const int num_groups = config->capturing_group()->size();
2563 for (int i = 0; i < num_groups; i++) {
2564 int status = UniLib::RegexMatcher::kNoError;
2565 const int group_start = match->Start(i, &status);
2566 const int group_end = match->End(i, &status);
2567 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2568 continue;
2569 }
2570
2571 *quantity =
2572 unilib_
2573 ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2574 group_end, /*do_copy=*/false))
2575 .ToUTF8String();
2576
2577 if (auto entry = model_->money_parsing_options()
2578 ->quantities_name_to_exponent()
2579 ->LookupByKey((*quantity).c_str())) {
2580 *exponent = entry->value();
2581 return;
2582 }
2583 }
2584 *exponent = 0;
2585}
2586
Tony Mak63959242020-02-07 18:31:16 +00002587bool Annotator::ParseAndFillInMoneyAmount(
Tony Maka44b3082020-08-13 18:57:10 +01002588 std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2589 const RegexModel_::Pattern* config,
2590 const UnicodeText& context_unicode) const {
Tony Mak63959242020-02-07 18:31:16 +00002591 std::unique_ptr<EntityDataT> data =
2592 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2593 *serialized_entity_data);
Tony Mak0b8b3322020-03-17 16:30:19 +00002594 if (data == nullptr) {
Tony Makc121edd2020-05-28 15:25:17 +01002595 if (model_->version() >= 706) {
2596 // This way of parsing money entity data is enabled for models newer than
2597 // v706, consequently logging errors only for them (b/156634162).
2598 TC3_LOG(ERROR)
2599 << "Data field is null when trying to parse Money Entity Data";
2600 }
Tony Mak0b8b3322020-03-17 16:30:19 +00002601 return false;
2602 }
2603 if (data->money->unnormalized_amount.empty()) {
Tony Makc121edd2020-05-28 15:25:17 +01002604 if (model_->version() >= 706) {
2605 // This way of parsing money entity data is enabled for models newer than
2606 // v706, consequently logging errors only for them (b/156634162).
2607 TC3_LOG(ERROR)
2608 << "Data unnormalized_amount is empty when trying to parse "
2609 "Money Entity Data";
2610 }
Tony Mak63959242020-02-07 18:31:16 +00002611 return false;
2612 }
2613
2614 UnicodeText amount =
2615 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2616 int separator_back_index = 0;
Tony Mak21460022020-03-12 18:29:35 +00002617 auto it_decimal_separator = --amount.end();
Tony Mak63959242020-02-07 18:31:16 +00002618 for (; it_decimal_separator != amount.begin();
2619 --it_decimal_separator, ++separator_back_index) {
2620 if (std::find(money_separators_.begin(), money_separators_.end(),
2621 static_cast<char32>(*it_decimal_separator)) !=
2622 money_separators_.end()) {
2623 break;
2624 }
2625 }
2626
2627 // If there are 3 digits after the last separator, we consider that a
2628 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2629 // If there is no separator in number, also that number is an int.
Tony Mak21460022020-03-12 18:29:35 +00002630 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
Tony Mak63959242020-02-07 18:31:16 +00002631 it_decimal_separator = amount.end();
2632 }
2633
2634 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2635 it_decimal_separator),
2636 &data->money->amount_whole_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002637 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2638 "amount: "
2639 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002640 return false;
2641 }
Tony Maka44b3082020-08-13 18:57:10 +01002642
Tony Mak63959242020-02-07 18:31:16 +00002643 if (it_decimal_separator == amount.end()) {
2644 data->money->amount_decimal_part = 0;
Tony Maka44b3082020-08-13 18:57:10 +01002645 data->money->nanos = 0;
Tony Mak63959242020-02-07 18:31:16 +00002646 } else {
2647 const int amount_codepoints_size = amount.size_codepoints();
Tony Maka44b3082020-08-13 18:57:10 +01002648 const UnicodeText decimal_part = UnicodeText::Substring(
2649 amount, amount_codepoints_size - separator_back_index,
2650 amount_codepoints_size, /*do_copy=*/false);
2651 if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002652 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2653 "the amount: "
2654 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002655 return false;
2656 }
Tony Maka44b3082020-08-13 18:57:10 +01002657 data->money->nanos = data->money->amount_decimal_part *
2658 pow(10, 9 - decimal_part.size_codepoints());
2659 }
2660
2661 if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2662 nullptr) {
2663 int quantity_exponent;
2664 std::string quantity;
2665 GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2666 &quantity, &quantity_exponent);
Tony Mak074ee382020-09-30 19:11:00 +01002667 if ((quantity_exponent > 0 && quantity_exponent < 9) ||
2668 (quantity_exponent == 9 && data->money->amount_whole_part <= 2)) {
Tony Maka44b3082020-08-13 18:57:10 +01002669 data->money->amount_whole_part =
2670 data->money->amount_whole_part * pow(10, quantity_exponent) +
2671 data->money->nanos / pow(10, 9 - quantity_exponent);
2672 data->money->nanos = data->money->nanos %
2673 static_cast<int>(pow(10, 9 - quantity_exponent)) *
2674 pow(10, quantity_exponent);
Tony Mak074ee382020-09-30 19:11:00 +01002675 }
2676 if (quantity_exponent > 0) {
Tony Maka44b3082020-08-13 18:57:10 +01002677 data->money->unnormalized_amount = strings::JoinStrings(
2678 " ", {data->money->unnormalized_amount, quantity});
2679 }
Tony Mak63959242020-02-07 18:31:16 +00002680 }
2681
2682 *serialized_entity_data =
2683 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2684 return true;
2685}
2686
Tony Mak6c4cc672018-09-17 11:48:50 +01002687bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2688 const std::vector<int>& rules,
Tony Maka44b3082020-08-13 18:57:10 +01002689 bool is_serialized_entity_data_enabled,
2690 const EnabledEntityTypes& enabled_entity_types,
2691 const AnnotationUsecase& annotation_usecase,
2692 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002693 for (int pattern_id : rules) {
2694 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
Tony Maka44b3082020-08-13 18:57:10 +01002695 if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2696 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2697 // No regex annotation type has been requested, skip regex annotation.
2698 continue;
2699 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002700 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2701 if (!matcher) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002702 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2703 << pattern_id;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002704 return false;
2705 }
2706
2707 int status = UniLib::RegexMatcher::kNoError;
2708 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
Tony Mak854015a2019-01-16 15:56:48 +00002709 if (regex_pattern.config->verification_options()) {
Tony Makdf54e742019-03-26 14:04:00 +00002710 if (!VerifyRegexMatchCandidate(
2711 context_unicode.ToUTF8String(),
2712 regex_pattern.config->verification_options(),
2713 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002714 continue;
2715 }
2716 }
Tony Makd9446602019-02-20 18:25:39 +00002717
2718 std::string serialized_entity_data;
Tony Mak83d2de62019-04-10 16:12:15 +01002719 if (is_serialized_entity_data_enabled) {
2720 if (!SerializedEntityDataFromRegexMatch(
2721 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2722 TC3_LOG(ERROR) << "Could not get entity data.";
2723 return false;
2724 }
Tony Mak63959242020-02-07 18:31:16 +00002725
Tony Maka44b3082020-08-13 18:57:10 +01002726 // Further parsing of money amount. Need this since regexes cannot have
2727 // empty groups that fill in entity data (amount_decimal_part and
2728 // quantity might be empty groups).
Tony Mak63959242020-02-07 18:31:16 +00002729 if (regex_pattern.config->collection_name()->str() ==
2730 Collections::Money()) {
Tony Maka44b3082020-08-13 18:57:10 +01002731 if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2732 regex_pattern.config,
2733 context_unicode)) {
Tony Makc121edd2020-05-28 15:25:17 +01002734 if (model_->version() >= 706) {
2735 // This way of parsing money entity data is enabled for models
2736 // newer than v706 => logging errors only for them (b/156634162).
2737 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2738 }
Tony Mak63959242020-02-07 18:31:16 +00002739 }
2740 }
Tony Makd9446602019-02-20 18:25:39 +00002741 }
2742
Lukas Zilkab23e2122018-02-09 10:25:19 +01002743 result->emplace_back();
Tony Mak854015a2019-01-16 15:56:48 +00002744
Lukas Zilkab23e2122018-02-09 10:25:19 +01002745 // Selection/annotation regular expressions need to specify a capturing
2746 // group specifying the selection.
Tony Mak854015a2019-01-16 15:56:48 +00002747 result->back().span =
2748 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2749
Lukas Zilkab23e2122018-02-09 10:25:19 +01002750 result->back().classification = {
Tony Mak854015a2019-01-16 15:56:48 +00002751 {regex_pattern.config->collection_name()->str(),
2752 regex_pattern.config->target_classification_score(),
2753 regex_pattern.config->priority_score()}};
Tony Makd9446602019-02-20 18:25:39 +00002754
2755 result->back().classification[0].serialized_entity_data =
2756 serialized_entity_data;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002757 }
2758 }
2759 return true;
2760}
2761
Tony Mak6c4cc672018-09-17 11:48:50 +01002762bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2763 tflite::Interpreter* selection_interpreter,
2764 const CachedFeatures& cached_features,
2765 std::vector<TokenSpan>* chunks) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01002766 const int max_selection_span =
2767 selection_feature_processor_->GetOptions()->max_selection_span();
Lukas Zilka21d8c982018-01-24 11:11:20 +01002768 // The inference span is the span of interest expanded to include
2769 // max_selection_span tokens on either side, which is how far a selection can
2770 // stretch from the click.
Tony Maka44b3082020-08-13 18:57:10 +01002771 const TokenSpan inference_span =
2772 IntersectTokenSpans(span_of_interest.Expand(
2773 /*num_tokens_left=*/max_selection_span,
2774 /*num_tokens_right=*/max_selection_span),
2775 {0, num_tokens});
Lukas Zilka21d8c982018-01-24 11:11:20 +01002776
2777 std::vector<ScoredChunk> scored_chunks;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002778 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2779 selection_feature_processor_->GetOptions()
2780 ->bounds_sensitive_features()
2781 ->enabled()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002782 if (!ModelBoundsSensitiveScoreChunks(
2783 num_tokens, span_of_interest, inference_span, cached_features,
2784 selection_interpreter, &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002785 return false;
2786 }
2787 } else {
2788 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002789 cached_features, selection_interpreter,
2790 &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002791 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002792 }
2793 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002794 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2795 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2796 return lhs.score < rhs.score;
2797 });
Lukas Zilka21d8c982018-01-24 11:11:20 +01002798
2799 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2800 // them greedily as long as they do not overlap with any previously picked
2801 // chunks.
Tony Maka44b3082020-08-13 18:57:10 +01002802 std::vector<bool> token_used(inference_span.Size());
Lukas Zilka21d8c982018-01-24 11:11:20 +01002803 chunks->clear();
2804 for (const ScoredChunk& scored_chunk : scored_chunks) {
2805 bool feasible = true;
2806 for (int i = scored_chunk.token_span.first;
2807 i < scored_chunk.token_span.second; ++i) {
2808 if (token_used[i - inference_span.first]) {
2809 feasible = false;
2810 break;
2811 }
2812 }
2813
2814 if (!feasible) {
2815 continue;
2816 }
2817
2818 for (int i = scored_chunk.token_span.first;
2819 i < scored_chunk.token_span.second; ++i) {
2820 token_used[i - inference_span.first] = true;
2821 }
2822
2823 chunks->push_back(scored_chunk.token_span);
2824 }
2825
2826 std::sort(chunks->begin(), chunks->end());
2827
2828 return true;
2829}
2830
Lukas Zilkab23e2122018-02-09 10:25:19 +01002831namespace {
2832// Updates the value at the given key in the map to maximum of the current value
2833// and the given value, or simply inserts the value if the key is not yet there.
2834template <typename Map>
2835void UpdateMax(Map* map, typename Map::key_type key,
2836 typename Map::mapped_type value) {
2837 const auto it = map->find(key);
2838 if (it != map->end()) {
2839 it->second = std::max(it->second, value);
2840 } else {
2841 (*map)[key] = value;
2842 }
2843}
2844} // namespace
2845
Tony Mak6c4cc672018-09-17 11:48:50 +01002846bool Annotator::ModelClickContextScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002847 int num_tokens, const TokenSpan& span_of_interest,
2848 const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002849 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002850 std::vector<ScoredChunk>* scored_chunks) const {
2851 const int max_batch_size = model_->selection_options()->batch_size();
2852
2853 std::vector<float> all_features;
2854 std::map<TokenSpan, float> chunk_scores;
2855 for (int batch_start = span_of_interest.first;
2856 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2857 const int batch_end =
2858 std::min(batch_start + max_batch_size, span_of_interest.second);
2859
2860 // Prepare features for the whole batch.
2861 all_features.clear();
2862 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2863 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2864 cached_features.AppendClickContextFeaturesForClick(click_pos,
2865 &all_features);
2866 }
2867
2868 // Run batched inference.
2869 const int batch_size = batch_end - batch_start;
2870 const int features_size = cached_features.OutputFeaturesSize();
2871 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002872 TensorView<float>(all_features.data(), {batch_size, features_size}),
2873 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002874 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002875 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002876 return false;
2877 }
2878 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2879 logits.dim(1) !=
2880 selection_feature_processor_->GetSelectionLabelCount()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002881 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002882 return false;
2883 }
2884
2885 // Save results.
2886 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2887 const std::vector<float> scores = ComputeSoftmax(
2888 logits.data() + logits.dim(1) * (click_pos - batch_start),
2889 logits.dim(1));
2890 for (int j = 0;
2891 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2892 TokenSpan relative_token_span;
2893 if (!selection_feature_processor_->LabelToTokenSpan(
2894 j, &relative_token_span)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002895 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002896 return false;
2897 }
Tony Maka44b3082020-08-13 18:57:10 +01002898 const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
2899 relative_token_span.first, relative_token_span.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002900 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2901 UpdateMax(&chunk_scores, candidate_span, scores[j]);
2902 }
2903 }
2904 }
2905 }
2906
2907 scored_chunks->clear();
2908 scored_chunks->reserve(chunk_scores.size());
2909 for (const auto& entry : chunk_scores) {
2910 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2911 }
2912
2913 return true;
2914}
2915
Tony Mak6c4cc672018-09-17 11:48:50 +01002916bool Annotator::ModelBoundsSensitiveScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002917 int num_tokens, const TokenSpan& span_of_interest,
2918 const TokenSpan& inference_span, const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002919 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002920 std::vector<ScoredChunk>* scored_chunks) const {
2921 const int max_selection_span =
2922 selection_feature_processor_->GetOptions()->max_selection_span();
2923 const int max_chunk_length = selection_feature_processor_->GetOptions()
2924 ->selection_reduced_output_space()
2925 ? max_selection_span + 1
2926 : 2 * max_selection_span + 1;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002927 const bool score_single_token_spans_as_zero =
2928 selection_feature_processor_->GetOptions()
2929 ->bounds_sensitive_features()
2930 ->score_single_token_spans_as_zero();
2931
2932 scored_chunks->clear();
2933 if (score_single_token_spans_as_zero) {
Tony Maka44b3082020-08-13 18:57:10 +01002934 scored_chunks->reserve(span_of_interest.Size());
Lukas Zilkaba849e72018-03-08 14:48:21 +01002935 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002936
2937 // Prepare all chunk candidates into one batch:
2938 // - Are contained in the inference span
2939 // - Have a non-empty intersection with the span of interest
2940 // - Are at least one token long
2941 // - Are not longer than the maximum chunk length
2942 std::vector<TokenSpan> candidate_spans;
2943 for (int start = inference_span.first; start < span_of_interest.second;
2944 ++start) {
2945 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2946 for (int end = leftmost_end_index;
2947 end <= inference_span.second && end - start <= max_chunk_length;
2948 ++end) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002949 const TokenSpan candidate_span = {start, end};
Tony Maka44b3082020-08-13 18:57:10 +01002950 if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002951 // Do not include the single token span in the batch, add a zero score
2952 // for it directly to the output.
2953 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2954 } else {
2955 candidate_spans.push_back(candidate_span);
2956 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002957 }
2958 }
2959
2960 const int max_batch_size = model_->selection_options()->batch_size();
2961
2962 std::vector<float> all_features;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002963 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
Lukas Zilkab23e2122018-02-09 10:25:19 +01002964 for (int batch_start = 0; batch_start < candidate_spans.size();
2965 batch_start += max_batch_size) {
2966 const int batch_end = std::min(batch_start + max_batch_size,
2967 static_cast<int>(candidate_spans.size()));
2968
2969 // Prepare features for the whole batch.
2970 all_features.clear();
2971 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2972 for (int i = batch_start; i < batch_end; ++i) {
2973 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2974 &all_features);
2975 }
2976
2977 // Run batched inference.
2978 const int batch_size = batch_end - batch_start;
2979 const int features_size = cached_features.OutputFeaturesSize();
2980 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002981 TensorView<float>(all_features.data(), {batch_size, features_size}),
2982 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002983 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002984 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002985 return false;
2986 }
2987 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2988 logits.dim(1) != 1) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002989 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002990 return false;
2991 }
2992
2993 // Save results.
2994 for (int i = batch_start; i < batch_end; ++i) {
2995 scored_chunks->push_back(
2996 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2997 }
2998 }
2999
3000 return true;
3001}
3002
Tony Mak6c4cc672018-09-17 11:48:50 +01003003bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3004 int64 reference_time_ms_utc,
3005 const std::string& reference_timezone,
3006 const std::string& locales, ModeFlag mode,
Tony Makd9446602019-02-20 18:25:39 +00003007 AnnotationUsecase annotation_usecase,
Tony Mak83d2de62019-04-10 16:12:15 +01003008 bool is_serialized_entity_data_enabled,
Tony Mak6c4cc672018-09-17 11:48:50 +01003009 std::vector<AnnotatedSpan>* result) const {
Tony Mak63959242020-02-07 18:31:16 +00003010 std::vector<DatetimeParseResultSpan> datetime_spans;
3011 if (cfg_datetime_parser_) {
3012 if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
3013 return true;
3014 }
3015 std::vector<Locale> parsed_locales;
3016 ParseLocales(locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00003017 cfg_datetime_parser_->Parse(
3018 context_unicode.ToUTF8String(),
3019 ToDateAnnotationOptions(
3020 model_->grammar_datetime_model()->annotation_options(),
3021 reference_timezone, reference_time_ms_utc),
3022 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00003023 }
3024
3025 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00003026 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3027 reference_timezone, locales, mode,
3028 annotation_usecase,
3029 /*anchor_start_end=*/false, &datetime_spans)) {
3030 return false;
3031 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02003032 }
3033
Lukas Zilkab23e2122018-02-09 10:25:19 +01003034 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
Tony Mak378c1f52019-03-04 15:58:11 +00003035 AnnotatedSpan annotated_span;
3036 annotated_span.span = datetime_span.span;
Tony Mak854015a2019-01-16 15:56:48 +00003037 for (const DatetimeParseResult& parse_result : datetime_span.data) {
Tony Mak378c1f52019-03-04 15:58:11 +00003038 annotated_span.classification.emplace_back(
3039 PickCollectionForDatetime(parse_result),
3040 datetime_span.target_classification_score,
3041 datetime_span.priority_score);
3042 annotated_span.classification.back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01003043 if (is_serialized_entity_data_enabled) {
3044 annotated_span.classification.back().serialized_entity_data =
3045 CreateDatetimeSerializedEntityData(parse_result);
3046 }
Tony Mak854015a2019-01-16 15:56:48 +00003047 }
Tony Mak448b5862019-03-22 13:36:41 +00003048 annotated_span.source = AnnotatedSpan::Source::DATETIME;
Tony Mak378c1f52019-03-04 15:58:11 +00003049 result->push_back(std::move(annotated_span));
Lukas Zilkab23e2122018-02-09 10:25:19 +01003050 }
3051 return true;
3052}
3053
Tony Mak378c1f52019-03-04 15:58:11 +00003054const Model* Annotator::model() const { return model_; }
Tony Makd9446602019-02-20 18:25:39 +00003055const reflection::Schema* Annotator::entity_data_schema() const {
3056 return entity_data_schema_;
3057}
Tony Mak854015a2019-01-16 15:56:48 +00003058
Lukas Zilka21d8c982018-01-24 11:11:20 +01003059const Model* ViewModel(const void* buffer, int size) {
3060 if (!buffer) {
3061 return nullptr;
3062 }
3063
3064 return LoadAndVerifyModel(buffer, size);
3065}
3066
Tony Makd9446602019-02-20 18:25:39 +00003067bool Annotator::LookUpKnowledgeEntity(
3068 const std::string& id, std::string* serialized_knowledge_result) const {
3069 return knowledge_engine_ &&
3070 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
3071}
3072
Tony Mak6c4cc672018-09-17 11:48:50 +01003073} // namespace libtextclassifier3