Renames AssetUtils to IOUtils and moves readFloats() there.
readFloats() logic was duplicated in MelCepLogF0 and TopK evaluators.
Bug: 113651411
Test: Covered by existing test:
com.example.android.nn.benchmark.NNTest
Change-Id: If5445319ae349af676392830109857f035728a0f
Merged-In: If5445319ae349af676392830109857f035728a0f
(cherry picked from commit bee83fac65c0348d9d8b2023d5e991c146b6ddd5)
diff --git a/src/com/android/nn/benchmark/core/EvaluatorConfig.java b/src/com/android/nn/benchmark/core/EvaluatorConfig.java
index 530b570..a32a462 100644
--- a/src/com/android/nn/benchmark/core/EvaluatorConfig.java
+++ b/src/com/android/nn/benchmark/core/EvaluatorConfig.java
@@ -17,7 +17,7 @@
package com.android.nn.benchmark.core;
import android.content.res.AssetManager;
-import com.android.nn.benchmark.util.AssetUtils;
+import com.android.nn.benchmark.util.IOUtils;
/**
* Config options for inference accuracy evaluators.
@@ -39,7 +39,7 @@
"com.android.nn.benchmark.evaluators." + className);
EvaluatorInterface evaluator = (EvaluatorInterface) clazz.getConstructor().newInstance();
if (outputMeanStdDev != null) {
- evaluator.setOutputMeanStdDev(new OutputMeanStdDev(AssetUtils.readAsset(
+ evaluator.setOutputMeanStdDev(new OutputMeanStdDev(IOUtils.readAsset(
assetManager, outputMeanStdDev, MeanStdDev.ELEMENT_SIZE_BYTES)));
}
return evaluator;
diff --git a/src/com/android/nn/benchmark/core/InferenceInOutSequence.java b/src/com/android/nn/benchmark/core/InferenceInOutSequence.java
index fc3af48..3ebd434 100644
--- a/src/com/android/nn/benchmark/core/InferenceInOutSequence.java
+++ b/src/com/android/nn/benchmark/core/InferenceInOutSequence.java
@@ -21,7 +21,7 @@
import java.io.IOException;
import android.content.res.AssetManager;
-import com.android.nn.benchmark.util.AssetUtils;
+import com.android.nn.benchmark.util.IOUtils;
import java.io.InputStream;
import java.io.InputStreamReader;
@@ -76,8 +76,8 @@
}
public InferenceInOutSequence readAssets(AssetManager assetManager) throws IOException {
- byte[] inputs = AssetUtils.readAsset(assetManager, mInputAssetName, mDataBytesSize);
- byte[] outputs = AssetUtils.readAsset(assetManager, mOutputAssetName, mDataBytesSize);
+ byte[] inputs = IOUtils.readAsset(assetManager, mInputAssetName, mDataBytesSize);
+ byte[] outputs = IOUtils.readAsset(assetManager, mOutputAssetName, mDataBytesSize);
if (inputs.length % mInputSizeBytes != 0) {
throw new IllegalArgumentException("Input data size (in bytes): " + inputs.length +
" is not a multiple of input size (in bytes): " + mInputSizeBytes);
diff --git a/src/com/android/nn/benchmark/evaluators/MelCepLogF0.java b/src/com/android/nn/benchmark/evaluators/MelCepLogF0.java
index ffbdb6f..b0e9042 100644
--- a/src/com/android/nn/benchmark/evaluators/MelCepLogF0.java
+++ b/src/com/android/nn/benchmark/evaluators/MelCepLogF0.java
@@ -19,9 +19,8 @@
import android.util.Log;
import com.android.nn.benchmark.core.*;
+import com.android.nn.benchmark.util.IOUtils;
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
import java.util.List;
/**
@@ -75,14 +74,16 @@
for (int i = 0; i < sequenceLength; ++i, ++inferenceIndex) {
InferenceResult result = inferenceResults.get(inferenceIndex);
System.arraycopy(
- mOutputMeanStdDev.denormalize(readBytes(result.mInferenceOutput, dataSize)),
- 0, outputs[i], 0, outputSize);
+ mOutputMeanStdDev.denormalize(IOUtils.readFloats(result.mInferenceOutput,
+ dataSize)), 0,
+ outputs[i], 0, outputSize);
InferenceInOut inOut = inferenceInOuts.get(result.mInputOutputSequenceIndex)
.get(result.mInputOutputIndex);
System.arraycopy(
- mOutputMeanStdDev.denormalize(readBytes(inOut.mExpectedOutput, dataSize)),
- 0, expectedOutputs[i], 0, outputSize);
+ mOutputMeanStdDev.denormalize(IOUtils.readFloats(inOut.mExpectedOutput,
+ dataSize)), 0,
+ expectedOutputs[i], 0, outputSize);
}
float melCepDistortion = calculateMelCepDistortion(outputs, expectedOutputs);
@@ -106,20 +107,6 @@
values.add(maxLogF0Error);
}
- private static float[] readBytes(byte[] bytes, int dataSize) {
- ByteBuffer buffer = ByteBuffer.wrap(bytes);
- buffer.order(ByteOrder.LITTLE_ENDIAN);
- int size = bytes.length / dataSize;
- float[] result = new float[size];
- for (int i = 0; i < size; ++i) {
- if (dataSize == 4) {
- result[i] = buffer.getFloat();
- }
- // TODO: Handle dataSize == 1 when adding the quantized TTS model.
- }
- return result;
- }
-
private static float calculateMelCepDistortion(float[][] outputs, float[][] expectedOutputs) {
int inferenceCount = outputs.length;
float squared_error = 0;
@@ -134,6 +121,7 @@
}
}
}
+
return (float)Math.sqrt(squared_error /
(inferenceCount * FRAMES_PER_INFERENCE * (AMPLITUDE_DIMENSION - 1)));
}
diff --git a/src/com/android/nn/benchmark/evaluators/TopK.java b/src/com/android/nn/benchmark/evaluators/TopK.java
index c49b964..845115c 100644
--- a/src/com/android/nn/benchmark/evaluators/TopK.java
+++ b/src/com/android/nn/benchmark/evaluators/TopK.java
@@ -24,6 +24,7 @@
import com.android.nn.benchmark.core.InferenceResult;
import com.android.nn.benchmark.core.OutputMeanStdDev;
import com.android.nn.benchmark.core.ValidationException;
+import com.android.nn.benchmark.util.IOUtils;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -73,17 +74,9 @@
return o2.second.compareTo(o1.second);
}
});
- ByteBuffer buf = ByteBuffer.wrap(result.mInferenceOutput);
- buf.order(ByteOrder.LITTLE_ENDIAN);
- int count = result.mInferenceOutput.length / sequence.mDatasize;
- for (int index = 0; index < count; index++) {
- float probability;
- if (sequence.mDatasize == 4) {
- probability = buf.getFloat();
- } else {
- probability = (float)(buf.get() & 0xff);
- }
- sorted.add(new Pair<Integer, Float>(new Integer(index), new Float(probability)));
+ float[] probabilities = IOUtils.readFloats(result.mInferenceOutput, sequence.mDatasize);
+ for (int index = 0; index < probabilities.length; index++) {
+ sorted.add(new Pair<>(index, probabilities[index]));
}
total++;
boolean seen = false;
diff --git a/src/com/android/nn/benchmark/util/AssetUtils.java b/src/com/android/nn/benchmark/util/IOUtils.java
similarity index 76%
rename from src/com/android/nn/benchmark/util/AssetUtils.java
rename to src/com/android/nn/benchmark/util/IOUtils.java
index fc2dc4d..f16882d 100644
--- a/src/com/android/nn/benchmark/util/AssetUtils.java
+++ b/src/com/android/nn/benchmark/util/IOUtils.java
@@ -21,13 +21,32 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/**
- * Utilities for reading assets, including handling endianness adjustment.
+ * Input/Output utilities.
*/
-public class AssetUtils {
- /** Read input/output data in native byte order */
+public final class IOUtils {
+ private IOUtils() {}
+
+ /** Reads float values from a byte array. */
+ public static float[] readFloats(byte[] bytes, int dataSize) {
+ ByteBuffer buffer = ByteBuffer.wrap(bytes);
+ buffer.order(ByteOrder.LITTLE_ENDIAN);
+ int size = bytes.length / dataSize;
+ float[] result = new float[size];
+ for (int i = 0; i < size; ++i) {
+ if (dataSize == 4) {
+ result[i] = buffer.getFloat();
+ } else if (dataSize == 1) {
+ result[i] = (float)(buffer.get() & 0xff);
+ }
+ }
+ return result;
+ }
+
+ /** Reads data in native byte order */
public static byte[] readAsset(AssetManager assetManager, String assetFilename,
int dataBytesSize)
throws IOException {
@@ -56,7 +75,7 @@
}
}
- /** Reverse endianness on array of 4 byte elements */
+ /** Reverses endianness on array of 4 byte elements */
private static void invertOrder4(byte[] data) {
if (data.length % 4 != 0) {
throw new IllegalArgumentException("Data is not 4 byte aligned");
@@ -71,7 +90,7 @@
}
}
- /** Reverse endianness on array of 2 byte elements */
+ /** Reverses endianness on array of 2 byte elements */
private static void invertOrder2(byte[] data) {
if (data.length % 2 != 0) {
throw new IllegalArgumentException("Data is not 2 byte aligned");