blob: 64a4ee2c2cdff2da4d43cf5dda96a57d4d0aa0a8 [file] [log] [blame]
Michael Butler92399812019-01-17 17:54:51 -08001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ExecutionBurstServer.h"
18
19#include <android-base/logging.h>
20
21namespace android {
22namespace nn {
23
24BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback) : mCallback(callback) {}
25
26hidl_vec<hidl_memory> BurstMemoryCache::getMemories(const std::vector<int32_t>& slots) {
27 std::lock_guard<std::mutex> guard(mMutex);
28
29 // find unique unknown slots
30 std::vector<int32_t> unknownSlots = slots;
31 std::sort(unknownSlots.begin(), unknownSlots.end());
32 auto last = std::unique(unknownSlots.begin(), unknownSlots.end());
33 unknownSlots.erase(last, unknownSlots.end());
34
35 // retrieve unknown slots
36 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
37 std::vector<hidl_memory> returnedMemories;
38 Return<void> ret = mCallback->getMemories(
39 unknownSlots, [&errorStatus, &returnedMemories](ErrorStatus status,
40 const hidl_vec<hidl_memory>& memories) {
41 errorStatus = status;
42 if (status == ErrorStatus::NONE) {
43 returnedMemories = memories;
44 }
45 });
46
47 if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
48 LOG(ERROR) << "Error retrieving memories";
49 return {};
50 }
51
52 // add memories to unknown slots
53 for (size_t i = 0; i < unknownSlots.size(); ++i) {
54 mSlotToMemoryCache[unknownSlots[i]] = returnedMemories[i];
55 }
56
57 // get all slots
58 hidl_vec<hidl_memory> memories(slots.size());
59 for (size_t i = 0; i < slots.size(); ++i) {
60 memories[i] = mSlotToMemoryCache[slots[i]];
61 }
62 return memories;
63}
64
65void BurstMemoryCache::freeMemory(int32_t slot) {
66 std::lock_guard<std::mutex> guard(mMutex);
67 mSlotToMemoryCache.erase(slot);
68}
69
70ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
71 std::unique_ptr<FmqRequestChannel> requestChannel,
72 std::unique_ptr<FmqResultChannel> resultChannel,
73 IPreparedModel* preparedModel)
74 : mMemoryCache(callback),
75 mFmqRequestChannel(std::move(requestChannel)),
76 mFmqResultChannel(std::move(resultChannel)),
77 mPreparedModel(preparedModel),
78 mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
79 // TODO: highly document the threading behavior of this class
80 mWorker = std::async(std::launch::async, [this] { task(); });
81}
82
83ExecutionBurstServer::~ExecutionBurstServer() {
84 // set teardown flag
85 mTeardown = true;
86
87 // force unblock
88 if (mBlocking) {
89 // TODO: look for a different/better way to signal/notify the futex to wake
90 // up any thread waiting on it
91 FmqRequestDatum datum;
92 datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
93 /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
94 mFmqRequestChannel->writeBlocking(&datum, 1);
95 }
96
97 // wait for task thread to end
98 mWorker.wait();
99}
100
101bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
102 if (mTeardown) {
103 return false;
104 }
105
106 if (mBlocking) {
107 return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
108 } else {
109 return mFmqResultChannel->write(packet.data(), packet.size());
110 }
111}
112
113std::vector<FmqRequestDatum> ExecutionBurstServer::getPacketBlocking() {
114 using discriminator = FmqRequestDatum::hidl_discriminator;
115
116 if (mTeardown) {
117 return {};
118 }
119
120 // wait for request packet and read first element of result packet
121 FmqRequestDatum datum;
122 bool success = false;
123 if (mBlocking) {
124 success = mFmqRequestChannel->readBlocking(&datum, 1);
125 } else {
126 while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
127 !mFmqRequestChannel->read(&datum, 1)) {
128 }
129 }
130
131 // terminate loop
132 if (mTeardown) {
133 return {};
134 }
135
136 // validate packet information
137 if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
138 LOG(ERROR) << "FMQ Request packet ill-formed";
139 return {};
140 }
141
142 // unpack packet information
143 const auto& packetInfo = datum.packetInformation();
144 const size_t count = packetInfo.packetSize;
145
146 // retrieve remaining elements
147 // NOTE: all of the data is already available at this point, so there's no
148 // need to do a blocking wait to wait for more data
149 std::vector<FmqRequestDatum> packet(count);
150 packet.front() = datum;
151 success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
152
153 if (!success) {
154 return {};
155 }
156
157 return packet;
158}
159
160// deserialize request
161std::pair<Request, MeasureTiming> ExecutionBurstServer::deserialize(
162 const std::vector<FmqRequestDatum>& data) {
163 using discriminator = FmqRequestDatum::hidl_discriminator;
164
165 Request request;
166 size_t index = 0;
167
168 // validate packet information
169 if (data[index].getDiscriminator() != discriminator::packetInformation) {
170 LOG(ERROR) << "FMQ Request packet ill-formed";
171 return {{}, MeasureTiming::NO};
172 }
173
174 // unpackage packet information
175 const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
176 index++;
177 const uint32_t packetSize = packetInfo.packetSize;
178 const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
179 const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
180 const uint32_t numberOfPools = packetInfo.numberOfPools;
181
182 // unpackage input operands
183 std::vector<RequestArgument> inputs;
184 inputs.reserve(numberOfInputOperands);
185 for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
186 // validate input operand information
187 if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
188 LOG(ERROR) << "FMQ Request packet ill-formed";
189 return {{}, MeasureTiming::NO};
190 }
191
192 // unpackage operand information
193 const FmqRequestDatum::OperandInformation& operandInfo =
194 data[index].inputOperandInformation();
195 index++;
196 const bool hasNoValue = operandInfo.hasNoValue;
197 const DataLocation location = operandInfo.location;
198 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
199
200 // unpackage operand dimensions
201 std::vector<uint32_t> dimensions;
202 dimensions.reserve(numberOfDimensions);
203 for (size_t i = 0; i < numberOfDimensions; ++i) {
204 // validate dimension
205 if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
206 LOG(ERROR) << "FMQ Request packet ill-formed";
207 return {{}, MeasureTiming::NO};
208 }
209
210 // unpackage dimension
211 const uint32_t dimension = data[index].inputOperandDimensionValue();
212 index++;
213
214 // store result
215 dimensions.push_back(dimension);
216 }
217
218 // store result
219 inputs.push_back(
220 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
221 }
222
223 // unpackage output operands
224 std::vector<RequestArgument> outputs;
225 outputs.reserve(numberOfOutputOperands);
226 for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
227 // validate output operand information
228 if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
229 LOG(ERROR) << "FMQ Request packet ill-formed";
230 return {{}, MeasureTiming::NO};
231 }
232
233 // unpackage operand information
234 const FmqRequestDatum::OperandInformation& operandInfo =
235 data[index].outputOperandInformation();
236 index++;
237 const bool hasNoValue = operandInfo.hasNoValue;
238 const DataLocation location = operandInfo.location;
239 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
240
241 // unpackage operand dimensions
242 std::vector<uint32_t> dimensions;
243 dimensions.reserve(numberOfDimensions);
244 for (size_t i = 0; i < numberOfDimensions; ++i) {
245 // validate dimension
246 if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
247 LOG(ERROR) << "FMQ Request packet ill-formed";
248 return {{}, MeasureTiming::NO};
249 }
250
251 // unpackage dimension
252 const uint32_t dimension = data[index].outputOperandDimensionValue();
253 index++;
254
255 // store result
256 dimensions.push_back(dimension);
257 }
258
259 // store result
260 outputs.push_back(
261 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
262 }
263
264 // unpackage pools
265 std::vector<int32_t> slots;
266 slots.reserve(numberOfPools);
267 for (size_t pool = 0; pool < numberOfPools; ++pool) {
268 // validate input operand information
269 if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
270 LOG(ERROR) << "FMQ Request packet ill-formed";
271 return {{}, MeasureTiming::NO};
272 }
273
274 // unpackage operand information
275 const int32_t poolId = data[index].poolIdentifier();
276 index++;
277
278 // store result
279 slots.push_back(poolId);
280 }
281 hidl_vec<hidl_memory> pools = mMemoryCache.getMemories(slots);
282
283 // validate measureTiming
284 if (data[index].getDiscriminator() != discriminator::measureTiming) {
285 LOG(ERROR) << "FMQ Request packet ill-formed";
286 return {{}, MeasureTiming::NO};
287 }
288
289 // unpackage measureTiming
290 const MeasureTiming measure = data[index].measureTiming();
291 index++;
292
293 // validate packet information
294 if (index != packetSize) {
295 LOG(ERROR) << "FMQ Result packet ill-formed";
296 return {{}, MeasureTiming::NO};
297 }
298
299 // return request
300 return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/std::move(pools)}, measure};
301}
302
303// serialize result
304std::vector<FmqResultDatum> ExecutionBurstServer::serialize(
305 ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing) {
306 // count how many elements need to be sent for a request
307 size_t count = 2 + outputShapes.size();
308 for (const auto& outputShape : outputShapes) {
309 count += outputShape.dimensions.size();
310 }
311
312 // create buffer to temporarily store elements
313 std::vector<FmqResultDatum> data;
314 data.reserve(count);
315
316 // package packetInfo
317 {
318 FmqResultDatum datum;
319 datum.packetInformation({/*.packetSize=*/static_cast<uint32_t>(count),
320 /*.errorStatus=*/errorStatus,
321 /*.numberOfOperands=*/static_cast<uint32_t>(outputShapes.size())});
322 data.push_back(datum);
323 }
324
325 // package output shape data
326 for (const auto& operand : outputShapes) {
327 // package operand information
328 FmqResultDatum datum;
329 datum.operandInformation(
330 {/*.isSufficient=*/operand.isSufficient,
331 /*.numberOfDimensions=*/static_cast<uint32_t>(operand.dimensions.size())});
332 data.push_back(datum);
333
334 // package operand dimensions
335 for (uint32_t dimension : operand.dimensions) {
336 FmqResultDatum datum;
337 datum.operandDimensionValue(dimension);
338 data.push_back(datum);
339 }
340 }
341
342 // package executionTiming
343 {
344 FmqResultDatum datum;
345 datum.executionTiming(timing);
346 data.push_back(datum);
347 }
348
349 // return result
350 return data;
351}
352
353Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
354 mMemoryCache.freeMemory(slot);
355 return Void();
356}
357
358void ExecutionBurstServer::task() {
359 while (!mTeardown) {
360 // receive request
361 const std::vector<FmqRequestDatum> requestData = getPacketBlocking();
362
363 // terminate loop
364 if (mTeardown) {
365 return;
366 }
367
368 // continue processing
369 Request request;
370 MeasureTiming measure;
371 std::tie(request, measure) = deserialize(requestData);
372
373 // perform computation
374 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
375 std::vector<OutputShape> outputShapes;
376 Timing returnedTiming;
377 mPreparedModel
378 ->executeSynchronously(request, measure,
379 [&errorStatus, &outputShapes, &returnedTiming](
380 ErrorStatus status,
381 const hidl_vec<OutputShape>& shapes, Timing timing) {
382 errorStatus = status;
383 outputShapes = shapes;
384 returnedTiming = timing;
385 })
386 .isOk();
387
388 // return result
389 const std::vector<FmqResultDatum> result =
390 serialize(errorStatus, outputShapes, returnedTiming);
391 sendPacket(result);
392 }
393}
394
395sp<IBurstContext> createBurstContext(const sp<IBurstCallback>& callback,
396 const MQDescriptorSync<FmqRequestDatum>& requestChannel,
397 const MQDescriptorSync<FmqResultDatum>& resultChannel,
398 IPreparedModel* preparedModel) {
399 // check inputs
400 if (callback == nullptr || preparedModel == nullptr) {
401 LOG(ERROR) << "createExecutionBurstServer passed a nullptr";
402 return nullptr;
403 }
404
405 // create FMQ objects
406 std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
407 FmqRequestChannel(requestChannel)};
408 std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
409 FmqResultChannel(resultChannel)};
410
411 // check FMQ objects
412 if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
413 !fmqResultChannel->isValid()) {
414 LOG(ERROR) << "createExecutionBurstServer failed to create FastMessageQueue";
415 return nullptr;
416 }
417
418 // make and return context
419 return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
420 std::move(fmqResultChannel), preparedModel);
421}
422
423} // namespace nn
424} // namespace android