Merge "Fix hanging issue in ROI_ALIGN." into qt-dev
diff --git a/nn/runtime/test/Android.bp b/nn/runtime/test/Android.bp
index 4d9aaac..ef04e6b 100644
--- a/nn/runtime/test/Android.bp
+++ b/nn/runtime/test/Android.bp
@@ -60,7 +60,7 @@
         // Changes to this list must be reflected in cts/tests/tests/neuralnetworks/Android.mk
         // to ensure CTS tests coverage.
         "generated/tests/*.cpp",
-        "GeneratedUtils.cpp",
+        "TestGenerated.cpp",
         "TestMemory.cpp",
         "TestOperandExtraParams.cpp",
         "TestTrivialModel.cpp",
@@ -185,7 +185,7 @@
         "TestNeuralNetworksWrapper.cpp",
         "TestMain.cpp",
         "generated/tests/*.cpp",
-        "GeneratedUtils.cpp",
+        "TestGenerated.cpp",
     ],
     cflags: [
         "-DNNTEST_MULTITHREADED"
diff --git a/nn/runtime/test/GeneratedUtils.h b/nn/runtime/test/GeneratedUtils.h
deleted file mode 100644
index 6a0ea28..0000000
--- a/nn/runtime/test/GeneratedUtils.h
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_GENERATEDUTILS_H
-#define ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_GENERATEDUTILS_H
-
-#include "TestHarness.h"
-#include "TestNeuralNetworksWrapper.h"
-
-namespace generated_tests {
-using namespace android::nn::test_wrapper;
-
-void execute(std::function<void(Model*)> createModel, std::function<bool(int)> isIgnored,
-             std::vector<test_helper::MixedTypedExample>& examples, std::string dumpFile = "");
-
-}  // namespace generated_tests
-
-
-#endif  // ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_GENERATEDUTILS_H
diff --git a/nn/runtime/test/GeneratedUtils.cpp b/nn/runtime/test/TestGenerated.cpp
similarity index 61%
rename from nn/runtime/test/GeneratedUtils.cpp
rename to nn/runtime/test/TestGenerated.cpp
index 2c91174..8077850 100644
--- a/nn/runtime/test/GeneratedUtils.cpp
+++ b/nn/runtime/test/TestGenerated.cpp
@@ -14,12 +14,13 @@
  * limitations under the License.
  */
 
-#include "GeneratedUtils.h"
-
+#include "TestGenerated.h"
 #include "TestHarness.h"
 
 #include <gtest/gtest.h>
 
+#include <ftw.h>
+#include <unistd.h>
 #include <cassert>
 #include <cmath>
 #include <fstream>
@@ -66,9 +67,8 @@
         os << "],\n";
     });
 }
-}  // namespace
 
-static void printAll(std::ostream& os, const MixedTyped& test) {
+void printAll(std::ostream& os, const MixedTyped& test) {
     print(os, test.float32Operands);
     print(os, test.int32Operands);
     print(os, test.quant8AsymmOperands);
@@ -81,24 +81,31 @@
     static_assert(9 == MixedTyped::kNumTypes,
                   "Number of types in MixedTyped changed, but printAll function wasn't updated");
 }
+}  // namespace
 
