blob: 4e0a3da5b6a096c1ba2de13a91c1d8c050f69a26 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Inference code for the feed-forward text classification models.
#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
#define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
#include <memory>
#include <set>
#include <string>
#include "base.h"
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/memory_image/embedding-network-params-from-image.h"
#include "smartselect/feature-processor.h"
#include "smartselect/text-classification-model.pb.h"
#include "smartselect/types.h"
namespace libtextclassifier {
// Loads and holds the parameters of the inference network.
//
// This class overrides a couple of methods of EmbeddingNetworkParamsFromImage
// because we only have one embedding matrix for all positions of context,
// whereas the original class would have a separate one for each.
class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage {
public:
static ModelParams* Build(const void* start, uint64 num_bytes);
const FeatureProcessorOptions& GetFeatureProcessorOptions() const {
return feature_processor_options_;
}
const SelectionModelOptions& GetSelectionModelOptions() const {
return selection_options_;
}
protected:
int embeddings_size() const override { return context_size_ * 2 + 1; }
int embedding_num_features_size() const override {
return context_size_ * 2 + 1;
}
int embedding_num_features(int i) const override { return 1; }
int embeddings_num_rows(int i) const override {
return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0);
};
int embeddings_num_cols(int i) const override {
return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0);
};
const void* embeddings_weights(int i) const override {
return EmbeddingNetworkParamsFromImage::embeddings_weights(0);
};
nlp_core::QuantizationType embeddings_quant_type(int i) const override {
return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0);
}
const nlp_core::float16* embeddings_quant_scales(int i) const override {
return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0);
}
private:
ModelParams(const void* start, uint64 num_bytes,
const SelectionModelOptions& selection_options,
const FeatureProcessorOptions& feature_processor_options)
: EmbeddingNetworkParamsFromImage(start, num_bytes),
selection_options_(selection_options),
feature_processor_options_(feature_processor_options),
context_size_(feature_processor_options_.context_size()) {}
SelectionModelOptions selection_options_;
FeatureProcessorOptions feature_processor_options_;
int context_size_;
};
// SmartSelection/Sharing feed-forward model.
class TextClassificationModel {
public:
// Loads TextClassificationModel from given file given by an int
// file descriptor.
explicit TextClassificationModel(int fd);
// Runs inference for given a context and current selection (i.e. index
// of the first and one past last selected characters (utf8 codepoint
// offsets)). Returns the indices (utf8 codepoint offsets) of the selection
// beginning character and one past selection end character.
// Returns the original click_indices if an error occurs.
// NOTE: The selection indices are passed in and returned in terms of
// UTF8 codepoints (not bytes).
// Requires that the model is a smart selection model.
CodepointSpan SuggestSelection(const std::string& context,
CodepointSpan click_indices) const;
// Same as above but accepts a selection_with_context. Only used for
// evaluation.
CodepointSpan SuggestSelection(
const SelectionWithContext& selection_with_context) const;
// Classifies the selected text given the context string.
// Requires that the model is a smart sharing model.
// Returns a default collection name if an error occurs.
std::string ClassifyText(const std::string& context,
CodepointSpan click_indices) const;
protected:
// Removes punctuation from the beginning and end of the selection and returns
// the new selection span.
CodepointSpan StripPunctuation(CodepointSpan selection,
const std::string& context) const;
// During evaluation we need access to the feature processor.
FeatureProcessor* SelectionFeatureProcessor() const {
return selection_feature_processor_.get();
}
private:
bool LoadModels(int fd);
nlp_core::EmbeddingNetwork::Vector InferInternal(
const std::string& context, CodepointSpan click_indices,
CodepointSpan selection_indices,
const FeatureProcessor& feature_processor,
const nlp_core::EmbeddingNetwork* network,
std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
CodepointSpan* selection_codepoint_label,
int* classification_label) const;
// Returns a selection suggestion with a score.
std::pair<CodepointSpan, float> SuggestSelectionInternal(
const std::string& context, CodepointSpan click_indices) const;
// Returns a selection suggestion and makes sure it's symmetric. Internally
// runs several times SuggestSelectionInternal.
CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
CodepointSpan click_indices) const;
bool initialized_;
std::unique_ptr<ModelParams> selection_params_;
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
std::unique_ptr<ModelParams> sharing_params_;
std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
SelectionModelOptions selection_options_;
std::set<int> punctuation_to_strip_;
};
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_TEXT_CLASSIFICATION_MODEL_H_