blob: bcc318b4fa4654689248e060a05511834655e3b0 [file] [log] [blame]
Tony Mak6c4cc672018-09-17 11:48:50 +01001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Contains classes that can execute different models/parts of a model.
18
19#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
20#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
21
22#include <memory>
23
24#include "annotator/types.h"
25#include "utils/base/logging.h"
26#include "utils/tensor-view.h"
27#include "utils/tflite-model-executor.h"
28
29namespace libtextclassifier3 {
30
31// Executor for the text selection prediction and classification models.
32class ModelExecutor : public TfLiteModelExecutor {
33 public:
34 static std::unique_ptr<ModelExecutor> FromModelSpec(
35 const tflite::Model* model_spec) {
36 auto model = TfLiteModelFromModelSpec(model_spec);
37 if (!model) {
38 return nullptr;
39 }
40 return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
41 }
42
43 static std::unique_ptr<ModelExecutor> FromBuffer(
44 const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
45 auto model = TfLiteModelFromBuffer(model_spec_buffer);
46 if (!model) {
47 return nullptr;
48 }
49 return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
50 }
51
52 TensorView<float> ComputeLogits(const TensorView<float>& features,
53 tflite::Interpreter* interpreter) const;
54
55 protected:
56 explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
57 : TfLiteModelExecutor(std::move(model)) {}
58
59 static const int kInputIndexFeatures = 0;
60 static const int kOutputIndexLogits = 0;
61};
62
63// Executor for embedding sparse features into a dense vector.
64class EmbeddingExecutor {
65 public:
66 virtual ~EmbeddingExecutor() {}
67
68 // Embeds the sparse_features into a dense embedding and adds (+) it
69 // element-wise to the dest vector.
70 virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
71 int dest_size) const = 0;
72
73 // Returns true when the model is ready to be used, false otherwise.
74 virtual bool IsReady() const { return true; }
75};
76
77class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
78 public:
79 static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
80 const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
Tony Makdf54e742019-03-26 14:04:00 +000081 int quantization_bits,
82 const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
Tony Mak6c4cc672018-09-17 11:48:50 +010083
84 // Embeds the sparse_features into a dense embedding and adds (+) it
85 // element-wise to the dest vector.
86 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
87 int dest_size) const;
88
Tony Makdf54e742019-03-26 14:04:00 +000089 // Auxiliary function for computing prefixes used in implementation of
90 // efficient mask indexing data structure.
91 void ComputePrefixCounts();
92
93 // Function implementing mask indexing based on efficient data structure
94 int PruneBucketId(int bucket_id) const;
95
Tony Mak6c4cc672018-09-17 11:48:50 +010096 protected:
97 explicit TFLiteEmbeddingExecutor(
98 std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
99 int num_buckets, int bytes_per_embedding, int output_embedding_size,
100 const TfLiteTensor* scales, const TfLiteTensor* embeddings,
Tony Makdf54e742019-03-26 14:04:00 +0000101 std::unique_ptr<tflite::Interpreter> interpreter,
102 const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
Tony Mak6c4cc672018-09-17 11:48:50 +0100103
104 std::unique_ptr<TfLiteModelExecutor> executor_;
105
106 int quantization_bits_;
107 int num_buckets_ = -1;
108 int bytes_per_embedding_ = -1;
109 int output_embedding_size_ = -1;
110 const TfLiteTensor* scales_ = nullptr;
111 const TfLiteTensor* embeddings_ = nullptr;
112
113 // NOTE: This interpreter is used in a read-only way (as a storage for the
114 // model params), thus is still thread-safe.
115 std::unique_ptr<tflite::Interpreter> interpreter_;
Tony Makdf54e742019-03-26 14:04:00 +0000116
117 std::vector<uint64> pruning_mask_;
118 std::vector<uint16> prefix_counts_;
119 int full_num_buckets_ = -1;
120
121 // Index of row of embedding table corresponding to all pruned buckets.
122 int pruned_row_bucket_id_ = -1;
Tony Mak6c4cc672018-09-17 11:48:50 +0100123};
124
125} // namespace libtextclassifier3
126
127#endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_