blob: 86b700f90792fddbf2cfe1b8fd4a171d184fa0b5 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Lukas Zilkab23e2122018-02-09 10:25:19 +010017#ifndef LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
18#define LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
Lukas Zilka21d8c982018-01-24 11:11:20 +010019
20#include <memory>
21#include <vector>
22
23#include "model-executor.h"
24#include "model_generated.h"
25#include "types.h"
26
27namespace libtextclassifier2 {
28
29// Holds state for extracting features across multiple calls and reusing them.
30// Assumes that features for each Token are independent.
31class CachedFeatures {
32 public:
Lukas Zilkab23e2122018-02-09 10:25:19 +010033 static std::unique_ptr<CachedFeatures> Create(
Lukas Zilka21d8c982018-01-24 11:11:20 +010034 const TokenSpan& extraction_span,
35 const std::vector<std::vector<int>>& sparse_features,
36 const std::vector<std::vector<float>>& dense_features,
37 const std::vector<int>& padding_sparse_features,
38 const std::vector<float>& padding_dense_features,
Lukas Zilkab23e2122018-02-09 10:25:19 +010039 const FeatureProcessorOptions* options,
Lukas Zilka21d8c982018-01-24 11:11:20 +010040 EmbeddingExecutor* embedding_executor, int feature_vector_size);
41
Lukas Zilkab23e2122018-02-09 10:25:19 +010042 // Appends the click context features for the given click position to
43 // 'output_features'.
44 void AppendClickContextFeaturesForClick(
45 int click_pos, std::vector<float>* output_features) const;
46
47 // Appends the bounds-sensitive features for the given token span to
48 // 'output_features'.
49 void AppendBoundsSensitiveFeaturesForSpan(
50 TokenSpan selected_span, std::vector<float>* output_features) const;
51
52 // Returns number of features that 'AppendFeaturesForSpan' appends.
53 int OutputFeaturesSize() const { return output_features_size_; }
Lukas Zilka21d8c982018-01-24 11:11:20 +010054
55 private:
Lukas Zilkab23e2122018-02-09 10:25:19 +010056 CachedFeatures() {}
57
Lukas Zilka21d8c982018-01-24 11:11:20 +010058 // Appends token features to the output. The intended_span specifies which
59 // tokens' features should be used in principle. The read_mask_span restricts
60 // which tokens are actually read. For tokens outside of the read_mask_span,
61 // padding tokens are used instead.
Lukas Zilkab23e2122018-02-09 10:25:19 +010062 void AppendFeaturesInternal(const TokenSpan& intended_span,
63 const TokenSpan& read_mask_span,
64 std::vector<float>* output_features) const;
Lukas Zilka21d8c982018-01-24 11:11:20 +010065
66 // Appends features of one padding token to the output.
67 void AppendPaddingFeatures(std::vector<float>* output_features) const;
68
69 // Appends the features of tokens from the given span to the output. The
Lukas Zilkab23e2122018-02-09 10:25:19 +010070 // features are averaged so that the appended features have the size
Lukas Zilka21d8c982018-01-24 11:11:20 +010071 // corresponding to one token.
Lukas Zilkab23e2122018-02-09 10:25:19 +010072 void AppendBagFeatures(const TokenSpan& bag_span,
73 std::vector<float>* output_features) const;
Lukas Zilka21d8c982018-01-24 11:11:20 +010074
75 int NumFeaturesPerToken() const;
76
Lukas Zilkab23e2122018-02-09 10:25:19 +010077 TokenSpan extraction_span_;
78 const FeatureProcessorOptions* options_;
Lukas Zilka21d8c982018-01-24 11:11:20 +010079 int output_features_size_;
80 std::vector<float> features_;
81 std::vector<float> padding_features_;
82};
83
84} // namespace libtextclassifier2
85
Lukas Zilkab23e2122018-02-09 10:25:19 +010086#endif // LIBTEXTCLASSIFIER_CACHED_FEATURES_H_