summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-17 11:42:21 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-17 11:42:21 +0100
commit3a84c90423e86bb95c9a620c1c9ccc1a055b2d37 (patch)
treebd7ffc5edc8b4f7216a2403e86001755efaf953f /vespajlib
parentcbcd468e0f19421876a52ba1bd74f33fda73b855 (diff)
Allow quoted labels in tensors
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java180
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java17
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java24
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java3
6 files changed, 156 insertions, 99 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index b8ef84cabb7..cffd41905a1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -18,6 +18,7 @@ import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
+import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
import java.util.Arrays;
@@ -31,6 +32,8 @@ import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
+import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers;
+
/**
* A multidimensional array which can be used in computations.
* <p>
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 4770ad1b1f0..a3805fb789a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -1,14 +1,11 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import com.yahoo.text.Ascii7BitMatcher;
-
import java.util.Arrays;
+import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
-import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers;
-
/**
* An immutable address to a tensor cell. This simply supplies a value to each dimension
* in a particular tensor type. By itself it is just a list of cell labels, it's meaning depends on its accompanying type.
@@ -85,7 +82,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public final String toString(TensorType type) {
StringBuilder b = new StringBuilder("{");
for (int i = 0; i < size(); i++) {
- b.append(type.dimensions().get(i).name()).append(":").append(label(i));
+ b.append(type.dimensions().get(i).name()).append(":").append(labelToString(label(i)));
b.append(",");
}
if (b.length() > 1)
@@ -94,6 +91,12 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
+ private String labelToString(String label) {
+ if (TensorType.labelMatcher.matches(label)) return label; // no quoting
+ if (label.contains("'")) return "\"" + label + "\"";
+ return "'" + label + "'";
+ }
+
private static final class StringTensorAddress extends TensorAddress {
private final String[] labels;
@@ -166,8 +169,6 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
/** Supports building of a tensor address */
public static class Builder {
- static private final Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + charsAndNumbers(),
- "_@$" + charsAndNumbers());
private final TensorType type;
private final String[] labels;
@@ -187,10 +188,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
* @return this for convenience
*/
public Builder add(String dimension, String label) {
- requireIdentifier(dimension, "dimension");
- requireIdentifier(label, "label");
+ Objects.requireNonNull(dimension, "dimension cannot be null");
+ Objects.requireNonNull(label, "label cannot be null");
Optional<Integer> labelIndex = type.indexOfDimension(dimension);
- if ( ! labelIndex.isPresent())
+ if ( labelIndex.isEmpty())
throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
labels[labelIndex.get()] = label;
return this;
@@ -209,13 +210,6 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return TensorAddress.of(labels);
}
- static private void requireIdentifier(String s, String parameterName) {
- if (s == null)
- throw new IllegalArgumentException(parameterName + " can not be null");
- if ( ! labelMatcher.matches(s))
- throw new IllegalArgumentException(parameterName + " must be an identifier or integer, not '" + s + "'");
- }
-
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 5a1fd98a009..8f8469cc63a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -50,7 +50,7 @@ class TensorParser {
valueString = valueString.trim();
if (valueString.startsWith("{") &&
(type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) {
- return tensorFromSparseValueString(valueString, type);
+ return tensorFromMappedValueString(valueString, type);
}
else if (valueString.startsWith("{")) {
return tensorFromMixedValueString(valueString, type, dimensionOrder);
@@ -73,35 +73,18 @@ class TensorParser {
}
/** Derives the tensor type from the first address string in the given tensor string */
- private static TensorType typeFromSparseValueString(String valueString) {
- String s = valueString.substring(1).trim(); // remove tensor start
- int firstKeyOrTensorEnd = s.indexOf('}');
- if (firstKeyOrTensorEnd < 0)
- throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'");
- String addressBody = s.substring(0, firstKeyOrTensorEnd).trim();
- if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor
- if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor
-
- addressBody = addressBody.substring(1, addressBody.length()); // remove key start
- if (addressBody.isEmpty()) return TensorType.empty; // Empty key
-
- TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE);
- for (String elementString : addressBody.split(",")) {
- String[] pair = elementString.split(":");
- if (pair.length != 2)
- throw new IllegalArgumentException("Expecting argument elements to be on the form dimension:label, " +
- "got '" + elementString + "'");
- builder.mapped(pair[0].trim());
- }
-
+ private static TensorType typeFromMappedValueString(String valueString) {
+ TensorType.Builder builder = new TensorType.Builder();
+ MappedValueTypeParser parser = new MappedValueTypeParser(valueString, builder);
+ parser.parse();
return builder.build();
}
- private static Tensor tensorFromSparseValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromMappedValueString(String valueString, Optional<TensorType> type) {
try {
valueString = valueString.trim();
- Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString)));
- SparseValueParser parser = new SparseValueParser(valueString, builder);
+ Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromMappedValueString(valueString)));
+ MappedValueParser parser = new MappedValueParser(valueString, builder);
parser.parse();
return builder.build();
}
@@ -176,6 +159,37 @@ class TensorParser {
position++;
}
+ protected String consumeIdentifier() {
+ int endIdentifier = nextStopCharIndex(position, string);
+ String identifier = string.substring(position, endIdentifier);
+ position = endIdentifier;
+ return identifier;
+ }
+
+ protected String consumeLabel() {
+ if (consumeOptional('\'')) {
+ int endQuote = string.indexOf('\'', position);
+ if (endQuote < 0)
+ throw new IllegalArgumentException("At value position " + position +
+ ": A label quoted by a tick (') must end by another tick");
+ String label = string.substring(position, endQuote);
+ position = endQuote + 1;
+ return label;
+ }
+ else if (consumeOptional('"')) {
+ int endQuote = string.indexOf('"', position);
+ if (endQuote < 0)
+ throw new IllegalArgumentException("At value position " + position +
+ ": A label quoted by a double quote (\") must end by another double quote");
+ String label = string.substring(position, endQuote);
+ position = endQuote + 1;
+ return label;
+ }
+ else {
+ return consumeIdentifier();
+ }
+ }
+
protected Number consumeNumber(TensorType.Value cellValueType) {
skipSpace();
@@ -199,15 +213,28 @@ class TensorParser {
}
}
+ protected boolean consumeOptional(char character) {
+ skipSpace();
+
+ if (position >= string.length())
+ return false;
+ if ( string.charAt(position) != character)
+ return false;
+
+ position++;
+ return true;
+ }
+
protected int nextStopCharIndex(int position, String valueString) {
while (position < valueString.length()) {
if (valueString.charAt(position) == ',') return position;
if (valueString.charAt(position) == ']') return position;
if (valueString.charAt(position) == '}') return position;
+ if (valueString.charAt(position) == ':') return position;
position++;
}
- throw new IllegalArgumentException("Malformed tensor value '" + valueString +
- "': Expected a ',', ']' or '}' after position " + position);
+ throw new IllegalArgumentException("Malformed tensor string '" + valueString +
+ "': Expected a ',', ']' or '}', ':' after position " + position);
}
}
@@ -291,13 +318,8 @@ class TensorParser {
consume('{');
skipSpace();
while (position + 1 < string.length()) {
- int labelEnd = string.indexOf(':', position);
- if (labelEnd <= position)
- throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...}, or {sparse-label:value, ...}");
- String label = string.substring(position, labelEnd);
- position = labelEnd + 1;
- skipSpace();
-
+ String label = consumeLabel();
+ consume(':');
TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build();
if (builder.type().rank() > 1)
parseDenseSubspace(mappedAddress, dimensionOrder);
@@ -309,24 +331,12 @@ class TensorParser {
}
}
- private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) {
+ private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) {
DenseValueParser denseParser = new DenseValueParser(string.substring(position),
denseDimensionOrder,
- ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(sparseAddress));
+ ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(mappedAddress));
denseParser.parse();
- position+= denseParser.position();
- }
-
- private boolean consumeOptional(char character) {
- skipSpace();
-
- if (position >= string.length())
- return false;
- if ( string.charAt(position) != character)
- return false;
-
- position++;
- return true;
+ position += denseParser.position();
}
private void consumeNumber(TensorAddress address) {
@@ -339,11 +349,11 @@ class TensorParser {
}
- private static class SparseValueParser extends ValueParser {
+ private static class MappedValueParser extends ValueParser {
private final Tensor.Builder builder;
- public SparseValueParser(String string, Tensor.Builder builder) {
+ public MappedValueParser(String string, Tensor.Builder builder) {
super(string);
this.builder = builder;
}
@@ -352,22 +362,17 @@ class TensorParser {
consume('{');
skipSpace();
while (position + 1 < string.length()) {
- int keyOrTensorEnd = string.indexOf('}', position);
- TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type());
- if (keyOrTensorEnd < string.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty
- addLabels(string.substring(position, keyOrTensorEnd + 1), addressBuilder);
- position = keyOrTensorEnd + 1;
- skipSpace();
+ TensorAddress address = consumeLabels();
+ if ( ! address.isEmpty())
consume(':');
- }
+
int valueEnd = string.indexOf(',', position);
if (valueEnd < 0) { // last value
valueEnd = string.indexOf('}', position);
if (valueEnd < 0)
- throw new IllegalArgumentException("A sparse tensor string must end by '}'");
+ throw new IllegalArgumentException("A mapped tensor string must end by '}'");
}
- TensorAddress address = addressBuilder.build();
TensorType.Value cellValueType = builder.type().valueType();
String cellValueString = string.substring(position, valueEnd).trim();
try {
@@ -389,21 +394,46 @@ class TensorParser {
}
/** Creates a tensor address from a string on the form {dimension1:label1,dimension2:label2,...} */
- private static void addLabels(String mapAddressString, TensorAddress.Builder builder) {
- mapAddressString = mapAddressString.trim();
- if ( ! (mapAddressString.startsWith("{") && mapAddressString.endsWith("}")))
- throw new IllegalArgumentException("Expecting a tensor address enclosed in {}, got '" + mapAddressString + "'");
-
- String addressBody = mapAddressString.substring(1, mapAddressString.length() - 1).trim();
- if (addressBody.isEmpty()) return;
-
- for (String elementString : addressBody.split(",")) {
- String[] pair = elementString.split(":");
- if (pair.length != 2)
- throw new IllegalArgumentException("Expecting argument elements on the form dimension:label, " +
- "got '" + elementString + "'");
- String dimension = pair[0].trim();
- builder.add(dimension, pair[1].trim());
+ private TensorAddress consumeLabels() {
+ TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type());
+ if ( ! consumeOptional('{')) return addressBuilder.build();
+ while ( ! consumeOptional('}')) {
+ String dimension = consumeIdentifier();
+ consume(':');
+ String label = consumeLabel();
+ addressBuilder.add(dimension, label);
+ consumeOptional(',');
+ }
+ return addressBuilder.build();
+ }
+
+ }
+
+ /** Parses a tensor *value* into a type */
+ private static class MappedValueTypeParser extends ValueParser {
+
+ private final TensorType.Builder builder;
+
+ public MappedValueTypeParser(String string, TensorType.Builder builder) {
+ super(string);
+ this.builder = builder;
+ }
+
+ /** Derives the tensor type from the first address string in the given tensor string */
+ public void parse() {
+ consume('{');
+ consumeLabels();
+ }
+
+ /** Consumes a mapped address into a set of the type builder */
+ private void consumeLabels() {
+ if ( ! consumeOptional('{')) return;
+ while ( ! consumeOptional('}')) {
+ String dimension = consumeIdentifier();
+ consume(':');
+ consumeLabel();
+ builder.mapped(dimension);
+ consumeOptional(',');
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index ca3f8ff28a4..58cb151875e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -2,9 +2,9 @@
package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
+import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
@@ -15,6 +15,8 @@ import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
+import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers;
+
/**
* A tensor type with its dimensions. This is immutable.
* <p>
@@ -25,6 +27,8 @@ import java.util.stream.Collectors;
*/
public class TensorType {
+ static Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + charsAndNumbers(), "_@$" + charsAndNumbers());
+
/** The permissible cell value types. Default is double. */
public enum Value {
@@ -292,8 +296,7 @@ public class TensorType {
private final String name;
private Dimension(String name) {
- Objects.requireNonNull(name, "A tensor name cannot be null");
- this.name = name;
+ this.name = requireIdentifier(name);
}
public final String name() { return name; }
@@ -361,6 +364,14 @@ public class TensorType {
return new MappedDimension(name);
}
+ static private String requireIdentifier(String name) {
+ if (name == null)
+ throw new IllegalArgumentException("A dimension name cannot be null");
+ if ( ! TensorType.labelMatcher.matches(name))
+ throw new IllegalArgumentException("A dimension name must be an identifier or integer, not '" + name + "'");
+ return name;
+ }
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index 6f9a5c13886..78afa1f7449 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -19,6 +19,10 @@ public class TensorParserTestCase {
assertEquals("If the type is specified, a dense tensor can be created from the sparse text form",
Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(),
Tensor.from("tensor(x[1]):{{x:0}:1.0}"));
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell().label("x", "..\",]}:..").value(1.0).build(),
+ Tensor.from("{{x:'..\",]}:..'}:1.0}"));
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell().label("x", "..'..").value(1.0).build(),
+ Tensor.from("{{x:\"..'..\"}:1.0}"));
}
@Test
@@ -95,6 +99,12 @@ public class TensorParserTestCase {
.cell(TensorAddress.ofLabels("b", "0"), 3)
.cell(TensorAddress.ofLabels("b", "1"), 4).build(),
Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}"));
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])"))
+ .cell(TensorAddress.ofLabels(",:", "0"), 1)
+ .cell(TensorAddress.ofLabels(",:", "1"), 2)
+ .cell(TensorAddress.ofLabels("b", "0"), 3)
+ .cell(TensorAddress.ofLabels("b", "1"), 4).build(),
+ Tensor.from("tensor(key{}, x[2]):{',:':[1, 2], b:[3, 4]}"));
}
@Test
@@ -103,6 +113,14 @@ public class TensorParserTestCase {
.cell(TensorAddress.ofLabels("a"), 1)
.cell(TensorAddress.ofLabels("b"), 2).build(),
Tensor.from("tensor(key{}):{a:1, b:2}"));
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{})"))
+ .cell(TensorAddress.ofLabels("..\",}]:.."), 1)
+ .cell(TensorAddress.ofLabels("b"), 2).build(),
+ Tensor.from("tensor(key{}):{'..\",}]:..':1, b:2}"));
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{})"))
+ .cell(TensorAddress.ofLabels("..'.."), 1)
+ .cell(TensorAddress.ofLabels("b"), 2).build(),
+ Tensor.from("tensor(key{}):{\"..'..\":1, b:2}"));
}
@Test
@@ -134,11 +152,9 @@ public class TensorParserTestCase {
@Test
public void testIllegalStrings() {
- assertIllegal("label must be an identifier or integer, not '\"l0\"'",
- "{{x:\"l0\"}:1.0}");
- assertIllegal("dimension must be an identifier or integer, not ''x''",
+ assertIllegal("A dimension name must be an identifier or integer, not ''x''",
"{{'x':\"l0\"}:1.0}");
- assertIllegal("dimension must be an identifier or integer, not '\"x\"'",
+ assertIllegal("A dimension name must be an identifier or integer, not '\"x\"'",
"{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}");
assertIllegal("At {x:0}: '1-.0' is not a valid double",
"{{x:0}:1-.0}");
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 9f077cb7b00..7932f90d797 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -36,6 +36,9 @@ public class TensorTestCase {
assertTrue(Tensor.from("tensor():{5.7}") instanceof IndexedTensor);
assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString());
assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString());
+ assertEquals("Labels are quoted when necessary",
+ "tensor(d1{}):{{d1:\"'''\"}:6.0,{d1:'[[\":\"]]'}:5.0}",
+ Tensor.from("{ {d1:'[[\":\"]]'}: 5, {d1:\"'''\"}:6.0 }").toString());
}
@Test