diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-17 11:42:21 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-17 11:42:21 +0100 |
commit | 3a84c90423e86bb95c9a620c1c9ccc1a055b2d37 (patch) | |
tree | bd7ffc5edc8b4f7216a2403e86001755efaf953f /vespajlib | |
parent | cbcd468e0f19421876a52ba1bd74f33fda73b855 (diff) |
Allow quoted labels in tensors
Diffstat (limited to 'vespajlib')
6 files changed, 156 insertions, 99 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index b8ef84cabb7..cffd41905a1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -18,6 +18,7 @@ import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.XwPlusB; +import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; import java.util.Arrays; @@ -31,6 +32,8 @@ import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; +import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; + /** * A multidimensional array which can be used in computations. * <p> diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 4770ad1b1f0..a3805fb789a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,14 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import com.yahoo.text.Ascii7BitMatcher; - import java.util.Arrays; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; -import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; - /** * An immutable address to a tensor cell. This simply supplies a value to each dimension * in a particular tensor type. By itself it is just a list of cell labels, it's meaning depends on its accompanying type. @@ -85,7 +82,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public final String toString(TensorType type) { StringBuilder b = new StringBuilder("{"); for (int i = 0; i < size(); i++) { - b.append(type.dimensions().get(i).name()).append(":").append(label(i)); + b.append(type.dimensions().get(i).name()).append(":").append(labelToString(label(i))); b.append(","); } if (b.length() > 1) @@ -94,6 +91,12 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return b.toString(); } + private String labelToString(String label) { + if (TensorType.labelMatcher.matches(label)) return label; // no quoting + if (label.contains("'")) return "\"" + label + "\""; + return "'" + label + "'"; + } + private static final class StringTensorAddress extends TensorAddress { private final String[] labels; @@ -166,8 +169,6 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { /** Supports building of a tensor address */ public static class Builder { - static private final Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + charsAndNumbers(), - "_@$" + charsAndNumbers()); private final TensorType type; private final String[] labels; @@ -187,10 +188,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { * @return this for convenience */ public Builder add(String dimension, String label) { - requireIdentifier(dimension, "dimension"); - requireIdentifier(label, "label"); + Objects.requireNonNull(dimension, "dimension cannot be null"); + Objects.requireNonNull(label, "label cannot be null"); Optional<Integer> labelIndex = type.indexOfDimension(dimension); - if ( ! labelIndex.isPresent()) + if ( labelIndex.isEmpty()) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); labels[labelIndex.get()] = label; return this; @@ -209,13 +210,6 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return TensorAddress.of(labels); } - static private void requireIdentifier(String s, String parameterName) { - if (s == null) - throw new IllegalArgumentException(parameterName + " can not be null"); - if ( ! labelMatcher.matches(s)) - throw new IllegalArgumentException(parameterName + " must be an identifier or integer, not '" + s + "'"); - } - } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 5a1fd98a009..8f8469cc63a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -50,7 +50,7 @@ class TensorParser { valueString = valueString.trim(); if (valueString.startsWith("{") && (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) { - return tensorFromSparseValueString(valueString, type); + return tensorFromMappedValueString(valueString, type); } else if (valueString.startsWith("{")) { return tensorFromMixedValueString(valueString, type, dimensionOrder); @@ -73,35 +73,18 @@ class TensorParser { } /** Derives the tensor type from the first address string in the given tensor string */ - private static TensorType typeFromSparseValueString(String valueString) { - String s = valueString.substring(1).trim(); // remove tensor start - int firstKeyOrTensorEnd = s.indexOf('}'); - if (firstKeyOrTensorEnd < 0) - throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'"); - String addressBody = s.substring(0, firstKeyOrTensorEnd).trim(); - if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor - if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor - - addressBody = addressBody.substring(1, addressBody.length()); // remove key start - if (addressBody.isEmpty()) return TensorType.empty; // Empty key - - TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE); - for (String elementString : addressBody.split(",")) { - String[] pair = elementString.split(":"); - if (pair.length != 2) - throw new IllegalArgumentException("Expecting argument elements to be on the form dimension:label, " + - "got '" + elementString + "'"); - builder.mapped(pair[0].trim()); - } - + private static TensorType typeFromMappedValueString(String valueString) { + TensorType.Builder builder = new TensorType.Builder(); + MappedValueTypeParser parser = new MappedValueTypeParser(valueString, builder); + parser.parse(); return builder.build(); } - private static Tensor tensorFromSparseValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromMappedValueString(String valueString, Optional<TensorType> type) { try { valueString = valueString.trim(); - Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString))); - SparseValueParser parser = new SparseValueParser(valueString, builder); + Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromMappedValueString(valueString))); + MappedValueParser parser = new MappedValueParser(valueString, builder); parser.parse(); return builder.build(); } @@ -176,6 +159,37 @@ class TensorParser { position++; } + protected String consumeIdentifier() { + int endIdentifier = nextStopCharIndex(position, string); + String identifier = string.substring(position, endIdentifier); + position = endIdentifier; + return identifier; + } + + protected String consumeLabel() { + if (consumeOptional('\'')) { + int endQuote = string.indexOf('\'', position); + if (endQuote < 0) + throw new IllegalArgumentException("At value position " + position + + ": A label quoted by a tick (') must end by another tick"); + String label = string.substring(position, endQuote); + position = endQuote + 1; + return label; + } + else if (consumeOptional('"')) { + int endQuote = string.indexOf('"', position); + if (endQuote < 0) + throw new IllegalArgumentException("At value position " + position + + ": A label quoted by a double quote (\") must end by another double quote"); + String label = string.substring(position, endQuote); + position = endQuote + 1; + return label; + } + else { + return consumeIdentifier(); + } + } + protected Number consumeNumber(TensorType.Value cellValueType) { skipSpace(); @@ -199,15 +213,28 @@ class TensorParser { } } + protected boolean consumeOptional(char character) { + skipSpace(); + + if (position >= string.length()) + return false; + if ( string.charAt(position) != character) + return false; + + position++; + return true; + } + protected int nextStopCharIndex(int position, String valueString) { while (position < valueString.length()) { if (valueString.charAt(position) == ',') return position; if (valueString.charAt(position) == ']') return position; if (valueString.charAt(position) == '}') return position; + if (valueString.charAt(position) == ':') return position; position++; } - throw new IllegalArgumentException("Malformed tensor value '" + valueString + - "': Expected a ',', ']' or '}' after position " + position); + throw new IllegalArgumentException("Malformed tensor string '" + valueString + + "': Expected a ',', ']' or '}', ':' after position " + position); } } @@ -291,13 +318,8 @@ class TensorParser { consume('{'); skipSpace(); while (position + 1 < string.length()) { - int labelEnd = string.indexOf(':', position); - if (labelEnd <= position) - throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...}, or {sparse-label:value, ...}"); - String label = string.substring(position, labelEnd); - position = labelEnd + 1; - skipSpace(); - + String label = consumeLabel(); + consume(':'); TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build(); if (builder.type().rank() > 1) parseDenseSubspace(mappedAddress, dimensionOrder); @@ -309,24 +331,12 @@ class TensorParser { } } - private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) { + private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) { DenseValueParser denseParser = new DenseValueParser(string.substring(position), denseDimensionOrder, - ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(sparseAddress)); + ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(mappedAddress)); denseParser.parse(); - position+= denseParser.position(); - } - - private boolean consumeOptional(char character) { - skipSpace(); - - if (position >= string.length()) - return false; - if ( string.charAt(position) != character) - return false; - - position++; - return true; + position += denseParser.position(); } private void consumeNumber(TensorAddress address) { @@ -339,11 +349,11 @@ class TensorParser { } - private static class SparseValueParser extends ValueParser { + private static class MappedValueParser extends ValueParser { private final Tensor.Builder builder; - public SparseValueParser(String string, Tensor.Builder builder) { + public MappedValueParser(String string, Tensor.Builder builder) { super(string); this.builder = builder; } @@ -352,22 +362,17 @@ class TensorParser { consume('{'); skipSpace(); while (position + 1 < string.length()) { - int keyOrTensorEnd = string.indexOf('}', position); - TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); - if (keyOrTensorEnd < string.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty - addLabels(string.substring(position, keyOrTensorEnd + 1), addressBuilder); - position = keyOrTensorEnd + 1; - skipSpace(); + TensorAddress address = consumeLabels(); + if ( ! address.isEmpty()) consume(':'); - } + int valueEnd = string.indexOf(',', position); if (valueEnd < 0) { // last value valueEnd = string.indexOf('}', position); if (valueEnd < 0) - throw new IllegalArgumentException("A sparse tensor string must end by '}'"); + throw new IllegalArgumentException("A mapped tensor string must end by '}'"); } - TensorAddress address = addressBuilder.build(); TensorType.Value cellValueType = builder.type().valueType(); String cellValueString = string.substring(position, valueEnd).trim(); try { @@ -389,21 +394,46 @@ class TensorParser { } /** Creates a tensor address from a string on the form {dimension1:label1,dimension2:label2,...} */ - private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { - mapAddressString = mapAddressString.trim(); - if ( ! (mapAddressString.startsWith("{") && mapAddressString.endsWith("}"))) - throw new IllegalArgumentException("Expecting a tensor address enclosed in {}, got '" + mapAddressString + "'"); - - String addressBody = mapAddressString.substring(1, mapAddressString.length() - 1).trim(); - if (addressBody.isEmpty()) return; - - for (String elementString : addressBody.split(",")) { - String[] pair = elementString.split(":"); - if (pair.length != 2) - throw new IllegalArgumentException("Expecting argument elements on the form dimension:label, " + - "got '" + elementString + "'"); - String dimension = pair[0].trim(); - builder.add(dimension, pair[1].trim()); + private TensorAddress consumeLabels() { + TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); + if ( ! consumeOptional('{')) return addressBuilder.build(); + while ( ! consumeOptional('}')) { + String dimension = consumeIdentifier(); + consume(':'); + String label = consumeLabel(); + addressBuilder.add(dimension, label); + consumeOptional(','); + } + return addressBuilder.build(); + } + + } + + /** Parses a tensor *value* into a type */ + private static class MappedValueTypeParser extends ValueParser { + + private final TensorType.Builder builder; + + public MappedValueTypeParser(String string, TensorType.Builder builder) { + super(string); + this.builder = builder; + } + + /** Derives the tensor type from the first address string in the given tensor string */ + public void parse() { + consume('{'); + consumeLabels(); + } + + /** Consumes a mapped address into a set of the type builder */ + private void consumeLabels() { + if ( ! consumeOptional('{')) return; + while ( ! consumeOptional('}')) { + String dimension = consumeIdentifier(); + consume(':'); + consumeLabel(); + builder.mapped(dimension); + consumeOptional(','); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index ca3f8ff28a4..58cb151875e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -2,9 +2,9 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; +import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; @@ -15,6 +15,8 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; + /** * A tensor type with its dimensions. This is immutable. * <p> @@ -25,6 +27,8 @@ import java.util.stream.Collectors; */ public class TensorType { + static Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + charsAndNumbers(), "_@$" + charsAndNumbers()); + /** The permissible cell value types. Default is double. */ public enum Value { @@ -292,8 +296,7 @@ public class TensorType { private final String name; private Dimension(String name) { - Objects.requireNonNull(name, "A tensor name cannot be null"); - this.name = name; + this.name = requireIdentifier(name); } public final String name() { return name; } @@ -361,6 +364,14 @@ public class TensorType { return new MappedDimension(name); } + static private String requireIdentifier(String name) { + if (name == null) + throw new IllegalArgumentException("A dimension name cannot be null"); + if ( ! TensorType.labelMatcher.matches(name)) + throw new IllegalArgumentException("A dimension name must be an identifier or integer, not '" + name + "'"); + return name; + } + } public static class IndexedBoundDimension extends TensorType.Dimension { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 6f9a5c13886..78afa1f7449 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -19,6 +19,10 @@ public class TensorParserTestCase { assertEquals("If the type is specified, a dense tensor can be created from the sparse text form", Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(), Tensor.from("tensor(x[1]):{{x:0}:1.0}")); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell().label("x", "..\",]}:..").value(1.0).build(), + Tensor.from("{{x:'..\",]}:..'}:1.0}")); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell().label("x", "..'..").value(1.0).build(), + Tensor.from("{{x:\"..'..\"}:1.0}")); } @Test @@ -95,6 +99,12 @@ public class TensorParserTestCase { .cell(TensorAddress.ofLabels("b", "0"), 3) .cell(TensorAddress.ofLabels("b", "1"), 4).build(), Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}")); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])")) + .cell(TensorAddress.ofLabels(",:", "0"), 1) + .cell(TensorAddress.ofLabels(",:", "1"), 2) + .cell(TensorAddress.ofLabels("b", "0"), 3) + .cell(TensorAddress.ofLabels("b", "1"), 4).build(), + Tensor.from("tensor(key{}, x[2]):{',:':[1, 2], b:[3, 4]}")); } @Test @@ -103,6 +113,14 @@ public class TensorParserTestCase { .cell(TensorAddress.ofLabels("a"), 1) .cell(TensorAddress.ofLabels("b"), 2).build(), Tensor.from("tensor(key{}):{a:1, b:2}")); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{})")) + .cell(TensorAddress.ofLabels("..\",}]:.."), 1) + .cell(TensorAddress.ofLabels("b"), 2).build(), + Tensor.from("tensor(key{}):{'..\",}]:..':1, b:2}")); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{})")) + .cell(TensorAddress.ofLabels("..'.."), 1) + .cell(TensorAddress.ofLabels("b"), 2).build(), + Tensor.from("tensor(key{}):{\"..'..\":1, b:2}")); } @Test @@ -134,11 +152,9 @@ public class TensorParserTestCase { @Test public void testIllegalStrings() { - assertIllegal("label must be an identifier or integer, not '\"l0\"'", - "{{x:\"l0\"}:1.0}"); - assertIllegal("dimension must be an identifier or integer, not ''x''", + assertIllegal("A dimension name must be an identifier or integer, not ''x''", "{{'x':\"l0\"}:1.0}"); - assertIllegal("dimension must be an identifier or integer, not '\"x\"'", + assertIllegal("A dimension name must be an identifier or integer, not '\"x\"'", "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); assertIllegal("At {x:0}: '1-.0' is not a valid double", "{{x:0}:1-.0}"); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 9f077cb7b00..7932f90d797 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -36,6 +36,9 @@ public class TensorTestCase { assertTrue(Tensor.from("tensor():{5.7}") instanceof IndexedTensor); assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); + assertEquals("Labels are quoted when necessary", + "tensor(d1{}):{{d1:\"'''\"}:6.0,{d1:'[[\":\"]]'}:5.0}", + Tensor.from("{ {d1:'[[\":\"]]'}: 5, {d1:\"'''\"}:6.0 }").toString()); } @Test |