From 3a84c90423e86bb95c9a620c1c9ccc1a055b2d37 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 17 Dec 2019 11:42:21 +0100 Subject: Allow quoted labels in tensors --- .../src/main/java/com/yahoo/tensor/Tensor.java | 3 + .../main/java/com/yahoo/tensor/TensorAddress.java | 28 ++-- .../main/java/com/yahoo/tensor/TensorParser.java | 180 ++++++++++++--------- .../src/main/java/com/yahoo/tensor/TensorType.java | 17 +- 4 files changed, 133 insertions(+), 95 deletions(-) (limited to 'vespajlib/src/main/java/com/yahoo') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index b8ef84cabb7..cffd41905a1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -18,6 +18,7 @@ import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.XwPlusB; +import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; import java.util.Arrays; @@ -31,6 +32,8 @@ import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; +import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; + /** * A multidimensional array which can be used in computations. *

diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 4770ad1b1f0..a3805fb789a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,14 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import com.yahoo.text.Ascii7BitMatcher; - import java.util.Arrays; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; -import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; - /** * An immutable address to a tensor cell. This simply supplies a value to each dimension * in a particular tensor type. By itself it is just a list of cell labels, it's meaning depends on its accompanying type. @@ -85,7 +82,7 @@ public abstract class TensorAddress implements Comparable { public final String toString(TensorType type) { StringBuilder b = new StringBuilder("{"); for (int i = 0; i < size(); i++) { - b.append(type.dimensions().get(i).name()).append(":").append(label(i)); + b.append(type.dimensions().get(i).name()).append(":").append(labelToString(label(i))); b.append(","); } if (b.length() > 1) @@ -94,6 +91,12 @@ public abstract class TensorAddress implements Comparable { return b.toString(); } + private String labelToString(String label) { + if (TensorType.labelMatcher.matches(label)) return label; // no quoting + if (label.contains("'")) return "\"" + label + "\""; + return "'" + label + "'"; + } + private static final class StringTensorAddress extends TensorAddress { private final String[] labels; @@ -166,8 +169,6 @@ public abstract class TensorAddress implements Comparable { /** Supports building of a tensor address */ public static class Builder { - static private final Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + charsAndNumbers(), - "_@$" + charsAndNumbers()); private final TensorType type; private final String[] labels; @@ -187,10 +188,10 @@ public abstract class TensorAddress implements Comparable { * @return this for convenience */ public Builder add(String dimension, String label) { - requireIdentifier(dimension, "dimension"); - requireIdentifier(label, "label"); + Objects.requireNonNull(dimension, "dimension cannot be null"); + Objects.requireNonNull(label, "label cannot be null"); Optional labelIndex = type.indexOfDimension(dimension); - if ( ! labelIndex.isPresent()) + if ( labelIndex.isEmpty()) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); labels[labelIndex.get()] = label; return this; @@ -209,13 +210,6 @@ public abstract class TensorAddress implements Comparable { return TensorAddress.of(labels); } - static private void requireIdentifier(String s, String parameterName) { - if (s == null) - throw new IllegalArgumentException(parameterName + " can not be null"); - if ( ! labelMatcher.matches(s)) - throw new IllegalArgumentException(parameterName + " must be an identifier or integer, not '" + s + "'"); - } - } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 5a1fd98a009..8f8469cc63a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -50,7 +50,7 @@ class TensorParser { valueString = valueString.trim(); if (valueString.startsWith("{") && (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) { - return tensorFromSparseValueString(valueString, type); + return tensorFromMappedValueString(valueString, type); } else if (valueString.startsWith("{")) { return tensorFromMixedValueString(valueString, type, dimensionOrder); @@ -73,35 +73,18 @@ class TensorParser { } /** Derives the tensor type from the first address string in the given tensor string */ - private static TensorType typeFromSparseValueString(String valueString) { - String s = valueString.substring(1).trim(); // remove tensor start - int firstKeyOrTensorEnd = s.indexOf('}'); - if (firstKeyOrTensorEnd < 0) - throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'"); - String addressBody = s.substring(0, firstKeyOrTensorEnd).trim(); - if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor - if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor - - addressBody = addressBody.substring(1, addressBody.length()); // remove key start - if (addressBody.isEmpty()) return TensorType.empty; // Empty key - - TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE); - for (String elementString : addressBody.split(",")) { - String[] pair = elementString.split(":"); - if (pair.length != 2) - throw new IllegalArgumentException("Expecting argument elements to be on the form dimension:label, " + - "got '" + elementString + "'"); - builder.mapped(pair[0].trim()); - } - + private static TensorType typeFromMappedValueString(String valueString) { + TensorType.Builder builder = new TensorType.Builder(); + MappedValueTypeParser parser = new MappedValueTypeParser(valueString, builder); + parser.parse(); return builder.build(); } - private static Tensor tensorFromSparseValueString(String valueString, Optional type) { + private static Tensor tensorFromMappedValueString(String valueString, Optional type) { try { valueString = valueString.trim(); - Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString))); - SparseValueParser parser = new SparseValueParser(valueString, builder); + Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromMappedValueString(valueString))); + MappedValueParser parser = new MappedValueParser(valueString, builder); parser.parse(); return builder.build(); } @@ -176,6 +159,37 @@ class TensorParser { position++; } + protected String consumeIdentifier() { + int endIdentifier = nextStopCharIndex(position, string); + String identifier = string.substring(position, endIdentifier); + position = endIdentifier; + return identifier; + } + + protected String consumeLabel() { + if (consumeOptional('\'')) { + int endQuote = string.indexOf('\'', position); + if (endQuote < 0) + throw new IllegalArgumentException("At value position " + position + + ": A label quoted by a tick (') must end by another tick"); + String label = string.substring(position, endQuote); + position = endQuote + 1; + return label; + } + else if (consumeOptional('"')) { + int endQuote = string.indexOf('"', position); + if (endQuote < 0) + throw new IllegalArgumentException("At value position " + position + + ": A label quoted by a double quote (\") must end by another double quote"); + String label = string.substring(position, endQuote); + position = endQuote + 1; + return label; + } + else { + return consumeIdentifier(); + } + } + protected Number consumeNumber(TensorType.Value cellValueType) { skipSpace(); @@ -199,15 +213,28 @@ class TensorParser { } } + protected boolean consumeOptional(char character) { + skipSpace(); + + if (position >= string.length()) + return false; + if ( string.charAt(position) != character) + return false; + + position++; + return true; + } + protected int nextStopCharIndex(int position, String valueString) { while (position < valueString.length()) { if (valueString.charAt(position) == ',') return position; if (valueString.charAt(position) == ']') return position; if (valueString.charAt(position) == '}') return position; + if (valueString.charAt(position) == ':') return position; position++; } - throw new IllegalArgumentException("Malformed tensor value '" + valueString + - "': Expected a ',', ']' or '}' after position " + position); + throw new IllegalArgumentException("Malformed tensor string '" + valueString + + "': Expected a ',', ']' or '}', ':' after position " + position); } } @@ -291,13 +318,8 @@ class TensorParser { consume('{'); skipSpace(); while (position + 1 < string.length()) { - int labelEnd = string.indexOf(':', position); - if (labelEnd <= position) - throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...}, or {sparse-label:value, ...}"); - String label = string.substring(position, labelEnd); - position = labelEnd + 1; - skipSpace(); - + String label = consumeLabel(); + consume(':'); TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build(); if (builder.type().rank() > 1) parseDenseSubspace(mappedAddress, dimensionOrder); @@ -309,24 +331,12 @@ class TensorParser { } } - private void parseDenseSubspace(TensorAddress sparseAddress, List denseDimensionOrder) { + private void parseDenseSubspace(TensorAddress mappedAddress, List denseDimensionOrder) { DenseValueParser denseParser = new DenseValueParser(string.substring(position), denseDimensionOrder, - ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(sparseAddress)); + ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(mappedAddress)); denseParser.parse(); - position+= denseParser.position(); - } - - private boolean consumeOptional(char character) { - skipSpace(); - - if (position >= string.length()) - return false; - if ( string.charAt(position) != character) - return false; - - position++; - return true; + position += denseParser.position(); } private void consumeNumber(TensorAddress address) { @@ -339,11 +349,11 @@ class TensorParser { } - private static class SparseValueParser extends ValueParser { + private static class MappedValueParser extends ValueParser { private final Tensor.Builder builder; - public SparseValueParser(String string, Tensor.Builder builder) { + public MappedValueParser(String string, Tensor.Builder builder) { super(string); this.builder = builder; } @@ -352,22 +362,17 @@ class TensorParser { consume('{'); skipSpace(); while (position + 1 < string.length()) { - int keyOrTensorEnd = string.indexOf('}', position); - TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); - if (keyOrTensorEnd < string.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty - addLabels(string.substring(position, keyOrTensorEnd + 1), addressBuilder); - position = keyOrTensorEnd + 1; - skipSpace(); + TensorAddress address = consumeLabels(); + if ( ! address.isEmpty()) consume(':'); - } + int valueEnd = string.indexOf(',', position); if (valueEnd < 0) { // last value valueEnd = string.indexOf('}', position); if (valueEnd < 0) - throw new IllegalArgumentException("A sparse tensor string must end by '}'"); + throw new IllegalArgumentException("A mapped tensor string must end by '}'"); } - TensorAddress address = addressBuilder.build(); TensorType.Value cellValueType = builder.type().valueType(); String cellValueString = string.substring(position, valueEnd).trim(); try { @@ -389,21 +394,46 @@ class TensorParser { } /** Creates a tensor address from a string on the form {dimension1:label1,dimension2:label2,...} */ - private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { - mapAddressString = mapAddressString.trim(); - if ( ! (mapAddressString.startsWith("{") && mapAddressString.endsWith("}"))) - throw new IllegalArgumentException("Expecting a tensor address enclosed in {}, got '" + mapAddressString + "'"); - - String addressBody = mapAddressString.substring(1, mapAddressString.length() - 1).trim(); - if (addressBody.isEmpty()) return; - - for (String elementString : addressBody.split(",")) { - String[] pair = elementString.split(":"); - if (pair.length != 2) - throw new IllegalArgumentException("Expecting argument elements on the form dimension:label, " + - "got '" + elementString + "'"); - String dimension = pair[0].trim(); - builder.add(dimension, pair[1].trim()); + private TensorAddress consumeLabels() { + TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); + if ( ! consumeOptional('{')) return addressBuilder.build(); + while ( ! consumeOptional('}')) { + String dimension = consumeIdentifier(); + consume(':'); + String label = consumeLabel(); + addressBuilder.add(dimension, label); + consumeOptional(','); + } + return addressBuilder.build(); + } + + } + + /** Parses a tensor *value* into a type */ + private static class MappedValueTypeParser extends ValueParser { + + private final TensorType.Builder builder; + + public MappedValueTypeParser(String string, TensorType.Builder builder) { + super(string); + this.builder = builder; + } + + /** Derives the tensor type from the first address string in the given tensor string */ + public void parse() { + consume('{'); + consumeLabels(); + } + + /** Consumes a mapped address into a set of the type builder */ + private void consumeLabels() { + if ( ! consumeOptional('{')) return; + while ( ! consumeOptional('}')) { + String dimension = consumeIdentifier(); + consume(':'); + consumeLabel(); + builder.mapped(dimension); + consumeOptional(','); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index ca3f8ff28a4..58cb151875e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -2,9 +2,9 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; +import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; @@ -15,6 +15,8 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; + /** * A tensor type with its dimensions. This is immutable. *

@@ -25,6 +27,8 @@ import java.util.stream.Collectors; */ public class TensorType { + static Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + charsAndNumbers(), "_@$" + charsAndNumbers()); + /** The permissible cell value types. Default is double. */ public enum Value { @@ -292,8 +296,7 @@ public class TensorType { private final String name; private Dimension(String name) { - Objects.requireNonNull(name, "A tensor name cannot be null"); - this.name = name; + this.name = requireIdentifier(name); } public final String name() { return name; } @@ -361,6 +364,14 @@ public class TensorType { return new MappedDimension(name); } + static private String requireIdentifier(String name) { + if (name == null) + throw new IllegalArgumentException("A dimension name cannot be null"); + if ( ! TensorType.labelMatcher.matches(name)) + throw new IllegalArgumentException("A dimension name must be an identifier or integer, not '" + name + "'"); + return name; + } + } public static class IndexedBoundDimension extends TensorType.Dimension { -- cgit v1.2.3