blob: c4ef4b9bd07ace0056edc536d4c4f0322e5059a4 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <algorithm>
Lukas Zilka21d8c982018-01-24 11:11:20 +010020#include <cmath>
Tony Mak21460022020-03-12 18:29:35 +000021#include <cstddef>
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <iterator>
23#include <numeric>
Tony Mak63959242020-02-07 18:31:16 +000024#include <string>
Tony Mak448b5862019-03-22 13:36:41 +000025#include <unordered_map>
Tony Mak63959242020-02-07 18:31:16 +000026#include <vector>
Lukas Zilka21d8c982018-01-24 11:11:20 +010027
Tony Mak854015a2019-01-16 15:56:48 +000028#include "annotator/collections.h"
Tony Mak83d2de62019-04-10 16:12:15 +010029#include "annotator/model_generated.h"
30#include "annotator/types.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010031#include "utils/base/logging.h"
Tony Makff31efb2020-03-31 11:13:06 +010032#include "utils/base/status.h"
33#include "utils/base/statusor.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010034#include "utils/checksum.h"
Tony Mak63959242020-02-07 18:31:16 +000035#include "utils/i18n/locale.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010036#include "utils/math/softmax.h"
Tony Mak8cd7ba62019-10-15 15:29:22 +010037#include "utils/normalization.h"
Tony Maka2a1ff42019-09-12 15:40:32 +010038#include "utils/optional.h"
Tony Makd9446602019-02-20 18:25:39 +000039#include "utils/regex-match.h"
Tony Mak63959242020-02-07 18:31:16 +000040#include "utils/strings/numbers.h"
41#include "utils/strings/split.h"
Tony Mak6c4cc672018-09-17 11:48:50 +010042#include "utils/utf8/unicodetext.h"
Tony Mak21460022020-03-12 18:29:35 +000043#include "utils/utf8/unilib-common.h"
Tony Mak378c1f52019-03-04 15:58:11 +000044#include "utils/zlib/zlib_regex.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010045
Tony Mak6c4cc672018-09-17 11:48:50 +010046namespace libtextclassifier3 {
Tony Mak448b5862019-03-22 13:36:41 +000047
48using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
49
Tony Mak6c4cc672018-09-17 11:48:50 +010050const std::string& Annotator::kPhoneCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010051 *[]() { return new std::string("phone"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010052const std::string& Annotator::kAddressCollection =
Lukas Zilkae7962cc2018-03-28 18:09:48 +020053 *[]() { return new std::string("address"); }();
Tony Mak6c4cc672018-09-17 11:48:50 +010054const std::string& Annotator::kDateCollection =
Lukas Zilkab23e2122018-02-09 10:25:19 +010055 *[]() { return new std::string("date"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000056const std::string& Annotator::kUrlCollection =
57 *[]() { return new std::string("url"); }();
Tony Mak296b7b62018-12-04 18:09:15 +000058const std::string& Annotator::kEmailCollection =
59 *[]() { return new std::string("email"); }();
Lukas Zilkab23e2122018-02-09 10:25:19 +010060
Lukas Zilka21d8c982018-01-24 11:11:20 +010061namespace {
62const Model* LoadAndVerifyModel(const void* addr, int size) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010063 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
Tony Mak51a9e542018-11-02 13:36:22 +000064 if (VerifyModelBuffer(verifier)) {
65 return GetModel(addr);
Lukas Zilka21d8c982018-01-24 11:11:20 +010066 } else {
67 return nullptr;
68 }
69}
Tony Mak6c4cc672018-09-17 11:48:50 +010070
Tony Mak76d80962020-01-08 17:30:51 +000071const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
72 int size) {
73 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
74 if (VerifyPersonNameModelBuffer(verifier)) {
75 return GetPersonNameModel(addr);
76 } else {
77 return nullptr;
78 }
79}
80
Tony Mak6c4cc672018-09-17 11:48:50 +010081// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
82// create a new instance, assign ownership to owned_lib, and return it.
83const UniLib* MaybeCreateUnilib(const UniLib* lib,
84 std::unique_ptr<UniLib>* owned_lib) {
85 if (lib) {
86 return lib;
87 } else {
88 owned_lib->reset(new UniLib);
89 return owned_lib->get();
90 }
91}
92
93// As above, but for CalendarLib.
94const CalendarLib* MaybeCreateCalendarlib(
95 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
96 if (lib) {
97 return lib;
98 } else {
99 owned_lib->reset(new CalendarLib);
100 return owned_lib->get();
101 }
102}
103
Tony Mak968412a2019-11-13 15:39:57 +0000104// Returns whether the provided input is valid:
105// * Valid utf8 text.
106// * Sane span indices.
107bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
108 if (!context.is_valid()) {
109 return false;
110 }
111 return (span.first >= 0 && span.first < span.second &&
112 span.second <= context.size_codepoints());
113}
114
Tony Mak63959242020-02-07 18:31:16 +0000115std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
116 const flatbuffers::Vector<int32_t>* ints) {
117 if (ints == nullptr) {
118 return {};
119 }
120 std::unordered_set<char32> ints_set;
121 for (auto value : *ints) {
122 ints_set.insert(static_cast<char32>(value));
123 }
124 return ints_set;
125}
126
Tony Mak21460022020-03-12 18:29:35 +0000127DateAnnotationOptions ToDateAnnotationOptions(
128 const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
129 const std::string& reference_timezone, const int64 reference_time_ms_utc) {
130 DateAnnotationOptions result_annotation_options;
131 result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
132 result_annotation_options.reference_timezone = reference_timezone;
133 if (fb_annotation_options != nullptr) {
134 result_annotation_options.enable_special_day_offset =
135 fb_annotation_options->enable_special_day_offset();
136 result_annotation_options.merge_adjacent_components =
137 fb_annotation_options->merge_adjacent_components();
138 result_annotation_options.enable_date_range =
139 fb_annotation_options->enable_date_range();
140 result_annotation_options.include_preposition =
141 fb_annotation_options->include_preposition();
Tony Mak21460022020-03-12 18:29:35 +0000142 if (fb_annotation_options->extra_requested_dates() != nullptr) {
143 for (const auto& extra_requested_date :
144 *fb_annotation_options->extra_requested_dates()) {
145 result_annotation_options.extra_requested_dates.push_back(
146 extra_requested_date->str());
147 }
148 }
Tony Makd99d58c2020-03-19 21:52:02 +0000149 if (fb_annotation_options->ignored_spans() != nullptr) {
150 for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) {
151 result_annotation_options.ignored_spans.push_back(ignored_span->str());
Tony Mak0b8b3322020-03-17 16:30:19 +0000152 }
153 }
Tony Mak21460022020-03-12 18:29:35 +0000154 }
155 return result_annotation_options;
156}
157
Lukas Zilka21d8c982018-01-24 11:11:20 +0100158} // namespace
159
Lukas Zilkaba849e72018-03-08 14:48:21 +0100160tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
161 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100162 TC3_CHECK(selection_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100163 selection_interpreter_ = selection_executor_->CreateInterpreter();
164 if (!selection_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100165 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100166 }
167 }
168 return selection_interpreter_.get();
169}
170
171tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
172 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100173 TC3_CHECK(classification_executor_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100174 classification_interpreter_ = classification_executor_->CreateInterpreter();
175 if (!classification_interpreter_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100176 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100177 }
178 }
179 return classification_interpreter_.get();
180}
181
Tony Mak6c4cc672018-09-17 11:48:50 +0100182std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
183 const char* buffer, int size, const UniLib* unilib,
184 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100185 const Model* model = LoadAndVerifyModel(buffer, size);
186 if (model == nullptr) {
187 return nullptr;
188 }
189
Lukas Zilkab23e2122018-02-09 10:25:19 +0100190 auto classifier =
Tony Mak6c4cc672018-09-17 11:48:50 +0100191 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100192 if (!classifier->IsInitialized()) {
193 return nullptr;
194 }
195
196 return classifier;
197}
198
Tony Mak6c4cc672018-09-17 11:48:50 +0100199std::unique_ptr<Annotator> Annotator::FromScopedMmap(
200 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
201 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100202 if (!(*mmap)->handle().ok()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100203 TC3_VLOG(1) << "Mmap failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100204 return nullptr;
205 }
206
207 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
208 (*mmap)->handle().num_bytes());
209 if (!model) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100210 TC3_LOG(ERROR) << "Model verification failed.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100211 return nullptr;
212 }
213
Tony Mak6c4cc672018-09-17 11:48:50 +0100214 auto classifier = std::unique_ptr<Annotator>(
215 new Annotator(mmap, model, unilib, calendarlib));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100216 if (!classifier->IsInitialized()) {
217 return nullptr;
218 }
219
220 return classifier;
221}
222
Tony Makdf54e742019-03-26 14:04:00 +0000223std::unique_ptr<Annotator> Annotator::FromScopedMmap(
224 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
225 std::unique_ptr<CalendarLib> calendarlib) {
226 if (!(*mmap)->handle().ok()) {
227 TC3_VLOG(1) << "Mmap failed.";
228 return nullptr;
229 }
230
231 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
232 (*mmap)->handle().num_bytes());
233 if (model == nullptr) {
234 TC3_LOG(ERROR) << "Model verification failed.";
235 return nullptr;
236 }
237
238 auto classifier = std::unique_ptr<Annotator>(
239 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
240 if (!classifier->IsInitialized()) {
241 return nullptr;
242 }
243
244 return classifier;
245}
246
Tony Mak6c4cc672018-09-17 11:48:50 +0100247std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
248 int fd, int offset, int size, const UniLib* unilib,
249 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100250 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
Tony Mak6c4cc672018-09-17 11:48:50 +0100251 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100252}
253
Tony Mak6c4cc672018-09-17 11:48:50 +0100254std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Makdf54e742019-03-26 14:04:00 +0000255 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
256 std::unique_ptr<CalendarLib> calendarlib) {
257 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
258 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
259}
260
261std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
Tony Mak6c4cc672018-09-17 11:48:50 +0100262 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100263 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
Tony Mak6c4cc672018-09-17 11:48:50 +0100264 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100265}
266
Tony Makdf54e742019-03-26 14:04:00 +0000267std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268 int fd, std::unique_ptr<UniLib> unilib,
269 std::unique_ptr<CalendarLib> calendarlib) {
270 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
271 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
272}
273
Tony Mak6c4cc672018-09-17 11:48:50 +0100274std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
275 const UniLib* unilib,
276 const CalendarLib* calendarlib) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100277 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
Tony Mak6c4cc672018-09-17 11:48:50 +0100278 return FromScopedMmap(&mmap, unilib, calendarlib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100279}
280
Tony Makdf54e742019-03-26 14:04:00 +0000281std::unique_ptr<Annotator> Annotator::FromPath(
282 const std::string& path, std::unique_ptr<UniLib> unilib,
283 std::unique_ptr<CalendarLib> calendarlib) {
284 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
285 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
286}
287
Tony Mak6c4cc672018-09-17 11:48:50 +0100288Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
289 const UniLib* unilib, const CalendarLib* calendarlib)
290 : model_(model),
291 mmap_(std::move(*mmap)),
292 owned_unilib_(nullptr),
293 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
294 owned_calendarlib_(nullptr),
295 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
296 ValidateAndInitialize();
297}
298
Tony Makdf54e742019-03-26 14:04:00 +0000299Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
300 std::unique_ptr<UniLib> unilib,
301 std::unique_ptr<CalendarLib> calendarlib)
302 : model_(model),
303 mmap_(std::move(*mmap)),
304 owned_unilib_(std::move(unilib)),
305 unilib_(owned_unilib_.get()),
306 owned_calendarlib_(std::move(calendarlib)),
307 calendarlib_(owned_calendarlib_.get()) {
308 ValidateAndInitialize();
309}
310
Tony Mak6c4cc672018-09-17 11:48:50 +0100311Annotator::Annotator(const Model* model, const UniLib* unilib,
312 const CalendarLib* calendarlib)
313 : model_(model),
314 owned_unilib_(nullptr),
315 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
316 owned_calendarlib_(nullptr),
317 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
318 ValidateAndInitialize();
319}
320
321void Annotator::ValidateAndInitialize() {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100322 initialized_ = false;
323
Lukas Zilka21d8c982018-01-24 11:11:20 +0100324 if (model_ == nullptr) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100325 TC3_LOG(ERROR) << "No model specified.";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100326 return;
327 }
328
Lukas Zilkaba849e72018-03-08 14:48:21 +0100329 const bool model_enabled_for_annotation =
330 (model_->triggering_options() != nullptr &&
331 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
332 const bool model_enabled_for_classification =
333 (model_->triggering_options() != nullptr &&
334 (model_->triggering_options()->enabled_modes() &
335 ModeFlag_CLASSIFICATION));
336 const bool model_enabled_for_selection =
337 (model_->triggering_options() != nullptr &&
338 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
339
340 // Annotation requires the selection model.
341 if (model_enabled_for_annotation || model_enabled_for_selection) {
342 if (!model_->selection_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100343 TC3_LOG(ERROR) << "No selection options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100344 return;
345 }
346 if (!model_->selection_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100347 TC3_LOG(ERROR) << "No selection feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100348 return;
349 }
350 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100351 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100352 return;
353 }
354 if (!model_->selection_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100355 TC3_LOG(ERROR) << "No selection model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100356 return;
357 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100358 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100359 if (!selection_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100360 TC3_LOG(ERROR) << "Could not initialize selection executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100361 return;
362 }
363 selection_feature_processor_.reset(
364 new FeatureProcessor(model_->selection_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100365 }
366
Lukas Zilkaba849e72018-03-08 14:48:21 +0100367 // Annotation requires the classification model for conflict resolution and
368 // scoring.
369 // Selection requires the classification model for conflict resolution.
370 if (model_enabled_for_annotation || model_enabled_for_classification ||
371 model_enabled_for_selection) {
372 if (!model_->classification_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100373 TC3_LOG(ERROR) << "No classification options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100374 return;
375 }
376
377 if (!model_->classification_feature_options()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100378 TC3_LOG(ERROR) << "No classification feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100379 return;
380 }
381
382 if (!model_->classification_feature_options()
383 ->bounds_sensitive_features()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100384 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100385 return;
386 }
387 if (!model_->classification_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100388 TC3_LOG(ERROR) << "No clf model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100389 return;
390 }
391
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200392 classification_executor_ =
Tony Mak6c4cc672018-09-17 11:48:50 +0100393 ModelExecutor::FromBuffer(model_->classification_model());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100394 if (!classification_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100395 TC3_LOG(ERROR) << "Could not initialize classification executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100396 return;
397 }
398
399 classification_feature_processor_.reset(new FeatureProcessor(
400 model_->classification_feature_options(), unilib_));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100401 }
402
Lukas Zilkaba849e72018-03-08 14:48:21 +0100403 // The embeddings need to be specified if the model is to be used for
404 // classification or selection.
405 if (model_enabled_for_annotation || model_enabled_for_classification ||
406 model_enabled_for_selection) {
407 if (!model_->embedding_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100408 TC3_LOG(ERROR) << "No embedding model.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100409 return;
410 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100411
Lukas Zilkaba849e72018-03-08 14:48:21 +0100412 // Check that the embedding size of the selection and classification model
413 // matches, as they are using the same embeddings.
414 if (model_enabled_for_selection &&
415 (model_->selection_feature_options()->embedding_size() !=
416 model_->classification_feature_options()->embedding_size() ||
417 model_->selection_feature_options()->embedding_quantization_bits() !=
418 model_->classification_feature_options()
419 ->embedding_quantization_bits())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100420 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100421 return;
422 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100423
Tony Mak6c4cc672018-09-17 11:48:50 +0100424 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200425 model_->embedding_model(),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100426 model_->classification_feature_options()->embedding_size(),
Tony Makdf54e742019-03-26 14:04:00 +0000427 model_->classification_feature_options()->embedding_quantization_bits(),
428 model_->embedding_pruning_mask());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200429 if (!embedding_executor_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100430 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100431 return;
432 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100433 }
434
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200435 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
Lukas Zilkab23e2122018-02-09 10:25:19 +0100436 if (model_->regex_model()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200437 if (!InitializeRegexModel(decompressor.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100438 TC3_LOG(ERROR) << "Could not initialize regex model.";
Lukas Zilka434442d2018-04-25 11:38:51 +0200439 return;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100440 }
Lukas Zilka21d8c982018-01-24 11:11:20 +0100441 }
Tony Mak63959242020-02-07 18:31:16 +0000442 if (model_->grammar_datetime_model() &&
443 model_->grammar_datetime_model()->datetime_rules()) {
444 cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
445 *unilib_,
446 /*tokenizer_options=*/
447 model_->grammar_datetime_model()->grammar_tokenizer_options(),
448 *calendarlib_,
Tony Mak21460022020-03-12 18:29:35 +0000449 /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
450 model_->grammar_datetime_model()->target_classification_score(),
451 model_->grammar_datetime_model()->priority_score()));
Tony Mak63959242020-02-07 18:31:16 +0000452 if (!cfg_datetime_parser_) {
453 TC3_LOG(ERROR) << "Could not initialize context free grammar based "
454 "datetime parser.";
455 return;
456 }
Tony Makd99d58c2020-03-19 21:52:02 +0000457 }
458
459 if (model_->datetime_model()) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100460 datetime_parser_ = DatetimeParser::Instance(
461 model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100462 if (!datetime_parser_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100463 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100464 return;
465 }
466 }
467
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200468 if (model_->output_options()) {
469 if (model_->output_options()->filtered_collections_annotation()) {
470 for (const auto collection :
471 *model_->output_options()->filtered_collections_annotation()) {
472 filtered_collections_annotation_.insert(collection->str());
473 }
474 }
475 if (model_->output_options()->filtered_collections_classification()) {
476 for (const auto collection :
477 *model_->output_options()->filtered_collections_classification()) {
478 filtered_collections_classification_.insert(collection->str());
479 }
480 }
481 if (model_->output_options()->filtered_collections_selection()) {
482 for (const auto collection :
483 *model_->output_options()->filtered_collections_selection()) {
484 filtered_collections_selection_.insert(collection->str());
485 }
486 }
487 }
488
Tony Mak378c1f52019-03-04 15:58:11 +0000489 if (model_->number_annotator_options() &&
490 model_->number_annotator_options()->enabled()) {
491 number_annotator_.reset(
Tony Mak63959242020-02-07 18:31:16 +0000492 new NumberAnnotator(model_->number_annotator_options(), unilib_));
493 }
494
495 if (model_->money_parsing_options()) {
496 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
497 model_->money_parsing_options()->separators());
Tony Mak378c1f52019-03-04 15:58:11 +0000498 }
499
Tony Makad2e22d2019-03-20 17:35:13 +0000500 if (model_->duration_annotator_options() &&
501 model_->duration_annotator_options()->enabled()) {
502 duration_annotator_.reset(
503 new DurationAnnotator(model_->duration_annotator_options(),
Tony Mak8cd7ba62019-10-15 15:29:22 +0100504 selection_feature_processor_.get(), unilib_));
Tony Makad2e22d2019-03-20 17:35:13 +0000505 }
506
Tony Makd9446602019-02-20 18:25:39 +0000507 if (model_->entity_data_schema()) {
508 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
509 model_->entity_data_schema()->Data(),
510 model_->entity_data_schema()->size());
511 if (entity_data_schema_ == nullptr) {
512 TC3_LOG(ERROR) << "Could not load entity data schema data.";
513 return;
514 }
515
516 entity_data_builder_.reset(
517 new ReflectiveFlatbufferBuilder(entity_data_schema_));
518 } else {
Tony Mak378c1f52019-03-04 15:58:11 +0000519 entity_data_schema_ = nullptr;
Tony Makd9446602019-02-20 18:25:39 +0000520 entity_data_builder_ = nullptr;
521 }
522
Tony Mak21460022020-03-12 18:29:35 +0000523 if (model_->grammar_model()) {
524 grammar_annotator_.reset(new GrammarAnnotator(
525 unilib_, model_->grammar_model(), entity_data_builder_.get()));
526 }
527
Tony Makdf54e742019-03-26 14:04:00 +0000528 if (model_->triggering_locales() &&
529 !ParseLocales(model_->triggering_locales()->c_str(),
530 &model_triggering_locales_)) {
Tony Mak378c1f52019-03-04 15:58:11 +0000531 TC3_LOG(ERROR) << "Could not parse model supported locales.";
532 return;
533 }
534
535 if (model_->triggering_options() != nullptr &&
Tony Makdf54e742019-03-26 14:04:00 +0000536 model_->triggering_options()->locales() != nullptr &&
537 !ParseLocales(model_->triggering_options()->locales()->c_str(),
538 &ml_model_triggering_locales_)) {
539 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
540 return;
541 }
542
543 if (model_->triggering_options() != nullptr &&
Tony Mak378c1f52019-03-04 15:58:11 +0000544 model_->triggering_options()->dictionary_locales() != nullptr &&
545 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
546 &dictionary_locales_)) {
547 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
548 return;
549 }
550
Lukas Zilka21d8c982018-01-24 11:11:20 +0100551 initialized_ = true;
552}
553
Tony Mak6c4cc672018-09-17 11:48:50 +0100554bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100555 if (!model_->regex_model()->patterns()) {
Lukas Zilka434442d2018-04-25 11:38:51 +0200556 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100557 }
558
559 // Initialize pattern recognizers.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100560 int regex_pattern_id = 0;
561 for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200562 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
Tony Mak378c1f52019-03-04 15:58:11 +0000563 UncompressMakeRegexPattern(
564 *unilib_, regex_pattern->pattern(),
565 regex_pattern->compressed_pattern(),
566 model_->regex_model()->lazy_regex_compilation(), decompressor);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100567 if (!compiled_pattern) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100568 TC3_LOG(INFO) << "Failed to load regex pattern";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200569 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100570 }
571
Lukas Zilkaba849e72018-03-08 14:48:21 +0100572 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100573 annotation_regex_patterns_.push_back(regex_pattern_id);
574 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100575 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100576 classification_regex_patterns_.push_back(regex_pattern_id);
577 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100578 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100579 selection_regex_patterns_.push_back(regex_pattern_id);
580 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100581 regex_patterns_.push_back({
Tony Mak854015a2019-01-16 15:56:48 +0000582 regex_pattern,
Tony Mak6c4cc672018-09-17 11:48:50 +0100583 std::move(compiled_pattern),
Tony Mak6c4cc672018-09-17 11:48:50 +0100584 });
Lukas Zilkab23e2122018-02-09 10:25:19 +0100585 ++regex_pattern_id;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100586 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100587
Lukas Zilkab23e2122018-02-09 10:25:19 +0100588 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100589}
590
Tony Mak6c4cc672018-09-17 11:48:50 +0100591bool Annotator::InitializeKnowledgeEngine(
592 const std::string& serialized_config) {
Tony Maka2a1ff42019-09-12 15:40:32 +0100593 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
Tony Mak63959242020-02-07 18:31:16 +0000594 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100595 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
596 return false;
597 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100598 if (model_->triggering_options() != nullptr) {
599 knowledge_engine->SetPriorityScore(
600 model_->triggering_options()->knowledge_priority_score());
601 }
Tony Mak6c4cc672018-09-17 11:48:50 +0100602 knowledge_engine_ = std::move(knowledge_engine);
603 return true;
604}
605
Tony Mak854015a2019-01-16 15:56:48 +0000606bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
Tony Mak5dc5e112019-02-01 14:52:10 +0000607 std::unique_ptr<ContactEngine> contact_engine(
Tony Mak63959242020-02-07 18:31:16 +0000608 new ContactEngine(selection_feature_processor_.get(), unilib_,
609 model_->contact_annotator_options()));
Tony Mak854015a2019-01-16 15:56:48 +0000610 if (!contact_engine->Initialize(serialized_config)) {
611 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
612 return false;
613 }
614 contact_engine_ = std::move(contact_engine);
615 return true;
616}
617
Tony Makd9446602019-02-20 18:25:39 +0000618bool Annotator::InitializeInstalledAppEngine(
619 const std::string& serialized_config) {
620 std::unique_ptr<InstalledAppEngine> installed_app_engine(
Tony Mak378c1f52019-03-04 15:58:11 +0000621 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
Tony Makd9446602019-02-20 18:25:39 +0000622 if (!installed_app_engine->Initialize(serialized_config)) {
623 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
624 return false;
625 }
626 installed_app_engine_ = std::move(installed_app_engine);
627 return true;
628}
629
Tony Mak63959242020-02-07 18:31:16 +0000630void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
631 lang_id_ = lang_id;
Tony Mak21460022020-03-12 18:29:35 +0000632 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
Tony Mak63959242020-02-07 18:31:16 +0000633 model_->translate_annotator_options()->enabled()) {
634 translate_annotator_.reset(new TranslateAnnotator(
635 model_->translate_annotator_options(), lang_id_, unilib_));
Tony Mak21460022020-03-12 18:29:35 +0000636 } else {
637 translate_annotator_.reset(nullptr);
Tony Mak63959242020-02-07 18:31:16 +0000638 }
639}
640
Tony Mak21460022020-03-12 18:29:35 +0000641bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
642 int size) {
643 const PersonNameModel* person_name_model =
644 LoadAndVerifyPersonNameModel(buffer, size);
Tony Mak76d80962020-01-08 17:30:51 +0000645
646 if (person_name_model == nullptr) {
647 TC3_LOG(ERROR) << "Person name model verification failed.";
648 return false;
649 }
650
651 if (!person_name_model->enabled()) {
652 return true;
653 }
654
655 std::unique_ptr<PersonNameEngine> person_name_engine(
Tony Mak21460022020-03-12 18:29:35 +0000656 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
Tony Mak76d80962020-01-08 17:30:51 +0000657 if (!person_name_engine->Initialize(person_name_model)) {
658 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
659 return false;
660 }
661 person_name_engine_ = std::move(person_name_engine);
662 return true;
663}
664
Tony Mak21460022020-03-12 18:29:35 +0000665bool Annotator::InitializePersonNameEngineFromScopedMmap(
666 const ScopedMmap& mmap) {
667 if (!mmap.handle().ok()) {
668 TC3_LOG(ERROR) << "Mmap for person name model failed.";
669 return false;
670 }
671
672 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
673 mmap.handle().num_bytes());
674}
675
676bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
677 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
678 return InitializePersonNameEngineFromScopedMmap(*mmap);
679}
680
681bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
682 int size) {
683 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
684 return InitializePersonNameEngineFromScopedMmap(*mmap);
685}
686
Lukas Zilka21d8c982018-01-24 11:11:20 +0100687namespace {
688
689int CountDigits(const std::string& str, CodepointSpan selection_indices) {
690 int count = 0;
691 int i = 0;
692 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
693 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
694 if (i >= selection_indices.first && i < selection_indices.second &&
Tony Mak21460022020-03-12 18:29:35 +0000695 IsDigit(*it)) {
Lukas Zilka21d8c982018-01-24 11:11:20 +0100696 ++count;
697 }
698 }
699 return count;
700}
701
Lukas Zilka21d8c982018-01-24 11:11:20 +0100702} // namespace
703
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200704namespace internal {
705// Helper function, which if the initial 'span' contains only white-spaces,
706// moves the selection to a single-codepoint selection on a left or right side
707// of this space.
708CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
709 const UnicodeText& context_unicode,
710 const UniLib& unilib) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100711 TC3_CHECK(ValidNonEmptySpan(span));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200712
713 UnicodeText::const_iterator it;
714
715 // Check that the current selection is all whitespaces.
716 it = context_unicode.begin();
717 std::advance(it, span.first);
718 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
719 if (!unilib.IsWhitespace(*it)) {
720 return span;
721 }
722 }
723
724 CodepointSpan result;
725
726 // Try moving left.
727 result = span;
728 it = context_unicode.begin();
729 std::advance(it, span.first);
730 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
731 --result.first;
732 --it;
733 }
734 result.second = result.first + 1;
735 if (!unilib.IsWhitespace(*it)) {
736 return result;
737 }
738
739 // If moving left didn't find a non-whitespace character, just return the
740 // original span.
741 return span;
742}
743} // namespace internal
744
Tony Mak6c4cc672018-09-17 11:48:50 +0100745bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200746 return !span.classification.empty() &&
747 filtered_collections_annotation_.find(
748 span.classification[0].collection) !=
749 filtered_collections_annotation_.end();
750}
751
Tony Mak6c4cc672018-09-17 11:48:50 +0100752bool Annotator::FilteredForClassification(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200753 const ClassificationResult& classification) const {
754 return filtered_collections_classification_.find(classification.collection) !=
755 filtered_collections_classification_.end();
756}
757
Tony Mak6c4cc672018-09-17 11:48:50 +0100758bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200759 return !span.classification.empty() &&
760 filtered_collections_selection_.find(
761 span.classification[0].collection) !=
762 filtered_collections_selection_.end();
763}
764
Tony Mak378c1f52019-03-04 15:58:11 +0000765namespace {
766inline bool ClassifiedAsOther(
767 const std::vector<ClassificationResult>& classification) {
768 return !classification.empty() &&
769 classification[0].collection == Collections::Other();
770}
771
Tony Maka2a1ff42019-09-12 15:40:32 +0100772} // namespace
773
774float Annotator::GetPriorityScore(
775 const std::vector<ClassificationResult>& classification) const {
Tony Mak378c1f52019-03-04 15:58:11 +0000776 if (!classification.empty() && !ClassifiedAsOther(classification)) {
777 return classification[0].priority_score;
778 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +0100779 if (model_->triggering_options() != nullptr) {
780 return model_->triggering_options()->other_collection_priority_score();
781 } else {
782 return -1000.0;
783 }
Tony Mak378c1f52019-03-04 15:58:11 +0000784 }
785}
Tony Mak378c1f52019-03-04 15:58:11 +0000786
Tony Makdf54e742019-03-26 14:04:00 +0000787bool Annotator::VerifyRegexMatchCandidate(
788 const std::string& context, const VerificationOptions* verification_options,
789 const std::string& match, const UniLib::RegexMatcher* matcher) const {
790 if (verification_options == nullptr) {
791 return true;
792 }
793 if (verification_options->verify_luhn_checksum() &&
794 !VerifyLuhnChecksum(match)) {
795 return false;
796 }
797 const int lua_verifier = verification_options->lua_verifier();
798 if (lua_verifier >= 0) {
799 if (model_->regex_model()->lua_verifier() == nullptr ||
800 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
801 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
802 return false;
803 }
804 return VerifyMatch(
805 context, matcher,
806 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
807 }
808 return true;
809}
810
Tony Mak6c4cc672018-09-17 11:48:50 +0100811CodepointSpan Annotator::SuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100812 const std::string& context, CodepointSpan click_indices,
813 const SelectionOptions& options) const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200814 CodepointSpan original_click_indices = click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100815 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100816 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200817 return original_click_indices;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100818 }
Lukas Zilkaba849e72018-03-08 14:48:21 +0100819 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200820 return original_click_indices;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100821 }
822
Tony Makdf54e742019-03-26 14:04:00 +0000823 std::vector<Locale> detected_text_language_tags;
824 if (!ParseLocales(options.detected_text_language_tags,
825 &detected_text_language_tags)) {
826 TC3_LOG(WARNING)
827 << "Failed to parse the detected_text_language_tags in options: "
828 << options.detected_text_language_tags;
829 }
830 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
831 model_triggering_locales_,
832 /*default_value=*/true)) {
833 return original_click_indices;
834 }
835
Lukas Zilkadf710db2018-02-27 12:44:09 +0100836 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
837 /*do_copy=*/false);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200838
Tony Mak968412a2019-11-13 15:39:57 +0000839 if (!IsValidSpanInput(context_unicode, click_indices)) {
840 TC3_VLOG(1)
841 << "Trying to run SuggestSelection with invalid input, indices: "
842 << click_indices.first << " " << click_indices.second;
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200843 return original_click_indices;
844 }
845
846 if (model_->snap_whitespace_selections()) {
847 // We want to expand a purely white-space selection to a multi-selection it
848 // would've been part of. But with this feature disabled we would do a no-
849 // op, because no token is found. Therefore, we need to modify the
850 // 'click_indices' a bit to include a part of the token, so that the click-
851 // finding logic finds the clicked token correctly. This modification is
852 // done by the following function. Note, that it's enough to check the left
853 // side of the current selection, because if the white-space is a part of a
Tony Mak6c4cc672018-09-17 11:48:50 +0100854 // multi-selection, necessarily both tokens - on the left and the right
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200855 // sides need to be selected. Thus snapping only to the left is sufficient
856 // (there's a check at the bottom that makes sure that if we snap to the
857 // left token but the result does not contain the initial white-space,
858 // returns the original indices).
859 click_indices = internal::SnapLeftIfWhitespaceSelection(
860 click_indices, context_unicode, *unilib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100861 }
862
Lukas Zilkab23e2122018-02-09 10:25:19 +0100863 std::vector<AnnotatedSpan> candidates;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100864 InterpreterManager interpreter_manager(selection_executor_.get(),
865 classification_executor_.get());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200866 std::vector<Token> tokens;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100867 if (!ModelSuggestSelection(context_unicode, click_indices,
Tony Makdf54e742019-03-26 14:04:00 +0000868 detected_text_language_tags, &interpreter_manager,
869 &tokens, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100870 TC3_LOG(ERROR) << "Model suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200871 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100872 }
Tony Mak83d2de62019-04-10 16:12:15 +0100873 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
874 /*is_serialized_entity_data_enabled=*/false)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100875 TC3_LOG(ERROR) << "Regex suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200876 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100877 }
Tony Mak83d2de62019-04-10 16:12:15 +0100878 if (!DatetimeChunk(
879 UTF8ToUnicodeText(context, /*do_copy=*/false),
880 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
881 options.locales, ModeFlag_SELECTION, options.annotation_usecase,
882 /*is_serialized_entity_data_enabled=*/false, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100883 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
884 return original_click_indices;
885 }
Tony Mak378c1f52019-03-04 15:58:11 +0000886 if (knowledge_engine_ != nullptr &&
Tony Maka2a1ff42019-09-12 15:40:32 +0100887 !knowledge_engine_->Chunk(context, options.annotation_usecase,
Tony Mak63959242020-02-07 18:31:16 +0000888 options.location_context, &candidates)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100889 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200890 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100891 }
Tony Mak378c1f52019-03-04 15:58:11 +0000892 if (contact_engine_ != nullptr &&
Tony Mak854015a2019-01-16 15:56:48 +0000893 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
894 TC3_LOG(ERROR) << "Contact suggest selection failed.";
895 return original_click_indices;
896 }
Tony Mak378c1f52019-03-04 15:58:11 +0000897 if (installed_app_engine_ != nullptr &&
Tony Makd9446602019-02-20 18:25:39 +0000898 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
899 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
900 return original_click_indices;
901 }
Tony Mak378c1f52019-03-04 15:58:11 +0000902 if (number_annotator_ != nullptr &&
903 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
904 &candidates)) {
905 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
906 return original_click_indices;
907 }
Tony Makad2e22d2019-03-20 17:35:13 +0000908 if (duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +0000909 !duration_annotator_->FindAll(context_unicode, tokens,
910 options.annotation_usecase, &candidates)) {
Tony Makad2e22d2019-03-20 17:35:13 +0000911 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
912 return original_click_indices;
913 }
Tony Mak76d80962020-01-08 17:30:51 +0000914 if (person_name_engine_ != nullptr &&
915 !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
916 TC3_LOG(ERROR) << "Person name suggest selection failed.";
917 return original_click_indices;
918 }
Lukas Zilkab23e2122018-02-09 10:25:19 +0100919
Tony Mak21460022020-03-12 18:29:35 +0000920 AnnotatedSpan grammar_suggested_span;
921 if (grammar_annotator_ != nullptr &&
922 grammar_annotator_->SuggestSelection(detected_text_language_tags,
923 context_unicode, click_indices,
924 &grammar_suggested_span)) {
925 candidates.push_back(grammar_suggested_span);
926 }
927
Lukas Zilkab23e2122018-02-09 10:25:19 +0100928 // Sort candidates according to their position in the input, so that the next
929 // code can assume that any connected component of overlapping spans forms a
930 // contiguous block.
931 std::sort(candidates.begin(), candidates.end(),
932 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
933 return a.span.first < b.span.first;
934 });
935
936 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +0000937 if (!ResolveConflicts(candidates, context, tokens,
938 detected_text_language_tags, options.annotation_usecase,
939 &interpreter_manager, &candidate_indices)) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100940 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200941 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100942 }
943
Tony Mak378c1f52019-03-04 15:58:11 +0000944 std::sort(candidate_indices.begin(), candidate_indices.end(),
Tony Maka2a1ff42019-09-12 15:40:32 +0100945 [this, &candidates](int a, int b) {
Tony Mak378c1f52019-03-04 15:58:11 +0000946 return GetPriorityScore(candidates[a].classification) >
947 GetPriorityScore(candidates[b].classification);
948 });
949
Lukas Zilkab23e2122018-02-09 10:25:19 +0100950 for (const int i : candidate_indices) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200951 if (SpansOverlap(candidates[i].span, click_indices) &&
952 SpansOverlap(candidates[i].span, original_click_indices)) {
953 // Run model classification if not present but requested and there's a
954 // classification collection filter specified.
955 if (candidates[i].classification.empty() &&
956 model_->selection_options()->always_classify_suggested_selection() &&
957 !filtered_collections_selection_.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +0000958 if (!ModelClassifyText(context, detected_text_language_tags,
959 candidates[i].span, &interpreter_manager,
960 /*embedding_cache=*/nullptr,
961 &candidates[i].classification)) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200962 return original_click_indices;
963 }
964 }
965
966 // Ignore if span classification is filtered.
967 if (FilteredForSelection(candidates[i])) {
968 return original_click_indices;
969 }
970
Lukas Zilkab23e2122018-02-09 10:25:19 +0100971 return candidates[i].span;
972 }
973 }
974
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200975 return original_click_indices;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100976}
977
978namespace {
979// Helper function that returns the index of the first candidate that
980// transitively does not overlap with the candidate on 'start_index'. If the end
981// of 'candidates' is reached, it returns the index that points right behind the
982// array.
983int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
984 int start_index) {
985 int first_non_overlapping = start_index + 1;
986 CodepointSpan conflicting_span = candidates[start_index].span;
987 while (
988 first_non_overlapping < candidates.size() &&
989 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
990 // Grow the span to include the current one.
991 conflicting_span.second = std::max(
992 conflicting_span.second, candidates[first_non_overlapping].span.second);
993
994 ++first_non_overlapping;
995 }
996 return first_non_overlapping;
997}
998} // namespace
999
Tony Mak378c1f52019-03-04 15:58:11 +00001000bool Annotator::ResolveConflicts(
1001 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1002 const std::vector<Token>& cached_tokens,
1003 const std::vector<Locale>& detected_text_language_tags,
1004 AnnotationUsecase annotation_usecase,
1005 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001006 result->clear();
1007 result->reserve(candidates.size());
1008 for (int i = 0; i < candidates.size();) {
1009 int first_non_overlapping =
1010 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1011
1012 const bool conflict_found = first_non_overlapping != (i + 1);
1013 if (conflict_found) {
1014 std::vector<int> candidate_indices;
Tony Mak378c1f52019-03-04 15:58:11 +00001015 if (!ResolveConflict(context, cached_tokens, candidates,
1016 detected_text_language_tags, i,
1017 first_non_overlapping, annotation_usecase,
1018 interpreter_manager, &candidate_indices)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001019 return false;
1020 }
1021 result->insert(result->end(), candidate_indices.begin(),
1022 candidate_indices.end());
1023 } else {
1024 result->push_back(i);
1025 }
1026
1027 // Skip over the whole conflicting group/go to next candidate.
1028 i = first_non_overlapping;
1029 }
1030 return true;
1031}
1032
1033namespace {
Tony Mak448b5862019-03-22 13:36:41 +00001034// Returns true, if the given two sources do conflict in given annotation
1035// usecase.
1036// - In SMART usecase, all sources do conflict, because there's only 1 possible
1037// annotation for a given span.
1038// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1039// and duration), while others not (e.g. duration and number).
1040bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1041 const AnnotatedSpan::Source source1,
1042 const AnnotatedSpan::Source source2) {
1043 uint32 source_mask =
1044 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1045
Tony Mak378c1f52019-03-04 15:58:11 +00001046 switch (annotation_usecase) {
1047 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
Tony Mak448b5862019-03-22 13:36:41 +00001048 // In the SMART mode, all annotations conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001049 return true;
Tony Mak448b5862019-03-22 13:36:41 +00001050
Tony Mak378c1f52019-03-04 15:58:11 +00001051 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
Tony Mak448b5862019-03-22 13:36:41 +00001052 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1053 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1054 // hours" (duration).
1055 if ((source_mask &
1056 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1057 (source_mask &
1058 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1059 return false;
Tony Mak378c1f52019-03-04 15:58:11 +00001060 }
Tony Mak448b5862019-03-22 13:36:41 +00001061
1062 // A KNOWLEDGE entity does not conflict with anything.
1063 if ((source_mask &
1064 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1065 return false;
1066 }
1067
Tony Makd0ae7c62020-03-27 13:58:00 +00001068 // A PERSONNAME entity does not conflict with anything.
1069 if ((source_mask &
1070 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1071 return false;
1072 }
1073
Tony Mak448b5862019-03-22 13:36:41 +00001074 // Entities from other sources can conflict.
Tony Mak378c1f52019-03-04 15:58:11 +00001075 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001076 }
1077}
1078} // namespace
1079
Tony Mak378c1f52019-03-04 15:58:11 +00001080bool Annotator::ResolveConflict(
1081 const std::string& context, const std::vector<Token>& cached_tokens,
1082 const std::vector<AnnotatedSpan>& candidates,
1083 const std::vector<Locale>& detected_text_language_tags, int start_index,
1084 int end_index, AnnotationUsecase annotation_usecase,
1085 InterpreterManager* interpreter_manager,
1086 std::vector<int>* chosen_indices) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001087 std::vector<int> conflicting_indices;
Tony Mak76d80962020-01-08 17:30:51 +00001088 std::unordered_map<int, std::pair<float, int>> scores_lengths;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001089 for (int i = start_index; i < end_index; ++i) {
1090 conflicting_indices.push_back(i);
1091 if (!candidates[i].classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001092 scores_lengths[i] = {
1093 GetPriorityScore(candidates[i].classification),
1094 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001095 continue;
1096 }
1097
1098 // OPTIMIZATION: So that we don't have to classify all the ML model
1099 // spans apriori, we wait until we get here, when they conflict with
1100 // something and we need the actual classification scores. So if the
1101 // candidate conflicts and comes from the model, we need to run a
1102 // classification to determine its priority:
1103 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001104 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1105 candidates[i].span, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001106 /*embedding_cache=*/nullptr, &classification)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001107 return false;
1108 }
1109
1110 if (!classification.empty()) {
Tony Mak76d80962020-01-08 17:30:51 +00001111 scores_lengths[i] = {
1112 GetPriorityScore(classification),
1113 candidates[i].span.second - candidates[i].span.first};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001114 }
1115 }
1116
Tony Mak101bc2a2020-01-09 12:32:17 +00001117 const bool prioritize_longest_annotation =
1118 model_->triggering_options() != nullptr &&
1119 model_->triggering_options()->prioritize_longest_annotation();
1120 std::sort(conflicting_indices.begin(), conflicting_indices.end(),
1121 [&scores_lengths, candidates, conflicting_indices,
1122 prioritize_longest_annotation](int i, int j) {
1123 if (scores_lengths[i].first == scores_lengths[j].first &&
1124 prioritize_longest_annotation) {
1125 return scores_lengths[i].second > scores_lengths[j].second;
1126 }
1127 return scores_lengths[i].first > scores_lengths[j].first;
1128 });
Lukas Zilkab23e2122018-02-09 10:25:19 +01001129
Tony Mak448b5862019-03-22 13:36:41 +00001130 // Here we keep a set of indices that were chosen, per-source, to enable
1131 // effective computation.
1132 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1133 chosen_indices_for_source_map;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001134
1135 // Greedily place the candidates if they don't conflict with the already
1136 // placed ones.
1137 for (int i = 0; i < conflicting_indices.size(); ++i) {
1138 const int considered_candidate = conflicting_indices[i];
Tony Mak448b5862019-03-22 13:36:41 +00001139
1140 // See if there is a conflict between the candidate and all already placed
1141 // candidates.
1142 bool conflict = false;
1143 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1144 for (auto& source_set_pair : chosen_indices_for_source_map) {
1145 if (source_set_pair.first == candidates[considered_candidate].source) {
1146 chosen_indices_for_source_ptr = &source_set_pair.second;
1147 }
1148
1149 if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
1150 candidates[considered_candidate].source) &&
1151 DoesCandidateConflict(considered_candidate, candidates,
1152 source_set_pair.second)) {
1153 conflict = true;
1154 break;
1155 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001156 }
Tony Mak448b5862019-03-22 13:36:41 +00001157
1158 // Skip the candidate if a conflict was found.
1159 if (conflict) {
1160 continue;
1161 }
1162
1163 // If the set of indices for the current source doesn't exist yet,
1164 // initialize it.
1165 if (chosen_indices_for_source_ptr == nullptr) {
1166 SortedIntSet new_set([&candidates](int a, int b) {
1167 return candidates[a].span.first < candidates[b].span.first;
1168 });
1169 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1170 std::move(new_set);
1171 chosen_indices_for_source_ptr =
1172 &chosen_indices_for_source_map[candidates[considered_candidate]
1173 .source];
1174 }
1175
1176 // Place the candidate to the output and to the per-source conflict set.
1177 chosen_indices->push_back(considered_candidate);
1178 chosen_indices_for_source_ptr->insert(considered_candidate);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001179 }
1180
Tony Mak378c1f52019-03-04 15:58:11 +00001181 std::sort(chosen_indices->begin(), chosen_indices->end());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001182
1183 return true;
1184}
1185
Tony Mak6c4cc672018-09-17 11:48:50 +01001186bool Annotator::ModelSuggestSelection(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001187 const UnicodeText& context_unicode, CodepointSpan click_indices,
Tony Makdf54e742019-03-26 14:04:00 +00001188 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001189 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001190 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001191 if (model_->triggering_options() == nullptr ||
1192 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1193 return true;
1194 }
1195
Tony Makdf54e742019-03-26 14:04:00 +00001196 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1197 ml_model_triggering_locales_,
1198 /*default_value=*/true)) {
1199 return true;
1200 }
1201
Lukas Zilka21d8c982018-01-24 11:11:20 +01001202 int click_pos;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001203 *tokens = selection_feature_processor_->Tokenize(context_unicode);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001204 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001205 context_unicode, click_indices,
1206 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001207 tokens, &click_pos);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001208 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001209 TC3_VLOG(1) << "Could not calculate the click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001210 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001211 }
1212
1213 const int symmetry_context_size =
1214 model_->selection_options()->symmetry_context_size();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001215 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
Lukas Zilkab23e2122018-02-09 10:25:19 +01001216 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1217 ->bounds_sensitive_features();
Lukas Zilka21d8c982018-01-24 11:11:20 +01001218
1219 // The symmetry context span is the clicked token with symmetry_context_size
1220 // tokens on either side.
1221 const TokenSpan symmetry_context_span = IntersectTokenSpans(
1222 ExpandTokenSpan(SingleTokenSpan(click_pos),
1223 /*num_tokens_left=*/symmetry_context_size,
1224 /*num_tokens_right=*/symmetry_context_size),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001225 {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001226
Lukas Zilkab23e2122018-02-09 10:25:19 +01001227 // Compute the extraction span based on the model type.
1228 TokenSpan extraction_span;
1229 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1230 // The extraction span is the symmetry context span expanded to include
1231 // max_selection_span tokens on either side, which is how far a selection
1232 // can stretch from the click, plus a relevant number of tokens outside of
1233 // the bounds of the selection.
1234 const int max_selection_span =
1235 selection_feature_processor_->GetOptions()->max_selection_span();
1236 extraction_span =
1237 ExpandTokenSpan(symmetry_context_span,
1238 /*num_tokens_left=*/max_selection_span +
1239 bounds_sensitive_features->num_tokens_before(),
1240 /*num_tokens_right=*/max_selection_span +
1241 bounds_sensitive_features->num_tokens_after());
1242 } else {
1243 // The extraction span is the symmetry context span expanded to include
1244 // context_size tokens on either side.
1245 const int context_size =
1246 selection_feature_processor_->GetOptions()->context_size();
1247 extraction_span = ExpandTokenSpan(symmetry_context_span,
1248 /*num_tokens_left=*/context_size,
1249 /*num_tokens_right=*/context_size);
1250 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001251 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilkab23e2122018-02-09 10:25:19 +01001252
Lukas Zilka434442d2018-04-25 11:38:51 +02001253 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1254 *tokens, extraction_span)) {
1255 return true;
1256 }
1257
Lukas Zilkab23e2122018-02-09 10:25:19 +01001258 std::unique_ptr<CachedFeatures> cached_features;
1259 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001260 *tokens, extraction_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001261 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1262 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001263 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001264 selection_feature_processor_->EmbeddingSize() +
1265 selection_feature_processor_->DenseFeaturesCount(),
1266 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001267 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001268 return false;
1269 }
1270
1271 // Produce selection model candidates.
1272 std::vector<TokenSpan> chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001273 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001274 interpreter_manager->SelectionInterpreter(), *cached_features,
1275 &chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001276 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001277 return false;
1278 }
1279
1280 for (const TokenSpan& chunk : chunks) {
1281 AnnotatedSpan candidate;
1282 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001283 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001284 if (model_->selection_options()->strip_unpaired_brackets()) {
1285 candidate.span =
1286 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1287 }
1288
1289 // Only output non-empty spans.
1290 if (candidate.span.first != candidate.span.second) {
1291 result->push_back(candidate);
1292 }
1293 }
1294 return true;
1295}
1296
Tony Mak6c4cc672018-09-17 11:48:50 +01001297bool Annotator::ModelClassifyText(
Tony Mak378c1f52019-03-04 15:58:11 +00001298 const std::string& context,
1299 const std::vector<Locale>& detected_text_language_tags,
1300 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001301 FeatureProcessor::EmbeddingCache* embedding_cache,
1302 std::vector<ClassificationResult>* classification_results) const {
Tony Mak378c1f52019-03-04 15:58:11 +00001303 return ModelClassifyText(context, {}, detected_text_language_tags,
1304 selection_indices, interpreter_manager,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001305 embedding_cache, classification_results);
1306}
1307
1308namespace internal {
1309std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1310 CodepointSpan selection_indices,
1311 TokenSpan tokens_around_selection_to_copy) {
1312 const auto first_selection_token = std::upper_bound(
1313 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1314 [](int selection_start, const Token& token) {
1315 return selection_start < token.end;
1316 });
1317 const auto last_selection_token = std::lower_bound(
1318 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1319 [](const Token& token, int selection_end) {
1320 return token.start < selection_end;
1321 });
1322
1323 const int64 first_token = std::max(
1324 static_cast<int64>(0),
1325 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1326 tokens_around_selection_to_copy.first));
1327 const int64 last_token = std::min(
1328 static_cast<int64>(cached_tokens.size()),
1329 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1330 tokens_around_selection_to_copy.second));
1331
1332 std::vector<Token> tokens;
1333 tokens.reserve(last_token - first_token);
1334 for (int i = first_token; i < last_token; ++i) {
1335 tokens.push_back(cached_tokens[i]);
1336 }
1337 return tokens;
1338}
1339} // namespace internal
1340
Tony Mak6c4cc672018-09-17 11:48:50 +01001341TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001342 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1343 bounds_sensitive_features =
1344 classification_feature_processor_->GetOptions()
1345 ->bounds_sensitive_features();
1346 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1347 // The extraction span is the selection span expanded to include a relevant
1348 // number of tokens outside of the bounds of the selection.
1349 return {bounds_sensitive_features->num_tokens_before(),
1350 bounds_sensitive_features->num_tokens_after()};
1351 } else {
1352 // The extraction span is the clicked token with context_size tokens on
1353 // either side.
1354 const int context_size =
1355 selection_feature_processor_->GetOptions()->context_size();
1356 return {context_size, context_size};
1357 }
1358}
1359
Tony Mak378c1f52019-03-04 15:58:11 +00001360namespace {
1361// Sorts the classification results from high score to low score.
1362void SortClassificationResults(
1363 std::vector<ClassificationResult>* classification_results) {
1364 std::sort(classification_results->begin(), classification_results->end(),
1365 [](const ClassificationResult& a, const ClassificationResult& b) {
1366 return a.score > b.score;
1367 });
1368}
1369} // namespace
1370
Tony Mak6c4cc672018-09-17 11:48:50 +01001371bool Annotator::ModelClassifyText(
Lukas Zilkaba849e72018-03-08 14:48:21 +01001372 const std::string& context, const std::vector<Token>& cached_tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00001373 const std::vector<Locale>& detected_text_language_tags,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001374 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1375 FeatureProcessor::EmbeddingCache* embedding_cache,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001376 std::vector<ClassificationResult>* classification_results) const {
1377 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00001378 return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1379 selection_indices, interpreter_manager,
1380 embedding_cache, classification_results, &tokens);
1381}
1382
1383bool Annotator::ModelClassifyText(
1384 const std::string& context, const std::vector<Token>& cached_tokens,
1385 const std::vector<Locale>& detected_text_language_tags,
1386 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1387 FeatureProcessor::EmbeddingCache* embedding_cache,
1388 std::vector<ClassificationResult>* classification_results,
1389 std::vector<Token>* tokens) const {
1390 if (model_->triggering_options() == nullptr ||
1391 !(model_->triggering_options()->enabled_modes() &
1392 ModeFlag_CLASSIFICATION)) {
1393 return true;
1394 }
1395
Tony Makdf54e742019-03-26 14:04:00 +00001396 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1397 ml_model_triggering_locales_,
1398 /*default_value=*/true)) {
1399 return true;
1400 }
1401
Lukas Zilkaba849e72018-03-08 14:48:21 +01001402 if (cached_tokens.empty()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001403 *tokens = classification_feature_processor_->Tokenize(context);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001404 } else {
Tony Mak378c1f52019-03-04 15:58:11 +00001405 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1406 ClassifyTextUpperBoundNeededTokens());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001407 }
1408
Lukas Zilkab23e2122018-02-09 10:25:19 +01001409 int click_pos;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001410 classification_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001411 context, selection_indices,
1412 classification_feature_processor_->GetOptions()
1413 ->only_use_line_with_click(),
Tony Mak378c1f52019-03-04 15:58:11 +00001414 tokens, &click_pos);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001415 const TokenSpan selection_token_span =
Tony Mak378c1f52019-03-04 15:58:11 +00001416 CodepointSpanToTokenSpan(*tokens, selection_indices);
Lukas Zilka434442d2018-04-25 11:38:51 +02001417 const int selection_num_tokens = TokenSpanSize(selection_token_span);
1418 if (model_->classification_options()->max_num_tokens() > 0 &&
1419 model_->classification_options()->max_num_tokens() <
1420 selection_num_tokens) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001421 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001422 return true;
1423 }
1424
Lukas Zilkab23e2122018-02-09 10:25:19 +01001425 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1426 bounds_sensitive_features =
1427 classification_feature_processor_->GetOptions()
1428 ->bounds_sensitive_features();
1429 if (selection_token_span.first == kInvalidIndex ||
1430 selection_token_span.second == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001431 TC3_LOG(ERROR) << "Could not determine span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001432 return false;
1433 }
1434
1435 // Compute the extraction span based on the model type.
1436 TokenSpan extraction_span;
1437 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1438 // The extraction span is the selection span expanded to include a relevant
1439 // number of tokens outside of the bounds of the selection.
1440 extraction_span = ExpandTokenSpan(
1441 selection_token_span,
1442 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1443 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1444 } else {
1445 if (click_pos == kInvalidIndex) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001446 TC3_LOG(ERROR) << "Couldn't choose a click position.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001447 return false;
1448 }
1449 // The extraction span is the clicked token with context_size tokens on
1450 // either side.
1451 const int context_size =
Lukas Zilkaba849e72018-03-08 14:48:21 +01001452 classification_feature_processor_->GetOptions()->context_size();
Lukas Zilkab23e2122018-02-09 10:25:19 +01001453 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
1454 /*num_tokens_left=*/context_size,
1455 /*num_tokens_right=*/context_size);
1456 }
Tony Mak378c1f52019-03-04 15:58:11 +00001457 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001458
Lukas Zilka434442d2018-04-25 11:38:51 +02001459 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
Tony Mak378c1f52019-03-04 15:58:11 +00001460 *tokens, extraction_span)) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001461 *classification_results = {{Collections::Other(), 1.0}};
Lukas Zilka434442d2018-04-25 11:38:51 +02001462 return true;
1463 }
1464
Lukas Zilka21d8c982018-01-24 11:11:20 +01001465 std::unique_ptr<CachedFeatures> cached_features;
1466 if (!classification_feature_processor_->ExtractFeatures(
Tony Mak378c1f52019-03-04 15:58:11 +00001467 *tokens, extraction_span, selection_indices,
1468 embedding_executor_.get(), embedding_cache,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001469 classification_feature_processor_->EmbeddingSize() +
1470 classification_feature_processor_->DenseFeaturesCount(),
1471 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001472 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001473 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001474 }
1475
Lukas Zilkab23e2122018-02-09 10:25:19 +01001476 std::vector<float> features;
1477 features.reserve(cached_features->OutputFeaturesSize());
1478 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1479 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1480 &features);
1481 } else {
1482 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001483 }
1484
Lukas Zilkaba849e72018-03-08 14:48:21 +01001485 TensorView<float> logits = classification_executor_->ComputeLogits(
1486 TensorView<float>(features.data(),
1487 {1, static_cast<int>(features.size())}),
1488 interpreter_manager->ClassificationInterpreter());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001489 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001490 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001491 return false;
1492 }
1493
1494 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1495 logits.dim(1) != classification_feature_processor_->NumCollections()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001496 TC3_LOG(ERROR) << "Mismatching output";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001497 return false;
1498 }
1499
1500 const std::vector<float> scores =
1501 ComputeSoftmax(logits.data(), logits.dim(1));
1502
Tony Mak81e52422019-04-30 09:34:45 +01001503 if (scores.empty()) {
1504 *classification_results = {{Collections::Other(), 1.0}};
1505 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001506 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001507
Tony Mak81e52422019-04-30 09:34:45 +01001508 const int best_score_index =
1509 std::max_element(scores.begin(), scores.end()) - scores.begin();
1510 const std::string top_collection =
1511 classification_feature_processor_->LabelToCollection(best_score_index);
1512
1513 // Sanity checks.
1514 if (top_collection == Collections::Phone()) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001515 const int digit_count = CountDigits(context, selection_indices);
1516 if (digit_count <
1517 model_->classification_options()->phone_min_num_digits() ||
1518 digit_count >
1519 model_->classification_options()->phone_max_num_digits()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001520 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001521 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001522 }
Tony Mak81e52422019-04-30 09:34:45 +01001523 } else if (top_collection == Collections::Address()) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001524 if (selection_num_tokens <
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001525 model_->classification_options()->address_min_num_tokens()) {
Tony Mak5dc5e112019-02-01 14:52:10 +00001526 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001527 return true;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001528 }
Tony Mak81e52422019-04-30 09:34:45 +01001529 } else if (top_collection == Collections::Dictionary()) {
Tony Mak378c1f52019-03-04 15:58:11 +00001530 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1531 dictionary_locales_,
Tony Makdf54e742019-03-26 14:04:00 +00001532 /*default_value=*/false)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001533 *classification_results = {{Collections::Other(), 1.0}};
Tony Mak81e52422019-04-30 09:34:45 +01001534 return true;
Tony Mak378c1f52019-03-04 15:58:11 +00001535 }
1536 }
Tony Mak81e52422019-04-30 09:34:45 +01001537
Tony Makd99d58c2020-03-19 21:52:02 +00001538 *classification_results = {{top_collection, /*arg_score=*/1.0,
1539 /*arg_priority_score=*/scores[best_score_index]}};
1540
1541 // For some entities, we might want to clamp the priority score, for better
1542 // conflict resolution between entities.
1543 if (model_->triggering_options() != nullptr &&
1544 model_->triggering_options()->collection_to_priority() != nullptr) {
1545 if (auto entry =
1546 model_->triggering_options()->collection_to_priority()->LookupByKey(
1547 top_collection.c_str())) {
1548 (*classification_results)[0].priority_score *= entry->value();
1549 }
1550 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001551 return true;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001552}
1553
Tony Mak6c4cc672018-09-17 11:48:50 +01001554bool Annotator::RegexClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001555 const std::string& context, CodepointSpan selection_indices,
Tony Mak378c1f52019-03-04 15:58:11 +00001556 std::vector<ClassificationResult>* classification_result) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001557 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001558 UTF8ToUnicodeText(context, /*do_copy=*/false)
1559 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001560 const UnicodeText selection_text_unicode(
1561 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1562
1563 // Check whether any of the regular expressions match.
1564 for (const int pattern_id : classification_regex_patterns_) {
1565 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1566 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1567 regex_pattern.pattern->Matcher(selection_text_unicode);
1568 int status = UniLib::RegexMatcher::kNoError;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001569 bool matches;
Tony Mak854015a2019-01-16 15:56:48 +00001570 if (regex_pattern.config->use_approximate_matching()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001571 matches = matcher->ApproximatelyMatches(&status);
1572 } else {
1573 matches = matcher->Matches(&status);
1574 }
1575 if (status != UniLib::RegexMatcher::kNoError) {
1576 return false;
1577 }
Tony Makdf54e742019-03-26 14:04:00 +00001578 if (matches && VerifyRegexMatchCandidate(
1579 context, regex_pattern.config->verification_options(),
1580 selection_text, matcher.get())) {
Tony Mak378c1f52019-03-04 15:58:11 +00001581 classification_result->push_back(
1582 {regex_pattern.config->collection_name()->str(),
1583 regex_pattern.config->target_classification_score(),
1584 regex_pattern.config->priority_score()});
Tony Makd9446602019-02-20 18:25:39 +00001585 if (!SerializedEntityDataFromRegexMatch(
1586 regex_pattern.config, matcher.get(),
Tony Mak378c1f52019-03-04 15:58:11 +00001587 &classification_result->back().serialized_entity_data)) {
Tony Makd9446602019-02-20 18:25:39 +00001588 TC3_LOG(ERROR) << "Could not get entity data.";
1589 return false;
1590 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001591 }
1592 }
1593
Tony Mak378c1f52019-03-04 15:58:11 +00001594 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001595}
1596
Tony Mak5dc5e112019-02-01 14:52:10 +00001597namespace {
1598std::string PickCollectionForDatetime(
1599 const DatetimeParseResult& datetime_parse_result) {
1600 switch (datetime_parse_result.granularity) {
1601 case GRANULARITY_HOUR:
1602 case GRANULARITY_MINUTE:
1603 case GRANULARITY_SECOND:
1604 return Collections::DateTime();
1605 default:
1606 return Collections::Date();
1607 }
1608}
Tony Mak83d2de62019-04-10 16:12:15 +01001609
1610std::string CreateDatetimeSerializedEntityData(
1611 const DatetimeParseResult& parse_result) {
1612 EntityDataT entity_data;
1613 entity_data.datetime.reset(new EntityData_::DatetimeT());
1614 entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
1615 entity_data.datetime->granularity =
1616 static_cast<EntityData_::Datetime_::Granularity>(
1617 parse_result.granularity);
1618
Tony Maka2a1ff42019-09-12 15:40:32 +01001619 for (const auto& c : parse_result.datetime_components) {
1620 EntityData_::Datetime_::DatetimeComponentT datetime_component;
1621 datetime_component.absolute_value = c.value;
1622 datetime_component.relative_count = c.relative_count;
1623 datetime_component.component_type =
1624 static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
1625 c.component_type);
1626 datetime_component.relation_type =
1627 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
1628 if (c.relative_qualifier !=
1629 DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
1630 datetime_component.relation_type =
1631 EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
1632 }
1633 entity_data.datetime->datetime_component.emplace_back(
1634 new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
1635 }
Tony Mak83d2de62019-04-10 16:12:15 +01001636 flatbuffers::FlatBufferBuilder builder;
1637 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
1638 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
1639 builder.GetSize());
1640}
Tony Mak63959242020-02-07 18:31:16 +00001641
Tony Mak5dc5e112019-02-01 14:52:10 +00001642} // namespace
1643
Tony Mak6c4cc672018-09-17 11:48:50 +01001644bool Annotator::DatetimeClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001645 const std::string& context, CodepointSpan selection_indices,
1646 const ClassificationOptions& options,
Tony Mak854015a2019-01-16 15:56:48 +00001647 std::vector<ClassificationResult>* classification_results) const {
Tony Mak63959242020-02-07 18:31:16 +00001648 if (!datetime_parser_ && !cfg_datetime_parser_) {
Tony Makd99d58c2020-03-19 21:52:02 +00001649 return true;
Lukas Zilkaba849e72018-03-08 14:48:21 +01001650 }
1651
Lukas Zilkab23e2122018-02-09 10:25:19 +01001652 const std::string selection_text =
Tony Makd9446602019-02-20 18:25:39 +00001653 UTF8ToUnicodeText(context, /*do_copy=*/false)
1654 .UTF8Substring(selection_indices.first, selection_indices.second);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001655
1656 std::vector<DatetimeParseResultSpan> datetime_spans;
Tony Makd99d58c2020-03-19 21:52:02 +00001657
Tony Mak63959242020-02-07 18:31:16 +00001658 if (cfg_datetime_parser_) {
1659 if (!(model_->grammar_datetime_model()->enabled_modes() &
1660 ModeFlag_CLASSIFICATION)) {
1661 return true;
1662 }
1663 std::vector<Locale> parsed_locales;
1664 ParseLocales(options.locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00001665 cfg_datetime_parser_->Parse(
1666 selection_text,
1667 ToDateAnnotationOptions(
1668 model_->grammar_datetime_model()->annotation_options(),
1669 options.reference_timezone, options.reference_time_ms_utc),
1670 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00001671 }
1672
1673 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00001674 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1675 options.reference_timezone, options.locales,
1676 ModeFlag_CLASSIFICATION,
1677 options.annotation_usecase,
1678 /*anchor_start_end=*/true, &datetime_spans)) {
1679 TC3_LOG(ERROR) << "Error during parsing datetime.";
1680 return false;
1681 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001682 }
Tony Makd99d58c2020-03-19 21:52:02 +00001683
Lukas Zilkab23e2122018-02-09 10:25:19 +01001684 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1685 // Only consider the result valid if the selection and extracted datetime
1686 // spans exactly match.
1687 if (std::make_pair(datetime_span.span.first + selection_indices.first,
1688 datetime_span.span.second + selection_indices.first) ==
1689 selection_indices) {
Tony Mak854015a2019-01-16 15:56:48 +00001690 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1691 classification_results->emplace_back(
Tony Mak5dc5e112019-02-01 14:52:10 +00001692 PickCollectionForDatetime(parse_result),
1693 datetime_span.target_classification_score);
Tony Mak854015a2019-01-16 15:56:48 +00001694 classification_results->back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01001695 classification_results->back().serialized_entity_data =
1696 CreateDatetimeSerializedEntityData(parse_result);
Tony Mak378c1f52019-03-04 15:58:11 +00001697 classification_results->back().priority_score =
1698 datetime_span.priority_score;
Tony Mak854015a2019-01-16 15:56:48 +00001699 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001700 return true;
1701 }
1702 }
Tony Mak378c1f52019-03-04 15:58:11 +00001703 return true;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001704}
1705
Tony Mak6c4cc672018-09-17 11:48:50 +01001706std::vector<ClassificationResult> Annotator::ClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001707 const std::string& context, CodepointSpan selection_indices,
1708 const ClassificationOptions& options) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01001709 if (!initialized_) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001710 TC3_LOG(ERROR) << "Not initialized";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001711 return {};
1712 }
1713
Lukas Zilkaba849e72018-03-08 14:48:21 +01001714 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1715 return {};
1716 }
1717
Tony Makdf54e742019-03-26 14:04:00 +00001718 std::vector<Locale> detected_text_language_tags;
1719 if (!ParseLocales(options.detected_text_language_tags,
1720 &detected_text_language_tags)) {
1721 TC3_LOG(WARNING)
1722 << "Failed to parse the detected_text_language_tags in options: "
1723 << options.detected_text_language_tags;
1724 }
1725 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1726 model_triggering_locales_,
1727 /*default_value=*/true)) {
1728 return {};
1729 }
1730
Tony Mak968412a2019-11-13 15:39:57 +00001731 if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
1732 selection_indices)) {
1733 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
Tony Mak6c4cc672018-09-17 11:48:50 +01001734 << std::get<0>(selection_indices) << " "
1735 << std::get<1>(selection_indices);
Lukas Zilka21d8c982018-01-24 11:11:20 +01001736 return {};
1737 }
1738
Tony Mak378c1f52019-03-04 15:58:11 +00001739 // We'll accumulate a list of candidates, and pick the best candidate in the
1740 // end.
1741 std::vector<AnnotatedSpan> candidates;
1742
Tony Mak6c4cc672018-09-17 11:48:50 +01001743 // Try the knowledge engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001744 // TODO(b/126579108): Propagate error status.
Tony Mak6c4cc672018-09-17 11:48:50 +01001745 ClassificationResult knowledge_result;
Tony Mak63959242020-02-07 18:31:16 +00001746 if (knowledge_engine_ &&
1747 knowledge_engine_->ClassifyText(
1748 context, selection_indices, options.annotation_usecase,
1749 options.location_context, &knowledge_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001750 candidates.push_back({selection_indices, {knowledge_result}});
1751 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
Tony Mak854015a2019-01-16 15:56:48 +00001752 }
1753
Tony Maka2a1ff42019-09-12 15:40:32 +01001754 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1755
Tony Mak854015a2019-01-16 15:56:48 +00001756 // Try the contact engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001757 // TODO(b/126579108): Propagate error status.
Tony Mak854015a2019-01-16 15:56:48 +00001758 ClassificationResult contact_result;
1759 if (contact_engine_ && contact_engine_->ClassifyText(
1760 context, selection_indices, &contact_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001761 candidates.push_back({selection_indices, {contact_result}});
Tony Mak6c4cc672018-09-17 11:48:50 +01001762 }
1763
Tony Mak76d80962020-01-08 17:30:51 +00001764 // Try the person name engine.
1765 ClassificationResult person_name_result;
1766 if (person_name_engine_ &&
1767 person_name_engine_->ClassifyText(context, selection_indices,
1768 &person_name_result)) {
1769 candidates.push_back({selection_indices, {person_name_result}});
Tony Makd0ae7c62020-03-27 13:58:00 +00001770 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
Tony Mak76d80962020-01-08 17:30:51 +00001771 }
1772
Tony Makd9446602019-02-20 18:25:39 +00001773 // Try the installed app engine.
Tony Mak378c1f52019-03-04 15:58:11 +00001774 // TODO(b/126579108): Propagate error status.
Tony Makd9446602019-02-20 18:25:39 +00001775 ClassificationResult installed_app_result;
1776 if (installed_app_engine_ &&
1777 installed_app_engine_->ClassifyText(context, selection_indices,
1778 &installed_app_result)) {
Tony Mak378c1f52019-03-04 15:58:11 +00001779 candidates.push_back({selection_indices, {installed_app_result}});
Tony Makd9446602019-02-20 18:25:39 +00001780 }
1781
Lukas Zilkab23e2122018-02-09 10:25:19 +01001782 // Try the regular expression models.
Tony Mak378c1f52019-03-04 15:58:11 +00001783 std::vector<ClassificationResult> regex_results;
1784 if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1785 return {};
1786 }
1787 for (const ClassificationResult& result : regex_results) {
1788 candidates.push_back({selection_indices, {result}});
Lukas Zilka21d8c982018-01-24 11:11:20 +01001789 }
1790
Lukas Zilkab23e2122018-02-09 10:25:19 +01001791 // Try the date model.
Tony Mak378c1f52019-03-04 15:58:11 +00001792 //
1793 // DatetimeClassifyText only returns the first result, which can however have
1794 // more interpretations. They are inserted in the candidates as a single
1795 // AnnotatedSpan, so that they get treated together by the conflict resolution
1796 // algorithm.
Tony Mak854015a2019-01-16 15:56:48 +00001797 std::vector<ClassificationResult> datetime_results;
Tony Mak378c1f52019-03-04 15:58:11 +00001798 if (!DatetimeClassifyText(context, selection_indices, options,
1799 &datetime_results)) {
1800 return {};
1801 }
1802 if (!datetime_results.empty()) {
1803 candidates.push_back({selection_indices, std::move(datetime_results)});
Tony Mak448b5862019-03-22 13:36:41 +00001804 candidates.back().source = AnnotatedSpan::Source::DATETIME;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001805 }
1806
Tony Mak378c1f52019-03-04 15:58:11 +00001807 // Try the number annotator.
1808 // TODO(b/126579108): Propagate error status.
1809 ClassificationResult number_annotator_result;
1810 if (number_annotator_ &&
1811 number_annotator_->ClassifyText(
1812 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1813 options.annotation_usecase, &number_annotator_result)) {
1814 candidates.push_back({selection_indices, {number_annotator_result}});
1815 }
Lukas Zilkaba849e72018-03-08 14:48:21 +01001816
Tony Makad2e22d2019-03-20 17:35:13 +00001817 // Try the duration annotator.
1818 ClassificationResult duration_annotator_result;
1819 if (duration_annotator_ &&
1820 duration_annotator_->ClassifyText(
1821 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1822 options.annotation_usecase, &duration_annotator_result)) {
1823 candidates.push_back({selection_indices, {duration_annotator_result}});
Tony Mak448b5862019-03-22 13:36:41 +00001824 candidates.back().source = AnnotatedSpan::Source::DURATION;
Tony Makad2e22d2019-03-20 17:35:13 +00001825 }
1826
Tony Mak63959242020-02-07 18:31:16 +00001827 // Try the translate annotator.
1828 ClassificationResult translate_annotator_result;
1829 if (translate_annotator_ &&
1830 translate_annotator_->ClassifyText(
1831 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1832 options.user_familiar_language_tags, &translate_annotator_result)) {
1833 candidates.push_back({selection_indices, {translate_annotator_result}});
1834 }
1835
Tony Mak21460022020-03-12 18:29:35 +00001836 // Try the grammar model.
1837 ClassificationResult grammar_annotator_result;
1838 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1839 detected_text_language_tags,
1840 UTF8ToUnicodeText(context, /*do_copy=*/false),
1841 selection_indices, &grammar_annotator_result)) {
1842 candidates.push_back({selection_indices, {grammar_annotator_result}});
1843 }
1844
Tony Mak378c1f52019-03-04 15:58:11 +00001845 // Try the ML model.
1846 //
1847 // The output of the model is considered as an exclusive 1-of-N choice. That's
1848 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1849 // span for each candidate, like e.g. the regex model.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001850 InterpreterManager interpreter_manager(selection_executor_.get(),
1851 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00001852 std::vector<ClassificationResult> model_results;
1853 std::vector<Token> tokens;
1854 if (!ModelClassifyText(
1855 context, /*cached_tokens=*/{}, detected_text_language_tags,
1856 selection_indices, &interpreter_manager,
1857 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1858 return {};
1859 }
1860 if (!model_results.empty()) {
1861 candidates.push_back({selection_indices, std::move(model_results)});
1862 }
1863
1864 std::vector<int> candidate_indices;
1865 if (!ResolveConflicts(candidates, context, tokens,
1866 detected_text_language_tags, options.annotation_usecase,
1867 &interpreter_manager, &candidate_indices)) {
1868 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1869 return {};
1870 }
1871
1872 std::vector<ClassificationResult> results;
1873 for (const int i : candidate_indices) {
1874 for (const ClassificationResult& result : candidates[i].classification) {
1875 if (!FilteredForClassification(result)) {
1876 results.push_back(result);
1877 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001878 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001879 }
1880
Tony Mak378c1f52019-03-04 15:58:11 +00001881 // Sort results according to score.
1882 std::sort(results.begin(), results.end(),
1883 [](const ClassificationResult& a, const ClassificationResult& b) {
1884 return a.score > b.score;
1885 });
1886
1887 if (results.empty()) {
Tony Mak81e52422019-04-30 09:34:45 +01001888 results = {{Collections::Other(), 1.0}};
Tony Mak378c1f52019-03-04 15:58:11 +00001889 }
Tony Mak378c1f52019-03-04 15:58:11 +00001890 return results;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001891}
1892
Tony Mak378c1f52019-03-04 15:58:11 +00001893bool Annotator::ModelAnnotate(
1894 const std::string& context,
1895 const std::vector<Locale>& detected_text_language_tags,
1896 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1897 std::vector<AnnotatedSpan>* result) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001898 if (model_->triggering_options() == nullptr ||
1899 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1900 return true;
1901 }
1902
Tony Makdf54e742019-03-26 14:04:00 +00001903 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1904 ml_model_triggering_locales_,
1905 /*default_value=*/true)) {
1906 return true;
1907 }
1908
Lukas Zilka21d8c982018-01-24 11:11:20 +01001909 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1910 /*do_copy=*/false);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001911 std::vector<UnicodeTextRange> lines;
1912 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1913 lines.push_back({context_unicode.begin(), context_unicode.end()});
1914 } else {
Tony Maka2a1ff42019-09-12 15:40:32 +01001915 lines = selection_feature_processor_->SplitContext(
1916 context_unicode, selection_feature_processor_->GetOptions()
1917 ->use_pipe_character_for_newline());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001918 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001919
Lukas Zilkaba849e72018-03-08 14:48:21 +01001920 const float min_annotate_confidence =
1921 (model_->triggering_options() != nullptr
1922 ? model_->triggering_options()->min_annotate_confidence()
1923 : 0.f);
1924
Lukas Zilkab23e2122018-02-09 10:25:19 +01001925 for (const UnicodeTextRange& line : lines) {
Tony Mak408c6b82019-03-08 17:57:27 +00001926 FeatureProcessor::EmbeddingCache embedding_cache;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001927 const std::string line_str =
1928 UnicodeText::UTF8Substring(line.first, line.second);
1929
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001930 *tokens = selection_feature_processor_->Tokenize(line_str);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001931 selection_feature_processor_->RetokenizeAndFindClick(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001932 line_str, {0, std::distance(line.first, line.second)},
1933 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001934 tokens,
Lukas Zilka21d8c982018-01-24 11:11:20 +01001935 /*click_pos=*/nullptr);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001936 const TokenSpan full_line_span = {0, tokens->size()};
Lukas Zilka21d8c982018-01-24 11:11:20 +01001937
Lukas Zilka434442d2018-04-25 11:38:51 +02001938 // TODO(zilka): Add support for greater granularity of this check.
1939 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1940 *tokens, full_line_span)) {
1941 continue;
1942 }
1943
Lukas Zilka21d8c982018-01-24 11:11:20 +01001944 std::unique_ptr<CachedFeatures> cached_features;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001945 if (!selection_feature_processor_->ExtractFeatures(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001946 *tokens, full_line_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001947 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1948 embedding_executor_.get(),
Lukas Zilkaba849e72018-03-08 14:48:21 +01001949 /*embedding_cache=*/nullptr,
Lukas Zilkab23e2122018-02-09 10:25:19 +01001950 selection_feature_processor_->EmbeddingSize() +
1951 selection_feature_processor_->DenseFeaturesCount(),
Lukas Zilka21d8c982018-01-24 11:11:20 +01001952 &cached_features)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001953 TC3_LOG(ERROR) << "Could not extract features.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001954 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001955 }
1956
1957 std::vector<TokenSpan> local_chunks;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001958 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001959 interpreter_manager->SelectionInterpreter(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01001960 *cached_features, &local_chunks)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001961 TC3_LOG(ERROR) << "Could not chunk.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001962 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01001963 }
1964
1965 const int offset = std::distance(context_unicode.begin(), line.first);
1966 for (const TokenSpan& chunk : local_chunks) {
1967 const CodepointSpan codepoint_span =
1968 selection_feature_processor_->StripBoundaryCodepoints(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001969 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001970
1971 // Skip empty spans.
1972 if (codepoint_span.first != codepoint_span.second) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001973 std::vector<ClassificationResult> classification;
Tony Mak378c1f52019-03-04 15:58:11 +00001974 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
1975 codepoint_span, interpreter_manager,
1976 &embedding_cache, &classification)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001977 TC3_LOG(ERROR) << "Could not classify text: "
1978 << (codepoint_span.first + offset) << " "
1979 << (codepoint_span.second + offset);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001980 return false;
1981 }
1982
1983 // Do not include the span if it's classified as "other".
1984 if (!classification.empty() && !ClassifiedAsOther(classification) &&
1985 classification[0].score >= min_annotate_confidence) {
1986 AnnotatedSpan result_span;
1987 result_span.span = {codepoint_span.first + offset,
1988 codepoint_span.second + offset};
1989 result_span.classification = std::move(classification);
1990 result->push_back(std::move(result_span));
1991 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001992 }
Lukas Zilka21d8c982018-01-24 11:11:20 +01001993 }
1994 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01001995 return true;
1996}
1997
Tony Mak6c4cc672018-09-17 11:48:50 +01001998const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
Lukas Zilka434442d2018-04-25 11:38:51 +02001999 return selection_feature_processor_.get();
2000}
2001
Tony Mak6c4cc672018-09-17 11:48:50 +01002002const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
Lukas Zilka434442d2018-04-25 11:38:51 +02002003 const {
2004 return classification_feature_processor_.get();
Lukas Zilkab23e2122018-02-09 10:25:19 +01002005}
2006
Tony Mak6c4cc672018-09-17 11:48:50 +01002007const DatetimeParser* Annotator::DatetimeParserForTests() const {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002008 return datetime_parser_.get();
2009}
2010
Tony Mak83d2de62019-04-10 16:12:15 +01002011void Annotator::RemoveNotEnabledEntityTypes(
2012 const EnabledEntityTypes& is_entity_type_enabled,
2013 std::vector<AnnotatedSpan>* annotated_spans) const {
2014 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2015 std::vector<ClassificationResult>& classifications =
2016 annotated_span.classification;
2017 classifications.erase(
2018 std::remove_if(classifications.begin(), classifications.end(),
2019 [&is_entity_type_enabled](
2020 const ClassificationResult& classification_result) {
2021 return !is_entity_type_enabled(
2022 classification_result.collection);
2023 }),
2024 classifications.end());
2025 }
2026 annotated_spans->erase(
2027 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2028 [](const AnnotatedSpan& annotated_span) {
2029 return annotated_span.classification.empty();
2030 }),
2031 annotated_spans->end());
2032}
2033
Tony Maka2a1ff42019-09-12 15:40:32 +01002034void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2035 std::vector<AnnotatedSpan>* candidates) const {
2036 if (candidates == nullptr || contact_engine_ == nullptr) {
2037 return;
2038 }
2039 for (auto& candidate : *candidates) {
2040 for (auto& classification_result : candidate.classification) {
2041 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2042 &classification_result);
2043 }
2044 }
2045}
2046
Tony Makff31efb2020-03-31 11:13:06 +01002047Status Annotator::AnnotateSingleInput(
2048 const std::string& context, const AnnotationOptions& options,
2049 std::vector<AnnotatedSpan>* candidates) const {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002050 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
Tony Makff31efb2020-03-31 11:13:06 +01002051 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
Lukas Zilkaba849e72018-03-08 14:48:21 +01002052 }
2053
Tony Mak854015a2019-01-16 15:56:48 +00002054 const UnicodeText context_unicode =
2055 UTF8ToUnicodeText(context, /*do_copy=*/false);
2056 if (!context_unicode.is_valid()) {
Tony Makff31efb2020-03-31 11:13:06 +01002057 return Status(StatusCode::INVALID_ARGUMENT,
2058 "Context string isn't valid UTF8.");
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002059 }
2060
Tony Mak378c1f52019-03-04 15:58:11 +00002061 std::vector<Locale> detected_text_language_tags;
2062 if (!ParseLocales(options.detected_text_language_tags,
2063 &detected_text_language_tags)) {
2064 TC3_LOG(WARNING)
Tony Makdf54e742019-03-26 14:04:00 +00002065 << "Failed to parse the detected_text_language_tags in options: "
Tony Mak378c1f52019-03-04 15:58:11 +00002066 << options.detected_text_language_tags;
2067 }
Tony Makdf54e742019-03-26 14:04:00 +00002068 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2069 model_triggering_locales_,
2070 /*default_value=*/true)) {
Tony Makff31efb2020-03-31 11:13:06 +01002071 return Status(
2072 StatusCode::UNAVAILABLE,
2073 "The detected language tags are not in the supported locales.");
Tony Makdf54e742019-03-26 14:04:00 +00002074 }
2075
2076 InterpreterManager interpreter_manager(selection_executor_.get(),
2077 classification_executor_.get());
Tony Mak378c1f52019-03-04 15:58:11 +00002078
Lukas Zilkab23e2122018-02-09 10:25:19 +01002079 // Annotate with the selection model.
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002080 std::vector<Token> tokens;
Tony Mak378c1f52019-03-04 15:58:11 +00002081 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
Tony Makff31efb2020-03-31 11:13:06 +01002082 &tokens, candidates)) {
2083 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002084 }
2085
2086 // Annotate with the regular expression models.
2087 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Tony Makff31efb2020-03-31 11:13:06 +01002088 annotation_regex_patterns_, candidates,
Tony Mak83d2de62019-04-10 16:12:15 +01002089 options.is_serialized_entity_data_enabled)) {
Tony Makff31efb2020-03-31 11:13:06 +01002090 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002091 }
2092
2093 // Annotate with the datetime model.
Tony Mak83d2de62019-04-10 16:12:15 +01002094 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2095 if ((is_entity_type_enabled(Collections::Date()) ||
2096 is_entity_type_enabled(Collections::DateTime())) &&
2097 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002098 options.reference_time_ms_utc, options.reference_timezone,
Tony Makd9446602019-02-20 18:25:39 +00002099 options.locales, ModeFlag_ANNOTATION,
Tony Mak83d2de62019-04-10 16:12:15 +01002100 options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002101 options.is_serialized_entity_data_enabled, candidates)) {
2102 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
Tony Mak6c4cc672018-09-17 11:48:50 +01002103 }
2104
Tony Mak854015a2019-01-16 15:56:48 +00002105 // Annotate with the contact engine.
2106 if (contact_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002107 !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2108 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
Tony Mak854015a2019-01-16 15:56:48 +00002109 }
2110
Tony Makd9446602019-02-20 18:25:39 +00002111 // Annotate with the installed app engine.
2112 if (installed_app_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002113 !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2114 return Status(StatusCode::INTERNAL,
2115 "Couldn't run installed app engine Chunk.");
Tony Makd9446602019-02-20 18:25:39 +00002116 }
2117
Tony Mak378c1f52019-03-04 15:58:11 +00002118 // Annotate with the number annotator.
2119 if (number_annotator_ != nullptr &&
2120 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
Tony Makff31efb2020-03-31 11:13:06 +01002121 candidates)) {
2122 return Status(StatusCode::INTERNAL,
2123 "Couldn't run number annotator FindAll.");
Tony Makad2e22d2019-03-20 17:35:13 +00002124 }
2125
2126 // Annotate with the duration annotator.
Tony Mak83d2de62019-04-10 16:12:15 +01002127 if (is_entity_type_enabled(Collections::Duration()) &&
2128 duration_annotator_ != nullptr &&
Tony Mak448b5862019-03-22 13:36:41 +00002129 !duration_annotator_->FindAll(context_unicode, tokens,
Tony Makff31efb2020-03-31 11:13:06 +01002130 options.annotation_usecase, candidates)) {
2131 return Status(StatusCode::INTERNAL,
2132 "Couldn't run duration annotator FindAll.");
Tony Mak378c1f52019-03-04 15:58:11 +00002133 }
2134
Tony Mak76d80962020-01-08 17:30:51 +00002135 // Annotate with the person name engine.
2136 if (is_entity_type_enabled(Collections::PersonName()) &&
2137 person_name_engine_ &&
Tony Makff31efb2020-03-31 11:13:06 +01002138 !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2139 return Status(StatusCode::INTERNAL,
2140 "Couldn't run person name engine Chunk.");
Tony Mak76d80962020-01-08 17:30:51 +00002141 }
2142
Tony Mak21460022020-03-12 18:29:35 +00002143 // Annotate with the grammar annotators.
2144 if (grammar_annotator_ != nullptr &&
2145 !grammar_annotator_->Annotate(detected_text_language_tags,
Tony Makff31efb2020-03-31 11:13:06 +01002146 context_unicode, candidates)) {
2147 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
Tony Mak21460022020-03-12 18:29:35 +00002148 }
2149
Lukas Zilkab23e2122018-02-09 10:25:19 +01002150 // Sort candidates according to their position in the input, so that the next
2151 // code can assume that any connected component of overlapping spans forms a
2152 // contiguous block.
Tony Makff31efb2020-03-31 11:13:06 +01002153 std::sort(candidates->begin(), candidates->end(),
Lukas Zilkab23e2122018-02-09 10:25:19 +01002154 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2155 return a.span.first < b.span.first;
2156 });
2157
2158 std::vector<int> candidate_indices;
Tony Makff31efb2020-03-31 11:13:06 +01002159 if (!ResolveConflicts(*candidates, context, tokens,
Tony Mak378c1f52019-03-04 15:58:11 +00002160 detected_text_language_tags, options.annotation_usecase,
2161 &interpreter_manager, &candidate_indices)) {
Tony Makff31efb2020-03-31 11:13:06 +01002162 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
Lukas Zilkab23e2122018-02-09 10:25:19 +01002163 }
2164
Lukas Zilkab23e2122018-02-09 10:25:19 +01002165 std::vector<AnnotatedSpan> result;
2166 result.reserve(candidate_indices.size());
Tony Mak378c1f52019-03-04 15:58:11 +00002167 AnnotatedSpan aggregated_span;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002168 for (const int i : candidate_indices) {
Tony Makff31efb2020-03-31 11:13:06 +01002169 if ((*candidates)[i].span != aggregated_span.span) {
Tony Mak378c1f52019-03-04 15:58:11 +00002170 if (!aggregated_span.classification.empty()) {
2171 result.push_back(std::move(aggregated_span));
2172 }
2173 aggregated_span =
Tony Makff31efb2020-03-31 11:13:06 +01002174 AnnotatedSpan((*candidates)[i].span, /*arg_classification=*/{});
Lukas Zilkab23e2122018-02-09 10:25:19 +01002175 }
Tony Makff31efb2020-03-31 11:13:06 +01002176 if ((*candidates)[i].classification.empty() ||
2177 ClassifiedAsOther((*candidates)[i].classification) ||
2178 FilteredForAnnotation((*candidates)[i])) {
Tony Mak378c1f52019-03-04 15:58:11 +00002179 continue;
2180 }
Tony Makff31efb2020-03-31 11:13:06 +01002181 for (ClassificationResult& classification :
2182 (*candidates)[i].classification) {
Tony Mak378c1f52019-03-04 15:58:11 +00002183 aggregated_span.classification.push_back(std::move(classification));
2184 }
2185 }
2186 if (!aggregated_span.classification.empty()) {
2187 result.push_back(std::move(aggregated_span));
2188 }
2189
Tony Mak83d2de62019-04-10 16:12:15 +01002190 // We generate all candidates and remove them later (with the exception of
2191 // date/time/duration entities) because there are complex interdependencies
2192 // between the entity types. E.g., the TLD of an email can be interpreted as a
2193 // URL, but most likely a user of the API does not want such annotations if
2194 // "url" is enabled and "email" is not.
2195 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2196
Tony Mak378c1f52019-03-04 15:58:11 +00002197 for (AnnotatedSpan& annotated_span : result) {
2198 SortClassificationResults(&annotated_span.classification);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002199 }
Tony Makff31efb2020-03-31 11:13:06 +01002200 *candidates = result;
2201 return Status::OK;
2202}
Lukas Zilkab23e2122018-02-09 10:25:19 +01002203
Tony Makff31efb2020-03-31 11:13:06 +01002204StatusOr<std::vector<std::vector<AnnotatedSpan>>>
2205Annotator::AnnotateStructuredInput(
2206 const std::vector<InputFragment>& string_fragments,
2207 const AnnotationOptions& options) const {
2208 std::vector<std::vector<AnnotatedSpan>> annotation_candidates(
2209 string_fragments.size());
2210
2211 std::vector<std::string> text_to_annotate;
2212 text_to_annotate.reserve(string_fragments.size());
2213 for (const auto& string_fragment : string_fragments) {
2214 text_to_annotate.push_back(string_fragment.text);
2215 }
2216
2217 // KnowledgeEngine is special, because it supports annotation of multiple
2218 // fragments at once.
2219 if (knowledge_engine_ &&
2220 !knowledge_engine_
2221 ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
2222 options.location_context,
2223 &annotation_candidates)
2224 .ok()) {
2225 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2226 }
2227 // The annotator engines shouldn't change the number of annotation vectors.
2228 if (annotation_candidates.size() != text_to_annotate.size()) {
2229 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2230 << " texts to annotate but generated a different number of "
2231 "lists of annotations:"
2232 << annotation_candidates.size();
2233 return Status(StatusCode::INTERNAL,
2234 "Number of annotation candidates differs from "
2235 "number of texts to annotate.");
2236 }
2237
2238 // Other annotators run on each fragment independently.
2239 for (int i = 0; i < text_to_annotate.size(); ++i) {
2240 AnnotationOptions annotation_options = options;
2241 if (string_fragments[i].datetime_options.has_value()) {
2242 DatetimeOptions reference_datetime =
2243 string_fragments[i].datetime_options.value();
2244 annotation_options.reference_time_ms_utc =
2245 reference_datetime.reference_time_ms_utc;
2246 annotation_options.reference_timezone =
2247 reference_datetime.reference_timezone;
2248 }
2249
2250 AddContactMetadataToKnowledgeClassificationResults(
2251 &annotation_candidates[i]);
2252
2253 Status annotation_status = AnnotateSingleInput(
2254 text_to_annotate[i], annotation_options, &annotation_candidates[i]);
2255 if (!annotation_status.ok()) {
2256 return annotation_status;
2257 }
2258 }
2259 return annotation_candidates;
2260}
2261
2262std::vector<AnnotatedSpan> Annotator::Annotate(
2263 const std::string& context, const AnnotationOptions& options) const {
2264 std::vector<InputFragment> string_fragments;
2265 string_fragments.push_back({.text = context});
2266 StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations =
2267 AnnotateStructuredInput(string_fragments, options);
2268 if (!annotations.ok()) {
2269 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2270 << annotations.status().error_message();
2271 return {};
2272 }
2273 return annotations.ValueOrDie()[0];
Lukas Zilka21d8c982018-01-24 11:11:20 +01002274}
2275
Tony Mak854015a2019-01-16 15:56:48 +00002276CodepointSpan Annotator::ComputeSelectionBoundaries(
2277 const UniLib::RegexMatcher* match,
2278 const RegexModel_::Pattern* config) const {
2279 if (config->capturing_group() == nullptr) {
2280 // Use first capturing group to specify the selection.
2281 int status = UniLib::RegexMatcher::kNoError;
2282 const CodepointSpan result = {match->Start(1, &status),
2283 match->End(1, &status)};
2284 if (status != UniLib::RegexMatcher::kNoError) {
2285 return {kInvalidIndex, kInvalidIndex};
2286 }
2287 return result;
2288 }
2289
2290 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2291 const int num_groups = config->capturing_group()->size();
2292 for (int i = 0; i < num_groups; i++) {
2293 if (!config->capturing_group()->Get(i)->extend_selection()) {
2294 continue;
2295 }
2296
2297 int status = UniLib::RegexMatcher::kNoError;
2298 // Check match and adjust bounds.
2299 const int group_start = match->Start(i, &status);
2300 const int group_end = match->End(i, &status);
2301 if (status != UniLib::RegexMatcher::kNoError) {
2302 return {kInvalidIndex, kInvalidIndex};
2303 }
2304 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2305 continue;
2306 }
2307 if (result.first == kInvalidIndex) {
2308 result = {group_start, group_end};
2309 } else {
2310 result.first = std::min(result.first, group_start);
2311 result.second = std::max(result.second, group_end);
2312 }
2313 }
2314 return result;
2315}
2316
Tony Makd9446602019-02-20 18:25:39 +00002317bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
Tony Mak21460022020-03-12 18:29:35 +00002318 if (pattern->serialized_entity_data() != nullptr ||
2319 pattern->entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002320 return true;
2321 }
2322 if (pattern->capturing_group() != nullptr) {
Tony Mak63959242020-02-07 18:31:16 +00002323 for (const CapturingGroup* group : *pattern->capturing_group()) {
Tony Makd9446602019-02-20 18:25:39 +00002324 if (group->entity_field_path() != nullptr) {
2325 return true;
2326 }
Tony Mak21460022020-03-12 18:29:35 +00002327 if (group->serialized_entity_data() != nullptr ||
2328 group->entity_data() != nullptr) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002329 return true;
2330 }
Tony Makd9446602019-02-20 18:25:39 +00002331 }
2332 }
2333 return false;
2334}
2335
2336bool Annotator::SerializedEntityDataFromRegexMatch(
2337 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2338 std::string* serialized_entity_data) const {
2339 if (!HasEntityData(pattern)) {
2340 serialized_entity_data->clear();
2341 return true;
2342 }
2343 TC3_CHECK(entity_data_builder_ != nullptr);
2344
2345 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
2346 entity_data_builder_->NewRoot();
2347
2348 TC3_CHECK(entity_data != nullptr);
2349
Tony Mak21460022020-03-12 18:29:35 +00002350 // Set fixed entity data.
Tony Makd9446602019-02-20 18:25:39 +00002351 if (pattern->serialized_entity_data() != nullptr) {
Tony Makd9446602019-02-20 18:25:39 +00002352 entity_data->MergeFromSerializedFlatbuffer(
2353 StringPiece(pattern->serialized_entity_data()->c_str(),
2354 pattern->serialized_entity_data()->size()));
2355 }
Tony Mak21460022020-03-12 18:29:35 +00002356 if (pattern->entity_data() != nullptr) {
2357 entity_data->MergeFrom(
2358 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2359 }
Tony Makd9446602019-02-20 18:25:39 +00002360
2361 // Add entity data from rule capturing groups.
2362 if (pattern->capturing_group() != nullptr) {
2363 const int num_groups = pattern->capturing_group()->size();
2364 for (int i = 0; i < num_groups; i++) {
Tony Mak63959242020-02-07 18:31:16 +00002365 const CapturingGroup* group = pattern->capturing_group()->Get(i);
Tony Maka2a1ff42019-09-12 15:40:32 +01002366
2367 // Check whether the group matched.
2368 Optional<std::string> group_match_text =
2369 GetCapturingGroupText(matcher, /*group_id=*/i);
2370 if (!group_match_text.has_value()) {
Tony Makd9446602019-02-20 18:25:39 +00002371 continue;
2372 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002373
Tony Mak21460022020-03-12 18:29:35 +00002374 // Set fixed entity data from capturing group match.
Tony Maka2a1ff42019-09-12 15:40:32 +01002375 if (group->serialized_entity_data() != nullptr) {
2376 entity_data->MergeFromSerializedFlatbuffer(
2377 StringPiece(group->serialized_entity_data()->c_str(),
2378 group->serialized_entity_data()->size()));
2379 }
Tony Mak21460022020-03-12 18:29:35 +00002380 if (group->entity_data() != nullptr) {
2381 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2382 pattern->entity_data()));
2383 }
Tony Maka2a1ff42019-09-12 15:40:32 +01002384
2385 // Set entity field from capturing group text.
2386 if (group->entity_field_path() != nullptr) {
Tony Mak8cd7ba62019-10-15 15:29:22 +01002387 UnicodeText normalized_group_match_text =
2388 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2389
2390 // Apply normalization if specified.
2391 if (group->normalization_options() != nullptr) {
2392 normalized_group_match_text =
2393 NormalizeText(unilib_, group->normalization_options(),
2394 normalized_group_match_text);
2395 }
2396
2397 if (!entity_data->ParseAndSet(
2398 group->entity_field_path(),
2399 normalized_group_match_text.ToUTF8String())) {
Tony Maka2a1ff42019-09-12 15:40:32 +01002400 TC3_LOG(ERROR)
2401 << "Could not set entity data from rule capturing group.";
2402 return false;
2403 }
Tony Makd9446602019-02-20 18:25:39 +00002404 }
2405 }
2406 }
2407
2408 *serialized_entity_data = entity_data->Serialize();
2409 return true;
2410}
2411
Tony Mak63959242020-02-07 18:31:16 +00002412UnicodeText RemoveMoneySeparators(
2413 const std::unordered_set<char32>& decimal_separators,
2414 const UnicodeText& amount,
2415 UnicodeText::const_iterator it_decimal_separator) {
2416 UnicodeText whole_amount;
2417 for (auto it = amount.begin();
2418 it != amount.end() && it != it_decimal_separator; ++it) {
2419 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2420 static_cast<char32>(*it)) == decimal_separators.end()) {
2421 whole_amount.push_back(*it);
2422 }
2423 }
2424 return whole_amount;
2425}
2426
2427bool Annotator::ParseAndFillInMoneyAmount(
2428 std::string* serialized_entity_data) const {
2429 std::unique_ptr<EntityDataT> data =
2430 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2431 *serialized_entity_data);
Tony Mak0b8b3322020-03-17 16:30:19 +00002432 if (data == nullptr) {
2433 TC3_LOG(ERROR)
2434 << "Data field is null when trying to parse Money Entity Data";
2435 return false;
2436 }
2437 if (data->money->unnormalized_amount.empty()) {
2438 TC3_LOG(ERROR) << "Data unnormalized_amount is empty when trying to parse "
2439 "Money Entity Data";
Tony Mak63959242020-02-07 18:31:16 +00002440 return false;
2441 }
2442
2443 UnicodeText amount =
2444 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2445 int separator_back_index = 0;
Tony Mak21460022020-03-12 18:29:35 +00002446 auto it_decimal_separator = --amount.end();
Tony Mak63959242020-02-07 18:31:16 +00002447 for (; it_decimal_separator != amount.begin();
2448 --it_decimal_separator, ++separator_back_index) {
2449 if (std::find(money_separators_.begin(), money_separators_.end(),
2450 static_cast<char32>(*it_decimal_separator)) !=
2451 money_separators_.end()) {
2452 break;
2453 }
2454 }
2455
2456 // If there are 3 digits after the last separator, we consider that a
2457 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2458 // If there is no separator in number, also that number is an int.
Tony Mak21460022020-03-12 18:29:35 +00002459 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
Tony Mak63959242020-02-07 18:31:16 +00002460 it_decimal_separator = amount.end();
2461 }
2462
2463 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2464 it_decimal_separator),
2465 &data->money->amount_whole_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002466 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2467 "amount: "
2468 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002469 return false;
2470 }
2471 if (it_decimal_separator == amount.end()) {
2472 data->money->amount_decimal_part = 0;
2473 } else {
2474 const int amount_codepoints_size = amount.size_codepoints();
2475 if (!unilib_->ParseInt32(
2476 UnicodeText::Substring(
Tony Mak21460022020-03-12 18:29:35 +00002477 amount, amount_codepoints_size - separator_back_index,
Tony Mak63959242020-02-07 18:31:16 +00002478 amount_codepoints_size, /*do_copy=*/false),
2479 &data->money->amount_decimal_part)) {
Tony Mak0b8b3322020-03-17 16:30:19 +00002480 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2481 "the amount: "
2482 << data->money->unnormalized_amount;
Tony Mak63959242020-02-07 18:31:16 +00002483 return false;
2484 }
2485 }
2486
2487 *serialized_entity_data =
2488 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2489 return true;
2490}
2491
Tony Mak6c4cc672018-09-17 11:48:50 +01002492bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2493 const std::vector<int>& rules,
Tony Mak83d2de62019-04-10 16:12:15 +01002494 std::vector<AnnotatedSpan>* result,
2495 bool is_serialized_entity_data_enabled) const {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002496 for (int pattern_id : rules) {
2497 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2498 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2499 if (!matcher) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002500 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2501 << pattern_id;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002502 return false;
2503 }
2504
2505 int status = UniLib::RegexMatcher::kNoError;
2506 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
Tony Mak854015a2019-01-16 15:56:48 +00002507 if (regex_pattern.config->verification_options()) {
Tony Makdf54e742019-03-26 14:04:00 +00002508 if (!VerifyRegexMatchCandidate(
2509 context_unicode.ToUTF8String(),
2510 regex_pattern.config->verification_options(),
2511 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002512 continue;
2513 }
2514 }
Tony Makd9446602019-02-20 18:25:39 +00002515
2516 std::string serialized_entity_data;
Tony Mak83d2de62019-04-10 16:12:15 +01002517 if (is_serialized_entity_data_enabled) {
2518 if (!SerializedEntityDataFromRegexMatch(
2519 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2520 TC3_LOG(ERROR) << "Could not get entity data.";
2521 return false;
2522 }
Tony Mak63959242020-02-07 18:31:16 +00002523
2524 // Further parsing unnormalized_amount for money into amount_whole_part
2525 // and amount_decimal_part. Can't do this with regexes because we cannot
2526 // have empty groups (amount_decimal_part might be an empty group).
2527 if (regex_pattern.config->collection_name()->str() ==
2528 Collections::Money()) {
2529 if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
2530 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2531 }
2532 }
Tony Makd9446602019-02-20 18:25:39 +00002533 }
2534
Lukas Zilkab23e2122018-02-09 10:25:19 +01002535 result->emplace_back();
Tony Mak854015a2019-01-16 15:56:48 +00002536
Lukas Zilkab23e2122018-02-09 10:25:19 +01002537 // Selection/annotation regular expressions need to specify a capturing
2538 // group specifying the selection.
Tony Mak854015a2019-01-16 15:56:48 +00002539 result->back().span =
2540 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2541
Lukas Zilkab23e2122018-02-09 10:25:19 +01002542 result->back().classification = {
Tony Mak854015a2019-01-16 15:56:48 +00002543 {regex_pattern.config->collection_name()->str(),
2544 regex_pattern.config->target_classification_score(),
2545 regex_pattern.config->priority_score()}};
Tony Makd9446602019-02-20 18:25:39 +00002546
2547 result->back().classification[0].serialized_entity_data =
2548 serialized_entity_data;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002549 }
2550 }
2551 return true;
2552}
2553
Tony Mak6c4cc672018-09-17 11:48:50 +01002554bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2555 tflite::Interpreter* selection_interpreter,
2556 const CachedFeatures& cached_features,
2557 std::vector<TokenSpan>* chunks) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +01002558 const int max_selection_span =
2559 selection_feature_processor_->GetOptions()->max_selection_span();
Lukas Zilka21d8c982018-01-24 11:11:20 +01002560 // The inference span is the span of interest expanded to include
2561 // max_selection_span tokens on either side, which is how far a selection can
2562 // stretch from the click.
2563 const TokenSpan inference_span = IntersectTokenSpans(
2564 ExpandTokenSpan(span_of_interest,
2565 /*num_tokens_left=*/max_selection_span,
2566 /*num_tokens_right=*/max_selection_span),
2567 {0, num_tokens});
2568
2569 std::vector<ScoredChunk> scored_chunks;
Lukas Zilkab23e2122018-02-09 10:25:19 +01002570 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2571 selection_feature_processor_->GetOptions()
2572 ->bounds_sensitive_features()
2573 ->enabled()) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002574 if (!ModelBoundsSensitiveScoreChunks(
2575 num_tokens, span_of_interest, inference_span, cached_features,
2576 selection_interpreter, &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002577 return false;
2578 }
2579 } else {
2580 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002581 cached_features, selection_interpreter,
2582 &scored_chunks)) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01002583 return false;
Lukas Zilka21d8c982018-01-24 11:11:20 +01002584 }
2585 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002586 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2587 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2588 return lhs.score < rhs.score;
2589 });
Lukas Zilka21d8c982018-01-24 11:11:20 +01002590
2591 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2592 // them greedily as long as they do not overlap with any previously picked
2593 // chunks.
2594 std::vector<bool> token_used(TokenSpanSize(inference_span));
2595 chunks->clear();
2596 for (const ScoredChunk& scored_chunk : scored_chunks) {
2597 bool feasible = true;
2598 for (int i = scored_chunk.token_span.first;
2599 i < scored_chunk.token_span.second; ++i) {
2600 if (token_used[i - inference_span.first]) {
2601 feasible = false;
2602 break;
2603 }
2604 }
2605
2606 if (!feasible) {
2607 continue;
2608 }
2609
2610 for (int i = scored_chunk.token_span.first;
2611 i < scored_chunk.token_span.second; ++i) {
2612 token_used[i - inference_span.first] = true;
2613 }
2614
2615 chunks->push_back(scored_chunk.token_span);
2616 }
2617
2618 std::sort(chunks->begin(), chunks->end());
2619
2620 return true;
2621}
2622
Lukas Zilkab23e2122018-02-09 10:25:19 +01002623namespace {
2624// Updates the value at the given key in the map to maximum of the current value
2625// and the given value, or simply inserts the value if the key is not yet there.
2626template <typename Map>
2627void UpdateMax(Map* map, typename Map::key_type key,
2628 typename Map::mapped_type value) {
2629 const auto it = map->find(key);
2630 if (it != map->end()) {
2631 it->second = std::max(it->second, value);
2632 } else {
2633 (*map)[key] = value;
2634 }
2635}
2636} // namespace
2637
Tony Mak6c4cc672018-09-17 11:48:50 +01002638bool Annotator::ModelClickContextScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002639 int num_tokens, const TokenSpan& span_of_interest,
2640 const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002641 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002642 std::vector<ScoredChunk>* scored_chunks) const {
2643 const int max_batch_size = model_->selection_options()->batch_size();
2644
2645 std::vector<float> all_features;
2646 std::map<TokenSpan, float> chunk_scores;
2647 for (int batch_start = span_of_interest.first;
2648 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2649 const int batch_end =
2650 std::min(batch_start + max_batch_size, span_of_interest.second);
2651
2652 // Prepare features for the whole batch.
2653 all_features.clear();
2654 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2655 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2656 cached_features.AppendClickContextFeaturesForClick(click_pos,
2657 &all_features);
2658 }
2659
2660 // Run batched inference.
2661 const int batch_size = batch_end - batch_start;
2662 const int features_size = cached_features.OutputFeaturesSize();
2663 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002664 TensorView<float>(all_features.data(), {batch_size, features_size}),
2665 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002666 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002667 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002668 return false;
2669 }
2670 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2671 logits.dim(1) !=
2672 selection_feature_processor_->GetSelectionLabelCount()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002673 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002674 return false;
2675 }
2676
2677 // Save results.
2678 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2679 const std::vector<float> scores = ComputeSoftmax(
2680 logits.data() + logits.dim(1) * (click_pos - batch_start),
2681 logits.dim(1));
2682 for (int j = 0;
2683 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2684 TokenSpan relative_token_span;
2685 if (!selection_feature_processor_->LabelToTokenSpan(
2686 j, &relative_token_span)) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002687 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002688 return false;
2689 }
2690 const TokenSpan candidate_span = ExpandTokenSpan(
2691 SingleTokenSpan(click_pos), relative_token_span.first,
2692 relative_token_span.second);
2693 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2694 UpdateMax(&chunk_scores, candidate_span, scores[j]);
2695 }
2696 }
2697 }
2698 }
2699
2700 scored_chunks->clear();
2701 scored_chunks->reserve(chunk_scores.size());
2702 for (const auto& entry : chunk_scores) {
2703 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2704 }
2705
2706 return true;
2707}
2708
Tony Mak6c4cc672018-09-17 11:48:50 +01002709bool Annotator::ModelBoundsSensitiveScoreChunks(
Lukas Zilkab23e2122018-02-09 10:25:19 +01002710 int num_tokens, const TokenSpan& span_of_interest,
2711 const TokenSpan& inference_span, const CachedFeatures& cached_features,
Lukas Zilkaba849e72018-03-08 14:48:21 +01002712 tflite::Interpreter* selection_interpreter,
Lukas Zilkab23e2122018-02-09 10:25:19 +01002713 std::vector<ScoredChunk>* scored_chunks) const {
2714 const int max_selection_span =
2715 selection_feature_processor_->GetOptions()->max_selection_span();
2716 const int max_chunk_length = selection_feature_processor_->GetOptions()
2717 ->selection_reduced_output_space()
2718 ? max_selection_span + 1
2719 : 2 * max_selection_span + 1;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002720 const bool score_single_token_spans_as_zero =
2721 selection_feature_processor_->GetOptions()
2722 ->bounds_sensitive_features()
2723 ->score_single_token_spans_as_zero();
2724
2725 scored_chunks->clear();
2726 if (score_single_token_spans_as_zero) {
2727 scored_chunks->reserve(TokenSpanSize(span_of_interest));
2728 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002729
2730 // Prepare all chunk candidates into one batch:
2731 // - Are contained in the inference span
2732 // - Have a non-empty intersection with the span of interest
2733 // - Are at least one token long
2734 // - Are not longer than the maximum chunk length
2735 std::vector<TokenSpan> candidate_spans;
2736 for (int start = inference_span.first; start < span_of_interest.second;
2737 ++start) {
2738 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2739 for (int end = leftmost_end_index;
2740 end <= inference_span.second && end - start <= max_chunk_length;
2741 ++end) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01002742 const TokenSpan candidate_span = {start, end};
2743 if (score_single_token_spans_as_zero &&
2744 TokenSpanSize(candidate_span) == 1) {
2745 // Do not include the single token span in the batch, add a zero score
2746 // for it directly to the output.
2747 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2748 } else {
2749 candidate_spans.push_back(candidate_span);
2750 }
Lukas Zilkab23e2122018-02-09 10:25:19 +01002751 }
2752 }
2753
2754 const int max_batch_size = model_->selection_options()->batch_size();
2755
2756 std::vector<float> all_features;
Lukas Zilkaba849e72018-03-08 14:48:21 +01002757 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
Lukas Zilkab23e2122018-02-09 10:25:19 +01002758 for (int batch_start = 0; batch_start < candidate_spans.size();
2759 batch_start += max_batch_size) {
2760 const int batch_end = std::min(batch_start + max_batch_size,
2761 static_cast<int>(candidate_spans.size()));
2762
2763 // Prepare features for the whole batch.
2764 all_features.clear();
2765 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2766 for (int i = batch_start; i < batch_end; ++i) {
2767 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2768 &all_features);
2769 }
2770
2771 // Run batched inference.
2772 const int batch_size = batch_end - batch_start;
2773 const int features_size = cached_features.OutputFeaturesSize();
2774 TensorView<float> logits = selection_executor_->ComputeLogits(
Lukas Zilkaba849e72018-03-08 14:48:21 +01002775 TensorView<float>(all_features.data(), {batch_size, features_size}),
2776 selection_interpreter);
Lukas Zilkab23e2122018-02-09 10:25:19 +01002777 if (!logits.is_valid()) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002778 TC3_LOG(ERROR) << "Couldn't compute logits.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002779 return false;
2780 }
2781 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2782 logits.dim(1) != 1) {
Tony Mak6c4cc672018-09-17 11:48:50 +01002783 TC3_LOG(ERROR) << "Mismatching output.";
Lukas Zilkab23e2122018-02-09 10:25:19 +01002784 return false;
2785 }
2786
2787 // Save results.
2788 for (int i = batch_start; i < batch_end; ++i) {
2789 scored_chunks->push_back(
2790 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2791 }
2792 }
2793
2794 return true;
2795}
2796
Tony Mak6c4cc672018-09-17 11:48:50 +01002797bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
2798 int64 reference_time_ms_utc,
2799 const std::string& reference_timezone,
2800 const std::string& locales, ModeFlag mode,
Tony Makd9446602019-02-20 18:25:39 +00002801 AnnotationUsecase annotation_usecase,
Tony Mak83d2de62019-04-10 16:12:15 +01002802 bool is_serialized_entity_data_enabled,
Tony Mak6c4cc672018-09-17 11:48:50 +01002803 std::vector<AnnotatedSpan>* result) const {
Tony Mak63959242020-02-07 18:31:16 +00002804 std::vector<DatetimeParseResultSpan> datetime_spans;
2805 if (cfg_datetime_parser_) {
2806 if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
2807 return true;
2808 }
2809 std::vector<Locale> parsed_locales;
2810 ParseLocales(locales, &parsed_locales);
Tony Mak21460022020-03-12 18:29:35 +00002811 cfg_datetime_parser_->Parse(
2812 context_unicode.ToUTF8String(),
2813 ToDateAnnotationOptions(
2814 model_->grammar_datetime_model()->annotation_options(),
2815 reference_timezone, reference_time_ms_utc),
2816 parsed_locales, &datetime_spans);
Tony Makd99d58c2020-03-19 21:52:02 +00002817 }
2818
2819 if (datetime_parser_) {
Tony Mak63959242020-02-07 18:31:16 +00002820 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
2821 reference_timezone, locales, mode,
2822 annotation_usecase,
2823 /*anchor_start_end=*/false, &datetime_spans)) {
2824 return false;
2825 }
Lukas Zilkae7962cc2018-03-28 18:09:48 +02002826 }
2827
Lukas Zilkab23e2122018-02-09 10:25:19 +01002828 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
Tony Mak378c1f52019-03-04 15:58:11 +00002829 AnnotatedSpan annotated_span;
2830 annotated_span.span = datetime_span.span;
Tony Mak854015a2019-01-16 15:56:48 +00002831 for (const DatetimeParseResult& parse_result : datetime_span.data) {
Tony Mak378c1f52019-03-04 15:58:11 +00002832 annotated_span.classification.emplace_back(
2833 PickCollectionForDatetime(parse_result),
2834 datetime_span.target_classification_score,
2835 datetime_span.priority_score);
2836 annotated_span.classification.back().datetime_parse_result = parse_result;
Tony Mak83d2de62019-04-10 16:12:15 +01002837 if (is_serialized_entity_data_enabled) {
2838 annotated_span.classification.back().serialized_entity_data =
2839 CreateDatetimeSerializedEntityData(parse_result);
2840 }
Tony Mak854015a2019-01-16 15:56:48 +00002841 }
Tony Mak448b5862019-03-22 13:36:41 +00002842 annotated_span.source = AnnotatedSpan::Source::DATETIME;
Tony Mak378c1f52019-03-04 15:58:11 +00002843 result->push_back(std::move(annotated_span));
Lukas Zilkab23e2122018-02-09 10:25:19 +01002844 }
2845 return true;
2846}
2847
Tony Mak378c1f52019-03-04 15:58:11 +00002848const Model* Annotator::model() const { return model_; }
Tony Makd9446602019-02-20 18:25:39 +00002849const reflection::Schema* Annotator::entity_data_schema() const {
2850 return entity_data_schema_;
2851}
Tony Mak854015a2019-01-16 15:56:48 +00002852
Lukas Zilka21d8c982018-01-24 11:11:20 +01002853const Model* ViewModel(const void* buffer, int size) {
2854 if (!buffer) {
2855 return nullptr;
2856 }
2857
2858 return LoadAndVerifyModel(buffer, size);
2859}
2860
Tony Makd9446602019-02-20 18:25:39 +00002861bool Annotator::LookUpKnowledgeEntity(
2862 const std::string& id, std::string* serialized_knowledge_result) const {
2863 return knowledge_engine_ &&
2864 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
2865}
2866
Tony Mak6c4cc672018-09-17 11:48:50 +01002867} // namespace libtextclassifier3