summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-06-06 11:04:12 +0200
committerJon Bratseth <bratseth@oath.com>2018-06-06 11:04:12 +0200
commit2685f35d1161ec93710a62b465b63e0cc152aba1 (patch)
tree8b1c71e37e810b0dd46799ef527656b05e6d1cc7 /vespajlib
parentf6cccb8d88a611eaefbeebc1eac5928a64fb7cb5 (diff)
Validate that tensor dimensions and labbels are identifiers
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/text/Identifier.java1
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java40
3 files changed, 53 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 1159d2fb32e..2a713611307 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
+import java.util.regex.Pattern;
/**
* An immutable address to a tensor cell. This simply supplies a value to each dimension
@@ -158,6 +159,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
/** Supports building of a tensor address */
public static class Builder {
+ private Pattern identifierPattern = Pattern.compile("[A-Za-z0-9_]+");
+
private final TensorType type;
private final String[] labels;
@@ -176,8 +179,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
* @return this for convenience
*/
public Builder add(String dimension, String label) {
- Objects.requireNonNull(dimension, "Dimension cannot be null");
- Objects.requireNonNull(label, "Label cannot be null");
+ requireIdentifier(dimension, "dimension");
+ requireIdentifier(label, "label");
Optional<Integer> labelIndex = type.indexOfDimension(dimension);
if ( ! labelIndex.isPresent())
throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
@@ -198,6 +201,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return TensorAddress.of(labels);
}
+ private void requireIdentifier(String s, String parameterName) {
+ if (s == null)
+ throw new IllegalArgumentException(parameterName + " can not be null");
+ if ( ! identifierPattern.matcher(s).matches())
+ throw new IllegalArgumentException(parameterName + " must be an identifier or integer, not '" + s + "'");
+ }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/text/Identifier.java b/vespajlib/src/main/java/com/yahoo/text/Identifier.java
index 5285c72fcc0..1746ab7b2bb 100644
--- a/vespajlib/src/main/java/com/yahoo/text/Identifier.java
+++ b/vespajlib/src/main/java/com/yahoo/text/Identifier.java
@@ -7,6 +7,7 @@ package com.yahoo.text;
* @author baldersheim
*/
public class Identifier extends Utf8Array {
+
public Identifier(String s) {
this(Utf8.toBytes(s));
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
new file mode 100644
index 00000000000..dd1b3ceb823
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -0,0 +1,40 @@
+package com.yahoo.tensor;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class TensorParserTestCase {
+
+ @Test
+ public void testParsing() {
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(),
+ Tensor.from("{}"));
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell(1.0, 0).build(),
+ Tensor.from("{{x:0}:1.0}"));
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell().label("x", "l0").value(1.0).build(),
+ Tensor.from("{{x:l0}:1.0}"));
+ }
+
+ @Test
+ public void testIllegalStrings() {
+ assertIllegal("label must be an identifier or integer, not '\"l0\"'",
+ "{{x:\"l0\"}:1.0}");
+ assertIllegal("dimension must be an identifier or integer, not ''x''",
+ "{{'x':\"l0\"}:1.0}");
+ assertIllegal("dimension must be an identifier or integer, not '\"x\"'",
+ "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}");
+ }
+
+ private void assertIllegal(String message, String tensor) {
+ try {
+ Tensor.from(tensor);
+ fail("Expected an IllegalArgumentException when parsing " + tensor);
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals(message, e.getMessage());
+ }
+ }
+
+}