Extend TimeZoneDistro to support non-memory origin

By switching to a Supplier<InputStream> instead of byte[]
the data does not need to be loaded into memory.

Test: run cts -m CtsLibcoreTestCases
Bug: 31008728
Change-Id: I980fc05a64276da99376b5688bc19b2bb4615c21
diff --git a/distro/core/src/main/com/android/timezone/distro/TimeZoneDistro.java b/distro/core/src/main/com/android/timezone/distro/TimeZoneDistro.java
index 8fe503d..d2c26db 100644
--- a/distro/core/src/main/com/android/timezone/distro/TimeZoneDistro.java
+++ b/distro/core/src/main/com/android/timezone/distro/TimeZoneDistro.java
@@ -22,12 +22,13 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.util.Arrays;
+import java.util.function.Supplier;
 import java.util.zip.ZipEntry;
 import java.util.zip.ZipInputStream;
 
 /**
- * A time zone distro. This is a thin wrapper around some in-memory bytes representing a zip
- * archive and logic for its safe extraction.
+ * A time zone distro. This is a thin wrapper around a supplier of bytes for a zip archive and logic
+ * for its safe extraction.
  */
 public final class TimeZoneDistro {
 
@@ -55,15 +56,27 @@
      */
     private static final long MAX_GET_ENTRY_CONTENTS_SIZE = 128 * 1024;
 
-    private final byte[] bytes;
+    private final Supplier<InputStream> inputStreamSupplier;
 
+    /**
+     * Creates a TimeZoneDistro using a byte array. Objects created in this way can be compared
+     * using {@link #equals(Object)} to compare backing arrays.
+     */
     public TimeZoneDistro(byte[] bytes) {
-        this.bytes = bytes;
+        this(new ByteStreamSupplier(bytes));
+    }
+
+    /**
+     * Creates a TimeZoneDistro using a {@link Supplier<InputStream>}. Objects created in this way
+     * can only be compared using {@link #equals(Object)} if the supplier implementation correctly
+     * implements {@link #equals(Object)}.
+     */
+    public TimeZoneDistro(Supplier<InputStream> inputStreamSupplier) {
+        this.inputStreamSupplier = inputStreamSupplier;
     }
 
     public DistroVersion getDistroVersion() throws DistroException, IOException {
-        byte[] contents = getEntryContents(
-                new ByteArrayInputStream(bytes), DISTRO_VERSION_FILE_NAME);
+        byte[] contents = getEntryContents(inputStreamSupplier.get(), DISTRO_VERSION_FILE_NAME);
         if (contents == null) {
             throw new DistroException("Distro version file entry not found");
         }
@@ -98,7 +111,7 @@
     }
 
     public void extractTo(File targetDir) throws IOException {
-        extractZipSafely(new ByteArrayInputStream(bytes), targetDir, true /* makeWorldReadable */);
+        extractZipSafely(inputStreamSupplier.get(), targetDir, true /* makeWorldReadable */);
     }
 
     /** Visible for testing */
@@ -156,10 +169,48 @@
 
         TimeZoneDistro that = (TimeZoneDistro) o;
 
-        if (!Arrays.equals(bytes, that.bytes)) {
-            return false;
+        return inputStreamSupplier.equals(that.inputStreamSupplier);
+    }
+
+    @Override
+    public int hashCode() {
+        return inputStreamSupplier.hashCode();
+    }
+
+    /**
+     * An implementation of {@link Supplier<InputStream>} wrapping a byte array that implements
+     * equals() for convenient comparison during tests.
+     */
+    private static class ByteStreamSupplier implements Supplier<InputStream> {
+
+        private final byte[] bytes;
+
+        ByteStreamSupplier(byte[] bytes) {
+            this.bytes = bytes;
         }
 
-        return true;
+        @Override
+        public InputStream get() {
+            return new ByteArrayInputStream(bytes);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (o == null || getClass() != o.getClass()) {
+                return false;
+            }
+
+            ByteStreamSupplier that = (ByteStreamSupplier) o;
+
+            return Arrays.equals(bytes, that.bytes);
+        }
+
+        @Override
+        public int hashCode() {
+            return Arrays.hashCode(bytes);
+        }
     }
 }
