blob: 84bb424dfd05dbe253752ef369d60498eeeb9d7f [file] [log] [blame]
Michael Butler60296322019-01-17 17:54:51 -08001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Michael Butler89e99ba2019-01-24 02:36:37 -080017#define LOG_TAG "ExecutionBurstServer"
18
Michael Butler60296322019-01-17 17:54:51 -080019#include "ExecutionBurstServer.h"
20
21#include <android-base/logging.h>
Michael Butler89e99ba2019-01-24 02:36:37 -080022#include <set>
23#include <string>
Michael Butler3db6fe52019-01-29 11:20:30 -080024#include "Tracing.h"
Michael Butler60296322019-01-17 17:54:51 -080025
Michael Butler3db6fe52019-01-29 11:20:30 -080026namespace android::nn {
Michael Butler60296322019-01-17 17:54:51 -080027
Michael Butler3db6fe52019-01-29 11:20:30 -080028ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
29 : mCallback(callback) {}
Michael Butler60296322019-01-17 17:54:51 -080030
Michael Butler3db6fe52019-01-29 11:20:30 -080031hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
32 const std::vector<int32_t>& slots) {
Michael Butler60296322019-01-17 17:54:51 -080033 std::lock_guard<std::mutex> guard(mMutex);
34
Michael Butler47c988f62019-03-14 17:34:48 -070035 const auto slotIsKnown = [this](int32_t slot) {
36 return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
37 };
38
Michael Butler60296322019-01-17 17:54:51 -080039 // find unique unknown slots
Michael Butler47c988f62019-03-14 17:34:48 -070040 std::vector<int32_t> unknownSlots = slots;
41 auto unknownSlotsEnd = unknownSlots.end();
42 std::sort(unknownSlots.begin(), unknownSlotsEnd);
43 unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
44 unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
45 unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
Michael Butler60296322019-01-17 17:54:51 -080046
47 // retrieve unknown slots
Michael Butler47c988f62019-03-14 17:34:48 -070048 if (!unknownSlots.empty()) {
Michael Butler89e99ba2019-01-24 02:36:37 -080049 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
50 std::vector<hidl_memory> returnedMemories;
Michael Butler47c988f62019-03-14 17:34:48 -070051 auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
52 const hidl_vec<hidl_memory>& memories) {
53 errorStatus = status;
54 returnedMemories = memories;
55 };
Michael Butler60296322019-01-17 17:54:51 -080056
Michael Butler47c988f62019-03-14 17:34:48 -070057 Return<void> ret = mCallback->getMemories(unknownSlots, cb);
58
59 // Ensure that the memories were successfully returned.
60 // IBurstCallback.hal specifies the that the number of memories returned
61 // must match the number of slots requested:
62 // "slots.size() == buffers.size()"
63 if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
64 returnedMemories.size() != unknownSlots.size()) {
Michael Butler89e99ba2019-01-24 02:36:37 -080065 LOG(ERROR) << "Error retrieving memories";
66 return {};
67 }
Michael Butler60296322019-01-17 17:54:51 -080068
Michael Butler47c988f62019-03-14 17:34:48 -070069 // resize cache to fit new slots if necessary
70 const int32_t maxUnknownSlot = unknownSlots.back();
71 if (maxUnknownSlot >= mMemoryCache.size()) {
72 mMemoryCache.resize(maxUnknownSlot + 1);
73 }
74
Michael Butler89e99ba2019-01-24 02:36:37 -080075 // add memories to unknown slots
Michael Butler47c988f62019-03-14 17:34:48 -070076 for (size_t i = 0; i < unknownSlots.size(); ++i) {
77 mMemoryCache[unknownSlots[i]] = returnedMemories[i];
Michael Butler89e99ba2019-01-24 02:36:37 -080078 }
Michael Butler60296322019-01-17 17:54:51 -080079 }
80
81 // get all slots
82 hidl_vec<hidl_memory> memories(slots.size());
Michael Butler47c988f62019-03-14 17:34:48 -070083 std::transform(slots.begin(), slots.end(), memories.begin(),
84 [this](int32_t slot) { return mMemoryCache[slot]; });
85
86 // Ensure all slots are valid. Although this case is never expected to
87 // occur, theoretically IBurstCallback::getMemories could return invalid
88 // hidl_memory objects that must be protected against.
89 if (!std::all_of(memories.begin(), memories.end(),
90 [](const hidl_memory& memory) { return memory.valid(); })) {
91 LOG(ERROR) << "Error, not all slots are valid!";
92 return {};
Michael Butler60296322019-01-17 17:54:51 -080093 }
Michael Butler89e99ba2019-01-24 02:36:37 -080094
Michael Butler60296322019-01-17 17:54:51 -080095 return memories;
96}
97
Michael Butler3db6fe52019-01-29 11:20:30 -080098void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
Michael Butler60296322019-01-17 17:54:51 -080099 std::lock_guard<std::mutex> guard(mMutex);
Michael Butler47c988f62019-03-14 17:34:48 -0700100 if (slot < mMemoryCache.size()) {
101 mMemoryCache[slot] = {};
102 }
Michael Butler60296322019-01-17 17:54:51 -0800103}
104
Michael Butler3db6fe52019-01-29 11:20:30 -0800105sp<ExecutionBurstServer> ExecutionBurstServer::create(
106 const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
107 const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
108 // check inputs
109 if (callback == nullptr || preparedModel == nullptr) {
110 LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
111 return nullptr;
112 }
113
114 // create FMQ objects
115 std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
116 FmqRequestChannel(requestChannel)};
117 std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
118 FmqResultChannel(resultChannel)};
119
120 // check FMQ objects
121 if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
122 !fmqResultChannel->isValid()) {
123 LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
124 return nullptr;
125 }
126
127 // make and return context
128 return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
129 std::move(fmqResultChannel), preparedModel);
130}
131
Michael Butler60296322019-01-17 17:54:51 -0800132ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
133 std::unique_ptr<FmqRequestChannel> requestChannel,
134 std::unique_ptr<FmqResultChannel> resultChannel,
135 IPreparedModel* preparedModel)
136 : mMemoryCache(callback),
137 mFmqRequestChannel(std::move(requestChannel)),
138 mFmqResultChannel(std::move(resultChannel)),
139 mPreparedModel(preparedModel),
140 mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
141 // TODO: highly document the threading behavior of this class
142 mWorker = std::async(std::launch::async, [this] { task(); });
143}
144
145ExecutionBurstServer::~ExecutionBurstServer() {
146 // set teardown flag
147 mTeardown = true;
148
149 // force unblock
Michael Butler89e99ba2019-01-24 02:36:37 -0800150 // ExecutionBurstServer is by default waiting on a request packet. If the
151 // client process destroys its burst object, the server will still be
152 // waiting on the futex (assuming mBlocking is true). This force unblock
153 // wakes up any thread waiting on the futex.
Michael Butler60296322019-01-17 17:54:51 -0800154 if (mBlocking) {
Michael Butler89e99ba2019-01-24 02:36:37 -0800155 // TODO: look for a different/better way to signal/notify the futex to
156 // wake up any thread waiting on it
Michael Butler60296322019-01-17 17:54:51 -0800157 FmqRequestDatum datum;
158 datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
159 /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
160 mFmqRequestChannel->writeBlocking(&datum, 1);
161 }
162
163 // wait for task thread to end
164 mWorker.wait();
165}
166
167bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
168 if (mTeardown) {
169 return false;
170 }
171
172 if (mBlocking) {
173 return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
174 } else {
175 return mFmqResultChannel->write(packet.data(), packet.size());
176 }
177}
178
179std::vector<FmqRequestDatum> ExecutionBurstServer::getPacketBlocking() {
180 using discriminator = FmqRequestDatum::hidl_discriminator;
181
182 if (mTeardown) {
183 return {};
184 }
185
Michael Butler89e99ba2019-01-24 02:36:37 -0800186 // wait for request packet and read first element of request packet
187 // TODO: have a more elegant way to wait for data, and read it all at once.
188 // For example, EventFlag can be used to directly wait on the futex, and all
189 // the data can be read at once with a non-blocking call to
190 // MessageQueue::read. For further optimization, MessageQueue::beginRead and
191 // MessageQueue::commitRead can be used to avoid an extra copy of the
192 // metadata.
Michael Butler60296322019-01-17 17:54:51 -0800193 FmqRequestDatum datum;
194 bool success = false;
195 if (mBlocking) {
196 success = mFmqRequestChannel->readBlocking(&datum, 1);
197 } else {
198 while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
199 !mFmqRequestChannel->read(&datum, 1)) {
200 }
201 }
202
203 // terminate loop
204 if (mTeardown) {
205 return {};
206 }
207
208 // validate packet information
209 if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
210 LOG(ERROR) << "FMQ Request packet ill-formed";
211 return {};
212 }
213
Michael Butler3db6fe52019-01-29 11:20:30 -0800214 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
215
Michael Butler60296322019-01-17 17:54:51 -0800216 // unpack packet information
217 const auto& packetInfo = datum.packetInformation();
218 const size_t count = packetInfo.packetSize;
219
220 // retrieve remaining elements
221 // NOTE: all of the data is already available at this point, so there's no
Michael Butler3db6fe52019-01-29 11:20:30 -0800222 // need to do a blocking wait to wait for more data. This is known because
223 // in FMQ, all writes are published (made available) atomically. Currently,
224 // the producer always publishes the entire packet in one function call, so
225 // if the first element of the packet is available, the remaining elements
226 // are also available.
Michael Butler60296322019-01-17 17:54:51 -0800227 std::vector<FmqRequestDatum> packet(count);
228 packet.front() = datum;
229 success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
230
231 if (!success) {
232 return {};
233 }
234
235 return packet;
236}
237
238// deserialize request
239std::pair<Request, MeasureTiming> ExecutionBurstServer::deserialize(
240 const std::vector<FmqRequestDatum>& data) {
241 using discriminator = FmqRequestDatum::hidl_discriminator;
242
243 Request request;
244 size_t index = 0;
245
246 // validate packet information
247 if (data[index].getDiscriminator() != discriminator::packetInformation) {
248 LOG(ERROR) << "FMQ Request packet ill-formed";
249 return {{}, MeasureTiming::NO};
250 }
251
252 // unpackage packet information
253 const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
254 index++;
255 const uint32_t packetSize = packetInfo.packetSize;
256 const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
257 const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
258 const uint32_t numberOfPools = packetInfo.numberOfPools;
259
260 // unpackage input operands
261 std::vector<RequestArgument> inputs;
262 inputs.reserve(numberOfInputOperands);
263 for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
264 // validate input operand information
265 if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
266 LOG(ERROR) << "FMQ Request packet ill-formed";
267 return {{}, MeasureTiming::NO};
268 }
269
270 // unpackage operand information
271 const FmqRequestDatum::OperandInformation& operandInfo =
272 data[index].inputOperandInformation();
273 index++;
274 const bool hasNoValue = operandInfo.hasNoValue;
275 const DataLocation location = operandInfo.location;
276 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
277
278 // unpackage operand dimensions
279 std::vector<uint32_t> dimensions;
280 dimensions.reserve(numberOfDimensions);
281 for (size_t i = 0; i < numberOfDimensions; ++i) {
282 // validate dimension
283 if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
284 LOG(ERROR) << "FMQ Request packet ill-formed";
285 return {{}, MeasureTiming::NO};
286 }
287
288 // unpackage dimension
289 const uint32_t dimension = data[index].inputOperandDimensionValue();
290 index++;
291
292 // store result
293 dimensions.push_back(dimension);
294 }
295
296 // store result
297 inputs.push_back(
298 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
299 }
300
301 // unpackage output operands
302 std::vector<RequestArgument> outputs;
303 outputs.reserve(numberOfOutputOperands);
304 for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
305 // validate output operand information
306 if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
307 LOG(ERROR) << "FMQ Request packet ill-formed";
308 return {{}, MeasureTiming::NO};
309 }
310
311 // unpackage operand information
312 const FmqRequestDatum::OperandInformation& operandInfo =
313 data[index].outputOperandInformation();
314 index++;
315 const bool hasNoValue = operandInfo.hasNoValue;
316 const DataLocation location = operandInfo.location;
317 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
318
319 // unpackage operand dimensions
320 std::vector<uint32_t> dimensions;
321 dimensions.reserve(numberOfDimensions);
322 for (size_t i = 0; i < numberOfDimensions; ++i) {
323 // validate dimension
324 if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
325 LOG(ERROR) << "FMQ Request packet ill-formed";
326 return {{}, MeasureTiming::NO};
327 }
328
329 // unpackage dimension
330 const uint32_t dimension = data[index].outputOperandDimensionValue();
331 index++;
332
333 // store result
334 dimensions.push_back(dimension);
335 }
336
337 // store result
338 outputs.push_back(
339 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
340 }
341
342 // unpackage pools
343 std::vector<int32_t> slots;
344 slots.reserve(numberOfPools);
345 for (size_t pool = 0; pool < numberOfPools; ++pool) {
346 // validate input operand information
347 if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
348 LOG(ERROR) << "FMQ Request packet ill-formed";
349 return {{}, MeasureTiming::NO};
350 }
351
352 // unpackage operand information
353 const int32_t poolId = data[index].poolIdentifier();
354 index++;
355
356 // store result
357 slots.push_back(poolId);
358 }
359 hidl_vec<hidl_memory> pools = mMemoryCache.getMemories(slots);
360
361 // validate measureTiming
362 if (data[index].getDiscriminator() != discriminator::measureTiming) {
363 LOG(ERROR) << "FMQ Request packet ill-formed";
364 return {{}, MeasureTiming::NO};
365 }
366
367 // unpackage measureTiming
368 const MeasureTiming measure = data[index].measureTiming();
369 index++;
370
371 // validate packet information
372 if (index != packetSize) {
373 LOG(ERROR) << "FMQ Result packet ill-formed";
374 return {{}, MeasureTiming::NO};
375 }
376
377 // return request
378 return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/std::move(pools)}, measure};
379}
380
381// serialize result
382std::vector<FmqResultDatum> ExecutionBurstServer::serialize(
383 ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing) {
384 // count how many elements need to be sent for a request
385 size_t count = 2 + outputShapes.size();
386 for (const auto& outputShape : outputShapes) {
387 count += outputShape.dimensions.size();
388 }
389
390 // create buffer to temporarily store elements
391 std::vector<FmqResultDatum> data;
392 data.reserve(count);
393
394 // package packetInfo
395 {
396 FmqResultDatum datum;
397 datum.packetInformation({/*.packetSize=*/static_cast<uint32_t>(count),
398 /*.errorStatus=*/errorStatus,
399 /*.numberOfOperands=*/static_cast<uint32_t>(outputShapes.size())});
400 data.push_back(datum);
401 }
402
403 // package output shape data
404 for (const auto& operand : outputShapes) {
405 // package operand information
406 FmqResultDatum datum;
407 datum.operandInformation(
408 {/*.isSufficient=*/operand.isSufficient,
409 /*.numberOfDimensions=*/static_cast<uint32_t>(operand.dimensions.size())});
410 data.push_back(datum);
411
412 // package operand dimensions
413 for (uint32_t dimension : operand.dimensions) {
414 FmqResultDatum datum;
415 datum.operandDimensionValue(dimension);
416 data.push_back(datum);
417 }
418 }
419
420 // package executionTiming
421 {
422 FmqResultDatum datum;
423 datum.executionTiming(timing);
424 data.push_back(datum);
425 }
426
427 // return result
428 return data;
429}
430
431Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
432 mMemoryCache.freeMemory(slot);
433 return Void();
434}
435
436void ExecutionBurstServer::task() {
437 while (!mTeardown) {
438 // receive request
439 const std::vector<FmqRequestDatum> requestData = getPacketBlocking();
440
441 // terminate loop
442 if (mTeardown) {
443 return;
444 }
445
Michael Butler3db6fe52019-01-29 11:20:30 -0800446 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
447 "ExecutionBurstServer processing packet and returning results");
448
Michael Butler60296322019-01-17 17:54:51 -0800449 // continue processing
450 Request request;
451 MeasureTiming measure;
452 std::tie(request, measure) = deserialize(requestData);
453
454 // perform computation
455 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
456 std::vector<OutputShape> outputShapes;
457 Timing returnedTiming;
Michael Butler89e99ba2019-01-24 02:36:37 -0800458 // This call to IPreparedModel::executeSynchronously occurs entirely
459 // within the same process, so ignore the Return<> errors via .isOk().
460 // TODO: verify it is safe to always call isOk() here, or if there is
461 // any benefit to checking any potential errors.
Michael Butler60296322019-01-17 17:54:51 -0800462 mPreparedModel
463 ->executeSynchronously(request, measure,
464 [&errorStatus, &outputShapes, &returnedTiming](
465 ErrorStatus status,
466 const hidl_vec<OutputShape>& shapes, Timing timing) {
467 errorStatus = status;
468 outputShapes = shapes;
469 returnedTiming = timing;
470 })
471 .isOk();
472
473 // return result
474 const std::vector<FmqResultDatum> result =
475 serialize(errorStatus, outputShapes, returnedTiming);
476 sendPacket(result);
477 }
478}
479
Michael Butler3db6fe52019-01-29 11:20:30 -0800480} // namespace android::nn