blob: a9af0045eb6e360c255c2e1b08d53dccf1d97689 [file] [log] [blame]
Michael Butler60296322019-01-17 17:54:51 -08001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Michael Butler89e99ba2019-01-24 02:36:37 -080017#define LOG_TAG "ExecutionBurstServer"
18
Michael Butler60296322019-01-17 17:54:51 -080019#include "ExecutionBurstServer.h"
20
21#include <android-base/logging.h>
Michael Butler89e99ba2019-01-24 02:36:37 -080022#include <set>
23#include <string>
Michael Butler60296322019-01-17 17:54:51 -080024
25namespace android {
26namespace nn {
27
28BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback) : mCallback(callback) {}
29
30hidl_vec<hidl_memory> BurstMemoryCache::getMemories(const std::vector<int32_t>& slots) {
31 std::lock_guard<std::mutex> guard(mMutex);
32
33 // find unique unknown slots
Michael Butler89e99ba2019-01-24 02:36:37 -080034 std::set<int32_t> setOfUnknownSlots;
35 for (int32_t slot : slots) {
36 if (mSlotToMemoryCache.find(slot) == mSlotToMemoryCache.end()) {
37 setOfUnknownSlots.insert(slot);
38 }
39 }
40 const std::vector<int32_t> unknownSlots(setOfUnknownSlots.begin(), setOfUnknownSlots.end());
Michael Butler60296322019-01-17 17:54:51 -080041
42 // retrieve unknown slots
Michael Butler89e99ba2019-01-24 02:36:37 -080043 if (!unknownSlots.empty()) {
44 LOG(ERROR) << "server calling getMemories";
45 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
46 std::vector<hidl_memory> returnedMemories;
47 Return<void> ret = mCallback->getMemories(
48 unknownSlots, [&errorStatus, &returnedMemories](
49 ErrorStatus status, const hidl_vec<hidl_memory>& memories) {
50 errorStatus = status;
51 if (status == ErrorStatus::NONE) {
52 returnedMemories = memories;
53 }
54 });
Michael Butler60296322019-01-17 17:54:51 -080055
Michael Butler89e99ba2019-01-24 02:36:37 -080056 if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
57 LOG(ERROR) << "Error retrieving memories";
58 return {};
59 }
Michael Butler60296322019-01-17 17:54:51 -080060
Michael Butler89e99ba2019-01-24 02:36:37 -080061 // add memories to unknown slots
62 for (size_t i = 0; i < unknownSlots.size(); ++i) {
63 mSlotToMemoryCache[unknownSlots[i]] = returnedMemories[i];
64 }
Michael Butler60296322019-01-17 17:54:51 -080065 }
66
67 // get all slots
68 hidl_vec<hidl_memory> memories(slots.size());
69 for (size_t i = 0; i < slots.size(); ++i) {
70 memories[i] = mSlotToMemoryCache[slots[i]];
71 }
Michael Butler89e99ba2019-01-24 02:36:37 -080072
Michael Butler60296322019-01-17 17:54:51 -080073 return memories;
74}
75
76void BurstMemoryCache::freeMemory(int32_t slot) {
77 std::lock_guard<std::mutex> guard(mMutex);
78 mSlotToMemoryCache.erase(slot);
79}
80
81ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
82 std::unique_ptr<FmqRequestChannel> requestChannel,
83 std::unique_ptr<FmqResultChannel> resultChannel,
84 IPreparedModel* preparedModel)
85 : mMemoryCache(callback),
86 mFmqRequestChannel(std::move(requestChannel)),
87 mFmqResultChannel(std::move(resultChannel)),
88 mPreparedModel(preparedModel),
89 mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
90 // TODO: highly document the threading behavior of this class
91 mWorker = std::async(std::launch::async, [this] { task(); });
92}
93
94ExecutionBurstServer::~ExecutionBurstServer() {
95 // set teardown flag
96 mTeardown = true;
97
98 // force unblock
Michael Butler89e99ba2019-01-24 02:36:37 -080099 // ExecutionBurstServer is by default waiting on a request packet. If the
100 // client process destroys its burst object, the server will still be
101 // waiting on the futex (assuming mBlocking is true). This force unblock
102 // wakes up any thread waiting on the futex.
Michael Butler60296322019-01-17 17:54:51 -0800103 if (mBlocking) {
Michael Butler89e99ba2019-01-24 02:36:37 -0800104 // TODO: look for a different/better way to signal/notify the futex to
105 // wake up any thread waiting on it
Michael Butler60296322019-01-17 17:54:51 -0800106 FmqRequestDatum datum;
107 datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
108 /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
109 mFmqRequestChannel->writeBlocking(&datum, 1);
110 }
111
112 // wait for task thread to end
113 mWorker.wait();
114}
115
116bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
117 if (mTeardown) {
118 return false;
119 }
120
121 if (mBlocking) {
122 return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
123 } else {
124 return mFmqResultChannel->write(packet.data(), packet.size());
125 }
126}
127
128std::vector<FmqRequestDatum> ExecutionBurstServer::getPacketBlocking() {
129 using discriminator = FmqRequestDatum::hidl_discriminator;
130
131 if (mTeardown) {
132 return {};
133 }
134
Michael Butler89e99ba2019-01-24 02:36:37 -0800135 // wait for request packet and read first element of request packet
136 // TODO: have a more elegant way to wait for data, and read it all at once.
137 // For example, EventFlag can be used to directly wait on the futex, and all
138 // the data can be read at once with a non-blocking call to
139 // MessageQueue::read. For further optimization, MessageQueue::beginRead and
140 // MessageQueue::commitRead can be used to avoid an extra copy of the
141 // metadata.
Michael Butler60296322019-01-17 17:54:51 -0800142 FmqRequestDatum datum;
143 bool success = false;
144 if (mBlocking) {
145 success = mFmqRequestChannel->readBlocking(&datum, 1);
146 } else {
147 while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
148 !mFmqRequestChannel->read(&datum, 1)) {
149 }
150 }
151
152 // terminate loop
153 if (mTeardown) {
154 return {};
155 }
156
157 // validate packet information
158 if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
159 LOG(ERROR) << "FMQ Request packet ill-formed";
160 return {};
161 }
162
163 // unpack packet information
164 const auto& packetInfo = datum.packetInformation();
165 const size_t count = packetInfo.packetSize;
166
167 // retrieve remaining elements
168 // NOTE: all of the data is already available at this point, so there's no
169 // need to do a blocking wait to wait for more data
170 std::vector<FmqRequestDatum> packet(count);
171 packet.front() = datum;
172 success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
173
174 if (!success) {
175 return {};
176 }
177
178 return packet;
179}
180
181// deserialize request
182std::pair<Request, MeasureTiming> ExecutionBurstServer::deserialize(
183 const std::vector<FmqRequestDatum>& data) {
184 using discriminator = FmqRequestDatum::hidl_discriminator;
185
186 Request request;
187 size_t index = 0;
188
189 // validate packet information
190 if (data[index].getDiscriminator() != discriminator::packetInformation) {
191 LOG(ERROR) << "FMQ Request packet ill-formed";
192 return {{}, MeasureTiming::NO};
193 }
194
195 // unpackage packet information
196 const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
197 index++;
198 const uint32_t packetSize = packetInfo.packetSize;
199 const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
200 const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
201 const uint32_t numberOfPools = packetInfo.numberOfPools;
202
203 // unpackage input operands
204 std::vector<RequestArgument> inputs;
205 inputs.reserve(numberOfInputOperands);
206 for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
207 // validate input operand information
208 if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
209 LOG(ERROR) << "FMQ Request packet ill-formed";
210 return {{}, MeasureTiming::NO};
211 }
212
213 // unpackage operand information
214 const FmqRequestDatum::OperandInformation& operandInfo =
215 data[index].inputOperandInformation();
216 index++;
217 const bool hasNoValue = operandInfo.hasNoValue;
218 const DataLocation location = operandInfo.location;
219 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
220
221 // unpackage operand dimensions
222 std::vector<uint32_t> dimensions;
223 dimensions.reserve(numberOfDimensions);
224 for (size_t i = 0; i < numberOfDimensions; ++i) {
225 // validate dimension
226 if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
227 LOG(ERROR) << "FMQ Request packet ill-formed";
228 return {{}, MeasureTiming::NO};
229 }
230
231 // unpackage dimension
232 const uint32_t dimension = data[index].inputOperandDimensionValue();
233 index++;
234
235 // store result
236 dimensions.push_back(dimension);
237 }
238
239 // store result
240 inputs.push_back(
241 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
242 }
243
244 // unpackage output operands
245 std::vector<RequestArgument> outputs;
246 outputs.reserve(numberOfOutputOperands);
247 for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
248 // validate output operand information
249 if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
250 LOG(ERROR) << "FMQ Request packet ill-formed";
251 return {{}, MeasureTiming::NO};
252 }
253
254 // unpackage operand information
255 const FmqRequestDatum::OperandInformation& operandInfo =
256 data[index].outputOperandInformation();
257 index++;
258 const bool hasNoValue = operandInfo.hasNoValue;
259 const DataLocation location = operandInfo.location;
260 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
261
262 // unpackage operand dimensions
263 std::vector<uint32_t> dimensions;
264 dimensions.reserve(numberOfDimensions);
265 for (size_t i = 0; i < numberOfDimensions; ++i) {
266 // validate dimension
267 if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
268 LOG(ERROR) << "FMQ Request packet ill-formed";
269 return {{}, MeasureTiming::NO};
270 }
271
272 // unpackage dimension
273 const uint32_t dimension = data[index].outputOperandDimensionValue();
274 index++;
275
276 // store result
277 dimensions.push_back(dimension);
278 }
279
280 // store result
281 outputs.push_back(
282 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
283 }
284
285 // unpackage pools
286 std::vector<int32_t> slots;
287 slots.reserve(numberOfPools);
288 for (size_t pool = 0; pool < numberOfPools; ++pool) {
289 // validate input operand information
290 if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
291 LOG(ERROR) << "FMQ Request packet ill-formed";
292 return {{}, MeasureTiming::NO};
293 }
294
295 // unpackage operand information
296 const int32_t poolId = data[index].poolIdentifier();
297 index++;
298
299 // store result
300 slots.push_back(poolId);
301 }
302 hidl_vec<hidl_memory> pools = mMemoryCache.getMemories(slots);
303
304 // validate measureTiming
305 if (data[index].getDiscriminator() != discriminator::measureTiming) {
306 LOG(ERROR) << "FMQ Request packet ill-formed";
307 return {{}, MeasureTiming::NO};
308 }
309
310 // unpackage measureTiming
311 const MeasureTiming measure = data[index].measureTiming();
312 index++;
313
314 // validate packet information
315 if (index != packetSize) {
316 LOG(ERROR) << "FMQ Result packet ill-formed";
317 return {{}, MeasureTiming::NO};
318 }
319
320 // return request
321 return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/std::move(pools)}, measure};
322}
323
324// serialize result
325std::vector<FmqResultDatum> ExecutionBurstServer::serialize(
326 ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing) {
327 // count how many elements need to be sent for a request
328 size_t count = 2 + outputShapes.size();
329 for (const auto& outputShape : outputShapes) {
330 count += outputShape.dimensions.size();
331 }
332
333 // create buffer to temporarily store elements
334 std::vector<FmqResultDatum> data;
335 data.reserve(count);
336
337 // package packetInfo
338 {
339 FmqResultDatum datum;
340 datum.packetInformation({/*.packetSize=*/static_cast<uint32_t>(count),
341 /*.errorStatus=*/errorStatus,
342 /*.numberOfOperands=*/static_cast<uint32_t>(outputShapes.size())});
343 data.push_back(datum);
344 }
345
346 // package output shape data
347 for (const auto& operand : outputShapes) {
348 // package operand information
349 FmqResultDatum datum;
350 datum.operandInformation(
351 {/*.isSufficient=*/operand.isSufficient,
352 /*.numberOfDimensions=*/static_cast<uint32_t>(operand.dimensions.size())});
353 data.push_back(datum);
354
355 // package operand dimensions
356 for (uint32_t dimension : operand.dimensions) {
357 FmqResultDatum datum;
358 datum.operandDimensionValue(dimension);
359 data.push_back(datum);
360 }
361 }
362
363 // package executionTiming
364 {
365 FmqResultDatum datum;
366 datum.executionTiming(timing);
367 data.push_back(datum);
368 }
369
370 // return result
371 return data;
372}
373
374Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
375 mMemoryCache.freeMemory(slot);
376 return Void();
377}
378
379void ExecutionBurstServer::task() {
380 while (!mTeardown) {
381 // receive request
382 const std::vector<FmqRequestDatum> requestData = getPacketBlocking();
383
384 // terminate loop
385 if (mTeardown) {
386 return;
387 }
388
389 // continue processing
390 Request request;
391 MeasureTiming measure;
392 std::tie(request, measure) = deserialize(requestData);
393
394 // perform computation
395 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
396 std::vector<OutputShape> outputShapes;
397 Timing returnedTiming;
Michael Butler89e99ba2019-01-24 02:36:37 -0800398 // This call to IPreparedModel::executeSynchronously occurs entirely
399 // within the same process, so ignore the Return<> errors via .isOk().
400 // TODO: verify it is safe to always call isOk() here, or if there is
401 // any benefit to checking any potential errors.
Michael Butler60296322019-01-17 17:54:51 -0800402 mPreparedModel
403 ->executeSynchronously(request, measure,
404 [&errorStatus, &outputShapes, &returnedTiming](
405 ErrorStatus status,
406 const hidl_vec<OutputShape>& shapes, Timing timing) {
407 errorStatus = status;
408 outputShapes = shapes;
409 returnedTiming = timing;
410 })
411 .isOk();
412
413 // return result
414 const std::vector<FmqResultDatum> result =
415 serialize(errorStatus, outputShapes, returnedTiming);
416 sendPacket(result);
417 }
418}
419
420sp<IBurstContext> createBurstContext(const sp<IBurstCallback>& callback,
421 const MQDescriptorSync<FmqRequestDatum>& requestChannel,
422 const MQDescriptorSync<FmqResultDatum>& resultChannel,
423 IPreparedModel* preparedModel) {
424 // check inputs
425 if (callback == nullptr || preparedModel == nullptr) {
426 LOG(ERROR) << "createExecutionBurstServer passed a nullptr";
427 return nullptr;
428 }
429
430 // create FMQ objects
431 std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
432 FmqRequestChannel(requestChannel)};
433 std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
434 FmqResultChannel(resultChannel)};
435
436 // check FMQ objects
437 if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
438 !fmqResultChannel->isValid()) {
439 LOG(ERROR) << "createExecutionBurstServer failed to create FastMessageQueue";
440 return nullptr;
441 }
442
443 // make and return context
444 return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
445 std::move(fmqResultChannel), preparedModel);
446}
447
448} // namespace nn
449} // namespace android