diff options
author | jonmv <venstad@gmail.com> | 2022-11-07 16:18:50 +0100 |
---|---|---|
committer | jonmv <venstad@gmail.com> | 2022-11-07 16:18:50 +0100 |
commit | 36b75df47a8df37e249ecde40c64bd636ea1e455 (patch) | |
tree | 8e6bd29735851696ca94a71e0a5724e0ca7309dd | |
parent | 00c51fa7f8fcc00a79fb22f0dc00d3c597aee6d1 (diff) |
Explicitly consume underlying zip content
3 files changed, 67 insertions, 48 deletions
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java index 7e733e1b74e..021064417ac 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java @@ -87,7 +87,7 @@ public class ApplicationPackageStream { * and the first to be exhausted will populate the truncated application package. */ public InputStream zipStream() { - return new Stream(new ZipInputStream(in.get()), replacer.get(), filter.get(), truncatedPackage); + return new Stream(in.get(), replacer.get(), filter.get(), truncatedPackage); } /** @@ -108,6 +108,7 @@ public class ApplicationPackageStream { private final ByteArrayOutputStream out = new ByteArrayOutputStream(1 << 16); private final ZipOutputStream outZip = new ZipOutputStream(out); private final AtomicReference<ApplicationPackage> truncatedPackage; + private final InputStream in; private final ZipInputStream inZip; private final Replacer replacer; private final Predicate<String> filter; @@ -118,8 +119,9 @@ public class ApplicationPackageStream { private boolean closed = false; private boolean done = false; - private Stream(ZipInputStream inZip, Replacer replacer, Predicate<String> filter, AtomicReference<ApplicationPackage> truncatedPackage) { - this.inZip = inZip; + private Stream(InputStream in, Replacer replacer, Predicate<String> filter, AtomicReference<ApplicationPackage> truncatedPackage) { + this.in = in; + this.inZip = new ZipInputStream(in); this.replacer = replacer; this.filter = filter; this.truncatedPackage = truncatedPackage; @@ -215,7 +217,8 @@ public class ApplicationPackageStream { @Override public void close() { if ( ! closed) try { - transferTo(nullOutputStream()); + transferTo(nullOutputStream()); // Finish reading the zip, to populate the truncated package in case of errors. + in.transferTo(nullOutputStream()); // For some inane reason, ZipInputStream doesn't exhaust its wrapped input. inZip.close(); closed = true; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ZipEntries.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ZipEntries.java index 63915c5050f..185c97f866e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ZipEntries.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ZipEntries.java @@ -36,36 +36,6 @@ public class ZipEntries { this.entries = List.copyOf(Objects.requireNonNull(entries)); } - /** Copies the zipped content from in to out, adding/overwriting an entry with the given name and content. */ - public static void transferAndWrite(OutputStream out, InputStream in, String name, byte[] content) { - transferAndWrite(out, in, Map.of(name, content)); - } - - /** Copies the zipped content from in to out, adding/overwriting/removing (on {@code null}) entries as specified. */ - public static void transferAndWrite(OutputStream out, InputStream in, Map<String, byte[]> entries) { - try (ZipOutputStream zipOut = new ZipOutputStream(out); - ZipInputStream zipIn = new ZipInputStream(in)) { - for (ZipEntry entry = zipIn.getNextEntry(); entry != null; entry = zipIn.getNextEntry()) { - if (entries.containsKey(entry.getName())) - continue; - - zipOut.putNextEntry(new ZipEntry(entry.getName())); - zipIn.transferTo(zipOut); - zipOut.closeEntry(); - } - for (Entry<String, byte[]> entry : entries.entrySet()) { - if (entry.getValue() != null) { - zipOut.putNextEntry(new ZipEntry(entry.getKey())); - zipOut.write(entry.getValue()); - zipOut.closeEntry(); - } - } - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - /** Read ZIP entries from inputStream */ public static ZipEntries from(byte[] zip, Predicate<String> entryNameMatcher, int maxEntrySizeInBytes, boolean throwIfEntryExceedsMaxSize) { diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageTest.java index ab8f696492a..8ac8b87ac45 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageTest.java @@ -15,6 +15,8 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintStream; import java.io.SequenceInputStream; import java.math.BigInteger; import java.nio.file.Files; @@ -22,7 +24,10 @@ import java.nio.file.Path; import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -36,6 +41,7 @@ import static com.yahoo.vespa.hosted.controller.application.pkg.ApplicationPacka import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -97,10 +103,10 @@ public class ApplicationPackageTest { @Test void testMetaData() { byte[] zip = filesZip(Map.of("services.xml", servicesXml.getBytes(UTF_8), - "jdisc.xml", jdiscXml.getBytes(UTF_8), - "content/content.xml", contentXml.getBytes(UTF_8), - "content/nodes.xml", nodesXml.getBytes(UTF_8), - "gurba", "gurba".getBytes(UTF_8))); + "jdisc.xml", jdiscXml.getBytes(UTF_8), + "content/content.xml", contentXml.getBytes(UTF_8), + "content/nodes.xml", nodesXml.getBytes(UTF_8), + "gurba", "gurba".getBytes(UTF_8))); assertEquals(Map.of("services.xml", servicesXml, "jdisc.xml", jdiscXml, @@ -201,6 +207,36 @@ public class ApplicationPackageTest { entry -> entry.getValue().getBytes(UTF_8)))); } + private static class AngryStreams { + + private final byte[] content; + private final Map<ByteArrayInputStream, Throwable> streams = new LinkedHashMap<>(); + + AngryStreams(byte[] content) { + this.content = content; + } + + InputStream stream() { + ByteArrayInputStream stream = new ByteArrayInputStream(Arrays.copyOf(content, content.length)) { + boolean closed = false; + @Override public void close() { closed = true; } + @Override public int read() { assertFalse(closed); return super.read(); } + @Override public int read(byte[] b, int off, int len) { assertFalse(closed); return super.read(b, off, len); } + @Override public long transferTo(OutputStream out) throws IOException { assertFalse(closed); return super.transferTo(out); } + @Override public byte[] readAllBytes() { assertFalse(closed); return super.readAllBytes(); } + }; + streams.put(stream, new Throwable()); + return stream; + } + + void verifyAllRead() { + streams.forEach((stream, stack) -> assertEquals(0, stream.available(), + "unconsumed content in stream created at " + + new ByteArrayOutputStream() {{ stack.printStackTrace(new PrintStream(this)); }})); + } + + } + @Test void testApplicationPackageStream() throws Exception { Map<String, String> content = Map.of("deployment.xml", deploymentXml, @@ -212,15 +248,16 @@ public class ApplicationPackageTest { "gurba", "gurba"); byte[] zip = zip(content); assertEquals(content, unzip(zip)); + AngryStreams angry = new AngryStreams(zip); - ApplicationPackageStream identity = new ApplicationPackageStream(() -> new ByteArrayInputStream(zip)); + ApplicationPackageStream identity = new ApplicationPackageStream(angry::stream); InputStream lazy = new LazyInputStream(() -> new ByteArrayInputStream(identity.truncatedPackage().zippedContent())); assertEquals("must completely exhaust input before reading package", assertThrows(IllegalStateException.class, identity::truncatedPackage).getMessage()); // Verify no content has changed when passing through the stream. ByteArrayOutputStream out = new ByteArrayOutputStream(); - identity.zipStream().transferTo(out); + try (InputStream stream = identity.zipStream()) { stream.transferTo(out); } assertEquals(content, unzip(out.toByteArray())); assertEquals(content, unzip(identity.truncatedPackage().zippedContent())); assertEquals(content, unzip(lazy.readAllBytes())); @@ -233,13 +270,13 @@ public class ApplicationPackageTest { "unused1.xml", in -> null, "unused2.xml", __ -> new ByteArrayInputStream(jdiscXml.getBytes(UTF_8))); Predicate<String> truncation = name -> name.endsWith(".xml"); - ApplicationPackageStream modifier = new ApplicationPackageStream(() -> new ByteArrayInputStream(Arrays.copyOf(zip, zip.length)), () -> truncation, replacements); + ApplicationPackageStream modifier = new ApplicationPackageStream(angry::stream, () -> truncation, replacements); out.reset(); InputStream partiallyRead = modifier.zipStream(); assertEquals(15, partiallyRead.readNBytes(15).length); - modifier.zipStream().transferTo(out); + try (InputStream stream = modifier.zipStream()) { stream.transferTo(out); } assertEquals(Map.of("deployment.xml", deploymentXml + "\n\n", "services.xml", servicesXml, @@ -268,18 +305,27 @@ public class ApplicationPackageTest { "gurba", "gurba"))).metaDataZip()), unzip(modifier.truncatedPackage().metaDataZip())); - assertArrayEquals(modifier.zipStream().readAllBytes(), - modifier.zipStream().readAllBytes()); + try (InputStream stream1 = modifier.zipStream(); + InputStream stream2 = modifier.zipStream()) { + assertArrayEquals(stream1.readAllBytes(), + stream2.readAllBytes()); + } ByteArrayOutputStream byteAtATime = new ByteArrayOutputStream(); - try (InputStream stream = modifier.zipStream()) { - for (int b; (b = stream.read()) != -1; ) byteAtATime.write(b); - assertArrayEquals(modifier.zipStream().readAllBytes(), + try (InputStream stream1 = modifier.zipStream(); + InputStream stream2 = modifier.zipStream()) { + for (int b; (b = stream1.read()) != -1; ) byteAtATime.write(b); + assertArrayEquals(stream2.readAllBytes(), byteAtATime.toByteArray()); } - assertEquals(modifier.zipStream().readAllBytes().length, + assertEquals(byteAtATime.size(), 15 + partiallyRead.readAllBytes().length); + partiallyRead.close(); + + try (InputStream stream = modifier.zipStream()) { stream.readNBytes(12); } + + angry.verifyAllRead(); } } |