-Compilation createAndCompileModel(Model* model, std::function<void(Model*)> createModel) {
-    NNTRACE_APP(NNTRACE_PHASE_PREPARATION, "createAndCompileModel");
-
-    createModel(model);
-    model->finish();
-
-    NNTRACE_APP_SWITCH(NNTRACE_PHASE_COMPILATION, "createAndCompileModel");
-    Compilation compilation(model);
-    compilation.finish();
-
-    return compilation;
+Compilation GeneratedTests::compileModel(const Model* model) {
+    NNTRACE_APP(NNTRACE_PHASE_COMPILATION, "compileModel");
+    if (mTestCompilationCaching) {
+        // Compile the model twice with the same token, so that compilation caching will be
+        // exercised if supported by the driver.
+        Compilation compilation1(model);
+        compilation1.setCaching(mCacheDir, mToken);
+        compilation1.finish();
+        Compilation compilation2(model);
+        compilation2.setCaching(mCacheDir, mToken);
+        compilation2.finish();
+        return compilation2;
+    } else {
+        Compilation compilation(model);
+        compilation.finish();
+        return compilation;
+    }
 }
 
-void executeWithCompilation(Model* model, Compilation* compilation,
-                            std::function<bool(int)> isIgnored,
-                            std::vector<MixedTypedExample>& examples,
-                            std::string dumpFile) {
+void GeneratedTests::executeWithCompilation(const Model* model, Compilation* compilation,
+                                            std::function<bool(int)> isIgnored,
+                                            std::vector<MixedTypedExample>& examples,
+                                            std::string dumpFile) {
     bool dumpToFile = !dumpFile.empty();
     std::ofstream s;
     if (dumpToFile) {
@@ -173,67 +180,86 @@
         exampleNo++;
 
         if (example.expectedMultinomialDistributionTolerance > 0) {
-          expectMultinomialDistributionWithinTolerance(test, example);
+            expectMultinomialDistributionWithinTolerance(test, example);
         }
     }
 }
 
-void executeOnce(std::function<void(Model*)> createModel,
-                 std::function<bool(int)> isIgnored,
-                 std::vector<MixedTypedExample>& examples,
-                 std::string dumpFile) {
+void GeneratedTests::executeOnce(const Model* model, std::function<bool(int)> isIgnored,
+                                 std::vector<MixedTypedExample>& examples, std::string dumpFile) {
     NNTRACE_APP(NNTRACE_PHASE_OVERALL, "executeOnce");
-    Model model;
-    Compilation compilation = createAndCompileModel(&model, createModel);
-    executeWithCompilation(&model, &compilation, isIgnored, examples, dumpFile);
+    Compilation compilation = compileModel(model);
+    executeWithCompilation(model, &compilation, isIgnored, examples, dumpFile);
 }
 
-
-void executeMultithreadedOwnCompilation(std::function<void(Model*)> createModel,
-                                        std::function<bool(int)> isIgnored,
-                                        std::vector<MixedTypedExample>& examples) {
+void GeneratedTests::executeMultithreadedOwnCompilation(const Model* model,
+                                                        std::function<bool(int)> isIgnored,
+                                                        std::vector<MixedTypedExample>& examples) {
     NNTRACE_APP(NNTRACE_PHASE_OVERALL, "executeMultithreadedOwnCompilation");
     SCOPED_TRACE("MultithreadedOwnCompilation");
     std::vector<std::thread> threads;
     for (int i = 0; i < 10; i++) {
-        threads.push_back(
-                std::thread([&]() { executeOnce(createModel, isIgnored, examples, ""); }));
+        threads.push_back(std::thread([&]() { executeOnce(model, isIgnored, examples, ""); }));
     }
-    std::for_each(threads.begin(), threads.end(), [](std::thread& t) {
-        t.join();
-    });
+    std::for_each(threads.begin(), threads.end(), [](std::thread& t) { t.join(); });
 }
 
-void executeMultithreadedSharedCompilation(std::function<void(Model*)> createModel,
-                                           std::function<bool(int)> isIgnored,
-                                           std::vector<MixedTypedExample>& examples) {
+void GeneratedTests::executeMultithreadedSharedCompilation(
+        const Model* model, std::function<bool(int)> isIgnored,
+        std::vector<MixedTypedExample>& examples) {
     NNTRACE_APP(NNTRACE_PHASE_OVERALL, "executeMultithreadedSharedCompilation");
     SCOPED_TRACE("MultithreadedSharedCompilation");
-    Model model;
-    Compilation compilation = createAndCompileModel(&model, createModel);
+    Compilation compilation = compileModel(model);
     std::vector<std::thread> threads;
     for (int i = 0; i < 10; i++) {
-        threads.push_back(std::thread([&]() {
-            executeWithCompilation(&model, &compilation, isIgnored, examples, "");
-        }));
+        threads.push_back(std::thread(
+                [&]() { executeWithCompilation(model, &compilation, isIgnored, examples, ""); }));
     }
-    std::for_each(threads.begin(), threads.end(), [](std::thread& t) {
-        t.join();
-    });
+    std::for_each(threads.begin(), threads.end(), [](std::thread& t) { t.join(); });
 }
 
-
 // Test driver for those generated from ml/nn/runtime/test/spec
-void execute(std::function<void(Model*)> createModel,
-             std::function<bool(int)> isIgnored,
-             std::vector<MixedTypedExample>& examples,
-             [[maybe_unused]] std::string dumpFile) {
+void GeneratedTests::execute(std::function<void(Model*)> createModel,
+                             std::function<bool(int)> isIgnored,
+                             std::vector<MixedTypedExample>& examples,
+                             [[maybe_unused]] std::string dumpFile) {
+    NNTRACE_APP(NNTRACE_PHASE_OVERALL, "execute");
+    Model model;
+    createModel(&model);
+    model.finish();
+    auto executeInternal = [&model, &isIgnored, &examples,
+                            this]([[maybe_unused]] std::string dumpFile) {
+        SCOPED_TRACE("TestCompilationCaching = " + std::to_string(mTestCompilationCaching));
 #ifndef NNTEST_MULTITHREADED
-    executeOnce(createModel, isIgnored, examples, dumpFile);
-#else  // defined(NNTEST_MULTITHREADED)
-    executeMultithreadedOwnCompilation(createModel, isIgnored, examples);
-    executeMultithreadedSharedCompilation(createModel, isIgnored, examples);
+        executeOnce(&model, isIgnored, examples, dumpFile);
+#else   // defined(NNTEST_MULTITHREADED)
+        executeMultithreadedOwnCompilation(&model, isIgnored, examples);
+        executeMultithreadedSharedCompilation(&model, isIgnored, examples);
 #endif  // !defined(NNTEST_MULTITHREADED)
+    };
+    mTestCompilationCaching = false;
+    executeInternal(dumpFile);
+    mTestCompilationCaching = true;
+    executeInternal("");
+}
+
+void GeneratedTests::SetUp() {
+    char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
+    char* cacheDir = mkdtemp(cacheDirTemp);
+    ASSERT_NE(cacheDir, nullptr);
+    mCacheDir = cacheDir;
+    mToken = std::vector<uint8_t>(ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN, 0);
+}
+
+void GeneratedTests::TearDown() {
+    if (!::testing::Test::HasFailure()) {
+        // TODO: Switch to std::filesystem::remove_all once libc++fs is made available in CTS.
+        // Remove the cache directory specified by path recursively.
+        auto callback = [](const char* child, const struct stat*, int, struct FTW*) {
+            return remove(child);
+        };
+        nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
+    }
 }
 
 }  // namespace generated_tests
diff --git a/nn/runtime/test/TestGenerated.h b/nn/runtime/test/TestGenerated.h
index 7fadc0f..ea26caa 100644
--- a/nn/runtime/test/TestGenerated.h
+++ b/nn/runtime/test/TestGenerated.h
@@ -17,17 +17,39 @@
 #ifndef ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_TESTGENERATED_H
 #define ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_TESTGENERATED_H
 
-#include "GeneratedUtils.h"
-#include "NeuralNetworksWrapper.h"
 #include "TestHarness.h"
+#include "TestNeuralNetworksWrapper.h"
 
 #include <gtest/gtest.h>
 
+using namespace android::nn::test_wrapper;
+using namespace test_helper;
+
 namespace generated_tests {
 
 class GeneratedTests : public ::testing::Test {
-protected:
-    virtual void SetUp() {}
+   protected:
+    virtual void SetUp() override;
+    virtual void TearDown() override;
+
+    Compilation compileModel(const Model* model);
+    void executeWithCompilation(const Model* model, Compilation* compilation,
+                                std::function<bool(int)> isIgnored,
+                                std::vector<MixedTypedExample>& examples, std::string dumpFile);
+    void executeOnce(const Model* model, std::function<bool(int)> isIgnored,
+                     std::vector<MixedTypedExample>& examples, std::string dumpFile);
+    void executeMultithreadedOwnCompilation(const Model* model, std::function<bool(int)> isIgnored,
+                                            std::vector<MixedTypedExample>& examples);
+    void executeMultithreadedSharedCompilation(const Model* model,
+                                               std::function<bool(int)> isIgnored,
+                                               std::vector<MixedTypedExample>& examples);
+    // Test driver for those generated from ml/nn/runtime/test/spec
+    void execute(std::function<void(Model*)> createModel, std::function<bool(int)> isIgnored,
+                 std::vector<MixedTypedExample>& examples, std::string dumpFile = "");
+
+    std::string mCacheDir;
+    std::vector<uint8_t> mToken;
+    bool mTestCompilationCaching;
 };
 
 // Tag for the dynamic output shape tests
@@ -35,7 +57,6 @@
 
 }  // namespace generated_tests
 
-using namespace test_helper;
 using namespace generated_tests;
 
 #endif  // ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_TESTGENERATED_H