telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | 93e4898 | 2018-09-05 13:05:09 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 4 | // |
| 5 | |
| 6 | #pragma once |
| 7 | |
surmeh01 | deb3bdb | 2018-07-05 12:06:04 +0100 | [diff] [blame] | 8 | #include "ArmnnDriver.hpp" |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame^] | 9 | #include "ConversionUtils.hpp" |
telsoa01 | ce3e84a | 2018-08-31 09:31:35 +0100 | [diff] [blame] | 10 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 11 | #include <armnn/ArmNN.hpp> |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 12 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 13 | #include <set> |
| 14 | |
| 15 | namespace armnn_driver |
| 16 | { |
| 17 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 18 | enum class ConversionResult |
| 19 | { |
| 20 | Success, |
| 21 | ErrorMappingPools, |
| 22 | UnsupportedFeature |
| 23 | }; |
| 24 | |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame^] | 25 | // A helper template class performing the conversion from an AndroidNN driver Model representation, |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 26 | // to an armnn::INetwork object |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame^] | 27 | template<typename HalPolicy> |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 28 | class ModelToINetworkConverter |
| 29 | { |
| 30 | public: |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame^] | 31 | using HalModel = typename HalPolicy::Model; |
kevmay01 | bc5f784 | 2018-08-30 12:34:39 +0100 | [diff] [blame] | 32 | |
telsoa01 | ce3e84a | 2018-08-31 09:31:35 +0100 | [diff] [blame] | 33 | ModelToINetworkConverter(armnn::Compute compute, |
Matteo Martincigh | e48bdff | 2018-09-03 13:50:50 +0100 | [diff] [blame] | 34 | const HalModel& model, |
| 35 | const std::set<unsigned int>& forcedUnsupportedOperations); |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 36 | |
| 37 | ConversionResult GetConversionResult() const { return m_ConversionResult; } |
| 38 | |
| 39 | // Returns the ArmNN INetwork corresponding to the input model, if preparation went smoothly, nullptr otherwise. |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame^] | 40 | armnn::INetwork* GetINetwork() const { return m_Data.m_Network.get(); } |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 41 | |
| 42 | bool IsOperationSupported(uint32_t operationIndex) const; |
| 43 | |
| 44 | private: |
| 45 | void Convert(); |
| 46 | |
arovir01 | b0717b5 | 2018-09-05 17:03:25 +0100 | [diff] [blame^] | 47 | // Shared aggregate input/output/internal data |
| 48 | ConversionData m_Data; |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 49 | |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 50 | // Input data |
kevmay01 | bc5f784 | 2018-08-30 12:34:39 +0100 | [diff] [blame] | 51 | const HalModel& m_Model; |
| 52 | const std::set<unsigned int>& m_ForcedUnsupportedOperations; |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 53 | |
| 54 | // Output data |
telsoa01 | ce3e84a | 2018-08-31 09:31:35 +0100 | [diff] [blame] | 55 | ConversionResult m_ConversionResult; |
| 56 | std::map<uint32_t, bool> m_OperationSupported; |
telsoa01 | 5307bc1 | 2018-03-09 13:51:08 +0000 | [diff] [blame] | 57 | }; |
| 58 | |
Matteo Martincigh | 79250ab | 2018-09-04 16:28:10 +0100 | [diff] [blame] | 59 | } // armnn_driver |