blob: e296a6494f32a288108be07bed9869e6c1658f87 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <algorithm>
Lukas Zilka21d8c982018-01-24 11:11:20 +010020#include <cmath>
Tony Mak21460022020-03-12 18:29:35 +000021#include <cstddef>
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <iterator>
Tony Mak13125532021-01-13 21:12:07 +000023#include <limits>
Lukas Zilka21d8c982018-01-24 11:11:20 +010024#include <numeric>
Tony Mak63959242020-02-07 18:31:16 +000025#include <string>
Tony Mak448b5862019-03-22 13:36:41 +000026#include <unordered_map>
Tony Mak63959242020-02-07 18:31:16 +000027#include <vector>
Lukas Zilka21d8c982018-01-24 11:11:20 +010028
Tony Mak854015a2019-01-16 15:56:48 +000029#include "annotator/collections.h"
Tony Make940bc22021-04-07 15:35:23 +010030#include "annotator/datetime/grammar-parser.h"
Tony Mak13125532021-01-13 21:12:07 +000031#include "annotator/datetime/regex-parser.h"
Tony Maka44b3082020-08-13 18:57:10 +010032#include "annotator/flatbuffer-utils.h"
33#include "annotator/knowledge/knowledge-engine-types.h"
Tony Mak83d2de62019-04-10 16:12:15 +010034#include "annotator/model_generated.h"
35#include "annotator/types.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010036#include "utils/base/logging.h"
Tony Makff31efb2020-03-31 11:13:06 +010037#include "utils/base/status.h"
38#include "utils/base/statusor.h"
Tony Mak13125532021-01-13 21:12:07 +000039#include "utils/calendar/calendar.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010040#include "utils/checksum.h"
Tony Make940bc22021-04-07 15:35:23 +010041#include "utils/grammar/analyzer.h"
Tony Mak13125532021-01-13 21:12:07 +000042#include "utils/i18n/locale-list.h"
Tony Mak63959242020-02-07 18:31:16 +000043#include "utils/i18n/locale.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010044#include "utils/math/softmax.h"
Tony Mak8cd7ba62019-10-15 15:29:22 +010045#include "utils/normalization.h"
Tony Maka2a1ff42019-09-12 15:40:32 +010046#include "utils/optional.h"
Tony Makd9446602019-02-20 18:25:39 +000047#include "utils/regex-match.h"
Tony Maka44b3082020-08-13 18:57:10 +010048#include "utils/strings/append.h"
Tony Mak63959242020-02-07 18:31:16 +000049#include "utils/strings/numbers.h"
50#include "utils/strings/split.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010051#include "utils/utf8/unicodetext.h"
Tony Mak21460022020-03-12 18:29:35 +000052#include "utils/utf8/unilib-common.h"
Tony Mak378c1f52019-03-04 15:58:11 +000053#include "utils/zlib/zlib_regex.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010054
Tony Mak6c4cc672018-09-17 11:48:50 +010055namespace libtextclassifier3 {
Tony Mak448b5862019-03-22 13:36:41 +000056
57using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58
Tony Mak6c4cc672018-09-17 11:48:50 +010059const std::string& Annotator::kPhoneCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010060 *[]() { return new std::string("phone"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010061const std::string& Annotator::kAddressCollection =
Lukas Zilkae7962cc2018-03-28 18:09:48 +020062 *[]() { return new std::string("address"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010063const std::string& Annotator::kDateCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010064 *[]() { return new std::string("date"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000065const std::string& Annotator::kUrlCollection =
66 *[]() { return new std::string("url"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000067const std::string& Annotator::kEmailCollection =
68 *[]() { return new std::string("email"); }();
Lukas Zilkab23e2122018-02-09 10:25:19 +010069
Lukas Zilka21d8c982018-01-24 11:11:20 +010070namespace {
71const Model* LoadAndVerifyModel(const void* addr, int size) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010072 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
Tony Mak51a9e542018-11-02 13:36:22 +000073 if (VerifyModelBuffer(verifier)) {
74 return GetModel(addr);
Lukas Zilka21d8c982018-01-24 11:11:20 +010075 } else {
76 return nullptr;
77 }
78}
Tony Mak6c4cc672018-09-17 11:48:50 +010079
Tony Mak76d80962020-01-08 17:30:51 +000080const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81 int size) {
82 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83 if (VerifyPersonNameModelBuffer(verifier)) {
84 return GetPersonNameModel(addr);
85 } else {
86 return nullptr;
87 }
88}
89
Tony Mak6c4cc672018-09-17 11:48:50 +010090// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91// create a new instance, assign ownership to owned_lib, and return it.
92const UniLib* MaybeCreateUnilib(const UniLib* lib,
93 std::unique_ptr<UniLib>* owned_lib) {
94 if (lib) {
95 return lib;
96 } else {
97 owned_lib->reset(new UniLib);
98 return owned_lib->get();
99 }
100}
101
102// As above, but for CalendarLib.
103const CalendarLib* MaybeCreateCalendarlib(
104 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105 if (lib) {
106 return lib;
107 } else {
108 owned_lib->reset(new CalendarLib);
109 return owned_lib->get();
110 }
111}
112
Tony Mak968412a2019-11-13 15:39:57 +0000113// Returns whether the provided input is valid:
Tony Mak968412a2019-11-13 15:39:57 +0000114// * Sane span indices.
Tony Maka44b3082020-08-13 18:57:10 +0100115bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
Tony Mak968412a2019-11-13 15:39:57 +0000116 return (span.first >= 0 && span.first < span.second &&
117 span.second <= context.size_codepoints());
118}
119
Tony Mak63959242020-02-07 18:31:16 +0000120std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121 const flatbuffers::Vector<int32_t>* ints) {
122 if (ints == nullptr) {
123 return {};
124 }
125 std::unordered_set<char32> ints_set;
126 for (auto value : *ints) {
127 ints_set.insert(static_cast<char32>(value));
128 }
129 return ints_set;
130}
131
Lukas Zilka21d8c982018-01-24 11:11:20 +0100132} // namespace
133
Lukas Zilkaba849e72018-03-08 14:48:21 +0100134tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100136 TC3_CHECK(selection_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100137 selection_interpreter_ = selection_executor_->CreateInterpreter();
138 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100139 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100140 }
141 }
142 return selection_interpreter_.get();
143}
144
145tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100147 TC3_CHECK(classification_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100148 classification_interpreter_ = classification_executor_->CreateInterpreter();
149 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100150 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100151 }
152 }
153 return classification_interpreter_.get();
154}
155
Tony Mak6c4cc672018-09-17 11:48:50 +0100156std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157 const char* buffer, int size, const UniLib* unilib,
158 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100159 const Model* model = LoadAndVerifyModel(buffer, size);
160 if (model == nullptr) {
161 return nullptr;
162 }
163
Tony Makc7bdd322020-10-08 12:26:40 +0100164 auto classifier = std::unique_ptr<Annotator>(new Annotator());
165 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166 calendarlib =
167 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168 classifier->ValidateAndInitialize(model, unilib, calendarlib);
169 if (!classifier->IsInitialized()) {
170 return nullptr;
171 }
172
173 return classifier;
174}
175
176std::unique_ptr<Annotator> Annotator::FromString(
177 const std::string& buffer, const UniLib* unilib,
178 const CalendarLib* calendarlib) {
179 auto classifier = std::unique_ptr<Annotator>(new Annotator());
180 classifier->owned_buffer_ = buffer;
181 const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182 classifier->owned_buffer_.size());
183 if (model == nullptr) {
184 return nullptr;
185 }
186 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187 calendarlib =
188 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189 classifier->ValidateAndInitialize(model, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100190 if (!classifier->IsInitialized()) {
191 return nullptr;
192 }
193
194 return classifier;
195}
196
Tony Mak6c4cc672018-09-17 11:48:50 +0100197std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100200 if (!(*mmap)->handle().ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100201 TC3_VLOG(1) << "Mmap failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100202 return nullptr;
203 }
204
205 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206 (*mmap)->handle().num_bytes());
207 if (!model) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100208 TC3_LOG(ERROR) << "Model verification failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100209 return nullptr;
210 }
211
Tony Makc7bdd322020-10-08 12:26:40 +0100212 auto classifier = std::unique_ptr<Annotator>(new Annotator());
213 classifier->mmap_ = std::move(*mmap);
214 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215 calendarlib =
216 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217 classifier->ValidateAndInitialize(model, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100218 if (!classifier->IsInitialized()) {
219 return nullptr;
220 }
221
222 return classifier;
223}
224
Tony Makdf54e742019-03-26 14:04:00 +0000225std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227 std::unique_ptr<CalendarLib> calendarlib) {
228 if (!(*mmap)->handle().ok()) {
229 TC3_VLOG(1) << "Mmap failed.";
230 return nullptr;
231 }
232
233 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234 (*mmap)->handle().num_bytes());
235 if (model == nullptr) {
236 TC3_LOG(ERROR) << "Model verification failed.";
237 return nullptr;
238 }
239
Tony Makc7bdd322020-10-08 12:26:40 +0100240 auto classifier = std::unique_ptr<Annotator>(new Annotator());
241 classifier->mmap_ = std::move(*mmap);
242 classifier->owned_unilib_ = std::move(unilib);
243 classifier->owned_calendarlib_ = std::move(calendarlib);
244 classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245 classifier->owned_calendarlib_.get());
Tony Makdf54e742019-03-26 14:04:00 +0000246 if (!classifier->IsInitialized()) {
247 return nullptr;
248 }
249
250 return classifier;
251}
252
Tony Mak6c4cc672018-09-17 11:48:50 +0100253std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254 int fd, int offset, int size, const UniLib* unilib,
255 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100256 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
Tony Mak6c4cc672018-09-17 11:48:50 +0100257 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100258}
259
Tony Mak6c4cc672018-09-17 11:48:50 +0100260std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Makdf54e742019-03-26 14:04:00 +0000261 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262 std::unique_ptr<CalendarLib> calendarlib) {
263 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265}
266
267std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Mak6c4cc672018-09-17 11:48:50 +0100268 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100269 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
Tony Mak6c4cc672018-09-17 11:48:50 +0100270 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100271}
272
Tony Makdf54e742019-03-26 14:04:00 +0000273std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274 int fd, std::unique_ptr<UniLib> unilib,
275 std::unique_ptr<CalendarLib> calendarlib) {
276 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278}
279
Tony Mak6c4cc672018-09-17 11:48:50 +0100280std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281 const UniLib* unilib,
282 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100283 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
Tony Mak6c4cc672018-09-17 11:48:50 +0100284 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100285}
286
Tony Makdf54e742019-03-26 14:04:00 +0000287std::unique_ptr<Annotator> Annotator::FromPath(
288 const std::string& path, std::unique_ptr<UniLib> unilib,
289 std::unique_ptr<CalendarLib> calendarlib) {
290 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292}
293
Tony Makc7bdd322020-10-08 12:26:40 +0100294void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295 const CalendarLib* calendarlib) {
296 model_ = model;
297 unilib_ = unilib;
298 calendarlib_ = calendarlib;
Tony Mak6c4cc672018-09-17 11:48:50 +0100299
Lukas Zilkab23e2122018-02-09 10:25:19 +0100300 initialized_ = false;
301
Lukas Zilka21d8c982018-01-24 11:11:20 +0100302 if (model_ == nullptr) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100303 TC3_LOG(ERROR) << "No model specified.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100304 return;
305 }
306
Lukas Zilkaba849e72018-03-08 14:48:21 +0100307 const bool model_enabled_for_annotation =
308 (model_->triggering_options() != nullptr &&
309 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310 const bool model_enabled_for_classification =
311 (model_->triggering_options() != nullptr &&
312 (model_->triggering_options()->enabled_modes() &
313 ModeFlag_CLASSIFICATION));
314 const bool model_enabled_for_selection =
315 (model_->triggering_options() != nullptr &&
316 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317
318 // Annotation requires the selection model.
319 if (model_enabled_for_annotation || model_enabled_for_selection) {
320 if (!model_->selection_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100321 TC3_LOG(ERROR) << "No selection options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100322 return;
323 }
324 if (!model_->selection_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100325 TC3_LOG(ERROR) << "No selection feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100326 return;
327 }
328 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100329 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100330 return;
331 }
332 if (!model_->selection_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100333 TC3_LOG(ERROR) << "No selection model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100334 return;
335 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100336 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100337 if (!selection_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100338 TC3_LOG(ERROR) << "Could not initialize selection executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100339 return;
340 }
341 selection_feature_processor_.reset(
342 new FeatureProcessor(model_->selection_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100343 }
344
Lukas Zilkaba849e72018-03-08 14:48:21 +0100345 // Annotation requires the classification model for conflict resolution and
346 // scoring.
347 // Selection requires the classification model for conflict resolution.
348 if (model_enabled_for_annotation || model_enabled_for_classification ||
349 model_enabled_for_selection) {
350 if (!model_->classification_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100351 TC3_LOG(ERROR) << "No classification options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100352 return;
353 }
354
355 if (!model_->classification_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100356 TC3_LOG(ERROR) << "No classification feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100357 return;
358 }
359
360 if (!model_->classification_feature_options()
361 ->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100362 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100363 return;
364 }
365 if (!model_->classification_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100366 TC3_LOG(ERROR) << "No clf model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100367 return;
368 }
369
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200370 classification_executor_ =
Tony Mak6c4cc672018-09-17 11:48:50 +0100371 ModelExecutor::FromBuffer(model_->classification_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100372 if (!classification_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100373 TC3_LOG(ERROR) << "Could not initialize classification executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100374 return;
375 }
376
377 classification_feature_processor_.reset(new FeatureProcessor(
378 model_->classification_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100379 }
380
Lukas Zilkaba849e72018-03-08 14:48:21 +0100381 // The embeddings need to be specified if the model is to be used for
382 // classification or selection.
383 if (model_enabled_for_annotation || model_enabled_for_classification ||
384 model_enabled_for_selection) {
385 if (!model_->embedding_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100386 TC3_LOG(ERROR) << "No embedding model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100387 return;
388 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100389
Lukas Zilkaba849e72018-03-08 14:48:21 +0100390 // Check that the embedding size of the selection and classification model
391 // matches, as they are using the same embeddings.
392 if (model_enabled_for_selection &&
393 (model_->selection_feature_options()->embedding_size() !=
394 model_->classification_feature_options()->embedding_size() ||
395 model_->selection_feature_options()->embedding_quantization_bits() !=
396 model_->classification_feature_options()
397 ->embedding_quantization_bits())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100398 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100399 return;
400 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100401
Tony Mak6c4cc672018-09-17 11:48:50 +0100402 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200403 model_->embedding_model(),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100404 model_->classification_feature_options()->embedding_size(),
Tony Makdf54e742019-03-26 14:04:00 +0000405 model_->classification_feature_options()->embedding_quantization_bits(),
406 model_->embedding_pruning_mask());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200407 if (!embedding_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100408 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100409 return;
410 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100411 }
412
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200413 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
Lukas Zilkab23e2122018-02-09 10:25:19 +0100414 if (model_->regex_model()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200415 if (!InitializeRegexModel(decompressor.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100416 TC3_LOG(ERROR) << "Could not initialize regex model.";
Lukas Zilka434442d2018-04-25 11:38:51 +0200417 return;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100418 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100419 }
Tony Makd99d58c2020-03-19 21:52:02 +0000420
Tony Make940bc22021-04-07 15:35:23 +0100421 if (model_->datetime_grammar_model()) {
422 if (model_->datetime_grammar_model()->rules()) {
423 analyzer_ = std::make_unique<grammar::Analyzer>(
424 unilib_, model_->datetime_grammar_model()->rules());
425 datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
426 datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
427 *analyzer_, *datetime_grounder_,
428 /*target_classification_score=*/1.0,
429 /*priority_score=*/1.0);
430 }
431 } else if (model_->datetime_model()) {
Tony Mak13125532021-01-13 21:12:07 +0000432 datetime_parser_ = RegexDatetimeParser::Instance(
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100433 model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100434 if (!datetime_parser_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100435 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100436 return;
437 }
438 }
439
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200440 if (model_->output_options()) {
441 if (model_->output_options()->filtered_collections_annotation()) {
442 for (const auto collection :
443 *model_->output_options()->filtered_collections_annotation()) {
444 filtered_collections_annotation_.insert(collection->str());
445 }
446 }
447 if (model_->output_options()->filtered_collections_classification()) {
448 for (const auto collection :
449 *model_->output_options()->filtered_collections_classification()) {
450 filtered_collections_classification_.insert(collection->str());
451 }
452 }
453 if (model_->output_options()->filtered_collections_selection()) {
454 for (const auto collection :
455 *model_->output_options()->filtered_collections_selection()) {
456 filtered_collections_selection_.insert(collection->str());
457 }
458 }
459 }
460
Tony Mak378c1f52019-03-04 15:58:11 +0000461 if (model_->number_annotator_options() &&
462 model_->number_annotator_options()->enabled()) {
463 number_annotator_.reset(
Tony Mak63959242020-02-07 18:31:16 +0000464 new NumberAnnotator(model_->number_annotator_options(), unilib_));
465 }
466
467 if (model_->money_parsing_options()) {
468 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
469 model_->money_parsing_options()->separators());
Tony Mak378c1f52019-03-04 15:58:11 +0000470 }
471
Tony Makad2e22d2019-03-20 17:35:13 +0000472 if (model_->duration_annotator_options() &&
473 model_->duration_annotator_options()->enabled()) {
474 duration_annotator_.reset(
475 new DurationAnnotator(model_->duration_annotator_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +0100476 selection_feature_processor_.get(), unilib_));
Tony Makad2e22d2019-03-20 17:35:13 +0000477 }
478
Tony Maka44b3082020-08-13 18:57:10 +0100479 if (model_->grammar_model()) {
480 grammar_annotator_.reset(new GrammarAnnotator(
481 unilib_, model_->grammar_model(), entity_data_builder_.get()));
482 }
483
Tony Maka5090082020-09-18 16:41:23 +0100484 // The following #ifdef is here to aid quality evaluation of a situation, when
485 // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
486 // it.
487#if !defined(TC3_DISABLE_POD_NER)
Tony Maka44b3082020-08-13 18:57:10 +0100488 if (model_->pod_ner_model()) {
489 pod_ner_annotator_ =
490 PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
491 }
Tony Maka5090082020-09-18 16:41:23 +0100492#endif
493
494 if (model_->vocab_model()) {
495 vocab_annotator_ = VocabAnnotator::Create(
496 model_->vocab_model(), *selection_feature_processor_, *unilib_);
497 }
Tony Maka44b3082020-08-13 18:57:10 +0100498
Tony Makd9446602019-02-20 18:25:39 +0000499 if (model_->entity_data_schema()) {
500 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
501 model_->entity_data_schema()->Data(),
502 model_->entity_data_schema()->size());
503 if (entity_data_schema_ == nullptr) {
504 TC3_LOG(ERROR) << "Could not load entity data schema data.";
505 return;
506 }
507
508 entity_data_builder_.reset(
Tony Maka44b3082020-08-13 18:57:10 +0100509 new MutableFlatbufferBuilder(entity_data_schema_));
Tony Makd9446602019-02-20 18:25:39 +0000510 } else {
Tony Mak378c1f52019-03-04 15:58:11 +0000511 entity_data_schema_ = nullptr;
Tony Makd9446602019-02-20 18:25:39 +0000512 entity_data_builder_ = nullptr;
513 }
514
Tony Makdf54e742019-03-26 14:04:00 +0000515 if (model_->triggering_locales() &&
516 !ParseLocales(model_->triggering_locales()->c_str(),
517 &model_triggering_locales_)) {
Tony Mak378c1f52019-03-04 15:58:11 +0000518 TC3_LOG(ERROR) << "Could not parse model supported locales.";
519 return;
520 }
521
522 if (model_->triggering_options() != nullptr &&
Tony Makdf54e742019-03-26 14:04:00 +0000523 model_->triggering_options()->locales() != nullptr &&
524 !ParseLocales(model_->triggering_options()->locales()->c_str(),
525 &ml_model_triggering_locales_)) {
526 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
527 return;
528 }
529
530 if (model_->triggering_options() != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +0000531 model_->triggering_options()->dictionary_locales() != nullptr &&
532 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
533 &dictionary_locales_)) {
534 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
535 return;
536 }
537
Tony Mak5a12b942020-05-01 12:41:31 +0100538 if (model_->conflict_resolution_options() != nullptr) {
539 prioritize_longest_annotation_ =
540 model_->conflict_resolution_options()->prioritize_longest_annotation();
541 do_conflict_resolution_in_raw_mode_ =
542 model_->conflict_resolution_options()
543 ->do_conflict_resolution_in_raw_mode();
544 }
545
Chang Licac0b442020-05-21 15:09:37 +0100546#ifdef TC3_EXPERIMENTAL
547 TC3_LOG(WARNING) << "Enabling experimental annotators.";
548 InitializeExperimentalAnnotators();
549#endif
550
Lukas Zilka21d8c982018-01-24 11:11:20 +0100551 initialized_ = true;
552}
553
Tony Mak6c4cc672018-09-17 11:48:50 +0100554bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100555 if (!model_->regex_model()->patterns()) {
Lukas Zilka434442d2018-04-25 11:38:51 +0200556 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100557 }
558
559 // Initialize pattern recognizers.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100560 int regex_pattern_id = 0;
Tony Mak1ac2e4a2020-04-29 13:41:53 +0100561 for (const auto regex_pattern : *model_->regex_model()->patterns()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200562 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
Tony Mak378c1f52019-03-04 15:58:11 +0000563 UncompressMakeRegexPattern(
564 *unilib_, regex_pattern->pattern(),
565 regex_pattern->compressed_pattern(),
566 model_->regex_model()->lazy_regex_compilation(), decompressor);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100567 if (!compiled_pattern) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100568 TC3_LOG(INFO) << "Failed to load regex pattern";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200569 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100570 }
571
Lukas Zilkaba849e72018-03-08 14:48:21 +0100572 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100573 annotation_regex_patterns_.push_back(regex_pattern_id);
574 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100575 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100576 classification_regex_patterns_.push_back(regex_pattern_id);
577 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100578 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100579 selection_regex_patterns_.push_back(regex_pattern_id);
580 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100581 regex_patterns_.push_back({
Tony Mak854015a2019-01-16 15:56:48 +0000582 regex_pattern,
Tony Mak6c4cc672018-09-17 11:48:50 +0100583 std::move(compiled_pattern),
Tony Mak6c4cc672018-09-17 11:48:50 +0100584 });
Lukas Zilkab23e2122018-02-09 10:25:19 +0100585 ++regex_pattern_id;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100586 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100587
Lukas Zilkab23e2122018-02-09 10:25:19 +0100588 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100589}
590
Tony Mak6c4cc672018-09-17 11:48:50 +0100591bool Annotator::InitializeKnowledgeEngine(
592 const std::string& serialized_config) {
Tony Maka2a1ff42019-09-12 15:40:32 +0100593 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
Tony Mak63959242020-02-07 18:31:16 +0000594 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100595 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
596 return false;
597 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100598 if (model_->triggering_options() != nullptr) {
599 knowledge_engine->SetPriorityScore(
600 model_->triggering_options()->knowledge_priority_score());
601 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100602 knowledge_engine_ = std::move(knowledge_engine);
603 return true;
604}
605
Tony Mak854015a2019-01-16 15:56:48 +0000606bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
Tony Mak5dc5e112019-02-01 14:52:10 +0000607 std::unique_ptr<ContactEngine> contact_engine(
Tony Mak63959242020-02-07 18:31:16 +0000608 new ContactEngine(selection_feature_processor_.get(), unilib_,
609 model_->contact_annotator_options()));
Tony Mak854015a2019-01-16 15:56:48 +0000610 if (!contact_engine->Initialize(serialized_config)) {
611 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
612 return false;
613 }
614 contact_engine_ = std::move(contact_engine);
615 return true;
616}
617
Tony Makd9446602019-02-20 18:25:39 +0000618bool Annotator::InitializeInstalledAppEngine(
619 const std::string& serialized_config) {
620 std::unique_ptr<InstalledAppEngine> installed_app_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000621 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
Tony Makd9446602019-02-20 18:25:39 +0000622 if (!installed_app_engine->Initialize(serialized_config)) {
623 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
624 return false;
625 }
626 installed_app_engine_ = std::move(installed_app_engine);
627 return true;
628}
629
Tony Mak13125532021-01-13 21:12:07 +0000630bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
631 if (lang_id == nullptr) {
632 return false;
633 }
634
Tony Mak63959242020-02-07 18:31:16 +0000635 lang_id_ = lang_id;
Tony Mak21460022020-03-12 18:29:35 +0000636 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
Tony Mak63959242020-02-07 18:31:16 +0000637 model_->translate_annotator_options()->enabled()) {
638 translate_annotator_.reset(new TranslateAnnotator(
639 model_->translate_annotator_options(), lang_id_, unilib_));
Tony Mak21460022020-03-12 18:29:35 +0000640 } else {
641 translate_annotator_.reset(nullptr);
Tony Mak63959242020-02-07 18:31:16 +0000642 }
Tony Mak13125532021-01-13 21:12:07 +0000643 return true;
Tony Mak63959242020-02-07 18:31:16 +0000644}
645
Tony Mak21460022020-03-12 18:29:35 +0000646bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
647 int size) {
648 const PersonNameModel* person_name_model =
649 LoadAndVerifyPersonNameModel(buffer, size);
Tony Mak76d80962020-01-08 17:30:51 +0000650
651 if (person_name_model == nullptr) {
652 TC3_LOG(ERROR) << "Person name model verification failed.";
653 return false;
654 }
655
656 if (!person_name_model->enabled()) {
657 return true;
658 }
659
660 std::unique_ptr<PersonNameEngine> person_name_engine(
Tony Mak21460022020-03-12 18:29:35 +0000661 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
Tony Mak76d80962020-01-08 17:30:51 +0000662 if (!person_name_engine->Initialize(person_name_model)) {
663 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
664 return false;
665 }
666 person_name_engine_ = std::move(person_name_engine);
667 return true;
668}
669
Tony Mak21460022020-03-12 18:29:35 +0000670bool Annotator::InitializePersonNameEngineFromScopedMmap(
671 const ScopedMmap& mmap) {
672 if (!mmap.handle().ok()) {
673 TC3_LOG(ERROR) << "Mmap for person name model failed.";
674 return false;
675 }
676
677 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
678 mmap.handle().num_bytes());
679}
680
681bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
682 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
683 return InitializePersonNameEngineFromScopedMmap(*mmap);
684}
685
686bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
687 int size) {
688 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
689 return InitializePersonNameEngineFromScopedMmap(*mmap);
690}
691
Tony Mak5a12b942020-05-01 12:41:31 +0100692bool Annotator::InitializeExperimentalAnnotators() {
693 if (ExperimentalAnnotator::IsEnabled()) {
Tony Makc121edd2020-05-28 15:25:17 +0100694 experimental_annotator_.reset(new ExperimentalAnnotator(
695 model_->experimental_model(), *selection_feature_processor_, *unilib_));
Tony Mak5a12b942020-05-01 12:41:31 +0100696 return true;
697 }
698 return false;
699}
700
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200701namespace internal {
702// Helper function, which if the initial 'span' contains only white-spaces,
703// moves the selection to a single-codepoint selection on a left or right side
704// of this space.
Tony Maka44b3082020-08-13 18:57:10 +0100705CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200706 const UnicodeText& context_unicode,
707 const UniLib& unilib) {
Tony Maka44b3082020-08-13 18:57:10 +0100708 TC3_CHECK(span.IsValid() && !span.IsEmpty());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200709
710 UnicodeText::const_iterator it;
711
712 // Check that the current selection is all whitespaces.
713 it = context_unicode.begin();
714 std::advance(it, span.first);
715 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
716 if (!unilib.IsWhitespace(*it)) {
717 return span;
718 }
719 }
720
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200721 // Try moving left.
Tony Maka44b3082020-08-13 18:57:10 +0100722 CodepointSpan result = span;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200723 it = context_unicode.begin();
724 std::advance(it, span.first);
725 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
726 --result.first;
727 --it;
728 }
729 result.second = result.first + 1;
730 if (!unilib.IsWhitespace(*it)) {
731 return result;
732 }
733
734 // If moving left didn't find a non-whitespace character, just return the
735 // original span.
736 return span;
737}
738} // namespace internal
739
Tony Mak6c4cc672018-09-17 11:48:50 +0100740bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200741 return !span.classification.empty() &&
742 filtered_collections_annotation_.find(
743 span.classification[0].collection) !=
744 filtered_collections_annotation_.end();
745}
746
Tony Mak6c4cc672018-09-17 11:48:50 +0100747bool Annotator::FilteredForClassification(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200748 const ClassificationResult& classification) const {
749 return filtered_collections_classification_.find(classification.collection) !=
750 filtered_collections_classification_.end();
751}
752
Tony Mak6c4cc672018-09-17 11:48:50 +0100753bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200754 return !span.classification.empty() &&
755 filtered_collections_selection_.find(
756 span.classification[0].collection) !=
757 filtered_collections_selection_.end();
758}
759
Tony Mak378c1f52019-03-04 15:58:11 +0000760namespace {
761inline bool ClassifiedAsOther(
762 const std::vector<ClassificationResult>& classification) {
763 return !classification.empty() &&
764 classification[0].collection == Collections::Other();
765}
766
Tony Maka2a1ff42019-09-12 15:40:32 +0100767} // namespace
768
769float Annotator::GetPriorityScore(
770 const std::vector<ClassificationResult>& classification) const {
Tony Mak378c1f52019-03-04 15:58:11 +0000771 if (!classification.empty() && !ClassifiedAsOther(classification)) {
772 return classification[0].priority_score;
773 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +0100774 if (model_->triggering_options() != nullptr) {
775 return model_->triggering_options()->other_collection_priority_score();
776 } else {
777 return -1000.0;
778 }
Tony Mak378c1f52019-03-04 15:58:11 +0000779 }
780}
Tony Mak378c1f52019-03-04 15:58:11 +0000781
Tony Makdf54e742019-03-26 14:04:00 +0000782bool Annotator::VerifyRegexMatchCandidate(
783 const std::string& context, const VerificationOptions* verification_options,
784 const std::string& match, const UniLib::RegexMatcher* matcher) const {
785 if (verification_options == nullptr) {
786 return true;
787 }
788 if (verification_options->verify_luhn_checksum() &&
789 !VerifyLuhnChecksum(match)) {
790 return false;
791 }
792 const int lua_verifier = verification_options->lua_verifier();
793 if (lua_verifier >= 0) {
794 if (model_->regex_model()->lua_verifier() == nullptr ||
795 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
796 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
797 return false;
798 }
799 return VerifyMatch(
800 context, matcher,
801 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
802 }
803 return true;
804}
805
Tony Mak6c4cc672018-09-17 11:48:50 +0100806CodepointSpan Annotator::SuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100807 const std::string& context, CodepointSpan click_indices,
808 const SelectionOptions& options) const {
Tony Mak13125532021-01-13 21:12:07 +0000809 if (context.size() > std::numeric_limits<int>::max()) {
810 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
811 return {};
812 }
813
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200814 CodepointSpan original_click_indices = click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100815 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100816 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200817 return original_click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100818 }
Tony Mak5a12b942020-05-01 12:41:31 +0100819 if (options.annotation_usecase !=
820 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
821 TC3_LOG(WARNING)
822 << "Invoking SuggestSelection, which is not supported in RAW mode.";
823 return original_click_indices;
824 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100825 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200826 return original_click_indices;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100827 }
828
Tony Makdf54e742019-03-26 14:04:00 +0000829 std::vector<Locale> detected_text_language_tags;
830 if (!ParseLocales(options.detected_text_language_tags,
831 &detected_text_language_tags)) {
832 TC3_LOG(WARNING)
833 << "Failed to parse the detected_text_language_tags in options: "
834 << options.detected_text_language_tags;
835 }
836 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
837 model_triggering_locales_,
838 /*default_value=*/true)) {
839 return original_click_indices;
840 }
841
Lukas Zilkadf710db2018-02-27 12:44:09 +0100842 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
843 /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200844
Tony Mak13125532021-01-13 21:12:07 +0000845 if (!unilib_->IsValidUtf8(context_unicode)) {
846 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
847 return original_click_indices;
848 }
849
Tony Mak968412a2019-11-13 15:39:57 +0000850 if (!IsValidSpanInput(context_unicode, click_indices)) {
851 TC3_VLOG(1)
852 << "Trying to run SuggestSelection with invalid input, indices: "
853 << click_indices.first << " " << click_indices.second;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200854 return original_click_indices;
855 }
856
857 if (model_->snap_whitespace_selections()) {
858 // We want to expand a purely white-space selection to a multi-selection it
859 // would've been part of. But with this feature disabled we would do a no-
860 // op, because no token is found. Therefore, we need to modify the
861 // 'click_indices' a bit to include a part of the token, so that the click-
862 // finding logic finds the clicked token correctly. This modification is
863 // done by the following function. Note, that it's enough to check the left
864 // side of the current selection, because if the white-space is a part of a
Tony Mak6c4cc672018-09-17 11:48:50 +0100865 // multi-selection, necessarily both tokens - on the left and the right
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200866 // sides need to be selected. Thus snapping only to the left is sufficient
867 // (there's a check at the bottom that makes sure that if we snap to the
868 // left token but the result does not contain the initial white-space,
869 // returns the original indices).
870 click_indices = internal::SnapLeftIfWhitespaceSelection(
871 click_indices, context_unicode, *unilib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100872 }
873
Tony Maka44b3082020-08-13 18:57:10 +0100874 Annotations candidates;
875 // As we process a single string of context, the candidates will only
876 // contain one vector of AnnotatedSpan.
877 candidates.annotated_spans.resize(1);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100878 InterpreterManager interpreter_manager(selection_executor_.get(),
879 classification_executor_.get());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200880 std::vector<Token> tokens;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100881 if (!ModelSuggestSelection(context_unicode, click_indices,
Tony Makdf54e742019-03-26 14:04:00 +0000882 detected_text_language_tags, &interpreter_manager,
Tony Maka44b3082020-08-13 18:57:10 +0100883 &tokens, &candidates.annotated_spans[0])) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100884 TC3_LOG(ERROR) << "Model suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200885 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100886 }
Tony Maka44b3082020-08-13 18:57:10 +0100887 const std::unordered_set<std::string> set;
888 const EnabledEntityTypes is_entity_type_enabled(set);
889 if (!RegexChunk(context_unicode, selection_regex_patterns_,
890 /*is_serialized_entity_data_enabled=*/false,
891 is_entity_type_enabled, options.annotation_usecase,
892 &candidates.annotated_spans[0])) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100893 TC3_LOG(ERROR) << "Regex suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200894 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100895 }
Tony Maka44b3082020-08-13 18:57:10 +0100896 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
897 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
898 options.locales, ModeFlag_SELECTION,
899 options.annotation_usecase,
900 /*is_serialized_entity_data_enabled=*/false,
901 &candidates.annotated_spans[0])) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100902 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
903 return original_click_indices;
904 }
Tony Mak378c1f52019-03-04 15:58:11 +0000905 if (knowledge_engine_ != nullptr &&
Tony Makf5fd3652021-03-18 19:23:10 +0000906 !knowledge_engine_
907 ->Chunk(context, options.annotation_usecase,
908 options.location_context, Permissions(),
909 AnnotateMode::kEntityAnnotation, &candidates)
910 .ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100911 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200912 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100913 }
Tony Mak378c1f52019-03-04 15:58:11 +0000914 if (contact_engine_ != nullptr &&
Tony Maka44b3082020-08-13 18:57:10 +0100915 !contact_engine_->Chunk(context_unicode, tokens,
916 &candidates.annotated_spans[0])) {
Tony Mak854015a2019-01-16 15:56:48 +0000917 TC3_LOG(ERROR) << "Contact suggest selection failed.";
918 return original_click_indices;
919 }
Tony Mak378c1f52019-03-04 15:58:11 +0000920 if (installed_app_engine_ != nullptr &&
Tony Maka44b3082020-08-13 18:57:10 +0100921 !installed_app_engine_->Chunk(context_unicode, tokens,
922 &candidates.annotated_spans[0])) {
Tony Makd9446602019-02-20 18:25:39 +0000923 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
924 return original_click_indices;
925 }
Tony Mak378c1f52019-03-04 15:58:11 +0000926 if (number_annotator_ != nullptr &&
927 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
Tony Maka44b3082020-08-13 18:57:10 +0100928 &candidates.annotated_spans[0])) {
Tony Mak378c1f52019-03-04 15:58:11 +0000929 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
930 return original_click_indices;
931 }
Tony Makad2e22d2019-03-20 17:35:13 +0000932 if (duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +0000933 !duration_annotator_->FindAll(context_unicode, tokens,
Tony Maka44b3082020-08-13 18:57:10 +0100934 options.annotation_usecase,
935 &candidates.annotated_spans[0])) {
Tony Makad2e22d2019-03-20 17:35:13 +0000936 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
937 return original_click_indices;
938 }
Tony Mak76d80962020-01-08 17:30:51 +0000939 if (person_name_engine_ != nullptr &&
Tony Maka44b3082020-08-13 18:57:10 +0100940 !person_name_engine_->Chunk(context_unicode, tokens,
941 &candidates.annotated_spans[0])) {
Tony Mak76d80962020-01-08 17:30:51 +0000942 TC3_LOG(ERROR) << "Person name suggest selection failed.";
943 return original_click_indices;
944 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100945
Tony Mak21460022020-03-12 18:29:35 +0000946 AnnotatedSpan grammar_suggested_span;
947 if (grammar_annotator_ != nullptr &&
948 grammar_annotator_->SuggestSelection(detected_text_language_tags,
949 context_unicode, click_indices,
950 &grammar_suggested_span)) {
Tony Maka44b3082020-08-13 18:57:10 +0100951 candidates.annotated_spans[0].push_back(grammar_suggested_span);
952 }
953
Tony Mak13125532021-01-13 21:12:07 +0000954 AnnotatedSpan pod_ner_suggested_span;
955 if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
956 pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
957 &pod_ner_suggested_span)) {
958 candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
Tony Mak21460022020-03-12 18:29:35 +0000959 }
960
Tony Mak5a12b942020-05-01 12:41:31 +0100961 if (experimental_annotator_ != nullptr) {
Tony Maka44b3082020-08-13 18:57:10 +0100962 candidates.annotated_spans[0].push_back(
963 experimental_annotator_->SuggestSelection(context_unicode,
964 click_indices));
Tony Mak5a12b942020-05-01 12:41:31 +0100965 }
966
Lukas Zilkab23e2122018-02-09 10:25:19 +0100967 // Sort candidates according to their position in the input, so that the next
968 // code can assume that any connected component of overlapping spans forms a
969 // contiguous block.
Tony Maka44b3082020-08-13 18:57:10 +0100970 std::sort(candidates.annotated_spans[0].begin(),
971 candidates.annotated_spans[0].end(),
Lukas Zilkab23e2122018-02-09 10:25:19 +0100972 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
973 return a.span.first < b.span.first;
974 });
975
976 std::vector<int> candidate_indices;
Tony Maka44b3082020-08-13 18:57:10 +0100977 if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
Tony Mak13125532021-01-13 21:12:07 +0000978 detected_text_language_tags, options,
Tony Mak378c1f52019-03-04 15:58:11 +0000979 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100980 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200981 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100982 }
983
Tony Mak378c1f52019-03-04 15:58:11 +0000984 std::sort(candidate_indices.begin(), candidate_indices.end(),
Tony Maka2a1ff42019-09-12 15:40:32 +0100985 [this, &candidates](int a, int b) {
Tony Maka44b3082020-08-13 18:57:10 +0100986 return GetPriorityScore(
987 candidates.annotated_spans[0][a].classification) >
988 GetPriorityScore(
989 candidates.annotated_spans[0][b].classification);
Tony Mak378c1f52019-03-04 15:58:11 +0000990 });
991
Lukas Zilkab23e2122018-02-09 10:25:19 +0100992 for (const int i : candidate_indices) {
Tony Maka44b3082020-08-13 18:57:10 +0100993 if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
994 SpansOverlap(candidates.annotated_spans[0][i].span,
995 original_click_indices)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200996 // Run model classification if not present but requested and there's a
997 // classification collection filter specified.
Tony Maka44b3082020-08-13 18:57:10 +0100998 if (candidates.annotated_spans[0][i].classification.empty() &&
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200999 model_->selection_options()->always_classify_suggested_selection() &&
1000 !filtered_collections_selection_.empty()) {
Tony Makf5fd3652021-03-18 19:23:10 +00001001 if (!ModelClassifyText(context, /*cached_tokens=*/{},
1002 detected_text_language_tags,
1003 candidates.annotated_spans[0][i].span, options,
1004 &interpreter_manager,
1005 /*embedding_cache=*/nullptr,
1006 &candidates.annotated_spans[0][i].classification,
1007 /*tokens=*/nullptr)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001008 return original_click_indices;
1009 }
1010 }
1011
1012 // Ignore if span classification is filtered.
Tony Maka44b3082020-08-13 18:57:10 +01001013 if (FilteredForSelection(candidates.annotated_spans[0][i])) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001014 return original_click_indices;
1015 }
1016
Tony Mak8a501052021-02-24 20:08:27 +00001017 // We return a suggested span contains the original span.
1018 // This compensates for "select all" selection that may come from
1019 // other apps. See http://b/179890518.
1020 if (SpanContains(candidates.annotated_spans[0][i].span,
1021 original_click_indices)) {
1022 return candidates.annotated_spans[0][i].span;
1023 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001024 }
1025 }
1026
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001027 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001028}
1029
1030namespace {
1031// Helper function that returns the index of the first candidate that
1032// transitively does not overlap with the candidate on 'start_index'. If the end
1033// of 'candidates' is reached, it returns the index that points right behind the
1034// array.
1035int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1036 int start_index) {
1037 int first_non_overlapping = start_index + 1;
1038 CodepointSpan conflicting_span = candidates[start_index].span;
1039 while (
1040 first_non_overlapping < candidates.size() &&
1041 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1042 // Grow the span to include the current one.
1043 conflicting_span.second = std::max(
1044 conflicting_span.second, candidates[first_non_overlapping].span.second);
1045
1046 ++first_non_overlapping;
1047 }
1048 return first_non_overlapping;
1049}
1050} // namespace
1051
Tony Mak378c1f52019-03-04 15:58:11 +00001052bool Annotator::ResolveConflicts(
1053 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1054 const std::vector<Token>& cached_tokens,
1055 const std::vector<Locale>& detected_text_language_tags,
Tony Mak13125532021-01-13 21:12:07 +00001056 const BaseOptions& options, InterpreterManager* interpreter_manager,
1057 std::vector<int>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001058 result->clear();
1059 result->reserve(candidates.size());
1060 for (int i = 0; i < candidates.size();) {
1061 int first_non_overlapping =
1062 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1063
1064 const bool conflict_found = first_non_overlapping != (i + 1);
1065 if (conflict_found) {
1066 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00001067 if (!ResolveConflict(context, cached_tokens, candidates,
1068 detected_text_language_tags, i,
Tony Mak13125532021-01-13 21:12:07 +00001069 first_non_overlapping, options, interpreter_manager,
1070 &candidate_indices)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001071 return false;
1072 }
1073 result->insert(result->end(), candidate_indices.begin(),
1074 candidate_indices.end());
1075 } else {
1076 result->push_back(i);
1077 }
1078
1079 // Skip over the whole conflicting group/go to next candidate.
1080 i = first_non_overlapping;
1081 }
1082 return true;
1083}
1084
1085namespace {
Tony Mak448b5862019-03-22 13:36:41 +00001086// Returns true, if the given two sources do conflict in given annotation
1087// usecase.
1088// - In SMART usecase, all sources do conflict, because there's only 1 possible
1089// annotation for a given span.
1090// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1091// and duration), while others not (e.g. duration and number).
1092bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1093 const AnnotatedSpan::Source source1,
1094 const AnnotatedSpan::Source source2) {
1095 uint32 source_mask =
1096 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1097
Tony Mak378c1f52019-03-04 15:58:11 +00001098 switch (annotation_usecase) {
1099 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
Tony Mak448b5862019-03-22 13:36:41 +00001100 // In the SMART mode, all annotations conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001101 return true;
Tony Mak448b5862019-03-22 13:36:41 +00001102
Tony Mak378c1f52019-03-04 15:58:11 +00001103 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
Tony Mak448b5862019-03-22 13:36:41 +00001104 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1105 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1106 // hours" (duration).
1107 if ((source_mask &
1108 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1109 (source_mask &
1110 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1111 return false;
Tony Mak378c1f52019-03-04 15:58:11 +00001112 }
Tony Mak448b5862019-03-22 13:36:41 +00001113
1114 // A KNOWLEDGE entity does not conflict with anything.
1115 if ((source_mask &
1116 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1117 return false;
1118 }
1119
Tony Makd0ae7c62020-03-27 13:58:00 +00001120 // A PERSONNAME entity does not conflict with anything.
1121 if ((source_mask &
1122 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1123 return false;
1124 }
1125
Tony Mak448b5862019-03-22 13:36:41 +00001126 // Entities from other sources can conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001127 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001128 }
1129}
1130} // namespace
1131
Tony Mak378c1f52019-03-04 15:58:11 +00001132bool Annotator::ResolveConflict(
1133 const std::string& context, const std::vector<Token>& cached_tokens,
1134 const std::vector<AnnotatedSpan>& candidates,
1135 const std::vector<Locale>& detected_text_language_tags, int start_index,
Tony Mak13125532021-01-13 21:12:07 +00001136 int end_index, const BaseOptions& options,
Tony Mak378c1f52019-03-04 15:58:11 +00001137 InterpreterManager* interpreter_manager,
1138 std::vector<int>* chosen_indices) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001139 std::vector<int> conflicting_indices;
Tony Mak76d80962020-01-08 17:30:51 +00001140 std::unordered_map<int, std::pair<float, int>> scores_lengths;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001141 for (int i = start_index; i < end_index; ++i) {
1142 conflicting_indices.push_back(i);
1143 if (!candidates[i].classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001144 scores_lengths[i] = {
1145 GetPriorityScore(candidates[i].classification),
1146 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001147 continue;
1148 }
1149
1150 // OPTIMIZATION: So that we don't have to classify all the ML model
1151 // spans apriori, we wait until we get here, when they conflict with
1152 // something and we need the actual classification scores. So if the
1153 // candidate conflicts and comes from the model, we need to run a
1154 // classification to determine its priority:
1155 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001156 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
Tony Mak13125532021-01-13 21:12:07 +00001157 candidates[i].span, options, interpreter_manager,
Tony Makf5fd3652021-03-18 19:23:10 +00001158 /*embedding_cache=*/nullptr, &classification,
1159 /*tokens=*/nullptr)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001160 return false;
1161 }
1162
1163 if (!classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001164 scores_lengths[i] = {
1165 GetPriorityScore(classification),
1166 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001167 }
1168 }
1169
Tony Mak5a12b942020-05-01 12:41:31 +01001170 std::sort(
1171 conflicting_indices.begin(), conflicting_indices.end(),
1172 [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1173 if (scores_lengths[i].first == scores_lengths[j].first &&
1174 prioritize_longest_annotation_) {
1175 return scores_lengths[i].second > scores_lengths[j].second;
1176 }
1177 return scores_lengths[i].first > scores_lengths[j].first;
1178 });
Lukas Zilkab23e2122018-02-09 10:25:19 +01001179
Tony Mak448b5862019-03-22 13:36:41 +00001180 // Here we keep a set of indices that were chosen, per-source, to enable
1181 // effective computation.
1182 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1183 chosen_indices_for_source_map;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001184
1185 // Greedily place the candidates if they don't conflict with the already
1186 // placed ones.
1187 for (int i = 0; i < conflicting_indices.size(); ++i) {
1188 const int considered_candidate = conflicting_indices[i];
Tony Mak448b5862019-03-22 13:36:41 +00001189
1190 // See if there is a conflict between the candidate and all already placed
1191 // candidates.
1192 bool conflict = false;
1193 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1194 for (auto& source_set_pair : chosen_indices_for_source_map) {
1195 if (source_set_pair.first == candidates[considered_candidate].source) {
1196 chosen_indices_for_source_ptr = &source_set_pair.second;
1197 }
1198
Tony Mak5a12b942020-05-01 12:41:31 +01001199 const bool needs_conflict_resolution =
Tony Mak13125532021-01-13 21:12:07 +00001200 options.annotation_usecase ==
1201 AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1202 (options.annotation_usecase ==
1203 AnnotationUsecase_ANNOTATION_USECASE_RAW &&
Tony Mak5a12b942020-05-01 12:41:31 +01001204 do_conflict_resolution_in_raw_mode_);
1205 if (needs_conflict_resolution &&
Tony Mak13125532021-01-13 21:12:07 +00001206 DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
Tony Mak448b5862019-03-22 13:36:41 +00001207 candidates[considered_candidate].source) &&
1208 DoesCandidateConflict(considered_candidate, candidates,
1209 source_set_pair.second)) {
1210 conflict = true;
1211 break;
1212 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001213 }
Tony Mak448b5862019-03-22 13:36:41 +00001214
1215 // Skip the candidate if a conflict was found.
1216 if (conflict) {
1217 continue;
1218 }
1219
1220 // If the set of indices for the current source doesn't exist yet,
1221 // initialize it.
1222 if (chosen_indices_for_source_ptr == nullptr) {
1223 SortedIntSet new_set([&candidates](int a, int b) {
1224 return candidates[a].span.first < candidates[b].span.first;
1225 });
1226 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1227 std::move(new_set);
1228 chosen_indices_for_source_ptr =
1229 &chosen_indices_for_source_map[candidates[considered_candidate]
1230 .source];
1231 }
1232
1233 // Place the candidate to the output and to the per-source conflict set.
1234 chosen_indices->push_back(considered_candidate);
1235 chosen_indices_for_source_ptr->insert(considered_candidate);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001236 }
1237
Tony Mak378c1f52019-03-04 15:58:11 +00001238 std::sort(chosen_indices->begin(), chosen_indices->end());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001239
1240 return true;
1241}
1242
Tony Mak6c4cc672018-09-17 11:48:50 +01001243bool Annotator::ModelSuggestSelection(
Tony Maka44b3082020-08-13 18:57:10 +01001244 const UnicodeText& context_unicode, const CodepointSpan& click_indices,
Tony Makdf54e742019-03-26 14:04:00 +00001245 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001246 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001247 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001248 if (model_->triggering_options() == nullptr ||
1249 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1250 return true;
1251 }
1252
Tony Makdf54e742019-03-26 14:04:00 +00001253 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1254 ml_model_triggering_locales_,
1255 /*default_value=*/true)) {
1256 return true;
1257 }
1258
Lukas Zilka21d8c982018-01-24 11:11:20 +01001259 int click_pos;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001260 *tokens = selection_feature_processor_->Tokenize(context_unicode);
Tony Makf5fd3652021-03-18 19:23:10 +00001261 const auto [click_begin, click_end] =
1262 CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001263 selection_feature_processor_->RetokenizeAndFindClick(
Tony Makf5fd3652021-03-18 19:23:10 +00001264 context_unicode, click_begin, click_end, click_indices,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001265 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001266 tokens, &click_pos);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001267 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001268 TC3_VLOG(1) << "Could not calculate the click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001269 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001270 }
1271
1272 const int symmetry_context_size =
1273 model_->selection_options()->symmetry_context_size();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001274 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
Lukas Zilkab23e2122018-02-09 10:25:19 +01001275 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1276 ->bounds_sensitive_features();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001277
1278 // The symmetry context span is the clicked token with symmetry_context_size
1279 // tokens on either side.
Tony Maka44b3082020-08-13 18:57:10 +01001280 const TokenSpan symmetry_context_span =
1281 IntersectTokenSpans(TokenSpan(click_pos).Expand(
1282 /*num_tokens_left=*/symmetry_context_size,
1283 /*num_tokens_right=*/symmetry_context_size),
1284 AllOf(*tokens));
Lukas Zilka21d8c982018-01-24 11:11:20 +01001285
Lukas Zilkab23e2122018-02-09 10:25:19 +01001286 // Compute the extraction span based on the model type.
1287 TokenSpan extraction_span;
1288 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1289 // The extraction span is the symmetry context span expanded to include
1290 // max_selection_span tokens on either side, which is how far a selection
1291 // can stretch from the click, plus a relevant number of tokens outside of
1292 // the bounds of the selection.
1293 const int max_selection_span =
1294 selection_feature_processor_->GetOptions()->max_selection_span();
Tony Maka44b3082020-08-13 18:57:10 +01001295 extraction_span = symmetry_context_span.Expand(
1296 /*num_tokens_left=*/max_selection_span +
1297 bounds_sensitive_features->num_tokens_before(),
1298 /*num_tokens_right=*/max_selection_span +
1299 bounds_sensitive_features->num_tokens_after());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001300 } else {
1301 // The extraction span is the symmetry context span expanded to include
1302 // context_size tokens on either side.
1303 const int context_size =
1304 selection_feature_processor_->GetOptions()->context_size();
Tony Maka44b3082020-08-13 18:57:10 +01001305 extraction_span = symmetry_context_span.Expand(
1306 /*num_tokens_left=*/context_size,
1307 /*num_tokens_right=*/context_size);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001308 }
Tony Maka44b3082020-08-13 18:57:10 +01001309 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001310
Lukas Zilka434442d2018-04-25 11:38:51 +02001311 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1312 *tokens, extraction_span)) {
1313 return true;
1314 }
1315
Lukas Zilkab23e2122018-02-09 10:25:19 +01001316 std::unique_ptr<CachedFeatures> cached_features;
1317 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001318 *tokens, extraction_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001319 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1320 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001321 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001322 selection_feature_processor_->EmbeddingSize() +
1323 selection_feature_processor_->DenseFeaturesCount(),
1324 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001325 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001326 return false;
1327 }
1328
1329 // Produce selection model candidates.
1330 std::vector<TokenSpan> chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001331 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001332 interpreter_manager->SelectionInterpreter(), *cached_features,
1333 &chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001334 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001335 return false;
1336 }
1337
1338 for (const TokenSpan& chunk : chunks) {
1339 AnnotatedSpan candidate;
1340 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001341 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001342 if (model_->selection_options()->strip_unpaired_brackets()) {
1343 candidate.span =
1344 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1345 }
1346
1347 // Only output non-empty spans.
1348 if (candidate.span.first != candidate.span.second) {
1349 result->push_back(candidate);
1350 }
1351 }
1352 return true;
1353}
1354
Lukas Zilkaba849e72018-03-08 14:48:21 +01001355namespace internal {
1356std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
Tony Maka44b3082020-08-13 18:57:10 +01001357 const CodepointSpan& selection_indices,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001358 TokenSpan tokens_around_selection_to_copy) {
1359 const auto first_selection_token = std::upper_bound(
1360 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1361 [](int selection_start, const Token& token) {
1362 return selection_start < token.end;
1363 });
1364 const auto last_selection_token = std::lower_bound(
1365 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1366 [](const Token& token, int selection_end) {
1367 return token.start < selection_end;
1368 });
1369
1370 const int64 first_token = std::max(
1371 static_cast<int64>(0),
1372 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1373 tokens_around_selection_to_copy.first));
1374 const int64 last_token = std::min(
1375 static_cast<int64>(cached_tokens.size()),
1376 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1377 tokens_around_selection_to_copy.second));
1378
1379 std::vector<Token> tokens;
1380 tokens.reserve(last_token - first_token);
1381 for (int i = first_token; i < last_token; ++i) {
1382 tokens.push_back(cached_tokens[i]);
1383 }
1384 return tokens;
1385}
1386} // namespace internal
1387
Tony Mak6c4cc672018-09-17 11:48:50 +01001388TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001389 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1390 bounds_sensitive_features =
1391 classification_feature_processor_->GetOptions()
1392 ->bounds_sensitive_features();
1393 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1394 // The extraction span is the selection span expanded to include a relevant
1395 // number of tokens outside of the bounds of the selection.
1396 return {bounds_sensitive_features->num_tokens_before(),
1397 bounds_sensitive_features->num_tokens_after()};
1398 } else {
1399 // The extraction span is the clicked token with context_size tokens on
1400 // either side.
1401 const int context_size =
1402 selection_feature_processor_->GetOptions()->context_size();
1403 return {context_size, context_size};
1404 }
1405}
1406
Tony Mak378c1f52019-03-04 15:58:11 +00001407namespace {
1408// Sorts the classification results from high score to low score.
1409void SortClassificationResults(
1410 std::vector<ClassificationResult>* classification_results) {
1411 std::sort(classification_results->begin(), classification_results->end(),
1412 [](const ClassificationResult& a, const ClassificationResult& b) {
1413 return a.score > b.score;
1414 });
1415}
1416} // namespace
1417
Tony Mak6c4cc672018-09-17 11:48:50 +01001418bool Annotator::ModelClassifyText(
Lukas Zilkaba849e72018-03-08 14:48:21 +01001419 const std::string& context, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001420 const std::vector<Locale>& detected_text_language_tags,
Tony Mak13125532021-01-13 21:12:07 +00001421 const CodepointSpan& selection_indices, const BaseOptions& options,
Tony Maka44b3082020-08-13 18:57:10 +01001422 InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001423 FeatureProcessor::EmbeddingCache* embedding_cache,
Tony Makf5fd3652021-03-18 19:23:10 +00001424 std::vector<ClassificationResult>* classification_results,
1425 std::vector<Token>* tokens) const {
1426 const UnicodeText context_unicode =
1427 UTF8ToUnicodeText(context, /*do_copy=*/false);
1428 const auto [span_begin, span_end] =
1429 CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1430 return ModelClassifyText(context_unicode, cached_tokens,
1431 detected_text_language_tags, span_begin, span_end,
1432 /*line=*/nullptr, selection_indices, options,
1433 interpreter_manager, embedding_cache,
1434 classification_results, tokens);
Tony Mak378c1f52019-03-04 15:58:11 +00001435}
1436
1437bool Annotator::ModelClassifyText(
Tony Makf5fd3652021-03-18 19:23:10 +00001438 const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001439 const std::vector<Locale>& detected_text_language_tags,
Tony Makf5fd3652021-03-18 19:23:10 +00001440 const UnicodeText::const_iterator& span_begin,
1441 const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
Tony Mak13125532021-01-13 21:12:07 +00001442 const CodepointSpan& selection_indices, const BaseOptions& options,
Tony Maka44b3082020-08-13 18:57:10 +01001443 InterpreterManager* interpreter_manager,
Tony Mak378c1f52019-03-04 15:58:11 +00001444 FeatureProcessor::EmbeddingCache* embedding_cache,
1445 std::vector<ClassificationResult>* classification_results,
1446 std::vector<Token>* tokens) const {
1447 if (model_->triggering_options() == nullptr ||
1448 !(model_->triggering_options()->enabled_modes() &
1449 ModeFlag_CLASSIFICATION)) {
1450 return true;
1451 }
1452
Tony Makdf54e742019-03-26 14:04:00 +00001453 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1454 ml_model_triggering_locales_,
1455 /*default_value=*/true)) {
1456 return true;
1457 }
1458
Tony Makf5fd3652021-03-18 19:23:10 +00001459 std::vector<Token> local_tokens;
1460 if (tokens == nullptr) {
1461 tokens = &local_tokens;
1462 }
1463
Lukas Zilkaba849e72018-03-08 14:48:21 +01001464 if (cached_tokens.empty()) {
Tony Makf5fd3652021-03-18 19:23:10 +00001465 *tokens = classification_feature_processor_->Tokenize(context_unicode);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001466 } else {
Tony Mak378c1f52019-03-04 15:58:11 +00001467 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1468 ClassifyTextUpperBoundNeededTokens());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001469 }
1470
Lukas Zilkab23e2122018-02-09 10:25:19 +01001471 int click_pos;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001472 classification_feature_processor_->RetokenizeAndFindClick(
Tony Makf5fd3652021-03-18 19:23:10 +00001473 context_unicode, span_begin, span_end, selection_indices,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001474 classification_feature_processor_->GetOptions()
1475 ->only_use_line_with_click(),
Tony Mak378c1f52019-03-04 15:58:11 +00001476 tokens, &click_pos);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001477 const TokenSpan selection_token_span =
Tony Mak378c1f52019-03-04 15:58:11 +00001478 CodepointSpanToTokenSpan(*tokens, selection_indices);
Tony Maka44b3082020-08-13 18:57:10 +01001479 const int selection_num_tokens = selection_token_span.Size();
Lukas Zilka434442d2018-04-25 11:38:51 +02001480 if (model_->classification_options()->max_num_tokens() > 0 &&
1481 model_->classification_options()->max_num_tokens() <
1482 selection_num_tokens) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001483 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001484 return true;
1485 }
1486
Lukas Zilkab23e2122018-02-09 10:25:19 +01001487 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1488 bounds_sensitive_features =
1489 classification_feature_processor_->GetOptions()
1490 ->bounds_sensitive_features();
1491 if (selection_token_span.first == kInvalidIndex ||
1492 selection_token_span.second == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001493 TC3_LOG(ERROR) << "Could not determine span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001494 return false;
1495 }
1496
1497 // Compute the extraction span based on the model type.
1498 TokenSpan extraction_span;
1499 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1500 // The extraction span is the selection span expanded to include a relevant
1501 // number of tokens outside of the bounds of the selection.
Tony Maka44b3082020-08-13 18:57:10 +01001502 extraction_span = selection_token_span.Expand(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001503 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1504 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1505 } else {
1506 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001507 TC3_LOG(ERROR) << "Couldn't choose a click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001508 return false;
1509 }
1510 // The extraction span is the clicked token with context_size tokens on
1511 // either side.
1512 const int context_size =
Lukas Zilkaba849e72018-03-08 14:48:21 +01001513 classification_feature_processor_->GetOptions()->context_size();
Tony Maka44b3082020-08-13 18:57:10 +01001514 extraction_span = TokenSpan(click_pos).Expand(
1515 /*num_tokens_left=*/context_size,
1516 /*num_tokens_right=*/context_size);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001517 }
Tony Maka44b3082020-08-13 18:57:10 +01001518 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
Lukas Zilka21d8c982018-01-24 11:11:20 +01001519
Lukas Zilka434442d2018-04-25 11:38:51 +02001520 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak378c1f52019-03-04 15:58:11 +00001521 *tokens, extraction_span)) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001522 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001523 return true;
1524 }
1525
Lukas Zilka21d8c982018-01-24 11:11:20 +01001526 std::unique_ptr<CachedFeatures> cached_features;
1527 if (!classification_feature_processor_->ExtractFeatures(
Tony Mak378c1f52019-03-04 15:58:11 +00001528 *tokens, extraction_span, selection_indices,
1529 embedding_executor_.get(), embedding_cache,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001530 classification_feature_processor_->EmbeddingSize() +
1531 classification_feature_processor_->DenseFeaturesCount(),
1532 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001533 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001534 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001535 }
1536
Lukas Zilkab23e2122018-02-09 10:25:19 +01001537 std::vector<float> features;
1538 features.reserve(cached_features->OutputFeaturesSize());
1539 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1540 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1541 &features);
1542 } else {
1543 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001544 }
1545
Lukas Zilkaba849e72018-03-08 14:48:21 +01001546 TensorView<float> logits = classification_executor_->ComputeLogits(
1547 TensorView<float>(features.data(),
1548 {1, static_cast<int>(features.size())}),
1549 interpreter_manager->ClassificationInterpreter());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001550 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001551 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001552 return false;
1553 }
1554
1555 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1556 logits.dim(1) != classification_feature_processor_->NumCollections()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001557 TC3_LOG(ERROR) << "Mismatching output";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001558 return false;
1559 }
1560
1561 const std::vector<float> scores =
1562 ComputeSoftmax(logits.data(), logits.dim(1));
1563
Tony Mak81e52422019-04-30 09:34:45 +01001564 if (scores.empty()) {
1565 *classification_results = {{Collections::Other(), 1.0}};
1566 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001567 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001568
Tony Mak81e52422019-04-30 09:34:45 +01001569 const int best_score_index =
1570 std::max_element(scores.begin(), scores.end()) - scores.begin();
1571 const std::string top_collection =
1572 classification_feature_processor_->LabelToCollection(best_score_index);
1573
1574 // Sanity checks.
1575 if (top_collection == Collections::Phone()) {
Tony Makf5fd3652021-03-18 19:23:10 +00001576 const int digit_count = std::count_if(span_begin, span_end, IsDigit);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001577 if (digit_count <
1578 model_->classification_options()->phone_min_num_digits() ||
1579 digit_count >
1580 model_->classification_options()->phone_max_num_digits()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001581 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001582 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001583 }
Tony Mak81e52422019-04-30 09:34:45 +01001584 } else if (top_collection == Collections::Address()) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001585 if (selection_num_tokens <
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001586 model_->classification_options()->address_min_num_tokens()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001587 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001588 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001589 }
Tony Mak81e52422019-04-30 09:34:45 +01001590 } else if (top_collection == Collections::Dictionary()) {
Tony Mak13125532021-01-13 21:12:07 +00001591 if ((options.use_vocab_annotator && vocab_annotator_) ||
1592 !Locale::IsAnyLocaleSupported(detected_text_language_tags,
Tony Mak378c1f52019-03-04 15:58:11 +00001593 dictionary_locales_,
Tony Makdf54e742019-03-26 14:04:00 +00001594 /*default_value=*/false)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001595 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001596 return true;
Tony Mak378c1f52019-03-04 15:58:11 +00001597 }
1598 }
Tony Makd99d58c2020-03-19 21:52:02 +00001599 *classification_results = {{top_collection, /*arg_score=*/1.0,
1600 /*arg_priority_score=*/scores[best_score_index]}};
1601
1602 // For some entities, we might want to clamp the priority score, for better
1603 // conflict resolution between entities.
1604 if (model_->triggering_options() != nullptr &&
1605 model_->triggering_options()->collection_to_priority() != nullptr) {
1606 if (auto entry =
1607 model_->triggering_options()->collection_to_priority()->LookupByKey(
1608 top_collection.c_str())) {
1609 (*classification_results)[0].priority_score *= entry->value();
1610 }
1611 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001612 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001613}
1614
Tony Mak6c4cc672018-09-17 11:48:50 +01001615bool Annotator::RegexClassifyText(
Tony Maka44b3082020-08-13 18:57:10 +01001616 const std::string& context, const CodepointSpan& selection_indices,
Tony Mak378c1f52019-03-04 15:58:11 +00001617 std::vector<ClassificationResult>* classification_result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001618 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001619 UTF8ToUnicodeText(context, /*do_copy=*/false)
1620 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001621 const UnicodeText selection_text_unicode(
1622 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1623
1624 // Check whether any of the regular expressions match.
1625 for (const int pattern_id : classification_regex_patterns_) {
1626 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1627 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1628 regex_pattern.pattern->Matcher(selection_text_unicode);
1629 int status = UniLib::RegexMatcher::kNoError;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001630 bool matches;
Tony Mak854015a2019-01-16 15:56:48 +00001631 if (regex_pattern.config->use_approximate_matching()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001632 matches = matcher->ApproximatelyMatches(&status);
1633 } else {
1634 matches = matcher->Matches(&status);
1635 }
1636 if (status != UniLib::RegexMatcher::kNoError) {
1637 return false;
1638 }
Tony Makdf54e742019-03-26 14:04:00 +00001639 if (matches && VerifyRegexMatchCandidate(
1640 context, regex_pattern.config->verification_options(),
1641 selection_text, matcher.get())) {
Tony Mak378c1f52019-03-04 15:58:11 +00001642 classification_result->push_back(
1643 {regex_pattern.config->collection_name()->str(),
1644 regex_pattern.config->target_classification_score(),
1645 regex_pattern.config->priority_score()});
Tony Makd9446602019-02-20 18:25:39 +00001646 if (!SerializedEntityDataFromRegexMatch(
1647 regex_pattern.config, matcher.get(),
Tony Mak378c1f52019-03-04 15:58:11 +00001648 &classification_result->back().serialized_entity_data)) {
Tony Makd9446602019-02-20 18:25:39 +00001649 TC3_LOG(ERROR) << "Could not get entity data.";
1650 return false;
1651 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001652 }
1653 }
1654
Tony Mak378c1f52019-03-04 15:58:11 +00001655 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001656}
1657
Tony Mak5dc5e112019-02-01 14:52:10 +00001658namespace {
1659std::string PickCollectionForDatetime(
1660 const DatetimeParseResult& datetime_parse_result) {
1661 switch (datetime_parse_result.granularity) {
1662 case GRANULARITY_HOUR:
1663 case GRANULARITY_MINUTE:
1664 case GRANULARITY_SECOND:
1665 return Collections::DateTime();
1666 default:
1667 return Collections::Date();
1668 }
1669}
Tony Mak83d2de62019-04-10 16:12:15 +01001670
Tony Mak5dc5e112019-02-01 14:52:10 +00001671} // namespace
1672
Tony Mak6c4cc672018-09-17 11:48:50 +01001673bool Annotator::DatetimeClassifyText(
Tony Maka44b3082020-08-13 18:57:10 +01001674 const std::string& context, const CodepointSpan& selection_indices,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001675 const ClassificationOptions& options,
Tony Mak854015a2019-01-16 15:56:48 +00001676 std::vector<ClassificationResult>* classification_results) const {
Tony Mak13125532021-01-13 21:12:07 +00001677 if (!datetime_parser_) {
Tony Makd99d58c2020-03-19 21:52:02 +00001678 return true;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001679 }
1680
Lukas Zilkab23e2122018-02-09 10:25:19 +01001681 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001682 UTF8ToUnicodeText(context, /*do_copy=*/false)
1683 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001684
Tony Mak13125532021-01-13 21:12:07 +00001685 LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1686 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1687 datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1688 options.reference_timezone, locale_list,
1689 ModeFlag_CLASSIFICATION,
1690 options.annotation_usecase,
1691 /*anchor_start_end=*/true);
1692 if (!result_status.ok()) {
1693 TC3_LOG(ERROR) << "Error during parsing datetime.";
1694 return false;
Tony Makd99d58c2020-03-19 21:52:02 +00001695 }
1696
Tony Mak13125532021-01-13 21:12:07 +00001697 for (const DatetimeParseResultSpan& datetime_span :
1698 result_status.ValueOrDie()) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001699 // Only consider the result valid if the selection and extracted datetime
1700 // spans exactly match.
Tony Maka44b3082020-08-13 18:57:10 +01001701 if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1702 datetime_span.span.second + selection_indices.first) ==
Lukas Zilkab23e2122018-02-09 10:25:19 +01001703 selection_indices) {
Tony Mak854015a2019-01-16 15:56:48 +00001704 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1705 classification_results->emplace_back(
Tony Mak5dc5e112019-02-01 14:52:10 +00001706 PickCollectionForDatetime(parse_result),
1707 datetime_span.target_classification_score);
Tony Mak854015a2019-01-16 15:56:48 +00001708 classification_results->back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01001709 classification_results->back().serialized_entity_data =
1710 CreateDatetimeSerializedEntityData(parse_result);
Tony Mak378c1f52019-03-04 15:58:11 +00001711 classification_results->back().priority_score =
1712 datetime_span.priority_score;
Tony Mak854015a2019-01-16 15:56:48 +00001713 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001714 return true;
1715 }
1716 }
Tony Mak378c1f52019-03-04 15:58:11 +00001717 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001718}
1719
Tony Mak6c4cc672018-09-17 11:48:50 +01001720std::vector<ClassificationResult> Annotator::ClassifyText(
Tony Maka44b3082020-08-13 18:57:10 +01001721 const std::string& context, const CodepointSpan& selection_indices,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001722 const ClassificationOptions& options) const {
Tony Mak13125532021-01-13 21:12:07 +00001723 if (context.size() > std::numeric_limits<int>::max()) {
1724 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1725 return {};
1726 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001727 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001728 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001729 return {};
1730 }
Tony Mak5a12b942020-05-01 12:41:31 +01001731 if (options.annotation_usecase !=
1732 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1733 TC3_LOG(WARNING)
1734 << "Invoking ClassifyText, which is not supported in RAW mode.";
1735 return {};
1736 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001737 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1738 return {};
1739 }
1740
Tony Makdf54e742019-03-26 14:04:00 +00001741 std::vector<Locale> detected_text_language_tags;
1742 if (!ParseLocales(options.detected_text_language_tags,
1743 &detected_text_language_tags)) {
1744 TC3_LOG(WARNING)
1745 << "Failed to parse the detected_text_language_tags in options: "
1746 << options.detected_text_language_tags;
1747 }
1748 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1749 model_triggering_locales_,
1750 /*default_value=*/true)) {
1751 return {};
1752 }
1753
Tony Mak13125532021-01-13 21:12:07 +00001754 const UnicodeText context_unicode =
1755 UTF8ToUnicodeText(context, /*do_copy=*/false);
1756
1757 if (!unilib_->IsValidUtf8(context_unicode)) {
1758 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1759 return {};
1760 }
1761
1762 if (!IsValidSpanInput(context_unicode, selection_indices)) {
Tony Mak968412a2019-11-13 15:39:57 +00001763 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
Tony Maka44b3082020-08-13 18:57:10 +01001764 << selection_indices.first << " " << selection_indices.second;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001765 return {};
1766 }
1767
Tony Mak378c1f52019-03-04 15:58:11 +00001768 // We'll accumulate a list of candidates, and pick the best candidate in the
1769 // end.
1770 std::vector<AnnotatedSpan> candidates;
1771
Tony Mak6c4cc672018-09-17 11:48:50 +01001772 // Try the knowledge engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001773 // TODO(b/126579108): Propagate error status.
Tony Mak6c4cc672018-09-17 11:48:50 +01001774 ClassificationResult knowledge_result;
Tony Mak63959242020-02-07 18:31:16 +00001775 if (knowledge_engine_ &&
Tony Makf5fd3652021-03-18 19:23:10 +00001776 knowledge_engine_
1777 ->ClassifyText(context, selection_indices, options.annotation_usecase,
1778 options.location_context, Permissions(),
1779 &knowledge_result)
1780 .ok()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001781 candidates.push_back({selection_indices, {knowledge_result}});
1782 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
Tony Mak854015a2019-01-16 15:56:48 +00001783 }
1784
Tony Maka2a1ff42019-09-12 15:40:32 +01001785 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1786
Tony Mak854015a2019-01-16 15:56:48 +00001787 // Try the contact engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001788 // TODO(b/126579108): Propagate error status.
Tony Mak854015a2019-01-16 15:56:48 +00001789 ClassificationResult contact_result;
1790 if (contact_engine_ && contact_engine_->ClassifyText(
1791 context, selection_indices, &contact_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001792 candidates.push_back({selection_indices, {contact_result}});
Tony Mak6c4cc672018-09-17 11:48:50 +01001793 }
1794
Tony Mak76d80962020-01-08 17:30:51 +00001795 // Try the person name engine.
1796 ClassificationResult person_name_result;
1797 if (person_name_engine_ &&
1798 person_name_engine_->ClassifyText(context, selection_indices,
1799 &person_name_result)) {
1800 candidates.push_back({selection_indices, {person_name_result}});
Tony Makd0ae7c62020-03-27 13:58:00 +00001801 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
Tony Mak76d80962020-01-08 17:30:51 +00001802 }
1803
Tony Makd9446602019-02-20 18:25:39 +00001804 // Try the installed app engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001805 // TODO(b/126579108): Propagate error status.
Tony Makd9446602019-02-20 18:25:39 +00001806 ClassificationResult installed_app_result;
1807 if (installed_app_engine_ &&
1808 installed_app_engine_->ClassifyText(context, selection_indices,
1809 &installed_app_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001810 candidates.push_back({selection_indices, {installed_app_result}});
Tony Makd9446602019-02-20 18:25:39 +00001811 }
1812
Lukas Zilkab23e2122018-02-09 10:25:19 +01001813 // Try the regular expression models.
Tony Mak378c1f52019-03-04 15:58:11 +00001814 std::vector<ClassificationResult> regex_results;
1815 if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1816 return {};
1817 }
1818 for (const ClassificationResult& result : regex_results) {
1819 candidates.push_back({selection_indices, {result}});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001820 }
1821
Lukas Zilkab23e2122018-02-09 10:25:19 +01001822 // Try the date model.
Tony Mak378c1f52019-03-04 15:58:11 +00001823 //
1824 // DatetimeClassifyText only returns the first result, which can however have
1825 // more interpretations. They are inserted in the candidates as a single
1826 // AnnotatedSpan, so that they get treated together by the conflict resolution
1827 // algorithm.
Tony Mak854015a2019-01-16 15:56:48 +00001828 std::vector<ClassificationResult> datetime_results;
Tony Mak378c1f52019-03-04 15:58:11 +00001829 if (!DatetimeClassifyText(context, selection_indices, options,
1830 &datetime_results)) {
1831 return {};
1832 }
1833 if (!datetime_results.empty()) {
1834 candidates.push_back({selection_indices, std::move(datetime_results)});
Tony Mak448b5862019-03-22 13:36:41 +00001835 candidates.back().source = AnnotatedSpan::Source::DATETIME;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001836 }
1837
Tony Mak378c1f52019-03-04 15:58:11 +00001838 // Try the number annotator.
1839 // TODO(b/126579108): Propagate error status.
1840 ClassificationResult number_annotator_result;
1841 if (number_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001842 number_annotator_->ClassifyText(context_unicode, selection_indices,
1843 options.annotation_usecase,
1844 &number_annotator_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001845 candidates.push_back({selection_indices, {number_annotator_result}});
1846 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001847
Tony Makad2e22d2019-03-20 17:35:13 +00001848 // Try the duration annotator.
1849 ClassificationResult duration_annotator_result;
1850 if (duration_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001851 duration_annotator_->ClassifyText(context_unicode, selection_indices,
1852 options.annotation_usecase,
1853 &duration_annotator_result)) {
Tony Makad2e22d2019-03-20 17:35:13 +00001854 candidates.push_back({selection_indices, {duration_annotator_result}});
Tony Mak448b5862019-03-22 13:36:41 +00001855 candidates.back().source = AnnotatedSpan::Source::DURATION;
Tony Makad2e22d2019-03-20 17:35:13 +00001856 }
1857
Tony Mak63959242020-02-07 18:31:16 +00001858 // Try the translate annotator.
1859 ClassificationResult translate_annotator_result;
1860 if (translate_annotator_ &&
Tony Mak5a12b942020-05-01 12:41:31 +01001861 translate_annotator_->ClassifyText(context_unicode, selection_indices,
1862 options.user_familiar_language_tags,
1863 &translate_annotator_result)) {
Tony Mak63959242020-02-07 18:31:16 +00001864 candidates.push_back({selection_indices, {translate_annotator_result}});
1865 }
1866
Tony Mak21460022020-03-12 18:29:35 +00001867 // Try the grammar model.
1868 ClassificationResult grammar_annotator_result;
1869 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
Tony Mak5a12b942020-05-01 12:41:31 +01001870 detected_text_language_tags, context_unicode,
Tony Mak21460022020-03-12 18:29:35 +00001871 selection_indices, &grammar_annotator_result)) {
1872 candidates.push_back({selection_indices, {grammar_annotator_result}});
1873 }
1874
Tony Maka44b3082020-08-13 18:57:10 +01001875 ClassificationResult pod_ner_annotator_result;
1876 if (pod_ner_annotator_ && options.use_pod_ner &&
1877 pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1878 &pod_ner_annotator_result)) {
1879 candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1880 }
1881
Tony Maka5090082020-09-18 16:41:23 +01001882 ClassificationResult vocab_annotator_result;
Tony Mak13125532021-01-13 21:12:07 +00001883 if (vocab_annotator_ && options.use_vocab_annotator &&
Tony Maka5090082020-09-18 16:41:23 +01001884 vocab_annotator_->ClassifyText(
1885 context_unicode, selection_indices, detected_text_language_tags,
1886 options.trigger_dictionary_on_beginner_words,
1887 &vocab_annotator_result)) {
1888 candidates.push_back({selection_indices, {vocab_annotator_result}});
1889 }
1890
Tony Maka44b3082020-08-13 18:57:10 +01001891 if (experimental_annotator_) {
1892 experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1893 candidates);
Tony Mak5a12b942020-05-01 12:41:31 +01001894 }
1895
Tony Mak378c1f52019-03-04 15:58:11 +00001896 // Try the ML model.
1897 //
1898 // The output of the model is considered as an exclusive 1-of-N choice. That's
1899 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1900 // span for each candidate, like e.g. the regex model.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001901 InterpreterManager interpreter_manager(selection_executor_.get(),
1902 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001903 std::vector<ClassificationResult> model_results;
1904 std::vector<Token> tokens;
1905 if (!ModelClassifyText(
1906 context, /*cached_tokens=*/{}, detected_text_language_tags,
Tony Mak13125532021-01-13 21:12:07 +00001907 selection_indices, options, &interpreter_manager,
Tony Mak378c1f52019-03-04 15:58:11 +00001908 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1909 return {};
1910 }
1911 if (!model_results.empty()) {
1912 candidates.push_back({selection_indices, std::move(model_results)});
1913 }
1914
1915 std::vector<int> candidate_indices;
1916 if (!ResolveConflicts(candidates, context, tokens,
Tony Mak13125532021-01-13 21:12:07 +00001917 detected_text_language_tags, options,
Tony Mak378c1f52019-03-04 15:58:11 +00001918 &interpreter_manager, &candidate_indices)) {
1919 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1920 return {};
1921 }
1922
1923 std::vector<ClassificationResult> results;
1924 for (const int i : candidate_indices) {
1925 for (const ClassificationResult& result : candidates[i].classification) {
1926 if (!FilteredForClassification(result)) {
1927 results.push_back(result);
1928 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001929 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001930 }
1931
Tony Mak378c1f52019-03-04 15:58:11 +00001932 // Sort results according to score.
1933 std::sort(results.begin(), results.end(),
1934 [](const ClassificationResult& a, const ClassificationResult& b) {
1935 return a.score > b.score;
1936 });
1937
1938 if (results.empty()) {
Tony Mak81e52422019-04-30 09:34:45 +01001939 results = {{Collections::Other(), 1.0}};
Tony Mak378c1f52019-03-04 15:58:11 +00001940 }
Tony Mak378c1f52019-03-04 15:58:11 +00001941 return results;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001942}
1943
Tony Mak378c1f52019-03-04 15:58:11 +00001944bool Annotator::ModelAnnotate(
1945 const std::string& context,
1946 const std::vector<Locale>& detected_text_language_tags,
Tony Mak8a501052021-02-24 20:08:27 +00001947 const AnnotationOptions& options, InterpreterManager* interpreter_manager,
Tony Mak13125532021-01-13 21:12:07 +00001948 std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001949 if (model_->triggering_options() == nullptr ||
1950 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1951 return true;
1952 }
1953
Tony Makdf54e742019-03-26 14:04:00 +00001954 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1955 ml_model_triggering_locales_,
1956 /*default_value=*/true)) {
1957 return true;
1958 }
1959
Lukas Zilka21d8c982018-01-24 11:11:20 +01001960 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1961 /*do_copy=*/false);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001962 std::vector<UnicodeTextRange> lines;
1963 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1964 lines.push_back({context_unicode.begin(), context_unicode.end()});
1965 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +01001966 lines = selection_feature_processor_->SplitContext(
1967 context_unicode, selection_feature_processor_->GetOptions()
1968 ->use_pipe_character_for_newline());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001969 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001970
Lukas Zilkaba849e72018-03-08 14:48:21 +01001971 const float min_annotate_confidence =
1972 (model_->triggering_options() != nullptr
1973 ? model_->triggering_options()->min_annotate_confidence()
1974 : 0.f);
1975
Lukas Zilkab23e2122018-02-09 10:25:19 +01001976 for (const UnicodeTextRange& line : lines) {
Tony Mak408c6b82019-03-08 17:57:27 +00001977 FeatureProcessor::EmbeddingCache embedding_cache;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001978 const std::string line_str =
1979 UnicodeText::UTF8Substring(line.first, line.second);
1980
Tony Mak13125532021-01-13 21:12:07 +00001981 std::vector<Token> line_tokens;
1982 line_tokens = selection_feature_processor_->Tokenize(line_str);
1983
Lukas Zilkaba849e72018-03-08 14:48:21 +01001984 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001985 line_str, {0, std::distance(line.first, line.second)},
1986 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Tony Mak13125532021-01-13 21:12:07 +00001987 &line_tokens,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001988 /*click_pos=*/nullptr);
Tony Mak13125532021-01-13 21:12:07 +00001989 const TokenSpan full_line_span = {
1990 0, static_cast<TokenIndex>(line_tokens.size())};
Lukas Zilka21d8c982018-01-24 11:11:20 +01001991
Lukas Zilka434442d2018-04-25 11:38:51 +02001992 // TODO(zilka): Add support for greater granularity of this check.
1993 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak13125532021-01-13 21:12:07 +00001994 line_tokens, full_line_span)) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001995 continue;
1996 }
1997
Lukas Zilka21d8c982018-01-24 11:11:20 +01001998 std::unique_ptr<CachedFeatures> cached_features;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001999 if (!selection_feature_processor_->ExtractFeatures(
Tony Mak13125532021-01-13 21:12:07 +00002000 line_tokens, full_line_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002001 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2002 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01002003 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002004 selection_feature_processor_->EmbeddingSize() +
2005 selection_feature_processor_->DenseFeaturesCount(),
Lukas Zilka21d8c982018-01-24 11:11:20 +01002006 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002007 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002008 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002009 }
2010
2011 std::vector<TokenSpan> local_chunks;
Tony Mak13125532021-01-13 21:12:07 +00002012 if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002013 interpreter_manager->SelectionInterpreter(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002014 *cached_features, &local_chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002015 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002016 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002017 }
2018
2019 const int offset = std::distance(context_unicode.begin(), line.first);
Tony Mak8a501052021-02-24 20:08:27 +00002020 UnicodeText line_unicode;
2021 std::vector<UnicodeText::const_iterator> line_codepoints;
2022 if (options.enable_optimization) {
2023 if (local_chunks.empty()) {
2024 continue;
2025 }
2026 line_unicode = UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2027 line_codepoints = line_unicode.Codepoints();
2028 line_codepoints.push_back(line_unicode.end());
2029 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01002030 for (const TokenSpan& chunk : local_chunks) {
Tony Maka44b3082020-08-13 18:57:10 +01002031 CodepointSpan codepoint_span =
Tony Mak8a501052021-02-24 20:08:27 +00002032 TokenSpanToCodepointSpan(line_tokens, chunk);
2033 if (options.enable_optimization) {
2034 if (!codepoint_span.IsValid() ||
2035 codepoint_span.second > line_codepoints.size()) {
2036 continue;
2037 }
2038 codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2039 /*span_begin=*/line_codepoints[codepoint_span.first],
2040 /*span_end=*/line_codepoints[codepoint_span.second],
2041 codepoint_span);
2042 if (model_->selection_options()->strip_unpaired_brackets()) {
2043 codepoint_span = StripUnpairedBrackets(
2044 /*span_begin=*/line_codepoints[codepoint_span.first],
2045 /*span_end=*/line_codepoints[codepoint_span.second],
2046 codepoint_span, *unilib_);
2047 }
2048 } else {
2049 codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2050 line_str, codepoint_span);
2051 if (model_->selection_options()->strip_unpaired_brackets()) {
2052 codepoint_span =
2053 StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
2054 }
Tony Maka44b3082020-08-13 18:57:10 +01002055 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002056
2057 // Skip empty spans.
2058 if (codepoint_span.first != codepoint_span.second) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002059 std::vector<ClassificationResult> classification;
Tony Makf5fd3652021-03-18 19:23:10 +00002060 if (options.enable_optimization) {
2061 if (!ModelClassifyText(
2062 line_unicode, line_tokens, detected_text_language_tags,
2063 /*span_begin=*/line_codepoints[codepoint_span.first],
2064 /*span_end=*/line_codepoints[codepoint_span.second], &line,
2065 codepoint_span, options, interpreter_manager,
2066 &embedding_cache, &classification, /*tokens=*/nullptr)) {
2067 TC3_LOG(ERROR) << "Could not classify text: "
2068 << (codepoint_span.first + offset) << " "
2069 << (codepoint_span.second + offset);
2070 return false;
2071 }
2072 } else {
2073 if (!ModelClassifyText(line_str, line_tokens,
2074 detected_text_language_tags, codepoint_span,
2075 options, interpreter_manager, &embedding_cache,
2076 &classification, /*tokens=*/nullptr)) {
2077 TC3_LOG(ERROR) << "Could not classify text: "
2078 << (codepoint_span.first + offset) << " "
2079 << (codepoint_span.second + offset);
2080 return false;
2081 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01002082 }
2083
2084 // Do not include the span if it's classified as "other".
2085 if (!classification.empty() && !ClassifiedAsOther(classification) &&
2086 classification[0].score >= min_annotate_confidence) {
2087 AnnotatedSpan result_span;
2088 result_span.span = {codepoint_span.first + offset,
2089 codepoint_span.second + offset};
2090 result_span.classification = std::move(classification);
2091 result->push_back(std::move(result_span));
2092 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002093 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01002094 }
Tony Mak13125532021-01-13 21:12:07 +00002095
2096 // If we are going line-by-line, we need to insert the tokens for each line.
2097 // But if not, we can optimize and just std::move the current line vector to
2098 // the output.
2099 if (selection_feature_processor_->GetOptions()
2100 ->only_use_line_with_click()) {
2101 tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2102 } else {
2103 *tokens = std::move(line_tokens);
2104 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01002105 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002106 return true;
2107}
2108
Tony Mak6c4cc672018-09-17 11:48:50 +01002109const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
Lukas Zilka434442d2018-04-25 11:38:51 +02002110 return selection_feature_processor_.get();
2111}
2112
Tony Mak6c4cc672018-09-17 11:48:50 +01002113const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
Lukas Zilka434442d2018-04-25 11:38:51 +02002114 const {
2115 return classification_feature_processor_.get();
Lukas Zilkab23e2122018-02-09 10:25:19 +01002116}
2117
Tony Mak6c4cc672018-09-17 11:48:50 +01002118const DatetimeParser* Annotator::DatetimeParserForTests() const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002119 return datetime_parser_.get();
2120}
2121
Tony Mak83d2de62019-04-10 16:12:15 +01002122void Annotator::RemoveNotEnabledEntityTypes(
2123 const EnabledEntityTypes& is_entity_type_enabled,
2124 std::vector<AnnotatedSpan>* annotated_spans) const {
2125 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2126 std::vector<ClassificationResult>& classifications =
2127 annotated_span.classification;
2128 classifications.erase(
2129 std::remove_if(classifications.begin(), classifications.end(),
2130 [&is_entity_type_enabled](
2131 const ClassificationResult& classification_result) {
2132 return !is_entity_type_enabled(
2133 classification_result.collection);
2134 }),
2135 classifications.end());
2136 }
2137 annotated_spans->erase(
2138 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2139 [](const AnnotatedSpan& annotated_span) {
2140 return annotated_span.classification.empty();
2141 }),
2142 annotated_spans->end());
2143}
2144
Tony Maka2a1ff42019-09-12 15:40:32 +01002145void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2146 std::vector<AnnotatedSpan>* candidates) const {
2147 if (candidates == nullptr || contact_engine_ == nullptr) {
2148 return;
2149 }
2150 for (auto& candidate : *candidates) {
2151 for (auto& classification_result : candidate.classification) {
2152 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2153 &classification_result);
2154 }
2155 }
2156}
2157
Tony Makff31efb2020-03-31 11:13:06 +01002158Status Annotator::AnnotateSingleInput(
2159 const std::string& context, const AnnotationOptions& options,
2160 std::vector<AnnotatedSpan>* candidates) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002161 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
Tony Makff31efb2020-03-31 11:13:06 +01002162 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
Lukas Zilkaba849e72018-03-08 14:48:21 +01002163 }
2164
Tony Mak854015a2019-01-16 15:56:48 +00002165 const UnicodeText context_unicode =
2166 UTF8ToUnicodeText(context, /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002167
Tony Mak378c1f52019-03-04 15:58:11 +00002168 std::vector<Locale> detected_text_language_tags;
2169 if (!ParseLocales(options.detected_text_language_tags,
2170 &detected_text_language_tags)) {
2171 TC3_LOG(WARNING)
Tony Makdf54e742019-03-26 14:04:00 +00002172 << "Failed to parse the detected_text_language_tags in options: "
Tony Mak378c1f52019-03-04 15:58:11 +00002173 << options.detected_text_language_tags;
2174 }
Tony Makdf54e742019-03-26 14:04:00 +00002175 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2176 model_triggering_locales_,
2177 /*default_value=*/true)) {
Tony Makff31efb2020-03-31 11:13:06 +01002178 return Status(
2179 StatusCode::UNAVAILABLE,
2180 "The detected language tags are not in the supported locales.");
Tony Makdf54e742019-03-26 14:04:00 +00002181 }
2182
2183 InterpreterManager interpreter_manager(selection_executor_.get(),
2184 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00002185
Tony Mak13125532021-01-13 21:12:07 +00002186 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2187 const bool is_raw_usecase =
2188 options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2189
Lukas Zilkab23e2122018-02-09 10:25:19 +01002190 // Annotate with the selection model.
Tony Mak13125532021-01-13 21:12:07 +00002191 const bool model_annotations_enabled =
2192 !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002193 std::vector<Token> tokens;
Tony Mak13125532021-01-13 21:12:07 +00002194 if (model_annotations_enabled &&
2195 !ModelAnnotate(context, detected_text_language_tags, options,
2196 &interpreter_manager, &tokens, candidates)) {
Tony Makff31efb2020-03-31 11:13:06 +01002197 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
Tony Mak13125532021-01-13 21:12:07 +00002198 } else if (!model_annotations_enabled) {
2199 // If the ML model didn't run, we need to tokenize to support the other
2200 // annotators that depend on the tokens.
2201 // Optimization could be made to only do this when an annotator that uses
2202 // the tokens is enabled, but it's unclear if the added complexity is worth
2203 // it.
2204 if (selection_feature_processor_ != nullptr) {
2205 tokens = selection_feature_processor_->Tokenize(context_unicode);
2206 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002207 }
2208
2209 // Annotate with the regular expression models.
Tony Mak13125532021-01-13 21:12:07 +00002210 const bool regex_annotations_enabled =
2211 !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2212 if (regex_annotations_enabled &&
2213 !RegexChunk(
Tony Maka44b3082020-08-13 18:57:10 +01002214 UTF8ToUnicodeText(context, /*do_copy=*/false),
2215 annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2216 is_entity_type_enabled, options.annotation_usecase, candidates)) {
Tony Makff31efb2020-03-31 11:13:06 +01002217 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002218 }
2219
2220 // Annotate with the datetime model.
Tony Mak13125532021-01-13 21:12:07 +00002221 // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2222 // relatively slow for some clients.
Tony Mak83d2de62019-04-10 16:12:15 +01002223 if ((is_entity_type_enabled(Collections::Date()) ||
2224 is_entity_type_enabled(Collections::DateTime())) &&
2225 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002226 options.reference_time_ms_utc, options.reference_timezone,
Tony Makd9446602019-02-20 18:25:39 +00002227 options.locales, ModeFlag_ANNOTATION,
Tony Mak83d2de62019-04-10 16:12:15 +01002228 options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002229 options.is_serialized_entity_data_enabled, candidates)) {
2230 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
Tony Mak6c4cc672018-09-17 11:48:50 +01002231 }
2232
Tony Mak854015a2019-01-16 15:56:48 +00002233 // Annotate with the contact engine.
Tony Mak13125532021-01-13 21:12:07 +00002234 const bool contact_annotations_enabled =
2235 !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2236 if (contact_annotations_enabled && contact_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002237 !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2238 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
Tony Mak854015a2019-01-16 15:56:48 +00002239 }
2240
Tony Makd9446602019-02-20 18:25:39 +00002241 // Annotate with the installed app engine.
Tony Mak13125532021-01-13 21:12:07 +00002242 const bool app_annotations_enabled =
2243 !is_raw_usecase || is_entity_type_enabled(Collections::App());
2244 if (app_annotations_enabled && installed_app_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002245 !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2246 return Status(StatusCode::INTERNAL,
2247 "Couldn't run installed app engine Chunk.");
Tony Makd9446602019-02-20 18:25:39 +00002248 }
2249
Tony Mak378c1f52019-03-04 15:58:11 +00002250 // Annotate with the number annotator.
Tony Mak13125532021-01-13 21:12:07 +00002251 const bool number_annotations_enabled =
2252 !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2253 is_entity_type_enabled(Collections::Percentage()));
Tony Maka44b3082020-08-13 18:57:10 +01002254 if (number_annotations_enabled && number_annotator_ != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +00002255 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002256 candidates)) {
2257 return Status(StatusCode::INTERNAL,
2258 "Couldn't run number annotator FindAll.");
Tony Makad2e22d2019-03-20 17:35:13 +00002259 }
2260
2261 // Annotate with the duration annotator.
Tony Mak13125532021-01-13 21:12:07 +00002262 const bool duration_annotations_enabled =
2263 !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2264 if (duration_annotations_enabled && duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +00002265 !duration_annotator_->FindAll(context_unicode, tokens,
Tony Makff31efb2020-03-31 11:13:06 +01002266 options.annotation_usecase, candidates)) {
2267 return Status(StatusCode::INTERNAL,
2268 "Couldn't run duration annotator FindAll.");
Tony Mak378c1f52019-03-04 15:58:11 +00002269 }
2270
Tony Mak76d80962020-01-08 17:30:51 +00002271 // Annotate with the person name engine.
Tony Mak13125532021-01-13 21:12:07 +00002272 const bool person_annotations_enabled =
2273 !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2274 if (person_annotations_enabled && person_name_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002275 !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2276 return Status(StatusCode::INTERNAL,
2277 "Couldn't run person name engine Chunk.");
Tony Mak76d80962020-01-08 17:30:51 +00002278 }
2279
Tony Mak21460022020-03-12 18:29:35 +00002280 // Annotate with the grammar annotators.
2281 if (grammar_annotator_ != nullptr &&
2282 !grammar_annotator_->Annotate(detected_text_language_tags,
Tony Makff31efb2020-03-31 11:13:06 +01002283 context_unicode, candidates)) {
2284 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
Tony Mak21460022020-03-12 18:29:35 +00002285 }
2286
Tony Maka44b3082020-08-13 18:57:10 +01002287 // Annotate with the POD NER annotator.
Tony Mak13125532021-01-13 21:12:07 +00002288 const bool pod_ner_annotations_enabled =
2289 !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2290 if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2291 options.use_pod_ner &&
Tony Maka44b3082020-08-13 18:57:10 +01002292 !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2293 return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2294 }
2295
Tony Maka5090082020-09-18 16:41:23 +01002296 // Annotate with the vocab annotator.
Tony Mak13125532021-01-13 21:12:07 +00002297 const bool vocab_annotations_enabled =
2298 !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2299 if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2300 options.use_vocab_annotator &&
Tony Maka5090082020-09-18 16:41:23 +01002301 !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2302 options.trigger_dictionary_on_beginner_words,
2303 candidates)) {
2304 return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2305 }
2306
Tony Maka44b3082020-08-13 18:57:10 +01002307 // Annotate with the experimental annotator.
Tony Mak5a12b942020-05-01 12:41:31 +01002308 if (experimental_annotator_ != nullptr &&
2309 !experimental_annotator_->Annotate(context_unicode, candidates)) {
2310 return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2311 }
2312
Lukas Zilkab23e2122018-02-09 10:25:19 +01002313 // Sort candidates according to their position in the input, so that the next
2314 // code can assume that any connected component of overlapping spans forms a
2315 // contiguous block.
Tony Mak5a12b942020-05-01 12:41:31 +01002316 // Also sort them according to the end position and collection, so that the
2317 // deduplication code below can assume that same spans and classifications
2318 // form contiguous blocks.
Tony Makff31efb2020-03-31 11:13:06 +01002319 std::sort(candidates->begin(), candidates->end(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002320 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
Tony Mak5a12b942020-05-01 12:41:31 +01002321 if (a.span.first != b.span.first) {
2322 return a.span.first < b.span.first;
2323 }
2324
2325 if (a.span.second != b.span.second) {
2326 return a.span.second < b.span.second;
2327 }
2328
2329 return a.classification[0].collection <
2330 b.classification[0].collection;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002331 });
2332
2333 std::vector<int> candidate_indices;
Tony Makff31efb2020-03-31 11:13:06 +01002334 if (!ResolveConflicts(*candidates, context, tokens,
Tony Mak13125532021-01-13 21:12:07 +00002335 detected_text_language_tags, options,
Tony Mak378c1f52019-03-04 15:58:11 +00002336 &interpreter_manager, &candidate_indices)) {
Tony Makff31efb2020-03-31 11:13:06 +01002337 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002338 }
2339
Tony Mak5a12b942020-05-01 12:41:31 +01002340 // Remove candidates that overlap exactly and have the same collection.
2341 // This can e.g. happen for phone coming from both ML model and regex.
2342 candidate_indices.erase(
2343 std::unique(candidate_indices.begin(), candidate_indices.end(),
2344 [&candidates](const int a_index, const int b_index) {
2345 const AnnotatedSpan& a = (*candidates)[a_index];
2346 const AnnotatedSpan& b = (*candidates)[b_index];
2347 return a.span == b.span &&
2348 a.classification[0].collection ==
2349 b.classification[0].collection;
2350 }),
2351 candidate_indices.end());
2352
Lukas Zilkab23e2122018-02-09 10:25:19 +01002353 std::vector<AnnotatedSpan> result;
2354 result.reserve(candidate_indices.size());
2355 for (const int i : candidate_indices) {
Tony Makff31efb2020-03-31 11:13:06 +01002356 if ((*candidates)[i].classification.empty() ||
2357 ClassifiedAsOther((*candidates)[i].classification) ||
2358 FilteredForAnnotation((*candidates)[i])) {
Tony Mak378c1f52019-03-04 15:58:11 +00002359 continue;
2360 }
Tony Mak5a12b942020-05-01 12:41:31 +01002361 result.push_back(std::move((*candidates)[i]));
Tony Mak378c1f52019-03-04 15:58:11 +00002362 }
2363
Tony Mak83d2de62019-04-10 16:12:15 +01002364 // We generate all candidates and remove them later (with the exception of
2365 // date/time/duration entities) because there are complex interdependencies
2366 // between the entity types. E.g., the TLD of an email can be interpreted as a
2367 // URL, but most likely a user of the API does not want such annotations if
2368 // "url" is enabled and "email" is not.
2369 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2370
Tony Mak378c1f52019-03-04 15:58:11 +00002371 for (AnnotatedSpan& annotated_span : result) {
2372 SortClassificationResults(&annotated_span.classification);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002373 }
Tony Makff31efb2020-03-31 11:13:06 +01002374 *candidates = result;
2375 return Status::OK;
2376}
Lukas Zilkab23e2122018-02-09 10:25:19 +01002377
Tony Maka44b3082020-08-13 18:57:10 +01002378StatusOr<Annotations> Annotator::AnnotateStructuredInput(
Tony Makff31efb2020-03-31 11:13:06 +01002379 const std::vector<InputFragment>& string_fragments,
2380 const AnnotationOptions& options) const {
Tony Maka44b3082020-08-13 18:57:10 +01002381 Annotations annotation_candidates;
2382 annotation_candidates.annotated_spans.resize(string_fragments.size());
Tony Makff31efb2020-03-31 11:13:06 +01002383
2384 std::vector<std::string> text_to_annotate;
2385 text_to_annotate.reserve(string_fragments.size());
Tony Mak13125532021-01-13 21:12:07 +00002386 std::vector<FragmentMetadata> fragment_metadata;
2387 fragment_metadata.reserve(string_fragments.size());
Tony Makff31efb2020-03-31 11:13:06 +01002388 for (const auto& string_fragment : string_fragments) {
2389 text_to_annotate.push_back(string_fragment.text);
Tony Mak13125532021-01-13 21:12:07 +00002390 fragment_metadata.push_back(
2391 {.relative_bounding_box_top = string_fragment.bounding_box_top,
2392 .relative_bounding_box_height = string_fragment.bounding_box_height});
Tony Makff31efb2020-03-31 11:13:06 +01002393 }
2394
2395 // KnowledgeEngine is special, because it supports annotation of multiple
2396 // fragments at once.
2397 if (knowledge_engine_ &&
2398 !knowledge_engine_
Tony Mak13125532021-01-13 21:12:07 +00002399 ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2400 options.annotation_usecase,
Tony Mak90d55672020-04-15 18:20:44 +01002401 options.location_context, options.permissions,
Tony Maka44b3082020-08-13 18:57:10 +01002402 options.annotate_mode, &annotation_candidates)
Tony Makff31efb2020-03-31 11:13:06 +01002403 .ok()) {
2404 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2405 }
2406 // The annotator engines shouldn't change the number of annotation vectors.
Tony Maka44b3082020-08-13 18:57:10 +01002407 if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
Tony Makff31efb2020-03-31 11:13:06 +01002408 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2409 << " texts to annotate but generated a different number of "
2410 "lists of annotations:"
Tony Maka44b3082020-08-13 18:57:10 +01002411 << annotation_candidates.annotated_spans.size();
Tony Makff31efb2020-03-31 11:13:06 +01002412 return Status(StatusCode::INTERNAL,
2413 "Number of annotation candidates differs from "
2414 "number of texts to annotate.");
2415 }
2416
Tony Maka44b3082020-08-13 18:57:10 +01002417 // As an optimization, if the only annotated type is Entity, we skip all the
2418 // other annotators than the KnowledgeEngine. This only happens in the raw
2419 // mode, to make sure it does not affect the result.
2420 if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2421 options.entity_types.size() == 1 &&
2422 *options.entity_types.begin() == Collections::Entity()) {
2423 return annotation_candidates;
2424 }
2425
Tony Makff31efb2020-03-31 11:13:06 +01002426 // Other annotators run on each fragment independently.
2427 for (int i = 0; i < text_to_annotate.size(); ++i) {
2428 AnnotationOptions annotation_options = options;
2429 if (string_fragments[i].datetime_options.has_value()) {
2430 DatetimeOptions reference_datetime =
2431 string_fragments[i].datetime_options.value();
2432 annotation_options.reference_time_ms_utc =
2433 reference_datetime.reference_time_ms_utc;
2434 annotation_options.reference_timezone =
2435 reference_datetime.reference_timezone;
2436 }
2437
2438 AddContactMetadataToKnowledgeClassificationResults(
Tony Maka44b3082020-08-13 18:57:10 +01002439 &annotation_candidates.annotated_spans[i]);
Tony Makff31efb2020-03-31 11:13:06 +01002440
Tony Maka44b3082020-08-13 18:57:10 +01002441 Status annotation_status =
2442 AnnotateSingleInput(text_to_annotate[i], annotation_options,
2443 &annotation_candidates.annotated_spans[i]);
Tony Makff31efb2020-03-31 11:13:06 +01002444 if (!annotation_status.ok()) {
2445 return annotation_status;
2446 }
2447 }
2448 return annotation_candidates;
2449}
2450
2451std::vector<AnnotatedSpan> Annotator::Annotate(
2452 const std::string& context, const AnnotationOptions& options) const {
Tony Mak13125532021-01-13 21:12:07 +00002453 if (context.size() > std::numeric_limits<int>::max()) {
2454 TC3_LOG(ERROR) << "Rejecting too long input.";
2455 return {};
2456 }
2457
2458 const UnicodeText context_unicode =
2459 UTF8ToUnicodeText(context, /*do_copy=*/false);
2460 if (!unilib_->IsValidUtf8(context_unicode)) {
2461 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2462 return {};
2463 }
2464
Tony Makff31efb2020-03-31 11:13:06 +01002465 std::vector<InputFragment> string_fragments;
2466 string_fragments.push_back({.text = context});
Tony Maka44b3082020-08-13 18:57:10 +01002467 StatusOr<Annotations> annotations =
Tony Makff31efb2020-03-31 11:13:06 +01002468 AnnotateStructuredInput(string_fragments, options);
2469 if (!annotations.ok()) {
2470 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2471 << annotations.status().error_message();
2472 return {};
2473 }
Tony Maka44b3082020-08-13 18:57:10 +01002474 return annotations.ValueOrDie().annotated_spans[0];
Lukas Zilka21d8c982018-01-24 11:11:20 +01002475}
2476
Tony Mak854015a2019-01-16 15:56:48 +00002477CodepointSpan Annotator::ComputeSelectionBoundaries(
2478 const UniLib::RegexMatcher* match,
2479 const RegexModel_::Pattern* config) const {
2480 if (config->capturing_group() == nullptr) {
2481 // Use first capturing group to specify the selection.
2482 int status = UniLib::RegexMatcher::kNoError;
2483 const CodepointSpan result = {match->Start(1, &status),
2484 match->End(1, &status)};
2485 if (status != UniLib::RegexMatcher::kNoError) {
2486 return {kInvalidIndex, kInvalidIndex};
2487 }
2488 return result;
2489 }
2490
2491 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2492 const int num_groups = config->capturing_group()->size();
2493 for (int i = 0; i < num_groups; i++) {
2494 if (!config->capturing_group()->Get(i)->extend_selection()) {
2495 continue;
2496 }
2497
2498 int status = UniLib::RegexMatcher::kNoError;
2499 // Check match and adjust bounds.
2500 const int group_start = match->Start(i, &status);
2501 const int group_end = match->End(i, &status);
2502 if (status != UniLib::RegexMatcher::kNoError) {
2503 return {kInvalidIndex, kInvalidIndex};
2504 }
2505 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2506 continue;
2507 }
2508 if (result.first == kInvalidIndex) {
2509 result = {group_start, group_end};
2510 } else {
2511 result.first = std::min(result.first, group_start);
2512 result.second = std::max(result.second, group_end);
2513 }
2514 }
2515 return result;
2516}
2517
Tony Makd9446602019-02-20 18:25:39 +00002518bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
Tony Mak21460022020-03-12 18:29:35 +00002519 if (pattern->serialized_entity_data() != nullptr ||
2520 pattern->entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002521 return true;
2522 }
2523 if (pattern->capturing_group() != nullptr) {
Tony Mak63959242020-02-07 18:31:16 +00002524 for (const CapturingGroup* group : *pattern->capturing_group()) {
Tony Makd9446602019-02-20 18:25:39 +00002525 if (group->entity_field_path() != nullptr) {
2526 return true;
2527 }
Tony Mak21460022020-03-12 18:29:35 +00002528 if (group->serialized_entity_data() != nullptr ||
2529 group->entity_data() != nullptr) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002530 return true;
2531 }
Tony Makd9446602019-02-20 18:25:39 +00002532 }
2533 }
2534 return false;
2535}
2536
2537bool Annotator::SerializedEntityDataFromRegexMatch(
2538 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2539 std::string* serialized_entity_data) const {
2540 if (!HasEntityData(pattern)) {
2541 serialized_entity_data->clear();
2542 return true;
2543 }
2544 TC3_CHECK(entity_data_builder_ != nullptr);
2545
Tony Maka44b3082020-08-13 18:57:10 +01002546 std::unique_ptr<MutableFlatbuffer> entity_data =
Tony Makd9446602019-02-20 18:25:39 +00002547 entity_data_builder_->NewRoot();
2548
2549 TC3_CHECK(entity_data != nullptr);
2550
Tony Mak21460022020-03-12 18:29:35 +00002551 // Set fixed entity data.
Tony Makd9446602019-02-20 18:25:39 +00002552 if (pattern->serialized_entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002553 entity_data->MergeFromSerializedFlatbuffer(
2554 StringPiece(pattern->serialized_entity_data()->c_str(),
2555 pattern->serialized_entity_data()->size()));
2556 }
Tony Mak21460022020-03-12 18:29:35 +00002557 if (pattern->entity_data() != nullptr) {
2558 entity_data->MergeFrom(
2559 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2560 }
Tony Makd9446602019-02-20 18:25:39 +00002561
2562 // Add entity data from rule capturing groups.
2563 if (pattern->capturing_group() != nullptr) {
2564 const int num_groups = pattern->capturing_group()->size();
2565 for (int i = 0; i < num_groups; i++) {
Tony Mak63959242020-02-07 18:31:16 +00002566 const CapturingGroup* group = pattern->capturing_group()->Get(i);
Tony Maka2a1ff42019-09-12 15:40:32 +01002567
2568 // Check whether the group matched.
2569 Optional<std::string> group_match_text =
2570 GetCapturingGroupText(matcher, /*group_id=*/i);
2571 if (!group_match_text.has_value()) {
Tony Makd9446602019-02-20 18:25:39 +00002572 continue;
2573 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002574
Tony Mak21460022020-03-12 18:29:35 +00002575 // Set fixed entity data from capturing group match.
Tony Maka2a1ff42019-09-12 15:40:32 +01002576 if (group->serialized_entity_data() != nullptr) {
2577 entity_data->MergeFromSerializedFlatbuffer(
2578 StringPiece(group->serialized_entity_data()->c_str(),
2579 group->serialized_entity_data()->size()));
2580 }
Tony Mak21460022020-03-12 18:29:35 +00002581 if (group->entity_data() != nullptr) {
2582 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2583 pattern->entity_data()));
2584 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002585
2586 // Set entity field from capturing group text.
2587 if (group->entity_field_path() != nullptr) {
Tony Mak8cd7ba62019-10-15 15:29:22 +01002588 UnicodeText normalized_group_match_text =
2589 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2590
2591 // Apply normalization if specified.
2592 if (group->normalization_options() != nullptr) {
2593 normalized_group_match_text =
Tony Mak1ac2e4a2020-04-29 13:41:53 +01002594 NormalizeText(*unilib_, group->normalization_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +01002595 normalized_group_match_text);
2596 }
2597
2598 if (!entity_data->ParseAndSet(
2599 group->entity_field_path(),
2600 normalized_group_match_text.ToUTF8String())) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002601 TC3_LOG(ERROR)
2602 << "Could not set entity data from rule capturing group.";
2603 return false;
2604 }
Tony Makd9446602019-02-20 18:25:39 +00002605 }
2606 }
2607 }
2608
2609 *serialized_entity_data = entity_data->Serialize();
2610 return true;
2611}
2612
Tony Mak63959242020-02-07 18:31:16 +00002613UnicodeText RemoveMoneySeparators(
2614 const std::unordered_set<char32>& decimal_separators,
2615 const UnicodeText& amount,
2616 UnicodeText::const_iterator it_decimal_separator) {
2617 UnicodeText whole_amount;
2618 for (auto it = amount.begin();
2619 it != amount.end() && it != it_decimal_separator; ++it) {
2620 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2621 static_cast<char32>(*it)) == decimal_separators.end()) {
2622 whole_amount.push_back(*it);
2623 }
2624 }
2625 return whole_amount;
2626}
2627
Tony Maka44b3082020-08-13 18:57:10 +01002628void Annotator::GetMoneyQuantityFromCapturingGroup(
2629 const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2630 const UnicodeText& context_unicode, std::string* quantity,
2631 int* exponent) const {
2632 if (config->capturing_group() == nullptr) {
2633 *exponent = 0;
2634 return;
2635 }
2636
2637 const int num_groups = config->capturing_group()->size();
2638 for (int i = 0; i < num_groups; i++) {
2639 int status = UniLib::RegexMatcher::kNoError;
2640 const int group_start = match->Start(i, &status);
2641 const int group_end = match->End(i, &status);
2642 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2643 continue;
2644 }
2645
2646 *quantity =
2647 unilib_
2648 ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2649 group_end, /*do_copy=*/false))
2650 .ToUTF8String();
2651
2652 if (auto entry = model_->money_parsing_options()
2653 ->quantities_name_to_exponent()
2654 ->LookupByKey((*quantity).c_str())) {
2655 *exponent = entry->value();
2656 return;
2657 }
2658 }
2659 *exponent = 0;
2660}
2661
Tony Mak63959242020-02-07 18:31:16 +00002662bool Annotator::ParseAndFillInMoneyAmount(
Tony Maka44b3082020-08-13 18:57:10 +01002663 std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2664 const RegexModel_::Pattern* config,
2665 const UnicodeText& context_unicode) const {
Tony Mak63959242020-02-07 18:31:16 +00002666 std::unique_ptr<EntityDataT> data =
2667 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2668 *serialized_entity_data);
Tony Mak0b8b3322020-03-17 16:30:19 +00002669 if (data == nullptr) {
Tony Makc121edd2020-05-28 15:25:17 +01002670 if (model_->version() >= 706) {
2671 // This way of parsing money entity data is enabled for models newer than
2672 // v706, consequently logging errors only for them (b/156634162).
2673 TC3_LOG(ERROR)
2674 << "Data field is null when trying to parse Money Entity Data";
2675 }
Tony Mak0b8b3322020-03-17 16:30:19 +00002676 return false;
2677 }
2678 if (data->money->unnormalized_amount.empty()) {
Tony Makc121edd2020-05-28 15:25:17 +01002679 if (model_->version() >= 706) {
2680 // This way of parsing money entity data is enabled for models newer than
2681 // v706, consequently logging errors only for them (b/156634162).
2682 TC3_LOG(ERROR)
2683 << "Data unnormalized_amount is empty when trying to parse "
2684 "Money Entity Data";
2685 }
Tony Mak63959242020-02-07 18:31:16 +00002686 return false;
2687 }
2688
2689 UnicodeText amount =
2690 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2691 int separator_back_index = 0;
Tony Mak21460022020-03-12 18:29:35 +00002692 auto it_decimal_separator = --amount.end();
Tony Mak63959242020-02-07 18:31:16 +00002693 for (; it_decimal_separator != amount.begin();
2694 --it_decimal_separator, ++separator_back_index) {
2695 if (std::find(money_separators_.begin(), money_separators_.end(),
2696 static_cast<char32>(*it_decimal_separator)) !=
2697 money_separators_.end()) {
2698 break;
2699 }
2700 }
2701
2702 // If there are 3 digits after the last separator, we consider that a
2703 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2704 // If there is no separator in number, also that number is an int.
Tony Mak21460022020-03-12 18:29:35 +00002705 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
Tony Mak63959242020-02-07 18:31:16 +00002706 it_decimal_separator = amount.end();
2707 }
2708
2709 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2710 it_decimal_separator),
2711 &data->money->amount_whole_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002712 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2713 "amount: "
2714 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002715 return false;
2716 }
Tony Maka44b3082020-08-13 18:57:10 +01002717
Tony Mak63959242020-02-07 18:31:16 +00002718 if (it_decimal_separator == amount.end()) {
2719 data->money->amount_decimal_part = 0;
Tony Maka44b3082020-08-13 18:57:10 +01002720 data->money->nanos = 0;
Tony Mak63959242020-02-07 18:31:16 +00002721 } else {
2722 const int amount_codepoints_size = amount.size_codepoints();
Tony Maka44b3082020-08-13 18:57:10 +01002723 const UnicodeText decimal_part = UnicodeText::Substring(
2724 amount, amount_codepoints_size - separator_back_index,
2725 amount_codepoints_size, /*do_copy=*/false);
2726 if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002727 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2728 "the amount: "
2729 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002730 return false;
2731 }
Tony Maka44b3082020-08-13 18:57:10 +01002732 data->money->nanos = data->money->amount_decimal_part *
2733 pow(10, 9 - decimal_part.size_codepoints());
2734 }
2735
2736 if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2737 nullptr) {
2738 int quantity_exponent;
2739 std::string quantity;
2740 GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2741 &quantity, &quantity_exponent);
Chang Li97265d02021-04-27 20:49:53 +00002742 if (quantity_exponent > 0 && quantity_exponent <= 9) {
2743 const double amount_whole_part =
Tony Maka44b3082020-08-13 18:57:10 +01002744 data->money->amount_whole_part * pow(10, quantity_exponent) +
2745 data->money->nanos / pow(10, 9 - quantity_exponent);
Chang Li97265d02021-04-27 20:49:53 +00002746 // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2747 // (and `std::numeric_limits<int>::max()` to
2748 // `std::numeric_limits<int64>::max()`).
2749 if (amount_whole_part < std::numeric_limits<int>::max()) {
2750 data->money->amount_whole_part = amount_whole_part;
2751 data->money->nanos = data->money->nanos %
2752 static_cast<int>(pow(10, 9 - quantity_exponent)) *
2753 pow(10, quantity_exponent);
2754 }
Tony Mak074ee382020-09-30 19:11:00 +01002755 }
2756 if (quantity_exponent > 0) {
Tony Maka44b3082020-08-13 18:57:10 +01002757 data->money->unnormalized_amount = strings::JoinStrings(
2758 " ", {data->money->unnormalized_amount, quantity});
2759 }
Tony Mak63959242020-02-07 18:31:16 +00002760 }
2761
2762 *serialized_entity_data =
2763 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2764 return true;
2765}
2766
Tony Mak13125532021-01-13 21:12:07 +00002767bool Annotator::IsAnyModelEntityTypeEnabled(
2768 const EnabledEntityTypes& is_entity_type_enabled) const {
2769 if (model_->classification_feature_options() == nullptr ||
2770 model_->classification_feature_options()->collections() == nullptr) {
2771 return false;
2772 }
2773 for (int i = 0;
2774 i < model_->classification_feature_options()->collections()->size();
2775 i++) {
2776 if (is_entity_type_enabled(model_->classification_feature_options()
2777 ->collections()
2778 ->Get(i)
2779 ->str())) {
2780 return true;
2781 }
2782 }
2783 return false;
2784}
2785
2786bool Annotator::IsAnyRegexEntityTypeEnabled(
2787 const EnabledEntityTypes& is_entity_type_enabled) const {
2788 if (model_->regex_model() == nullptr ||
2789 model_->regex_model()->patterns() == nullptr) {
2790 return false;
2791 }
2792 for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2793 if (is_entity_type_enabled(model_->regex_model()
2794 ->patterns()
2795 ->Get(i)
2796 ->collection_name()
2797 ->str())) {
2798 return true;
2799 }
2800 }
2801 return false;
2802}
2803
2804bool Annotator::IsAnyPodNerEntityTypeEnabled(
2805 const EnabledEntityTypes& is_entity_type_enabled) const {
2806 if (pod_ner_annotator_ == nullptr) {
2807 return false;
2808 }
2809
2810 for (const std::string& collection :
2811 pod_ner_annotator_->GetSupportedCollections()) {
2812 if (is_entity_type_enabled(collection)) {
2813 return true;
2814 }
2815 }
2816 return false;
2817}
2818
Tony Mak6c4cc672018-09-17 11:48:50 +01002819bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2820 const std::vector<int>& rules,
Tony Maka44b3082020-08-13 18:57:10 +01002821 bool is_serialized_entity_data_enabled,
2822 const EnabledEntityTypes& enabled_entity_types,
2823 const AnnotationUsecase& annotation_usecase,
2824 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002825 for (int pattern_id : rules) {
2826 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
Tony Maka44b3082020-08-13 18:57:10 +01002827 if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2828 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2829 // No regex annotation type has been requested, skip regex annotation.
2830 continue;
2831 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002832 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2833 if (!matcher) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002834 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2835 << pattern_id;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002836 return false;
2837 }
2838
2839 int status = UniLib::RegexMatcher::kNoError;
2840 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
Tony Mak854015a2019-01-16 15:56:48 +00002841 if (regex_pattern.config->verification_options()) {
Tony Makdf54e742019-03-26 14:04:00 +00002842 if (!VerifyRegexMatchCandidate(
2843 context_unicode.ToUTF8String(),
2844 regex_pattern.config->verification_options(),
2845 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002846 continue;
2847 }
2848 }
Tony Makd9446602019-02-20 18:25:39 +00002849
2850 std::string serialized_entity_data;
Tony Mak83d2de62019-04-10 16:12:15 +01002851 if (is_serialized_entity_data_enabled) {
2852 if (!SerializedEntityDataFromRegexMatch(
2853 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2854 TC3_LOG(ERROR) << "Could not get entity data.";
2855 return false;
2856 }
Tony Mak63959242020-02-07 18:31:16 +00002857
Tony Maka44b3082020-08-13 18:57:10 +01002858 // Further parsing of money amount. Need this since regexes cannot have
2859 // empty groups that fill in entity data (amount_decimal_part and
2860 // quantity might be empty groups).
Tony Mak63959242020-02-07 18:31:16 +00002861 if (regex_pattern.config->collection_name()->str() ==
2862 Collections::Money()) {
Tony Maka44b3082020-08-13 18:57:10 +01002863 if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2864 regex_pattern.config,
2865 context_unicode)) {
Tony Makc121edd2020-05-28 15:25:17 +01002866 if (model_->version() >= 706) {
2867 // This way of parsing money entity data is enabled for models
2868 // newer than v706 => logging errors only for them (b/156634162).
2869 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2870 }
Tony Mak63959242020-02-07 18:31:16 +00002871 }
2872 }
Tony Makd9446602019-02-20 18:25:39 +00002873 }
2874
Lukas Zilkab23e2122018-02-09 10:25:19 +01002875 result->emplace_back();
Tony Mak854015a2019-01-16 15:56:48 +00002876
Lukas Zilkab23e2122018-02-09 10:25:19 +01002877 // Selection/annotation regular expressions need to specify a capturing
2878 // group specifying the selection.
Tony Mak854015a2019-01-16 15:56:48 +00002879 result->back().span =
2880 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2881
Lukas Zilkab23e2122018-02-09 10:25:19 +01002882 result->back().classification = {
Tony Mak854015a2019-01-16 15:56:48 +00002883 {regex_pattern.config->collection_name()->str(),
2884 regex_pattern.config->target_classification_score(),
2885 regex_pattern.config->priority_score()}};
Tony Makd9446602019-02-20 18:25:39 +00002886
2887 result->back().classification[0].serialized_entity_data =
2888 serialized_entity_data;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002889 }
2890 }
2891 return true;
2892}
2893
Tony Mak6c4cc672018-09-17 11:48:50 +01002894bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2895 tflite::Interpreter* selection_interpreter,
2896 const CachedFeatures& cached_features,
2897 std::vector<TokenSpan>* chunks) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01002898 const int max_selection_span =
2899 selection_feature_processor_->GetOptions()->max_selection_span();
Lukas Zilka21d8c982018-01-24 11:11:20 +01002900 // The inference span is the span of interest expanded to include
2901 // max_selection_span tokens on either side, which is how far a selection can
2902 // stretch from the click.
Tony Maka44b3082020-08-13 18:57:10 +01002903 const TokenSpan inference_span =
2904 IntersectTokenSpans(span_of_interest.Expand(
2905 /*num_tokens_left=*/max_selection_span,
2906 /*num_tokens_right=*/max_selection_span),
2907 {0, num_tokens});
Lukas Zilka21d8c982018-01-24 11:11:20 +01002908
2909 std::vector<ScoredChunk> scored_chunks;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002910 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2911 selection_feature_processor_->GetOptions()
2912 ->bounds_sensitive_features()
2913 ->enabled()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002914 if (!ModelBoundsSensitiveScoreChunks(
2915 num_tokens, span_of_interest, inference_span, cached_features,
2916 selection_interpreter, &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002917 return false;
2918 }
2919 } else {
2920 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002921 cached_features, selection_interpreter,
2922 &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002923 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002924 }
2925 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002926 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2927 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2928 return lhs.score < rhs.score;
2929 });
Lukas Zilka21d8c982018-01-24 11:11:20 +01002930
2931 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2932 // them greedily as long as they do not overlap with any previously picked
2933 // chunks.
Tony Maka44b3082020-08-13 18:57:10 +01002934 std::vector<bool> token_used(inference_span.Size());
Lukas Zilka21d8c982018-01-24 11:11:20 +01002935 chunks->clear();
2936 for (const ScoredChunk& scored_chunk : scored_chunks) {
2937 bool feasible = true;
2938 for (int i = scored_chunk.token_span.first;
2939 i < scored_chunk.token_span.second; ++i) {
2940 if (token_used[i - inference_span.first]) {
2941 feasible = false;
2942 break;
2943 }
2944 }
2945
2946 if (!feasible) {
2947 continue;
2948 }
2949
2950 for (int i = scored_chunk.token_span.first;
2951 i < scored_chunk.token_span.second; ++i) {
2952 token_used[i - inference_span.first] = true;
2953 }
2954
2955 chunks->push_back(scored_chunk.token_span);
2956 }
2957
2958 std::sort(chunks->begin(), chunks->end());
2959
2960 return true;
2961}
2962
Lukas Zilkab23e2122018-02-09 10:25:19 +01002963namespace {
2964// Updates the value at the given key in the map to maximum of the current value
2965// and the given value, or simply inserts the value if the key is not yet there.
2966template <typename Map>
2967void UpdateMax(Map* map, typename Map::key_type key,
2968 typename Map::mapped_type value) {
2969 const auto it = map->find(key);
2970 if (it != map->end()) {
2971 it->second = std::max(it->second, value);
2972 } else {
2973 (*map)[key] = value;
2974 }
2975}
2976} // namespace
2977
Tony Mak6c4cc672018-09-17 11:48:50 +01002978bool Annotator::ModelClickContextScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002979 int num_tokens, const TokenSpan& span_of_interest,
2980 const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002981 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002982 std::vector<ScoredChunk>* scored_chunks) const {
2983 const int max_batch_size = model_->selection_options()->batch_size();
2984
2985 std::vector<float> all_features;
2986 std::map<TokenSpan, float> chunk_scores;
2987 for (int batch_start = span_of_interest.first;
2988 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2989 const int batch_end =
2990 std::min(batch_start + max_batch_size, span_of_interest.second);
2991
2992 // Prepare features for the whole batch.
2993 all_features.clear();
2994 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2995 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2996 cached_features.AppendClickContextFeaturesForClick(click_pos,
2997 &all_features);
2998 }
2999
3000 // Run batched inference.
3001 const int batch_size = batch_end - batch_start;
3002 const int features_size = cached_features.OutputFeaturesSize();
3003 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01003004 TensorView<float>(all_features.data(), {batch_size, features_size}),
3005 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01003006 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01003007 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01003008 return false;
3009 }
3010 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3011 logits.dim(1) !=
3012 selection_feature_processor_->GetSelectionLabelCount()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01003013 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01003014 return false;
3015 }
3016
3017 // Save results.
3018 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3019 const std::vector<float> scores = ComputeSoftmax(
3020 logits.data() + logits.dim(1) * (click_pos - batch_start),
3021 logits.dim(1));
3022 for (int j = 0;
3023 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3024 TokenSpan relative_token_span;
3025 if (!selection_feature_processor_->LabelToTokenSpan(
3026 j, &relative_token_span)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01003027 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01003028 return false;
3029 }
Tony Maka44b3082020-08-13 18:57:10 +01003030 const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3031 relative_token_span.first, relative_token_span.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01003032 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3033 UpdateMax(&chunk_scores, candidate_span, scores[j]);
3034 }
3035 }
3036 }
3037 }
3038
3039 scored_chunks->clear();
3040 scored_chunks->reserve(chunk_scores.size());
3041 for (const auto& entry : chunk_scores) {
3042 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3043 }
3044
3045 return true;
3046}
3047
Tony Mak6c4cc672018-09-17 11:48:50 +01003048bool Annotator::ModelBoundsSensitiveScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01003049 int num_tokens, const TokenSpan& span_of_interest,
3050 const TokenSpan& inference_span, const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01003051 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01003052 std::vector<ScoredChunk>* scored_chunks) const {
3053 const int max_selection_span =
3054 selection_feature_processor_->GetOptions()->max_selection_span();
3055 const int max_chunk_length = selection_feature_processor_->GetOptions()
3056 ->selection_reduced_output_space()
3057 ? max_selection_span + 1
3058 : 2 * max_selection_span + 1;
Lukas Zilkaba849e72018-03-08 14:48:21 +01003059 const bool score_single_token_spans_as_zero =
3060 selection_feature_processor_->GetOptions()
3061 ->bounds_sensitive_features()
3062 ->score_single_token_spans_as_zero();
3063
3064 scored_chunks->clear();
3065 if (score_single_token_spans_as_zero) {
Tony Maka44b3082020-08-13 18:57:10 +01003066 scored_chunks->reserve(span_of_interest.Size());
Lukas Zilkaba849e72018-03-08 14:48:21 +01003067 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01003068
3069 // Prepare all chunk candidates into one batch:
3070 // - Are contained in the inference span
3071 // - Have a non-empty intersection with the span of interest
3072 // - Are at least one token long
3073 // - Are not longer than the maximum chunk length
3074 std::vector<TokenSpan> candidate_spans;
3075 for (int start = inference_span.first; start < span_of_interest.second;
3076 ++start) {
3077 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3078 for (int end = leftmost_end_index;
3079 end <= inference_span.second && end - start <= max_chunk_length;
3080 ++end) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01003081 const TokenSpan candidate_span = {start, end};
Tony Maka44b3082020-08-13 18:57:10 +01003082 if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01003083 // Do not include the single token span in the batch, add a zero score
3084 // for it directly to the output.
3085 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3086 } else {
3087 candidate_spans.push_back(candidate_span);
3088 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01003089 }
3090 }
3091
3092 const int max_batch_size = model_->selection_options()->batch_size();
3093
3094 std::vector<float> all_features;
Lukas Zilkaba849e72018-03-08 14:48:21 +01003095 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
Lukas Zilkab23e2122018-02-09 10:25:19 +01003096 for (int batch_start = 0; batch_start < candidate_spans.size();
3097 batch_start += max_batch_size) {
3098 const int batch_end = std::min(batch_start + max_batch_size,
3099 static_cast<int>(candidate_spans.size()));
3100
3101 // Prepare features for the whole batch.
3102 all_features.clear();
3103 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3104 for (int i = batch_start; i < batch_end; ++i) {
3105 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3106 &all_features);
3107 }
3108
3109 // Run batched inference.
3110 const int batch_size = batch_end - batch_start;
3111 const int features_size = cached_features.OutputFeaturesSize();
3112 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01003113 TensorView<float>(all_features.data(), {batch_size, features_size}),
3114 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01003115 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01003116 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01003117 return false;
3118 }
3119 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3120 logits.dim(1) != 1) {
Tony Mak6c4cc672018-09-17 11:48:50 +01003121 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01003122 return false;
3123 }
3124
3125 // Save results.
3126 for (int i = batch_start; i < batch_end; ++i) {
3127 scored_chunks->push_back(
3128 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3129 }
3130 }
3131
3132 return true;
3133}
3134
Tony Mak6c4cc672018-09-17 11:48:50 +01003135bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3136 int64 reference_time_ms_utc,
3137 const std::string& reference_timezone,
3138 const std::string& locales, ModeFlag mode,
Tony Makd9446602019-02-20 18:25:39 +00003139 AnnotationUsecase annotation_usecase,
Tony Mak83d2de62019-04-10 16:12:15 +01003140 bool is_serialized_entity_data_enabled,
Tony Mak6c4cc672018-09-17 11:48:50 +01003141 std::vector<AnnotatedSpan>* result) const {
Tony Mak13125532021-01-13 21:12:07 +00003142 if (!datetime_parser_) {
3143 return true;
3144 }
3145 LocaleList locale_list = LocaleList::ParseFrom(locales);
3146 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3147 datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3148 reference_timezone, locale_list, mode,
3149 annotation_usecase,
3150 /*anchor_start_end=*/false);
3151 if (!result_status.ok()) {
3152 return false;
Tony Makd99d58c2020-03-19 21:52:02 +00003153 }
3154
Tony Mak13125532021-01-13 21:12:07 +00003155 for (const DatetimeParseResultSpan& datetime_span :
3156 result_status.ValueOrDie()) {
Tony Mak378c1f52019-03-04 15:58:11 +00003157 AnnotatedSpan annotated_span;
3158 annotated_span.span = datetime_span.span;
Tony Mak854015a2019-01-16 15:56:48 +00003159 for (const DatetimeParseResult& parse_result : datetime_span.data) {
Tony Mak378c1f52019-03-04 15:58:11 +00003160 annotated_span.classification.emplace_back(
3161 PickCollectionForDatetime(parse_result),
3162 datetime_span.target_classification_score,
3163 datetime_span.priority_score);
3164 annotated_span.classification.back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01003165 if (is_serialized_entity_data_enabled) {
3166 annotated_span.classification.back().serialized_entity_data =
3167 CreateDatetimeSerializedEntityData(parse_result);
3168 }
Tony Mak854015a2019-01-16 15:56:48 +00003169 }
Tony Mak448b5862019-03-22 13:36:41 +00003170 annotated_span.source = AnnotatedSpan::Source::DATETIME;
Tony Mak378c1f52019-03-04 15:58:11 +00003171 result->push_back(std::move(annotated_span));
Lukas Zilkab23e2122018-02-09 10:25:19 +01003172 }
3173 return true;
3174}
3175
Tony Mak378c1f52019-03-04 15:58:11 +00003176const Model* Annotator::model() const { return model_; }
Tony Makd9446602019-02-20 18:25:39 +00003177const reflection::Schema* Annotator::entity_data_schema() const {
3178 return entity_data_schema_;
3179}
Tony Mak854015a2019-01-16 15:56:48 +00003180
Lukas Zilka21d8c982018-01-24 11:11:20 +01003181const Model* ViewModel(const void* buffer, int size) {
3182 if (!buffer) {
3183 return nullptr;
3184 }
3185
3186 return LoadAndVerifyModel(buffer, size);
3187}
3188
Tony Makf5fd3652021-03-18 19:23:10 +00003189StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3190 const std::string& id) const {
3191 if (!knowledge_engine_) {
3192 return Status(StatusCode::FAILED_PRECONDITION,
3193 "knowledge_engine_ is nullptr");
3194 }
3195 return knowledge_engine_->LookUpEntity(id);
Tony Makd9446602019-02-20 18:25:39 +00003196}
3197
Tony Mak8a501052021-02-24 20:08:27 +00003198StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3199 const std::string& mid_str, const std::string& property) const {
3200 if (!knowledge_engine_) {
3201 return Status(StatusCode::FAILED_PRECONDITION,
3202 "knowledge_engine_ is nullptr");
3203 }
3204 return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3205}
3206
Tony Mak6c4cc672018-09-17 11:48:50 +01003207} // namespace libtextclassifier3