blob: 5ffb9a9ad6a6c3c045d26f0d86eb1de4aff4a5cf [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
18#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
19
20#include <memory>
21#include <vector>
22
23#include "model-executor.h"
24#include "model_generated.h"
25#include "types.h"
26
27namespace libtextclassifier2 {
28
29// Holds state for extracting features across multiple calls and reusing them.
30// Assumes that features for each Token are independent.
31class CachedFeatures {
32 public:
33 CachedFeatures(
34 const TokenSpan& extraction_span,
35 const std::vector<std::vector<int>>& sparse_features,
36 const std::vector<std::vector<float>>& dense_features,
37 const std::vector<int>& padding_sparse_features,
38 const std::vector<float>& padding_dense_features,
39 const FeatureProcessorOptions_::BoundsSensitiveFeatures* config,
40 EmbeddingExecutor* embedding_executor, int feature_vector_size);
41
42 // Gets a vector of features for the given token span.
43 std::vector<float> Get(TokenSpan selected_span) const;
44
45 private:
46 // Appends token features to the output. The intended_span specifies which
47 // tokens' features should be used in principle. The read_mask_span restricts
48 // which tokens are actually read. For tokens outside of the read_mask_span,
49 // padding tokens are used instead.
50 void AppendFeatures(const TokenSpan& intended_span,
51 const TokenSpan& read_mask_span,
52 std::vector<float>* output_features) const;
53
54 // Appends features of one padding token to the output.
55 void AppendPaddingFeatures(std::vector<float>* output_features) const;
56
57 // Appends the features of tokens from the given span to the output. The
58 // features are summed so that the appended features have the size
59 // corresponding to one token.
60 void AppendSummedFeatures(const TokenSpan& summing_span,
61 std::vector<float>* output_features) const;
62
63 int NumFeaturesPerToken() const;
64
65 const TokenSpan extraction_span_;
66 const FeatureProcessorOptions_::BoundsSensitiveFeatures* config_;
67 int output_features_size_;
68 std::vector<float> features_;
69 std::vector<float> padding_features_;
70};
71
72} // namespace libtextclassifier2
73
74#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_