blob: 0224d86f2fe2258289200bbeef87ac22e931bb90 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Lukas Zilkab23e2122018-02-09 10:25:19 +010017#ifndef LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
18#define LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
Lukas Zilka21d8c982018-01-24 11:11:20 +010019
20#include <memory>
21#include <vector>
22
23#include "model-executor.h"
24#include "model_generated.h"
25#include "types.h"
26
27namespace libtextclassifier2 {
28
29// Holds state for extracting features across multiple calls and reusing them.
30// Assumes that features for each Token are independent.
31class CachedFeatures {
32 public:
Lukas Zilkab23e2122018-02-09 10:25:19 +010033 static std::unique_ptr<CachedFeatures> Create(
Lukas Zilka21d8c982018-01-24 11:11:20 +010034 const TokenSpan& extraction_span,
Lukas Zilkaba849e72018-03-08 14:48:21 +010035 std::unique_ptr<std::vector<float>> features,
36 std::unique_ptr<std::vector<float>> padding_features,
37 const FeatureProcessorOptions* options, int feature_vector_size);
Lukas Zilka21d8c982018-01-24 11:11:20 +010038
Lukas Zilkab23e2122018-02-09 10:25:19 +010039 // Appends the click context features for the given click position to
40 // 'output_features'.
41 void AppendClickContextFeaturesForClick(
42 int click_pos, std::vector<float>* output_features) const;
43
44 // Appends the bounds-sensitive features for the given token span to
45 // 'output_features'.
46 void AppendBoundsSensitiveFeaturesForSpan(
47 TokenSpan selected_span, std::vector<float>* output_features) const;
48
49 // Returns number of features that 'AppendFeaturesForSpan' appends.
50 int OutputFeaturesSize() const { return output_features_size_; }
Lukas Zilka21d8c982018-01-24 11:11:20 +010051
52 private:
Lukas Zilkab23e2122018-02-09 10:25:19 +010053 CachedFeatures() {}
54
Lukas Zilka21d8c982018-01-24 11:11:20 +010055 // Appends token features to the output. The intended_span specifies which
56 // tokens' features should be used in principle. The read_mask_span restricts
57 // which tokens are actually read. For tokens outside of the read_mask_span,
58 // padding tokens are used instead.
Lukas Zilkab23e2122018-02-09 10:25:19 +010059 void AppendFeaturesInternal(const TokenSpan& intended_span,
60 const TokenSpan& read_mask_span,
61 std::vector<float>* output_features) const;
Lukas Zilka21d8c982018-01-24 11:11:20 +010062
63 // Appends features of one padding token to the output.
64 void AppendPaddingFeatures(std::vector<float>* output_features) const;
65
66 // Appends the features of tokens from the given span to the output. The
Lukas Zilkab23e2122018-02-09 10:25:19 +010067 // features are averaged so that the appended features have the size
Lukas Zilka21d8c982018-01-24 11:11:20 +010068 // corresponding to one token.
Lukas Zilkab23e2122018-02-09 10:25:19 +010069 void AppendBagFeatures(const TokenSpan& bag_span,
70 std::vector<float>* output_features) const;
Lukas Zilka21d8c982018-01-24 11:11:20 +010071
72 int NumFeaturesPerToken() const;
73
Lukas Zilkab23e2122018-02-09 10:25:19 +010074 TokenSpan extraction_span_;
75 const FeatureProcessorOptions* options_;
Lukas Zilka21d8c982018-01-24 11:11:20 +010076 int output_features_size_;
Lukas Zilkaba849e72018-03-08 14:48:21 +010077 std::unique_ptr<std::vector<float>> features_;
78 std::unique_ptr<std::vector<float>> padding_features_;
Lukas Zilka21d8c982018-01-24 11:11:20 +010079};
80
81} // namespace libtextclassifier2
82
Lukas Zilkab23e2122018-02-09 10:25:19 +010083#endif // LIBTEXTCLASSIFIER_CACHED_FEATURES_H_