| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "text-classifier.h" |
| |
| #include <algorithm> |
| #include <cctype> |
| #include <cmath> |
| #include <iterator> |
| #include <numeric> |
| |
| #include "util/base/logging.h" |
| #include "util/math/softmax.h" |
| #include "util/utf8/unicodetext.h" |
| |
| namespace libtextclassifier2 { |
| const std::string& TextClassifier::kOtherCollection = |
| *[]() { return new std::string("other"); }(); |
| const std::string& TextClassifier::kPhoneCollection = |
| *[]() { return new std::string("phone"); }(); |
| const std::string& TextClassifier::kAddressCollection = |
| *[]() { return new std::string("address"); }(); |
| const std::string& TextClassifier::kDateCollection = |
| *[]() { return new std::string("date"); }(); |
| |
| namespace { |
| const Model* LoadAndVerifyModel(const void* addr, int size) { |
| const Model* model = GetModel(addr); |
| |
| flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size); |
| if (model->Verify(verifier)) { |
| return model; |
| } else { |
| return nullptr; |
| } |
| } |
| } // namespace |
| |
| tflite::Interpreter* InterpreterManager::SelectionInterpreter() { |
| if (!selection_interpreter_) { |
| TC_CHECK(selection_executor_); |
| selection_interpreter_ = selection_executor_->CreateInterpreter(); |
| if (!selection_interpreter_) { |
| TC_LOG(ERROR) << "Could not build TFLite interpreter."; |
| } |
| } |
| return selection_interpreter_.get(); |
| } |
| |
| tflite::Interpreter* InterpreterManager::ClassificationInterpreter() { |
| if (!classification_interpreter_) { |
| TC_CHECK(classification_executor_); |
| classification_interpreter_ = classification_executor_->CreateInterpreter(); |
| if (!classification_interpreter_) { |
| TC_LOG(ERROR) << "Could not build TFLite interpreter."; |
| } |
| } |
| return classification_interpreter_.get(); |
| } |
| |
| std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer( |
| const char* buffer, int size, const UniLib* unilib) { |
| const Model* model = LoadAndVerifyModel(buffer, size); |
| if (model == nullptr) { |
| return nullptr; |
| } |
| |
| auto classifier = |
| std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib)); |
| if (!classifier->IsInitialized()) { |
| return nullptr; |
| } |
| |
| return classifier; |
| } |
| |
| std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap( |
| std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) { |
| if (!(*mmap)->handle().ok()) { |
| TC_VLOG(1) << "Mmap failed."; |
| return nullptr; |
| } |
| |
| const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), |
| (*mmap)->handle().num_bytes()); |
| if (!model) { |
| TC_LOG(ERROR) << "Model verification failed."; |
| return nullptr; |
| } |
| |
| auto classifier = |
| std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib)); |
| if (!classifier->IsInitialized()) { |
| return nullptr; |
| } |
| |
| return classifier; |
| } |
| |
| std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( |
| int fd, int offset, int size, const UniLib* unilib) { |
| std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); |
| return FromScopedMmap(&mmap, unilib); |
| } |
| |
| std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( |
| int fd, const UniLib* unilib) { |
| std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); |
| return FromScopedMmap(&mmap, unilib); |
| } |
| |
| std::unique_ptr<TextClassifier> TextClassifier::FromPath( |
| const std::string& path, const UniLib* unilib) { |
| std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); |
| return FromScopedMmap(&mmap, unilib); |
| } |
| |
| void TextClassifier::ValidateAndInitialize() { |
| initialized_ = false; |
| |
| if (model_ == nullptr) { |
| TC_LOG(ERROR) << "No model specified."; |
| return; |
| } |
| |
| const bool model_enabled_for_annotation = |
| (model_->triggering_options() != nullptr && |
| (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)); |
| const bool model_enabled_for_classification = |
| (model_->triggering_options() != nullptr && |
| (model_->triggering_options()->enabled_modes() & |
| ModeFlag_CLASSIFICATION)); |
| const bool model_enabled_for_selection = |
| (model_->triggering_options() != nullptr && |
| (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)); |
| |
| // Annotation requires the selection model. |
| if (model_enabled_for_annotation || model_enabled_for_selection) { |
| if (!model_->selection_options()) { |
| TC_LOG(ERROR) << "No selection options."; |
| return; |
| } |
| if (!model_->selection_feature_options()) { |
| TC_LOG(ERROR) << "No selection feature options."; |
| return; |
| } |
| if (!model_->selection_feature_options()->bounds_sensitive_features()) { |
| TC_LOG(ERROR) << "No selection bounds sensitive feature options."; |
| return; |
| } |
| if (!model_->selection_model()) { |
| TC_LOG(ERROR) << "No selection model."; |
| return; |
| } |
| selection_executor_ = ModelExecutor::Instance(model_->selection_model()); |
| if (!selection_executor_) { |
| TC_LOG(ERROR) << "Could not initialize selection executor."; |
| return; |
| } |
| selection_feature_processor_.reset( |
| new FeatureProcessor(model_->selection_feature_options(), unilib_)); |
| } |
| |
| // Annotation requires the classification model for conflict resolution and |
| // scoring. |
| // Selection requires the classification model for conflict resolution. |
| if (model_enabled_for_annotation || model_enabled_for_classification || |
| model_enabled_for_selection) { |
| if (!model_->classification_options()) { |
| TC_LOG(ERROR) << "No classification options."; |
| return; |
| } |
| |
| if (!model_->classification_feature_options()) { |
| TC_LOG(ERROR) << "No classification feature options."; |
| return; |
| } |
| |
| if (!model_->classification_feature_options() |
| ->bounds_sensitive_features()) { |
| TC_LOG(ERROR) << "No classification bounds sensitive feature options."; |
| return; |
| } |
| if (!model_->classification_model()) { |
| TC_LOG(ERROR) << "No clf model."; |
| return; |
| } |
| |
| classification_executor_ = |
| ModelExecutor::Instance(model_->classification_model()); |
| if (!classification_executor_) { |
| TC_LOG(ERROR) << "Could not initialize classification executor."; |
| return; |
| } |
| |
| classification_feature_processor_.reset(new FeatureProcessor( |
| model_->classification_feature_options(), unilib_)); |
| } |
| |
| // The embeddings need to be specified if the model is to be used for |
| // classification or selection. |
| if (model_enabled_for_annotation || model_enabled_for_classification || |
| model_enabled_for_selection) { |
| if (!model_->embedding_model()) { |
| TC_LOG(ERROR) << "No embedding model."; |
| return; |
| } |
| |
| // Check that the embedding size of the selection and classification model |
| // matches, as they are using the same embeddings. |
| if (model_enabled_for_selection && |
| (model_->selection_feature_options()->embedding_size() != |
| model_->classification_feature_options()->embedding_size() || |
| model_->selection_feature_options()->embedding_quantization_bits() != |
| model_->classification_feature_options() |
| ->embedding_quantization_bits())) { |
| TC_LOG(ERROR) << "Mismatching embedding size/quantization."; |
| return; |
| } |
| |
| embedding_executor_ = TFLiteEmbeddingExecutor::Instance( |
| model_->embedding_model(), |
| model_->classification_feature_options()->embedding_size(), |
| model_->classification_feature_options() |
| ->embedding_quantization_bits()); |
| if (!embedding_executor_) { |
| TC_LOG(ERROR) << "Could not initialize embedding executor."; |
| return; |
| } |
| } |
| |
| std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); |
| if (model_->regex_model()) { |
| if (!InitializeRegexModel(decompressor.get())) { |
| TC_LOG(ERROR) << "Could not initialize regex model."; |
| return; |
| } |
| } |
| |
| if (model_->datetime_model()) { |
| datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(), |
| *unilib_, decompressor.get()); |
| if (!datetime_parser_) { |
| TC_LOG(ERROR) << "Could not initialize datetime parser."; |
| return; |
| } |
| } |
| |
| if (model_->output_options()) { |
| if (model_->output_options()->filtered_collections_annotation()) { |
| for (const auto collection : |
| *model_->output_options()->filtered_collections_annotation()) { |
| filtered_collections_annotation_.insert(collection->str()); |
| } |
| } |
| if (model_->output_options()->filtered_collections_classification()) { |
| for (const auto collection : |
| *model_->output_options()->filtered_collections_classification()) { |
| filtered_collections_classification_.insert(collection->str()); |
| } |
| } |
| if (model_->output_options()->filtered_collections_selection()) { |
| for (const auto collection : |
| *model_->output_options()->filtered_collections_selection()) { |
| filtered_collections_selection_.insert(collection->str()); |
| } |
| } |
| } |
| |
| initialized_ = true; |
| } |
| |
| bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { |
| if (!model_->regex_model()->patterns()) { |
| return true; |
| } |
| |
| // Initialize pattern recognizers. |
| int regex_pattern_id = 0; |
| for (const auto& regex_pattern : *model_->regex_model()->patterns()) { |
| std::unique_ptr<UniLib::RegexPattern> compiled_pattern = |
| UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(), |
| regex_pattern->compressed_pattern(), |
| decompressor); |
| if (!compiled_pattern) { |
| TC_LOG(INFO) << "Failed to load regex pattern"; |
| return false; |
| } |
| |
| if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) { |
| annotation_regex_patterns_.push_back(regex_pattern_id); |
| } |
| if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) { |
| classification_regex_patterns_.push_back(regex_pattern_id); |
| } |
| if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) { |
| selection_regex_patterns_.push_back(regex_pattern_id); |
| } |
| regex_patterns_.push_back({regex_pattern->collection_name()->str(), |
| regex_pattern->target_classification_score(), |
| regex_pattern->priority_score(), |
| std::move(compiled_pattern)}); |
| if (regex_pattern->use_approximate_matching()) { |
| regex_approximate_match_pattern_ids_.insert(regex_pattern_id); |
| } |
| ++regex_pattern_id; |
| } |
| |
| return true; |
| } |
| |
| namespace { |
| |
| int CountDigits(const std::string& str, CodepointSpan selection_indices) { |
| int count = 0; |
| int i = 0; |
| const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); |
| for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { |
| if (i >= selection_indices.first && i < selection_indices.second && |
| isdigit(*it)) { |
| ++count; |
| } |
| } |
| return count; |
| } |
| |
| std::string ExtractSelection(const std::string& context, |
| CodepointSpan selection_indices) { |
| const UnicodeText context_unicode = |
| UTF8ToUnicodeText(context, /*do_copy=*/false); |
| auto selection_begin = context_unicode.begin(); |
| std::advance(selection_begin, selection_indices.first); |
| auto selection_end = context_unicode.begin(); |
| std::advance(selection_end, selection_indices.second); |
| return UnicodeText::UTF8Substring(selection_begin, selection_end); |
| } |
| } // namespace |
| |
| namespace internal { |
| // Helper function, which if the initial 'span' contains only white-spaces, |
| // moves the selection to a single-codepoint selection on a left or right side |
| // of this space. |
| CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, |
| const UnicodeText& context_unicode, |
| const UniLib& unilib) { |
| TC_CHECK(ValidNonEmptySpan(span)); |
| |
| UnicodeText::const_iterator it; |
| |
| // Check that the current selection is all whitespaces. |
| it = context_unicode.begin(); |
| std::advance(it, span.first); |
| for (int i = 0; i < (span.second - span.first); ++i, ++it) { |
| if (!unilib.IsWhitespace(*it)) { |
| return span; |
| } |
| } |
| |
| CodepointSpan result; |
| |
| // Try moving left. |
| result = span; |
| it = context_unicode.begin(); |
| std::advance(it, span.first); |
| while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) { |
| --result.first; |
| --it; |
| } |
| result.second = result.first + 1; |
| if (!unilib.IsWhitespace(*it)) { |
| return result; |
| } |
| |
| // If moving left didn't find a non-whitespace character, just return the |
| // original span. |
| return span; |
| } |
| } // namespace internal |
| |
| bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const { |
| return !span.classification.empty() && |
| filtered_collections_annotation_.find( |
| span.classification[0].collection) != |
| filtered_collections_annotation_.end(); |
| } |
| |
| bool TextClassifier::FilteredForClassification( |
| const ClassificationResult& classification) const { |
| return filtered_collections_classification_.find(classification.collection) != |
| filtered_collections_classification_.end(); |
| } |
| |
| bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const { |
| return !span.classification.empty() && |
| filtered_collections_selection_.find( |
| span.classification[0].collection) != |
| filtered_collections_selection_.end(); |
| } |
| |
| CodepointSpan TextClassifier::SuggestSelection( |
| const std::string& context, CodepointSpan click_indices, |
| const SelectionOptions& options) const { |
| CodepointSpan original_click_indices = click_indices; |
| if (!initialized_) { |
| TC_LOG(ERROR) << "Not initialized"; |
| return original_click_indices; |
| } |
| if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { |
| return original_click_indices; |
| } |
| |
| const UnicodeText context_unicode = UTF8ToUnicodeText(context, |
| /*do_copy=*/false); |
| |
| if (!context_unicode.is_valid()) { |
| return original_click_indices; |
| } |
| |
| const int context_codepoint_size = context_unicode.size_codepoints(); |
| |
| if (click_indices.first < 0 || click_indices.second < 0 || |
| click_indices.first >= context_codepoint_size || |
| click_indices.second > context_codepoint_size || |
| click_indices.first >= click_indices.second) { |
| TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " |
| << click_indices.first << " " << click_indices.second; |
| return original_click_indices; |
| } |
| |
| if (model_->snap_whitespace_selections()) { |
| // We want to expand a purely white-space selection to a multi-selection it |
| // would've been part of. But with this feature disabled we would do a no- |
| // op, because no token is found. Therefore, we need to modify the |
| // 'click_indices' a bit to include a part of the token, so that the click- |
| // finding logic finds the clicked token correctly. This modification is |
| // done by the following function. Note, that it's enough to check the left |
| // side of the current selection, because if the white-space is a part of a |
| // multi-selection, neccessarily both tokens - on the left and the right |
| // sides need to be selected. Thus snapping only to the left is sufficient |
| // (there's a check at the bottom that makes sure that if we snap to the |
| // left token but the result does not contain the initial white-space, |
| // returns the original indices). |
| click_indices = internal::SnapLeftIfWhitespaceSelection( |
| click_indices, context_unicode, *unilib_); |
| } |
| |
| std::vector<AnnotatedSpan> candidates; |
| InterpreterManager interpreter_manager(selection_executor_.get(), |
| classification_executor_.get()); |
| std::vector<Token> tokens; |
| if (!ModelSuggestSelection(context_unicode, click_indices, |
| &interpreter_manager, &tokens, &candidates)) { |
| TC_LOG(ERROR) << "Model suggest selection failed."; |
| return original_click_indices; |
| } |
| if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) { |
| TC_LOG(ERROR) << "Regex suggest selection failed."; |
| return original_click_indices; |
| } |
| if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), |
| /*reference_time_ms_utc=*/0, /*reference_timezone=*/"", |
| options.locales, ModeFlag_SELECTION, &candidates)) { |
| TC_LOG(ERROR) << "Datetime suggest selection failed."; |
| return original_click_indices; |
| } |
| |
| // Sort candidates according to their position in the input, so that the next |
| // code can assume that any connected component of overlapping spans forms a |
| // contiguous block. |
| std::sort(candidates.begin(), candidates.end(), |
| [](const AnnotatedSpan& a, const AnnotatedSpan& b) { |
| return a.span.first < b.span.first; |
| }); |
| |
| std::vector<int> candidate_indices; |
| if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, |
| &candidate_indices)) { |
| TC_LOG(ERROR) << "Couldn't resolve conflicts."; |
| return original_click_indices; |
| } |
| |
| for (const int i : candidate_indices) { |
| if (SpansOverlap(candidates[i].span, click_indices) && |
| SpansOverlap(candidates[i].span, original_click_indices)) { |
| // Run model classification if not present but requested and there's a |
| // classification collection filter specified. |
| if (candidates[i].classification.empty() && |
| model_->selection_options()->always_classify_suggested_selection() && |
| !filtered_collections_selection_.empty()) { |
| if (!ModelClassifyText( |
| context, candidates[i].span, &interpreter_manager, |
| /*embedding_cache=*/nullptr, &candidates[i].classification)) { |
| return original_click_indices; |
| } |
| } |
| |
| // Ignore if span classification is filtered. |
| if (FilteredForSelection(candidates[i])) { |
| return original_click_indices; |
| } |
| |
| return candidates[i].span; |
| } |
| } |
| |
| return original_click_indices; |
| } |
| |
| namespace { |
| // Helper function that returns the index of the first candidate that |
| // transitively does not overlap with the candidate on 'start_index'. If the end |
| // of 'candidates' is reached, it returns the index that points right behind the |
| // array. |
| int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates, |
| int start_index) { |
| int first_non_overlapping = start_index + 1; |
| CodepointSpan conflicting_span = candidates[start_index].span; |
| while ( |
| first_non_overlapping < candidates.size() && |
| SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) { |
| // Grow the span to include the current one. |
| conflicting_span.second = std::max( |
| conflicting_span.second, candidates[first_non_overlapping].span.second); |
| |
| ++first_non_overlapping; |
| } |
| return first_non_overlapping; |
| } |
| } // namespace |
| |
| bool TextClassifier::ResolveConflicts( |
| const std::vector<AnnotatedSpan>& candidates, const std::string& context, |
| const std::vector<Token>& cached_tokens, |
| InterpreterManager* interpreter_manager, std::vector<int>* result) const { |
| result->clear(); |
| result->reserve(candidates.size()); |
| for (int i = 0; i < candidates.size();) { |
| int first_non_overlapping = |
| FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i); |
| |
| const bool conflict_found = first_non_overlapping != (i + 1); |
| if (conflict_found) { |
| std::vector<int> candidate_indices; |
| if (!ResolveConflict(context, cached_tokens, candidates, i, |
| first_non_overlapping, interpreter_manager, |
| &candidate_indices)) { |
| return false; |
| } |
| result->insert(result->end(), candidate_indices.begin(), |
| candidate_indices.end()); |
| } else { |
| result->push_back(i); |
| } |
| |
| // Skip over the whole conflicting group/go to next candidate. |
| i = first_non_overlapping; |
| } |
| return true; |
| } |
| |
| namespace { |
| inline bool ClassifiedAsOther( |
| const std::vector<ClassificationResult>& classification) { |
| return !classification.empty() && |
| classification[0].collection == TextClassifier::kOtherCollection; |
| } |
| |
| float GetPriorityScore( |
| const std::vector<ClassificationResult>& classification) { |
| if (!ClassifiedAsOther(classification)) { |
| return classification[0].priority_score; |
| } else { |
| return -1.0; |
| } |
| } |
| } // namespace |
| |
| bool TextClassifier::ResolveConflict( |
| const std::string& context, const std::vector<Token>& cached_tokens, |
| const std::vector<AnnotatedSpan>& candidates, int start_index, |
| int end_index, InterpreterManager* interpreter_manager, |
| std::vector<int>* chosen_indices) const { |
| std::vector<int> conflicting_indices; |
| std::unordered_map<int, float> scores; |
| for (int i = start_index; i < end_index; ++i) { |
| conflicting_indices.push_back(i); |
| if (!candidates[i].classification.empty()) { |
| scores[i] = GetPriorityScore(candidates[i].classification); |
| continue; |
| } |
| |
| // OPTIMIZATION: So that we don't have to classify all the ML model |
| // spans apriori, we wait until we get here, when they conflict with |
| // something and we need the actual classification scores. So if the |
| // candidate conflicts and comes from the model, we need to run a |
| // classification to determine its priority: |
| std::vector<ClassificationResult> classification; |
| if (!ModelClassifyText(context, cached_tokens, candidates[i].span, |
| interpreter_manager, |
| /*embedding_cache=*/nullptr, &classification)) { |
| return false; |
| } |
| |
| if (!classification.empty()) { |
| scores[i] = GetPriorityScore(classification); |
| } |
| } |
| |
| std::sort(conflicting_indices.begin(), conflicting_indices.end(), |
| [&scores](int i, int j) { return scores[i] > scores[j]; }); |
| |
| // Keeps the candidates sorted by their position in the text (their left span |
| // index) for fast retrieval down. |
| std::set<int, std::function<bool(int, int)>> chosen_indices_set( |
| [&candidates](int a, int b) { |
| return candidates[a].span.first < candidates[b].span.first; |
| }); |
| |
| // Greedily place the candidates if they don't conflict with the already |
| // placed ones. |
| for (int i = 0; i < conflicting_indices.size(); ++i) { |
| const int considered_candidate = conflicting_indices[i]; |
| if (!DoesCandidateConflict(considered_candidate, candidates, |
| chosen_indices_set)) { |
| chosen_indices_set.insert(considered_candidate); |
| } |
| } |
| |
| *chosen_indices = |
| std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end()); |
| |
| return true; |
| } |
| |
| bool TextClassifier::ModelSuggestSelection( |
| const UnicodeText& context_unicode, CodepointSpan click_indices, |
| InterpreterManager* interpreter_manager, std::vector<Token>* tokens, |
| std::vector<AnnotatedSpan>* result) const { |
| if (model_->triggering_options() == nullptr || |
| !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) { |
| return true; |
| } |
| |
| int click_pos; |
| *tokens = selection_feature_processor_->Tokenize(context_unicode); |
| selection_feature_processor_->RetokenizeAndFindClick( |
| context_unicode, click_indices, |
| selection_feature_processor_->GetOptions()->only_use_line_with_click(), |
| tokens, &click_pos); |
| if (click_pos == kInvalidIndex) { |
| TC_VLOG(1) << "Could not calculate the click position."; |
| return false; |
| } |
| |
| const int symmetry_context_size = |
| model_->selection_options()->symmetry_context_size(); |
| const FeatureProcessorOptions_::BoundsSensitiveFeatures* |
| bounds_sensitive_features = selection_feature_processor_->GetOptions() |
| ->bounds_sensitive_features(); |
| |
| // The symmetry context span is the clicked token with symmetry_context_size |
| // tokens on either side. |
| const TokenSpan symmetry_context_span = IntersectTokenSpans( |
| ExpandTokenSpan(SingleTokenSpan(click_pos), |
| /*num_tokens_left=*/symmetry_context_size, |
| /*num_tokens_right=*/symmetry_context_size), |
| {0, tokens->size()}); |
| |
| // Compute the extraction span based on the model type. |
| TokenSpan extraction_span; |
| if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { |
| // The extraction span is the symmetry context span expanded to include |
| // max_selection_span tokens on either side, which is how far a selection |
| // can stretch from the click, plus a relevant number of tokens outside of |
| // the bounds of the selection. |
| const int max_selection_span = |
| selection_feature_processor_->GetOptions()->max_selection_span(); |
| extraction_span = |
| ExpandTokenSpan(symmetry_context_span, |
| /*num_tokens_left=*/max_selection_span + |
| bounds_sensitive_features->num_tokens_before(), |
| /*num_tokens_right=*/max_selection_span + |
| bounds_sensitive_features->num_tokens_after()); |
| } else { |
| // The extraction span is the symmetry context span expanded to include |
| // context_size tokens on either side. |
| const int context_size = |
| selection_feature_processor_->GetOptions()->context_size(); |
| extraction_span = ExpandTokenSpan(symmetry_context_span, |
| /*num_tokens_left=*/context_size, |
| /*num_tokens_right=*/context_size); |
| } |
| extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()}); |
| |
| if (!selection_feature_processor_->HasEnoughSupportedCodepoints( |
| *tokens, extraction_span)) { |
| return true; |
| } |
| |
| std::unique_ptr<CachedFeatures> cached_features; |
| if (!selection_feature_processor_->ExtractFeatures( |
| *tokens, extraction_span, |
| /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, |
| embedding_executor_.get(), |
| /*embedding_cache=*/nullptr, |
| selection_feature_processor_->EmbeddingSize() + |
| selection_feature_processor_->DenseFeaturesCount(), |
| &cached_features)) { |
| TC_LOG(ERROR) << "Could not extract features."; |
| return false; |
| } |
| |
| // Produce selection model candidates. |
| std::vector<TokenSpan> chunks; |
| if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span, |
| interpreter_manager->SelectionInterpreter(), *cached_features, |
| &chunks)) { |
| TC_LOG(ERROR) << "Could not chunk."; |
| return false; |
| } |
| |
| for (const TokenSpan& chunk : chunks) { |
| AnnotatedSpan candidate; |
| candidate.span = selection_feature_processor_->StripBoundaryCodepoints( |
| context_unicode, TokenSpanToCodepointSpan(*tokens, chunk)); |
| if (model_->selection_options()->strip_unpaired_brackets()) { |
| candidate.span = |
| StripUnpairedBrackets(context_unicode, candidate.span, *unilib_); |
| } |
| |
| // Only output non-empty spans. |
| if (candidate.span.first != candidate.span.second) { |
| result->push_back(candidate); |
| } |
| } |
| return true; |
| } |
| |
| bool TextClassifier::ModelClassifyText( |
| const std::string& context, CodepointSpan selection_indices, |
| InterpreterManager* interpreter_manager, |
| FeatureProcessor::EmbeddingCache* embedding_cache, |
| std::vector<ClassificationResult>* classification_results) const { |
| if (model_->triggering_options() == nullptr || |
| !(model_->triggering_options()->enabled_modes() & |
| ModeFlag_CLASSIFICATION)) { |
| return true; |
| } |
| return ModelClassifyText(context, {}, selection_indices, interpreter_manager, |
| embedding_cache, classification_results); |
| } |
| |
| namespace internal { |
| std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, |
| CodepointSpan selection_indices, |
| TokenSpan tokens_around_selection_to_copy) { |
| const auto first_selection_token = std::upper_bound( |
| cached_tokens.begin(), cached_tokens.end(), selection_indices.first, |
| [](int selection_start, const Token& token) { |
| return selection_start < token.end; |
| }); |
| const auto last_selection_token = std::lower_bound( |
| cached_tokens.begin(), cached_tokens.end(), selection_indices.second, |
| [](const Token& token, int selection_end) { |
| return token.start < selection_end; |
| }); |
| |
| const int64 first_token = std::max( |
| static_cast<int64>(0), |
| static_cast<int64>((first_selection_token - cached_tokens.begin()) - |
| tokens_around_selection_to_copy.first)); |
| const int64 last_token = std::min( |
| static_cast<int64>(cached_tokens.size()), |
| static_cast<int64>((last_selection_token - cached_tokens.begin()) + |
| tokens_around_selection_to_copy.second)); |
| |
| std::vector<Token> tokens; |
| tokens.reserve(last_token - first_token); |
| for (int i = first_token; i < last_token; ++i) { |
| tokens.push_back(cached_tokens[i]); |
| } |
| return tokens; |
| } |
| } // namespace internal |
| |
| TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const { |
| const FeatureProcessorOptions_::BoundsSensitiveFeatures* |
| bounds_sensitive_features = |
| classification_feature_processor_->GetOptions() |
| ->bounds_sensitive_features(); |
| if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { |
| // The extraction span is the selection span expanded to include a relevant |
| // number of tokens outside of the bounds of the selection. |
| return {bounds_sensitive_features->num_tokens_before(), |
| bounds_sensitive_features->num_tokens_after()}; |
| } else { |
| // The extraction span is the clicked token with context_size tokens on |
| // either side. |
| const int context_size = |
| selection_feature_processor_->GetOptions()->context_size(); |
| return {context_size, context_size}; |
| } |
| } |
| |
| bool TextClassifier::ModelClassifyText( |
| const std::string& context, const std::vector<Token>& cached_tokens, |
| CodepointSpan selection_indices, InterpreterManager* interpreter_manager, |
| FeatureProcessor::EmbeddingCache* embedding_cache, |
| std::vector<ClassificationResult>* classification_results) const { |
| std::vector<Token> tokens; |
| if (cached_tokens.empty()) { |
| tokens = classification_feature_processor_->Tokenize(context); |
| } else { |
| tokens = internal::CopyCachedTokens(cached_tokens, selection_indices, |
| ClassifyTextUpperBoundNeededTokens()); |
| } |
| |
| int click_pos; |
| classification_feature_processor_->RetokenizeAndFindClick( |
| context, selection_indices, |
| classification_feature_processor_->GetOptions() |
| ->only_use_line_with_click(), |
| &tokens, &click_pos); |
| const TokenSpan selection_token_span = |
| CodepointSpanToTokenSpan(tokens, selection_indices); |
| const int selection_num_tokens = TokenSpanSize(selection_token_span); |
| if (model_->classification_options()->max_num_tokens() > 0 && |
| model_->classification_options()->max_num_tokens() < |
| selection_num_tokens) { |
| *classification_results = {{kOtherCollection, 1.0}}; |
| return true; |
| } |
| |
| const FeatureProcessorOptions_::BoundsSensitiveFeatures* |
| bounds_sensitive_features = |
| classification_feature_processor_->GetOptions() |
| ->bounds_sensitive_features(); |
| if (selection_token_span.first == kInvalidIndex || |
| selection_token_span.second == kInvalidIndex) { |
| TC_LOG(ERROR) << "Could not determine span."; |
| return false; |
| } |
| |
| // Compute the extraction span based on the model type. |
| TokenSpan extraction_span; |
| if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { |
| // The extraction span is the selection span expanded to include a relevant |
| // number of tokens outside of the bounds of the selection. |
| extraction_span = ExpandTokenSpan( |
| selection_token_span, |
| /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(), |
| /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after()); |
| } else { |
| if (click_pos == kInvalidIndex) { |
| TC_LOG(ERROR) << "Couldn't choose a click position."; |
| return false; |
| } |
| // The extraction span is the clicked token with context_size tokens on |
| // either side. |
| const int context_size = |
| classification_feature_processor_->GetOptions()->context_size(); |
| extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos), |
| /*num_tokens_left=*/context_size, |
| /*num_tokens_right=*/context_size); |
| } |
| extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()}); |
| |
| if (!classification_feature_processor_->HasEnoughSupportedCodepoints( |
| tokens, extraction_span)) { |
| *classification_results = {{kOtherCollection, 1.0}}; |
| return true; |
| } |
| |
| std::unique_ptr<CachedFeatures> cached_features; |
| if (!classification_feature_processor_->ExtractFeatures( |
| tokens, extraction_span, selection_indices, embedding_executor_.get(), |
| embedding_cache, |
| classification_feature_processor_->EmbeddingSize() + |
| classification_feature_processor_->DenseFeaturesCount(), |
| &cached_features)) { |
| TC_LOG(ERROR) << "Could not extract features."; |
| return false; |
| } |
| |
| std::vector<float> features; |
| features.reserve(cached_features->OutputFeaturesSize()); |
| if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { |
| cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span, |
| &features); |
| } else { |
| cached_features->AppendClickContextFeaturesForClick(click_pos, &features); |
| } |
| |
| TensorView<float> logits = classification_executor_->ComputeLogits( |
| TensorView<float>(features.data(), |
| {1, static_cast<int>(features.size())}), |
| interpreter_manager->ClassificationInterpreter()); |
| if (!logits.is_valid()) { |
| TC_LOG(ERROR) << "Couldn't compute logits."; |
| return false; |
| } |
| |
| if (logits.dims() != 2 || logits.dim(0) != 1 || |
| logits.dim(1) != classification_feature_processor_->NumCollections()) { |
| TC_LOG(ERROR) << "Mismatching output"; |
| return false; |
| } |
| |
| const std::vector<float> scores = |
| ComputeSoftmax(logits.data(), logits.dim(1)); |
| |
| classification_results->resize(scores.size()); |
| for (int i = 0; i < scores.size(); i++) { |
| (*classification_results)[i] = { |
| classification_feature_processor_->LabelToCollection(i), scores[i]}; |
| } |
| std::sort(classification_results->begin(), classification_results->end(), |
| [](const ClassificationResult& a, const ClassificationResult& b) { |
| return a.score > b.score; |
| }); |
| |
| // Phone class sanity check. |
| if (!classification_results->empty() && |
| classification_results->begin()->collection == kPhoneCollection) { |
| const int digit_count = CountDigits(context, selection_indices); |
| if (digit_count < |
| model_->classification_options()->phone_min_num_digits() || |
| digit_count > |
| model_->classification_options()->phone_max_num_digits()) { |
| *classification_results = {{kOtherCollection, 1.0}}; |
| } |
| } |
| |
| // Address class sanity check. |
| if (!classification_results->empty() && |
| classification_results->begin()->collection == kAddressCollection) { |
| if (selection_num_tokens < |
| model_->classification_options()->address_min_num_tokens()) { |
| *classification_results = {{kOtherCollection, 1.0}}; |
| } |
| } |
| |
| return true; |
| } |
| |
| bool TextClassifier::RegexClassifyText( |
| const std::string& context, CodepointSpan selection_indices, |
| ClassificationResult* classification_result) const { |
| const std::string selection_text = |
| ExtractSelection(context, selection_indices); |
| const UnicodeText selection_text_unicode( |
| UTF8ToUnicodeText(selection_text, /*do_copy=*/false)); |
| |
| // Check whether any of the regular expressions match. |
| for (const int pattern_id : classification_regex_patterns_) { |
| const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; |
| const std::unique_ptr<UniLib::RegexMatcher> matcher = |
| regex_pattern.pattern->Matcher(selection_text_unicode); |
| int status = UniLib::RegexMatcher::kNoError; |
| bool matches; |
| if (regex_approximate_match_pattern_ids_.find(pattern_id) != |
| regex_approximate_match_pattern_ids_.end()) { |
| matches = matcher->ApproximatelyMatches(&status); |
| } else { |
| matches = matcher->Matches(&status); |
| } |
| if (status != UniLib::RegexMatcher::kNoError) { |
| return false; |
| } |
| if (matches) { |
| *classification_result = {regex_pattern.collection_name, |
| regex_pattern.target_classification_score, |
| regex_pattern.priority_score}; |
| return true; |
| } |
| if (status != UniLib::RegexMatcher::kNoError) { |
| TC_LOG(ERROR) << "Cound't match regex: " << pattern_id; |
| } |
| } |
| |
| return false; |
| } |
| |
| bool TextClassifier::DatetimeClassifyText( |
| const std::string& context, CodepointSpan selection_indices, |
| const ClassificationOptions& options, |
| ClassificationResult* classification_result) const { |
| if (!datetime_parser_) { |
| return false; |
| } |
| |
| const std::string selection_text = |
| ExtractSelection(context, selection_indices); |
| |
| std::vector<DatetimeParseResultSpan> datetime_spans; |
| if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, |
| options.reference_timezone, options.locales, |
| ModeFlag_CLASSIFICATION, |
| /*anchor_start_end=*/true, &datetime_spans)) { |
| TC_LOG(ERROR) << "Error during parsing datetime."; |
| return false; |
| } |
| for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { |
| // Only consider the result valid if the selection and extracted datetime |
| // spans exactly match. |
| if (std::make_pair(datetime_span.span.first + selection_indices.first, |
| datetime_span.span.second + selection_indices.first) == |
| selection_indices) { |
| *classification_result = {kDateCollection, |
| datetime_span.target_classification_score}; |
| classification_result->datetime_parse_result = datetime_span.data; |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| std::vector<ClassificationResult> TextClassifier::ClassifyText( |
| const std::string& context, CodepointSpan selection_indices, |
| const ClassificationOptions& options) const { |
| if (!initialized_) { |
| TC_LOG(ERROR) << "Not initialized"; |
| return {}; |
| } |
| |
| if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { |
| return {}; |
| } |
| |
| if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { |
| return {}; |
| } |
| |
| if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { |
| TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: " |
| << std::get<0>(selection_indices) << " " |
| << std::get<1>(selection_indices); |
| return {}; |
| } |
| |
| // Try the regular expression models. |
| ClassificationResult regex_result; |
| if (RegexClassifyText(context, selection_indices, ®ex_result)) { |
| if (!FilteredForClassification(regex_result)) { |
| return {regex_result}; |
| } else { |
| return {{kOtherCollection, 1.0}}; |
| } |
| } |
| |
| // Try the date model. |
| ClassificationResult datetime_result; |
| if (DatetimeClassifyText(context, selection_indices, options, |
| &datetime_result)) { |
| if (!FilteredForClassification(datetime_result)) { |
| return {datetime_result}; |
| } else { |
| return {{kOtherCollection, 1.0}}; |
| } |
| } |
| |
| // Fallback to the model. |
| std::vector<ClassificationResult> model_result; |
| |
| InterpreterManager interpreter_manager(selection_executor_.get(), |
| classification_executor_.get()); |
| if (ModelClassifyText(context, selection_indices, &interpreter_manager, |
| /*embedding_cache=*/nullptr, &model_result) && |
| !model_result.empty()) { |
| if (!FilteredForClassification(model_result[0])) { |
| return model_result; |
| } else { |
| return {{kOtherCollection, 1.0}}; |
| } |
| } |
| |
| // No classifications. |
| return {}; |
| } |
| |
| bool TextClassifier::ModelAnnotate(const std::string& context, |
| InterpreterManager* interpreter_manager, |
| std::vector<Token>* tokens, |
| std::vector<AnnotatedSpan>* result) const { |
| if (model_->triggering_options() == nullptr || |
| !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { |
| return true; |
| } |
| |
| const UnicodeText context_unicode = UTF8ToUnicodeText(context, |
| /*do_copy=*/false); |
| std::vector<UnicodeTextRange> lines; |
| if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) { |
| lines.push_back({context_unicode.begin(), context_unicode.end()}); |
| } else { |
| lines = selection_feature_processor_->SplitContext(context_unicode); |
| } |
| |
| const float min_annotate_confidence = |
| (model_->triggering_options() != nullptr |
| ? model_->triggering_options()->min_annotate_confidence() |
| : 0.f); |
| |
| FeatureProcessor::EmbeddingCache embedding_cache; |
| for (const UnicodeTextRange& line : lines) { |
| const std::string line_str = |
| UnicodeText::UTF8Substring(line.first, line.second); |
| |
| *tokens = selection_feature_processor_->Tokenize(line_str); |
| selection_feature_processor_->RetokenizeAndFindClick( |
| line_str, {0, std::distance(line.first, line.second)}, |
| selection_feature_processor_->GetOptions()->only_use_line_with_click(), |
| tokens, |
| /*click_pos=*/nullptr); |
| const TokenSpan full_line_span = {0, tokens->size()}; |
| |
| // TODO(zilka): Add support for greater granularity of this check. |
| if (!selection_feature_processor_->HasEnoughSupportedCodepoints( |
| *tokens, full_line_span)) { |
| continue; |
| } |
| |
| std::unique_ptr<CachedFeatures> cached_features; |
| if (!selection_feature_processor_->ExtractFeatures( |
| *tokens, full_line_span, |
| /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, |
| embedding_executor_.get(), |
| /*embedding_cache=*/nullptr, |
| selection_feature_processor_->EmbeddingSize() + |
| selection_feature_processor_->DenseFeaturesCount(), |
| &cached_features)) { |
| TC_LOG(ERROR) << "Could not extract features."; |
| return false; |
| } |
| |
| std::vector<TokenSpan> local_chunks; |
| if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span, |
| interpreter_manager->SelectionInterpreter(), |
| *cached_features, &local_chunks)) { |
| TC_LOG(ERROR) << "Could not chunk."; |
| return false; |
| } |
| |
| const int offset = std::distance(context_unicode.begin(), line.first); |
| for (const TokenSpan& chunk : local_chunks) { |
| const CodepointSpan codepoint_span = |
| selection_feature_processor_->StripBoundaryCodepoints( |
| line_str, TokenSpanToCodepointSpan(*tokens, chunk)); |
| |
| // Skip empty spans. |
| if (codepoint_span.first != codepoint_span.second) { |
| std::vector<ClassificationResult> classification; |
| if (!ModelClassifyText(line_str, *tokens, codepoint_span, |
| interpreter_manager, &embedding_cache, |
| &classification)) { |
| TC_LOG(ERROR) << "Could not classify text: " |
| << (codepoint_span.first + offset) << " " |
| << (codepoint_span.second + offset); |
| return false; |
| } |
| |
| // Do not include the span if it's classified as "other". |
| if (!classification.empty() && !ClassifiedAsOther(classification) && |
| classification[0].score >= min_annotate_confidence) { |
| AnnotatedSpan result_span; |
| result_span.span = {codepoint_span.first + offset, |
| codepoint_span.second + offset}; |
| result_span.classification = std::move(classification); |
| result->push_back(std::move(result_span)); |
| } |
| } |
| } |
| } |
| return true; |
| } |
| |
| const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests() |
| const { |
| return selection_feature_processor_.get(); |
| } |
| |
| const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests() |
| const { |
| return classification_feature_processor_.get(); |
| } |
| |
| const DatetimeParser* TextClassifier::DatetimeParserForTests() const { |
| return datetime_parser_.get(); |
| } |
| |
| std::vector<AnnotatedSpan> TextClassifier::Annotate( |
| const std::string& context, const AnnotationOptions& options) const { |
| std::vector<AnnotatedSpan> candidates; |
| |
| if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { |
| return {}; |
| } |
| |
| if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { |
| return {}; |
| } |
| |
| InterpreterManager interpreter_manager(selection_executor_.get(), |
| classification_executor_.get()); |
| // Annotate with the selection model. |
| std::vector<Token> tokens; |
| if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) { |
| TC_LOG(ERROR) << "Couldn't run ModelAnnotate."; |
| return {}; |
| } |
| |
| // Annotate with the regular expression models. |
| if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), |
| annotation_regex_patterns_, &candidates)) { |
| TC_LOG(ERROR) << "Couldn't run RegexChunk."; |
| return {}; |
| } |
| |
| // Annotate with the datetime model. |
| if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), |
| options.reference_time_ms_utc, options.reference_timezone, |
| options.locales, ModeFlag_ANNOTATION, &candidates)) { |
| TC_LOG(ERROR) << "Couldn't run RegexChunk."; |
| return {}; |
| } |
| |
| // Sort candidates according to their position in the input, so that the next |
| // code can assume that any connected component of overlapping spans forms a |
| // contiguous block. |
| std::sort(candidates.begin(), candidates.end(), |
| [](const AnnotatedSpan& a, const AnnotatedSpan& b) { |
| return a.span.first < b.span.first; |
| }); |
| |
| std::vector<int> candidate_indices; |
| if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, |
| &candidate_indices)) { |
| TC_LOG(ERROR) << "Couldn't resolve conflicts."; |
| return {}; |
| } |
| |
| std::vector<AnnotatedSpan> result; |
| result.reserve(candidate_indices.size()); |
| for (const int i : candidate_indices) { |
| if (!candidates[i].classification.empty() && |
| !ClassifiedAsOther(candidates[i].classification) && |
| !FilteredForAnnotation(candidates[i])) { |
| result.push_back(std::move(candidates[i])); |
| } |
| } |
| |
| return result; |
| } |
| |
| bool TextClassifier::RegexChunk(const UnicodeText& context_unicode, |
| const std::vector<int>& rules, |
| std::vector<AnnotatedSpan>* result) const { |
| for (int pattern_id : rules) { |
| const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; |
| const auto matcher = regex_pattern.pattern->Matcher(context_unicode); |
| if (!matcher) { |
| TC_LOG(ERROR) << "Could not get regex matcher for pattern: " |
| << pattern_id; |
| return false; |
| } |
| |
| int status = UniLib::RegexMatcher::kNoError; |
| while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { |
| result->emplace_back(); |
| // Selection/annotation regular expressions need to specify a capturing |
| // group specifying the selection. |
| result->back().span = {matcher->Start(1, &status), |
| matcher->End(1, &status)}; |
| result->back().classification = { |
| {regex_pattern.collection_name, |
| regex_pattern.target_classification_score, |
| regex_pattern.priority_score}}; |
| } |
| } |
| return true; |
| } |
| |
| bool TextClassifier::ModelChunk(int num_tokens, |
| const TokenSpan& span_of_interest, |
| tflite::Interpreter* selection_interpreter, |
| const CachedFeatures& cached_features, |
| std::vector<TokenSpan>* chunks) const { |
| const int max_selection_span = |
| selection_feature_processor_->GetOptions()->max_selection_span(); |
| // The inference span is the span of interest expanded to include |
| // max_selection_span tokens on either side, which is how far a selection can |
| // stretch from the click. |
| const TokenSpan inference_span = IntersectTokenSpans( |
| ExpandTokenSpan(span_of_interest, |
| /*num_tokens_left=*/max_selection_span, |
| /*num_tokens_right=*/max_selection_span), |
| {0, num_tokens}); |
| |
| std::vector<ScoredChunk> scored_chunks; |
| if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() && |
| selection_feature_processor_->GetOptions() |
| ->bounds_sensitive_features() |
| ->enabled()) { |
| if (!ModelBoundsSensitiveScoreChunks( |
| num_tokens, span_of_interest, inference_span, cached_features, |
| selection_interpreter, &scored_chunks)) { |
| return false; |
| } |
| } else { |
| if (!ModelClickContextScoreChunks(num_tokens, span_of_interest, |
| cached_features, selection_interpreter, |
| &scored_chunks)) { |
| return false; |
| } |
| } |
| std::sort(scored_chunks.rbegin(), scored_chunks.rend(), |
| [](const ScoredChunk& lhs, const ScoredChunk& rhs) { |
| return lhs.score < rhs.score; |
| }); |
| |
| // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick |
| // them greedily as long as they do not overlap with any previously picked |
| // chunks. |
| std::vector<bool> token_used(TokenSpanSize(inference_span)); |
| chunks->clear(); |
| for (const ScoredChunk& scored_chunk : scored_chunks) { |
| bool feasible = true; |
| for (int i = scored_chunk.token_span.first; |
| i < scored_chunk.token_span.second; ++i) { |
| if (token_used[i - inference_span.first]) { |
| feasible = false; |
| break; |
| } |
| } |
| |
| if (!feasible) { |
| continue; |
| } |
| |
| for (int i = scored_chunk.token_span.first; |
| i < scored_chunk.token_span.second; ++i) { |
| token_used[i - inference_span.first] = true; |
| } |
| |
| chunks->push_back(scored_chunk.token_span); |
| } |
| |
| std::sort(chunks->begin(), chunks->end()); |
| |
| return true; |
| } |
| |
| namespace { |
| // Updates the value at the given key in the map to maximum of the current value |
| // and the given value, or simply inserts the value if the key is not yet there. |
| template <typename Map> |
| void UpdateMax(Map* map, typename Map::key_type key, |
| typename Map::mapped_type value) { |
| const auto it = map->find(key); |
| if (it != map->end()) { |
| it->second = std::max(it->second, value); |
| } else { |
| (*map)[key] = value; |
| } |
| } |
| } // namespace |
| |
| bool TextClassifier::ModelClickContextScoreChunks( |
| int num_tokens, const TokenSpan& span_of_interest, |
| const CachedFeatures& cached_features, |
| tflite::Interpreter* selection_interpreter, |
| std::vector<ScoredChunk>* scored_chunks) const { |
| const int max_batch_size = model_->selection_options()->batch_size(); |
| |
| std::vector<float> all_features; |
| std::map<TokenSpan, float> chunk_scores; |
| for (int batch_start = span_of_interest.first; |
| batch_start < span_of_interest.second; batch_start += max_batch_size) { |
| const int batch_end = |
| std::min(batch_start + max_batch_size, span_of_interest.second); |
| |
| // Prepare features for the whole batch. |
| all_features.clear(); |
| all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); |
| for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { |
| cached_features.AppendClickContextFeaturesForClick(click_pos, |
| &all_features); |
| } |
| |
| // Run batched inference. |
| const int batch_size = batch_end - batch_start; |
| const int features_size = cached_features.OutputFeaturesSize(); |
| TensorView<float> logits = selection_executor_->ComputeLogits( |
| TensorView<float>(all_features.data(), {batch_size, features_size}), |
| selection_interpreter); |
| if (!logits.is_valid()) { |
| TC_LOG(ERROR) << "Couldn't compute logits."; |
| return false; |
| } |
| if (logits.dims() != 2 || logits.dim(0) != batch_size || |
| logits.dim(1) != |
| selection_feature_processor_->GetSelectionLabelCount()) { |
| TC_LOG(ERROR) << "Mismatching output."; |
| return false; |
| } |
| |
| // Save results. |
| for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { |
| const std::vector<float> scores = ComputeSoftmax( |
| logits.data() + logits.dim(1) * (click_pos - batch_start), |
| logits.dim(1)); |
| for (int j = 0; |
| j < selection_feature_processor_->GetSelectionLabelCount(); ++j) { |
| TokenSpan relative_token_span; |
| if (!selection_feature_processor_->LabelToTokenSpan( |
| j, &relative_token_span)) { |
| TC_LOG(ERROR) << "Couldn't map the label to a token span."; |
| return false; |
| } |
| const TokenSpan candidate_span = ExpandTokenSpan( |
| SingleTokenSpan(click_pos), relative_token_span.first, |
| relative_token_span.second); |
| if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) { |
| UpdateMax(&chunk_scores, candidate_span, scores[j]); |
| } |
| } |
| } |
| } |
| |
| scored_chunks->clear(); |
| scored_chunks->reserve(chunk_scores.size()); |
| for (const auto& entry : chunk_scores) { |
| scored_chunks->push_back(ScoredChunk{entry.first, entry.second}); |
| } |
| |
| return true; |
| } |
| |
| bool TextClassifier::ModelBoundsSensitiveScoreChunks( |
| int num_tokens, const TokenSpan& span_of_interest, |
| const TokenSpan& inference_span, const CachedFeatures& cached_features, |
| tflite::Interpreter* selection_interpreter, |
| std::vector<ScoredChunk>* scored_chunks) const { |
| const int max_selection_span = |
| selection_feature_processor_->GetOptions()->max_selection_span(); |
| const int max_chunk_length = selection_feature_processor_->GetOptions() |
| ->selection_reduced_output_space() |
| ? max_selection_span + 1 |
| : 2 * max_selection_span + 1; |
| const bool score_single_token_spans_as_zero = |
| selection_feature_processor_->GetOptions() |
| ->bounds_sensitive_features() |
| ->score_single_token_spans_as_zero(); |
| |
| scored_chunks->clear(); |
| if (score_single_token_spans_as_zero) { |
| scored_chunks->reserve(TokenSpanSize(span_of_interest)); |
| } |
| |
| // Prepare all chunk candidates into one batch: |
| // - Are contained in the inference span |
| // - Have a non-empty intersection with the span of interest |
| // - Are at least one token long |
| // - Are not longer than the maximum chunk length |
| std::vector<TokenSpan> candidate_spans; |
| for (int start = inference_span.first; start < span_of_interest.second; |
| ++start) { |
| const int leftmost_end_index = std::max(start, span_of_interest.first) + 1; |
| for (int end = leftmost_end_index; |
| end <= inference_span.second && end - start <= max_chunk_length; |
| ++end) { |
| const TokenSpan candidate_span = {start, end}; |
| if (score_single_token_spans_as_zero && |
| TokenSpanSize(candidate_span) == 1) { |
| // Do not include the single token span in the batch, add a zero score |
| // for it directly to the output. |
| scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f}); |
| } else { |
| candidate_spans.push_back(candidate_span); |
| } |
| } |
| } |
| |
| const int max_batch_size = model_->selection_options()->batch_size(); |
| |
| std::vector<float> all_features; |
| scored_chunks->reserve(scored_chunks->size() + candidate_spans.size()); |
| for (int batch_start = 0; batch_start < candidate_spans.size(); |
| batch_start += max_batch_size) { |
| const int batch_end = std::min(batch_start + max_batch_size, |
| static_cast<int>(candidate_spans.size())); |
| |
| // Prepare features for the whole batch. |
| all_features.clear(); |
| all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); |
| for (int i = batch_start; i < batch_end; ++i) { |
| cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i], |
| &all_features); |
| } |
| |
| // Run batched inference. |
| const int batch_size = batch_end - batch_start; |
| const int features_size = cached_features.OutputFeaturesSize(); |
| TensorView<float> logits = selection_executor_->ComputeLogits( |
| TensorView<float>(all_features.data(), {batch_size, features_size}), |
| selection_interpreter); |
| if (!logits.is_valid()) { |
| TC_LOG(ERROR) << "Couldn't compute logits."; |
| return false; |
| } |
| if (logits.dims() != 2 || logits.dim(0) != batch_size || |
| logits.dim(1) != 1) { |
| TC_LOG(ERROR) << "Mismatching output."; |
| return false; |
| } |
| |
| // Save results. |
| for (int i = batch_start; i < batch_end; ++i) { |
| scored_chunks->push_back( |
| ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]}); |
| } |
| } |
| |
| return true; |
| } |
| |
| bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode, |
| int64 reference_time_ms_utc, |
| const std::string& reference_timezone, |
| const std::string& locales, ModeFlag mode, |
| std::vector<AnnotatedSpan>* result) const { |
| if (!datetime_parser_) { |
| return true; |
| } |
| |
| std::vector<DatetimeParseResultSpan> datetime_spans; |
| if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc, |
| reference_timezone, locales, mode, |
| /*anchor_start_end=*/false, &datetime_spans)) { |
| return false; |
| } |
| for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { |
| AnnotatedSpan annotated_span; |
| annotated_span.span = datetime_span.span; |
| annotated_span.classification = {{kDateCollection, |
| datetime_span.target_classification_score, |
| datetime_span.priority_score}}; |
| annotated_span.classification[0].datetime_parse_result = datetime_span.data; |
| |
| result->push_back(std::move(annotated_span)); |
| } |
| return true; |
| } |
| |
| const Model* ViewModel(const void* buffer, int size) { |
| if (!buffer) { |
| return nullptr; |
| } |
| |
| return LoadAndVerifyModel(buffer, size); |
| } |
| |
| } // namespace libtextclassifier2 |