diff --git a/distro/core/src/test/com/android/timezone/distro/TimeZoneDistroTest.java b/distro/core/src/test/com/android/timezone/distro/TimeZoneDistroTest.java
index 8f5fcc0..61f6a05 100644
--- a/distro/core/src/test/com/android/timezone/distro/TimeZoneDistroTest.java
+++ b/distro/core/src/test/com/android/timezone/distro/TimeZoneDistroTest.java
@@ -27,6 +27,7 @@
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.function.Supplier;
 import java.util.zip.ZipEntry;
 import java.util.zip.ZipOutputStream;
 import libcore.io.IoUtils;
@@ -61,6 +62,55 @@
         assertEquals(distroVersion, distro.getDistroVersion());
     }
 
+    public void testGetDistroVersion_closesStream() throws Exception {
+        DistroVersion distroVersion = new DistroVersion(DistroVersion.CURRENT_FORMAT_MAJOR_VERSION,
+                DistroVersion.CURRENT_FORMAT_MINOR_VERSION, "2016c", 1);
+        ByteArrayOutputStream baos = new ByteArrayOutputStream();
+        try (ZipOutputStream zipOutputStream = new ZipOutputStream(baos)) {
+            addZipEntry(zipOutputStream, TimeZoneDistro.DISTRO_VERSION_FILE_NAME,
+                    distroVersion.toBytes());
+        }
+        byte[] bytes = baos.toByteArray();
+
+        TestInputStreamSupplier inputStreamSupplier = new TestInputStreamSupplier(bytes);
+        TimeZoneDistro distro = new TimeZoneDistro(inputStreamSupplier);
+        assertEquals(distroVersion, distro.getDistroVersion());
+
+        inputStreamSupplier.assertStreamCount(1);
+        inputStreamSupplier.getInputStreamStream(0).assertClosed();
+    }
+
+    public void testExtractTo_closesStream() throws Exception {
+        DistroVersion distroVersion = new DistroVersion(DistroVersion.CURRENT_FORMAT_MAJOR_VERSION,
+                DistroVersion.CURRENT_FORMAT_MINOR_VERSION, "2016c", 1);
+        ByteArrayOutputStream baos = new ByteArrayOutputStream();
+        try (ZipOutputStream zipOutputStream = new ZipOutputStream(baos)) {
+            addZipEntry(zipOutputStream, TimeZoneDistro.DISTRO_VERSION_FILE_NAME,
+                    distroVersion.toBytes());
+        }
+        byte[] bytes = baos.toByteArray();
+
+        TestInputStreamSupplier inputStreamSupplier = new TestInputStreamSupplier(bytes);
+        TimeZoneDistro distro = new TimeZoneDistro(inputStreamSupplier);
+        distro.extractTo(createTempDir());
+
+        inputStreamSupplier.assertStreamCount(1);
+        inputStreamSupplier.getInputStreamStream(0).assertClosed();
+    }
+
+    public void testBytesConstructorEquals() throws Exception {
+        byte[] bytes1 = new byte[4];
+        byte[] sameAsBytes1 = new byte[4];
+        byte[] bytes2 = new byte[5];
+
+        TimeZoneDistro distro1 = new TimeZoneDistro(bytes1);
+        assertEquals(distro1, distro1);
+
+        assertEquals(new TimeZoneDistro(sameAsBytes1), distro1);
+
+        assertFalse(new TimeZoneDistro(bytes2).equals(distro1));
+    }
+
     public void testExtractZipSafely_goodZip() throws Exception {
         ByteArrayOutputStream baos = new ByteArrayOutputStream();
         try (ZipOutputStream zipOutputStream = new ZipOutputStream(baos)) {
@@ -167,4 +217,29 @@
             assertTrue(closed);
         }
     }
+
+    private static class TestInputStreamSupplier implements Supplier<InputStream> {
+
+        private List<TestInputStream> inputStreams = new ArrayList<>();
+        private final byte[] bytes;
+
+        TestInputStreamSupplier(byte[] bytes) {
+            this.bytes = bytes;
+        }
+
+        @Override
+        public InputStream get() {
+            TestInputStream is = new TestInputStream(new ByteArrayInputStream(bytes));
+            inputStreams.add(is);
+            return is;
+        }
+
+        public void assertStreamCount(int expected) {
+            assertEquals(expected, inputStreams.size());
+        }
+
+        public TestInputStream getInputStreamStream(int index) {
+            return inputStreams.get(index);
+        }
+    }
 }