blob: 6b2f68c2948e0034d8a5d437029affdd4e8bb2ee [file] [log] [blame]
Jenkinsb3a371b2018-05-23 11:36:53 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph/detail/CrossLayerMemoryManagerHelpers.h"
25
26#include "arm_compute/graph/Graph.h"
27#include "arm_compute/graph/GraphContext.h"
28#include "arm_compute/graph/GraphManager.h"
29#include "arm_compute/graph/INode.h"
30#include "arm_compute/graph/Tensor.h"
31#include "arm_compute/graph/Types.h"
32#include "arm_compute/graph/backends/BackendRegistry.h"
33
34#include "arm_compute/core/ITensor.h"
35#include "arm_compute/core/utils/misc/Cast.h"
36
37#include <algorithm>
38#include <map>
39
40namespace arm_compute
41{
42namespace graph
43{
44namespace detail
45{
46namespace
47{
48using HandleCountPair = std::pair<ITensorHandle *, unsigned int>;
49using HandleCounter = std::map<HandleCountPair::first_type, HandleCountPair::second_type>;
50using TargetHandleCounter = std::map<Target, HandleCounter>;
51
52/** Holds managed IO tensor handles if a task */
53struct TaskHandles
54{
55 std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> input_handles = {}; /**< Input handles to a task */
56 std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> output_handles = {}; /**< Output handles of a task */
57};
58
59/** Returns memory group depending on handle backend type
60 *
61 * @param[in] ctx Graph context
62 * @param[in] handle Tensor handle
63 *
64 * @return Memory groupb
65 */
66IMemoryGroup *get_memory_group_from_handle(GraphContext &ctx, ITensorHandle *handle)
67{
68 ARM_COMPUTE_ERROR_ON(handle == nullptr);
69 return ctx.memory_management_ctx(handle->target())->cross_group.get();
70}
71
72/** Get handles of const tensors of graph
73 *
74 * @param[in] g Graph
75 *
76 * @return Handles of const tensors of graph
77 */
78std::set<ITensorHandle *> get_const_handles(const Graph &g)
79{
80 std::set<NodeType> const_node_types = { NodeType::Input, NodeType::Output, NodeType::Const };
81
82 std::set<ITensorHandle *> const_tensors;
83
84 auto &nodes = g.nodes();
85 for(auto &node : nodes)
86 {
87 // If its a const node:
88 if(node != nullptr && const_node_types.find(node->type()) != std::end(const_node_types))
89 {
90 // Add all its inputs / outputs to the list of constant handles
91 for(unsigned int i = 0; i < node->num_inputs(); ++i)
92 {
93 if(node->input(i) != nullptr)
94 {
95 const_tensors.insert(node->input(i)->handle()->parent_handle());
96 }
97 }
98 for(unsigned int i = 0; i < node->num_outputs(); ++i)
99 {
100 if(node->output(i) != nullptr)
101 {
102 const_tensors.insert(node->output(i)->handle()->parent_handle());
103 }
104 }
105 }
106 }
107
108 return const_tensors;
109}
110
111/** Builds a list of all the transition handles (Handles that are used to link two nodes)
112 *
113 * @param[in] ctx Graph context
114 * @param[in] task Workload task
115 * @param[in] const_tensors Constant tensors
116 *
117 * @return List of transition handles
118 */
119TaskHandles get_transition_handles(GraphContext &ctx,
120 ExecutionTask &task,
121 const std::set<ITensorHandle *> &const_tensors)
122{
123 ARM_COMPUTE_ERROR_ON(task.node == nullptr || task.task == nullptr);
124 INode &node = *task.node;
125
126 TaskHandles transition_handles;
127
128 // Add input handles
129 for(unsigned int i = 0; i < node.input_edges().size(); ++i)
130 {
131 Edge *input_edge = node.input_edge(i);
132 // If this input is the output of another node
133 if(input_edge != nullptr && input_edge->tensor() != nullptr && const_tensors.find(input_edge->tensor()->handle()->parent_handle()) == std::end(const_tensors))
134 {
135 // Then add it to the list of transition buffers
136 ITensorHandle *tensor_handle = input_edge->tensor()->handle()->parent_handle();
137 IMemoryGroup *mm_group = get_memory_group_from_handle(ctx, tensor_handle);
138 transition_handles.input_handles.push_back(std::make_pair(tensor_handle, mm_group));
139 }
140 }
141
142 // Add output handles
143 for(unsigned int i = 0; i < node.num_outputs(); ++i)
144 {
145 Tensor *output_tensor = node.output(i);
146 // If this output is used as an input for another node
147 if(output_tensor != nullptr && const_tensors.find(output_tensor->handle()->parent_handle()) == std::end(const_tensors))
148 {
149 ITensorHandle *tensor_handle = output_tensor->handle()->parent_handle();
150 IMemoryGroup *mm_group = get_memory_group_from_handle(ctx, tensor_handle);
151 transition_handles.output_handles.push_back(std::make_pair(tensor_handle, mm_group));
152 }
153 }
154
155 return transition_handles;
156}
157
158/** Counts handles refcount for each input handle of each target
159 *
160 * @param[in] task Execution task containing the managed handles
161 * @param[in,out] handle_counter Data structure that keeps the handles reference count
162 */
163void count_input_handles_per_target(const TaskHandles &task_handles, TargetHandleCounter &handle_counter)
164{
165 for(const auto &handle : task_handles.input_handles)
166 {
167 ITensorHandle *key = handle.first;
168 HandleCounter &target_counter = handle_counter[key->target()];
169 if(target_counter.find(key) == std::end(target_counter))
170 {
171 target_counter.emplace(std::make_pair(key, 1));
172 }
173 else
174 {
175 ++target_counter[key];
176 }
177 }
178}
179
180/** Calculates the lifetime of each tensor handle
181 *
182 * @param[in, out] tasks_handles Tensor handles for each task
183 * @param[in] hc Data structure that keeps the handles reference count
184 */
185void configure_handle_lifetime(std::vector<TaskHandles> &tasks_handles, const HandleCounter &hc)
186{
187 // Identify max number of tensors in flight
188 HandleCounter tensors_in_flight;
189
190 // Acquires the given handles and sets them as in flight if they aren't already
191 auto acquire = [&](std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> &handles)
192 {
193 for(auto &handle : handles)
194 {
195 ITensorHandle *parent_handle = handle.first;
196 ARM_COMPUTE_ERROR_ON(parent_handle == nullptr);
197 // If the tensor is not already in flight:
198 if(tensors_in_flight.find(parent_handle) == std::end(tensors_in_flight))
199 {
200 ARM_COMPUTE_ERROR_ON(hc.find(parent_handle) == std::end(hc));
201 // Then add it to the list of in flight tensors
202 tensors_in_flight.insert(std::make_pair(parent_handle, hc.at(parent_handle)));
203 // Start of allocation's lifetime
204 parent_handle->manage(handle.second);
205 }
206 }
207 };
208
209 for(auto &task_handle : tasks_handles)
210 {
211 // Marking all the input and output tensors of the task as in flight
212 acquire(task_handle.input_handles);
213 acquire(task_handle.output_handles);
214
215 // Releasing the input tensors
216 for(auto &input_handle : task_handle.input_handles)
217 {
218 ITensorHandle *ihandle = input_handle.first;
219 ARM_COMPUTE_ERROR_ON(ihandle == nullptr);
220 ARM_COMPUTE_ERROR_ON(tensors_in_flight.find(ihandle) == std::end(tensors_in_flight));
221 --tensors_in_flight[ihandle];
222 if(tensors_in_flight[ihandle] <= 0)
223 {
224 // Remove tensor for tensors in flight
225 tensors_in_flight.erase(ihandle);
226 // End of allocation's lifetime
227 ihandle->allocate();
228 }
229 }
230 }
231}
232} // namespace
233
234void configure_transition_manager(Graph &g, GraphContext &ctx, ExecutionWorkload &workload)
235{
236 // Get const tensors (un-managed)
237 std::set<ITensorHandle *> const_tensors = get_const_handles(g);
238
239 std::vector<TaskHandles> tasks_handles;
240 TargetHandleCounter target_handle_count;
241
242 // Count handles
243 for(auto &task : workload.tasks)
244 {
245 // Populates IO handles
246 tasks_handles.push_back(get_transition_handles(ctx, task, const_tensors));
247
248 // Count handles
249 count_input_handles_per_target(tasks_handles.back(), target_handle_count);
250 }
251
252 // Setup memory managers
253 for(auto &hc : target_handle_count)
254 {
255 MemoryManagerContext *mm_ctx = ctx.memory_management_ctx(hc.first);
256 if(mm_ctx != nullptr)
257 {
258 if(mm_ctx->cross_mm != nullptr && mm_ctx->cross_group != nullptr)
259 {
260 // Manage and allocate tensors
261 configure_handle_lifetime(tasks_handles, hc.second);
262 }
263 }
264 }
265}
266} // namespace detail
267} // namespace graph
268} // namespace arm_compute