Impose more discipline on uses of ModelArgumentInfo
Change ModelArgumentInfo from a struct to a class. Fields are no
longer public, and initialization uses a factory method (e.g.,
createFromPointer) rather than a setter (e.g.,
setFromPointer). Accessor methods are guarded by a check against the
(already existing) state (POINTER, MEMORY, or both).
Calling ANeuralNetworksExecution_setInput*() or
ANeuralNetworksExecution_setOutput*() multiple times on the same
operand of the same execution is now an error.
Bug: 147850002
Test: NeuralNetworksTest_static
Change-Id: Idebe6b4df79bff5c62e56860e42463d610d91085
diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp
index 5905ea5..6c16e5d 100644
--- a/nn/runtime/ExecutionBuilder.cpp
+++ b/nn/runtime/ExecutionBuilder.cpp
@@ -125,8 +125,15 @@
return ANEURALNETWORKS_BAD_DATA;
}
uint32_t l = static_cast<uint32_t>(length);
- return mInputs[index].setFromPointer(mModel->getInputOperand(index), type,
- const_cast<void*>(buffer), l);
+ if (!mInputs[index].unspecified()) {
+ LOG(ERROR) << "ANeuralNetworksExecution_setInput called when an input has already been "
+ "provided";
+ return ANEURALNETWORKS_BAD_STATE;
+ }
+ int n;
+ std::tie(n, mInputs[index]) = ModelArgumentInfo::createFromPointer(
+ mModel->getInputOperand(index), type, const_cast<void*>(buffer), l);
+ return n;
}
int ExecutionBuilder::setInputFromMemory(uint32_t index, const ANeuralNetworksOperandType* type,
@@ -162,8 +169,16 @@
}
// TODO validate the rest
uint32_t poolIndex = mMemories.add(memory);
- return mInputs[index].setFromMemory(mModel->getInputOperand(index), type, poolIndex, offset,
- length);
+ if (!mInputs[index].unspecified()) {
+ LOG(ERROR)
+ << "ANeuralNetworksExecution_setInputFromMemory called when an input has already "
+ "been provided";
+ return ANEURALNETWORKS_BAD_STATE;
+ }
+ int n;
+ std::tie(n, mInputs[index]) = ModelArgumentInfo::createFromMemory(
+ mModel->getInputOperand(index), type, poolIndex, offset, length);
+ return n;
}
int ExecutionBuilder::setOutput(uint32_t index, const ANeuralNetworksOperandType* type,
@@ -187,7 +202,15 @@
return ANEURALNETWORKS_BAD_DATA;
}
uint32_t l = static_cast<uint32_t>(length);
- return mOutputs[index].setFromPointer(mModel->getOutputOperand(index), type, buffer, l);
+ if (!mOutputs[index].unspecified()) {
+ LOG(ERROR) << "ANeuralNetworksExecution_setOutput called when an output has already been "
+ "provided";
+ return ANEURALNETWORKS_BAD_STATE;
+ }
+ int n;
+ std::tie(n, mOutputs[index]) =
+ ModelArgumentInfo::createFromPointer(mModel->getOutputOperand(index), type, buffer, l);
+ return n;
}
int ExecutionBuilder::setOutputFromMemory(uint32_t index, const ANeuralNetworksOperandType* type,
@@ -223,8 +246,15 @@
}
// TODO validate the rest
uint32_t poolIndex = mMemories.add(memory);
- return mOutputs[index].setFromMemory(mModel->getOutputOperand(index), type, poolIndex, offset,
- length);
+ if (!mOutputs[index].unspecified()) {
+ LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory called when an output has "
+ "already been provided";
+ return ANEURALNETWORKS_BAD_STATE;
+ }
+ int n;
+ std::tie(n, mOutputs[index]) = ModelArgumentInfo::createFromMemory(
+ mModel->getOutputOperand(index), type, poolIndex, offset, length);
+ return n;
}
int ExecutionBuilder::setMeasureTiming(bool measure) {
@@ -374,15 +404,15 @@
<< " " << count;
return ANEURALNETWORKS_BAD_DATA;
}
- const auto& dims = mOutputs[index].dimensions;
+ const auto& dims = mOutputs[index].dimensions();
if (dims.empty()) {
LOG(ERROR) << "ANeuralNetworksExecution_getOutputOperandDimensions can not query "
"dimensions of a scalar";
return ANEURALNETWORKS_BAD_DATA;
}
std::copy(dims.begin(), dims.end(), dimensions);
- return mOutputs[index].isSufficient ? ANEURALNETWORKS_NO_ERROR
- : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
+ return mOutputs[index].isSufficient() ? ANEURALNETWORKS_NO_ERROR
+ : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
}
int ExecutionBuilder::getOutputOperandRank(uint32_t index, uint32_t* rank) {
@@ -405,9 +435,9 @@
<< count;
return ANEURALNETWORKS_BAD_DATA;
}
- *rank = static_cast<uint32_t>(mOutputs[index].dimensions.size());
- return mOutputs[index].isSufficient ? ANEURALNETWORKS_NO_ERROR
- : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
+ *rank = static_cast<uint32_t>(mOutputs[index].dimensions().size());
+ return mOutputs[index].isSufficient() ? ANEURALNETWORKS_NO_ERROR
+ : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
}
// Attempt synchronous execution of full model on CPU.
@@ -705,21 +735,21 @@
}
const auto deadline = makeDeadline(mTimeoutDuration);
for (auto& p : mInputs) {
- if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+ if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
LOG(ERROR) << "ANeuralNetworksExecution_startComputeWithDependencies"
" not all inputs specified";
return ANEURALNETWORKS_BAD_DATA;
}
}
for (auto& p : mOutputs) {
- if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+ if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
LOG(ERROR) << "ANeuralNetworksExecution_startComputeWithDependencies"
" not all outputs specified";
return ANEURALNETWORKS_BAD_DATA;
}
}
for (uint32_t i = 0; i < mOutputs.size(); i++) {
- if (mOutputs[i].state != ModelArgumentInfo::HAS_NO_VALUE &&
+ if (mOutputs[i].state() != ModelArgumentInfo::HAS_NO_VALUE &&
!checkDimensionInfo(mModel->getOutputOperand(i), nullptr,
"ANeuralNetworksExecution_startComputeWithDependencies", false)) {
LOG(ERROR) << "ANeuralNetworksExecution_startComputeWithDependencies"
@@ -762,18 +792,18 @@
return ANEURALNETWORKS_BAD_STATE;
}
for (auto& p : mInputs) {
- if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+ if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
LOG(ERROR) << "ANeuralNetworksExecution_" << name() << " not all inputs specified";
return ANEURALNETWORKS_BAD_DATA;
- } else if (p.state == ModelArgumentInfo::MEMORY) {
- const Memory* memory = mMemories[p.locationAndLength.poolIndex];
- if (!memory->getValidator().validateInputDimensions(p.dimensions)) {
+ } else if (p.state() == ModelArgumentInfo::MEMORY) {
+ const Memory* memory = mMemories[p.locationAndLength().poolIndex];
+ if (!memory->getValidator().validateInputDimensions(p.dimensions())) {
return ANEURALNETWORKS_OP_FAILED;
}
}
}
for (auto& p : mOutputs) {
- if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+ if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
LOG(ERROR) << "ANeuralNetworksExecution_" << name() << " not all outputs specified";
return ANEURALNETWORKS_BAD_DATA;
}
@@ -835,7 +865,7 @@
std::vector<OutputShape> outputShapes(mOutputs.size());
std::transform(mOutputs.begin(), mOutputs.end(), outputShapes.begin(),
[](const auto& x) -> OutputShape {
- return {.dimensions = x.dimensions, .isSufficient = true};
+ return {.dimensions = x.dimensions(), .isSufficient = true};
});
return outputShapes;
}
@@ -858,20 +888,20 @@
NN_RET_CHECK_EQ(outputShapes.size(), mOutputs.size());
for (uint32_t i = 0; i < outputShapes.size(); i++) {
// Check if only unspecified dimensions or rank are overwritten.
- NN_RET_CHECK(isUpdatable(mOutputs[i].dimensions, outputShapes[i].dimensions));
+ NN_RET_CHECK(isUpdatable(mOutputs[i].dimensions(), outputShapes[i].dimensions));
}
for (uint32_t i = 0; i < outputShapes.size(); i++) {
- mOutputs[i].dimensions = outputShapes[i].dimensions;
- mOutputs[i].isSufficient = outputShapes[i].isSufficient;
+ mOutputs[i].dimensions() = outputShapes[i].dimensions;
+ mOutputs[i].isSufficient() = outputShapes[i].isSufficient;
}
return true;
}
bool ExecutionBuilder::updateMemories() {
for (const auto& output : mOutputs) {
- if (output.state != ModelArgumentInfo::MEMORY) continue;
- const Memory* memory = mMemories[output.locationAndLength.poolIndex];
- NN_RET_CHECK(memory->getValidator().updateMetadata({.dimensions = output.dimensions}));
+ if (output.state() != ModelArgumentInfo::MEMORY) continue;
+ const Memory* memory = mMemories[output.locationAndLength().poolIndex];
+ NN_RET_CHECK(memory->getValidator().updateMetadata({.dimensions = output.dimensions()}));
}
return true;
}
@@ -885,8 +915,8 @@
}
bool success = status == ErrorStatus::NONE;
for (const auto& output : mOutputs) {
- if (output.state != ModelArgumentInfo::MEMORY) continue;
- const Memory* memory = mMemories[output.locationAndLength.poolIndex];
+ if (output.state() != ModelArgumentInfo::MEMORY) continue;
+ const Memory* memory = mMemories[output.locationAndLength().poolIndex];
memory->getValidator().setInitialized(success);
}
return status;
@@ -940,7 +970,7 @@
void StepExecutor::mapInputOrOutput(const ModelArgumentInfo& builderInputOrOutput,
ModelArgumentInfo* executorInputOrOutput) {
*executorInputOrOutput = builderInputOrOutput;
- switch (executorInputOrOutput->state) {
+ switch (executorInputOrOutput->state()) {
default:
CHECK(false) << "unexpected ModelArgumentInfo::state";
break;
@@ -949,10 +979,10 @@
case ModelArgumentInfo::UNSPECIFIED:
break;
case ModelArgumentInfo::MEMORY: {
- const uint32_t builderPoolIndex = builderInputOrOutput.locationAndLength.poolIndex;
+ const uint32_t builderPoolIndex = builderInputOrOutput.locationAndLength().poolIndex;
const Memory* memory = mExecutionBuilder->mMemories[builderPoolIndex];
const uint32_t executorPoolIndex = mMemories.add(memory);
- executorInputOrOutput->locationAndLength.poolIndex = executorPoolIndex;
+ executorInputOrOutput->locationAndLength().poolIndex = executorPoolIndex;
break;
}
}
@@ -967,22 +997,26 @@
uint32_t poolIndex = mMemories.add(memory);
uint32_t length = TypeManager::get()->getSizeOfData(inputOrOutputOperand);
- return inputOrOutputInfo->setFromMemory(inputOrOutputOperand, /*type=*/nullptr, poolIndex,
- offset, length);
+ CHECK(inputOrOutputInfo->unspecified());
+ int n;
+ std::tie(n, *inputOrOutputInfo) =
+ ModelArgumentInfo::createFromMemory(inputOrOutputOperand,
+ /*type=*/nullptr, poolIndex, offset, length);
+ return n;
}
static void logArguments(const char* kind, const std::vector<ModelArgumentInfo>& args) {
for (unsigned i = 0; i < args.size(); i++) {
const auto& arg = args[i];
std::string prefix = kind + std::string("[") + std::to_string(i) + "] = ";
- switch (arg.state) {
+ switch (arg.state()) {
case ModelArgumentInfo::POINTER:
- VLOG(EXECUTION) << prefix << "POINTER(" << SHOW_IF_DEBUG(arg.buffer) << ")";
+ VLOG(EXECUTION) << prefix << "POINTER(" << SHOW_IF_DEBUG(arg.buffer()) << ")";
break;
case ModelArgumentInfo::MEMORY:
VLOG(EXECUTION) << prefix << "MEMORY("
- << "pool=" << arg.locationAndLength.poolIndex << ", "
- << "off=" << arg.locationAndLength.offset << ")";
+ << "pool=" << arg.locationAndLength().poolIndex << ", "
+ << "off=" << arg.locationAndLength().offset << ")";
break;
case ModelArgumentInfo::HAS_NO_VALUE:
VLOG(EXECUTION) << prefix << "HAS_NO_VALUE";
@@ -991,7 +1025,7 @@
VLOG(EXECUTION) << prefix << "UNSPECIFIED";
break;
default:
- VLOG(EXECUTION) << prefix << "state(" << arg.state << ")";
+ VLOG(EXECUTION) << prefix << "state(" << arg.state() << ")";
break;
}
}
diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp
index f19065b..91ef6b1 100644
--- a/nn/runtime/ExecutionPlan.cpp
+++ b/nn/runtime/ExecutionPlan.cpp
@@ -1003,14 +1003,14 @@
std::optional<ExecutionPlan::Buffer> ExecutionPlan::getBufferFromModelArgumentInfo(
const ModelArgumentInfo& info, const ExecutionBuilder* executionBuilder) const {
- switch (info.state) {
+ switch (info.state()) {
case ModelArgumentInfo::POINTER: {
- return Buffer(info.buffer, info.locationAndLength.length);
+ return Buffer(info.buffer(), info.length());
} break;
case ModelArgumentInfo::MEMORY: {
if (std::optional<RunTimePoolInfo> poolInfo =
- executionBuilder->getRunTimePoolInfo(info.locationAndLength.poolIndex)) {
- return Buffer(*poolInfo, info.locationAndLength.offset);
+ executionBuilder->getRunTimePoolInfo(info.locationAndLength().poolIndex)) {
+ return Buffer(*poolInfo, info.locationAndLength().offset);
} else {
LOG(ERROR) << "Unable to map operand memory pool";
return std::nullopt;
@@ -1021,7 +1021,7 @@
return std::nullopt;
} break;
default: {
- LOG(ERROR) << "Unexpected operand memory state: " << static_cast<int>(info.state);
+ LOG(ERROR) << "Unexpected operand memory state: " << static_cast<int>(info.state());
return std::nullopt;
} break;
}
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index db3cd6b..7f2afb6 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -282,14 +282,13 @@
const uint32_t nextPoolIndex = memories->size();
int64_t total = 0;
for (const auto& info : args) {
- if (info.state == ModelArgumentInfo::POINTER) {
- const DataLocation& loc = info.locationAndLength;
+ if (info.state() == ModelArgumentInfo::POINTER) {
// TODO Good enough alignment?
- total += alignBytesNeeded(static_cast<uint32_t>(total), loc.length);
+ total += alignBytesNeeded(static_cast<uint32_t>(total), info.length());
ptrArgsLocations.push_back({.poolIndex = nextPoolIndex,
.offset = static_cast<uint32_t>(total),
- .length = loc.length});
- total += loc.length;
+ .length = info.length()});
+ total += info.length();
}
};
if (total > 0xFFFFFFFF) {
@@ -348,10 +347,10 @@
if (inputPtrArgsMemory != nullptr) {
uint32_t ptrInputIndex = 0;
for (const auto& info : inputs) {
- if (info.state == ModelArgumentInfo::POINTER) {
+ if (info.state() == ModelArgumentInfo::POINTER) {
const DataLocation& loc = inputPtrArgsLocations[ptrInputIndex++];
uint8_t* const data = inputPtrArgsMemory->getPointer();
- memcpy(data + loc.offset, info.buffer, loc.length);
+ memcpy(data + loc.offset, info.buffer(), loc.length);
}
}
}
@@ -412,10 +411,10 @@
if (outputPtrArgsMemory != nullptr) {
uint32_t ptrOutputIndex = 0;
for (const auto& info : outputs) {
- if (info.state == ModelArgumentInfo::POINTER) {
+ if (info.state() == ModelArgumentInfo::POINTER) {
const DataLocation& loc = outputPtrArgsLocations[ptrOutputIndex++];
const uint8_t* const data = outputPtrArgsMemory->getPointer();
- memcpy(info.buffer, data + loc.offset, loc.length);
+ memcpy(info.buffer(), data + loc.offset, loc.length);
}
}
}
@@ -457,10 +456,10 @@
if (inputPtrArgsMemory != nullptr) {
uint32_t ptrInputIndex = 0;
for (const auto& info : inputs) {
- if (info.state == ModelArgumentInfo::POINTER) {
+ if (info.state() == ModelArgumentInfo::POINTER) {
const DataLocation& loc = inputPtrArgsLocations[ptrInputIndex++];
uint8_t* const data = inputPtrArgsMemory->getPointer();
- memcpy(data + loc.offset, info.buffer, loc.length);
+ memcpy(data + loc.offset, info.buffer(), loc.length);
}
}
}
@@ -528,10 +527,10 @@
}
uint32_t ptrOutputIndex = 0;
for (const auto& info : outputs) {
- if (info.state == ModelArgumentInfo::POINTER) {
+ if (info.state() == ModelArgumentInfo::POINTER) {
const DataLocation& loc = outputPtrArgsLocations[ptrOutputIndex++];
const uint8_t* const data = outputPtrArgsMemory->getPointer();
- memcpy(info.buffer, data + loc.offset, loc.length);
+ memcpy(info.buffer(), data + loc.offset, loc.length);
}
}
}
@@ -769,13 +768,13 @@
[&requestPoolInfos](const std::vector<ModelArgumentInfo>& argumentInfos) {
std::vector<DataLocation> ptrArgsLocations;
for (const ModelArgumentInfo& argumentInfo : argumentInfos) {
- if (argumentInfo.state == ModelArgumentInfo::POINTER) {
+ if (argumentInfo.state() == ModelArgumentInfo::POINTER) {
ptrArgsLocations.push_back(
{.poolIndex = static_cast<uint32_t>(requestPoolInfos.size()),
.offset = 0,
- .length = argumentInfo.locationAndLength.length});
+ .length = argumentInfo.length()});
requestPoolInfos.emplace_back(RunTimePoolInfo::createFromExistingBuffer(
- static_cast<uint8_t*>(argumentInfo.buffer)));
+ static_cast<uint8_t*>(argumentInfo.buffer())));
}
}
return ptrArgsLocations;
diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h
index 0ece99b..4dac086 100644
--- a/nn/runtime/Manager.h
+++ b/nn/runtime/Manager.h
@@ -39,8 +39,8 @@
class Device;
class ExecutionBurstController;
class MetaModel;
+class ModelArgumentInfo;
class VersionedIPreparedModel;
-struct ModelArgumentInfo;
// A unified interface for actual driver prepared model as well as the CPU.
class PreparedModel {
diff --git a/nn/runtime/ModelArgumentInfo.cpp b/nn/runtime/ModelArgumentInfo.cpp
index f8ddbfe..cf24004 100644
--- a/nn/runtime/ModelArgumentInfo.cpp
+++ b/nn/runtime/ModelArgumentInfo.cpp
@@ -19,6 +19,7 @@
#include "ModelArgumentInfo.h"
#include <algorithm>
+#include <utility>
#include <vector>
#include "HalInterfaces.h"
@@ -31,61 +32,74 @@
using namespace hal;
-int ModelArgumentInfo::setFromPointer(const Operand& operand,
- const ANeuralNetworksOperandType* type, void* data,
- uint32_t length) {
+static const std::pair<int, ModelArgumentInfo> kBadDataModelArgumentInfo{ANEURALNETWORKS_BAD_DATA,
+ {}};
+
+std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromPointer(
+ const Operand& operand, const ANeuralNetworksOperandType* type, void* data,
+ uint32_t length) {
if ((data == nullptr) != (length == 0)) {
const char* dataPtrMsg = data ? "NOT_NULLPTR" : "NULLPTR";
LOG(ERROR) << "Data pointer must be nullptr if and only if length is zero (data = "
<< dataPtrMsg << ", length = " << length << ")";
- return ANEURALNETWORKS_BAD_DATA;
+ return kBadDataModelArgumentInfo;
}
+
+ ModelArgumentInfo ret;
if (data == nullptr) {
- state = ModelArgumentInfo::HAS_NO_VALUE;
+ ret.mState = ModelArgumentInfo::HAS_NO_VALUE;
} else {
- NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
+ if (int n = ret.updateDimensionInfo(operand, type)) {
+ return {n, ModelArgumentInfo()};
+ }
if (operand.type != OperandType::OEM) {
- uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
+ uint32_t neededLength =
+ TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
if (neededLength != length && neededLength != 0) {
LOG(ERROR) << "Setting argument with invalid length: " << length
<< ", expected length: " << neededLength;
- return ANEURALNETWORKS_BAD_DATA;
+ return kBadDataModelArgumentInfo;
}
}
- state = ModelArgumentInfo::POINTER;
+ ret.mState = ModelArgumentInfo::POINTER;
}
- buffer = data;
- locationAndLength = {.poolIndex = 0, .offset = 0, .length = length};
- return ANEURALNETWORKS_NO_ERROR;
+ ret.mBuffer = data;
+ ret.mLocationAndLength = {.poolIndex = 0, .offset = 0, .length = length};
+ return {ANEURALNETWORKS_NO_ERROR, ret};
}
-int ModelArgumentInfo::setFromMemory(const Operand& operand, const ANeuralNetworksOperandType* type,
- uint32_t poolIndex, uint32_t offset, uint32_t length) {
- NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
+std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromMemory(
+ const Operand& operand, const ANeuralNetworksOperandType* type, uint32_t poolIndex,
+ uint32_t offset, uint32_t length) {
+ ModelArgumentInfo ret;
+ if (int n = ret.updateDimensionInfo(operand, type)) {
+ return {n, ModelArgumentInfo()};
+ }
const bool isMemorySizeKnown = offset != 0 || length != 0;
if (isMemorySizeKnown && operand.type != OperandType::OEM) {
- const uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
+ const uint32_t neededLength =
+ TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
if (neededLength != length && neededLength != 0) {
LOG(ERROR) << "Setting argument with invalid length: " << length
<< " (offset: " << offset << "), expected length: " << neededLength;
- return ANEURALNETWORKS_BAD_DATA;
+ return kBadDataModelArgumentInfo;
}
}
- state = ModelArgumentInfo::MEMORY;
- locationAndLength = {.poolIndex = poolIndex, .offset = offset, .length = length};
- buffer = nullptr;
- return ANEURALNETWORKS_NO_ERROR;
+ ret.mState = ModelArgumentInfo::MEMORY;
+ ret.mLocationAndLength = {.poolIndex = poolIndex, .offset = offset, .length = length};
+ ret.mBuffer = nullptr;
+ return {ANEURALNETWORKS_NO_ERROR, ret};
}
int ModelArgumentInfo::updateDimensionInfo(const Operand& operand,
const ANeuralNetworksOperandType* newType) {
if (newType == nullptr) {
- dimensions = operand.dimensions;
+ mDimensions = operand.dimensions;
} else {
const uint32_t count = newType->dimensionCount;
- dimensions = hidl_vec<uint32_t>(count);
- std::copy(&newType->dimensions[0], &newType->dimensions[count], dimensions.begin());
+ mDimensions = hidl_vec<uint32_t>(count);
+ std::copy(&newType->dimensions[0], &newType->dimensions[count], mDimensions.begin());
}
return ANEURALNETWORKS_NO_ERROR;
}
@@ -98,12 +112,22 @@
uint32_t ptrArgsIndex = 0;
for (size_t i = 0; i < count; i++) {
const auto& info = argumentInfos[i];
- ioInfos[i] = {
- .hasNoValue = info.state == ModelArgumentInfo::HAS_NO_VALUE,
- .location = info.state == ModelArgumentInfo::POINTER
- ? ptrArgsLocations[ptrArgsIndex++]
- : info.locationAndLength,
- .dimensions = info.dimensions,
+ switch (info.state()) {
+ case ModelArgumentInfo::POINTER:
+ ioInfos[i] = {.hasNoValue = false,
+ .location = ptrArgsLocations[ptrArgsIndex++],
+ .dimensions = info.dimensions()};
+ break;
+ case ModelArgumentInfo::MEMORY:
+ ioInfos[i] = {.hasNoValue = false,
+ .location = info.locationAndLength(),
+ .dimensions = info.dimensions()};
+ break;
+ case ModelArgumentInfo::HAS_NO_VALUE:
+ ioInfos[i] = {.hasNoValue = true};
+ break;
+ default:
+ CHECK(false);
};
}
return ioInfos;
diff --git a/nn/runtime/ModelArgumentInfo.h b/nn/runtime/ModelArgumentInfo.h
index 61dfc1e..22dd34c 100644
--- a/nn/runtime/ModelArgumentInfo.h
+++ b/nn/runtime/ModelArgumentInfo.h
@@ -17,36 +17,93 @@
#ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MODEL_ARGUMENT_INFO_H
#define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MODEL_ARGUMENT_INFO_H
+#include <utility>
#include <vector>
#include "HalInterfaces.h"
#include "NeuralNetworks.h"
+#include "Utils.h"
namespace android {
namespace nn {
// TODO move length out of DataLocation
-struct ModelArgumentInfo {
+//
+// NOTE: The primary usage pattern is that a ModelArgumentInfo instance
+// is not modified once it is created (unless it is reassigned to).
+// There are a small number of use cases where it NEEDS to be modified,
+// and we have a limited number of methods that support this.
+class ModelArgumentInfo {
+ public:
+ ModelArgumentInfo() {}
+
+ static std::pair<int, ModelArgumentInfo> createFromPointer(
+ const hal::Operand& operand, const ANeuralNetworksOperandType* type,
+ void* data /* nullptr means HAS_NO_VALUE */, uint32_t length);
+ static std::pair<int, ModelArgumentInfo> createFromMemory(
+ const hal::Operand& operand, const ANeuralNetworksOperandType* type, uint32_t poolIndex,
+ uint32_t offset, uint32_t length);
+
+ enum State { POINTER, MEMORY, HAS_NO_VALUE, UNSPECIFIED };
+
+ State state() const { return mState; }
+
+ bool unspecified() const { return mState == UNSPECIFIED; }
+
+ void* buffer() const {
+ CHECK_EQ(mState, POINTER);
+ return mBuffer;
+ }
+
+ const std::vector<uint32_t>& dimensions() const {
+ CHECK(mState == POINTER || mState == MEMORY);
+ return mDimensions;
+ }
+ std::vector<uint32_t>& dimensions() {
+ CHECK(mState == POINTER || mState == MEMORY);
+ return mDimensions;
+ }
+
+ bool isSufficient() const {
+ CHECK(mState == POINTER || mState == MEMORY);
+ return mIsSufficient;
+ }
+ bool& isSufficient() {
+ CHECK(mState == POINTER || mState == MEMORY);
+ return mIsSufficient;
+ }
+
+ uint32_t length() const {
+ CHECK(mState == POINTER || mState == MEMORY);
+ return mLocationAndLength.length;
+ }
+
+ const hal::DataLocation& locationAndLength() const {
+ CHECK_EQ(mState, MEMORY);
+ return mLocationAndLength;
+ }
+ hal::DataLocation& locationAndLength() {
+ CHECK_EQ(mState, MEMORY);
+ return mLocationAndLength;
+ }
+
+ private:
+ int updateDimensionInfo(const hal::Operand& operand, const ANeuralNetworksOperandType* newType);
+
// Whether the argument was specified as being in a Memory, as a pointer,
// has no value, or has not been specified.
// If POINTER then:
- // locationAndLength.length is valid.
- // dimensions is valid.
- // buffer is valid
+ // mLocationAndLength.length is valid.
+ // mDimensions is valid.
+ // mBuffer is valid.
// If MEMORY then:
- // locationAndLength.{poolIndex, offset, length} is valid.
- // dimensions is valid.
- enum { POINTER, MEMORY, HAS_NO_VALUE, UNSPECIFIED } state = UNSPECIFIED;
- hal::DataLocation locationAndLength;
- std::vector<uint32_t> dimensions;
- void* buffer;
- bool isSufficient = true;
-
- int setFromPointer(const hal::Operand& operand, const ANeuralNetworksOperandType* type,
- void* buffer, uint32_t length);
- int setFromMemory(const hal::Operand& operand, const ANeuralNetworksOperandType* type,
- uint32_t poolIndex, uint32_t offset, uint32_t length);
- int updateDimensionInfo(const hal::Operand& operand, const ANeuralNetworksOperandType* newType);
+ // mLocationAndLength.{poolIndex, offset, length} is valid.
+ // mDimensions is valid.
+ State mState = UNSPECIFIED; // fixed at creation
+ void* mBuffer = nullptr; // fixed at creation
+ hal::DataLocation mLocationAndLength; // can be updated after creation
+ std::vector<uint32_t> mDimensions; // can be updated after creation
+ bool mIsSufficient = true; // can be updated after creation
};
// Convert ModelArgumentInfo to HIDL RequestArgument. For pointer arguments, use the location
diff --git a/nn/runtime/test/TestValidation.cpp b/nn/runtime/test/TestValidation.cpp
index 9e1f910..1b3e6be 100644
--- a/nn/runtime/test/TestValidation.cpp
+++ b/nn/runtime/test/TestValidation.cpp
@@ -1201,6 +1201,12 @@
EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, &kInvalidTensorType2, buffer,
sizeof(float)),
ANEURALNETWORKS_BAD_DATA);
+
+ // Cannot do this twice.
+ EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, buffer, 8),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, buffer, 8),
+ ANEURALNETWORKS_BAD_STATE);
}
TEST_F(ValidationTestExecution, SetOutput) {
@@ -1229,6 +1235,12 @@
EXPECT_EQ(ANeuralNetworksExecution_setOutput(mExecution, 0, &kInvalidTensorType2, buffer,
sizeof(float)),
ANEURALNETWORKS_BAD_DATA);
+
+ // Cannot do this twice.
+ EXPECT_EQ(ANeuralNetworksExecution_setOutput(mExecution, 0, nullptr, buffer, 8),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setOutput(mExecution, 0, nullptr, buffer, 8),
+ ANEURALNETWORKS_BAD_STATE);
}
TEST_F(ValidationTestExecution, SetInputFromMemory) {
@@ -1281,6 +1293,15 @@
memory, 0, sizeof(float)),
ANEURALNETWORKS_BAD_DATA);
+ // Cannot do this twice.
+ EXPECT_EQ(ANeuralNetworksExecution_setInputFromMemory(mExecution, 0, nullptr, memory, 0, 8),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setInputFromMemory(mExecution, 0, nullptr, memory, 0, 8),
+ ANEURALNETWORKS_BAD_STATE);
+ char buffer[memorySize];
+ EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, buffer, 8),
+ ANEURALNETWORKS_BAD_STATE);
+
// close memory
close(memoryFd);
}
@@ -1381,6 +1402,15 @@
memory, 0, sizeof(float)),
ANEURALNETWORKS_BAD_DATA);
+ // Cannot do this twice.
+ EXPECT_EQ(ANeuralNetworksExecution_setOutputFromMemory(execution, 0, nullptr, memory, 0, 8),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setOutputFromMemory(execution, 0, nullptr, memory, 0, 8),
+ ANEURALNETWORKS_BAD_STATE);
+ char buffer[memorySize];
+ EXPECT_EQ(ANeuralNetworksExecution_setOutput(execution, 0, nullptr, buffer, 8),
+ ANEURALNETWORKS_BAD_STATE);
+
// close memory
close(memoryFd);
}