blob: 5fece36b5fd513f18b2fcea36a60a80960ea5b58 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "smartselect/text-classification-model.h"
#include <cmath>
#include <iterator>
#include <numeric>
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/memory_image/embedding-network-params-from-image.h"
#include "common/memory_image/memory-image-reader.h"
#include "common/mmap.h"
#include "common/softmax.h"
#include "smartselect/text-classification-model.pb.h"
#include "util/base/logging.h"
#include "util/utf8/unicodetext.h"
namespace libtextclassifier {
using nlp_core::EmbeddingNetwork;
using nlp_core::EmbeddingNetworkProto;
using nlp_core::FeatureVector;
using nlp_core::MemoryImageReader;
using nlp_core::MmapFile;
using nlp_core::MmapHandle;
ModelParams* ModelParams::Build(const void* start, uint64 num_bytes) {
MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes);
FeatureProcessorOptions feature_processor_options;
auto feature_processor_extension_id =
feature_processor_options_in_embedding_network_proto;
if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) {
feature_processor_options =
reader.trimmed_proto().GetExtension(feature_processor_extension_id);
// If no tokenization codepoint config is present, tokenize on space.
if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
TokenizationCodepointRange* config =
feature_processor_options.add_tokenization_codepoint_config();
config->set_start(32);
config->set_end(33);
config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
}
} else {
return nullptr;
}
SelectionModelOptions selection_options;
auto selection_options_extension_id =
selection_model_options_in_embedding_network_proto;
if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
selection_options =
reader.trimmed_proto().GetExtension(selection_options_extension_id);
} else {
// TODO(zilka): Remove this once we added the model options to the exported
// models.
for (const auto codepoint_pair : std::vector<std::pair<int, int>>(
{{33, 35}, {37, 39}, {42, 42}, {44, 47},
{58, 59}, {63, 64}, {91, 93}, {95, 95},
{123, 123}, {125, 125}, {161, 161}, {171, 171},
{183, 183}, {187, 187}, {191, 191}, {894, 894},
{903, 903}, {1370, 1375}, {1417, 1418}, {1470, 1470},
{1472, 1472}, {1475, 1475}, {1478, 1478}, {1523, 1524},
{1548, 1549}, {1563, 1563}, {1566, 1567}, {1642, 1645},
{1748, 1748}, {1792, 1805}, {2404, 2405}, {2416, 2416},
{3572, 3572}, {3663, 3663}, {3674, 3675}, {3844, 3858},
{3898, 3901}, {3973, 3973}, {4048, 4049}, {4170, 4175},
{4347, 4347}, {4961, 4968}, {5741, 5742}, {5787, 5788},
{5867, 5869}, {5941, 5942}, {6100, 6102}, {6104, 6106},
{6144, 6154}, {6468, 6469}, {6622, 6623}, {6686, 6687},
{8208, 8231}, {8240, 8259}, {8261, 8273}, {8275, 8286},
{8317, 8318}, {8333, 8334}, {9001, 9002}, {9140, 9142},
{10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648},
{10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519},
{11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305},
{12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448},
{12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106},
{65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131},
{65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307},
{65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371},
{65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463},
{68176, 68184}})) {
for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) {
selection_options.add_punctuation_to_strip(i);
}
selection_options.set_strip_punctuation(true);
selection_options.set_enforce_symmetry(true);
selection_options.set_symmetry_context_size(
feature_processor_options.context_size() * 2);
}
}
return new ModelParams(start, num_bytes, selection_options,
feature_processor_options);
}
CodepointSpan TextClassificationModel::StripPunctuation(
CodepointSpan selection, const std::string& context) const {
UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
int context_length =
std::distance(context_unicode.begin(), context_unicode.end());
// Check that the indices are valid.
if (selection.first < 0 || selection.first > context_length ||
selection.second < 0 || selection.second > context_length) {
return selection;
}
UnicodeText::const_iterator it;
for (it = context_unicode.begin(), std::advance(it, selection.first);
punctuation_to_strip_.find(*it) != punctuation_to_strip_.end();
++it, ++selection.first) {
}
for (it = context_unicode.begin(), std::advance(it, selection.second - 1);
punctuation_to_strip_.find(*it) != punctuation_to_strip_.end();
--it, --selection.second) {
}
return selection;
}
TextClassificationModel::TextClassificationModel(int fd) {
initialized_ = LoadModels(fd);
if (!initialized_) {
TC_LOG(ERROR) << "Failed to load models";
return;
}
selection_options_ = selection_params_->GetSelectionModelOptions();
for (const int codepoint : selection_options_.punctuation_to_strip()) {
punctuation_to_strip_.insert(codepoint);
}
}
bool TextClassificationModel::LoadModels(int fd) {
MmapHandle mmap_handle = MmapFile(fd);
if (!mmap_handle.ok()) {
return false;
}
// Read the length of the selection model.
const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
uint32 selection_model_length =
LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
model_data += sizeof(selection_model_length);
selection_params_.reset(
ModelParams::Build(model_data, selection_model_length));
if (!selection_params_.get()) {
return false;
}
selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
selection_feature_processor_.reset(
new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
model_data += selection_model_length;
uint32 sharing_model_length =
LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
model_data += sizeof(sharing_model_length);
sharing_params_.reset(ModelParams::Build(model_data, sharing_model_length));
if (!sharing_params_.get()) {
return false;
}
sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
sharing_feature_processor_.reset(
new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
return true;
}
EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
const std::string& context, CodepointSpan click_indices,
CodepointSpan selection_indices, const FeatureProcessor& feature_processor,
const EmbeddingNetwork* network,
std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
CodepointSpan* selection_codepoint_label, int* classification_label) const {
SelectionWithContext selection_with_context;
selection_with_context.context = context;
selection_with_context.click_start = std::get<0>(click_indices);
selection_with_context.click_end = std::get<1>(click_indices);
selection_with_context.selection_start = std::get<0>(selection_indices);
selection_with_context.selection_end = std::get<1>(selection_indices);
std::vector<FeatureVector> features;
std::vector<float> extra_features;
const bool features_computed = feature_processor.GetFeaturesAndLabels(
selection_with_context, &features, &extra_features, selection_label_spans,
selection_label, selection_codepoint_label, classification_label);
EmbeddingNetwork::Vector scores;
if (!features_computed) {
TC_LOG(ERROR) << "Features not computed";
return scores;
}
network->ComputeFinalScores(features, extra_features, &scores);
return scores;
}
CodepointSpan TextClassificationModel::SuggestSelection(
const std::string& context, CodepointSpan click_indices) const {
if (!initialized_) {
TC_LOG(ERROR) << "Not initialized";
return click_indices;
}
if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
<< std::get<0>(click_indices) << " "
<< std::get<1>(click_indices);
return click_indices;
}
CodepointSpan result;
if (selection_options_.enforce_symmetry()) {
result = SuggestSelectionSymmetrical(context, click_indices);
} else {
float score;
std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
}
if (selection_options_.strip_punctuation()) {
result = StripPunctuation(result, context);
}
return result;
}
std::pair<CodepointSpan, float>
TextClassificationModel::SuggestSelectionInternal(
const std::string& context, CodepointSpan click_indices) const {
if (!initialized_) {
TC_LOG(ERROR) << "Not initialized";
return {click_indices, -1.0};
}
// Invalid selection indices make the feature extraction use the provided
// click indices.
const CodepointSpan selection_indices({kInvalidIndex, kInvalidIndex});
std::vector<CodepointSpan> selection_label_spans;
EmbeddingNetwork::Vector scores = InferInternal(
context, click_indices, selection_indices, *selection_feature_processor_,
selection_network_.get(), &selection_label_spans,
/*selection_label=*/nullptr,
/*selection_codepoint_label=*/nullptr,
/*classification_label=*/nullptr);
if (!scores.empty()) {
scores = nlp_core::ComputeSoftmax(scores);
const int prediction =
std::max_element(scores.begin(), scores.end()) - scores.begin();
std::pair<CodepointIndex, CodepointIndex> selection =
selection_label_spans[prediction];
if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
<< prediction << " " << selection.first << " "
<< selection.second;
return {click_indices, -1.0};
}
return {{selection.first, selection.second}, scores[prediction]};
} else {
TC_LOG(ERROR) << "Returning default selection: scores.size() = "
<< scores.size();
return {click_indices, -1.0};
}
}
namespace {
int GetClickTokenIndex(const std::vector<Token>& tokens,
CodepointSpan click_indices) {
TokenSpan span = CodepointSpanToTokenSpan(tokens, click_indices);
if (span.second - span.first == 1) {
return span.first;
} else {
for (int i = 0; i < tokens.size(); i++) {
if (tokens[i].start <= click_indices.first &&
tokens[i].end >= click_indices.second) {
return i;
}
}
return kInvalidIndex;
}
}
} // namespace
// Implements a greedy-search-like algorithm for making selections symmetric.
//
// Steps:
// 1. Get a set of selection proposals from places around the clicked word.
// 2. For each proposal (going from highest-scoring), check if the tokens that
// the proposal selects are still free, otherwise claims them, if a proposal
// that contains the clicked token is found, it is returned as the
// suggestion.
//
// This algorithm should ensure that if a selection is proposed, it does not
// matter which word of it was tapped - all of them will lead to the same
// selection.
CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
const std::string& full_context, CodepointSpan click_indices) const {
// Extract context from the current line only.
std::string context;
int context_shift;
std::tie(context, context_shift) =
ExtractLineWithSpan(full_context, click_indices);
click_indices.first -= context_shift;
click_indices.second -= context_shift;
std::vector<Token> tokens = selection_feature_processor_->Tokenize(context);
const int click_index = GetClickTokenIndex(tokens, click_indices);
if (click_index == kInvalidIndex) {
return click_indices;
}
const int symmetry_context_size = selection_options_.symmetry_context_size();
// Scan in the symmetry context for selection span proposals.
std::vector<std::pair<CodepointSpan, float>> proposals;
for (int i = -symmetry_context_size; i < symmetry_context_size + 1; i++) {
const int token_index = click_index + i;
if (token_index >= 0 && token_index < tokens.size()) {
float score;
CodepointSpan span;
std::tie(span, score) = SuggestSelectionInternal(
context, {tokens[token_index].start, tokens[token_index].end});
proposals.push_back({span, score});
}
}
// Sort selection span proposals by their respective probabilities.
std::sort(
proposals.begin(), proposals.end(),
[](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
return a.second > b.second;
});
// Go from the highest-scoring proposal and claim tokens. Tokens are marked as
// claimed by the higher-scoring selection proposals, so that the
// lower-scoring ones cannot use them. Returns the selection proposal if it
// contains the clicked token.
std::vector<int> used_tokens(tokens.size(), 0);
for (auto span_result : proposals) {
TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
bool feasible = true;
for (int i = span.first; i < span.second; i++) {
if (used_tokens[i] != 0) {
feasible = false;
break;
}
}
if (feasible) {
if (span.first <= click_index && span.second > click_index) {
return {span_result.first.first + context_shift,
span_result.first.second + context_shift};
}
for (int i = span.first; i < span.second; i++) {
used_tokens[i] = 1;
}
}
}
}
return {click_indices.first + context_shift,
click_indices.second + context_shift};
}
CodepointSpan TextClassificationModel::SuggestSelection(
const SelectionWithContext& selection_with_context) const {
CodepointSpan click_indices = {selection_with_context.click_start,
selection_with_context.click_end};
// If click_indices are unspecified, select the first token.
if (click_indices == CodepointSpan({kInvalidIndex, kInvalidIndex})) {
click_indices = selection_feature_processor_->ClickRandomTokenInSelection(
selection_with_context);
}
return SuggestSelection(selection_with_context.context, click_indices);
}
std::vector<std::pair<std::string, float>>
TextClassificationModel::ClassifyText(const std::string& context,
CodepointSpan selection_indices) const {
if (!initialized_) {
TC_LOG(ERROR) << "Not initialized";
return {};
}
// Invalid click indices make the feature extraction select the middle word in
// the selection span.
const CodepointSpan click_indices({kInvalidIndex, kInvalidIndex});
EmbeddingNetwork::Vector scores = InferInternal(
context, click_indices, selection_indices, *sharing_feature_processor_,
sharing_network_.get(), nullptr, nullptr, nullptr, nullptr);
if (scores.empty()) {
TC_LOG(ERROR) << "Using default class";
return {};
}
if (!scores.empty() &&
scores.size() == sharing_feature_processor_->NumCollections()) {
scores = nlp_core::ComputeSoftmax(scores);
std::vector<std::pair<std::string, float>> result;
for (int i = 0; i < scores.size(); i++) {
result.push_back(
{sharing_feature_processor_->LabelToCollection(i), scores[i]});
}
return result;
} else {
TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
return {};
}
}
} // namespace libtextclassifier