| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ |
| #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ |
| |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include "common/feature-extractor.h" |
| #include "common/task-context.h" |
| #include "common/workspace.h" |
| #include "util/base/logging.h" |
| #include "util/base/macros.h" |
| |
| namespace libtextclassifier { |
| namespace nlp_core { |
| |
| // An EmbeddingFeatureExtractor manages the extraction of features for |
| // embedding-based models. It wraps a sequence of underlying classes of feature |
| // extractors, along with associated predicate maps. Each class of feature |
| // extractors is associated with a name, e.g., "words", "labels", "tags". |
| // |
| // The class is split between a generic abstract version, |
| // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the |
| // signature of the ExtractFeatures method) and a typed version. |
| // |
| // The predicate maps must be initialized before use: they can be loaded using |
| // Read() or updated via UpdateMapsForExample. |
| class GenericEmbeddingFeatureExtractor { |
| public: |
| GenericEmbeddingFeatureExtractor() {} |
| virtual ~GenericEmbeddingFeatureExtractor() {} |
| |
| // Get the prefix std::string to put in front of all arguments, so they don't |
| // conflict with other embedding models. |
| virtual const std::string ArgPrefix() const = 0; |
| |
| // Initializes predicate maps and embedding space names that are common for |
| // all embedding-based feature extractors. |
| virtual bool Init(TaskContext *context); |
| |
| // Requests workspace for the underlying feature extractors. This is |
| // implemented in the typed class. |
| virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0; |
| |
| // Returns number of embedding spaces. |
| int NumEmbeddings() const { return embedding_dims_.size(); } |
| |
| // Number of predicates for the embedding at a given index (vocabulary size). |
| // Returns -1 if index is out of bounds. |
| int EmbeddingSize(int index) const { |
| const GenericFeatureExtractor *extractor = generic_feature_extractor(index); |
| return (extractor == nullptr) ? -1 : extractor->GetDomainSize(); |
| } |
| |
| // Returns the dimensionality of the embedding space. |
| int EmbeddingDims(int index) const { return embedding_dims_[index]; } |
| |
| // Accessor for embedding dims (dimensions of the embedding spaces). |
| const std::vector<int> &embedding_dims() const { return embedding_dims_; } |
| |
| const std::vector<std::string> &embedding_fml() const { |
| return embedding_fml_; |
| } |
| |
| // Get parameter name by concatenating the prefix and the original name. |
| std::string GetParamName(const std::string ¶m_name) const { |
| std::string full_name = ArgPrefix(); |
| full_name.push_back('_'); |
| full_name.append(param_name); |
| return full_name; |
| } |
| |
| protected: |
| // Provides the generic class with access to the templated extractors. This is |
| // used to get the type information out of the feature extractor without |
| // knowing the specific calling arguments of the extractor itself. |
| // Returns nullptr for an out-of-bounds idx. |
| virtual const GenericFeatureExtractor *generic_feature_extractor( |
| int idx) const = 0; |
| |
| private: |
| // Embedding space names for parameter sharing. |
| std::vector<std::string> embedding_names_; |
| |
| // FML strings for each feature extractor. |
| std::vector<std::string> embedding_fml_; |
| |
| // Size of each of the embedding spaces (maximum predicate id). |
| std::vector<int> embedding_sizes_; |
| |
| // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.) |
| std::vector<int> embedding_dims_; |
| |
| TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor); |
| }; |
| |
| // Templated, object-specific implementation of the |
| // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ, |
| // ARGS...> class that has the appropriate FeatureTraits() to ensure that |
| // locator type features work. |
| // |
| // Note: for backwards compatibility purposes, this always reads the FML spec |
| // from "<prefix>_features". |
| template <class EXTRACTOR, class OBJ, class... ARGS> |
| class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { |
| public: |
| // Initializes all predicate maps, feature extractors, etc. |
| bool Init(TaskContext *context) override { |
| if (!GenericEmbeddingFeatureExtractor::Init(context)) { |
| return false; |
| } |
| feature_extractors_.resize(embedding_fml().size()); |
| for (int i = 0; i < embedding_fml().size(); ++i) { |
| feature_extractors_[i].reset(new EXTRACTOR()); |
| if (!feature_extractors_[i]->Parse(embedding_fml()[i])) { |
| return false; |
| } |
| if (!feature_extractors_[i]->Setup(context)) { |
| return false; |
| } |
| } |
| for (auto &feature_extractor : feature_extractors_) { |
| if (!feature_extractor->Init(context)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| // Requests workspaces from the registry. Must be called after Init(), and |
| // before Preprocess(). |
| void RequestWorkspaces(WorkspaceRegistry *registry) override { |
| for (auto &feature_extractor : feature_extractors_) { |
| feature_extractor->RequestWorkspaces(registry); |
| } |
| } |
| |
| // Must be called on the object one state for each sentence, before any |
| // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures). |
| void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const { |
| for (auto &feature_extractor : feature_extractors_) { |
| feature_extractor->Preprocess(workspaces, obj); |
| } |
| } |
| |
| // Extracts features using the extractors. Note that features must already |
| // be initialized to the correct number of feature extractors. No predicate |
| // mapping is applied. |
| void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj, |
| ARGS... args, |
| std::vector<FeatureVector> *features) const { |
| TC_DCHECK(features != nullptr); |
| TC_DCHECK_EQ(features->size(), feature_extractors_.size()); |
| for (int i = 0; i < feature_extractors_.size(); ++i) { |
| (*features)[i].clear(); |
| feature_extractors_[i]->ExtractFeatures(workspaces, obj, args..., |
| &(*features)[i]); |
| } |
| } |
| |
| protected: |
| // Provides generic access to the feature extractors. |
| const GenericFeatureExtractor *generic_feature_extractor( |
| int idx) const override { |
| if ((idx < 0) || (idx >= feature_extractors_.size())) { |
| TC_LOG(ERROR) << "Out of bounds index " << idx; |
| TC_DCHECK(false); // Crash in debug mode. |
| return nullptr; |
| } |
| return feature_extractors_[idx].get(); |
| } |
| |
| private: |
| // Templated feature extractor class. |
| std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_; |
| }; |
| |
| } // namespace nlp_core |
| } // namespace libtextclassifier |
| |
| #endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ |