diff options
-rw-r--r-- | container-core/src/main/java/com/yahoo/restapi/HttpURL.java | 286 | ||||
-rw-r--r-- | container-core/src/test/java/com/yahoo/restapi/HttpURLTest.java | 39 |
2 files changed, 159 insertions, 166 deletions
diff --git a/container-core/src/main/java/com/yahoo/restapi/HttpURL.java b/container-core/src/main/java/com/yahoo/restapi/HttpURL.java index a43c5998c79..7a5986ed067 100644 --- a/container-core/src/main/java/com/yahoo/restapi/HttpURL.java +++ b/container-core/src/main/java/com/yahoo/restapi/HttpURL.java @@ -14,7 +14,8 @@ import java.util.Map; import java.util.Objects; import java.util.OptionalInt; import java.util.StringJoiner; -import java.util.function.Function; +import java.util.function.Consumer; +import java.util.function.UnaryOperator; import static ai.vespa.validation.Validation.require; import static ai.vespa.validation.Validation.requireInRange; @@ -23,7 +24,6 @@ import static java.net.URLEncoder.encode; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Collections.unmodifiableMap; import static java.util.Objects.requireNonNull; -import static java.util.function.Function.identity; /** * This is the best class for creating, manipulating and inspecting HTTP URLs, because: @@ -49,15 +49,15 @@ import static java.util.function.Function.identity; * * @author jonmv */ -public class HttpURL<T> { +public class HttpURL { private final Scheme scheme; private final DomainName domain; private final int port; - private final Path<T> path; - private final Query<T> query; + private final Path path; + private final Query query; - private HttpURL(Scheme scheme, DomainName domain, int port, Path<T> path, Query<T> query) { + private HttpURL(Scheme scheme, DomainName domain, int port, Path path, Query query) { this.scheme = requireNonNull(scheme); this.domain = requireNonNull(domain); this.port = requireInRange(port, "port number", -1, (1 << 16) - 1); @@ -65,74 +65,58 @@ public class HttpURL<T> { this.query = requireNonNull(query); } - public static <T> HttpURL<T> create(Scheme scheme, DomainName domain, int port, Path<T> path, Query<T> query) { - return new HttpURL<>(scheme, domain, port, path, query); + public static HttpURL create(Scheme scheme, DomainName domain, int port, Path path, Query query) { + return new HttpURL(scheme, domain, port, path, query); } - public static HttpURL<String> create(Scheme scheme, DomainName domain, int port, Path<String> path) { + public static HttpURL create(Scheme scheme, DomainName domain, int port, Path path) { return create(scheme, domain, port, path, Query.empty()); } - public static <T extends StringWrapper<T>> HttpURL<T> create(Scheme scheme, DomainName domain, int port, Path<T> path, Function<String, T> validator) { - return create(scheme, domain, port, path, Query.empty(validator)); + public static HttpURL create(Scheme scheme, DomainName domain, int port) { + return create(scheme, domain, port, Path.empty(), Query.empty()); } - public static <T extends StringWrapper<T>> HttpURL<T> create(Scheme scheme, DomainName domain, int port, Function<String, T> validator) { - return create(scheme, domain, port, Path.empty(validator), validator); - } - - public static HttpURL<String> create(Scheme scheme, DomainName domain, int port) { - return create(scheme, domain, port, Path.empty()); - } - - public static <T extends StringWrapper<T>> HttpURL<T> create(Scheme scheme, DomainName domain, Function<String, T> validator) { - return create(scheme, domain, -1, validator); - } - - public static HttpURL<String> create(Scheme scheme, DomainName domain) { + public static HttpURL create(Scheme scheme, DomainName domain) { return create(scheme, domain, -1); } - public static HttpURL<String> from(URI uri) { - return from(uri, identity(), identity()); - } - - public static <T extends StringWrapper<T>> HttpURL<T> from(URI uri, Function<String, T> validator) { - return from(uri, validator, T::value); + public static HttpURL from(URI uri) { + return from(uri, HttpURL::requirePathSegment, HttpURL::requireNothing); } - private static <T> HttpURL<T> from(URI uri, Function<String, T> validator, Function<T, String> inverse) { + public static HttpURL from(URI uri, Consumer<String> pathValidator, Consumer<String> queryValidator) { if ( ! uri.normalize().equals(uri)) throw new IllegalArgumentException("uri should be normalized, but got: " + uri); return create(Scheme.of(uri.getScheme()), DomainName.of(requireNonNull(uri.getHost(), "URI must specify a host")), uri.getPort(), - Path.parse(uri.getRawPath(), validator, inverse), - Query.parse(uri.getRawQuery(), validator, inverse)); + Path.parse(uri.getRawPath(), pathValidator), + Query.parse(uri.getRawQuery(), queryValidator)); } - public HttpURL<T> withScheme(Scheme scheme) { + public HttpURL withScheme(Scheme scheme) { return create(scheme, domain, port, path, query); } - public HttpURL<T> withDomain(DomainName domain) { + public HttpURL withDomain(DomainName domain) { return create(scheme, domain, port, path, query); } - public HttpURL<T> withPort(int port) { + public HttpURL withPort(int port) { return create(scheme, domain, port, path, query); } - public HttpURL<T> withoutPort() { + public HttpURL withoutPort() { return create(scheme, domain, -1, path, query); } - public HttpURL<T> withPath(Path<T> path) { + public HttpURL withPath(Path path) { return create(scheme, domain, port, path, query); } - public HttpURL<T> withQuery(Query<T> query) { + public HttpURL withQuery(Query query) { return create(scheme, domain, port, path, query); } @@ -148,11 +132,11 @@ public class HttpURL<T> { return port == -1 ? OptionalInt.empty() : OptionalInt.of(port); } - public Path<T> path() { + public Path path() { return path; } - public Query<T> query() { + public Query query() { return query; } @@ -166,59 +150,68 @@ public class HttpURL<T> { } } + /** Require that the given string contains no {@code '/'}, or anything that could be URL-decoded to one. */ + public static void requirePathSegment(String value) { + require( ! value.contains("/") && ! value.matches(".*%(25)*(%(25)*32|2)(%(25)*([46])6|F|f).*"), value, "path segment cannot contain '/'"); + } + + private static void requireNothing(String value) { } - public static class Path<T> { + public static class Path { - private final List<T> segments; + private final List<String> segments; private final boolean trailingSlash; - private final Function<String, T> validator; - private final Function<T, String> inverse; + private final UnaryOperator<String> validator; - private Path(List<T> segments, boolean trailingSlash, Function<String, T> validator, Function<T, String> inverse) { + private Path(List<String> segments, boolean trailingSlash, UnaryOperator<String> validator) { this.segments = requireNonNull(segments); this.trailingSlash = trailingSlash; this.validator = requireNonNull(validator); - this.inverse = requireNonNull(inverse); } /** Creates a new, empty path, with a trailing slash. */ - public static Path<String> empty() { - return new Path<>(List.of(), true, identity(), identity()); + public static Path empty() { + return empty(__ -> { }); } - /** Creates a new, empty path, with a trailing slash, using the indicated string wrapper for segments. */ - public static <T extends StringWrapper<T>> Path<T> empty(Function<String, T> validator) { - return new Path<>(List.of(), true, validator, T::value); + /** Creates a new, empty path, with a trailing slash, using the indicated validator for segments. */ + public static Path empty(Consumer<String> validator) { + return new Path(List.of(), true, segmentValidator(validator)); } + /** Creates a new path with the given <em>decoded</em> segments. */ - public static Path<String> from(List<String> segments) { - return empty().append(segments); + public static Path from(List<String> segments) { + return from(segments, __ -> { }); } /** Creates a new path with the given <em>decoded</em> segments, and the validator applied to each segment. */ - public static <T extends StringWrapper<T>> Path<T> from(List<String> segments, Function<String, T> validator) { - return empty(validator).append(segments, identity(), true); - } - - /** Parses the given raw, normalized path string; this ignores whether the path is absolute or relative.) */ - public static <T extends StringWrapper<T>> Path<T> parse(String raw, Function<String, T> validator) { - return parse(raw, validator, T::value); + public static Path from(List<String> segments, Consumer<String> validator) { + return empty(validator).append(segments, true); } /** Parses the given raw, normalized path string; this ignores whether the path is absolute or relative. */ - public static Path<String> parse(String raw) { - return parse(raw, identity(), identity()); + public static Path parse(String raw) { + return parse(raw, HttpURL::requirePathSegment); } - private static <T> Path<T> parse(String raw, Function<String, T> validator, Function<T, String> inverse) { - boolean trailingSlash = raw.endsWith("/"); + /** Parses the given raw, normalized path string; this ignores whether the path is absolute or relative.) */ + public static Path parse(String raw, Consumer<String> validator) { + Path base = new Path(List.of(), raw.endsWith("/"), segmentValidator(validator)); if (raw.startsWith("/")) raw = raw.substring(1); - if (raw.isEmpty()) return new Path<>(List.of(), trailingSlash, validator, inverse); - List<T> segments = new ArrayList<>(); - for (String segment : raw.split("/")) - segments.add(validator.apply(requireNonNormalizable(decode(segment, UTF_8)))); - if (segments.size() == 0) requireNonNormalizable(""); // Raw path was only slashes. - return new Path<>(segments, trailingSlash, validator, inverse); + if (raw.isEmpty()) return base; + List<String> segments = new ArrayList<>(); + for (String segment : raw.split("/")) segments.add(decode(segment, UTF_8)); + if (segments.isEmpty()) requireNonNormalizable(""); // Raw path was only slashes. + return base.append(segments); + } + + private static UnaryOperator<String> segmentValidator(Consumer<String> validator) { + requireNonNull(validator, "segment validator cannot be null"); + return value -> { + requireNonNormalizable(value); + validator.accept(value); + return value; + }; } private static String requireNonNormalizable(String segment) { @@ -227,56 +220,55 @@ public class HttpURL<T> { } /** Returns a copy of this where the first segments are skipped. */ - public Path<T> skip(int count) { - return new Path<>(segments.subList(count, segments.size()), trailingSlash, validator, inverse); + public Path skip(int count) { + return new Path(segments.subList(count, segments.size()), trailingSlash, validator); } /** Returns a copy of this where the last segments are cut off. */ - public Path<T> cut(int count) { - return new Path<>(segments.subList(0, segments.size() - count), trailingSlash, validator, inverse); + public Path cut(int count) { + return new Path(segments.subList(0, segments.size() - count), trailingSlash, validator); } /** Returns a copy of this with the <em>decoded</em> segment appended at the end; it may not be either of {@code ""}, {@code "."} or {@code ".."}. */ - public Path<T> append(String segment) { - return append(List.of(segment), identity(), trailingSlash); + public Path append(String segment) { + return append(List.of(segment), trailingSlash); } /** Returns a copy of this all segments of the other path appended, with a trailing slash as per the appendage. */ - public <U> Path<T> append(Path<U> other) { - return append(other.segments, other.inverse, other.trailingSlash); + public Path append(Path other) { + return append(other.segments, other.trailingSlash); } /** Returns a copy of this all given segments appended, with a trailing slash as per this path. */ - public Path<T> append(List<T> segments) { - return append(segments, inverse, trailingSlash); + public Path append(List<String> segments) { + return append(segments, trailingSlash); } - private <U> Path<T> append(List<U> segments, Function<U, String> inverse, boolean trailingSlash) { - List<T> copy = new ArrayList<>(this.segments); - for (U segment : segments) copy.add(validator.apply(requireNonNormalizable(inverse.apply(segment)))); - return new Path<>(copy, trailingSlash, validator, this.inverse); + private Path append(List<String> segments, boolean trailingSlash) { + List<String> copy = new ArrayList<>(this.segments); + for (String segment : segments) copy.add(validator.apply(segment)); + return new Path(copy, trailingSlash, validator); } /** Returns a copy of this which encodes a trailing slash. */ - public Path<T> withTrailingSlash() { - return new Path<>(segments, true, validator, inverse); + public Path withTrailingSlash() { + return new Path(segments, true, validator); } /** Returns a copy of this which does not encode a trailing slash. */ - public Path<T> withoutTrailingSlash() { - return new Path<>(segments, false, validator, inverse); + public Path withoutTrailingSlash() { + return new Path(segments, false, validator); } /** The <em>URL decoded</em> segments that make up this path; never {@code null}, {@code ""}, {@code "."} or {@code ".."}. */ - public List<T> segments() { + public List<String> segments() { return Collections.unmodifiableList(segments); } /** A raw path string which parses to this, by splitting on {@code "/"}, and then URL decoding. */ - private String raw() { + String raw() { StringJoiner joiner = new StringJoiner("/", "/", trailingSlash ? "/" : "").setEmptyValue(trailingSlash ? "/" : ""); - for (T segment : segments) - joiner.add(encode(inverse.apply(segment), UTF_8)); + for (String segment : segments) joiner.add(encode(segment, UTF_8)); return joiner.toString(); } @@ -290,7 +282,7 @@ public class HttpURL<T> { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - Path<?> path = (Path<?>) o; + Path path = (Path) o; return trailingSlash == path.trailingSlash && segments.equals(path.segments); } @@ -302,108 +294,106 @@ public class HttpURL<T> { } - public static class Query<T> { + public static class Query { - private final Map<T, T> values; - private final Function<String, T> validator; - private final Function<T, String> inverse; + private final Map<String, String> values; + private final UnaryOperator<String> validator; - private Query(Map<T, T> values, Function<String, T> validator, Function<T, String> inverse) { + private Query(Map<String, String> values, UnaryOperator<String> validator) { this.values = requireNonNull(values); this.validator = requireNonNull(validator); - this.inverse = requireNonNull(inverse); } /** Creates a new, empty query part. */ - public static Query<String> empty() { - return new Query<>(Map.of(), identity(), identity()); + public static Query empty() { + return empty(__ -> { }); } /** Creates a new, empty query part, using the indicated string wrapper for keys and non-null values. */ - public static <T extends StringWrapper<T>> Query<T> empty(Function<String, T> validator) { - return new Query<>(Map.of(), validator, T::value); + public static Query empty(Consumer<String> validator) { + return new Query(Map.of(), entryValidator(validator)); } + /** Creates a new query part with the given <em>decoded</em> values. */ - public static Query<String> from(Map<String, String> values) { - return empty().merge(values); + public static Query from(Map<String, String> values) { + return from(values, __ -> { }); } /** Creates a new query part with the given <em>decoded</em> values, and the validator applied to each pair. */ - public static <T extends StringWrapper<T>> Query<T> from(Map<String, String> values, Function<String, T> validator) { - return empty(validator).merge(values, identity()); - } - - /** Parses the given raw query string, using the indicated string wrapper to hold keys and non-null values. */ - public static <T extends StringWrapper<T>> Query<T> parse(String raw, Function<String, T> validator) { - return parse(raw, validator, T::value); + public static Query from(Map<String, String> values, Consumer<String> validator) { + return empty(validator).merge(values); } /** Parses the given raw query string. */ - public static Query<String> parse(String raw) { - return parse(raw, identity(), identity()); + public static Query parse(String raw) { + return parse(raw, __-> { }); } - - private static <T> Query<T> parse(String raw, Function<String, T> validator, Function<T, String> inverse) { - if (raw == null) return new Query<>(Map.of(), validator, inverse); - Map<T, T> values = new LinkedHashMap<>(); + /** Parses the given raw query string, using the indicated string wrapper to hold keys and non-null values. */ + public static Query parse(String raw, Consumer<String> validator) { + if (raw == null) return empty(validator); + Map<String, String> values = new LinkedHashMap<>(); for (String pair : raw.split("&")) { int split = pair.indexOf("="); String key, value; if (split == -1) { key = pair; value = null; } else { key = pair.substring(0, split); value = pair.substring(split + 1); } - values.put(validator.apply(decode(key, UTF_8)), value == null ? null : validator.apply(decode(value, UTF_8))); + values.put(decode(key, UTF_8), value == null ? null : decode(value, UTF_8)); } - return new Query<>(values, validator, inverse); + return empty(validator).merge(values); + } + + private static UnaryOperator<String> entryValidator(Consumer<String> validator) { + requireNonNull(validator); + return value -> { + validator.accept(value); + return value; + }; } /** Returns a copy of this with the <em>decoded</em> non-null key pointing to the <em>decoded</em> non-null value. */ - public Query<T> put(String key, String value) { - Map<T, T> copy = new LinkedHashMap<>(values); - copy.put(requireNonNull(validator.apply(key)), requireNonNull(validator.apply(value))); - return new Query<>(copy, validator, inverse); + public Query put(String key, String value) { + Map<String, String> copy = new LinkedHashMap<>(values); + copy.put(validator.apply(requireNonNull(key)), validator.apply(requireNonNull(value))); + return new Query(copy, validator); } /** Returns a copy of this with the <em>decoded</em> non-null key pointing to "nothing". */ - public Query<T> add(String key) { - Map<T, T> copy = new LinkedHashMap<>(values); - copy.put(requireNonNull(validator.apply(key)), null); - return new Query<>(copy, validator, inverse); + public Query add(String key) { + Map<String, String> copy = new LinkedHashMap<>(values); + copy.put(validator.apply(requireNonNull(key)), null); + return new Query(copy, validator); } /** Returns a copy of this without any key-value pair with the <em>decoded</em> key. */ - public Query<T> remove(String key) { - Map<T, T> copy = new LinkedHashMap<>(values); - copy.remove(requireNonNull(validator.apply(key))); - return new Query<>(copy, validator, inverse); + public Query remove(String key) { + Map<String, String> copy = new LinkedHashMap<>(values); + copy.remove(validator.apply(requireNonNull(key))); + return new Query(copy, validator); } /** Returns a copy of this with all mappings from the other query added to this, possibly overwriting existing mappings. */ - public <U> Query<T> merge(Query<U> other) { - return merge(other.values, other.inverse); + public Query merge(Query other) { + return merge(other.values); } /** Returns a copy of this with all given mappings added to this, possibly overwriting existing mappings. */ - public Query<T> merge(Map<T, T> values) { - return merge(values, inverse); - } - - private <U> Query<T> merge(Map<U, U> values, Function<U, String> inverse) { - Map<T, T> copy = new LinkedHashMap<>(this.values); - values.forEach((key, value) -> copy.put(validator.apply(inverse.apply(requireNonNull(key, "keys cannot be null"))), - value == null ? null : validator.apply(inverse.apply(value)))); - return new Query<>(copy, validator, this.inverse); + public Query merge(Map<String, String> values) { + Map<String, String> copy = new LinkedHashMap<>(this.values); + values.forEach((key, value) -> copy.put(validator.apply(requireNonNull(key, "keys cannot be null")), + value == null ? null : validator.apply(value))); + return new Query(copy, validator); } /** The <em>URL decoded</em> key-value pairs that make up this query; keys and values may be {@code ""}, and values are {@code null} when only key was specified. */ - public Map<T, T> entries() { + public Map<String, String> entries() { return unmodifiableMap(values); } /** A raw query string, with {@code '?'} prepended, that parses to this, by splitting on {@code "&"}, then on {@code "="}, and then URL decoding; or the empty string if this is empty. */ private String raw() { StringJoiner joiner = new StringJoiner("&", "?", "").setEmptyValue(""); - values.forEach((key, value) -> joiner.add(encode(inverse.apply(key), UTF_8) + - (value == null ? "" : "=" + encode(inverse.apply(value), UTF_8)))); + values.forEach((key, value) -> joiner.add(encode(key, UTF_8) + + (value == null ? "" : "=" + encode(value, UTF_8)))); return joiner.toString(); } @@ -417,7 +407,7 @@ public class HttpURL<T> { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - Query<?> query = (Query<?>) o; + Query query = (Query) o; return values.equals(query.values); } diff --git a/container-core/src/test/java/com/yahoo/restapi/HttpURLTest.java b/container-core/src/test/java/com/yahoo/restapi/HttpURLTest.java index de20fcb3193..4354f5ee3ea 100644 --- a/container-core/src/test/java/com/yahoo/restapi/HttpURLTest.java +++ b/container-core/src/test/java/com/yahoo/restapi/HttpURLTest.java @@ -12,6 +12,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.OptionalInt; +import java.util.function.Consumer; import static com.yahoo.net.DomainName.localhost; import static com.yahoo.restapi.HttpURL.Scheme.http; @@ -33,14 +34,16 @@ class HttpURLTest { "https://strange/queries?=&foo", "https://weirdness?=foo", "https://encoded/%3F%3D%26%2F?%3F%3D%26%2F=%3F%3D%26%2F", - "https://host.at.domain:123/one/two/?three=four&five")) - assertEquals(uri, HttpURL.from(URI.create(uri)).asURI().toString(), + "https://host.at.domain:123/one/two/?three=four&five")) { + Consumer<String> pathValidator = __ -> { }; + assertEquals(uri, HttpURL.from(URI.create(uri), pathValidator, pathValidator).asURI().toString(), "uri '" + uri + "' should be returned unchanged"); + } } @Test void testModification() { - HttpURL<Name> url = HttpURL.create(http, localhost, Name::of); + HttpURL url = HttpURL.create(http, localhost).withPath(HttpURL.Path.empty(Name::of)); assertEquals(http, url.scheme()); assertEquals(localhost, url.domain()); assertEquals(OptionalInt.empty(), url.port()); @@ -91,21 +94,21 @@ class HttpURLTest { assertEquals("name must match '[A-Za-z][A-Za-z0-9_-]{0,63}', but got: '/'", assertThrows(IllegalArgumentException.class, - () -> HttpURL.from(URI.create("http://foo/%2F"), Name::of)).getMessage()); + () -> HttpURL.from(URI.create("http://foo/%2F"), Name::of, Name::of)).getMessage()); assertEquals("name must match '[A-Za-z][A-Za-z0-9_-]{0,63}', but got: '/'", assertThrows(IllegalArgumentException.class, - () -> HttpURL.from(URI.create("http://foo?%2F"), Name::of)).getMessage()); + () -> HttpURL.from(URI.create("http://foo?%2F"), Name::of, Name::of)).getMessage()); assertEquals("name must match '[A-Za-z][A-Za-z0-9_-]{0,63}', but got: ''", assertThrows(IllegalArgumentException.class, - () -> HttpURL.from(URI.create("http://foo?"), Name::of)).getMessage()); + () -> HttpURL.from(URI.create("http://foo?"), Name::of, Name::of)).getMessage()); } @Test void testPath() { - HttpURL.Path<Name> path = HttpURL.Path.parse("foo/bar/baz", Name::of); - List<Name> expected = List.of(Name.of("foo"), Name.of("bar"), Name.of("baz")); + HttpURL.Path path = HttpURL.Path.parse("foo/bar/baz", Name::of); + List<String> expected = List.of("foo", "bar", "baz"); assertEquals(expected, path.segments()); assertEquals(expected.subList(1, 3), path.skip(1).segments()); @@ -124,7 +127,7 @@ class HttpURLTest { assertThrows(NullPointerException.class, () -> path.append((String) null)); - List<Name> names = new ArrayList<>(); + List<String> names = new ArrayList<>(); names.add(null); assertThrows(NullPointerException.class, () -> path.append(names)); @@ -140,17 +143,17 @@ class HttpURLTest { @Test void testQuery() { - Query<Name> query = Query.parse("foo=bar&baz", Name::of); - Map<Name, Name> expected = new LinkedHashMap<>(); - expected.put(Name.of("foo"), Name.of("bar")); - expected.put(Name.of("baz"), null); + Query query = Query.parse("foo=bar&baz", Name::of); + Map<String, String> expected = new LinkedHashMap<>(); + expected.put("foo", "bar"); + expected.put("baz", null); assertEquals(expected, query.entries()); - expected.remove(Name.of("baz")); + expected.remove("baz"); assertEquals(expected, query.remove("baz").entries()); - expected.put(Name.of("baz"), null); - expected.remove(Name.of("foo")); + expected.put("baz", null); + expected.remove("foo"); assertEquals(expected, query.remove("foo").entries()); assertEquals(expected, Query.empty(Name::of).add("baz").entries()); @@ -169,8 +172,8 @@ class HttpURLTest { assertThrows(NullPointerException.class, () -> query.put("hax", null)); - Map<Name, Name> names = new LinkedHashMap<>(); - names.put(null, Name.of("hax")); + Map<String, String> names = new LinkedHashMap<>(); + names.put(null, "hax"); assertThrows(NullPointerException.class, () -> query.merge(names)); } |