blob: abb57e855b97b12d4dd7c3cf07d2ca72dcff18ce [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <algorithm>
20#include <cctype>
21#include <cmath>
22#include <iterator>
23#include <numeric>
Tony Mak448b5862019-03-22 13:36:41 +000024#include <unordered_map>
Lukas Zilka21d8c982018-01-24 11:11:20 +010025
Tony Mak854015a2019-01-16 15:56:48 +000026#include "annotator/collections.h"
Tony Mak83d2de62019-04-10 16:12:15 +010027#include "annotator/model_generated.h"
28#include "annotator/types.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010029#include "utils/base/logging.h"
30#include "utils/checksum.h"
31#include "utils/math/softmax.h"
Tony Mak8cd7ba62019-10-15 15:29:22 +010032#include "utils/normalization.h"
Tony Maka2a1ff42019-09-12 15:40:32 +010033#include "utils/optional.h"
Tony Makd9446602019-02-20 18:25:39 +000034#include "utils/regex-match.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010035#include "utils/utf8/unicodetext.h"
Tony Mak378c1f52019-03-04 15:58:11 +000036#include "utils/zlib/zlib_regex.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010037
Tony Mak83d2de62019-04-10 16:12:15 +010038
Tony Mak6c4cc672018-09-17 11:48:50 +010039namespace libtextclassifier3 {
Tony Mak448b5862019-03-22 13:36:41 +000040
41using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
42
Tony Mak6c4cc672018-09-17 11:48:50 +010043const std::string& Annotator::kPhoneCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010044 *[]() { return new std::string("phone"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010045const std::string& Annotator::kAddressCollection =
Lukas Zilkae7962cc2018-03-28 18:09:48 +020046 *[]() { return new std::string("address"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010047const std::string& Annotator::kDateCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010048 *[]() { return new std::string("date"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000049const std::string& Annotator::kUrlCollection =
50 *[]() { return new std::string("url"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000051const std::string& Annotator::kEmailCollection =
52 *[]() { return new std::string("email"); }();
Lukas Zilkab23e2122018-02-09 10:25:19 +010053
Lukas Zilka21d8c982018-01-24 11:11:20 +010054namespace {
55const Model* LoadAndVerifyModel(const void* addr, int size) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010056 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
Tony Mak51a9e542018-11-02 13:36:22 +000057 if (VerifyModelBuffer(verifier)) {
58 return GetModel(addr);
Lukas Zilka21d8c982018-01-24 11:11:20 +010059 } else {
60 return nullptr;
61 }
62}
Tony Mak6c4cc672018-09-17 11:48:50 +010063
64// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
65// create a new instance, assign ownership to owned_lib, and return it.
66const UniLib* MaybeCreateUnilib(const UniLib* lib,
67 std::unique_ptr<UniLib>* owned_lib) {
68 if (lib) {
69 return lib;
70 } else {
71 owned_lib->reset(new UniLib);
72 return owned_lib->get();
73 }
74}
75
76// As above, but for CalendarLib.
77const CalendarLib* MaybeCreateCalendarlib(
78 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
79 if (lib) {
80 return lib;
81 } else {
82 owned_lib->reset(new CalendarLib);
83 return owned_lib->get();
84 }
85}
86
Tony Mak968412a2019-11-13 15:39:57 +000087// Returns whether the provided input is valid:
88// * Valid utf8 text.
89// * Sane span indices.
90bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
91 if (!context.is_valid()) {
92 return false;
93 }
94 return (span.first >= 0 && span.first < span.second &&
95 span.second <= context.size_codepoints());
96}
97
Lukas Zilka21d8c982018-01-24 11:11:20 +010098} // namespace
99
Lukas Zilkaba849e72018-03-08 14:48:21 +0100100tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
101 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100102 TC3_CHECK(selection_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100103 selection_interpreter_ = selection_executor_->CreateInterpreter();
104 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100105 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100106 }
107 }
108 return selection_interpreter_.get();
109}
110
111tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
112 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100113 TC3_CHECK(classification_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100114 classification_interpreter_ = classification_executor_->CreateInterpreter();
115 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100116 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100117 }
118 }
119 return classification_interpreter_.get();
120}
121
Tony Mak6c4cc672018-09-17 11:48:50 +0100122std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
123 const char* buffer, int size, const UniLib* unilib,
124 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100125 const Model* model = LoadAndVerifyModel(buffer, size);
126 if (model == nullptr) {
127 return nullptr;
128 }
129
Lukas Zilkab23e2122018-02-09 10:25:19 +0100130 auto classifier =
Tony Mak6c4cc672018-09-17 11:48:50 +0100131 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100132 if (!classifier->IsInitialized()) {
133 return nullptr;
134 }
135
136 return classifier;
137}
138
Tony Maka0f598b2018-11-20 20:39:04 +0000139
Tony Mak6c4cc672018-09-17 11:48:50 +0100140std::unique_ptr<Annotator> Annotator::FromScopedMmap(
141 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
142 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100143 if (!(*mmap)->handle().ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100144 TC3_VLOG(1) << "Mmap failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100145 return nullptr;
146 }
147
148 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
149 (*mmap)->handle().num_bytes());
150 if (!model) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100151 TC3_LOG(ERROR) << "Model verification failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100152 return nullptr;
153 }
154
Tony Mak6c4cc672018-09-17 11:48:50 +0100155 auto classifier = std::unique_ptr<Annotator>(
156 new Annotator(mmap, model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100157 if (!classifier->IsInitialized()) {
158 return nullptr;
159 }
160
161 return classifier;
162}
163
Tony Makdf54e742019-03-26 14:04:00 +0000164std::unique_ptr<Annotator> Annotator::FromScopedMmap(
165 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
166 std::unique_ptr<CalendarLib> calendarlib) {
167 if (!(*mmap)->handle().ok()) {
168 TC3_VLOG(1) << "Mmap failed.";
169 return nullptr;
170 }
171
172 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
173 (*mmap)->handle().num_bytes());
174 if (model == nullptr) {
175 TC3_LOG(ERROR) << "Model verification failed.";
176 return nullptr;
177 }
178
179 auto classifier = std::unique_ptr<Annotator>(
180 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
181 if (!classifier->IsInitialized()) {
182 return nullptr;
183 }
184
185 return classifier;
186}
187
Tony Mak6c4cc672018-09-17 11:48:50 +0100188std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
189 int fd, int offset, int size, const UniLib* unilib,
190 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100191 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
Tony Mak6c4cc672018-09-17 11:48:50 +0100192 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100193}
194
Tony Mak6c4cc672018-09-17 11:48:50 +0100195std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Makdf54e742019-03-26 14:04:00 +0000196 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
197 std::unique_ptr<CalendarLib> calendarlib) {
198 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
199 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
200}
201
202std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Mak6c4cc672018-09-17 11:48:50 +0100203 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100204 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
Tony Mak6c4cc672018-09-17 11:48:50 +0100205 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100206}
207
Tony Makdf54e742019-03-26 14:04:00 +0000208std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
209 int fd, std::unique_ptr<UniLib> unilib,
210 std::unique_ptr<CalendarLib> calendarlib) {
211 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
212 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
213}
214
Tony Mak6c4cc672018-09-17 11:48:50 +0100215std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
216 const UniLib* unilib,
217 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100218 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
Tony Mak6c4cc672018-09-17 11:48:50 +0100219 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100220}
221
Tony Makdf54e742019-03-26 14:04:00 +0000222std::unique_ptr<Annotator> Annotator::FromPath(
223 const std::string& path, std::unique_ptr<UniLib> unilib,
224 std::unique_ptr<CalendarLib> calendarlib) {
225 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
226 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
227}
228
Tony Mak6c4cc672018-09-17 11:48:50 +0100229Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
230 const UniLib* unilib, const CalendarLib* calendarlib)
231 : model_(model),
232 mmap_(std::move(*mmap)),
233 owned_unilib_(nullptr),
234 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
235 owned_calendarlib_(nullptr),
236 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
237 ValidateAndInitialize();
238}
239
Tony Makdf54e742019-03-26 14:04:00 +0000240Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
241 std::unique_ptr<UniLib> unilib,
242 std::unique_ptr<CalendarLib> calendarlib)
243 : model_(model),
244 mmap_(std::move(*mmap)),
245 owned_unilib_(std::move(unilib)),
246 unilib_(owned_unilib_.get()),
247 owned_calendarlib_(std::move(calendarlib)),
248 calendarlib_(owned_calendarlib_.get()) {
249 ValidateAndInitialize();
250}
251
Tony Mak6c4cc672018-09-17 11:48:50 +0100252Annotator::Annotator(const Model* model, const UniLib* unilib,
253 const CalendarLib* calendarlib)
254 : model_(model),
255 owned_unilib_(nullptr),
256 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
257 owned_calendarlib_(nullptr),
258 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
259 ValidateAndInitialize();
260}
261
262void Annotator::ValidateAndInitialize() {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100263 initialized_ = false;
264
Lukas Zilka21d8c982018-01-24 11:11:20 +0100265 if (model_ == nullptr) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100266 TC3_LOG(ERROR) << "No model specified.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100267 return;
268 }
269
Lukas Zilkaba849e72018-03-08 14:48:21 +0100270 const bool model_enabled_for_annotation =
271 (model_->triggering_options() != nullptr &&
272 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
273 const bool model_enabled_for_classification =
274 (model_->triggering_options() != nullptr &&
275 (model_->triggering_options()->enabled_modes() &
276 ModeFlag_CLASSIFICATION));
277 const bool model_enabled_for_selection =
278 (model_->triggering_options() != nullptr &&
279 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
280
281 // Annotation requires the selection model.
282 if (model_enabled_for_annotation || model_enabled_for_selection) {
283 if (!model_->selection_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100284 TC3_LOG(ERROR) << "No selection options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100285 return;
286 }
287 if (!model_->selection_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100288 TC3_LOG(ERROR) << "No selection feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100289 return;
290 }
291 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100292 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100293 return;
294 }
295 if (!model_->selection_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100296 TC3_LOG(ERROR) << "No selection model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100297 return;
298 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100299 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100300 if (!selection_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100301 TC3_LOG(ERROR) << "Could not initialize selection executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100302 return;
303 }
304 selection_feature_processor_.reset(
305 new FeatureProcessor(model_->selection_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100306 }
307
Lukas Zilkaba849e72018-03-08 14:48:21 +0100308 // Annotation requires the classification model for conflict resolution and
309 // scoring.
310 // Selection requires the classification model for conflict resolution.
311 if (model_enabled_for_annotation || model_enabled_for_classification ||
312 model_enabled_for_selection) {
313 if (!model_->classification_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100314 TC3_LOG(ERROR) << "No classification options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100315 return;
316 }
317
318 if (!model_->classification_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100319 TC3_LOG(ERROR) << "No classification feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100320 return;
321 }
322
323 if (!model_->classification_feature_options()
324 ->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100325 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100326 return;
327 }
328 if (!model_->classification_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100329 TC3_LOG(ERROR) << "No clf model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100330 return;
331 }
332
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200333 classification_executor_ =
Tony Mak6c4cc672018-09-17 11:48:50 +0100334 ModelExecutor::FromBuffer(model_->classification_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100335 if (!classification_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100336 TC3_LOG(ERROR) << "Could not initialize classification executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100337 return;
338 }
339
340 classification_feature_processor_.reset(new FeatureProcessor(
341 model_->classification_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100342 }
343
Lukas Zilkaba849e72018-03-08 14:48:21 +0100344 // The embeddings need to be specified if the model is to be used for
345 // classification or selection.
346 if (model_enabled_for_annotation || model_enabled_for_classification ||
347 model_enabled_for_selection) {
348 if (!model_->embedding_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100349 TC3_LOG(ERROR) << "No embedding model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100350 return;
351 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100352
Lukas Zilkaba849e72018-03-08 14:48:21 +0100353 // Check that the embedding size of the selection and classification model
354 // matches, as they are using the same embeddings.
355 if (model_enabled_for_selection &&
356 (model_->selection_feature_options()->embedding_size() !=
357 model_->classification_feature_options()->embedding_size() ||
358 model_->selection_feature_options()->embedding_quantization_bits() !=
359 model_->classification_feature_options()
360 ->embedding_quantization_bits())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100361 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100362 return;
363 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100364
Tony Mak6c4cc672018-09-17 11:48:50 +0100365 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200366 model_->embedding_model(),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100367 model_->classification_feature_options()->embedding_size(),
Tony Makdf54e742019-03-26 14:04:00 +0000368 model_->classification_feature_options()->embedding_quantization_bits(),
369 model_->embedding_pruning_mask());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200370 if (!embedding_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100371 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100372 return;
373 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100374 }
375
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200376 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
Lukas Zilkab23e2122018-02-09 10:25:19 +0100377 if (model_->regex_model()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200378 if (!InitializeRegexModel(decompressor.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100379 TC3_LOG(ERROR) << "Could not initialize regex model.";
Lukas Zilka434442d2018-04-25 11:38:51 +0200380 return;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100381 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100382 }
383
Lukas Zilkab23e2122018-02-09 10:25:19 +0100384 if (model_->datetime_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100385 datetime_parser_ = DatetimeParser::Instance(
386 model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100387 if (!datetime_parser_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100388 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100389 return;
390 }
391 }
392
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200393 if (model_->output_options()) {
394 if (model_->output_options()->filtered_collections_annotation()) {
395 for (const auto collection :
396 *model_->output_options()->filtered_collections_annotation()) {
397 filtered_collections_annotation_.insert(collection->str());
398 }
399 }
400 if (model_->output_options()->filtered_collections_classification()) {
401 for (const auto collection :
402 *model_->output_options()->filtered_collections_classification()) {
403 filtered_collections_classification_.insert(collection->str());
404 }
405 }
406 if (model_->output_options()->filtered_collections_selection()) {
407 for (const auto collection :
408 *model_->output_options()->filtered_collections_selection()) {
409 filtered_collections_selection_.insert(collection->str());
410 }
411 }
412 }
413
Tony Mak378c1f52019-03-04 15:58:11 +0000414 if (model_->number_annotator_options() &&
415 model_->number_annotator_options()->enabled()) {
Tony Makad2e22d2019-03-20 17:35:13 +0000416 if (selection_feature_processor_ == nullptr) {
417 TC3_LOG(ERROR)
418 << "Could not initialize NumberAnnotator without a feature processor";
419 return;
420 }
421
Tony Mak378c1f52019-03-04 15:58:11 +0000422 number_annotator_.reset(
423 new NumberAnnotator(model_->number_annotator_options(),
424 selection_feature_processor_.get()));
425 }
426
Tony Makad2e22d2019-03-20 17:35:13 +0000427 if (model_->duration_annotator_options() &&
428 model_->duration_annotator_options()->enabled()) {
429 duration_annotator_.reset(
430 new DurationAnnotator(model_->duration_annotator_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +0100431 selection_feature_processor_.get(), unilib_));
Tony Makad2e22d2019-03-20 17:35:13 +0000432 }
433
Tony Makd9446602019-02-20 18:25:39 +0000434 if (model_->entity_data_schema()) {
435 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
436 model_->entity_data_schema()->Data(),
437 model_->entity_data_schema()->size());
438 if (entity_data_schema_ == nullptr) {
439 TC3_LOG(ERROR) << "Could not load entity data schema data.";
440 return;
441 }
442
443 entity_data_builder_.reset(
444 new ReflectiveFlatbufferBuilder(entity_data_schema_));
445 } else {
Tony Mak378c1f52019-03-04 15:58:11 +0000446 entity_data_schema_ = nullptr;
Tony Makd9446602019-02-20 18:25:39 +0000447 entity_data_builder_ = nullptr;
448 }
449
Tony Makdf54e742019-03-26 14:04:00 +0000450 if (model_->triggering_locales() &&
451 !ParseLocales(model_->triggering_locales()->c_str(),
452 &model_triggering_locales_)) {
Tony Mak378c1f52019-03-04 15:58:11 +0000453 TC3_LOG(ERROR) << "Could not parse model supported locales.";
454 return;
455 }
456
457 if (model_->triggering_options() != nullptr &&
Tony Makdf54e742019-03-26 14:04:00 +0000458 model_->triggering_options()->locales() != nullptr &&
459 !ParseLocales(model_->triggering_options()->locales()->c_str(),
460 &ml_model_triggering_locales_)) {
461 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
462 return;
463 }
464
465 if (model_->triggering_options() != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +0000466 model_->triggering_options()->dictionary_locales() != nullptr &&
467 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
468 &dictionary_locales_)) {
469 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
470 return;
471 }
472
Lukas Zilka21d8c982018-01-24 11:11:20 +0100473 initialized_ = true;
474}
475
Tony Mak6c4cc672018-09-17 11:48:50 +0100476bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100477 if (!model_->regex_model()->patterns()) {
Lukas Zilka434442d2018-04-25 11:38:51 +0200478 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100479 }
480
481 // Initialize pattern recognizers.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100482 int regex_pattern_id = 0;
483 for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200484 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
Tony Mak378c1f52019-03-04 15:58:11 +0000485 UncompressMakeRegexPattern(
486 *unilib_, regex_pattern->pattern(),
487 regex_pattern->compressed_pattern(),
488 model_->regex_model()->lazy_regex_compilation(), decompressor);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100489 if (!compiled_pattern) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100490 TC3_LOG(INFO) << "Failed to load regex pattern";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200491 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100492 }
493
Lukas Zilkaba849e72018-03-08 14:48:21 +0100494 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100495 annotation_regex_patterns_.push_back(regex_pattern_id);
496 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100497 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100498 classification_regex_patterns_.push_back(regex_pattern_id);
499 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100500 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100501 selection_regex_patterns_.push_back(regex_pattern_id);
502 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100503 regex_patterns_.push_back({
Tony Mak854015a2019-01-16 15:56:48 +0000504 regex_pattern,
Tony Mak6c4cc672018-09-17 11:48:50 +0100505 std::move(compiled_pattern),
Tony Mak6c4cc672018-09-17 11:48:50 +0100506 });
Lukas Zilkab23e2122018-02-09 10:25:19 +0100507 ++regex_pattern_id;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100508 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100509
Lukas Zilkab23e2122018-02-09 10:25:19 +0100510 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100511}
512
Tony Mak6c4cc672018-09-17 11:48:50 +0100513bool Annotator::InitializeKnowledgeEngine(
514 const std::string& serialized_config) {
Tony Maka2a1ff42019-09-12 15:40:32 +0100515 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
Tony Mak6c4cc672018-09-17 11:48:50 +0100516 if (!knowledge_engine->Initialize(serialized_config)) {
517 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
518 return false;
519 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100520 if (model_->triggering_options() != nullptr) {
521 knowledge_engine->SetPriorityScore(
522 model_->triggering_options()->knowledge_priority_score());
523 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100524 knowledge_engine_ = std::move(knowledge_engine);
525 return true;
526}
527
Tony Mak854015a2019-01-16 15:56:48 +0000528bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
Tony Mak5dc5e112019-02-01 14:52:10 +0000529 std::unique_ptr<ContactEngine> contact_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000530 new ContactEngine(selection_feature_processor_.get(), unilib_));
Tony Mak854015a2019-01-16 15:56:48 +0000531 if (!contact_engine->Initialize(serialized_config)) {
532 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
533 return false;
534 }
535 contact_engine_ = std::move(contact_engine);
536 return true;
537}
538
Tony Makd9446602019-02-20 18:25:39 +0000539bool Annotator::InitializeInstalledAppEngine(
540 const std::string& serialized_config) {
541 std::unique_ptr<InstalledAppEngine> installed_app_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000542 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
Tony Makd9446602019-02-20 18:25:39 +0000543 if (!installed_app_engine->Initialize(serialized_config)) {
544 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
545 return false;
546 }
547 installed_app_engine_ = std::move(installed_app_engine);
548 return true;
549}
550
Lukas Zilka21d8c982018-01-24 11:11:20 +0100551namespace {
552
553int CountDigits(const std::string& str, CodepointSpan selection_indices) {
554 int count = 0;
555 int i = 0;
556 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
557 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
558 if (i >= selection_indices.first && i < selection_indices.second &&
559 isdigit(*it)) {
560 ++count;
561 }
562 }
563 return count;
564}
565
Lukas Zilka21d8c982018-01-24 11:11:20 +0100566} // namespace
567
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200568namespace internal {
569// Helper function, which if the initial 'span' contains only white-spaces,
570// moves the selection to a single-codepoint selection on a left or right side
571// of this space.
572CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
573 const UnicodeText& context_unicode,
574 const UniLib& unilib) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100575 TC3_CHECK(ValidNonEmptySpan(span));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200576
577 UnicodeText::const_iterator it;
578
579 // Check that the current selection is all whitespaces.
580 it = context_unicode.begin();
581 std::advance(it, span.first);
582 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
583 if (!unilib.IsWhitespace(*it)) {
584 return span;
585 }
586 }
587
588 CodepointSpan result;
589
590 // Try moving left.
591 result = span;
592 it = context_unicode.begin();
593 std::advance(it, span.first);
594 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
595 --result.first;
596 --it;
597 }
598 result.second = result.first + 1;
599 if (!unilib.IsWhitespace(*it)) {
600 return result;
601 }
602
603 // If moving left didn't find a non-whitespace character, just return the
604 // original span.
605 return span;
606}
607} // namespace internal
608
Tony Mak6c4cc672018-09-17 11:48:50 +0100609bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200610 return !span.classification.empty() &&
611 filtered_collections_annotation_.find(
612 span.classification[0].collection) !=
613 filtered_collections_annotation_.end();
614}
615
Tony Mak6c4cc672018-09-17 11:48:50 +0100616bool Annotator::FilteredForClassification(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200617 const ClassificationResult& classification) const {
618 return filtered_collections_classification_.find(classification.collection) !=
619 filtered_collections_classification_.end();
620}
621
Tony Mak6c4cc672018-09-17 11:48:50 +0100622bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200623 return !span.classification.empty() &&
624 filtered_collections_selection_.find(
625 span.classification[0].collection) !=
626 filtered_collections_selection_.end();
627}
628
Tony Mak378c1f52019-03-04 15:58:11 +0000629namespace {
630inline bool ClassifiedAsOther(
631 const std::vector<ClassificationResult>& classification) {
632 return !classification.empty() &&
633 classification[0].collection == Collections::Other();
634}
635
Tony Maka2a1ff42019-09-12 15:40:32 +0100636} // namespace
637
638float Annotator::GetPriorityScore(
639 const std::vector<ClassificationResult>& classification) const {
Tony Mak378c1f52019-03-04 15:58:11 +0000640 if (!classification.empty() && !ClassifiedAsOther(classification)) {
641 return classification[0].priority_score;
642 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +0100643 if (model_->triggering_options() != nullptr) {
644 return model_->triggering_options()->other_collection_priority_score();
645 } else {
646 return -1000.0;
647 }
Tony Mak378c1f52019-03-04 15:58:11 +0000648 }
649}
Tony Mak378c1f52019-03-04 15:58:11 +0000650
Tony Makdf54e742019-03-26 14:04:00 +0000651bool Annotator::VerifyRegexMatchCandidate(
652 const std::string& context, const VerificationOptions* verification_options,
653 const std::string& match, const UniLib::RegexMatcher* matcher) const {
654 if (verification_options == nullptr) {
655 return true;
656 }
657 if (verification_options->verify_luhn_checksum() &&
658 !VerifyLuhnChecksum(match)) {
659 return false;
660 }
661 const int lua_verifier = verification_options->lua_verifier();
662 if (lua_verifier >= 0) {
663 if (model_->regex_model()->lua_verifier() == nullptr ||
664 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
665 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
666 return false;
667 }
668 return VerifyMatch(
669 context, matcher,
670 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
671 }
672 return true;
673}
674
Tony Mak6c4cc672018-09-17 11:48:50 +0100675CodepointSpan Annotator::SuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100676 const std::string& context, CodepointSpan click_indices,
677 const SelectionOptions& options) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200678 CodepointSpan original_click_indices = click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100679 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100680 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200681 return original_click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100682 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100683 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200684 return original_click_indices;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100685 }
686
Tony Makdf54e742019-03-26 14:04:00 +0000687 std::vector<Locale> detected_text_language_tags;
688 if (!ParseLocales(options.detected_text_language_tags,
689 &detected_text_language_tags)) {
690 TC3_LOG(WARNING)
691 << "Failed to parse the detected_text_language_tags in options: "
692 << options.detected_text_language_tags;
693 }
694 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
695 model_triggering_locales_,
696 /*default_value=*/true)) {
697 return original_click_indices;
698 }
699
Lukas Zilkadf710db2018-02-27 12:44:09 +0100700 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
701 /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200702
Tony Mak968412a2019-11-13 15:39:57 +0000703 if (!IsValidSpanInput(context_unicode, click_indices)) {
704 TC3_VLOG(1)
705 << "Trying to run SuggestSelection with invalid input, indices: "
706 << click_indices.first << " " << click_indices.second;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200707 return original_click_indices;
708 }
709
710 if (model_->snap_whitespace_selections()) {
711 // We want to expand a purely white-space selection to a multi-selection it
712 // would've been part of. But with this feature disabled we would do a no-
713 // op, because no token is found. Therefore, we need to modify the
714 // 'click_indices' a bit to include a part of the token, so that the click-
715 // finding logic finds the clicked token correctly. This modification is
716 // done by the following function. Note, that it's enough to check the left
717 // side of the current selection, because if the white-space is a part of a
Tony Mak6c4cc672018-09-17 11:48:50 +0100718 // multi-selection, necessarily both tokens - on the left and the right
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200719 // sides need to be selected. Thus snapping only to the left is sufficient
720 // (there's a check at the bottom that makes sure that if we snap to the
721 // left token but the result does not contain the initial white-space,
722 // returns the original indices).
723 click_indices = internal::SnapLeftIfWhitespaceSelection(
724 click_indices, context_unicode, *unilib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100725 }
726
Lukas Zilkab23e2122018-02-09 10:25:19 +0100727 std::vector<AnnotatedSpan> candidates;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100728 InterpreterManager interpreter_manager(selection_executor_.get(),
729 classification_executor_.get());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200730 std::vector<Token> tokens;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100731 if (!ModelSuggestSelection(context_unicode, click_indices,
Tony Makdf54e742019-03-26 14:04:00 +0000732 detected_text_language_tags, &interpreter_manager,
733 &tokens, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100734 TC3_LOG(ERROR) << "Model suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200735 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100736 }
Tony Mak83d2de62019-04-10 16:12:15 +0100737 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
738 /*is_serialized_entity_data_enabled=*/false)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100739 TC3_LOG(ERROR) << "Regex suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200740 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100741 }
Tony Mak83d2de62019-04-10 16:12:15 +0100742 if (!DatetimeChunk(
743 UTF8ToUnicodeText(context, /*do_copy=*/false),
744 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
745 options.locales, ModeFlag_SELECTION, options.annotation_usecase,
746 /*is_serialized_entity_data_enabled=*/false, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100747 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
748 return original_click_indices;
749 }
Tony Mak378c1f52019-03-04 15:58:11 +0000750 if (knowledge_engine_ != nullptr &&
Tony Maka2a1ff42019-09-12 15:40:32 +0100751 !knowledge_engine_->Chunk(context, options.annotation_usecase,
752 &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100753 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200754 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100755 }
Tony Mak378c1f52019-03-04 15:58:11 +0000756 if (contact_engine_ != nullptr &&
Tony Mak854015a2019-01-16 15:56:48 +0000757 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
758 TC3_LOG(ERROR) << "Contact suggest selection failed.";
759 return original_click_indices;
760 }
Tony Mak378c1f52019-03-04 15:58:11 +0000761 if (installed_app_engine_ != nullptr &&
Tony Makd9446602019-02-20 18:25:39 +0000762 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
763 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
764 return original_click_indices;
765 }
Tony Mak378c1f52019-03-04 15:58:11 +0000766 if (number_annotator_ != nullptr &&
767 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
768 &candidates)) {
769 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
770 return original_click_indices;
771 }
Tony Makad2e22d2019-03-20 17:35:13 +0000772 if (duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +0000773 !duration_annotator_->FindAll(context_unicode, tokens,
774 options.annotation_usecase, &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +0000775 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
776 return original_click_indices;
777 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100778
779 // Sort candidates according to their position in the input, so that the next
780 // code can assume that any connected component of overlapping spans forms a
781 // contiguous block.
782 std::sort(candidates.begin(), candidates.end(),
783 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
784 return a.span.first < b.span.first;
785 });
786
787 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +0000788 if (!ResolveConflicts(candidates, context, tokens,
789 detected_text_language_tags, options.annotation_usecase,
790 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100791 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200792 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100793 }
794
Tony Mak378c1f52019-03-04 15:58:11 +0000795 std::sort(candidate_indices.begin(), candidate_indices.end(),
Tony Maka2a1ff42019-09-12 15:40:32 +0100796 [this, &candidates](int a, int b) {
Tony Mak378c1f52019-03-04 15:58:11 +0000797 return GetPriorityScore(candidates[a].classification) >
798 GetPriorityScore(candidates[b].classification);
799 });
800
Lukas Zilkab23e2122018-02-09 10:25:19 +0100801 for (const int i : candidate_indices) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200802 if (SpansOverlap(candidates[i].span, click_indices) &&
803 SpansOverlap(candidates[i].span, original_click_indices)) {
804 // Run model classification if not present but requested and there's a
805 // classification collection filter specified.
806 if (candidates[i].classification.empty() &&
807 model_->selection_options()->always_classify_suggested_selection() &&
808 !filtered_collections_selection_.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +0000809 if (!ModelClassifyText(context, detected_text_language_tags,
810 candidates[i].span, &interpreter_manager,
811 /*embedding_cache=*/nullptr,
812 &candidates[i].classification)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200813 return original_click_indices;
814 }
815 }
816
817 // Ignore if span classification is filtered.
818 if (FilteredForSelection(candidates[i])) {
819 return original_click_indices;
820 }
821
Lukas Zilkab23e2122018-02-09 10:25:19 +0100822 return candidates[i].span;
823 }
824 }
825
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200826 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100827}
828
829namespace {
830// Helper function that returns the index of the first candidate that
831// transitively does not overlap with the candidate on 'start_index'. If the end
832// of 'candidates' is reached, it returns the index that points right behind the
833// array.
834int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
835 int start_index) {
836 int first_non_overlapping = start_index + 1;
837 CodepointSpan conflicting_span = candidates[start_index].span;
838 while (
839 first_non_overlapping < candidates.size() &&
840 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
841 // Grow the span to include the current one.
842 conflicting_span.second = std::max(
843 conflicting_span.second, candidates[first_non_overlapping].span.second);
844
845 ++first_non_overlapping;
846 }
847 return first_non_overlapping;
848}
849} // namespace
850
Tony Mak378c1f52019-03-04 15:58:11 +0000851bool Annotator::ResolveConflicts(
852 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
853 const std::vector<Token>& cached_tokens,
854 const std::vector<Locale>& detected_text_language_tags,
855 AnnotationUsecase annotation_usecase,
856 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100857 result->clear();
858 result->reserve(candidates.size());
859 for (int i = 0; i < candidates.size();) {
860 int first_non_overlapping =
861 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
862
863 const bool conflict_found = first_non_overlapping != (i + 1);
864 if (conflict_found) {
865 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +0000866 if (!ResolveConflict(context, cached_tokens, candidates,
867 detected_text_language_tags, i,
868 first_non_overlapping, annotation_usecase,
869 interpreter_manager, &candidate_indices)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100870 return false;
871 }
872 result->insert(result->end(), candidate_indices.begin(),
873 candidate_indices.end());
874 } else {
875 result->push_back(i);
876 }
877
878 // Skip over the whole conflicting group/go to next candidate.
879 i = first_non_overlapping;
880 }
881 return true;
882}
883
884namespace {
Tony Mak448b5862019-03-22 13:36:41 +0000885// Returns true, if the given two sources do conflict in given annotation
886// usecase.
887// - In SMART usecase, all sources do conflict, because there's only 1 possible
888// annotation for a given span.
889// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
890// and duration), while others not (e.g. duration and number).
891bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
892 const AnnotatedSpan::Source source1,
893 const AnnotatedSpan::Source source2) {
894 uint32 source_mask =
895 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
896
Tony Mak378c1f52019-03-04 15:58:11 +0000897 switch (annotation_usecase) {
898 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
Tony Mak448b5862019-03-22 13:36:41 +0000899 // In the SMART mode, all annotations conflict.
Tony Mak378c1f52019-03-04 15:58:11 +0000900 return true;
Tony Mak448b5862019-03-22 13:36:41 +0000901
Tony Mak378c1f52019-03-04 15:58:11 +0000902 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
Tony Mak448b5862019-03-22 13:36:41 +0000903 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
904 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
905 // hours" (duration).
906 if ((source_mask &
907 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
908 (source_mask &
909 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
910 return false;
Tony Mak378c1f52019-03-04 15:58:11 +0000911 }
Tony Mak448b5862019-03-22 13:36:41 +0000912
913 // A KNOWLEDGE entity does not conflict with anything.
914 if ((source_mask &
915 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
916 return false;
917 }
918
919 // Entities from other sources can conflict.
Tony Mak378c1f52019-03-04 15:58:11 +0000920 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100921 }
922}
923} // namespace
924
Tony Mak378c1f52019-03-04 15:58:11 +0000925bool Annotator::ResolveConflict(
926 const std::string& context, const std::vector<Token>& cached_tokens,
927 const std::vector<AnnotatedSpan>& candidates,
928 const std::vector<Locale>& detected_text_language_tags, int start_index,
929 int end_index, AnnotationUsecase annotation_usecase,
930 InterpreterManager* interpreter_manager,
931 std::vector<int>* chosen_indices) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100932 std::vector<int> conflicting_indices;
933 std::unordered_map<int, float> scores;
934 for (int i = start_index; i < end_index; ++i) {
935 conflicting_indices.push_back(i);
936 if (!candidates[i].classification.empty()) {
937 scores[i] = GetPriorityScore(candidates[i].classification);
938 continue;
939 }
940
941 // OPTIMIZATION: So that we don't have to classify all the ML model
942 // spans apriori, we wait until we get here, when they conflict with
943 // something and we need the actual classification scores. So if the
944 // candidate conflicts and comes from the model, we need to run a
945 // classification to determine its priority:
946 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +0000947 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
948 candidates[i].span, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +0100949 /*embedding_cache=*/nullptr, &classification)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100950 return false;
951 }
952
953 if (!classification.empty()) {
954 scores[i] = GetPriorityScore(classification);
955 }
956 }
957
958 std::sort(conflicting_indices.begin(), conflicting_indices.end(),
959 [&scores](int i, int j) { return scores[i] > scores[j]; });
960
Tony Mak448b5862019-03-22 13:36:41 +0000961 // Here we keep a set of indices that were chosen, per-source, to enable
962 // effective computation.
963 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
964 chosen_indices_for_source_map;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100965
966 // Greedily place the candidates if they don't conflict with the already
967 // placed ones.
968 for (int i = 0; i < conflicting_indices.size(); ++i) {
969 const int considered_candidate = conflicting_indices[i];
Tony Mak448b5862019-03-22 13:36:41 +0000970
971 // See if there is a conflict between the candidate and all already placed
972 // candidates.
973 bool conflict = false;
974 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
975 for (auto& source_set_pair : chosen_indices_for_source_map) {
976 if (source_set_pair.first == candidates[considered_candidate].source) {
977 chosen_indices_for_source_ptr = &source_set_pair.second;
978 }
979
980 if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
981 candidates[considered_candidate].source) &&
982 DoesCandidateConflict(considered_candidate, candidates,
983 source_set_pair.second)) {
984 conflict = true;
985 break;
986 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100987 }
Tony Mak448b5862019-03-22 13:36:41 +0000988
989 // Skip the candidate if a conflict was found.
990 if (conflict) {
991 continue;
992 }
993
994 // If the set of indices for the current source doesn't exist yet,
995 // initialize it.
996 if (chosen_indices_for_source_ptr == nullptr) {
997 SortedIntSet new_set([&candidates](int a, int b) {
998 return candidates[a].span.first < candidates[b].span.first;
999 });
1000 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1001 std::move(new_set);
1002 chosen_indices_for_source_ptr =
1003 &chosen_indices_for_source_map[candidates[considered_candidate]
1004 .source];
1005 }
1006
1007 // Place the candidate to the output and to the per-source conflict set.
1008 chosen_indices->push_back(considered_candidate);
1009 chosen_indices_for_source_ptr->insert(considered_candidate);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001010 }
1011
Tony Mak378c1f52019-03-04 15:58:11 +00001012 std::sort(chosen_indices->begin(), chosen_indices->end());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001013
1014 return true;
1015}
1016
Tony Mak6c4cc672018-09-17 11:48:50 +01001017bool Annotator::ModelSuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001018 const UnicodeText& context_unicode, CodepointSpan click_indices,
Tony Makdf54e742019-03-26 14:04:00 +00001019 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001020 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001021 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001022 if (model_->triggering_options() == nullptr ||
1023 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1024 return true;
1025 }
1026
Tony Makdf54e742019-03-26 14:04:00 +00001027 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1028 ml_model_triggering_locales_,
1029 /*default_value=*/true)) {
1030 return true;
1031 }
1032
Lukas Zilka21d8c982018-01-24 11:11:20 +01001033 int click_pos;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001034 *tokens = selection_feature_processor_->Tokenize(context_unicode);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001035 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001036 context_unicode, click_indices,
1037 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001038 tokens, &click_pos);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001039 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001040 TC3_VLOG(1) << "Could not calculate the click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001041 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001042 }
1043
1044 const int symmetry_context_size =
1045 model_->selection_options()->symmetry_context_size();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001046 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
Lukas Zilkab23e2122018-02-09 10:25:19 +01001047 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1048 ->bounds_sensitive_features();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001049
1050 // The symmetry context span is the clicked token with symmetry_context_size
1051 // tokens on either side.
1052 const TokenSpan symmetry_context_span = IntersectTokenSpans(
1053 ExpandTokenSpan(SingleTokenSpan(click_pos),
1054 /*num_tokens_left=*/symmetry_context_size,
1055 /*num_tokens_right=*/symmetry_context_size),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001056 {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001057
Lukas Zilkab23e2122018-02-09 10:25:19 +01001058 // Compute the extraction span based on the model type.
1059 TokenSpan extraction_span;
1060 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1061 // The extraction span is the symmetry context span expanded to include
1062 // max_selection_span tokens on either side, which is how far a selection
1063 // can stretch from the click, plus a relevant number of tokens outside of
1064 // the bounds of the selection.
1065 const int max_selection_span =
1066 selection_feature_processor_->GetOptions()->max_selection_span();
1067 extraction_span =
1068 ExpandTokenSpan(symmetry_context_span,
1069 /*num_tokens_left=*/max_selection_span +
1070 bounds_sensitive_features->num_tokens_before(),
1071 /*num_tokens_right=*/max_selection_span +
1072 bounds_sensitive_features->num_tokens_after());
1073 } else {
1074 // The extraction span is the symmetry context span expanded to include
1075 // context_size tokens on either side.
1076 const int context_size =
1077 selection_feature_processor_->GetOptions()->context_size();
1078 extraction_span = ExpandTokenSpan(symmetry_context_span,
1079 /*num_tokens_left=*/context_size,
1080 /*num_tokens_right=*/context_size);
1081 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001082 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilkab23e2122018-02-09 10:25:19 +01001083
Lukas Zilka434442d2018-04-25 11:38:51 +02001084 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1085 *tokens, extraction_span)) {
1086 return true;
1087 }
1088
Lukas Zilkab23e2122018-02-09 10:25:19 +01001089 std::unique_ptr<CachedFeatures> cached_features;
1090 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001091 *tokens, extraction_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001092 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1093 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001094 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001095 selection_feature_processor_->EmbeddingSize() +
1096 selection_feature_processor_->DenseFeaturesCount(),
1097 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001098 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001099 return false;
1100 }
1101
1102 // Produce selection model candidates.
1103 std::vector<TokenSpan> chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001104 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001105 interpreter_manager->SelectionInterpreter(), *cached_features,
1106 &chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001107 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001108 return false;
1109 }
1110
1111 for (const TokenSpan& chunk : chunks) {
1112 AnnotatedSpan candidate;
1113 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001114 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001115 if (model_->selection_options()->strip_unpaired_brackets()) {
1116 candidate.span =
1117 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1118 }
1119
1120 // Only output non-empty spans.
1121 if (candidate.span.first != candidate.span.second) {
1122 result->push_back(candidate);
1123 }
1124 }
1125 return true;
1126}
1127
Tony Mak6c4cc672018-09-17 11:48:50 +01001128bool Annotator::ModelClassifyText(
Tony Mak378c1f52019-03-04 15:58:11 +00001129 const std::string& context,
1130 const std::vector<Locale>& detected_text_language_tags,
1131 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001132 FeatureProcessor::EmbeddingCache* embedding_cache,
1133 std::vector<ClassificationResult>* classification_results) const {
Tony Mak378c1f52019-03-04 15:58:11 +00001134 return ModelClassifyText(context, {}, detected_text_language_tags,
1135 selection_indices, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001136 embedding_cache, classification_results);
1137}
1138
1139namespace internal {
1140std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1141 CodepointSpan selection_indices,
1142 TokenSpan tokens_around_selection_to_copy) {
1143 const auto first_selection_token = std::upper_bound(
1144 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1145 [](int selection_start, const Token& token) {
1146 return selection_start < token.end;
1147 });
1148 const auto last_selection_token = std::lower_bound(
1149 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1150 [](const Token& token, int selection_end) {
1151 return token.start < selection_end;
1152 });
1153
1154 const int64 first_token = std::max(
1155 static_cast<int64>(0),
1156 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1157 tokens_around_selection_to_copy.first));
1158 const int64 last_token = std::min(
1159 static_cast<int64>(cached_tokens.size()),
1160 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1161 tokens_around_selection_to_copy.second));
1162
1163 std::vector<Token> tokens;
1164 tokens.reserve(last_token - first_token);
1165 for (int i = first_token; i < last_token; ++i) {
1166 tokens.push_back(cached_tokens[i]);
1167 }
1168 return tokens;
1169}
1170} // namespace internal
1171
Tony Mak6c4cc672018-09-17 11:48:50 +01001172TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001173 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1174 bounds_sensitive_features =
1175 classification_feature_processor_->GetOptions()
1176 ->bounds_sensitive_features();
1177 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1178 // The extraction span is the selection span expanded to include a relevant
1179 // number of tokens outside of the bounds of the selection.
1180 return {bounds_sensitive_features->num_tokens_before(),
1181 bounds_sensitive_features->num_tokens_after()};
1182 } else {
1183 // The extraction span is the clicked token with context_size tokens on
1184 // either side.
1185 const int context_size =
1186 selection_feature_processor_->GetOptions()->context_size();
1187 return {context_size, context_size};
1188 }
1189}
1190
Tony Mak378c1f52019-03-04 15:58:11 +00001191namespace {
1192// Sorts the classification results from high score to low score.
1193void SortClassificationResults(
1194 std::vector<ClassificationResult>* classification_results) {
1195 std::sort(classification_results->begin(), classification_results->end(),
1196 [](const ClassificationResult& a, const ClassificationResult& b) {
1197 return a.score > b.score;
1198 });
1199}
1200} // namespace
1201
Tony Mak6c4cc672018-09-17 11:48:50 +01001202bool Annotator::ModelClassifyText(
Lukas Zilkaba849e72018-03-08 14:48:21 +01001203 const std::string& context, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001204 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001205 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1206 FeatureProcessor::EmbeddingCache* embedding_cache,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001207 std::vector<ClassificationResult>* classification_results) const {
1208 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00001209 return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1210 selection_indices, interpreter_manager,
1211 embedding_cache, classification_results, &tokens);
1212}
1213
1214bool Annotator::ModelClassifyText(
1215 const std::string& context, const std::vector<Token>& cached_tokens,
1216 const std::vector<Locale>& detected_text_language_tags,
1217 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1218 FeatureProcessor::EmbeddingCache* embedding_cache,
1219 std::vector<ClassificationResult>* classification_results,
1220 std::vector<Token>* tokens) const {
1221 if (model_->triggering_options() == nullptr ||
1222 !(model_->triggering_options()->enabled_modes() &
1223 ModeFlag_CLASSIFICATION)) {
1224 return true;
1225 }
1226
Tony Makdf54e742019-03-26 14:04:00 +00001227 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1228 ml_model_triggering_locales_,
1229 /*default_value=*/true)) {
1230 return true;
1231 }
1232
Lukas Zilkaba849e72018-03-08 14:48:21 +01001233 if (cached_tokens.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001234 *tokens = classification_feature_processor_->Tokenize(context);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001235 } else {
Tony Mak378c1f52019-03-04 15:58:11 +00001236 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1237 ClassifyTextUpperBoundNeededTokens());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001238 }
1239
Lukas Zilkab23e2122018-02-09 10:25:19 +01001240 int click_pos;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001241 classification_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001242 context, selection_indices,
1243 classification_feature_processor_->GetOptions()
1244 ->only_use_line_with_click(),
Tony Mak378c1f52019-03-04 15:58:11 +00001245 tokens, &click_pos);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001246 const TokenSpan selection_token_span =
Tony Mak378c1f52019-03-04 15:58:11 +00001247 CodepointSpanToTokenSpan(*tokens, selection_indices);
Lukas Zilka434442d2018-04-25 11:38:51 +02001248 const int selection_num_tokens = TokenSpanSize(selection_token_span);
1249 if (model_->classification_options()->max_num_tokens() > 0 &&
1250 model_->classification_options()->max_num_tokens() <
1251 selection_num_tokens) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001252 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001253 return true;
1254 }
1255
Lukas Zilkab23e2122018-02-09 10:25:19 +01001256 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1257 bounds_sensitive_features =
1258 classification_feature_processor_->GetOptions()
1259 ->bounds_sensitive_features();
1260 if (selection_token_span.first == kInvalidIndex ||
1261 selection_token_span.second == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001262 TC3_LOG(ERROR) << "Could not determine span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001263 return false;
1264 }
1265
1266 // Compute the extraction span based on the model type.
1267 TokenSpan extraction_span;
1268 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1269 // The extraction span is the selection span expanded to include a relevant
1270 // number of tokens outside of the bounds of the selection.
1271 extraction_span = ExpandTokenSpan(
1272 selection_token_span,
1273 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1274 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1275 } else {
1276 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001277 TC3_LOG(ERROR) << "Couldn't choose a click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001278 return false;
1279 }
1280 // The extraction span is the clicked token with context_size tokens on
1281 // either side.
1282 const int context_size =
Lukas Zilkaba849e72018-03-08 14:48:21 +01001283 classification_feature_processor_->GetOptions()->context_size();
Lukas Zilkab23e2122018-02-09 10:25:19 +01001284 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
1285 /*num_tokens_left=*/context_size,
1286 /*num_tokens_right=*/context_size);
1287 }
Tony Mak378c1f52019-03-04 15:58:11 +00001288 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001289
Lukas Zilka434442d2018-04-25 11:38:51 +02001290 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak378c1f52019-03-04 15:58:11 +00001291 *tokens, extraction_span)) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001292 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001293 return true;
1294 }
1295
Lukas Zilka21d8c982018-01-24 11:11:20 +01001296 std::unique_ptr<CachedFeatures> cached_features;
1297 if (!classification_feature_processor_->ExtractFeatures(
Tony Mak378c1f52019-03-04 15:58:11 +00001298 *tokens, extraction_span, selection_indices,
1299 embedding_executor_.get(), embedding_cache,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001300 classification_feature_processor_->EmbeddingSize() +
1301 classification_feature_processor_->DenseFeaturesCount(),
1302 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001303 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001304 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001305 }
1306
Lukas Zilkab23e2122018-02-09 10:25:19 +01001307 std::vector<float> features;
1308 features.reserve(cached_features->OutputFeaturesSize());
1309 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1310 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1311 &features);
1312 } else {
1313 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001314 }
1315
Lukas Zilkaba849e72018-03-08 14:48:21 +01001316 TensorView<float> logits = classification_executor_->ComputeLogits(
1317 TensorView<float>(features.data(),
1318 {1, static_cast<int>(features.size())}),
1319 interpreter_manager->ClassificationInterpreter());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001320 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001321 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001322 return false;
1323 }
1324
1325 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1326 logits.dim(1) != classification_feature_processor_->NumCollections()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001327 TC3_LOG(ERROR) << "Mismatching output";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001328 return false;
1329 }
1330
1331 const std::vector<float> scores =
1332 ComputeSoftmax(logits.data(), logits.dim(1));
1333
Tony Mak81e52422019-04-30 09:34:45 +01001334 if (scores.empty()) {
1335 *classification_results = {{Collections::Other(), 1.0}};
1336 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001337 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001338
Tony Mak81e52422019-04-30 09:34:45 +01001339 const int best_score_index =
1340 std::max_element(scores.begin(), scores.end()) - scores.begin();
1341 const std::string top_collection =
1342 classification_feature_processor_->LabelToCollection(best_score_index);
1343
1344 // Sanity checks.
1345 if (top_collection == Collections::Phone()) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001346 const int digit_count = CountDigits(context, selection_indices);
1347 if (digit_count <
1348 model_->classification_options()->phone_min_num_digits() ||
1349 digit_count >
1350 model_->classification_options()->phone_max_num_digits()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001351 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001352 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001353 }
Tony Mak81e52422019-04-30 09:34:45 +01001354 } else if (top_collection == Collections::Address()) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001355 if (selection_num_tokens <
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001356 model_->classification_options()->address_min_num_tokens()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001357 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001358 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001359 }
Tony Mak81e52422019-04-30 09:34:45 +01001360 } else if (top_collection == Collections::Dictionary()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001361 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1362 dictionary_locales_,
Tony Makdf54e742019-03-26 14:04:00 +00001363 /*default_value=*/false)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001364 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001365 return true;
Tony Mak378c1f52019-03-04 15:58:11 +00001366 }
1367 }
Tony Mak81e52422019-04-30 09:34:45 +01001368
1369 *classification_results = {{top_collection, 1.0, scores[best_score_index]}};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001370 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001371}
1372
Tony Mak6c4cc672018-09-17 11:48:50 +01001373bool Annotator::RegexClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001374 const std::string& context, CodepointSpan selection_indices,
Tony Mak378c1f52019-03-04 15:58:11 +00001375 std::vector<ClassificationResult>* classification_result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001376 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001377 UTF8ToUnicodeText(context, /*do_copy=*/false)
1378 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001379 const UnicodeText selection_text_unicode(
1380 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1381
1382 // Check whether any of the regular expressions match.
1383 for (const int pattern_id : classification_regex_patterns_) {
1384 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1385 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1386 regex_pattern.pattern->Matcher(selection_text_unicode);
1387 int status = UniLib::RegexMatcher::kNoError;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001388 bool matches;
Tony Mak854015a2019-01-16 15:56:48 +00001389 if (regex_pattern.config->use_approximate_matching()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001390 matches = matcher->ApproximatelyMatches(&status);
1391 } else {
1392 matches = matcher->Matches(&status);
1393 }
1394 if (status != UniLib::RegexMatcher::kNoError) {
1395 return false;
1396 }
Tony Makdf54e742019-03-26 14:04:00 +00001397 if (matches && VerifyRegexMatchCandidate(
1398 context, regex_pattern.config->verification_options(),
1399 selection_text, matcher.get())) {
Tony Mak378c1f52019-03-04 15:58:11 +00001400 classification_result->push_back(
1401 {regex_pattern.config->collection_name()->str(),
1402 regex_pattern.config->target_classification_score(),
1403 regex_pattern.config->priority_score()});
Tony Makd9446602019-02-20 18:25:39 +00001404 if (!SerializedEntityDataFromRegexMatch(
1405 regex_pattern.config, matcher.get(),
Tony Mak378c1f52019-03-04 15:58:11 +00001406 &classification_result->back().serialized_entity_data)) {
Tony Makd9446602019-02-20 18:25:39 +00001407 TC3_LOG(ERROR) << "Could not get entity data.";
1408 return false;
1409 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001410 }
1411 }
1412
Tony Mak378c1f52019-03-04 15:58:11 +00001413 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001414}
1415
Tony Mak5dc5e112019-02-01 14:52:10 +00001416namespace {
1417std::string PickCollectionForDatetime(
1418 const DatetimeParseResult& datetime_parse_result) {
1419 switch (datetime_parse_result.granularity) {
1420 case GRANULARITY_HOUR:
1421 case GRANULARITY_MINUTE:
1422 case GRANULARITY_SECOND:
1423 return Collections::DateTime();
1424 default:
1425 return Collections::Date();
1426 }
1427}
Tony Mak83d2de62019-04-10 16:12:15 +01001428
1429std::string CreateDatetimeSerializedEntityData(
1430 const DatetimeParseResult& parse_result) {
1431 EntityDataT entity_data;
1432 entity_data.datetime.reset(new EntityData_::DatetimeT());
1433 entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
1434 entity_data.datetime->granularity =
1435 static_cast<EntityData_::Datetime_::Granularity>(
1436 parse_result.granularity);
1437
Tony Maka2a1ff42019-09-12 15:40:32 +01001438 for (const auto& c : parse_result.datetime_components) {
1439 EntityData_::Datetime_::DatetimeComponentT datetime_component;
1440 datetime_component.absolute_value = c.value;
1441 datetime_component.relative_count = c.relative_count;
1442 datetime_component.component_type =
1443 static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
1444 c.component_type);
1445 datetime_component.relation_type =
1446 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
1447 if (c.relative_qualifier !=
1448 DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
1449 datetime_component.relation_type =
1450 EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
1451 }
1452 entity_data.datetime->datetime_component.emplace_back(
1453 new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
1454 }
Tony Mak83d2de62019-04-10 16:12:15 +01001455 flatbuffers::FlatBufferBuilder builder;
1456 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
1457 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
1458 builder.GetSize());
1459}
Tony Mak5dc5e112019-02-01 14:52:10 +00001460} // namespace
1461
Tony Mak6c4cc672018-09-17 11:48:50 +01001462bool Annotator::DatetimeClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001463 const std::string& context, CodepointSpan selection_indices,
1464 const ClassificationOptions& options,
Tony Mak854015a2019-01-16 15:56:48 +00001465 std::vector<ClassificationResult>* classification_results) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001466 if (!datetime_parser_) {
1467 return false;
1468 }
1469
Lukas Zilkab23e2122018-02-09 10:25:19 +01001470 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001471 UTF8ToUnicodeText(context, /*do_copy=*/false)
1472 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001473
1474 std::vector<DatetimeParseResultSpan> datetime_spans;
1475 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1476 options.reference_timezone, options.locales,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001477 ModeFlag_CLASSIFICATION,
Tony Makd9446602019-02-20 18:25:39 +00001478 options.annotation_usecase,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001479 /*anchor_start_end=*/true, &datetime_spans)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001480 TC3_LOG(ERROR) << "Error during parsing datetime.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001481 return false;
1482 }
1483 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1484 // Only consider the result valid if the selection and extracted datetime
1485 // spans exactly match.
1486 if (std::make_pair(datetime_span.span.first + selection_indices.first,
1487 datetime_span.span.second + selection_indices.first) ==
1488 selection_indices) {
Tony Mak854015a2019-01-16 15:56:48 +00001489 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1490 classification_results->emplace_back(
Tony Mak5dc5e112019-02-01 14:52:10 +00001491 PickCollectionForDatetime(parse_result),
1492 datetime_span.target_classification_score);
Tony Mak854015a2019-01-16 15:56:48 +00001493 classification_results->back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01001494 classification_results->back().serialized_entity_data =
1495 CreateDatetimeSerializedEntityData(parse_result);
Tony Mak378c1f52019-03-04 15:58:11 +00001496 classification_results->back().priority_score =
1497 datetime_span.priority_score;
Tony Mak854015a2019-01-16 15:56:48 +00001498 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001499 return true;
1500 }
1501 }
Tony Mak378c1f52019-03-04 15:58:11 +00001502 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001503}
1504
Tony Mak6c4cc672018-09-17 11:48:50 +01001505std::vector<ClassificationResult> Annotator::ClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001506 const std::string& context, CodepointSpan selection_indices,
1507 const ClassificationOptions& options) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01001508 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001509 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001510 return {};
1511 }
1512
Lukas Zilkaba849e72018-03-08 14:48:21 +01001513 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1514 return {};
1515 }
1516
Tony Makdf54e742019-03-26 14:04:00 +00001517 std::vector<Locale> detected_text_language_tags;
1518 if (!ParseLocales(options.detected_text_language_tags,
1519 &detected_text_language_tags)) {
1520 TC3_LOG(WARNING)
1521 << "Failed to parse the detected_text_language_tags in options: "
1522 << options.detected_text_language_tags;
1523 }
1524 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1525 model_triggering_locales_,
1526 /*default_value=*/true)) {
1527 return {};
1528 }
1529
Tony Mak968412a2019-11-13 15:39:57 +00001530 if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
1531 selection_indices)) {
1532 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
Tony Mak6c4cc672018-09-17 11:48:50 +01001533 << std::get<0>(selection_indices) << " "
1534 << std::get<1>(selection_indices);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001535 return {};
1536 }
1537
Tony Mak378c1f52019-03-04 15:58:11 +00001538 // We'll accumulate a list of candidates, and pick the best candidate in the
1539 // end.
1540 std::vector<AnnotatedSpan> candidates;
1541
Tony Mak6c4cc672018-09-17 11:48:50 +01001542 // Try the knowledge engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001543 // TODO(b/126579108): Propagate error status.
Tony Mak6c4cc672018-09-17 11:48:50 +01001544 ClassificationResult knowledge_result;
1545 if (knowledge_engine_ && knowledge_engine_->ClassifyText(
Tony Maka2a1ff42019-09-12 15:40:32 +01001546 context, selection_indices,
1547 options.annotation_usecase, &knowledge_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001548 candidates.push_back({selection_indices, {knowledge_result}});
1549 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
Tony Mak854015a2019-01-16 15:56:48 +00001550 }
1551
Tony Maka2a1ff42019-09-12 15:40:32 +01001552 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1553
Tony Mak854015a2019-01-16 15:56:48 +00001554 // Try the contact engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001555 // TODO(b/126579108): Propagate error status.
Tony Mak854015a2019-01-16 15:56:48 +00001556 ClassificationResult contact_result;
1557 if (contact_engine_ && contact_engine_->ClassifyText(
1558 context, selection_indices, &contact_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001559 candidates.push_back({selection_indices, {contact_result}});
Tony Mak6c4cc672018-09-17 11:48:50 +01001560 }
1561
Tony Makd9446602019-02-20 18:25:39 +00001562 // Try the installed app engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001563 // TODO(b/126579108): Propagate error status.
Tony Makd9446602019-02-20 18:25:39 +00001564 ClassificationResult installed_app_result;
1565 if (installed_app_engine_ &&
1566 installed_app_engine_->ClassifyText(context, selection_indices,
1567 &installed_app_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001568 candidates.push_back({selection_indices, {installed_app_result}});
Tony Makd9446602019-02-20 18:25:39 +00001569 }
1570
Lukas Zilkab23e2122018-02-09 10:25:19 +01001571 // Try the regular expression models.
Tony Mak378c1f52019-03-04 15:58:11 +00001572 std::vector<ClassificationResult> regex_results;
1573 if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1574 return {};
1575 }
1576 for (const ClassificationResult& result : regex_results) {
1577 candidates.push_back({selection_indices, {result}});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001578 }
1579
Lukas Zilkab23e2122018-02-09 10:25:19 +01001580 // Try the date model.
Tony Mak378c1f52019-03-04 15:58:11 +00001581 //
1582 // DatetimeClassifyText only returns the first result, which can however have
1583 // more interpretations. They are inserted in the candidates as a single
1584 // AnnotatedSpan, so that they get treated together by the conflict resolution
1585 // algorithm.
Tony Mak854015a2019-01-16 15:56:48 +00001586 std::vector<ClassificationResult> datetime_results;
Tony Mak378c1f52019-03-04 15:58:11 +00001587 if (!DatetimeClassifyText(context, selection_indices, options,
1588 &datetime_results)) {
1589 return {};
1590 }
1591 if (!datetime_results.empty()) {
1592 candidates.push_back({selection_indices, std::move(datetime_results)});
Tony Mak448b5862019-03-22 13:36:41 +00001593 candidates.back().source = AnnotatedSpan::Source::DATETIME;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001594 }
1595
Tony Mak378c1f52019-03-04 15:58:11 +00001596 // Try the number annotator.
1597 // TODO(b/126579108): Propagate error status.
1598 ClassificationResult number_annotator_result;
1599 if (number_annotator_ &&
1600 number_annotator_->ClassifyText(
1601 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1602 options.annotation_usecase, &number_annotator_result)) {
1603 candidates.push_back({selection_indices, {number_annotator_result}});
1604 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001605
Tony Makad2e22d2019-03-20 17:35:13 +00001606 // Try the duration annotator.
1607 ClassificationResult duration_annotator_result;
1608 if (duration_annotator_ &&
1609 duration_annotator_->ClassifyText(
1610 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1611 options.annotation_usecase, &duration_annotator_result)) {
1612 candidates.push_back({selection_indices, {duration_annotator_result}});
Tony Mak448b5862019-03-22 13:36:41 +00001613 candidates.back().source = AnnotatedSpan::Source::DURATION;
Tony Makad2e22d2019-03-20 17:35:13 +00001614 }
1615
Tony Mak378c1f52019-03-04 15:58:11 +00001616 // Try the ML model.
1617 //
1618 // The output of the model is considered as an exclusive 1-of-N choice. That's
1619 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1620 // span for each candidate, like e.g. the regex model.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001621 InterpreterManager interpreter_manager(selection_executor_.get(),
1622 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001623 std::vector<ClassificationResult> model_results;
1624 std::vector<Token> tokens;
1625 if (!ModelClassifyText(
1626 context, /*cached_tokens=*/{}, detected_text_language_tags,
1627 selection_indices, &interpreter_manager,
1628 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1629 return {};
1630 }
1631 if (!model_results.empty()) {
1632 candidates.push_back({selection_indices, std::move(model_results)});
1633 }
1634
1635 std::vector<int> candidate_indices;
1636 if (!ResolveConflicts(candidates, context, tokens,
1637 detected_text_language_tags, options.annotation_usecase,
1638 &interpreter_manager, &candidate_indices)) {
1639 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1640 return {};
1641 }
1642
1643 std::vector<ClassificationResult> results;
1644 for (const int i : candidate_indices) {
1645 for (const ClassificationResult& result : candidates[i].classification) {
1646 if (!FilteredForClassification(result)) {
1647 results.push_back(result);
1648 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001649 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001650 }
1651
Tony Mak378c1f52019-03-04 15:58:11 +00001652 // Sort results according to score.
1653 std::sort(results.begin(), results.end(),
1654 [](const ClassificationResult& a, const ClassificationResult& b) {
1655 return a.score > b.score;
1656 });
1657
1658 if (results.empty()) {
Tony Mak81e52422019-04-30 09:34:45 +01001659 results = {{Collections::Other(), 1.0}};
Tony Mak378c1f52019-03-04 15:58:11 +00001660 }
Tony Mak378c1f52019-03-04 15:58:11 +00001661 return results;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001662}
1663
Tony Mak378c1f52019-03-04 15:58:11 +00001664bool Annotator::ModelAnnotate(
1665 const std::string& context,
1666 const std::vector<Locale>& detected_text_language_tags,
1667 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1668 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001669 if (model_->triggering_options() == nullptr ||
1670 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1671 return true;
1672 }
1673
Tony Makdf54e742019-03-26 14:04:00 +00001674 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1675 ml_model_triggering_locales_,
1676 /*default_value=*/true)) {
1677 return true;
1678 }
1679
Lukas Zilka21d8c982018-01-24 11:11:20 +01001680 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1681 /*do_copy=*/false);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001682 std::vector<UnicodeTextRange> lines;
1683 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1684 lines.push_back({context_unicode.begin(), context_unicode.end()});
1685 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +01001686 lines = selection_feature_processor_->SplitContext(
1687 context_unicode, selection_feature_processor_->GetOptions()
1688 ->use_pipe_character_for_newline());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001689 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001690
Lukas Zilkaba849e72018-03-08 14:48:21 +01001691 const float min_annotate_confidence =
1692 (model_->triggering_options() != nullptr
1693 ? model_->triggering_options()->min_annotate_confidence()
1694 : 0.f);
1695
Lukas Zilkab23e2122018-02-09 10:25:19 +01001696 for (const UnicodeTextRange& line : lines) {
Tony Mak408c6b82019-03-08 17:57:27 +00001697 FeatureProcessor::EmbeddingCache embedding_cache;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001698 const std::string line_str =
1699 UnicodeText::UTF8Substring(line.first, line.second);
1700
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001701 *tokens = selection_feature_processor_->Tokenize(line_str);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001702 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001703 line_str, {0, std::distance(line.first, line.second)},
1704 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001705 tokens,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001706 /*click_pos=*/nullptr);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001707 const TokenSpan full_line_span = {0, tokens->size()};
Lukas Zilka21d8c982018-01-24 11:11:20 +01001708
Lukas Zilka434442d2018-04-25 11:38:51 +02001709 // TODO(zilka): Add support for greater granularity of this check.
1710 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1711 *tokens, full_line_span)) {
1712 continue;
1713 }
1714
Lukas Zilka21d8c982018-01-24 11:11:20 +01001715 std::unique_ptr<CachedFeatures> cached_features;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001716 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001717 *tokens, full_line_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001718 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1719 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001720 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001721 selection_feature_processor_->EmbeddingSize() +
1722 selection_feature_processor_->DenseFeaturesCount(),
Lukas Zilka21d8c982018-01-24 11:11:20 +01001723 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001724 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001725 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001726 }
1727
1728 std::vector<TokenSpan> local_chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001729 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001730 interpreter_manager->SelectionInterpreter(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01001731 *cached_features, &local_chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001732 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001733 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001734 }
1735
1736 const int offset = std::distance(context_unicode.begin(), line.first);
1737 for (const TokenSpan& chunk : local_chunks) {
1738 const CodepointSpan codepoint_span =
1739 selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001740 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001741
1742 // Skip empty spans.
1743 if (codepoint_span.first != codepoint_span.second) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001744 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001745 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
1746 codepoint_span, interpreter_manager,
1747 &embedding_cache, &classification)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001748 TC3_LOG(ERROR) << "Could not classify text: "
1749 << (codepoint_span.first + offset) << " "
1750 << (codepoint_span.second + offset);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001751 return false;
1752 }
1753
1754 // Do not include the span if it's classified as "other".
1755 if (!classification.empty() && !ClassifiedAsOther(classification) &&
1756 classification[0].score >= min_annotate_confidence) {
1757 AnnotatedSpan result_span;
1758 result_span.span = {codepoint_span.first + offset,
1759 codepoint_span.second + offset};
1760 result_span.classification = std::move(classification);
1761 result->push_back(std::move(result_span));
1762 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001763 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001764 }
1765 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001766 return true;
1767}
1768
Tony Mak6c4cc672018-09-17 11:48:50 +01001769const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
Lukas Zilka434442d2018-04-25 11:38:51 +02001770 return selection_feature_processor_.get();
1771}
1772
Tony Mak6c4cc672018-09-17 11:48:50 +01001773const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
Lukas Zilka434442d2018-04-25 11:38:51 +02001774 const {
1775 return classification_feature_processor_.get();
Lukas Zilkab23e2122018-02-09 10:25:19 +01001776}
1777
Tony Mak6c4cc672018-09-17 11:48:50 +01001778const DatetimeParser* Annotator::DatetimeParserForTests() const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001779 return datetime_parser_.get();
1780}
1781
Tony Mak83d2de62019-04-10 16:12:15 +01001782void Annotator::RemoveNotEnabledEntityTypes(
1783 const EnabledEntityTypes& is_entity_type_enabled,
1784 std::vector<AnnotatedSpan>* annotated_spans) const {
1785 for (AnnotatedSpan& annotated_span : *annotated_spans) {
1786 std::vector<ClassificationResult>& classifications =
1787 annotated_span.classification;
1788 classifications.erase(
1789 std::remove_if(classifications.begin(), classifications.end(),
1790 [&is_entity_type_enabled](
1791 const ClassificationResult& classification_result) {
1792 return !is_entity_type_enabled(
1793 classification_result.collection);
1794 }),
1795 classifications.end());
1796 }
1797 annotated_spans->erase(
1798 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
1799 [](const AnnotatedSpan& annotated_span) {
1800 return annotated_span.classification.empty();
1801 }),
1802 annotated_spans->end());
1803}
1804
Tony Maka2a1ff42019-09-12 15:40:32 +01001805void Annotator::AddContactMetadataToKnowledgeClassificationResults(
1806 std::vector<AnnotatedSpan>* candidates) const {
1807 if (candidates == nullptr || contact_engine_ == nullptr) {
1808 return;
1809 }
1810 for (auto& candidate : *candidates) {
1811 for (auto& classification_result : candidate.classification) {
1812 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
1813 &classification_result);
1814 }
1815 }
1816}
1817
Tony Mak6c4cc672018-09-17 11:48:50 +01001818std::vector<AnnotatedSpan> Annotator::Annotate(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001819 const std::string& context, const AnnotationOptions& options) const {
1820 std::vector<AnnotatedSpan> candidates;
1821
Lukas Zilkaba849e72018-03-08 14:48:21 +01001822 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
1823 return {};
1824 }
1825
Tony Mak854015a2019-01-16 15:56:48 +00001826 const UnicodeText context_unicode =
1827 UTF8ToUnicodeText(context, /*do_copy=*/false);
1828 if (!context_unicode.is_valid()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001829 return {};
1830 }
1831
Tony Mak378c1f52019-03-04 15:58:11 +00001832 std::vector<Locale> detected_text_language_tags;
1833 if (!ParseLocales(options.detected_text_language_tags,
1834 &detected_text_language_tags)) {
1835 TC3_LOG(WARNING)
Tony Makdf54e742019-03-26 14:04:00 +00001836 << "Failed to parse the detected_text_language_tags in options: "
Tony Mak378c1f52019-03-04 15:58:11 +00001837 << options.detected_text_language_tags;
1838 }
Tony Makdf54e742019-03-26 14:04:00 +00001839 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1840 model_triggering_locales_,
1841 /*default_value=*/true)) {
1842 return {};
1843 }
1844
1845 InterpreterManager interpreter_manager(selection_executor_.get(),
1846 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001847
Lukas Zilkab23e2122018-02-09 10:25:19 +01001848 // Annotate with the selection model.
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001849 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00001850 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
1851 &tokens, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001852 TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001853 return {};
1854 }
1855
1856 // Annotate with the regular expression models.
1857 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Tony Mak83d2de62019-04-10 16:12:15 +01001858 annotation_regex_patterns_, &candidates,
1859 options.is_serialized_entity_data_enabled)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001860 TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001861 return {};
1862 }
1863
1864 // Annotate with the datetime model.
Tony Mak83d2de62019-04-10 16:12:15 +01001865 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
1866 if ((is_entity_type_enabled(Collections::Date()) ||
1867 is_entity_type_enabled(Collections::DateTime())) &&
1868 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Lukas Zilkab23e2122018-02-09 10:25:19 +01001869 options.reference_time_ms_utc, options.reference_timezone,
Tony Makd9446602019-02-20 18:25:39 +00001870 options.locales, ModeFlag_ANNOTATION,
Tony Mak83d2de62019-04-10 16:12:15 +01001871 options.annotation_usecase,
1872 options.is_serialized_entity_data_enabled, &candidates)) {
Tony Maka2a1ff42019-09-12 15:40:32 +01001873 TC3_LOG(ERROR) << "Couldn't run DatetimeChunk.";
Tony Mak6c4cc672018-09-17 11:48:50 +01001874 return {};
1875 }
1876
Tony Maka2a1ff42019-09-12 15:40:32 +01001877 // Annotate with the knowledge engine into a temporary vector.
1878 std::vector<AnnotatedSpan> knowledge_candidates;
1879 if (knowledge_engine_ &&
1880 !knowledge_engine_->Chunk(context, options.annotation_usecase,
1881 &knowledge_candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001882 TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001883 return {};
1884 }
1885
Tony Maka2a1ff42019-09-12 15:40:32 +01001886 AddContactMetadataToKnowledgeClassificationResults(&knowledge_candidates);
1887
1888 // Move the knowledge candidates to the full candidate list, and erase
1889 // knowledge_candidates.
1890 candidates.insert(candidates.end(),
1891 std::make_move_iterator(knowledge_candidates.begin()),
1892 std::make_move_iterator(knowledge_candidates.end()));
1893 knowledge_candidates.clear();
1894
Tony Mak854015a2019-01-16 15:56:48 +00001895 // Annotate with the contact engine.
1896 if (contact_engine_ &&
1897 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
1898 TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
1899 return {};
1900 }
1901
Tony Makd9446602019-02-20 18:25:39 +00001902 // Annotate with the installed app engine.
1903 if (installed_app_engine_ &&
1904 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
1905 TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
1906 return {};
1907 }
1908
Tony Mak378c1f52019-03-04 15:58:11 +00001909 // Annotate with the number annotator.
1910 if (number_annotator_ != nullptr &&
1911 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
1912 &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +00001913 TC3_LOG(ERROR) << "Couldn't run number annotator FindAll.";
1914 return {};
1915 }
1916
1917 // Annotate with the duration annotator.
Tony Mak83d2de62019-04-10 16:12:15 +01001918 if (is_entity_type_enabled(Collections::Duration()) &&
1919 duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +00001920 !duration_annotator_->FindAll(context_unicode, tokens,
1921 options.annotation_usecase, &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +00001922 TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll.";
Tony Mak378c1f52019-03-04 15:58:11 +00001923 return {};
1924 }
1925
Lukas Zilkab23e2122018-02-09 10:25:19 +01001926 // Sort candidates according to their position in the input, so that the next
1927 // code can assume that any connected component of overlapping spans forms a
1928 // contiguous block.
1929 std::sort(candidates.begin(), candidates.end(),
1930 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
1931 return a.span.first < b.span.first;
1932 });
1933
1934 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00001935 if (!ResolveConflicts(candidates, context, tokens,
1936 detected_text_language_tags, options.annotation_usecase,
1937 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001938 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001939 return {};
1940 }
1941
Lukas Zilkab23e2122018-02-09 10:25:19 +01001942 std::vector<AnnotatedSpan> result;
1943 result.reserve(candidate_indices.size());
Tony Mak378c1f52019-03-04 15:58:11 +00001944 AnnotatedSpan aggregated_span;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001945 for (const int i : candidate_indices) {
Tony Mak378c1f52019-03-04 15:58:11 +00001946 if (candidates[i].span != aggregated_span.span) {
1947 if (!aggregated_span.classification.empty()) {
1948 result.push_back(std::move(aggregated_span));
1949 }
1950 aggregated_span =
1951 AnnotatedSpan(candidates[i].span, /*arg_classification=*/{});
Lukas Zilkab23e2122018-02-09 10:25:19 +01001952 }
Tony Mak378c1f52019-03-04 15:58:11 +00001953 if (candidates[i].classification.empty() ||
1954 ClassifiedAsOther(candidates[i].classification) ||
1955 FilteredForAnnotation(candidates[i])) {
1956 continue;
1957 }
1958 for (ClassificationResult& classification : candidates[i].classification) {
1959 aggregated_span.classification.push_back(std::move(classification));
1960 }
1961 }
1962 if (!aggregated_span.classification.empty()) {
1963 result.push_back(std::move(aggregated_span));
1964 }
1965
Tony Mak83d2de62019-04-10 16:12:15 +01001966 // We generate all candidates and remove them later (with the exception of
1967 // date/time/duration entities) because there are complex interdependencies
1968 // between the entity types. E.g., the TLD of an email can be interpreted as a
1969 // URL, but most likely a user of the API does not want such annotations if
1970 // "url" is enabled and "email" is not.
1971 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
1972
Tony Mak378c1f52019-03-04 15:58:11 +00001973 for (AnnotatedSpan& annotated_span : result) {
1974 SortClassificationResults(&annotated_span.classification);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001975 }
1976
Lukas Zilka21d8c982018-01-24 11:11:20 +01001977 return result;
1978}
1979
Tony Mak854015a2019-01-16 15:56:48 +00001980CodepointSpan Annotator::ComputeSelectionBoundaries(
1981 const UniLib::RegexMatcher* match,
1982 const RegexModel_::Pattern* config) const {
1983 if (config->capturing_group() == nullptr) {
1984 // Use first capturing group to specify the selection.
1985 int status = UniLib::RegexMatcher::kNoError;
1986 const CodepointSpan result = {match->Start(1, &status),
1987 match->End(1, &status)};
1988 if (status != UniLib::RegexMatcher::kNoError) {
1989 return {kInvalidIndex, kInvalidIndex};
1990 }
1991 return result;
1992 }
1993
1994 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
1995 const int num_groups = config->capturing_group()->size();
1996 for (int i = 0; i < num_groups; i++) {
1997 if (!config->capturing_group()->Get(i)->extend_selection()) {
1998 continue;
1999 }
2000
2001 int status = UniLib::RegexMatcher::kNoError;
2002 // Check match and adjust bounds.
2003 const int group_start = match->Start(i, &status);
2004 const int group_end = match->End(i, &status);
2005 if (status != UniLib::RegexMatcher::kNoError) {
2006 return {kInvalidIndex, kInvalidIndex};
2007 }
2008 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2009 continue;
2010 }
2011 if (result.first == kInvalidIndex) {
2012 result = {group_start, group_end};
2013 } else {
2014 result.first = std::min(result.first, group_start);
2015 result.second = std::max(result.second, group_end);
2016 }
2017 }
2018 return result;
2019}
2020
Tony Makd9446602019-02-20 18:25:39 +00002021bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2022 if (pattern->serialized_entity_data() != nullptr) {
2023 return true;
2024 }
2025 if (pattern->capturing_group() != nullptr) {
2026 for (const RegexModel_::Pattern_::CapturingGroup* group :
2027 *pattern->capturing_group()) {
2028 if (group->entity_field_path() != nullptr) {
2029 return true;
2030 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002031 if (group->serialized_entity_data() != nullptr) {
2032 return true;
2033 }
Tony Makd9446602019-02-20 18:25:39 +00002034 }
2035 }
2036 return false;
2037}
2038
2039bool Annotator::SerializedEntityDataFromRegexMatch(
2040 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2041 std::string* serialized_entity_data) const {
2042 if (!HasEntityData(pattern)) {
2043 serialized_entity_data->clear();
2044 return true;
2045 }
2046 TC3_CHECK(entity_data_builder_ != nullptr);
2047
2048 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
2049 entity_data_builder_->NewRoot();
2050
2051 TC3_CHECK(entity_data != nullptr);
2052
2053 // Set static entity data.
2054 if (pattern->serialized_entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002055 entity_data->MergeFromSerializedFlatbuffer(
2056 StringPiece(pattern->serialized_entity_data()->c_str(),
2057 pattern->serialized_entity_data()->size()));
2058 }
2059
2060 // Add entity data from rule capturing groups.
2061 if (pattern->capturing_group() != nullptr) {
2062 const int num_groups = pattern->capturing_group()->size();
2063 for (int i = 0; i < num_groups; i++) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002064 const RegexModel_::Pattern_::CapturingGroup* group =
2065 pattern->capturing_group()->Get(i);
2066
2067 // Check whether the group matched.
2068 Optional<std::string> group_match_text =
2069 GetCapturingGroupText(matcher, /*group_id=*/i);
2070 if (!group_match_text.has_value()) {
Tony Makd9446602019-02-20 18:25:39 +00002071 continue;
2072 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002073
2074 // Set static entity data from capturing group match.
2075 if (group->serialized_entity_data() != nullptr) {
2076 entity_data->MergeFromSerializedFlatbuffer(
2077 StringPiece(group->serialized_entity_data()->c_str(),
2078 group->serialized_entity_data()->size()));
2079 }
2080
2081 // Set entity field from capturing group text.
2082 if (group->entity_field_path() != nullptr) {
Tony Mak8cd7ba62019-10-15 15:29:22 +01002083 UnicodeText normalized_group_match_text =
2084 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2085
2086 // Apply normalization if specified.
2087 if (group->normalization_options() != nullptr) {
2088 normalized_group_match_text =
2089 NormalizeText(unilib_, group->normalization_options(),
2090 normalized_group_match_text);
2091 }
2092
2093 if (!entity_data->ParseAndSet(
2094 group->entity_field_path(),
2095 normalized_group_match_text.ToUTF8String())) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002096 TC3_LOG(ERROR)
2097 << "Could not set entity data from rule capturing group.";
2098 return false;
2099 }
Tony Makd9446602019-02-20 18:25:39 +00002100 }
2101 }
2102 }
2103
2104 *serialized_entity_data = entity_data->Serialize();
2105 return true;
2106}
2107
Tony Mak6c4cc672018-09-17 11:48:50 +01002108bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2109 const std::vector<int>& rules,
Tony Mak83d2de62019-04-10 16:12:15 +01002110 std::vector<AnnotatedSpan>* result,
2111 bool is_serialized_entity_data_enabled) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002112 for (int pattern_id : rules) {
2113 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2114 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2115 if (!matcher) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002116 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2117 << pattern_id;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002118 return false;
2119 }
2120
2121 int status = UniLib::RegexMatcher::kNoError;
2122 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
Tony Mak854015a2019-01-16 15:56:48 +00002123 if (regex_pattern.config->verification_options()) {
Tony Makdf54e742019-03-26 14:04:00 +00002124 if (!VerifyRegexMatchCandidate(
2125 context_unicode.ToUTF8String(),
2126 regex_pattern.config->verification_options(),
2127 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002128 continue;
2129 }
2130 }
Tony Makd9446602019-02-20 18:25:39 +00002131
2132 std::string serialized_entity_data;
Tony Mak83d2de62019-04-10 16:12:15 +01002133 if (is_serialized_entity_data_enabled) {
2134 if (!SerializedEntityDataFromRegexMatch(
2135 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2136 TC3_LOG(ERROR) << "Could not get entity data.";
2137 return false;
2138 }
Tony Makd9446602019-02-20 18:25:39 +00002139 }
2140
Lukas Zilkab23e2122018-02-09 10:25:19 +01002141 result->emplace_back();
Tony Mak854015a2019-01-16 15:56:48 +00002142
Lukas Zilkab23e2122018-02-09 10:25:19 +01002143 // Selection/annotation regular expressions need to specify a capturing
2144 // group specifying the selection.
Tony Mak854015a2019-01-16 15:56:48 +00002145 result->back().span =
2146 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2147
Lukas Zilkab23e2122018-02-09 10:25:19 +01002148 result->back().classification = {
Tony Mak854015a2019-01-16 15:56:48 +00002149 {regex_pattern.config->collection_name()->str(),
2150 regex_pattern.config->target_classification_score(),
2151 regex_pattern.config->priority_score()}};
Tony Makd9446602019-02-20 18:25:39 +00002152
2153 result->back().classification[0].serialized_entity_data =
2154 serialized_entity_data;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002155 }
2156 }
2157 return true;
2158}
2159
Tony Mak6c4cc672018-09-17 11:48:50 +01002160bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2161 tflite::Interpreter* selection_interpreter,
2162 const CachedFeatures& cached_features,
2163 std::vector<TokenSpan>* chunks) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01002164 const int max_selection_span =
2165 selection_feature_processor_->GetOptions()->max_selection_span();
Lukas Zilka21d8c982018-01-24 11:11:20 +01002166 // The inference span is the span of interest expanded to include
2167 // max_selection_span tokens on either side, which is how far a selection can
2168 // stretch from the click.
2169 const TokenSpan inference_span = IntersectTokenSpans(
2170 ExpandTokenSpan(span_of_interest,
2171 /*num_tokens_left=*/max_selection_span,
2172 /*num_tokens_right=*/max_selection_span),
2173 {0, num_tokens});
2174
2175 std::vector<ScoredChunk> scored_chunks;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002176 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2177 selection_feature_processor_->GetOptions()
2178 ->bounds_sensitive_features()
2179 ->enabled()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002180 if (!ModelBoundsSensitiveScoreChunks(
2181 num_tokens, span_of_interest, inference_span, cached_features,
2182 selection_interpreter, &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002183 return false;
2184 }
2185 } else {
2186 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002187 cached_features, selection_interpreter,
2188 &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002189 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002190 }
2191 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002192 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2193 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2194 return lhs.score < rhs.score;
2195 });
Lukas Zilka21d8c982018-01-24 11:11:20 +01002196
2197 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2198 // them greedily as long as they do not overlap with any previously picked
2199 // chunks.
2200 std::vector<bool> token_used(TokenSpanSize(inference_span));
2201 chunks->clear();
2202 for (const ScoredChunk& scored_chunk : scored_chunks) {
2203 bool feasible = true;
2204 for (int i = scored_chunk.token_span.first;
2205 i < scored_chunk.token_span.second; ++i) {
2206 if (token_used[i - inference_span.first]) {
2207 feasible = false;
2208 break;
2209 }
2210 }
2211
2212 if (!feasible) {
2213 continue;
2214 }
2215
2216 for (int i = scored_chunk.token_span.first;
2217 i < scored_chunk.token_span.second; ++i) {
2218 token_used[i - inference_span.first] = true;
2219 }
2220
2221 chunks->push_back(scored_chunk.token_span);
2222 }
2223
2224 std::sort(chunks->begin(), chunks->end());
2225
2226 return true;
2227}
2228
Lukas Zilkab23e2122018-02-09 10:25:19 +01002229namespace {
2230// Updates the value at the given key in the map to maximum of the current value
2231// and the given value, or simply inserts the value if the key is not yet there.
2232template <typename Map>
2233void UpdateMax(Map* map, typename Map::key_type key,
2234 typename Map::mapped_type value) {
2235 const auto it = map->find(key);
2236 if (it != map->end()) {
2237 it->second = std::max(it->second, value);
2238 } else {
2239 (*map)[key] = value;
2240 }
2241}
2242} // namespace
2243
Tony Mak6c4cc672018-09-17 11:48:50 +01002244bool Annotator::ModelClickContextScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002245 int num_tokens, const TokenSpan& span_of_interest,
2246 const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002247 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002248 std::vector<ScoredChunk>* scored_chunks) const {
2249 const int max_batch_size = model_->selection_options()->batch_size();
2250
2251 std::vector<float> all_features;
2252 std::map<TokenSpan, float> chunk_scores;
2253 for (int batch_start = span_of_interest.first;
2254 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2255 const int batch_end =
2256 std::min(batch_start + max_batch_size, span_of_interest.second);
2257
2258 // Prepare features for the whole batch.
2259 all_features.clear();
2260 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2261 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2262 cached_features.AppendClickContextFeaturesForClick(click_pos,
2263 &all_features);
2264 }
2265
2266 // Run batched inference.
2267 const int batch_size = batch_end - batch_start;
2268 const int features_size = cached_features.OutputFeaturesSize();
2269 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002270 TensorView<float>(all_features.data(), {batch_size, features_size}),
2271 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002272 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002273 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002274 return false;
2275 }
2276 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2277 logits.dim(1) !=
2278 selection_feature_processor_->GetSelectionLabelCount()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002279 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002280 return false;
2281 }
2282
2283 // Save results.
2284 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2285 const std::vector<float> scores = ComputeSoftmax(
2286 logits.data() + logits.dim(1) * (click_pos - batch_start),
2287 logits.dim(1));
2288 for (int j = 0;
2289 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2290 TokenSpan relative_token_span;
2291 if (!selection_feature_processor_->LabelToTokenSpan(
2292 j, &relative_token_span)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002293 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002294 return false;
2295 }
2296 const TokenSpan candidate_span = ExpandTokenSpan(
2297 SingleTokenSpan(click_pos), relative_token_span.first,
2298 relative_token_span.second);
2299 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2300 UpdateMax(&chunk_scores, candidate_span, scores[j]);
2301 }
2302 }
2303 }
2304 }
2305
2306 scored_chunks->clear();
2307 scored_chunks->reserve(chunk_scores.size());
2308 for (const auto& entry : chunk_scores) {
2309 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2310 }
2311
2312 return true;
2313}
2314
Tony Mak6c4cc672018-09-17 11:48:50 +01002315bool Annotator::ModelBoundsSensitiveScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002316 int num_tokens, const TokenSpan& span_of_interest,
2317 const TokenSpan& inference_span, const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002318 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002319 std::vector<ScoredChunk>* scored_chunks) const {
2320 const int max_selection_span =
2321 selection_feature_processor_->GetOptions()->max_selection_span();
2322 const int max_chunk_length = selection_feature_processor_->GetOptions()
2323 ->selection_reduced_output_space()
2324 ? max_selection_span + 1
2325 : 2 * max_selection_span + 1;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002326 const bool score_single_token_spans_as_zero =
2327 selection_feature_processor_->GetOptions()
2328 ->bounds_sensitive_features()
2329 ->score_single_token_spans_as_zero();
2330
2331 scored_chunks->clear();
2332 if (score_single_token_spans_as_zero) {
2333 scored_chunks->reserve(TokenSpanSize(span_of_interest));
2334 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002335
2336 // Prepare all chunk candidates into one batch:
2337 // - Are contained in the inference span
2338 // - Have a non-empty intersection with the span of interest
2339 // - Are at least one token long
2340 // - Are not longer than the maximum chunk length
2341 std::vector<TokenSpan> candidate_spans;
2342 for (int start = inference_span.first; start < span_of_interest.second;
2343 ++start) {
2344 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2345 for (int end = leftmost_end_index;
2346 end <= inference_span.second && end - start <= max_chunk_length;
2347 ++end) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002348 const TokenSpan candidate_span = {start, end};
2349 if (score_single_token_spans_as_zero &&
2350 TokenSpanSize(candidate_span) == 1) {
2351 // Do not include the single token span in the batch, add a zero score
2352 // for it directly to the output.
2353 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2354 } else {
2355 candidate_spans.push_back(candidate_span);
2356 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002357 }
2358 }
2359
2360 const int max_batch_size = model_->selection_options()->batch_size();
2361
2362 std::vector<float> all_features;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002363 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
Lukas Zilkab23e2122018-02-09 10:25:19 +01002364 for (int batch_start = 0; batch_start < candidate_spans.size();
2365 batch_start += max_batch_size) {
2366 const int batch_end = std::min(batch_start + max_batch_size,
2367 static_cast<int>(candidate_spans.size()));
2368
2369 // Prepare features for the whole batch.
2370 all_features.clear();
2371 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2372 for (int i = batch_start; i < batch_end; ++i) {
2373 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2374 &all_features);
2375 }
2376
2377 // Run batched inference.
2378 const int batch_size = batch_end - batch_start;
2379 const int features_size = cached_features.OutputFeaturesSize();
2380 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002381 TensorView<float>(all_features.data(), {batch_size, features_size}),
2382 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002383 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002384 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002385 return false;
2386 }
2387 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2388 logits.dim(1) != 1) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002389 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002390 return false;
2391 }
2392
2393 // Save results.
2394 for (int i = batch_start; i < batch_end; ++i) {
2395 scored_chunks->push_back(
2396 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2397 }
2398 }
2399
2400 return true;
2401}
2402
Tony Mak6c4cc672018-09-17 11:48:50 +01002403bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
2404 int64 reference_time_ms_utc,
2405 const std::string& reference_timezone,
2406 const std::string& locales, ModeFlag mode,
Tony Makd9446602019-02-20 18:25:39 +00002407 AnnotationUsecase annotation_usecase,
Tony Mak83d2de62019-04-10 16:12:15 +01002408 bool is_serialized_entity_data_enabled,
Tony Mak6c4cc672018-09-17 11:48:50 +01002409 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002410 if (!datetime_parser_) {
Lukas Zilka434442d2018-04-25 11:38:51 +02002411 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002412 }
2413
Lukas Zilkab23e2122018-02-09 10:25:19 +01002414 std::vector<DatetimeParseResultSpan> datetime_spans;
2415 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002416 reference_timezone, locales, mode,
Tony Makd9446602019-02-20 18:25:39 +00002417 annotation_usecase,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002418 /*anchor_start_end=*/false, &datetime_spans)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002419 return false;
2420 }
2421 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
Tony Mak378c1f52019-03-04 15:58:11 +00002422 AnnotatedSpan annotated_span;
2423 annotated_span.span = datetime_span.span;
Tony Mak854015a2019-01-16 15:56:48 +00002424 for (const DatetimeParseResult& parse_result : datetime_span.data) {
Tony Mak378c1f52019-03-04 15:58:11 +00002425 annotated_span.classification.emplace_back(
2426 PickCollectionForDatetime(parse_result),
2427 datetime_span.target_classification_score,
2428 datetime_span.priority_score);
2429 annotated_span.classification.back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01002430 if (is_serialized_entity_data_enabled) {
2431 annotated_span.classification.back().serialized_entity_data =
2432 CreateDatetimeSerializedEntityData(parse_result);
2433 }
Tony Mak854015a2019-01-16 15:56:48 +00002434 }
Tony Mak448b5862019-03-22 13:36:41 +00002435 annotated_span.source = AnnotatedSpan::Source::DATETIME;
Tony Mak378c1f52019-03-04 15:58:11 +00002436 result->push_back(std::move(annotated_span));
Lukas Zilkab23e2122018-02-09 10:25:19 +01002437 }
2438 return true;
2439}
2440
Tony Mak378c1f52019-03-04 15:58:11 +00002441const Model* Annotator::model() const { return model_; }
Tony Makd9446602019-02-20 18:25:39 +00002442const reflection::Schema* Annotator::entity_data_schema() const {
2443 return entity_data_schema_;
2444}
Tony Mak854015a2019-01-16 15:56:48 +00002445
Lukas Zilka21d8c982018-01-24 11:11:20 +01002446const Model* ViewModel(const void* buffer, int size) {
2447 if (!buffer) {
2448 return nullptr;
2449 }
2450
2451 return LoadAndVerifyModel(buffer, size);
2452}
2453
Tony Makd9446602019-02-20 18:25:39 +00002454bool Annotator::LookUpKnowledgeEntity(
2455 const std::string& id, std::string* serialized_knowledge_result) const {
2456 return knowledge_engine_ &&
2457 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
2458}
2459
Tony Mak6c4cc672018-09-17 11:48:50 +01002460} // namespace libtextclassifier3