telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | 93e4898 | 2018-09-05 13:05:09 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 4 | // |
| 5 | |
| 6 | #pragma once |
| 7 | |
surmeh01 | deb3bdb | 2018-07-05 12:06:04 +0100 | [diff] [blame] | 8 | #include "ArmnnDriver.hpp" |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame] | 9 | #include "ConversionUtils.hpp" |
telsoa01 | ce3e84a | 2018-08-31 09:31:35 +0100 | [diff] [blame] | 10 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 11 | #include <armnn/ArmNN.hpp> |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 12 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 13 | #include <set> |
Nattapat Chaimanowong | d5fd976 | 2019-04-04 13:33:10 +0100 | [diff] [blame] | 14 | #include <vector> |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 15 | |
| 16 | namespace armnn_driver |
| 17 | { |
| 18 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 19 | enum class ConversionResult |
| 20 | { |
| 21 | Success, |
| 22 | ErrorMappingPools, |
| 23 | UnsupportedFeature |
| 24 | }; |
| 25 | |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame] | 26 | // A helper template class performing the conversion from an AndroidNN driver Model representation, |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 27 | // to an armnn::INetwork object |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame] | 28 | template<typename HalPolicy> |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 29 | class ModelToINetworkConverter |
| 30 | { |
| 31 | public: |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame] | 32 | using HalModel = typename HalPolicy::Model; |
kevmay01 | bc5f784 | 2018-08-30 12:34:39 +0100 | [diff] [blame] | 33 | |
Nattapat Chaimanowong | d5fd976 | 2019-04-04 13:33:10 +0100 | [diff] [blame] | 34 | ModelToINetworkConverter(const std::vector<armnn::BackendId>& backends, |
Matteo Martincigh | e48bdff | 2018-09-03 13:50:50 +0100 | [diff] [blame] | 35 | const HalModel& model, |
| 36 | const std::set<unsigned int>& forcedUnsupportedOperations); |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 37 | |
| 38 | ConversionResult GetConversionResult() const { return m_ConversionResult; } |
| 39 | |
| 40 | // Returns the ArmNN INetwork corresponding to the input model, if preparation went smoothly, nullptr otherwise. |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame] | 41 | armnn::INetwork* GetINetwork() const { return m_Data.m_Network.get(); } |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 42 | |
| 43 | bool IsOperationSupported(uint32_t operationIndex) const; |
| 44 | |
| 45 | private: |
| 46 | void Convert(); |
| 47 | |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame] | 48 | // Shared aggregate input/output/internal data |
| 49 | ConversionData m_Data; |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 50 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 51 | // Input data |
kevmay01 | bc5f784 | 2018-08-30 12:34:39 +0100 | [diff] [blame] | 52 | const HalModel& m_Model; |
| 53 | const std::set<unsigned int>& m_ForcedUnsupportedOperations; |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 54 | |
| 55 | // Output data |
telsoa01 | ce3e84a | 2018-08-31 09:31:35 +0100 | [diff] [blame] | 56 | ConversionResult m_ConversionResult; |
| 57 | std::map<uint32_t, bool> m_OperationSupported; |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 58 | }; |
| 59 | |
Matteo Martincigh | 79250ab | 2018-09-04 16:28:10 +0100 | [diff] [blame] | 60 | } // armnn_driver |