arm_compute v19.11
diff --git a/examples/SConscript b/examples/SConscript
index a7cd401..44c14fb 100644
--- a/examples/SConscript
+++ b/examples/SConscript
@@ -86,6 +86,17 @@
alias = examples_env.Alias(example, prog)
Default(alias)
+if env['gemm_tuner'] and env['opencl']:
+ gemm_tuner_common_options = examples_env.Object("./gemm_tuner/CommonGemmExampleOptions.cpp")
+ for file in Glob("./gemm_tuner/cl_*.cpp"):
+ example = os.path.basename(os.path.splitext(str(file))[0])
+ example = os.path.join("gemm_tuner", example)
+ prog = examples_env.Program(example, ["{}.cpp".format(example), utils, gemm_tuner_common_options], CPPDEFINES=['ARM_COMPUTE_CL'], LIBS = examples_libs + arm_compute_libs)
+ Depends(prog, arm_compute_dependency)
+ prog = install_bin(prog)
+ alias = examples_env.Alias(example, prog)
+ Default(alias)
+
if env['neon']:
for file in Glob("./neon_*.cpp"):
example = os.path.basename(os.path.splitext(str(file))[0])
diff --git a/examples/gemm_tuner/CommonGemmExampleOptions.cpp b/examples/gemm_tuner/CommonGemmExampleOptions.cpp
new file mode 100644
index 0000000..a93d019
--- /dev/null
+++ b/examples/gemm_tuner/CommonGemmExampleOptions.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "CommonGemmExampleOptions.h"
+
+namespace gemm_tuner
+{
+using namespace arm_compute;
+using namespace utils;
+
+::std::ostream &operator<<(::std::ostream &os, const CommonGemmExampleParams &common_params)
+{
+ os << "M : " << common_params.M << std::endl;
+ os << "N : " << common_params.N << std::endl;
+ os << "K : " << common_params.K << std::endl;
+ os << "B : " << common_params.B << std::endl;
+ return os;
+}
+
+CommonGemmExampleOptions::CommonGemmExampleOptions(CommandLineParser &parser)
+ : help(parser.add_option<ToggleOption>("help")),
+ M(parser.add_positional_option<SimpleOption<size_t>>("M", 100)),
+ N(parser.add_positional_option<SimpleOption<size_t>>("N", 100)),
+ K(parser.add_positional_option<SimpleOption<size_t>>("K", 50)),
+ B(parser.add_positional_option<SimpleOption<size_t>>("B", 1))
+{
+ help->set_help("Show this help message.");
+ M->set_help("Number of lhs matrix rows.");
+ N->set_help("Number of rhs matrix columns.");
+ K->set_help("Number of lhs matrix columns/rhs matrix rows.");
+ B->set_help("Batch size.");
+}
+
+CommonGemmExampleParams consume_common_gemm_example_parameters(const CommonGemmExampleOptions &options)
+{
+ CommonGemmExampleParams common_params;
+ common_params.M = options.M->value();
+ common_params.N = options.N->value();
+ common_params.K = options.K->value();
+ common_params.B = options.B->value();
+ return common_params;
+}
+} // namespace gemm_tuner
diff --git a/examples/gemm_tuner/CommonGemmExampleOptions.h b/examples/gemm_tuner/CommonGemmExampleOptions.h
new file mode 100644
index 0000000..5f079ab
--- /dev/null
+++ b/examples/gemm_tuner/CommonGemmExampleOptions.h
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_EXAMPLES_GEMM_TUNER_COMMON_GEMM_EXAMPLE_OPTIONS
+#define ARM_COMPUTE_EXAMPLES_GEMM_TUNER_COMMON_GEMM_EXAMPLE_OPTIONS
+
+#include "utils/command_line/CommandLineOptions.h"
+#include "utils/command_line/CommandLineParser.h"
+
+namespace gemm_tuner
+{
+/** Structure holding all the common gemm example parameters */
+struct CommonGemmExampleParams
+{
+ size_t M{ 100 }; /**< Number of lhs matrix rows */
+ size_t N{ 100 }; /**< Number of rhs matrix columns */
+ size_t K{ 50 }; /**< Number of lhs matrix columns/rhs matrix rows */
+ size_t B{ 1 }; /**< Batch size */
+};
+
+/** Formatted output of the CommonGemmExampleParams type
+ *
+ * @param[out] os Output stream.
+ * @param[in] common_params Common parameters to output
+ *
+ * @return Modified output stream.
+ */
+::std::ostream &operator<<(::std::ostream &os, const CommonGemmExampleParams &common_params);
+
+/** Common command line options used to configure the gemm examples
+ *
+ * The options in this object get populated when "parse()" is called on the parser used to construct it.
+ * The expected workflow is:
+ *
+ * CommandLineParser parser;
+ * CommonOptions options( parser );
+ * parser.parse(argc, argv);
+ */
+class CommonGemmExampleOptions
+{
+public:
+ /** Constructor
+ *
+ * @param[in,out] parser A parser on which "parse()" hasn't been called yet.
+ */
+ CommonGemmExampleOptions(arm_compute::utils::CommandLineParser &parser);
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CommonGemmExampleOptions(const CommonGemmExampleOptions &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CommonGemmExampleOptions &operator=(const CommonGemmExampleOptions &) = delete;
+ /** Allow instances of this class to be moved */
+ CommonGemmExampleOptions(CommonGemmExampleOptions &&) = default;
+ /** Allow instances of this class to be moved */
+ CommonGemmExampleOptions &operator=(CommonGemmExampleOptions &&) = default;
+ /** Default destructor */
+ ~CommonGemmExampleOptions() = default;
+
+ arm_compute::utils::ToggleOption *help; /**< Show help option */
+ arm_compute::utils::SimpleOption<size_t> *M; /**< Number of lhs matrix rows option */
+ arm_compute::utils::SimpleOption<size_t> *N; /**< Number of rhs matrix columns option */
+ arm_compute::utils::SimpleOption<size_t> *K; /**< Number of lhs matrix columns/rhs matrix rows option */
+ arm_compute::utils::SimpleOption<size_t> *B; /**< Batch size option */
+};
+
+/** Consumes the common gemm example options and creates a structure containing all information
+ *
+ * @param[in] options Options to consume
+ *
+ * @return Structure containing the common gemm example parameters
+ */
+CommonGemmExampleParams consume_common_gemm_example_parameters(const CommonGemmExampleOptions &options);
+} // namespace gemm_tuner
+#endif /* ARM_COMPUTE_EXAMPLES_GEMM_TUNER_COMMON_GEMM_EXAMPLE_OPTIONS */
diff --git a/examples/gemm_tuner/GemmTuner.py b/examples/gemm_tuner/GemmTuner.py
new file mode 100644
index 0000000..29c414c
--- /dev/null
+++ b/examples/gemm_tuner/GemmTuner.py
@@ -0,0 +1,559 @@
+# Copyright (c) 2019 ARM Limited.
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+#!/usr/bin/python3
+
+import argparse
+import csv
+import json
+import logging
+import math
+import os
+from collections import Counter, defaultdict, deque, namedtuple
+from enum import Enum
+from pathlib import Path
+from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union
+
+################################################################################
+# Types
+################################################################################
+
+# Gemm strategy
+Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"])
+
+# Gemm parameter
+class GEMMParam(NamedTuple):
+ M: int # Number of lhs matrix rows
+ N: int # Number of rhs matrix columns
+ K: int # Number of lhs matrix columns/rhs matrix rows
+ B: int # Batch size
+
+ @staticmethod
+ def parse_from_strs(*args):
+ return GEMMParam(*map(int, args))
+
+ def __str__(self):
+ return "-".join(map(str, self))
+
+
+# Gemm configuration for strategy Native
+class NativeGEMMConfig(NamedTuple):
+ m0: int # Number of rows processed by the matrix multiplication
+ n0: int # Number of columns processed by the matrix multiplication
+ k0: int # Number of partial accumulations performed by the matrix multiplication
+
+ @staticmethod
+ def parse_from_strs(*args):
+ *mnk, = map(int, args)
+ return NativeGEMMConfig(*mnk)
+
+ def __str__(self):
+ return "-".join(map(str, self))
+
+
+# Gemm configuration for strategy Reshaped Only RHS
+class ReshapedOnlyRHSGEMMConfig(NamedTuple):
+ m0: int # Number of rows processed by the matrix multiplication
+ n0: int # Number of columns processed by the matrix multiplication
+ k0: int # Number of partial accumulations performed by the matrix multiplication
+ h0: int # Number of horizontal blocks of size (k0xn0) stored on the same output row
+ interleave_rhs: bool # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+ transpose_rhs: bool # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+
+ @staticmethod
+ def parse_from_strs(*args):
+ *mnkh, interleave_rhs, transpose_rhs = map(int, args)
+ interleave_rhs = interleave_rhs == 1
+ transpose_rhs = transpose_rhs == 1
+ return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs)
+
+ def __str__(self):
+ return "-".join(map(str, self))
+
+
+# Gemm configuration for strategy Reshaped
+class ReshapedGEMMConfig(NamedTuple):
+ m0: int # Number of rows processed by the matrix multiplication
+ n0: int # Number of columns processed by the matrix multiplication
+ k0: int # Number of partial accumulations performed by the matrix multiplication
+ v0: int # Number of vertical blocks of size (m0xk0) stored on the same output row
+ h0: int # Number of horizontal blocks of size (k0xn0) stored on the same output row
+ interleave_lhs: bool # Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
+ interleave_rhs: bool # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+ transpose_rhs: bool # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+
+ @staticmethod
+ def parse_from_strs(*args):
+ *mnkvh, interleave_lhs, interleave_rhs, transpose_rhs = map(int, args)
+ interleave_lhs = interleave_lhs == 1
+ interleave_rhs = interleave_rhs == 1
+ transpose_rhs = transpose_rhs == 1
+ return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs)
+
+ def __str__(self):
+ return "-".join(map(str, self))
+
+
+# Measurement we take from the benchmark result.
+class Measurement(NamedTuple):
+ opencl_timer_ms: float
+
+ def is_close_to(self, other, tol):
+ return math.fabs(self.opencl_timer_ms - other.opencl_timer_ms) < tol
+
+ def is_better_than(self, other, tol):
+ return self < other and not self.is_close_to(other)
+
+ def __add__(self, other):
+ return Measurement(self.opencl_timer_ms + other.opencl_timer_ms)
+
+ def __sub__(self, other):
+ return Measurement(self.opencl_timer_ms - other.opencl_timer_ms)
+
+ def __mul__(self, other):
+ return Measurement(self.opencl_timer_ms * other.opencl_timer_ms)
+
+ def __floordiv__(self, other):
+ return Measurement(self.opencl_timer_ms // other.opencl_timer_ms)
+
+ def __truediv__(self, other):
+ return Measurement(self.opencl_timer_ms / other.opencl_timer_ms)
+
+ def __pow__(self, power):
+ return Measurement(self.opencl_timer_ms ** power)
+
+
+# GEMMConfig Type
+GEMMConfigT = Union[NativeGEMMConfig, ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
+
+
+# Representation of the benchmark result from a single experiment
+class BenchmarkResult(NamedTuple):
+ gemm_param: GEMMParam
+ strategy: Strategy
+ gemm_config: GEMMConfigT
+ measurement: Measurement
+
+
+# Representation of a single row of BenchmarkResult in CSV
+# NOTE: In the CSV representation, we merge all fields of Gemm Config into a single field "GEMMConfig", but keep the
+# fields of GEMMParam and Measurement
+# The example entry including header would look like:
+# M , N , K , B, Strategy , GEMMConfig , OpenCLTimer_MS
+# 1225, 32, 192, 1, Reshaped , 4-4-4-3-1-1-1-0 , 0.3309
+BenchmarkResultCSVRow = namedtuple(
+ "BenchmarkResultCSVRow", GEMMParam._fields + ("Strategy", "GEMMConfig") + Measurement._fields
+)
+
+
+def benchmark_result_2_csv_row(result: BenchmarkResult) -> BenchmarkResultCSVRow:
+ """ Convert a BenchmarkResult into its CSV row form """
+ return BenchmarkResultCSVRow(
+ *(result.gemm_param + (result.strategy.name, str(result.gemm_config)) + result.measurement)
+ )
+
+
+class GEMMBenchmarkResultRecorder:
+ """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record.
+ """
+
+ SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"])
+
+ def __init__(self, tol=0.01):
+ """ Initializer
+ """
+ self._benchmark_result_record: List[BenchmarkResult] = []
+ # Strategies recorded
+ self._strategies = set()
+ self._tol = tol
+
+ def add(self, benchmark_result: BenchmarkResult):
+ """ Add a benchmark result to the record.
+ """
+ gemm_param, strategy, gemm_config, measurement = benchmark_result
+ # Update strategies encoutnered
+ self._strategies.add(strategy)
+
+ self._benchmark_result_record.append(benchmark_result)
+
+ def get_record(self) -> Generator[BenchmarkResult, None, None]:
+ """ Return an iterator that iterates over the record.
+ """
+ yield from self._benchmark_result_record
+
+ def get_best_gemm_configs(self):
+ """ Get the best GEMMConfig set per GEMMParam per Strategy
+ """
+ best_gc_sets: Dict[
+ Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]]
+ ] = defaultdict(list)
+ for gemm_param, strategy, gemm_config, measurement in self.get_record():
+ best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), [])
+ best_gc_set.append((gemm_config, measurement))
+ # Sort the best config set (list)
+ best_gc_set = sorted(best_gc_set, key=lambda gc_and_m: gc_and_m[1])
+ # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement
+ best_gc, best_m = best_gc_set[0]
+ best_gc_set_new = [
+ (gemm_config, measurement)
+ for gemm_config, measurement in best_gc_set[1:]
+ if measurement.is_close_to(best_m, self._tol)
+ ]
+ # Add back the best config
+ best_gc_set_new.insert(0, (best_gc, best_m))
+ best_gc_sets[(gemm_param, strategy)] = best_gc_set_new
+
+ return best_gc_sets
+
+ def get_best_gemm_configs_as_sequence(self):
+ """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence
+ of BenchmarkResults
+ """
+ for (gemm_param, strategy), best_gc_sets in self.get_best_gemm_configs().items():
+ for best_gemm_config, best_measurement in best_gc_sets:
+ yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement)
+
+ def get_config_distributions(self):
+ """ Return GEMMConfigDistribution for each strategy
+ """
+ gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict(
+ GEMMConfigDistribution
+ )
+ for benchmark_result in self.get_best_gemm_configs_as_sequence():
+ _, strategy, _, _ = benchmark_result
+ gemm_config_distributions[strategy].add(benchmark_result)
+
+ return gemm_config_distributions
+
+ def save_to_csvs(self, out_dir, only_best_config=True):
+ """ Save records to an output directory of csv files.
+ The directory is organized such that each strategy gets its own CSV file.
+ """
+ if not os.path.exists(out_dir):
+ logging.info("Output directory {} does not exist. Creating...".format(out_dir))
+ os.mkdir(out_dir)
+ for strategy in self._strategies:
+ out_csv_path = os.path.join(out_dir, strategy.name)
+ if os.path.exists(out_csv_path):
+ overwrite = (
+ input(
+ "Output CSV {} already exists. Overwrite? [Y/N]: ".format(out_csv_path)
+ ).lower()
+ == "y"
+ )
+ if not overwrite:
+ logging.info("Skipping {}".format(out_csv_path))
+ continue
+ logging.info("Saving csv file to {}".format(out_csv_path))
+ record = (
+ self.get_best_gemm_configs_as_sequence() if only_best_config else self.get_record()
+ )
+ with open(out_csv_path, "w") as f:
+ csv_writer = csv.DictWriter(f, fieldnames=BenchmarkResultCSVRow._fields)
+ csv_writer.writeheader()
+ csv_writer.writerows(
+ benchmark_result_2_csv_row(res)._asdict()
+ for res in record
+ if res.strategy == strategy
+ )
+ logging.info("Saved")
+
+ def summary(self, sum_level=SummaryLevel.Short):
+ """ Return the summary string of the record
+ """
+ num_raw_records = sum(1 for _ in self.get_record())
+ gemm_params_per_strategy = defaultdict(list)
+ for gemm_param, strategy in self.get_best_gemm_configs().keys():
+ gemm_params_per_strategy[strategy].append(gemm_param)
+ global_summary = f"""
+=== {self.__class__.__name__} Summary ===
+[Global]
+Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))}
+Total number of results recorded: {num_raw_records}
+
+[Per strategy]
+ """
+ strategy_summaries = []
+ for strategy in gemm_params_per_strategy:
+ summary = f"""
+Strategy {strategy.name}:
+GEMM parameters:
+ Number of: {len(gemm_params_per_strategy[strategy])}
+ """
+ if sum_level == self.__class__.SummaryLevel.Detailed:
+ summary += f"""
+ Content: {gemm_params_per_strategy[strategy]}
+ """
+ strategy_summaries.append(summary)
+ return global_summary + "".join(strategy_summaries)
+
+
+class GEMMConfigDistribution:
+ """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder.
+ """
+
+ def __init__(self):
+ """ Initializer
+ """
+ self._gemm_config_dist: Dict[GEMMConfig, List[Tuple[GEMMParam, Measurement]]] = defaultdict(
+ list
+ )
+ self._gemm_config_freq = Counter()
+
+ def add(self, benchmark_result: BenchmarkResult):
+ """ Add a benchmark result to the distribution
+ """
+ gemm_param, _, gemm_config, measurement = benchmark_result
+ self._gemm_config_dist[gemm_config].append((gemm_param, measurement))
+ self._gemm_config_freq[gemm_config] += 1
+
+ def distribution(self):
+ return self._gemm_config_dist
+
+ def frequency(self):
+ """ Get the frequency of each (best) gemm config recorded
+ """
+ return self._gemm_config_freq.most_common()
+
+ def best_config(self):
+ """ Get the overall best config, as voted by all benchmark results.
+ """
+ return self._gemm_config_freq.most_common(1)
+
+ def std(self):
+ """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values
+ as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger.
+ """
+ freqs = self._gemm_config_freq.values()
+ if len(freqs) == 0:
+ return 0
+ mean_freq = sum(freqs) / len(freqs)
+ return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs))
+
+
+################################################################################
+# Globals
+################################################################################
+
+# Gemm config type factory
+# Produces a GEMMConfig type specific to a Strategy
+GEMM_CONFIG_FACTORY = {
+ Strategy.Native: NativeGEMMConfig,
+ Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig,
+ Strategy.Reshaped: ReshapedGEMMConfig,
+}
+
+# Mapping from example binary name to Strategy
+# Assume 1-to-1 mapping
+EXAMPLE_FILE_2_STRATEGY = {
+ "benchmark_cl_gemm_native": Strategy.Native,
+ "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS,
+ "benchmark_cl_gemm_reshaped": Strategy.Reshaped,
+}
+
+# Gemm example arguments type factory
+# Produces a Gemm_Example_Args type specific to a Strategy
+# Gemm example arguments consist of:
+# GEMMParam + GEMMConfig
+# in that order.
+# For example, the example args of running a reshaped rhs only example could be:
+# 100,100,100,1, 4, 4, 4, 1, 1, 1
+# M ,N ,K, B,m0,n0,k0,h0,interleave_rhs,transpose_rhs
+# <-GEMMParam-><-------------GEMMConfig-------------->
+# Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases
+GEMM_EXAMPLE_ARGS_FACTORY = {
+ strategy: namedtuple(
+ "{}_Gemm_Example_Args".format(strategy_name),
+ GEMMParam._fields + GEMM_CONFIG_FACTORY[strategy]._fields,
+ )
+ for strategy_name, strategy in Strategy.__members__.items()
+ if strategy_name == strategy.name
+}
+
+# File extension used for benchmark result json files
+BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark"
+
+################################################################################
+# Functions
+################################################################################
+
+
+def parse_benchmark_commandline(commandline: str) -> Dict[str, str]:
+ """ Parse the benchmark example command-line string into a dictionary of command-line agruments
+ """
+ args = commandline.split()
+ # Discard program name
+ args = args[1:]
+ # Split into a list of (argument name, argument value)
+ args = map(lambda arg: arg.split("="), args)
+
+ def transform(_name):
+ # Strip '-'/"--" if it exists
+ _name = _name.lstrip("-")
+ return _name
+
+ return {transform(name): val for name, val in args}
+
+
+def extract_benchmark_results(
+ json_results: Dict, measurement_method="avg"
+) -> Generator[BenchmarkResult, None, None]:
+ """ Parse the benchmark result and extract relevant information, namely:
+ GEMM param,
+ Strategy,
+ GEMM config,
+ Measurements
+ """
+ for json_res in json_results:
+ # Get example test and test data.
+ # There should only be 1 test per run
+ example_tests = list(json_res["tests"].items())
+ assert len(example_tests) == 1
+ example_fn, example_test_data = example_tests[0]
+
+ # Process example file name
+ example_fn = example_fn.split(os.path.sep)[-1]
+
+ # Get strategy
+ strategy = EXAMPLE_FILE_2_STRATEGY[example_fn]
+
+ # Get gemm params + gemm configs from example args
+ benchmark_args = parse_benchmark_commandline(json_res["CommandLine"])
+ Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy]
+ example_args = Gemm_Example_Args_T(*(benchmark_args["example_args"].split(",")))
+ # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order)
+ gemm_param_fields_len = len(GEMMParam._fields)
+ gemm_param = GEMMParam.parse_from_strs(*example_args[:gemm_param_fields_len])
+ GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
+ gemm_config = GEMMConfig.parse_from_strs(*example_args[gemm_param_fields_len:])
+
+ # Get OpenCL_Time_Ms stats
+ measurements = list(example_test_data["measurements"].items())
+ # There should only be 1 instrument per run
+ assert len(measurements) == 1
+ measurement_instrument, data = measurements.pop()
+ # Get instrument name and assert that it is the one we expect
+ measurement_instrument_name = measurement_instrument.split("/")[0]
+ assert measurement_instrument_name == "OpenCLTimer"
+ # Take either the minimum or the average of the raw data as the measurement value
+ if measurement_method == "min":
+ measurement_val = min(data["raw"])
+ elif measurement_method == "avg":
+ measurement_val = sum(data["raw"]) / len(data["raw"])
+ else:
+ raise ValueError("Invalid measurement method: {}".format(measurement_method))
+
+ measurement = Measurement(measurement_val)
+
+ yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
+
+
+def parse_json(dir_name):
+ """ Glob all benchmark result json files and parse them into json objects (dicts).
+ """
+ for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)):
+ with open(res_fn) as res_fp:
+ yield json.load(res_fp)
+
+
+################################################################################
+# Main
+################################################################################
+
+
+def main(args):
+ logging.info("Searching best gemm configurations from {}".format(args.benchmark_results_dir))
+
+ benchmark_results = extract_benchmark_results(parse_json(args.benchmark_results_dir))
+
+ # Add all benchmark results to the recorder
+ benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance)
+ for benchmark_result in benchmark_results:
+ benchmark_result_recorder.add(benchmark_result)
+
+ if args.debug:
+ recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed
+ else:
+ recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
+
+ # Print overall summary of the recorded results
+ logging.info(benchmark_result_recorder.summary(sum_level=recorder_sum_level))
+
+ # Get GEMM configuration distributions for each strategy
+ all_config_dists = benchmark_result_recorder.get_config_distributions()
+
+ logging.info("=== Result ===")
+ for strategy, config_dist in all_config_dists.items():
+ logging.info("Strategy: {}".format(strategy.name))
+ logging.debug("GEMM Config, Votes")
+ for config, freq in config_dist.frequency():
+ logging.debug("{}, {}".format(config, freq))
+ logging.info(
+ "Best GEMM Config: {} with std: {}".format(config_dist.best_config(), config_dist.std())
+ )
+
+ # Save the recorded results to csv files in output directory
+ if args.output_dir is not None:
+ benchmark_result_recorder.save_to_csvs(args.output_dir, only_best_config=(not args.debug))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="CL GEMM Tuner")
+ parser.add_argument(
+ "-b",
+ "--benchmark_results",
+ dest="benchmark_results_dir",
+ metavar="PATH",
+ action="store",
+ type=str,
+ help="Path to benchmark result directory, where benchmark result json files have a file \
+ extension of '{}'".format(
+ BENCHMARK_RESULT_JSON_EXTENSION
+ ),
+ required=True,
+ )
+ parser.add_argument(
+ "-o",
+ "--output_dir",
+ dest="output_dir",
+ metavar="PATH",
+ action="store",
+ type=str,
+ help="Path to directory that holds output csv files. One per strategy",
+ )
+ parser.add_argument(
+ "-t",
+ "--tolerance",
+ action="store",
+ type=float,
+ default=0.01,
+ help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\
+ milliseconds. Recommended value: <= 0.1 ms",
+ )
+ parser.add_argument(
+ "-D", "--debug", dest="debug", action="store_true", help="Enable script debugging output"
+ )
+ args = parser.parse_args()
+ logging_level = logging.DEBUG if args.debug else logging.INFO
+ logging.basicConfig(level=logging_level)
+ logging.debug("Arguments: {}".format(args))
+ main(args)
diff --git a/examples/gemm_tuner/README.md b/examples/gemm_tuner/README.md
new file mode 100644
index 0000000..3238a9d
--- /dev/null
+++ b/examples/gemm_tuner/README.md
@@ -0,0 +1,89 @@
+# Gemm Tuner
+
+## Introduction
+
+This is a set of 2 script tools for tuning the performance of OpenCL GEMM
+kernels (limited to Convolution layer functions only for now). Specifically, we
+tune 3 GEMM kernels, each has a different implementation strategy of the GEMM
+operation: native, reshaped, reshaped only rhs. The details of these strategies
+can be found in the documentations of the corresponding kernels:
+CLGEMMMatrixMultiplyNativeKernel, CLGEMMMatrixMultiplyReshapedKernel and
+CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.
+
+The outputs of the tuning process are 1 optimal configuration (called GEMM
+Configuration or GEMMConfig) for each of the 3 strategies.
+
+## Approach
+
+This section gives a brief description and rationale of the approach adopted by
+the current version of GEMM Tuner.
+
+As explained in the Introduction section, the outputs of the tuner are 1 optimal
+GEMMConfig for each strategy. This is because we can only integrate 1 GEMMConfig
+for each strategy in ACL at compile time. In theory, however, the optimal
+GEMMConfig also depends on different parameters of GEMM (called GEMM Parameter
+or GEMMParam, e.g.: the shape of the operation); thus ideally, for each
+strategy, the optimal configurations should be a mapping from GEMMParam to
+GEMMConfig instead of a single GEMMConfig.
+
+To address this issue, we ensure the one single optimal GEMMConfig can
+generalise well to all potential GEMMParams (or at least the ones that we care
+about). The approach we adopt involves a preliminary stage where a collection of
+common GEMMParams (GEMM shapes from popular networks) are compiled. Then, to
+reduce the final tuning time, rather contradictorily, we spend a lot of time
+searching for near-optimal GEMMConfigs for each GEMMParam first, and then
+discard redundant GEMMParams which share similar optimal GEMMConfigs with
+others. The resultant list of GEMMParams is called a __GEMMParam archetype
+list__, as in these GEMMParams are typical enough to capture the space of
+GEMMParams that we care about.
+
+During this preliminary stage we also produce a list of good GEMMConfigs that
+can be used to search for the optimal one in the actual tuning stage. This,
+again, is to reduce the tuning time, and the resultant list is called a
+__GEMMConfig search list__.
+
+The GEMMParam archetype list and the GEMMConfig search list are investigated and
+prepared by the developers; the users of GEMM tuner need not worry about
+producing them, but they need to obtain them prior to running the tuner.
+
+Once these two lists (2 for each strategy, so 6 in total) are obtained, they can
+be fed to the tuner, to produce the optimal GEMMConfig(s).
+
+## Pre-requisite
+* A target device (Android phones, Linux boards, e.t.c.), on which to tune the
+ GEMM kernels, plus these on the device:
+ * (Preferably) Bash shell
+ * Built ACL with benchmark examples
+ * GEMMParam archetype list
+ * GEMMConfig search list
+* A host machine, plus these on the machine:
+ * python >= 3.6
+
+## Usage
+
+The tuning stage consists of 2 steps:
+
+1. Run benchmarks: Run the runner shell script (benchmark_gemm_examples.sh) on
+your target device. Note that all the built benchmark examples have to be
+present on your target device prior to running. The script will run the selected
+strategy, over all configs defined in GEMMConfig search list, on all GEMMParams
+inside the GEMMParam archetype list, and then save the benchmark results to json
+files in an output directory.
+```
+[$SHELL] ./benchmark_gemm_examples.sh -s \<strategy\> -e \<example_binary_dir\>
+-g \<gemmparam_archetype_list\> -c \<gemmconfig_search_list\> [-o \<out_dir\>]
+```
+2. Run analyser: Run the python script (GemmTuner.py) on your host machine.
+You'll need to transfer all the benchmark result json files generated from the
+previous step to your host machine beforehand. Note that this requires python >=
+3.6. The script will output the best configuration, along with some analysis
+statistics for each strategy, and optionally save the parsed benchmark results
+into csv files (one for each strategy) for further analysis.
+An optional tolerance in milliseconds in OpenCl timer is provided to determine
+how far apart in performance two GEMMConfigs have to be, to be considered
+different. A default value of 0.01 ms is used, and it's recommended this value
+should be < 0.1 ms.
+```
+python GemmTuner.py -b \<benchmark_results_dir\> [-t \<tolerance\>]
+[-o \<out_dir\>]
+```
diff --git a/examples/gemm_tuner/benchmark_gemm_examples.sh b/examples/gemm_tuner/benchmark_gemm_examples.sh
new file mode 100755
index 0000000..95bb367
--- /dev/null
+++ b/examples/gemm_tuner/benchmark_gemm_examples.sh
@@ -0,0 +1,458 @@
+# Copyright (c) 2019 ARM Limited.
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+#!/bin/sh
+
+# Global: Global variables and global settings {{{
+# Treat unset variables as an error when substituting
+set -u
+
+CMD=$( basename $0 )
+
+# All supported strategy options
+ALL_STRATEGY_OPTIONS=("native" "reshaped_rhs_only" "reshaped")
+
+# Names of example binary for each strategy
+EXAMPLE_BIN_NATIVE="benchmark_cl_gemm_native"
+EXAMPLE_BIN_RESHAPED_RHS_ONLY="benchmark_cl_gemm_reshaped_rhs_only"
+EXAMPLE_BIN_RESHAPED="benchmark_cl_gemm_reshaped"
+
+# Default output directory
+DEFAULT_OUT_DIR="out"
+
+# Number of iterations for each benchmark run
+NUM_ITERATION=5
+# Global }}}
+
+# Functions {{{
+#######################################
+# Print gemm shape file help message
+# Globals:
+# None
+# Arguments:
+# None
+# Returns:
+# None
+#######################################
+function help_gemm_shape_file() {
+ cat >&2 << EOF
+Gemm shape file:
+ Gemm shape file is a headerless csv file with fields separated by commas and commas only (there cannot be whitespaces
+ around each field).
+ A gemm shape is a list of 4 positive integers <M, N, K, B> describing the shapes of the two matrices (LHS and RHS)
+ with:
+ M - Number of lhs matrix rows
+ N - Number of rhs matrix columns
+ K - Number of lhs matrix columns/rhs matrix rows
+ B - Batch size
+
+ An example gemm shape file looks like:
+ 100,100,30,1
+ 100,100,30,3
+ ...
+
+EOF
+}
+
+#######################################
+# Print gemm config file for native help message
+# Globals:
+# None
+# Arguments:
+# None
+# Returns:
+# None
+#######################################
+function help_gemm_config_file_native() {
+ cat >&2 << EOF
+Gemm config file (Strategy native):
+ Gemm config file is a headerless csv file with fields separated by commas and commas only (there cannot be whitespaces
+ around each field).
+ A gemm config is a list of 4 positive integers <m0, n0, k0, h0> and 2 boolean values interleave_rhs and transpose_rhs, with:
+ m0 - Number of rows processed by the matrix multiplication
+ n0 - Number of columns processed by the matrix multiplication
+ k0 - Number of partial accumulations performed by the matrix multiplication
+
+ Only the following configurations of M0, N0 and K0 are currently supported:
+ M0 = 1, 2, 3, 4, 5, 6, 7, 8
+ N0 = 2, 3, 4, 8, 16
+ K0 = 2, 3, 4, 8, 16
+
+ An example gemm config file looks like:
+ 1,4,4
+ 2,3,8
+ ...
+
+EOF
+}
+
+#######################################
+# Print gemm config file for reshaped_rhs_only help message
+# Globals:
+# None
+# Arguments:
+# None
+# Returns:
+# None
+#######################################
+function help_gemm_config_file_reshaped_rhs_only() {
+ cat >&2 << EOF
+Gemm config file (Strategy reshaped_rhs_only):
+ Gemm config file is a headerless csv file with fields separated by commas and commas only (there cannot be whitespaces
+ around each field).
+ A gemm config is a list of 4 positive integers <m0, n0, k0, h0> and 2 boolean values interleave_rhs and transpose_rhs, with:
+ m0 - Number of rows processed by the matrix multiplication
+ n0 - Number of columns processed by the matrix multiplication
+ k0 - Number of partial accumulations performed by the matrix multiplication
+ h0 - Number of horizontal blocks of size (k0xn0) stored on the same output row
+ interleave_rhs - Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+ transpose_rhs - Transpose rhs matrix (1) / Do not transpose rhs matrix (0)
+
+ Only the following configurations of M0, N0 and K0 are currently supported:
+ M0 = 1, 2, 3, 4, 5, 6, 7, 8
+ N0 = 2, 3, 4, 8, 16
+ K0 = 2, 3, 4, 8, 16
+ H0 >= 1
+
+ An example gemm config file looks like:
+ 4,4,4,1,1,1
+ 4,4,4,3,1,0
+ ...
+
+EOF
+}
+
+#######################################
+# Print gemm config file for reshaped help message
+# Globals:
+# None
+# Arguments:
+# None
+# Returns:
+# None
+#######################################
+function help_gemm_config_file_reshaped() {
+ cat >&2 << EOF
+Gemm config file (Strategy reshaped):
+ Gemm config file is a headerless csv file with fields separated by commas and commas only (there cannot be whitespaces
+ around each field).
+ A gemm config is a list of 5 positive integers <m0, n0, k0, v0, h0> and 3 boolean values interleave_lhs, interleave_rhs and transpose_rhs, with:
+ m0 - Number of rows processed by the matrix multiplication
+ n0 - Number of columns processed by the matrix multiplication
+ k0 - Number of partial accumulations performed by the matrix multiplication
+ v0 - Number of vertical blocks of size (m0xk0) stored on the same output row
+ h0 - Number of horizontal blocks of size (k0xn0) stored on the same output row
+ interleave_lhs - Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
+ interleave_rhs - Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+ transpose_rhs - Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+
+ If rhs matrix is transposed only the following configurations are currently supported:
+ M0 = 2, 3, 4, 5, 6, 7, 8
+ N0 = 2, 3, 4, 8, 16
+ K0 = 2, 3, 4, 8, 16
+ V0 >= 1
+ H0 >= 1
+
+ If lhs matrix is transposed only the following configurations are currently supported:
+ M0 = 2, 3, 4, 8
+ N0 = 2, 3, 4, 8, 16
+ K0 = 2, 3, 4, 8, 16
+ V0 >= 1
+ H0 >= 1
+
+ An example gemm config file looks like:
+ 4,4,4,1,3,1,1,1
+ 4,4,4,3,3,1,1,0
+ ...
+
+EOF
+}
+
+#######################################
+# Print usage of this program and exit with Error
+# Globals:
+# Assumes all globals are required
+# Arguments:
+# None
+# Returns:
+# Error(1)
+#######################################
+function usage() {
+ cat >&2 << EOF
+Run gemm examples of a selected strategy, over provided tunable configurationsa and gemm shapes.
+Save the benchmark results to json files in an output directory.
+
+Usage: ${CMD} [-h] -s <strategy> -e <example_binary_dir> -g <gemm_shape_file> -c <gemm_config_file> [-o <out_dir>]
+
+Options:
+ -h
+ Print help messages. If a strategy is specified with -s <strategy>, then only display messages relevant to that
+ strategy. Otherwise if no strategy is specified, display messages for all available strategies.
+
+ -s <strategy>
+ Strategy option.
+ Options: ${ALL_STRATEGY_OPTIONS[@]}.
+
+ -e <example_binary_dir>
+ Path to directory that holds all example binaries
+
+ -g <gemm_shape_file>
+ Path to gemm shape csv file
+
+ -c <gemm_config_file>
+ Path to gemm config csv file
+
+ -o <out_dir>
+ Path to output directory that holds output json files
+ Default: ${DEFAULT_OUT_DIR}
+
+EOF
+# Print help messages about gemm shapes and various gemm configs
+$HELP && help_gemm_shape_file
+$HELP && ( [ "${STRATEGY_OPTION}" == "" ] || [ "${STRATEGY_OPTION}" == "native" ] ) && help_gemm_config_file_native
+$HELP && ( [ "${STRATEGY_OPTION}" == "" ] || [ "${STRATEGY_OPTION}" == "reshaped_rhs_only" ] ) && help_gemm_config_file_reshaped_rhs_only
+$HELP && ( [ "${STRATEGY_OPTION}" == "" ] || [ "${STRATEGY_OPTION}" == "reshaped" ] ) && help_gemm_config_file_reshaped
+exit 1
+}
+
+#######################################
+# Print error message and exit with Error.
+# Globals:
+# None
+# Arguments:
+# $1 - Error message
+# Returns:
+# None
+#######################################
+function error_msg() {
+ echo "Error: $1" 1>&2
+ exit 1
+}
+
+#######################################
+# Convert string to lower-case
+# Globals:
+# None
+# Arguments:
+# target - String
+# Returns:
+# (stdout) - String in lowercase
+#######################################
+function to_lower() {
+ local target=$1
+ echo "$target" | tr '[:upper:]' '[:lower:]'
+}
+
+#######################################
+# Test if the argument is an integer
+# Globals:
+# None
+# Arguments:
+# in - Input
+# Returns:
+# true/false
+#######################################
+function is_integer() {
+ local in=$1
+ [ "$in" -eq "$in" ] 2> /dev/null
+}
+
+#######################################
+# Test if a string is in an array of strings
+# Globals:
+# None
+# Arguments:
+# target - String to test
+# array - Array of strings to search
+# Returns:
+# true/false
+#######################################
+function arr_contains() {
+ local target=$1
+ shift
+ local array
+ array=("$@")
+ for s in "${array[@]}"
+ do
+ [ "$s" == "${target}" ] && return
+ done
+ false
+}
+
+#######################################
+# Run a single example with all tunable gemm configurations on all gemm parameters
+# Globals:
+# OUT_DIR
+# OUT_EXTENSION
+# EXAMPLE_BIN_DIR
+# NUM_ITERATION
+# GEMM_CONFIGS_FILE
+# GEMM_SHAPES_FILE
+# Arguments:
+# example_bin Name of the example binary to run
+# Returns:
+# None
+#######################################
+function run() {
+ local example_bin=$1
+ echo "Running all configs for ${example_bin}" 1>&2
+ local example_args
+ local expr_count=1
+ # Total number of experiment runs scheduled for this session
+ local total_num_experiment
+ local num_params
+ local num_configs
+ num_params=$( wc -l ${GEMM_SHAPES_FILE} | cut -d " " -f 1)
+ num_configs=$( wc -l ${GEMM_CONFIGS_FILE} | cut -d " " -f 1 )
+ (( total_num_experiment=${num_params} * ${num_configs} ))
+ # Time elapsed since the beginning in seconds
+ local time_elapsed_s
+ # Time estimated to finish in seconds
+ local time_est_s
+ echo "Running a total number of ${total_num_experiment} experiments" 1>&2
+
+ while read gemm_shape
+ do
+ while read gemm_config
+ do
+ echo "Running..." 1>&2
+ example_args="${gemm_shape},${gemm_config}"
+ # Run experiment
+ ${EXAMPLE_BIN_DIR}/${example_bin} --example_args=${example_args} --iterations=${NUM_ITERATION} --json-file=${OUT_DIR}/${expr_count}.${OUT_EXTENSION} --instruments=OPENCL_TIMER_MS
+ # Print progress
+ print_progress ${expr_count} ${total_num_experiment}
+ # Print time statistics
+ time_elapsed_s=$SECONDS
+ echo "Time elapsed since beginning: $(( $time_elapsed_s / 60 ))m $(( $time_elapsed_s % 60 ))s" 1>&2
+ (( time_est_s=(${total_num_experiment} - ${expr_count}) * ${time_elapsed_s} / ${expr_count} ))
+ echo "Time estimated to finish: $(( $time_est_s / 60 ))m $(( $time_est_s % 60 ))s" 1>&2
+ (( expr_count++ ))
+ echo "Done." 1>&2
+ done < "${GEMM_CONFIGS_FILE}"
+ done < "${GEMM_SHAPES_FILE}"
+ echo "Finished running all configs for ${example_bin}" 1>&2
+ echo "All results saved to ${OUT_DIR}" 1>&2
+}
+
+#######################################
+# Print the progress of the current session
+# Globals:
+# None
+# Arguments:
+# current Current number of items
+# total Total number of items
+# Returns:
+# None
+#######################################
+function print_progress() {
+ local current
+ local total
+ current=$1
+ total=$2
+ # Width of progress bar
+ local width
+ width=20
+ (( current_width= $width * current / total ))
+ echo -n -e "Progress [" 1>&2
+ for i in $(seq 1 ${width}); do
+ if [[ $i -le ${current_width} ]]; then
+ echo -n "#" 1>&2
+ else
+ echo -n " " 1>&2
+ fi
+ done
+ echo "] $current / $total Experiments" 1>&2
+}
+
+# Functions }}}
+
+# Main: Main script {{{
+# Path to directory containing all benchmark examples binaries
+EXAMPLE_BIN_DIR=""
+# Path to gemm shapes file
+GEMM_SHAPES_FILE=""
+# Path to gemm configs file
+GEMM_CONFIGS_FILE=""
+STRATEGY_OPTION=""
+# Path to output directory
+OUT_DIR=${DEFAULT_OUT_DIR}
+# Output benchmark result file extension
+OUT_EXTENSION="gemmtuner_benchmark"
+# Toggle help
+HELP=false
+
+# Obtain options
+while getopts "hs:e:g:c:o:" opt; do
+ case "$opt" in
+ h) HELP=true ;;
+ s) STRATEGY_OPTION=$(to_lower "${OPTARG}");;
+ e) EXAMPLE_BIN_DIR="${OPTARG}";;
+ g) GEMM_SHAPES_FILE="${OPTARG}";;
+ c) GEMM_CONFIGS_FILE="${OPTARG}";;
+ o) OUT_DIR="${OPTARG}";;
+ esac
+done
+shift $((OPTIND - 1))
+
+# Lazily print usage (after arguments have been parsed)
+$HELP &&
+ usage
+
+# Parse and validate options
+# Verify all compulsory arguments are passed in
+( [ ! -z "${STRATEGY_OPTION}" ] && [ ! -z "${EXAMPLE_BIN_DIR}" ] && [ ! -z "${GEMM_SHAPES_FILE}" ] && [ ! -z "${GEMM_CONFIGS_FILE}" ] ) ||
+ usage
+
+# Verify example binaries directory exists
+[ -d "${EXAMPLE_BIN_DIR}" ] ||
+ error_msg "${EXAMPLE_BIN_DIR} does not exist."
+
+# Verify all benchmark example binaries exist
+[ -f "${EXAMPLE_BIN_DIR}/${EXAMPLE_BIN_RESHAPED_RHS_ONLY}" ] ||
+ error_msg "Cannot find ${EXAMPLE_BIN_RESHAPED_RHS_ONLY} at ${EXAMPLE_BIN_DIR}"
+
+# Verify Gemm shapes file exists
+[ -f "${GEMM_SHAPES_FILE}" ] ||
+ error_msg "Cannot find gemm shapes file ${GEMM_SHAPES_FILE}"
+
+# Verify Gemm configs file exists
+[ -f "${GEMM_CONFIGS_FILE}" ] ||
+ error_msg "Cannot find gemm configs file ${GEMM_CONFIGS_FILE}"
+
+# Verify strategy option is valid
+arr_contains "${STRATEGY_OPTION}" "${ALL_STRATEGY_OPTIONS[@]}" ||
+ error_msg "Does not support strategy ${STRATEGY_OPTION}"
+
+# Make sure existing benchmark outputs are not overwritten
+[ ! -d "${OUT_DIR}" ] ||
+ error_msg "Output directory ${OUT_DIR} already exists!"
+
+# Make output directory
+mkdir ${OUT_DIR}
+
+# Run selected strategy with all configurations
+# Restart the built-in timer
+SECONDS=0
+[ "${STRATEGY_OPTION}" == "native" ] && run $EXAMPLE_BIN_NATIVE
+[ "${STRATEGY_OPTION}" == "reshaped_rhs_only" ] && run $EXAMPLE_BIN_RESHAPED_RHS_ONLY
+[ "${STRATEGY_OPTION}" == "reshaped" ] && run $EXAMPLE_BIN_RESHAPED
+# Main: Main script }}}
diff --git a/examples/gemm_tuner/cl_gemm_native.cpp b/examples/gemm_tuner/cl_gemm_native.cpp
new file mode 100644
index 0000000..0cacd82
--- /dev/null
+++ b/examples/gemm_tuner/cl_gemm_native.cpp
@@ -0,0 +1,239 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CL /* Needed by Utils.cpp to handle OpenCL exceptions properly */
+#error "This example needs to be built with -DARM_COMPUTE_CL"
+#endif /* ARM_COMPUTE_CL */
+
+#include "CommonGemmExampleOptions.h"
+#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/CL/CLFunctions.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/CLTuner.h"
+#include "tests/CL/Helper.h"
+#include "utils/Utils.h"
+#include "utils/command_line/CommandLineOptions.h"
+#include "utils/command_line/CommandLineParser.h"
+
+#include <cstdlib>
+
+using namespace arm_compute;
+using namespace utils;
+using namespace arm_compute::misc::shape_calculator;
+using namespace gemm_tuner;
+
+namespace
+{
+/** Structure holding all tunable gemm configs specific to this example/strategy */
+struct GemmConfigs
+{
+ size_t m0{ 4 }; /**< Number of rows processed by the matrix multiplication */
+ size_t n0{ 4 }; /**< Number of columns processed by the matrix multiplication */
+ size_t k0{ 4 }; /**< Number of partial accumulations performed by the matrix multiplication */
+};
+
+/** Formatted output of the GemmConfigs type
+ *
+ * @param[out] os Output stream.
+ * @param[in] configs Tunable configurations to output
+ *
+ * @return Modified output stream.
+ */
+::std::ostream &operator<<(::std::ostream &os, const GemmConfigs &configs)
+{
+ std::string false_str = std::string("false");
+ std::string true_str = std::string("true");
+
+ os << "m0 : " << configs.m0 << std::endl;
+ os << "n0 : " << configs.n0 << std::endl;
+ os << "k0 : " << configs.k0 << std::endl;
+ return os;
+}
+
+/** Command line options for gemm configs */
+class GemmConfigOptions
+{
+public:
+ /** Constructor
+ *
+ * @param[in,out] parser A parser on which "parse()" hasn't been called yet.
+ */
+ GemmConfigOptions(CommandLineParser &parser)
+ : m0(parser.add_positional_option<SimpleOption<size_t>>("m0", 4)),
+ n0(parser.add_positional_option<SimpleOption<size_t>>("n0", 4)),
+ k0(parser.add_positional_option<SimpleOption<size_t>>("k0", 4))
+ {
+ m0->set_help("Number of rows processed by the matrix multiplication");
+ n0->set_help("Number of columns processed by the matrix multiplication");
+ k0->set_help("Number of partial accumulations performed by the matrix multiplication");
+ }
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions(const GemmConfigOptions &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions &operator=(const GemmConfigOptions &) = delete;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions(GemmConfigOptions &&) = default;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions &operator=(GemmConfigOptions &&) = default;
+ /** Default destructor */
+ ~GemmConfigOptions() = default;
+
+ SimpleOption<size_t> *m0; /**< Number of rows processed by the matrix multiplication option */
+ SimpleOption<size_t> *n0; /**< Number of columns processed by the matrix multiplication option */
+ SimpleOption<size_t> *k0; /**< Number of partial accumulations performed by the matrix multiplication option */
+};
+
+/** Consumes the gemm configuration options and creates a structure containing all information
+ *
+ * @param[in] options Options to consume
+ *
+ * @return Structure containing the gemm configurations
+ */
+GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
+{
+ GemmConfigs configs;
+ configs.m0 = options.m0->value();
+ configs.n0 = options.n0->value();
+ configs.k0 = options.k0->value();
+ return configs;
+}
+
+} // namespace
+// Create function for CLGEMMMatrixMultiplyNativeKernel
+using CLGEMMMatrixMultiplyNative = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyNativeKernel>;
+
+class CLGEMMMatrixMultiplyNativeExample : public Example
+{
+public:
+ bool do_setup(int argc, char **argv) override
+ {
+ // Default parameters
+ const DataType data_type = DataType::F32;
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+ const ActivationLayerInfo act_info = ActivationLayerInfo();
+ CommonGemmExampleParams params;
+ GemmConfigs configs;
+
+ // Set up command line parser and options
+ CommandLineParser parser;
+ CommonGemmExampleOptions param_options(parser);
+ GemmConfigOptions config_options(parser);
+
+ // Parse command line options
+ parser.parse(argc, argv);
+ if(param_options.help->is_set() && param_options.help->value())
+ {
+ // Print help message
+ parser.print_help(argv[0]);
+ return false;
+ }
+ if(!parser.validate())
+ {
+ // Invalid arguments. Use default parameters and configs
+ std::cerr << "Invalid arguments." << std::endl;
+ parser.print_help(argv[0]);
+ std::cerr << "Falling back to default parameters and configs" << std::endl;
+ }
+ else
+ {
+ // Get parameters and configs from command-line options
+ params = consume_common_gemm_example_parameters(param_options);
+ configs = consume_gemm_configs(config_options);
+ }
+
+ // Print gemm parameters and configurations
+ std::cerr << "Gemm parameters:" << std::endl;
+ std::cerr << params << std::endl;
+ std::cerr << "Gemm configurations:" << std::endl;
+ std::cerr << configs << std::endl;
+
+ CLScheduler::get().default_init(&tuner);
+
+ lhs.allocator()->init(TensorInfo(TensorShape(params.K, params.M, params.B), 1, data_type));
+ rhs.allocator()->init(TensorInfo(TensorShape(params.N, params.K, params.B), 1, data_type));
+ bias.allocator()->init(TensorInfo(TensorShape(params.N, 1, params.B), 1, data_type));
+
+ GEMMLHSMatrixInfo lhs_info;
+ lhs_info.m0 = configs.m0;
+ lhs_info.k0 = configs.k0;
+
+ GEMMRHSMatrixInfo rhs_info;
+ rhs_info.n0 = configs.n0;
+ rhs_info.k0 = configs.k0;
+
+ GEMMKernelInfo kernel_info;
+ kernel_info.m = params.M;
+ kernel_info.n = params.N;
+ kernel_info.k = params.K;
+ kernel_info.depth_output_gemm3d = 0;
+ kernel_info.reinterpret_input_as_3d = false;
+ kernel_info.broadcast_bias = true;
+ kernel_info.activation_info = act_info;
+
+ // Configure function
+ gemm.configure(&lhs, &rhs, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
+
+ // Allocate tensors
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ bias.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ return true;
+ }
+ void do_run() override
+ {
+ // Execute the function
+ gemm.run();
+
+ // Make sure all the OpenCL jobs are done executing:
+ CLScheduler::get().sync();
+ }
+
+ void do_teardown() override
+ {
+ }
+
+private:
+ CLTensor lhs{};
+ CLTensor rhs{};
+ CLTensor bias{};
+ CLTensor dst{};
+ CLTuner tuner{};
+ CLGEMMMatrixMultiplyNative gemm{};
+};
+
+/** Main program for gemm native test
+ *
+ * @param[in] argc Number of arguments
+ * @param[in] argv Arguments ( [optional] M, [optional] N, [optional] K, [optional] B, [optional] m0, [optional] n0, [optional] k0 )
+ */
+int main(int argc, char **argv)
+{
+ return run_example<CLGEMMMatrixMultiplyNativeExample>(argc, argv);
+}
diff --git a/examples/gemm_tuner/cl_gemm_reshaped.cpp b/examples/gemm_tuner/cl_gemm_reshaped.cpp
new file mode 100644
index 0000000..e579ed7
--- /dev/null
+++ b/examples/gemm_tuner/cl_gemm_reshaped.cpp
@@ -0,0 +1,305 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CL /* Needed by Utils.cpp to handle OpenCL exceptions properly */
+#error "This example needs to be built with -DARM_COMPUTE_CL"
+#endif /* ARM_COMPUTE_CL */
+
+#include "CommonGemmExampleOptions.h"
+#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
+#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/CL/CLFunctions.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/CLTuner.h"
+#include "tests/CL/Helper.h"
+#include "utils/Utils.h"
+#include "utils/command_line/CommandLineOptions.h"
+#include "utils/command_line/CommandLineParser.h"
+
+#include <cstdlib>
+
+using namespace arm_compute;
+using namespace utils;
+using namespace arm_compute::misc::shape_calculator;
+using namespace gemm_tuner;
+
+namespace
+{
+/** Structure holding all tunable gemm configs specific to this example/strategy */
+struct GemmConfigs
+{
+ size_t m0{ 4 }; /**< Number of rows processed by the matrix multiplication */
+ size_t n0{ 4 }; /**< Number of columns processed by the matrix multiplication */
+ size_t k0{ 4 }; /**< Number of partial accumulations performed by the matrix multiplication */
+ size_t v0{ 1 }; /**< Number of vertical blocks of size (m0xk0) stored on the same output row */
+ size_t h0{ 1 }; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
+ bool interleave_lhs{ true }; /**< Interleave lhs matrix */
+ bool transpose_lhs{ true }; /**< Transpose lhs matrix. */
+ bool interleave_rhs{ true }; /**< Interleave rhs matrix */
+ bool transpose_rhs{ true }; /**< Transpose rhs matrix. */
+};
+
+/** Formatted output of the GemmConfigs type
+ *
+ * @param[out] os Output stream.
+ * @param[in] configs Tunable configurations to output
+ *
+ * @return Modified output stream.
+ */
+::std::ostream &operator<<(::std::ostream &os, const GemmConfigs &configs)
+{
+ std::string false_str = std::string("false");
+ std::string true_str = std::string("true");
+
+ os << "m0 : " << configs.m0 << std::endl;
+ os << "n0 : " << configs.n0 << std::endl;
+ os << "k0 : " << configs.k0 << std::endl;
+ os << "v0 : " << configs.v0 << std::endl;
+ os << "h0 : " << configs.h0 << std::endl;
+ os << "interleave_lhs : " << (configs.interleave_lhs ? true_str : false_str) << std::endl;
+ os << "transpose_lhs : " << (configs.transpose_lhs ? true_str : false_str) << std::endl;
+ os << "interleave_rhs : " << (configs.interleave_rhs ? true_str : false_str) << std::endl;
+ os << "transpose_rhs : " << (configs.transpose_rhs ? true_str : false_str) << std::endl;
+ return os;
+}
+
+/** Command line options for gemm configs */
+class GemmConfigOptions
+{
+public:
+ /** Constructor
+ *
+ * @param[in,out] parser A parser on which "parse()" hasn't been called yet.
+ */
+ GemmConfigOptions(CommandLineParser &parser)
+ : m0(parser.add_positional_option<SimpleOption<size_t>>("m0", 4)),
+ n0(parser.add_positional_option<SimpleOption<size_t>>("n0", 4)),
+ k0(parser.add_positional_option<SimpleOption<size_t>>("k0", 4)),
+ v0(parser.add_positional_option<SimpleOption<size_t>>("v0", 1)),
+ h0(parser.add_positional_option<SimpleOption<size_t>>("h0", 1)),
+ interleave_lhs(parser.add_positional_option<SimpleOption<size_t>>("interleave_lhs", 1)),
+ interleave_rhs(parser.add_positional_option<SimpleOption<size_t>>("interleave_rhs", 1)),
+ transpose_rhs(parser.add_positional_option<SimpleOption<size_t>>("transpose_rhs", 1))
+ {
+ m0->set_help("Number of rows processed by the matrix multiplication");
+ n0->set_help("Number of columns processed by the matrix multiplication");
+ k0->set_help("Number of partial accumulations performed by the matrix multiplication");
+ v0->set_help("Number of vertical blocks of size (m0xk0) stored on the same output row");
+ h0->set_help("Number of horizontal blocks of size (k0xn0) stored on the same output row");
+ interleave_lhs->set_help("Interleave lhs matrix (1) / Do not interleave lhs matrix (0)");
+ interleave_rhs->set_help("Interleave rhs matrix (1) / Do not interleave rhs matrix (0)");
+ // FIXME: Currently we only support 2 variants of the gemm reshaped kernels in which transpose_lhs and
+ // transpose_rhs are the opposites of each other. In the future we may extend the kernels to include the other
+ // 2 variants (both transposed and none transposed)
+ transpose_rhs->set_help("Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)");
+ }
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions(const GemmConfigOptions &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions &operator=(const GemmConfigOptions &) = delete;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions(GemmConfigOptions &&) = default;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions &operator=(GemmConfigOptions &&) = default;
+ /** Default destructor */
+ ~GemmConfigOptions() = default;
+
+ SimpleOption<size_t> *m0; /**< Number of rows processed by the matrix multiplication option */
+ SimpleOption<size_t> *n0; /**< Number of columns processed by the matrix multiplication option */
+ SimpleOption<size_t> *k0; /**< Number of partial accumulations performed by the matrix multiplication option */
+ SimpleOption<size_t> *v0; /**< Number of vertical blocks of size (m0xk0) stored on the same output row option */
+ SimpleOption<size_t> *h0; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row option */
+ SimpleOption<size_t> *interleave_lhs; /**< Interleave lhs matrix option (1 enable; 0 disable) */
+ SimpleOption<size_t> *interleave_rhs; /**< Interleave rhs matrix option (1 enable; 0 disable) */
+ // FIXME: Currently we only support 2 variants of the gemm reshaped kernels in which transpose_lhs and
+ // transpose_rhs are the opposites of each other. In the future we may extend the kernels to include the other
+ // 2 variants (both transposed and none transposed)
+ SimpleOption<size_t> *transpose_rhs; /**< Transpose rhs matrix option (1 enable; 0 disable). Also set the lhs matrix transpose option to the opposite. */
+};
+
+/** Consumes the gemm configuration options and creates a structure containing all information
+ *
+ * @param[in] options Options to consume
+ *
+ * @return Structure containing the gemm configurations
+ */
+GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
+{
+ GemmConfigs configs;
+ configs.m0 = options.m0->value();
+ configs.n0 = options.n0->value();
+ configs.k0 = options.k0->value();
+ configs.v0 = options.v0->value();
+ configs.h0 = options.h0->value();
+ configs.interleave_lhs = options.interleave_lhs->value() != 0;
+ // FIXME: Currently we only support 2 variants of the gemm reshaped kernels in which transpose_lhs and
+ // transpose_rhs are the opposites of each other. In the future we may extend the kernels to include the other
+ // 2 variants (both transposed and none transposed)
+ configs.transpose_lhs = options.transpose_rhs->value() == 0;
+ configs.interleave_rhs = options.interleave_rhs->value() != 0;
+ configs.transpose_rhs = options.transpose_rhs->value() != 0;
+ return configs;
+}
+
+} // namespace
+// Create function for CLGEMMReshapeLHSMatrixKernel
+using CLGEMMReshapeLHSMatrix = test::CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
+// Create function for CLGEMMMatrixMultiplyReshapedKernel
+using CLGEMMMatrixMultiplyReshaped = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
+
+class CLGEMMMatrixMultiplyReshapedExample : public Example
+{
+public:
+ bool do_setup(int argc, char **argv) override
+ {
+ // Default parameters
+ const DataType data_type = DataType::F32;
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+ const ActivationLayerInfo act_info = ActivationLayerInfo();
+ CommonGemmExampleParams params;
+ GemmConfigs configs;
+
+ // Set up command line parser and options
+ CommandLineParser parser;
+ CommonGemmExampleOptions param_options(parser);
+ GemmConfigOptions config_options(parser);
+
+ // Parse command line options
+ parser.parse(argc, argv);
+ if(param_options.help->is_set() && param_options.help->value())
+ {
+ // Print help message
+ parser.print_help(argv[0]);
+ return false;
+ }
+ if(!parser.validate())
+ {
+ // Invalid arguments. Use default parameters and configs
+ std::cerr << "Invalid arguments." << std::endl;
+ parser.print_help(argv[0]);
+ std::cerr << "Falling back to default parameters and configs" << std::endl;
+ }
+ else
+ {
+ // Get parameters and configs from command-line options
+ params = consume_common_gemm_example_parameters(param_options);
+ configs = consume_gemm_configs(config_options);
+ }
+
+ // Print gemm parameters and configurations
+ std::cerr << "Gemm parameters:" << std::endl;
+ std::cerr << params << std::endl;
+ std::cerr << "Gemm configurations:" << std::endl;
+ std::cerr << configs << std::endl;
+
+ CLScheduler::get().default_init(&tuner);
+
+ lhs.allocator()->init(TensorInfo(TensorShape(params.K, params.M, params.B), 1, data_type));
+ rhs.allocator()->init(TensorInfo(TensorShape(params.N, params.K, params.B), 1, data_type));
+ bias.allocator()->init(TensorInfo(TensorShape(params.N, 1, params.B), 1, data_type));
+
+ GEMMLHSMatrixInfo lhs_info;
+ lhs_info.m0 = configs.m0;
+ lhs_info.k0 = configs.k0;
+ lhs_info.v0 = configs.v0;
+ lhs_info.interleave = configs.interleave_lhs;
+ lhs_info.transpose = configs.transpose_lhs;
+
+ GEMMRHSMatrixInfo rhs_info;
+ rhs_info.n0 = configs.n0;
+ rhs_info.k0 = configs.k0;
+ rhs_info.h0 = configs.h0;
+ rhs_info.interleave = configs.interleave_rhs;
+ rhs_info.transpose = configs.transpose_rhs;
+
+ GEMMKernelInfo kernel_info;
+ kernel_info.m = params.M;
+ kernel_info.n = params.N;
+ kernel_info.k = params.K;
+ kernel_info.depth_output_gemm3d = 0;
+ kernel_info.reinterpret_input_as_3d = false;
+ kernel_info.broadcast_bias = true;
+ kernel_info.activation_info = act_info;
+
+ // Initialise lhs_reshaped tensor info
+ auto_init_if_empty(*lhs_reshaped.info(), lhs.info()->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*lhs.info(), lhs_info)));
+
+ // Initialise rhs_reshaped tensor info
+ auto_init_if_empty(*rhs_reshaped.info(), rhs.info()->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*rhs.info(), rhs_info)));
+
+ // Configure reshape lhs function
+ reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
+
+ // Configure function
+ gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
+
+ // Allocate tensors
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ lhs_reshaped.allocator()->allocate();
+ rhs_reshaped.allocator()->allocate();
+ bias.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ return true;
+ }
+ void do_run() override
+ {
+ // Execute the function
+ reshape_lhs.run();
+ gemm.run();
+
+ // Make sure all the OpenCL jobs are done executing:
+ CLScheduler::get().sync();
+ }
+
+ void do_teardown() override
+ {
+ }
+
+private:
+ CLTensor lhs{};
+ CLTensor rhs{};
+ CLTensor lhs_reshaped{};
+ CLTensor rhs_reshaped{};
+ CLTensor bias{};
+ CLTensor dst{};
+ CLTuner tuner{};
+ CLGEMMReshapeLHSMatrix reshape_lhs{};
+ CLGEMMMatrixMultiplyReshaped gemm{};
+};
+
+/** Main program for gemm reshaped test
+ *
+ * @param[in] argc Number of arguments
+ * @param[in] argv Arguments ( [optional] M, [optional] N, [optional] K, [optional] B, [optional] m0, [optional] n0, [optional] k0, [optional] v0, [optional] h0, [optional] interleave_lhs, [optional] interleave_rhs, [optional] transpose_rhs )
+ */
+int main(int argc, char **argv)
+{
+ return run_example<CLGEMMMatrixMultiplyReshapedExample>(argc, argv);
+}
diff --git a/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp b/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
new file mode 100644
index 0000000..0d161aa
--- /dev/null
+++ b/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
@@ -0,0 +1,265 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CL /* Needed by Utils.cpp to handle OpenCL exceptions properly */
+#error "This example needs to be built with -DARM_COMPUTE_CL"
+#endif /* ARM_COMPUTE_CL */
+
+#include "CommonGemmExampleOptions.h"
+#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/CL/CLFunctions.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/CLTuner.h"
+#include "tests/CL/Helper.h"
+#include "utils/Utils.h"
+#include "utils/command_line/CommandLineOptions.h"
+#include "utils/command_line/CommandLineParser.h"
+
+#include <cstdlib>
+
+using namespace arm_compute;
+using namespace utils;
+using namespace arm_compute::misc::shape_calculator;
+using namespace gemm_tuner;
+
+namespace
+{
+/** Structure holding all tunable gemm configs specific to this example/strategy */
+struct GemmConfigs
+{
+ size_t m0{ 4 }; /**< Number of rows processed by the matrix multiplication */
+ size_t n0{ 4 }; /**< Number of columns processed by the matrix multiplication */
+ size_t k0{ 4 }; /**< Number of partial accumulations performed by the matrix multiplication */
+ size_t h0{ 1 }; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
+ bool interleave_rhs{ true }; /**< Interleave rhs matrix */
+ bool transpose_rhs{ true }; /**< Transpose rhs matrix */
+};
+
+/** Formatted output of the GemmConfigs type
+ *
+ * @param[out] os Output stream.
+ * @param[in] configs Tunable configurations to output
+ *
+ * @return Modified output stream.
+ */
+::std::ostream &operator<<(::std::ostream &os, const GemmConfigs &configs)
+{
+ std::string false_str = std::string("false");
+ std::string true_str = std::string("true");
+
+ os << "m0 : " << configs.m0 << std::endl;
+ os << "n0 : " << configs.n0 << std::endl;
+ os << "k0 : " << configs.k0 << std::endl;
+ os << "h0 : " << configs.h0 << std::endl;
+ os << "interleave_rhs : " << (configs.interleave_rhs ? true_str : false_str) << std::endl;
+ os << "transpose_rhs : " << (configs.transpose_rhs ? true_str : false_str) << std::endl;
+ return os;
+}
+
+/** Command line options for gemm configs */
+class GemmConfigOptions
+{
+public:
+ /** Constructor
+ *
+ * @param[in,out] parser A parser on which "parse()" hasn't been called yet.
+ */
+ GemmConfigOptions(CommandLineParser &parser)
+ : m0(parser.add_positional_option<SimpleOption<size_t>>("m0", 4)),
+ n0(parser.add_positional_option<SimpleOption<size_t>>("n0", 4)),
+ k0(parser.add_positional_option<SimpleOption<size_t>>("k0", 4)),
+ h0(parser.add_positional_option<SimpleOption<size_t>>("h0", 1)),
+ interleave_rhs(parser.add_positional_option<SimpleOption<size_t>>("interleave_rhs", 1)),
+ transpose_rhs(parser.add_positional_option<SimpleOption<size_t>>("transpose_rhs", 1))
+ {
+ m0->set_help("Number of rows processed by the matrix multiplication");
+ n0->set_help("Number of columns processed by the matrix multiplication");
+ k0->set_help("Number of partial accumulations performed by the matrix multiplication");
+ h0->set_help("Number of horizontal blocks of size (k0xn0) stored on the same output row");
+ interleave_rhs->set_help("Interleave rhs matrix (1) / Do not interleave rhs matrix (0)");
+ transpose_rhs->set_help("Transpose rhs matrix (1) / Do not transpose rhs matrix (0)");
+ }
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions(const GemmConfigOptions &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions &operator=(const GemmConfigOptions &) = delete;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions(GemmConfigOptions &&) = default;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions &operator=(GemmConfigOptions &&) = default;
+ /** Default destructor */
+ ~GemmConfigOptions() = default;
+
+ SimpleOption<size_t> *m0; /**< Number of rows processed by the matrix multiplication option */
+ SimpleOption<size_t> *n0; /**< Number of columns processed by the matrix multiplication option */
+ SimpleOption<size_t> *k0; /**< Number of partial accumulations performed by the matrix multiplication option */
+ SimpleOption<size_t> *h0; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row option */
+ SimpleOption<size_t> *interleave_rhs; /**< Interleave rhs matrix option (1 enable; 0 disable) */
+ SimpleOption<size_t> *transpose_rhs; /**< Transpose rhs matrix option (1 enable; 0 disable) */
+};
+
+/** Consumes the gemm configuration options and creates a structure containing all information
+ *
+ * @param[in] options Options to consume
+ *
+ * @return Structure containing the gemm configurations
+ */
+GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
+{
+ GemmConfigs configs;
+ configs.m0 = options.m0->value();
+ configs.n0 = options.n0->value();
+ configs.k0 = options.k0->value();
+ configs.h0 = options.h0->value();
+ configs.interleave_rhs = options.interleave_rhs->value() != 0;
+ configs.transpose_rhs = options.transpose_rhs->value() != 0;
+ return configs;
+}
+
+} // namespace
+// Create function for CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
+using CLGEMMMatrixMultiplyReshapedOnlyRHS = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>;
+
+class CLGEMMMatrixMultiplyReshapedOnlyRHSExample : public Example
+{
+public:
+ bool do_setup(int argc, char **argv) override
+ {
+ // Default parameters
+ const DataType data_type = DataType::F32;
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+ const ActivationLayerInfo act_info = ActivationLayerInfo();
+ CommonGemmExampleParams params;
+ GemmConfigs configs;
+
+ // Set up command line parser and options
+ CommandLineParser parser;
+ CommonGemmExampleOptions param_options(parser);
+ GemmConfigOptions config_options(parser);
+
+ // Parse command line options
+ parser.parse(argc, argv);
+ if(param_options.help->is_set() && param_options.help->value())
+ {
+ // Print help message
+ parser.print_help(argv[0]);
+ return false;
+ }
+ if(!parser.validate())
+ {
+ // Invalid arguments. Use default parameters and configs
+ std::cerr << "Invalid arguments." << std::endl;
+ parser.print_help(argv[0]);
+ std::cerr << "Falling back to default parameters and configs" << std::endl;
+ }
+ else
+ {
+ // Get parameters and configs from command-line options
+ params = consume_common_gemm_example_parameters(param_options);
+ configs = consume_gemm_configs(config_options);
+ }
+
+ // Print gemm parameters and configurations
+ std::cerr << "Gemm parameters:" << std::endl;
+ std::cerr << params << std::endl;
+ std::cerr << "Gemm configurations:" << std::endl;
+ std::cerr << configs << std::endl;
+
+ CLScheduler::get().default_init(&tuner);
+
+ lhs.allocator()->init(TensorInfo(TensorShape(params.K, params.M, params.B), 1, data_type));
+ rhs.allocator()->init(TensorInfo(TensorShape(params.N, params.K, params.B), 1, data_type));
+ bias.allocator()->init(TensorInfo(TensorShape(params.N, 1, params.B), 1, data_type));
+
+ GEMMLHSMatrixInfo lhs_info;
+ lhs_info.m0 = configs.m0;
+ lhs_info.k0 = configs.k0;
+
+ GEMMRHSMatrixInfo rhs_info;
+ rhs_info.n0 = configs.n0;
+ rhs_info.k0 = configs.k0;
+ rhs_info.h0 = configs.h0;
+ rhs_info.interleave = configs.interleave_rhs;
+ rhs_info.transpose = configs.transpose_rhs;
+
+ GEMMKernelInfo kernel_info;
+ kernel_info.m = params.M;
+ kernel_info.n = params.N;
+ kernel_info.k = params.K;
+ kernel_info.depth_output_gemm3d = 0;
+ kernel_info.reinterpret_input_as_3d = false;
+ kernel_info.broadcast_bias = true;
+ kernel_info.activation_info = act_info;
+
+ // Initialise rhs_reshaped tensor info
+ auto_init_if_empty(*rhs_reshaped.info(), rhs.info()->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*rhs.info(), rhs_info)));
+
+ // Configure function
+ gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
+
+ // Allocate tensors
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ rhs_reshaped.allocator()->allocate();
+ bias.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ return true;
+ }
+ void do_run() override
+ {
+ // Execute the function
+ gemm.run();
+
+ // Make sure all the OpenCL jobs are done executing:
+ CLScheduler::get().sync();
+ }
+
+ void do_teardown() override
+ {
+ }
+
+private:
+ CLTensor lhs{};
+ CLTensor rhs{};
+ CLTensor rhs_reshaped{};
+ CLTensor bias{};
+ CLTensor dst{};
+ CLTuner tuner{};
+ CLGEMMMatrixMultiplyReshapedOnlyRHS gemm{};
+};
+
+/** Main program for gemm reshaped rhs only test
+ *
+ * @param[in] argc Number of arguments
+ * @param[in] argv Arguments ( [optional] M, [optional] N, [optional] K, [optional] B, [optional] m0, [optional] n0, [optional] k0, [optional] h0, [optional] interleave_rhs, [optional] transpose_rhs )
+ */
+int main(int argc, char **argv)
+{
+ return run_example<CLGEMMMatrixMultiplyReshapedOnlyRHSExample>(argc, argv);
+}
diff --git a/examples/graph_alexnet.cpp b/examples/graph_alexnet.cpp
index 88e0d7e..79d02f6 100644
--- a/examples/graph_alexnet.cpp
+++ b/examples/graph_alexnet.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_deepspeech_v0_4_1.cpp b/examples/graph_deepspeech_v0_4_1.cpp
index a69d235..d2a4832 100644
--- a/examples/graph_deepspeech_v0_4_1.cpp
+++ b/examples/graph_deepspeech_v0_4_1.cpp
@@ -45,6 +45,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
@@ -141,61 +142,56 @@
get_weights_accessor(data_path, "ones.npy"))
.set_name("add_y");
- // TODO(COMPMID-2103): Use sub stream for FC weights and bias in LSTM cells
// Create LSTM Fully Connected weights and bias descriptors
- //TensorDescriptor lstm_weights_descriptor = TensorDescriptor(TensorShape(4096U, 8192U), common_params.data_type).set_layout(common_params.data_layout);
- //TensorDescriptor lstm_bias_descriptor = TensorDescriptor(TensorShape(8192U), common_params.data_type).set_layout(common_params.data_layout);
- //SubStream lstm_fc_weights(graph);
- //SubStream lstm_fc_bias(graph);
-
- //lstm_fc_weights << InputLayer(lstm_weights_descriptor,
- // get_weights_accessor(data_path, "rnn_lstm_cell_kernel_transpose.npy", weights_layout))
- // .set_name("h5/transpose");
- //lstm_fc_bias << InputLayer(lstm_bias_descriptor,
- // get_weights_accessor(data_path, "rnn_lstm_cell_MatMul_bias.npy"))
- // .set_name("MatMul_3_bias");
+ TensorDescriptor lstm_weights_descriptor = TensorDescriptor(TensorShape(4096U, 8192U), common_params.data_type).set_layout(common_params.data_layout);
+ TensorDescriptor lstm_bias_descriptor = TensorDescriptor(TensorShape(8192U), common_params.data_type).set_layout(common_params.data_layout);
+ SubStream lstm_fc_weights(graph);
+ SubStream lstm_fc_bias(graph);
+ lstm_fc_weights << ConstantLayer(lstm_weights_descriptor,
+ get_weights_accessor(data_path, "rnn_lstm_cell_kernel_transpose.npy", weights_layout))
+ .set_name("h5/transpose");
+ lstm_fc_bias << ConstantLayer(lstm_bias_descriptor,
+ get_weights_accessor(data_path, "rnn_lstm_cell_MatMul_bias.npy"))
+ .set_name("MatMul_3_bias");
// LSTM Block
- std::pair<SubStream, SubStream> new_state_1 = add_lstm_cell(data_path, unstack_nid, 0, previous_state, previous_state, add_y);
- std::pair<SubStream, SubStream> new_state_2 = add_lstm_cell(data_path, unstack_nid, 1, new_state_1.first, new_state_1.second, add_y);
- std::pair<SubStream, SubStream> new_state_3 = add_lstm_cell(data_path, unstack_nid, 2, new_state_2.first, new_state_2.second, add_y);
- std::pair<SubStream, SubStream> new_state_4 = add_lstm_cell(data_path, unstack_nid, 3, new_state_3.first, new_state_3.second, add_y);
- std::pair<SubStream, SubStream> new_state_5 = add_lstm_cell(data_path, unstack_nid, 4, new_state_4.first, new_state_4.second, add_y);
- std::pair<SubStream, SubStream> new_state_6 = add_lstm_cell(data_path, unstack_nid, 5, new_state_5.first, new_state_5.second, add_y);
- std::pair<SubStream, SubStream> new_state_7 = add_lstm_cell(data_path, unstack_nid, 6, new_state_6.first, new_state_6.second, add_y);
- std::pair<SubStream, SubStream> new_state_8 = add_lstm_cell(data_path, unstack_nid, 7, new_state_7.first, new_state_7.second, add_y);
- std::pair<SubStream, SubStream> new_state_9 = add_lstm_cell(data_path, unstack_nid, 8, new_state_8.first, new_state_8.second, add_y);
- std::pair<SubStream, SubStream> new_state_10 = add_lstm_cell(data_path, unstack_nid, 9, new_state_9.first, new_state_9.second, add_y);
- std::pair<SubStream, SubStream> new_state_11 = add_lstm_cell(data_path, unstack_nid, 10, new_state_10.first, new_state_10.second, add_y);
- std::pair<SubStream, SubStream> new_state_12 = add_lstm_cell(data_path, unstack_nid, 11, new_state_11.first, new_state_11.second, add_y);
- std::pair<SubStream, SubStream> new_state_13 = add_lstm_cell(data_path, unstack_nid, 12, new_state_12.first, new_state_12.second, add_y);
- std::pair<SubStream, SubStream> new_state_14 = add_lstm_cell(data_path, unstack_nid, 13, new_state_13.first, new_state_13.second, add_y);
- std::pair<SubStream, SubStream> new_state_15 = add_lstm_cell(data_path, unstack_nid, 14, new_state_14.first, new_state_14.second, add_y);
- std::pair<SubStream, SubStream> new_state_16 = add_lstm_cell(data_path, unstack_nid, 15, new_state_15.first, new_state_15.second, add_y);
+ std::pair<SubStream, SubStream> new_state_1 = add_lstm_cell(unstack_nid, 0, previous_state, previous_state, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_2 = add_lstm_cell(unstack_nid, 1, new_state_1.first, new_state_1.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_3 = add_lstm_cell(unstack_nid, 2, new_state_2.first, new_state_2.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_4 = add_lstm_cell(unstack_nid, 3, new_state_3.first, new_state_3.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_5 = add_lstm_cell(unstack_nid, 4, new_state_4.first, new_state_4.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_6 = add_lstm_cell(unstack_nid, 5, new_state_5.first, new_state_5.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_7 = add_lstm_cell(unstack_nid, 6, new_state_6.first, new_state_6.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_8 = add_lstm_cell(unstack_nid, 7, new_state_7.first, new_state_7.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_9 = add_lstm_cell(unstack_nid, 8, new_state_8.first, new_state_8.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_10 = add_lstm_cell(unstack_nid, 9, new_state_9.first, new_state_9.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_11 = add_lstm_cell(unstack_nid, 10, new_state_10.first, new_state_10.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_12 = add_lstm_cell(unstack_nid, 11, new_state_11.first, new_state_11.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_13 = add_lstm_cell(unstack_nid, 12, new_state_12.first, new_state_12.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_14 = add_lstm_cell(unstack_nid, 13, new_state_13.first, new_state_13.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_15 = add_lstm_cell(unstack_nid, 14, new_state_14.first, new_state_14.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_16 = add_lstm_cell(unstack_nid, 15, new_state_15.first, new_state_15.second, add_y, lstm_fc_weights, lstm_fc_bias);
- if(n_steps > 1)
- {
- // Concatenate new states on height
- const int axis = 1;
- graph << StackLayer(axis,
- std::move(new_state_1.second),
- std::move(new_state_2.second),
- std::move(new_state_3.second),
- std::move(new_state_4.second),
- std::move(new_state_5.second),
- std::move(new_state_6.second),
- std::move(new_state_7.second),
- std::move(new_state_8.second),
- std::move(new_state_9.second),
- std::move(new_state_10.second),
- std::move(new_state_11.second),
- std::move(new_state_12.second),
- std::move(new_state_13.second),
- std::move(new_state_14.second),
- std::move(new_state_15.second),
- std::move(new_state_16.second))
- .set_name("concat");
- }
+ // Concatenate new states on height
+ const int axis = 1;
+ graph << StackLayer(axis,
+ std::move(new_state_1.second),
+ std::move(new_state_2.second),
+ std::move(new_state_3.second),
+ std::move(new_state_4.second),
+ std::move(new_state_5.second),
+ std::move(new_state_6.second),
+ std::move(new_state_7.second),
+ std::move(new_state_8.second),
+ std::move(new_state_9.second),
+ std::move(new_state_10.second),
+ std::move(new_state_11.second),
+ std::move(new_state_12.second),
+ std::move(new_state_13.second),
+ std::move(new_state_14.second),
+ std::move(new_state_15.second),
+ std::move(new_state_16.second))
+ .set_name("concat");
graph << FullyConnectedLayer(
2048U,
@@ -245,15 +241,13 @@
return Status{};
}
- std::pair<SubStream, SubStream> add_lstm_cell(const std::string &data_path,
- NodeID unstack_nid,
+ std::pair<SubStream, SubStream> add_lstm_cell(NodeID unstack_nid,
unsigned int unstack_idx,
SubStream previous_state_c,
SubStream previous_state_h,
- SubStream add_y)
- // TODO(COMPMID-2103): Use sub streams for FC weights and bias
- //SubStream lstm_fc_weights,
- //SubStream lstm_fc_bias)
+ SubStream add_y,
+ SubStream lstm_fc_weights,
+ SubStream lstm_fc_bias)
{
const std::string cell_name("rnn/lstm_cell_" + std::to_string(unstack_idx));
const DataLayoutDimension concat_dim = (common_params.data_layout == DataLayout::NHWC) ? DataLayoutDimension::CHANNEL : DataLayoutDimension::WIDTH;
@@ -268,8 +262,8 @@
graph << FullyConnectedLayer(
8192U,
- get_weights_accessor(data_path, "rnn_lstm_cell_kernel_transpose.npy", DataLayout::NHWC),
- get_weights_accessor(data_path, "rnn_lstm_cell_MatMul_bias.npy"))
+ lstm_fc_weights,
+ lstm_fc_bias)
.set_name(cell_name + "/BiasAdd");
// Split Layer
diff --git a/examples/graph_googlenet.cpp b/examples/graph_googlenet.cpp
index 185680a..b768d28 100644
--- a/examples/graph_googlenet.cpp
+++ b/examples/graph_googlenet.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_inception_resnet_v1.cpp b/examples/graph_inception_resnet_v1.cpp
index 64c35e1..89e44ed 100644
--- a/examples/graph_inception_resnet_v1.cpp
+++ b/examples/graph_inception_resnet_v1.cpp
@@ -56,6 +56,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_inception_resnet_v2.cpp b/examples/graph_inception_resnet_v2.cpp
index 921fada..424884f 100644
--- a/examples/graph_inception_resnet_v2.cpp
+++ b/examples/graph_inception_resnet_v2.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_inception_v3.cpp b/examples/graph_inception_v3.cpp
index 0a1e312..1de6a5f 100644
--- a/examples/graph_inception_v3.cpp
+++ b/examples/graph_inception_v3.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
@@ -84,8 +85,8 @@
"/cnn_data/inceptionv3_model/Conv2d_1a_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path,
"/cnn_data/inceptionv3_model/Conv2d_1a_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f), get_weights_accessor(data_path,
- "/cnn_data/inceptionv3_model/Conv2d_1a_3x3_BatchNorm_beta.npy"),
+ nullptr, get_weights_accessor(data_path,
+ "/cnn_data/inceptionv3_model/Conv2d_1a_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name("Conv2d_1a_3x3/BatchNorm/batchnorm")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Conv2d_1a_3x3/Relu")
@@ -97,8 +98,8 @@
"/cnn_data/inceptionv3_model/Conv2d_2a_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path,
"/cnn_data/inceptionv3_model/Conv2d_2a_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f), get_weights_accessor(data_path,
- "/cnn_data/inceptionv3_model/Conv2d_2a_3x3_BatchNorm_beta.npy"),
+ nullptr, get_weights_accessor(data_path,
+ "/cnn_data/inceptionv3_model/Conv2d_2a_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name("Conv2d_2a_3x3/BatchNorm/batchnorm")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Conv2d_2a_3x3/Relu")
@@ -111,8 +112,8 @@
"/cnn_data/inceptionv3_model/Conv2d_2b_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path,
"/cnn_data/inceptionv3_model/Conv2d_2b_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f), get_weights_accessor(data_path,
- "/cnn_data/inceptionv3_model/Conv2d_2b_3x3_BatchNorm_beta.npy"),
+ nullptr, get_weights_accessor(data_path,
+ "/cnn_data/inceptionv3_model/Conv2d_2b_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name("Conv2d_2b_3x3/BatchNorm/batchnorm")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Conv2d_2b_3x3/Relu")
@@ -127,8 +128,8 @@
"/cnn_data/inceptionv3_model/Conv2d_3b_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path,
"/cnn_data/inceptionv3_model/Conv2d_3b_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f), get_weights_accessor(data_path,
- "/cnn_data/inceptionv3_model/Conv2d_3b_1x1_BatchNorm_beta.npy"),
+ nullptr, get_weights_accessor(data_path,
+ "/cnn_data/inceptionv3_model/Conv2d_3b_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name("Conv2d_3b_1x1/BatchNorm/batchnorm")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Conv2d_3b_1x1/Relu")
@@ -141,8 +142,8 @@
"/cnn_data/inceptionv3_model/Conv2d_4a_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path,
"/cnn_data/inceptionv3_model/Conv2d_4a_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f), get_weights_accessor(data_path,
- "/cnn_data/inceptionv3_model/Conv2d_4a_3x3_BatchNorm_beta.npy"),
+ nullptr, get_weights_accessor(data_path,
+ "/cnn_data/inceptionv3_model/Conv2d_4a_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name("Conv2d_4a_3x3/BatchNorm/batchnorm")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Conv2d_4a_3x3/Relu")
@@ -248,7 +249,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_0/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -264,7 +265,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d" + conv_id0 + "1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d" + conv_id0 + "1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d" + conv_id0 + "1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d" + conv_id0 + "1x1/BatchNorm/batchnorm")
@@ -278,7 +279,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv" + conv_id1 + "5x5_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv" + conv_id1 + "5x5_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv" + conv_id1 + "5x5_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d" + conv_id1 + "5x5/BatchNorm/batchnorm")
@@ -294,7 +295,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -308,7 +309,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0b_3x3/BatchNorm/batchnorm")
@@ -322,7 +323,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0c_3x3/BatchNorm/batcnorm")
@@ -339,7 +340,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_3/Conv2d_0b_1x1/BatchNorm/batchnorm")
@@ -363,7 +364,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_1a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_1a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_1a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_0/Conv2d_1a_1x1/BatchNorm/batchnorm")
@@ -379,7 +380,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -393,7 +394,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0b_3x3/BatchNorm/batchnorm")
@@ -407,7 +408,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_1a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_1a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_1a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_1a_1x1/BatchNorm/batchnorm")
@@ -436,7 +437,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_0/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -452,7 +453,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -466,7 +467,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x7_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x7_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x7_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0b_1x7/BatchNorm/batchnorm")
@@ -480,7 +481,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0c_7x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0c_7x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0c_7x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0c_7x1/BatchNorm/batchnorm")
@@ -496,7 +497,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -510,7 +511,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_7x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_7x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_7x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0b_7x1/BatchNorm/batchnorm")
@@ -524,7 +525,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_1x7_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_1x7_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_1x7_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0c_1x7/BatchNorm/batchnorm")
@@ -538,7 +539,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0d_7x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0d_7x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0d_7x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0d_7x1/BatchNorm/batchnorm")
@@ -552,7 +553,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0e_1x7_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0e_1x7_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0e_1x7_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0e_1x7/BatchNorm/batchnorm")
@@ -569,7 +570,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_3/Conv2d_0b_1x1/BatchNorm/batchnorm")
@@ -593,7 +594,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_0/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -607,7 +608,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_1a_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_1a_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_1a_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_0/Conv2d_1a_3x3/BatchNorm/batchnorm")
@@ -623,7 +624,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -637,7 +638,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x7_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x7_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x7_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0b_1x7/BatchNorm/batchnorm")
@@ -651,7 +652,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0c_7x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0c_7x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0c_7x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0c_7x1/BatchNorm/batchnorm")
@@ -665,7 +666,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_1a_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_1a_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_1a_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_1a_3x3/BatchNorm/batchnorm")
@@ -702,7 +703,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_0_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_0/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -718,7 +719,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -734,7 +735,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d_0b_1x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d_0b_1x3/BatchNorm/batchnorm")
@@ -750,7 +751,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d" + conv_id + "3x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d" + conv_id + "3x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_1_Conv2d" + conv_id + "3x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_1/Conv2d" + conv_id + "3x1/BatchNorm/batchnorm")
@@ -769,7 +770,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0a_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0a_1x1/BatchNorm/batchnorm")
@@ -783,7 +784,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_3x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_3x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0b_3x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0b_3x3/BatchNorm/batchnorm")
@@ -799,7 +800,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_1x3_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_1x3_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0c_1x3_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0c_1x3/BatchNorm/batchnorm")
@@ -815,7 +816,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0d_3x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0d_3x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_2_Conv2d_0d_3x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_2/Conv2d_0d_3x1/BatchNorm/batchnorm")
@@ -835,7 +836,7 @@
<< BatchNormalizationLayer(
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_moving_mean.npy"),
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_moving_variance.npy"),
- get_random_accessor(1.f, 1.f),
+ nullptr,
get_weights_accessor(data_path, total_path + "Branch_3_Conv2d_0b_1x1_BatchNorm_beta.npy"),
0.001f)
.set_name(param_path + "/Branch_3/Conv2d_0b_1x1/BatchNorm/batchnorm")
diff --git a/examples/graph_inception_v4.cpp b/examples/graph_inception_v4.cpp
index a7f57ec..bac85ee 100644
--- a/examples/graph_inception_v4.cpp
+++ b/examples/graph_inception_v4.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_lenet.cpp b/examples/graph_lenet.cpp
index c75a2f8..9936ea5 100644
--- a/examples/graph_lenet.cpp
+++ b/examples/graph_lenet.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_mnist.cpp b/examples/graph_mnist.cpp
new file mode 100644
index 0000000..eb66138
--- /dev/null
+++ b/examples/graph_mnist.cpp
@@ -0,0 +1,170 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/graph.h"
+#include "support/ToolchainSupport.h"
+#include "utils/CommonGraphOptions.h"
+#include "utils/GraphUtils.h"
+#include "utils/Utils.h"
+
+using namespace arm_compute;
+using namespace arm_compute::utils;
+using namespace arm_compute::graph::frontend;
+using namespace arm_compute::graph_utils;
+
+/** Example demonstrating how to implement Mnist's network using the Compute Library's graph API */
+class GraphMnistExample : public Example
+{
+public:
+ GraphMnistExample()
+ : cmd_parser(), common_opts(cmd_parser), common_params(), graph(0, "LeNet")
+ {
+ }
+ bool do_setup(int argc, char **argv) override
+ {
+ // Parse arguments
+ cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
+
+ // Consume common parameters
+ common_params = consume_common_graph_parameters(common_opts);
+
+ // Return when help menu is requested
+ if(common_params.help)
+ {
+ cmd_parser.print_help(argv[0]);
+ return false;
+ }
+
+ // Print parameter values
+ std::cout << common_params << std::endl;
+
+ // Get trainable parameters data path
+ std::string data_path = common_params.data_path;
+
+ // Add model path to data path
+ if(!data_path.empty() && arm_compute::is_data_type_quantized_asymmetric(common_params.data_type))
+ {
+ data_path += "/cnn_data/mnist_qasymm8_model/";
+ }
+
+ // Create input descriptor
+ const TensorShape tensor_shape = permute_shape(TensorShape(28U, 28U, 1U), DataLayout::NCHW, common_params.data_layout);
+ TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
+
+ const QuantizationInfo in_quant_info = QuantizationInfo(0.003921568859368563f, 0);
+
+ const std::vector<std::pair<QuantizationInfo, QuantizationInfo>> conv_quant_info =
+ {
+ { QuantizationInfo(0.004083447158336639f, 138), QuantizationInfo(0.0046257381327450275f, 0) }, // conv0
+ { QuantizationInfo(0.0048590428195893764f, 149), QuantizationInfo(0.03558270260691643f, 0) }, // conv1
+ { QuantizationInfo(0.004008443560451269f, 146), QuantizationInfo(0.09117382764816284f, 0) }, // conv2
+ { QuantizationInfo(0.004344311077147722f, 160), QuantizationInfo(0.5494495034217834f, 167) }, // fc
+ };
+
+ // Set weights trained layout
+ const DataLayout weights_layout = DataLayout::NHWC;
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo();
+ fc_info.set_weights_trained_layout(weights_layout);
+
+ graph << common_params.target
+ << common_params.fast_math_hint
+ << InputLayer(input_descriptor.set_quantization_info(in_quant_info),
+ get_input_accessor(common_params))
+ << ConvolutionLayer(
+ 3U, 3U, 32U,
+ get_weights_accessor(data_path, "conv2d_weights_quant_FakeQuantWithMinMaxVars.npy", weights_layout),
+ get_weights_accessor(data_path, "conv2d_Conv2D_bias.npy"),
+ PadStrideInfo(1U, 1U, 1U, 1U), 1, conv_quant_info.at(0).first, conv_quant_info.at(0).second)
+ .set_name("Conv0")
+
+ << ConvolutionLayer(
+ 3U, 3U, 32U,
+ get_weights_accessor(data_path, "conv2d_1_weights_quant_FakeQuantWithMinMaxVars.npy", weights_layout),
+ get_weights_accessor(data_path, "conv2d_1_Conv2D_bias.npy"),
+ PadStrideInfo(1U, 1U, 1U, 1U), 1, conv_quant_info.at(1).first, conv_quant_info.at(1).second)
+ .set_name("conv1")
+
+ << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))).set_name("maxpool1")
+
+ << ConvolutionLayer(
+ 3U, 3U, 32U,
+ get_weights_accessor(data_path, "conv2d_2_weights_quant_FakeQuantWithMinMaxVars.npy", weights_layout),
+ get_weights_accessor(data_path, "conv2d_2_Conv2D_bias.npy"),
+ PadStrideInfo(1U, 1U, 1U, 1U), 1, conv_quant_info.at(2).first, conv_quant_info.at(2).second)
+ .set_name("conv2")
+
+ << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))).set_name("maxpool2")
+
+ << FullyConnectedLayer(
+ 10U,
+ get_weights_accessor(data_path, "dense_weights_quant_FakeQuantWithMinMaxVars_transpose.npy", weights_layout),
+ get_weights_accessor(data_path, "dense_MatMul_bias.npy"),
+ fc_info, conv_quant_info.at(3).first, conv_quant_info.at(3).second)
+ .set_name("fc")
+
+ << SoftmaxLayer().set_name("prob");
+
+ if(arm_compute::is_data_type_quantized_asymmetric(common_params.data_type))
+ {
+ graph << DequantizationLayer().set_name("dequantize");
+ }
+
+ graph << OutputLayer(get_output_accessor(common_params, 5));
+
+ // Finalize graph
+ GraphConfig config;
+ config.num_threads = common_params.threads;
+ config.use_tuner = common_params.enable_tuner;
+ config.tuner_mode = common_params.tuner_mode;
+ config.tuner_file = common_params.tuner_file;
+
+ graph.finalize(common_params.target, config);
+
+ return true;
+ }
+ void do_run() override
+ {
+ // Run graph
+ graph.run();
+ }
+
+private:
+ CommandLineParser cmd_parser;
+ CommonGraphOptions common_opts;
+ CommonGraphParams common_params;
+ Stream graph;
+};
+
+/** Main program for Mnist Example
+ *
+ * @note To list all the possible arguments execute the binary appended with the --help option
+ *
+ * @param[in] argc Number of arguments
+ * @param[in] argv Arguments
+ */
+int main(int argc, char **argv)
+{
+ return arm_compute::utils::run_example<GraphMnistExample>(argc, argv);
+}
diff --git a/examples/graph_mobilenet.cpp b/examples/graph_mobilenet.cpp
index 9c014e7..5a39dc0 100644
--- a/examples/graph_mobilenet.cpp
+++ b/examples/graph_mobilenet.cpp
@@ -52,6 +52,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
@@ -78,7 +79,6 @@
// Set graph hints
graph << common_params.target
- << DepthwiseConvolutionMethod::Optimized3x3 // TODO(COMPMID-1073): Add heuristics to automatically call the optimized 3x3 method
<< common_params.fast_math_hint;
// Create core graph
@@ -235,7 +235,7 @@
};
graph << InputLayer(input_descriptor.set_quantization_info(in_quant_info),
- get_weights_accessor(data_path, common_params.image))
+ get_input_accessor(common_params, nullptr, false))
<< ConvolutionLayer(
3U, 3U, 32U,
get_weights_accessor(data_path, "Conv2d_0_weights.npy"),
diff --git a/examples/graph_mobilenet_v2.cpp b/examples/graph_mobilenet_v2.cpp
index 25690aa..337d685 100644
--- a/examples/graph_mobilenet_v2.cpp
+++ b/examples/graph_mobilenet_v2.cpp
@@ -50,6 +50,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
@@ -70,7 +71,6 @@
// Set graph hints
graph << common_params.target
- << DepthwiseConvolutionMethod::Optimized3x3 // TODO(COMPMID-1073): Add heuristics to automatically call the optimized 3x3 method
<< common_params.fast_math_hint;
// Create core graph
diff --git a/examples/graph_resnet12.cpp b/examples/graph_resnet12.cpp
index db70b53..33f29dd 100644
--- a/examples/graph_resnet12.cpp
+++ b/examples/graph_resnet12.cpp
@@ -54,6 +54,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_resnet50.cpp b/examples/graph_resnet50.cpp
index 7c9b95e..17506dc 100644
--- a/examples/graph_resnet50.cpp
+++ b/examples/graph_resnet50.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_resnet_v2_50.cpp b/examples/graph_resnet_v2_50.cpp
index 78845a8..785ae9c 100644
--- a/examples/graph_resnet_v2_50.cpp
+++ b/examples/graph_resnet_v2_50.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_resnext50.cpp b/examples/graph_resnext50.cpp
index 766b8ff..4e505a0 100644
--- a/examples/graph_resnext50.cpp
+++ b/examples/graph_resnext50.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_shufflenet.cpp b/examples/graph_shufflenet.cpp
index 3704be7..0a67f58 100644
--- a/examples/graph_shufflenet.cpp
+++ b/examples/graph_shufflenet.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_squeezenet.cpp b/examples/graph_squeezenet.cpp
index 4796dd3..9721775 100644
--- a/examples/graph_squeezenet.cpp
+++ b/examples/graph_squeezenet.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_squeezenet_v1_1.cpp b/examples/graph_squeezenet_v1_1.cpp
index fd4561f..0fd52b9 100644
--- a/examples/graph_squeezenet_v1_1.cpp
+++ b/examples/graph_squeezenet_v1_1.cpp
@@ -43,6 +43,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_srcnn955.cpp b/examples/graph_srcnn955.cpp
index 066f16e..b693058 100644
--- a/examples/graph_srcnn955.cpp
+++ b/examples/graph_srcnn955.cpp
@@ -54,6 +54,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_ssd_mobilenet.cpp b/examples/graph_ssd_mobilenet.cpp
index 55c9d75..b3476b8 100644
--- a/examples/graph_ssd_mobilenet.cpp
+++ b/examples/graph_ssd_mobilenet.cpp
@@ -41,7 +41,16 @@
{
// Add topk option
keep_topk_opt = cmd_parser.add_option<SimpleOption<int>>("topk", 100);
- keep_topk_opt->set_help("Top k detections results per image.");
+ keep_topk_opt->set_help("Top k detections results per image. Used for data type F32.");
+ // Add output option
+ detection_boxes_opt = cmd_parser.add_option<SimpleOption<std::string>>("detection_boxes_opt", "");
+ detection_boxes_opt->set_help("Filename containing the reference values for the graph output detection_boxes. Used for data type QASYMM8.");
+ detection_classes_opt = cmd_parser.add_option<SimpleOption<std::string>>("detection_classes_opt", "");
+ detection_classes_opt->set_help("Filename containing the reference values for the output detection_classes. Used for data type QASYMM8.");
+ detection_scores_opt = cmd_parser.add_option<SimpleOption<std::string>>("detection_scores_opt", "");
+ detection_scores_opt->set_help("Filename containing the reference values for the output detection_scores. Used for data type QASYMM8.");
+ num_detections_opt = cmd_parser.add_option<SimpleOption<std::string>>("num_detections_opt", "");
+ num_detections_opt->set_help("Filename containing the reference values for the output num_detections. Used with datatype QASYMM8.");
}
GraphSSDMobilenetExample(const GraphSSDMobilenetExample &) = delete;
GraphSSDMobilenetExample &operator=(const GraphSSDMobilenetExample &) = delete;
@@ -52,6 +61,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
@@ -72,12 +82,140 @@
// Set graph hints
graph << common_params.target
- << DepthwiseConvolutionMethod::Optimized3x3 // TODO(COMPMID-1073): Add heuristics to automatically call the optimized 3x3 method
<< common_params.fast_math_hint;
// Create core graph
- std::string model_path = "/cnn_data/ssd_mobilenet_model/";
+ if(arm_compute::is_data_type_float(common_params.data_type))
+ {
+ create_graph_float(input_descriptor);
+ }
+ else
+ {
+ create_graph_qasymm(input_descriptor);
+ }
+ // Finalize graph
+ GraphConfig config;
+ config.num_threads = common_params.threads;
+ config.use_tuner = common_params.enable_tuner;
+ config.tuner_file = common_params.tuner_file;
+
+ graph.finalize(common_params.target, config);
+
+ return true;
+ }
+ void do_run() override
+ {
+ // Run graph
+ graph.run();
+ }
+
+private:
+ CommandLineParser cmd_parser;
+ CommonGraphOptions common_opts;
+ SimpleOption<int> *keep_topk_opt{ nullptr };
+ CommonGraphParams common_params;
+ Stream graph;
+
+ SimpleOption<std::string> *detection_boxes_opt{ nullptr };
+ SimpleOption<std::string> *detection_classes_opt{ nullptr };
+ SimpleOption<std::string> *detection_scores_opt{ nullptr };
+ SimpleOption<std::string> *num_detections_opt{ nullptr };
+
+ ConcatLayer get_node_A_float(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
+ unsigned int conv_filt,
+ PadStrideInfo dwc_pad_stride_info, PadStrideInfo conv_pad_stride_info)
+ {
+ const std::string total_path = param_path + "_";
+ SubStream sg(master_graph);
+
+ sg << DepthwiseConvolutionLayer(
+ 3U, 3U,
+ get_weights_accessor(data_path, total_path + "dw_w.npy"),
+ std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
+ dwc_pad_stride_info)
+ .set_name(param_path + "/dw")
+ << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "dw_bn_mean.npy"),
+ get_weights_accessor(data_path, total_path + "dw_bn_var.npy"),
+ get_weights_accessor(data_path, total_path + "dw_scale_w.npy"),
+ get_weights_accessor(data_path, total_path + "dw_scale_b.npy"), 0.00001f)
+ .set_name(param_path + "/dw/bn")
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(param_path + "dw/relu")
+
+ << ConvolutionLayer(
+ 1U, 1U, conv_filt,
+ get_weights_accessor(data_path, total_path + "w.npy"),
+ std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
+ conv_pad_stride_info)
+ .set_name(param_path + "/pw")
+ << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "bn_mean.npy"),
+ get_weights_accessor(data_path, total_path + "bn_var.npy"),
+ get_weights_accessor(data_path, total_path + "scale_w.npy"),
+ get_weights_accessor(data_path, total_path + "scale_b.npy"), 0.00001f)
+ .set_name(param_path + "/pw/bn")
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(param_path + "pw/relu");
+
+ return ConcatLayer(std::move(sg));
+ }
+
+ ConcatLayer get_node_B_float(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
+ unsigned int conv_filt,
+ PadStrideInfo conv_pad_stride_info_1, PadStrideInfo conv_pad_stride_info_2)
+ {
+ const std::string total_path = param_path + "_";
+ SubStream sg(master_graph);
+
+ sg << ConvolutionLayer(
+ 1, 1, conv_filt / 2,
+ get_weights_accessor(data_path, total_path + "1_w.npy"),
+ std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
+ conv_pad_stride_info_1)
+ .set_name(total_path + "1/conv")
+ << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "1_bn_mean.npy"),
+ get_weights_accessor(data_path, total_path + "1_bn_var.npy"),
+ get_weights_accessor(data_path, total_path + "1_scale_w.npy"),
+ get_weights_accessor(data_path, total_path + "1_scale_b.npy"), 0.00001f)
+ .set_name(total_path + "1/bn")
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(total_path + "1/relu");
+
+ sg << ConvolutionLayer(
+ 3, 3, conv_filt,
+ get_weights_accessor(data_path, total_path + "2_w.npy"),
+ std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
+ conv_pad_stride_info_2)
+ .set_name(total_path + "2/conv")
+ << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "2_bn_mean.npy"),
+ get_weights_accessor(data_path, total_path + "2_bn_var.npy"),
+ get_weights_accessor(data_path, total_path + "2_scale_w.npy"),
+ get_weights_accessor(data_path, total_path + "2_scale_b.npy"), 0.00001f)
+ .set_name(total_path + "2/bn")
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(total_path + "2/relu");
+
+ return ConcatLayer(std::move(sg));
+ }
+
+ ConcatLayer get_node_C_float(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
+ unsigned int conv_filt, PadStrideInfo conv_pad_stride_info)
+ {
+ const std::string total_path = param_path + "_";
+ SubStream sg(master_graph);
+ sg << ConvolutionLayer(
+ 1U, 1U, conv_filt,
+ get_weights_accessor(data_path, total_path + "w.npy"),
+ get_weights_accessor(data_path, total_path + "b.npy"),
+ conv_pad_stride_info)
+ .set_name(param_path + "/conv");
+ if(common_params.data_layout == DataLayout::NCHW)
+ {
+ sg << PermuteLayer(PermutationVector(2U, 0U, 1U), DataLayout::NHWC).set_name(param_path + "/perm");
+ }
+ sg << FlattenLayer().set_name(param_path + "/flat");
+
+ return ConcatLayer(std::move(sg));
+ }
+
+ void create_graph_float(TensorDescriptor &input_descriptor)
+ {
// Create a preprocessor object
const std::array<float, 3> mean_rgb{ { 127.5f, 127.5f, 127.5f } };
std::unique_ptr<IPreprocessor> preprocessor = arm_compute::support::cpp14::make_unique<CaffePreproccessor>(mean_rgb, true, 0.007843f);
@@ -88,7 +226,7 @@
// Add model path to data path
if(!data_path.empty())
{
- data_path += model_path;
+ data_path += "/cnn_data/ssd_mobilenet_model/";
}
graph << InputLayer(input_descriptor,
@@ -108,52 +246,52 @@
.set_name("conv0/bn")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("conv0/relu");
- conv_11 << get_node_A(conv_11, data_path, "conv1", 64, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv2", 128, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv3", 128, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv4", 256, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv5", 256, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv6", 512, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv7", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv8", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv9", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv10", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_11 << get_node_A(conv_11, data_path, "conv11", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv1", 64, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv2", 128, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv3", 128, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv4", 256, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv5", 256, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv6", 512, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv7", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv8", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv9", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv10", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_11 << get_node_A_float(conv_11, data_path, "conv11", 512, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
SubStream conv_13(conv_11);
- conv_13 << get_node_A(conv_11, data_path, "conv12", 1024, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
- conv_13 << get_node_A(conv_13, data_path, "conv13", 1024, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_13 << get_node_A_float(conv_11, data_path, "conv12", 1024, PadStrideInfo(2, 2, 1, 1), PadStrideInfo(1, 1, 0, 0));
+ conv_13 << get_node_A_float(conv_13, data_path, "conv13", 1024, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
SubStream conv_14(conv_13);
- conv_14 << get_node_B(conv_13, data_path, "conv14", 512, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
+ conv_14 << get_node_B_float(conv_13, data_path, "conv14", 512, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
SubStream conv_15(conv_14);
- conv_15 << get_node_B(conv_14, data_path, "conv15", 256, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
+ conv_15 << get_node_B_float(conv_14, data_path, "conv15", 256, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
SubStream conv_16(conv_15);
- conv_16 << get_node_B(conv_15, data_path, "conv16", 256, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
+ conv_16 << get_node_B_float(conv_15, data_path, "conv16", 256, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
SubStream conv_17(conv_16);
- conv_17 << get_node_B(conv_16, data_path, "conv17", 128, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
+ conv_17 << get_node_B_float(conv_16, data_path, "conv17", 128, PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 2, 1, 1));
//mbox_loc
SubStream conv_11_mbox_loc(conv_11);
- conv_11_mbox_loc << get_node_C(conv_11, data_path, "conv11_mbox_loc", 12, PadStrideInfo(1, 1, 0, 0));
+ conv_11_mbox_loc << get_node_C_float(conv_11, data_path, "conv11_mbox_loc", 12, PadStrideInfo(1, 1, 0, 0));
SubStream conv_13_mbox_loc(conv_13);
- conv_13_mbox_loc << get_node_C(conv_13, data_path, "conv13_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
+ conv_13_mbox_loc << get_node_C_float(conv_13, data_path, "conv13_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
SubStream conv_14_2_mbox_loc(conv_14);
- conv_14_2_mbox_loc << get_node_C(conv_14, data_path, "conv14_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
+ conv_14_2_mbox_loc << get_node_C_float(conv_14, data_path, "conv14_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
SubStream conv_15_2_mbox_loc(conv_15);
- conv_15_2_mbox_loc << get_node_C(conv_15, data_path, "conv15_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
+ conv_15_2_mbox_loc << get_node_C_float(conv_15, data_path, "conv15_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
SubStream conv_16_2_mbox_loc(conv_16);
- conv_16_2_mbox_loc << get_node_C(conv_16, data_path, "conv16_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
+ conv_16_2_mbox_loc << get_node_C_float(conv_16, data_path, "conv16_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
SubStream conv_17_2_mbox_loc(conv_17);
- conv_17_2_mbox_loc << get_node_C(conv_17, data_path, "conv17_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
+ conv_17_2_mbox_loc << get_node_C_float(conv_17, data_path, "conv17_2_mbox_loc", 24, PadStrideInfo(1, 1, 0, 0));
SubStream mbox_loc(graph);
mbox_loc << ConcatLayer(std::move(conv_11_mbox_loc), std::move(conv_13_mbox_loc), conv_14_2_mbox_loc, std::move(conv_15_2_mbox_loc),
@@ -161,22 +299,22 @@
//mbox_conf
SubStream conv_11_mbox_conf(conv_11);
- conv_11_mbox_conf << get_node_C(conv_11, data_path, "conv11_mbox_conf", 63, PadStrideInfo(1, 1, 0, 0));
+ conv_11_mbox_conf << get_node_C_float(conv_11, data_path, "conv11_mbox_conf", 63, PadStrideInfo(1, 1, 0, 0));
SubStream conv_13_mbox_conf(conv_13);
- conv_13_mbox_conf << get_node_C(conv_13, data_path, "conv13_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
+ conv_13_mbox_conf << get_node_C_float(conv_13, data_path, "conv13_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
SubStream conv_14_2_mbox_conf(conv_14);
- conv_14_2_mbox_conf << get_node_C(conv_14, data_path, "conv14_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
+ conv_14_2_mbox_conf << get_node_C_float(conv_14, data_path, "conv14_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
SubStream conv_15_2_mbox_conf(conv_15);
- conv_15_2_mbox_conf << get_node_C(conv_15, data_path, "conv15_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
+ conv_15_2_mbox_conf << get_node_C_float(conv_15, data_path, "conv15_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
SubStream conv_16_2_mbox_conf(conv_16);
- conv_16_2_mbox_conf << get_node_C(conv_16, data_path, "conv16_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
+ conv_16_2_mbox_conf << get_node_C_float(conv_16, data_path, "conv16_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
SubStream conv_17_2_mbox_conf(conv_17);
- conv_17_2_mbox_conf << get_node_C(conv_17, data_path, "conv17_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
+ conv_17_2_mbox_conf << get_node_C_float(conv_17, data_path, "conv17_2_mbox_conf", 126, PadStrideInfo(1, 1, 0, 0));
SubStream mbox_conf(graph);
mbox_conf << ConcatLayer(std::move(conv_11_mbox_conf), std::move(conv_13_mbox_conf), std::move(conv_14_2_mbox_conf),
@@ -224,7 +362,8 @@
SubStream mbox_priorbox(graph);
mbox_priorbox << ConcatLayer(
- (common_params.data_layout == DataLayout::NCHW) ? DataLayoutDimension::WIDTH : DataLayoutDimension::CHANNEL,
+ (common_params.data_layout == DataLayout::NCHW) ? arm_compute::graph::descriptors::ConcatLayerDescriptor(DataLayoutDimension::WIDTH) : arm_compute::graph::descriptors::ConcatLayerDescriptor(
+ DataLayoutDimension::CHANNEL),
std::move(conv_11_mbox_priorbox), std::move(conv_13_mbox_priorbox), std::move(conv_14_2_mbox_priorbox),
std::move(conv_15_2_mbox_priorbox), std::move(conv_16_2_mbox_priorbox), std::move(conv_17_2_mbox_priorbox));
@@ -240,35 +379,13 @@
SubStream detection_ouput(mbox_loc);
detection_ouput << DetectionOutputLayer(std::move(mbox_conf), std::move(mbox_priorbox),
DetectionOutputLayerInfo(num_classes, share_location, detection_type, keep_top_k, nms_threshold, top_k, label_id_background, conf_thrs));
- detection_ouput << OutputLayer(get_detection_output_accessor(common_params, { tensor_shape }));
-
- // Finalize graph
- GraphConfig config;
- config.num_threads = common_params.threads;
- config.use_tuner = common_params.enable_tuner;
- config.tuner_mode = common_params.tuner_mode;
- config.tuner_file = common_params.tuner_file;
-
- graph.finalize(common_params.target, config);
-
- return true;
- }
- void do_run() override
- {
- // Run graph
- graph.run();
+ detection_ouput << OutputLayer(get_detection_output_accessor(common_params, { input_descriptor.shape }));
}
-private:
- CommandLineParser cmd_parser;
- CommonGraphOptions common_opts;
- SimpleOption<int> *keep_topk_opt{ nullptr };
- CommonGraphParams common_params;
- Stream graph;
-
- ConcatLayer get_node_A(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
- unsigned int conv_filt,
- PadStrideInfo dwc_pad_stride_info, PadStrideInfo conv_pad_stride_info)
+ ConcatLayer get_node_A_qasymm(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
+ unsigned int conv_filt,
+ PadStrideInfo dwc_pad_stride_info, PadStrideInfo conv_pad_stride_info,
+ std::pair<QuantizationInfo, QuantizationInfo> depth_quant_info, std::pair<QuantizationInfo, QuantizationInfo> point_quant_info)
{
const std::string total_path = param_path + "_";
SubStream sg(master_graph);
@@ -276,70 +393,52 @@
sg << DepthwiseConvolutionLayer(
3U, 3U,
get_weights_accessor(data_path, total_path + "dw_w.npy"),
- std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
- dwc_pad_stride_info)
+ get_weights_accessor(data_path, total_path + "dw_b.npy"),
+ dwc_pad_stride_info, 1, depth_quant_info.first, depth_quant_info.second)
.set_name(param_path + "/dw")
- << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "dw_bn_mean.npy"),
- get_weights_accessor(data_path, total_path + "dw_bn_var.npy"),
- get_weights_accessor(data_path, total_path + "dw_scale_w.npy"),
- get_weights_accessor(data_path, total_path + "dw_scale_b.npy"), 0.00001f)
- .set_name(param_path + "/dw/bn")
- << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(param_path + "dw/relu")
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)).set_name(param_path + "/dw/relu6");
- << ConvolutionLayer(
+ sg << ConvolutionLayer(
1U, 1U, conv_filt,
get_weights_accessor(data_path, total_path + "w.npy"),
- std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
- conv_pad_stride_info)
+ get_weights_accessor(data_path, total_path + "b.npy"),
+ conv_pad_stride_info, 1, point_quant_info.first, point_quant_info.second)
.set_name(param_path + "/pw")
- << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "bn_mean.npy"),
- get_weights_accessor(data_path, total_path + "bn_var.npy"),
- get_weights_accessor(data_path, total_path + "scale_w.npy"),
- get_weights_accessor(data_path, total_path + "scale_b.npy"), 0.00001f)
- .set_name(param_path + "/pw/bn")
- << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(param_path + "pw/relu");
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)).set_name(param_path + "/pw/relu6");
return ConcatLayer(std::move(sg));
}
- ConcatLayer get_node_B(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
- unsigned int conv_filt,
- PadStrideInfo conv_pad_stride_info_1, PadStrideInfo conv_pad_stride_info_2)
+ ConcatLayer get_node_B_qasymm(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
+ unsigned int conv_filt,
+ PadStrideInfo conv_pad_stride_info_1x1, PadStrideInfo conv_pad_stride_info_3x3,
+ const std::pair<QuantizationInfo, QuantizationInfo> quant_info_1x1, const std::pair<QuantizationInfo, QuantizationInfo> quant_info_3x3)
{
const std::string total_path = param_path + "_";
SubStream sg(master_graph);
sg << ConvolutionLayer(
1, 1, conv_filt / 2,
- get_weights_accessor(data_path, total_path + "1_w.npy"),
- std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
- conv_pad_stride_info_1)
- .set_name(total_path + "1/conv")
- << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "1_bn_mean.npy"),
- get_weights_accessor(data_path, total_path + "1_bn_var.npy"),
- get_weights_accessor(data_path, total_path + "1_scale_w.npy"),
- get_weights_accessor(data_path, total_path + "1_scale_b.npy"), 0.00001f)
- .set_name(total_path + "1/bn")
- << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(total_path + "1/relu");
+ get_weights_accessor(data_path, total_path + "1x1_w.npy"),
+ get_weights_accessor(data_path, total_path + "1x1_b.npy"),
+ conv_pad_stride_info_1x1, 1, quant_info_1x1.first, quant_info_1x1.second)
+ .set_name(total_path + "1x1/conv")
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)).set_name(total_path + "1x1/conv/relu6");
sg << ConvolutionLayer(
3, 3, conv_filt,
- get_weights_accessor(data_path, total_path + "2_w.npy"),
- std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
- conv_pad_stride_info_2)
- .set_name(total_path + "2/conv")
- << BatchNormalizationLayer(get_weights_accessor(data_path, total_path + "2_bn_mean.npy"),
- get_weights_accessor(data_path, total_path + "2_bn_var.npy"),
- get_weights_accessor(data_path, total_path + "2_scale_w.npy"),
- get_weights_accessor(data_path, total_path + "2_scale_b.npy"), 0.00001f)
- .set_name(total_path + "2/bn")
- << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name(total_path + "2/relu");
+ get_weights_accessor(data_path, total_path + "3x3_w.npy"),
+ get_weights_accessor(data_path, total_path + "3x3_b.npy"),
+ conv_pad_stride_info_3x3, 1, quant_info_3x3.first, quant_info_3x3.second)
+ .set_name(total_path + "3x3/conv")
+ << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)).set_name(total_path + "3x3/conv/relu6");
return ConcatLayer(std::move(sg));
}
- ConcatLayer get_node_C(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
- unsigned int conv_filt, PadStrideInfo conv_pad_stride_info)
+ ConcatLayer get_node_C_qasymm(IStream &master_graph, const std::string &data_path, std::string &¶m_path,
+ unsigned int conv_filt, PadStrideInfo conv_pad_stride_info,
+ const std::pair<QuantizationInfo, QuantizationInfo> quant_info, TensorShape reshape_shape)
{
const std::string total_path = param_path + "_";
SubStream sg(master_graph);
@@ -347,16 +446,256 @@
1U, 1U, conv_filt,
get_weights_accessor(data_path, total_path + "w.npy"),
get_weights_accessor(data_path, total_path + "b.npy"),
- conv_pad_stride_info)
+ conv_pad_stride_info, 1, quant_info.first, quant_info.second)
.set_name(param_path + "/conv");
if(common_params.data_layout == DataLayout::NCHW)
{
- sg << PermuteLayer(PermutationVector(2U, 0U, 1U), DataLayout::NHWC).set_name(param_path + "/perm");
+ sg << PermuteLayer(PermutationVector(2U, 0U, 1U), DataLayout::NHWC);
}
- sg << FlattenLayer().set_name(param_path + "/flat");
+ sg << ReshapeLayer(reshape_shape).set_name(param_path + "/reshape");
return ConcatLayer(std::move(sg));
}
+
+ void create_graph_qasymm(TensorDescriptor &input_descriptor)
+ {
+ // Get trainable parameters data path
+ std::string data_path = common_params.data_path;
+
+ // Add model path to data path
+ if(!data_path.empty())
+ {
+ data_path += "/cnn_data/ssd_mobilenet_qasymm8_model/";
+ }
+
+ // Quantization info are saved as pair for each (pointwise/depthwise) convolution layer: <weight_quant_info, output_quant_info>
+ const std::vector<std::pair<QuantizationInfo, QuantizationInfo>> conv_quant_info =
+ {
+ { QuantizationInfo(0.03624850884079933f, 163), QuantizationInfo(0.22219789028167725f, 113) }, // conv0
+ { QuantizationInfo(0.0028752065263688564f, 113), QuantizationInfo(0.05433657020330429f, 128) }, // conv13_2_1_1
+ { QuantizationInfo(0.0014862528769299388f, 125), QuantizationInfo(0.05037643015384674f, 131) }, // conv13_2_3_3
+ { QuantizationInfo(0.00233650766313076f, 113), QuantizationInfo(0.04468846693634987f, 126) }, // conv13_3_1_1
+ { QuantizationInfo(0.002501056529581547f, 120), QuantizationInfo(0.06026708707213402f, 111) }, // conv13_3_3_3
+ { QuantizationInfo(0.002896666992455721f, 121), QuantizationInfo(0.037775348871946335f, 117) }, // conv13_4_1_1
+ { QuantizationInfo(0.0023875406477600336f, 122), QuantizationInfo(0.03881589323282242f, 108) }, // conv13_4_3_3
+ { QuantizationInfo(0.0022081052884459496f, 77), QuantizationInfo(0.025450613349676132f, 125) }, // conv13_5_1_1
+ { QuantizationInfo(0.00604657270014286f, 121), QuantizationInfo(0.033533502370119095f, 109) } // conv13_5_3_3
+ };
+
+ const std::vector<std::pair<QuantizationInfo, QuantizationInfo>> depth_quant_info =
+ {
+ { QuantizationInfo(0.03408717364072f, 131), QuantizationInfo(0.29286590218544006f, 108) }, // dwsc1
+ { QuantizationInfo(0.027518004179000854f, 107), QuantizationInfo(0.20796941220760345, 117) }, // dwsc2
+ { QuantizationInfo(0.052489638328552246f, 85), QuantizationInfo(0.4303881824016571f, 142) }, // dwsc3
+ { QuantizationInfo(0.016570359468460083f, 79), QuantizationInfo(0.10512150079011917f, 116) }, // dwsc4
+ { QuantizationInfo(0.060739465057849884f, 65), QuantizationInfo(0.15331414341926575f, 94) }, // dwsc5
+ { QuantizationInfo(0.01324534136801958f, 124), QuantizationInfo(0.13010895252227783f, 153) }, // dwsc6
+ { QuantizationInfo(0.032326459884643555f, 124), QuantizationInfo(0.11565316468477249, 156) }, // dwsc7
+ { QuantizationInfo(0.029948478564620018f, 155), QuantizationInfo(0.11413891613483429f, 146) }, // dwsc8
+ { QuantizationInfo(0.028054025024175644f, 129), QuantizationInfo(0.1142905130982399f, 140) }, // dwsc9
+ { QuantizationInfo(0.025204822421073914f, 129), QuantizationInfo(0.14668069779872894f, 149) }, // dwsc10
+ { QuantizationInfo(0.019332280382514f, 110), QuantizationInfo(0.1480235457420349f, 91) }, // dwsc11
+ { QuantizationInfo(0.0319712869822979f, 88), QuantizationInfo(0.10424695909023285f, 117) }, // dwsc12
+ { QuantizationInfo(0.04378943517804146f, 164), QuantizationInfo(0.23176774382591248f, 138) } // dwsc13
+ };
+
+ const std::vector<std::pair<QuantizationInfo, QuantizationInfo>> point_quant_info =
+ {
+ { QuantizationInfo(0.028777318075299263f, 144), QuantizationInfo(0.2663874328136444f, 121) }, // pw1
+ { QuantizationInfo(0.015796702355146408f, 127), QuantizationInfo(0.1739964485168457f, 111) }, // pw2
+ { QuantizationInfo(0.009349990636110306f, 127), QuantizationInfo(0.1805974692106247f, 104) }, // pw3
+ { QuantizationInfo(0.012920888140797615f, 106), QuantizationInfo(0.1205204650759697f, 100) }, // pw4
+ { QuantizationInfo(0.008119508624076843f, 145), QuantizationInfo(0.12272439152002335f, 97) }, // pw5
+ { QuantizationInfo(0.0070041813887655735f, 115), QuantizationInfo(0.0947074219584465f, 101) }, // pw6
+ { QuantizationInfo(0.004827278666198254f, 115), QuantizationInfo(0.0842885747551918f, 110) }, // pw7
+ { QuantizationInfo(0.004755120258778334f, 128), QuantizationInfo(0.08283159881830215f, 116) }, // pw8
+ { QuantizationInfo(0.007527193054556847f, 142), QuantizationInfo(0.12555131316184998f, 137) }, // pw9
+ { QuantizationInfo(0.006050156895071268f, 109), QuantizationInfo(0.10871313512325287f, 124) }, // pw10
+ { QuantizationInfo(0.00490700313821435f, 127), QuantizationInfo(0.10364262014627457f, 140) }, // pw11
+ { QuantizationInfo(0.006063731852918863, 124), QuantizationInfo(0.11241862177848816f, 125) }, // pw12
+ { QuantizationInfo(0.007901716977357864f, 139), QuantizationInfo(0.49889302253723145f, 141) } // pw13
+ };
+
+ // Quantization info taken from the TfLite SSD MobileNet example
+ const QuantizationInfo in_quant_info = QuantizationInfo(0.0078125f, 128);
+ // Create core graph
+ graph << InputLayer(input_descriptor.set_quantization_info(in_quant_info),
+ get_weights_accessor(data_path, common_params.image, DataLayout::NHWC));
+ graph << ConvolutionLayer(
+ 3U, 3U, 32U,
+ get_weights_accessor(data_path, "conv0_w.npy"),
+ get_weights_accessor(data_path, "conv0_b.npy"),
+ PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::CEIL), 1, conv_quant_info.at(0).first, conv_quant_info.at(0).second)
+ .set_name("conv0");
+ graph << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)).set_name("conv0/relu");
+ graph << get_node_A_qasymm(graph, data_path, "conv1", 64U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(0),
+ point_quant_info.at(0));
+ graph << get_node_A_qasymm(graph, data_path, "conv2", 128U, PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(1),
+ point_quant_info.at(1));
+ graph << get_node_A_qasymm(graph, data_path, "conv3", 128U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(2),
+ point_quant_info.at(2));
+ graph << get_node_A_qasymm(graph, data_path, "conv4", 256U, PadStrideInfo(2U, 2U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(3),
+ point_quant_info.at(3));
+ graph << get_node_A_qasymm(graph, data_path, "conv5", 256U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(4),
+ point_quant_info.at(4));
+ graph << get_node_A_qasymm(graph, data_path, "conv6", 512U, PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(5),
+ point_quant_info.at(5));
+ graph << get_node_A_qasymm(graph, data_path, "conv7", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(6),
+ point_quant_info.at(6));
+ graph << get_node_A_qasymm(graph, data_path, "conv8", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(7),
+ point_quant_info.at(7));
+ graph << get_node_A_qasymm(graph, data_path, "conv9", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(8),
+ point_quant_info.at(8));
+ graph << get_node_A_qasymm(graph, data_path, "conv10", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(9),
+ point_quant_info.at(9));
+ graph << get_node_A_qasymm(graph, data_path, "conv11", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(10),
+ point_quant_info.at(10));
+
+ SubStream conv_13(graph);
+ conv_13 << get_node_A_qasymm(graph, data_path, "conv12", 1024U, PadStrideInfo(2U, 2U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(11),
+ point_quant_info.at(11));
+ conv_13 << get_node_A_qasymm(conv_13, data_path, "conv13", 1024U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), PadStrideInfo(1U, 1U, 0U, 0U), depth_quant_info.at(12),
+ point_quant_info.at(12));
+ SubStream conv_14(conv_13);
+ conv_14 << get_node_B_qasymm(conv_13, data_path, "conv13_2", 512U, PadStrideInfo(1U, 1U, 0U, 0U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::CEIL), conv_quant_info.at(1),
+ conv_quant_info.at(2));
+ SubStream conv_15(conv_14);
+ conv_15 << get_node_B_qasymm(conv_14, data_path, "conv13_3", 256U, PadStrideInfo(1U, 1U, 0U, 0U), PadStrideInfo(2U, 2U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), conv_quant_info.at(3),
+ conv_quant_info.at(4));
+ SubStream conv_16(conv_15);
+ conv_16 << get_node_B_qasymm(conv_15, data_path, "conv13_4", 256U, PadStrideInfo(1U, 1U, 0U, 0U), PadStrideInfo(2U, 2U, 1U, 1U, 1U, 1U, DimensionRoundingType::CEIL), conv_quant_info.at(5),
+ conv_quant_info.at(6));
+ SubStream conv_17(conv_16);
+ conv_17 << get_node_B_qasymm(conv_16, data_path, "conv13_5", 128U, PadStrideInfo(1U, 1U, 0U, 0U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::CEIL), conv_quant_info.at(7),
+ conv_quant_info.at(8));
+
+ // box_predictor
+ const std::vector<std::pair<QuantizationInfo, QuantizationInfo>> box_enc_pred_quant_info =
+ {
+ { QuantizationInfo(0.005202020984143019f, 136), QuantizationInfo(0.08655580133199692f, 183) }, // boxpredictor0_bep
+ { QuantizationInfo(0.003121797926723957f, 132), QuantizationInfo(0.03218776360154152f, 140) }, // boxpredictor1_bep
+ { QuantizationInfo(0.002995674265548587f, 130), QuantizationInfo(0.029072262346744537f, 125) }, // boxpredictor2_bep
+ { QuantizationInfo(0.0023131705820560455f, 130), QuantizationInfo(0.026488754898309708f, 127) }, // boxpredictor3_bep
+ { QuantizationInfo(0.0013905081432312727f, 132), QuantizationInfo(0.0199890099465847f, 137) }, // boxpredictor4_bep
+ { QuantizationInfo(0.00216794665902853f, 121), QuantizationInfo(0.019798893481492996f, 151) } // boxpredictor5_bep
+ };
+
+ const std::vector<TensorShape> box_reshape = // NHWC
+ {
+ TensorShape(4U, 1U, 1083U), // boxpredictor0_bep_reshape
+ TensorShape(4U, 1U, 600U), // boxpredictor1_bep_reshape
+ TensorShape(4U, 1U, 150U), // boxpredictor2_bep_reshape
+ TensorShape(4U, 1U, 54U), // boxpredictor3_bep_reshape
+ TensorShape(4U, 1U, 24U), // boxpredictor4_bep_reshape
+ TensorShape(4U, 1U, 6U) // boxpredictor5_bep_reshape
+ };
+
+ SubStream conv_11_box_enc_pre(graph);
+ conv_11_box_enc_pre << get_node_C_qasymm(graph, data_path, "BoxPredictor_0_BEP", 12U, PadStrideInfo(1U, 1U, 0U, 0U), box_enc_pred_quant_info.at(0), box_reshape.at(0));
+
+ SubStream conv_13_box_enc_pre(conv_13);
+ conv_13_box_enc_pre << get_node_C_qasymm(conv_13, data_path, "BoxPredictor_1_BEP", 24U, PadStrideInfo(1U, 1U, 0U, 0U), box_enc_pred_quant_info.at(1), box_reshape.at(1));
+
+ SubStream conv_14_2_box_enc_pre(conv_14);
+ conv_14_2_box_enc_pre << get_node_C_qasymm(conv_14, data_path, "BoxPredictor_2_BEP", 24U, PadStrideInfo(1U, 1U, 0U, 0U), box_enc_pred_quant_info.at(2), box_reshape.at(2));
+
+ SubStream conv_15_2_box_enc_pre(conv_15);
+ conv_15_2_box_enc_pre << get_node_C_qasymm(conv_15, data_path, "BoxPredictor_3_BEP", 24U, PadStrideInfo(1U, 1U, 0U, 0U), box_enc_pred_quant_info.at(3), box_reshape.at(3));
+
+ SubStream conv_16_2_box_enc_pre(conv_16);
+ conv_16_2_box_enc_pre << get_node_C_qasymm(conv_16, data_path, "BoxPredictor_4_BEP", 24U, PadStrideInfo(1U, 1U, 0U, 0U), box_enc_pred_quant_info.at(4), box_reshape.at(4));
+
+ SubStream conv_17_2_box_enc_pre(conv_17);
+ conv_17_2_box_enc_pre << get_node_C_qasymm(conv_17, data_path, "BoxPredictor_5_BEP", 24U, PadStrideInfo(1U, 1U, 0U, 0U), box_enc_pred_quant_info.at(5), box_reshape.at(5));
+
+ SubStream box_enc_pre(graph);
+ const QuantizationInfo bep_concate_qinfo = QuantizationInfo(0.08655580133199692f, 183);
+ box_enc_pre << ConcatLayer(arm_compute::graph::descriptors::ConcatLayerDescriptor(DataLayoutDimension::HEIGHT, bep_concate_qinfo),
+ std::move(conv_11_box_enc_pre), std::move(conv_13_box_enc_pre), conv_14_2_box_enc_pre, std::move(conv_15_2_box_enc_pre),
+ std::move(conv_16_2_box_enc_pre), std::move(conv_17_2_box_enc_pre))
+ .set_name("BoxPredictor/concat");
+ box_enc_pre << ReshapeLayer(TensorShape(4U, 1917U)).set_name("BoxPredictor/reshape");
+
+ // class_predictor
+ const std::vector<std::pair<QuantizationInfo, QuantizationInfo>> class_pred_quant_info =
+ {
+ { QuantizationInfo(0.002744135679677129f, 125), QuantizationInfo(0.05746262148022652f, 234) }, // boxpredictor0_cp
+ { QuantizationInfo(0.0024326108396053314f, 80), QuantizationInfo(0.03764628246426582f, 217) }, // boxpredictor1_cp
+ { QuantizationInfo(0.0013898586621508002f, 141), QuantizationInfo(0.034081317484378815f, 214) }, // boxpredictor2_cp
+ { QuantizationInfo(0.0014176908880472183f, 133), QuantizationInfo(0.033889178186655045f, 215) }, // boxpredictor3_cp
+ { QuantizationInfo(0.001090311910957098f, 125), QuantizationInfo(0.02646234817802906f, 230) }, // boxpredictor4_cp
+ { QuantizationInfo(0.001134163816459477f, 115), QuantizationInfo(0.026926767081022263f, 218) } // boxpredictor5_cp
+ };
+
+ const std::vector<TensorShape> class_reshape =
+ {
+ TensorShape(91U, 1083U), // boxpredictor0_cp_reshape
+ TensorShape(91U, 600U), // boxpredictor1_cp_reshape
+ TensorShape(91U, 150U), // boxpredictor2_cp_reshape
+ TensorShape(91U, 54U), // boxpredictor3_cp_reshape
+ TensorShape(91U, 24U), // boxpredictor4_cp_reshape
+ TensorShape(91U, 6U) // boxpredictor5_cp_reshape
+ };
+
+ SubStream conv_11_class_pre(graph);
+ conv_11_class_pre << get_node_C_qasymm(graph, data_path, "BoxPredictor_0_CP", 273U, PadStrideInfo(1U, 1U, 0U, 0U), class_pred_quant_info.at(0), class_reshape.at(0));
+
+ SubStream conv_13_class_pre(conv_13);
+ conv_13_class_pre << get_node_C_qasymm(conv_13, data_path, "BoxPredictor_1_CP", 546U, PadStrideInfo(1U, 1U, 0U, 0U), class_pred_quant_info.at(1), class_reshape.at(1));
+
+ SubStream conv_14_2_class_pre(conv_14);
+ conv_14_2_class_pre << get_node_C_qasymm(conv_14, data_path, "BoxPredictor_2_CP", 546U, PadStrideInfo(1U, 1U, 0U, 0U), class_pred_quant_info.at(2), class_reshape.at(2));
+
+ SubStream conv_15_2_class_pre(conv_15);
+ conv_15_2_class_pre << get_node_C_qasymm(conv_15, data_path, "BoxPredictor_3_CP", 546U, PadStrideInfo(1U, 1U, 0U, 0U), class_pred_quant_info.at(3), class_reshape.at(3));
+
+ SubStream conv_16_2_class_pre(conv_16);
+ conv_16_2_class_pre << get_node_C_qasymm(conv_16, data_path, "BoxPredictor_4_CP", 546U, PadStrideInfo(1U, 1U, 0U, 0U), class_pred_quant_info.at(4), class_reshape.at(4));
+
+ SubStream conv_17_2_class_pre(conv_17);
+ conv_17_2_class_pre << get_node_C_qasymm(conv_17, data_path, "BoxPredictor_5_CP", 546U, PadStrideInfo(1U, 1U, 0U, 0U), class_pred_quant_info.at(5), class_reshape.at(5));
+
+ const QuantizationInfo cp_concate_qinfo = QuantizationInfo(0.0584389753639698f, 230);
+ SubStream class_pred(graph);
+ class_pred << ConcatLayer(
+ arm_compute::graph::descriptors::ConcatLayerDescriptor(DataLayoutDimension::WIDTH, cp_concate_qinfo),
+ std::move(conv_11_class_pre), std::move(conv_13_class_pre), std::move(conv_14_2_class_pre),
+ std::move(conv_15_2_class_pre), std::move(conv_16_2_class_pre), std::move(conv_17_2_class_pre))
+ .set_name("ClassPrediction/concat");
+
+ const QuantizationInfo logistic_out_qinfo = QuantizationInfo(0.00390625f, 0);
+ class_pred << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC), logistic_out_qinfo).set_name("ClassPrediction/logistic");
+
+ const int max_detections = 10;
+ const int max_classes_per_detection = 1;
+ const float nms_score_threshold = 0.30000001192092896f;
+ const float nms_iou_threshold = 0.6000000238418579f;
+ const int num_classes = 90;
+ const float x_scale = 10.f;
+ const float y_scale = 10.f;
+ const float h_scale = 5.f;
+ const float w_scale = 5.f;
+ std::array<float, 4> scales = { y_scale, x_scale, w_scale, h_scale };
+ const QuantizationInfo anchors_qinfo = QuantizationInfo(0.006453060545027256f, 0);
+
+ SubStream detection_ouput(box_enc_pre);
+ detection_ouput << DetectionPostProcessLayer(std::move(class_pred),
+ DetectionPostProcessLayerInfo(max_detections, max_classes_per_detection, nms_score_threshold, nms_iou_threshold, num_classes, scales),
+ get_weights_accessor(data_path, "anchors.npy"), anchors_qinfo)
+ .set_name("DetectionPostProcess");
+
+ SubStream ouput_0(detection_ouput);
+ ouput_0 << OutputLayer(get_npy_output_accessor(detection_boxes_opt->value(), TensorShape(4U, 10U), DataType::F32), 0);
+
+ SubStream ouput_1(detection_ouput);
+ ouput_1 << OutputLayer(get_npy_output_accessor(detection_classes_opt->value(), TensorShape(10U), DataType::F32), 1);
+
+ SubStream ouput_2(detection_ouput);
+ ouput_2 << OutputLayer(get_npy_output_accessor(detection_scores_opt->value(), TensorShape(10U), DataType::F32), 2);
+
+ SubStream ouput_3(detection_ouput);
+ ouput_3 << OutputLayer(get_npy_output_accessor(num_detections_opt->value(), TensorShape(1U), DataType::F32), 3);
+ }
};
/** Main program for MobileNetSSD
diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp
index e8055d4..d58bf6c 100644
--- a/examples/graph_vgg16.cpp
+++ b/examples/graph_vgg16.cpp
@@ -53,6 +53,7 @@
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_vgg19.cpp b/examples/graph_vgg19.cpp
index 63051fb..82895bb 100644
--- a/examples/graph_vgg19.cpp
+++ b/examples/graph_vgg19.cpp
@@ -52,6 +52,7 @@
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_vgg_vdsr.cpp b/examples/graph_vgg_vdsr.cpp
index 9f0b357..f82ae4c 100644
--- a/examples/graph_vgg_vdsr.cpp
+++ b/examples/graph_vgg_vdsr.cpp
@@ -55,6 +55,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);
diff --git a/examples/graph_yolov3.cpp b/examples/graph_yolov3.cpp
index c0a97da..bbc6b72 100644
--- a/examples/graph_yolov3.cpp
+++ b/examples/graph_yolov3.cpp
@@ -44,6 +44,7 @@
{
// Parse arguments
cmd_parser.parse(argc, argv);
+ cmd_parser.validate();
// Consume common parameters
common_params = consume_common_graph_parameters(common_opts);