diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-06-06 11:04:12 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-06-06 11:04:12 +0200 |
commit | 2685f35d1161ec93710a62b465b63e0cc152aba1 (patch) | |
tree | 8b1c71e37e810b0dd46799ef527656b05e6d1cc7 /vespajlib | |
parent | f6cccb8d88a611eaefbeebc1eac5928a64fb7cb5 (diff) |
Validate that tensor dimensions and labbels are identifiers
Diffstat (limited to 'vespajlib')
3 files changed, 53 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 1159d2fb32e..2a713611307 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -4,6 +4,7 @@ package com.yahoo.tensor; import java.util.Arrays; import java.util.Objects; import java.util.Optional; +import java.util.regex.Pattern; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -158,6 +159,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { /** Supports building of a tensor address */ public static class Builder { + private Pattern identifierPattern = Pattern.compile("[A-Za-z0-9_]+"); + private final TensorType type; private final String[] labels; @@ -176,8 +179,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { * @return this for convenience */ public Builder add(String dimension, String label) { - Objects.requireNonNull(dimension, "Dimension cannot be null"); - Objects.requireNonNull(label, "Label cannot be null"); + requireIdentifier(dimension, "dimension"); + requireIdentifier(label, "label"); Optional<Integer> labelIndex = type.indexOfDimension(dimension); if ( ! labelIndex.isPresent()) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); @@ -198,6 +201,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return TensorAddress.of(labels); } + private void requireIdentifier(String s, String parameterName) { + if (s == null) + throw new IllegalArgumentException(parameterName + " can not be null"); + if ( ! identifierPattern.matcher(s).matches()) + throw new IllegalArgumentException(parameterName + " must be an identifier or integer, not '" + s + "'"); + } + } } diff --git a/vespajlib/src/main/java/com/yahoo/text/Identifier.java b/vespajlib/src/main/java/com/yahoo/text/Identifier.java index 5285c72fcc0..1746ab7b2bb 100644 --- a/vespajlib/src/main/java/com/yahoo/text/Identifier.java +++ b/vespajlib/src/main/java/com/yahoo/text/Identifier.java @@ -7,6 +7,7 @@ package com.yahoo.text; * @author baldersheim */ public class Identifier extends Utf8Array { + public Identifier(String s) { this(Utf8.toBytes(s)); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java new file mode 100644 index 00000000000..dd1b3ceb823 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -0,0 +1,40 @@ +package com.yahoo.tensor; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class TensorParserTestCase { + + @Test + public void testParsing() { + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(), + Tensor.from("{}")); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell(1.0, 0).build(), + Tensor.from("{{x:0}:1.0}")); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell().label("x", "l0").value(1.0).build(), + Tensor.from("{{x:l0}:1.0}")); + } + + @Test + public void testIllegalStrings() { + assertIllegal("label must be an identifier or integer, not '\"l0\"'", + "{{x:\"l0\"}:1.0}"); + assertIllegal("dimension must be an identifier or integer, not ''x''", + "{{'x':\"l0\"}:1.0}"); + assertIllegal("dimension must be an identifier or integer, not '\"x\"'", + "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); + } + + private void assertIllegal(String message, String tensor) { + try { + Tensor.from(tensor); + fail("Expected an IllegalArgumentException when parsing " + tensor); + } + catch (IllegalArgumentException e) { + assertEquals(message, e.getMessage()); + } + } + +} |