| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // Contains classes that can execute different models/parts of a model. |
| |
| #ifndef LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ |
| #define LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ |
| |
| #include <memory> |
| |
| #include "tensor-view.h" |
| #include "types.h" |
| #include "util/base/logging.h" |
| #include "tensorflow/contrib/lite/interpreter.h" |
| #include "tensorflow/contrib/lite/kernels/register.h" |
| #include "tensorflow/contrib/lite/model.h" |
| |
| namespace libtextclassifier2 { |
| |
| namespace internal { |
| bool FromModelSpec(const tflite::Model* model_spec, |
| std::unique_ptr<const tflite::FlatBufferModel>* model); |
| } // namespace internal |
| |
| // A helper function that given indices of feature and logits tensor, feature |
| // values computes the logits using given interpreter. |
| TensorView<float> ComputeLogitsHelper(const int input_index_features, |
| const int output_index_logits, |
| const TensorView<float>& features, |
| tflite::Interpreter* interpreter); |
| |
| // Executor for the text selection prediction and classification models. |
| class ModelExecutor { |
| public: |
| static std::unique_ptr<const ModelExecutor> Instance( |
| const flatbuffers::Vector<uint8_t>* model_spec_buffer) { |
| const tflite::Model* model = |
| flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); |
| flatbuffers::Verifier verifier(model_spec_buffer->data(), |
| model_spec_buffer->Length()); |
| if (!model->Verify(verifier)) { |
| return nullptr; |
| } |
| return Instance(model); |
| } |
| |
| static std::unique_ptr<const ModelExecutor> Instance( |
| const tflite::Model* model_spec) { |
| std::unique_ptr<const tflite::FlatBufferModel> model; |
| if (!internal::FromModelSpec(model_spec, &model)) { |
| return nullptr; |
| } |
| return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); |
| } |
| |
| // Creates an Interpreter for the model that serves as a scratch-pad for the |
| // inference. The Interpreter is NOT thread-safe. |
| std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; |
| |
| TensorView<float> ComputeLogits(const TensorView<float>& features, |
| tflite::Interpreter* interpreter) const { |
| return ComputeLogitsHelper(kInputIndexFeatures, kOutputIndexLogits, |
| features, interpreter); |
| } |
| |
| protected: |
| explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model) |
| : model_(std::move(model)) {} |
| |
| static const int kInputIndexFeatures = 0; |
| static const int kOutputIndexLogits = 0; |
| |
| std::unique_ptr<const tflite::FlatBufferModel> model_; |
| tflite::ops::builtin::BuiltinOpResolver builtins_; |
| }; |
| |
| // Executor for embedding sparse features into a dense vector. |
| class EmbeddingExecutor { |
| public: |
| virtual ~EmbeddingExecutor() {} |
| |
| // Embeds the sparse_features into a dense embedding and adds (+) it |
| // element-wise to the dest vector. |
| virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, |
| int dest_size) const = 0; |
| |
| // Returns true when the model is ready to be used, false otherwise. |
| virtual bool IsReady() const { return true; } |
| }; |
| |
| class TFLiteEmbeddingExecutor : public EmbeddingExecutor { |
| public: |
| static std::unique_ptr<TFLiteEmbeddingExecutor> Instance( |
| const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, |
| int quantization_bits); |
| |
| bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, |
| int dest_size) const override; |
| |
| protected: |
| explicit TFLiteEmbeddingExecutor( |
| std::unique_ptr<const tflite::FlatBufferModel> model, |
| int quantization_bits, int num_buckets, int bytes_per_embedding, |
| int output_embedding_size, const TfLiteTensor* scales, |
| const TfLiteTensor* embeddings, |
| std::unique_ptr<tflite::Interpreter> interpreter); |
| |
| std::unique_ptr<const tflite::FlatBufferModel> model_; |
| |
| int quantization_bits_; |
| int num_buckets_ = -1; |
| int bytes_per_embedding_ = -1; |
| int output_embedding_size_ = -1; |
| const TfLiteTensor* scales_ = nullptr; |
| const TfLiteTensor* embeddings_ = nullptr; |
| |
| // NOTE: This interpreter is used in a read-only way (as a storage for the |
| // model params), thus is still thread-safe. |
| std::unique_ptr<tflite::Interpreter> interpreter_; |
| }; |
| |
| } // namespace libtextclassifier2 |
| |
| #endif // LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ |