IVGCVSW-3351 Run VTS tests

 * Added ArmnnBurstExecutorWithCache to fix test failures.
 * Added support for MeasureTiming to fix test failures.

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I12b7c6228354bac1f1a9b61ee78066219c0923ad
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp
index f03d69d..74da473 100644
--- a/ArmnnPreparedModel_1_2.cpp
+++ b/ArmnnPreparedModel_1_2.cpp
@@ -20,11 +20,22 @@
 using namespace android;
 using namespace android::hardware;
 
-static const Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
-
 namespace {
 
+static const Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
 using namespace armnn_driver;
+using TimePoint = std::chrono::steady_clock::time_point;
+
+TimePoint Now()
+{
+    return std::chrono::steady_clock::now();
+}
+
+unsigned long MicrosecondsDuration(TimePoint endPoint, TimePoint startPoint)
+{
+    return static_cast<unsigned long>(std::chrono::duration_cast<std::chrono::microseconds>(
+                                      endPoint - startPoint).count());
+}
 
 void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback, ErrorStatus errorStatus,
                             std::string callingFunction)
@@ -167,8 +178,8 @@
 
 template<typename HalVersion>
 Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const Request& request,
-                                                                      MeasureTiming,
-                                                                      V1_2::IPreparedModel::executeSynchronously_cb cb)
+                                                                      MeasureTiming measureTiming,
+                                                                      executeSynchronously_cb cb)
 {
     ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
     m_RequestCount++;
@@ -179,8 +190,16 @@
         return Void();
     }
 
+    TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
+
+    if (measureTiming == MeasureTiming::YES)
+    {
+        driverStart = Now();
+    }
+
     if (!android::nn::validateRequest(request, m_Model))
     {
+        ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model");
         cb(ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming);
         return Void();
     }
@@ -247,12 +266,21 @@
     ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution");
 
     DumpTensorsIfRequired("Input", *pInputTensors);
-
     // run it
     try
     {
+        if (measureTiming == MeasureTiming::YES)
+        {
+            deviceStart = Now();
+        }
+
         armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
 
+        if (measureTiming == MeasureTiming::YES)
+        {
+            deviceEnd = Now();
+        }
+
         if (status != armnn::Status::Success)
         {
             ALOGW("EnqueueWorkload failed");
@@ -277,11 +305,111 @@
         pool.update();
     }
     ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() after Execution");
-    cb(ErrorStatus::NONE, {}, g_NoTiming);
+
+    if (measureTiming == MeasureTiming::YES)
+    {
+        driverEnd = Now();
+        Timing timing;
+        timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
+        timing.timeInDriver = MicrosecondsDuration(driverEnd, driverStart);
+        ALOGV("ArmnnPreparedModel_1_2::executeSynchronously timing Device = %lu Driver = %lu", timing.timeOnDevice,
+                timing.timeInDriver);
+        cb(ErrorStatus::NONE, {}, timing);
+    }
+    else
+    {
+        cb(ErrorStatus::NONE, {}, g_NoTiming);
+    }
     return Void();
 }
 
 template<typename HalVersion>
+class ArmnnBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache
+{
+public:
+    ArmnnBurstExecutorWithCache(ArmnnPreparedModel_1_2<HalVersion>* preparedModel)
+        : m_PreparedModel(preparedModel)
+    {}
+
+    bool isCacheEntryPresent(int slot) const override
+    {
+        const auto it = m_MemoryCache.find(slot);
+        return (it != m_MemoryCache.end()) && it->second.valid();
+    }
+
+    void addCacheEntry(const hidl_memory& memory, int slot) override
+    {
+        m_MemoryCache[slot] = memory;
+    }
+
+    void removeCacheEntry(int slot) override
+    {
+        m_MemoryCache.erase(slot);
+    }
+
+    std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+            const Request& request, const std::vector<int>& slots,
+            MeasureTiming measure) override {
+        ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache::execute");
+        TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
+
+        if (measure == MeasureTiming::YES)
+        {
+            driverStart = Now();
+        }
+        hidl_vec<hidl_memory> pools(slots.size());
+
+        for (int slot : slots)
+        {
+            if (!isCacheEntryPresent(slot))
+            {
+                ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::no cache entry present");
+                return std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>(ErrorStatus::INVALID_ARGUMENT,
+                                                                              {},
+                                                                              g_NoTiming);
+            }
+            pools[slot] = m_MemoryCache[slot];
+        }
+
+        Request fullRequest = request;
+        fullRequest.pools = std::move(pools);
+
+        // Setup callback
+        ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
+        hidl_vec<OutputShape> returnedOutputShapes;
+        Timing returnedTiming;
+
+        auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](ErrorStatus status,
+                                                                            const hidl_vec<OutputShape>& outputShapes,
+                                                                            const Timing& timing)
+        {
+            returnedStatus = status;
+            returnedOutputShapes = outputShapes;
+            returnedTiming = timing;
+        };
+
+        // Execute
+        ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache executing");
+        Return<void> ret = m_PreparedModel->executeSynchronously(fullRequest, measure, cb);
+
+        if (!ret.isOk() || returnedStatus != ErrorStatus::NONE)
+        {
+            ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::error executing");
+            return std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>(returnedStatus, {}, returnedTiming);
+        }
+
+        return std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>(returnedStatus,
+                                                                      std::move(returnedOutputShapes),
+                                                                      returnedTiming);
+    }
+
+private:
+    Model m_Model;
+    ArmnnPreparedModel_1_2<HalVersion>* m_PreparedModel;
+    std::map<int, hidl_memory> m_MemoryCache;
+};
+
+template<typename HalVersion>
 Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
         const sp<V1_2::IBurstCallback>& callback,
         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
@@ -289,12 +417,17 @@
         V1_2::IPreparedModel::configureExecutionBurst_cb cb)
 {
     ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst");
-    const sp<V1_2::IBurstContext> burst =
-            ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
+    const std::shared_ptr<ArmnnBurstExecutorWithCache<HalVersion>> executorWithCache =
+            std::make_shared<ArmnnBurstExecutorWithCache<HalVersion>>(this);
+    const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
+            callback, requestChannel, resultChannel, executorWithCache);
 
-    if (burst == nullptr) {
+    if (burst == nullptr)
+    {
         cb(ErrorStatus::GENERAL_FAILURE, {});
-    } else {
+    }
+    else
+    {
         cb(ErrorStatus::NONE, burst);
     }
     return Void();