diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
commit | 4c46e1816d2cdfacd8435ad4d55e831929fc99ba (patch) | |
tree | d55a90aeeddcf9265a74e7f16129517e36f45375 | |
parent | b8d2859a9fece15dac2b9260d71dea39f8ce19b3 (diff) |
Tensor parsing improvements
- Mixed tensor format parsing (outside expressions)
- Validate structure of dense tensor strings
11 files changed, 393 insertions, 111 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 00750c70d2c..0601043f2ce 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -402,6 +402,12 @@ public class EvaluationTestCase { "{ {x:0}:7 }", "tensor(x{}):{ {x:0}:2 }"); tester.assertEvaluates("tensor<float>(d0[1],x[3]):[[1.0, 0.5, 0.25]]", "tensor<float>(d0[1],x[3]):[[one,one_half,a_quarter]]"); + tester.assertEvaluates("tensor(x[2],y[3]):[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]", + "tensor(x[2],y[3]):[[one,one_half,a_quarter],[a_quarter,one_half,one]]"); + tester.assertEvaluates("tensor(x{},y[2]):{{x:a,y:0}:1.0, {x:a,y:1}:0.5, {x:b,y:0}:0.25, {x:b,y:1}:2.0}", + "tensor(x{},y[2]):{{x:a,y:0}:one, {x:a,y:1}:one_half, {x:b,y:0}:a_quarter, {x:b,y:1}:2}"); + tester.assertEvaluates("tensor(x{},y[2]):{a:[1.0, 0.5], b:[0.25, 2]}", + "tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter, 2]}"); } @Test diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e991173805f..d91b38a8a96 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -776,15 +776,14 @@ }, "com.yahoo.tensor.IndexedTensor$BoundBuilder": { "superClass": "com.yahoo.tensor.IndexedTensor$Builder", - "interfaces": [], + "interfaces": [ + "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder" + ], "attributes": [ "public", "abstract" ], - "methods": [ - "public abstract void cellByDirectIndex(long, double)", - "public abstract void cellByDirectIndex(long, float)" - ], + "methods": [], "fields": [] }, "com.yahoo.tensor.IndexedTensor$Builder": { @@ -813,6 +812,21 @@ ], "fields": [] }, + "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public", + "interface", + "abstract" + ], + "methods": [ + "public abstract com.yahoo.tensor.TensorType type()", + "public abstract void cellByDirectIndex(long, double)", + "public abstract void cellByDirectIndex(long, float)" + ], + "fields": [] + }, "com.yahoo.tensor.IndexedTensor$Indexes": { "superClass": "java.lang.Object", "interfaces": [], @@ -829,7 +843,8 @@ "public java.util.List toList()", "public java.lang.String toString()", "public abstract long size()", - "public abstract void next()" + "public abstract void next()", + "public abstract boolean hasNext()" ], "fields": [ "protected final long[] indexes" @@ -943,6 +958,7 @@ ], "methods": [ "public long denseSubspaceSize()", + "public com.yahoo.tensor.IndexedTensor$DirectIndexBuilder denseSubspaceBuilder(com.yahoo.tensor.TensorAddress)", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", "public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])", @@ -1035,8 +1051,8 @@ ], "methods": [ "public void <init>(int)", - "public void add(java.lang.String, long)", - "public void add(java.lang.String, java.lang.String)", + "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, long)", + "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, java.lang.String)", "public com.yahoo.tensor.PartialAddress build()" ], "fields": [] @@ -1236,6 +1252,7 @@ "methods": [ "public void <init>()", "public static com.yahoo.tensor.TensorAddress of(java.lang.String[])", + "public static varargs com.yahoo.tensor.TensorAddress ofLabels(java.lang.String[])", "public static varargs com.yahoo.tensor.TensorAddress of(long[])", "public abstract int size()", "public abstract java.lang.String label(int)", @@ -1395,6 +1412,7 @@ "public" ], "methods": [ + "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)", "public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 176ddfefc13..30923976fa5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor { indexes.next(); // start brackets - for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++) + for (int i = 0; i < indexes.rightDimensionsAtStart(); i++) b.append("["); // value @@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); // end bracket and comma - for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++) + for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++) b.append("]"); if (index < size() - 1) b.append(", "); @@ -375,8 +375,22 @@ public abstract class IndexedTensor implements Tensor { } + public interface DirectIndexBuilder { + + TensorType type(); + + + + /** Sets a value by its <i>standard value order</i> index */ + void cellByDirectIndex(long index, double value); + + /** Sets a value by its <i>standard value order</i> index */ + void cellByDirectIndex(long index, float value); + + } + /** A bound builder can create the double array directly */ - public static abstract class BoundBuilder extends Builder { + public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder { private DimensionSizes sizes; @@ -393,14 +407,16 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; } - BoundBuilder fill(float [] values) { + + BoundBuilder fill(float[] values) { long index = 0; for (float value : values) { cellByDirectIndex(index++, value); } return this; } - BoundBuilder fill(double [] values) { + + BoundBuilder fill(double[] values) { long index = 0; for (double value : values) { cellByDirectIndex(index++, value); @@ -410,12 +426,6 @@ public abstract class IndexedTensor implements Tensor { DimensionSizes sizes() { return sizes; } - /** Sets a value by its <i>standard value order</i> index */ - public abstract void cellByDirectIndex(long index, double value); - - /** Sets a value by its <i>standard value order</i> index */ - public abstract void cellByDirectIndex(long index, float value); - } /** @@ -869,8 +879,11 @@ public abstract class IndexedTensor implements Tensor { public abstract void next(); + /** Returns whether further values are available by calling next() */ + public abstract boolean hasNext(); + /** Returns the number of dimensions from the right which are currently at the start position (0) */ - int rightDimensionsWhichAreAtStart() { + int rightDimensionsAtStart() { int dimension = indexes.length - 1; int atStartCount = 0; while (dimension >= 0 && indexes[dimension] == 0) { @@ -881,7 +894,7 @@ public abstract class IndexedTensor implements Tensor { } /** Returns the number of dimensions from the right which are currently at the end position */ - int rightDimensionsWhichAreAtEnd() { + int rightDimensionsAtEnd() { int dimension = indexes.length - 1; int atEndCount = 0; while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) { @@ -904,10 +917,15 @@ public abstract class IndexedTensor implements Tensor { @Override public void next() {} + @Override + public boolean hasNext() { return false; } + } private final static class SingleValueIndexes extends Indexes { + private boolean exhausted = false; + private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) { super(sourceSizes, iterateSizes, indexes); } @@ -916,7 +934,10 @@ public abstract class IndexedTensor implements Tensor { public long size() { return 1; } @Override - public void next() {} + public void next() { exhausted = true; } + + @Override + public boolean hasNext() { return ! exhausted; } } @@ -945,7 +966,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -957,6 +978,15 @@ public abstract class IndexedTensor implements Tensor { indexes[iterateDimensions.get(iterateDimensionsIndex)]++; } + @Override + public boolean hasNext() { + for (int iterateDimension : iterateDimensions) { + if (indexes[iterateDimension] + 1 < dimensionSizes().size(iterateDimension)) + return true; // some dimension is not at the end + } + return false; + } + } /** In this case we can reuse the source index computation for the iteration index */ @@ -1016,7 +1046,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -1031,6 +1061,11 @@ public abstract class IndexedTensor implements Tensor { @Override long toIterationValueIndex() { return currentIterationValueIndex; } + @Override + public boolean hasNext() { + return indexes[iterateDimension] + 1 < size; + } + } /** In this case we only need to keep track of one index */ @@ -1068,7 +1103,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -1077,6 +1112,11 @@ public abstract class IndexedTensor implements Tensor { } @Override + public boolean hasNext() { + return indexes[iterateDimension] + 1 < size; + } + + @Override long toSourceValueIndex() { return currentValueIndex; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1cde1fcdbb7..0c4efe78113 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -217,25 +217,34 @@ public class MixedTensor implements Tensor { public static class BoundBuilder extends Builder { /** For each sparse partial address, hold a dense subspace */ - final private Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); - final private Index.Builder indexBuilder; - final private Index index; + private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); + private final Index.Builder indexBuilder; + private final Index index; + private final TensorType denseSubtype; private BoundBuilder(TensorType type) { super(type); indexBuilder = new Index.Builder(type); index = indexBuilder.index(); + denseSubtype = new TensorType(type.valueType(), + type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList())); } public long denseSubspaceSize() { return index.denseSubspaceSize(); } - private double[] denseSubspace(TensorAddress sparsePartial) { - if (!denseSubspaceMap.containsKey(sparsePartial)) { - denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]); + private double[] denseSubspace(TensorAddress sparseAddress) { + if (!denseSubspaceMap.containsKey(sparseAddress)) { + denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]); } - return denseSubspaceMap.get(sparsePartial); + return denseSubspaceMap.get(sparseAddress); + } + + public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) { + double[] values = new double[(int)denseSubspaceSize()]; + denseSubspaceMap.put(sparseAddress, values); + return new DenseSubspaceBuilder(denseSubtype, values); } @Override @@ -280,7 +289,6 @@ public class MixedTensor implements Tensor { } - /** * Temporarily stores all cells to find bounds of indexed dimensions, * then creates a tensor using BoundBuilder. This is due to the @@ -491,6 +499,31 @@ public class MixedTensor implements Tensor { } + private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder { + + private final TensorType type; + private final double[] values; + + public DenseSubspaceBuilder(TensorType type, double[] values) { + this.type = type; + this.values = values; + } + + @Override + public TensorType type() { return type; } + + @Override + public void cellByDirectIndex(long index, double value) { + values[(int)index] = value; + } + + @Override + public void cellByDirectIndex(long index, float value) { + values[(int)index] = value; + } + + } + public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { TensorType.Builder builder = new TensorType.Builder(valueType); for (TensorType.Dimension dimension : dimensions) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index 4eca9c47402..84f26d96725 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -122,16 +122,18 @@ public class PartialAddress { labels = new Object[size]; } - public void add(String dimensionName, long label) { + public Builder add(String dimensionName, long label) { dimensionNames[index] = dimensionName; labels[index] = label; index++; + return this; } - public void add(String dimensionName, String label) { + public Builder add(String dimensionName, String label) { dimensionNames[index] = dimensionName; labels[index] = label; index++; + return this; } public PartialAddress build() { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 52256293a5b..43d1bb0e468 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -18,6 +18,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return new StringTensorAddress(labels); } + public static TensorAddress ofLabels(String ... labels) { + return new StringTensorAddress(labels); + } + public static TensorAddress of(long ... labels) { return new NumericTensorAddress(labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 4d8b34b7dcf..04d3295795f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.List; import java.util.Optional; /** @@ -9,6 +10,16 @@ import java.util.Optional; class TensorParser { static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) { + try { + return tensorFromBody(tensorString, explicitType); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" + + (explicitType.isPresent() ? " of type " + explicitType.get() : ""), + e); + } + } + + static Tensor tensorFromBody(String tensorString, Optional<TensorType> explicitType) { Optional<TensorType> type; String valueString; @@ -29,9 +40,13 @@ class TensorParser { } valueString = valueString.trim(); - if (valueString.startsWith("{")) { + if (valueString.startsWith("{") && + (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) { return tensorFromSparseValueString(valueString, type); } + else if (valueString.startsWith("{")) { + return tensorFromMixedValueString(valueString, type); + } else if (valueString.startsWith("[")) { return tensorFromDenseValueString(valueString, type); } @@ -54,8 +69,7 @@ class TensorParser { String s = valueString.substring(1).trim(); // remove tensor start int firstKeyOrTensorEnd = s.indexOf('}'); if (firstKeyOrTensorEnd < 0) - throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" + - valueString + "'"); + throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'"); String addressBody = s.substring(0, firstKeyOrTensorEnd).trim(); if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor @@ -79,73 +93,51 @@ class TensorParser { try { valueString = valueString.trim(); Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString))); - return fromCellString(builder, valueString); + return tensorFromSparseCellString(builder, valueString); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + - valueString + "'"); + throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('"); } } - private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) { if (type.isEmpty()) - throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + + throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); - if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty()))) - throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + - "only dense dimensions with a given size"); + if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1) + throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " + + "but got " + type.get()); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get()); - long index = 0; - int currentChar; - int nextNumberEnd = 0; - // Since we know the dimensions the brackets are just syntactic sugar: - while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) { - nextNumberEnd = nextStopCharIndex(currentChar, valueString); - if (currentChar == nextNumberEnd) return builder.build(); - TensorType.Value cellValueType = builder.type().valueType(); - String cellValueString = valueString.substring(currentChar, nextNumberEnd); - try { - if (cellValueType == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(index, Double.parseDouble(cellValueString)); - else if (cellValueType == TensorType.Value.FLOAT) - builder.cellByDirectIndex(index, Float.parseFloat(cellValueString)); - else - throw new IllegalArgumentException(cellValueType + " is not supported"); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("At index " + index + ": '" + - cellValueString + "' is not a valid " + cellValueType); - } - index++; + try { + valueString = valueString.trim(); + if ( ! valueString.startsWith("{") && valueString.endsWith("}")) + throw new IllegalArgumentException("A mixed tensor must be enclosed in {}"); + // TODO: Check if there is also at least one bound indexed dimension + MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get()); + MixedParser parser = new MixedParser(valueString, builder); + parser.parse(); + return builder.build(); } - return builder.build(); - } - - /** Returns the position of the next character that should contain a number, or if none the string length */ - private static int nextStartCharIndex(int charIndex, String valueString) { - for (; charIndex < valueString.length(); charIndex++) { - if (valueString.charAt(charIndex) == ']') continue; - if (valueString.charAt(charIndex) == '[') continue; - if (valueString.charAt(charIndex) == ',') continue; - if (valueString.charAt(charIndex) == ' ') continue; - return charIndex; + catch (NumberFormatException e) { + throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('"); } - return valueString.length(); } - private static int nextStopCharIndex(int charIndex, String valueString) { - while (charIndex < valueString.length()) { - if (valueString.charAt(charIndex) == ',') return charIndex; - if (valueString.charAt(charIndex) == ']') return charIndex; - charIndex++; - } - throw new IllegalArgumentException("Malformed tensor value '" + valueString + - "': Expected a ',' or ']' after position " + charIndex); + private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + if (type.isEmpty()) + throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + + "on the form 'tensor(dimensions):..."); + if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty()))) + throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + + "only dense dimensions with a given size"); + + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get()); + new DenseParser(valueString, builder).parse(); + return builder.build(); } - private static Tensor fromCellString(Tensor.Builder builder, String s) { + private static Tensor tensorFromSparseCellString(Tensor.Builder builder, String s) { int index = 1; index = skipSpace(index, s); while (index + 1 < s.length()) { @@ -194,6 +186,16 @@ class TensorParser { return index; } + private static int nextStopCharIndex(int charIndex, String valueString) { + while (charIndex < valueString.length()) { + if (valueString.charAt(charIndex) == ',') return charIndex; + if (valueString.charAt(charIndex) == ']') return charIndex; + charIndex++; + } + throw new IllegalArgumentException("Malformed tensor value '" + valueString + + "': Expected a ',' or ']' after position " + charIndex); + } + /** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */ private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { mapAddressString = mapAddressString.trim(); @@ -213,4 +215,157 @@ class TensorParser { } } + private static abstract class ValueParser { + + protected final String string; + protected int position = 0; + + protected ValueParser(String string) { + this.string = string; + } + + protected void skipSpace() { + while (position < string.length() && string.charAt(position) == ' ') + position++; + } + + protected void consume(char character) { + skipSpace(); + + if (position >= string.length()) + throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + "' but got the end of the string"); + if ( string.charAt(position) != character) + throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + "' but got '" + string.charAt(position) + "'"); + position++; + } + + } + + /** A single-use dense tensor string parser */ + private static class DenseParser extends ValueParser { + + private final IndexedTensor.DirectIndexBuilder builder; + private final IndexedTensor.Indexes indexes; + private final boolean hasInnerStructure; + + private long tensorIndex = 0; + + public DenseParser(String string, IndexedTensor.DirectIndexBuilder builder) { + super(string); + this.builder = builder; + indexes = IndexedTensor.Indexes.of(builder.type()); + hasInnerStructure = hasInnerStructure(string); + } + + public void parse() { + if (!hasInnerStructure) + consume('['); + + while (indexes.hasNext()) { + indexes.next(); + + for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++) + consume('['); + + consumeNumber(); + + for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++) + consume(']'); + + if (indexes.hasNext()) + consume(','); + } + + if (!hasInnerStructure) + consume(']'); + } + + public int position() { return position; } + + /** Are there inner square brackets in this or is it just a flat list of numbers until ']'? */ + private static boolean hasInnerStructure(String valueString) { + valueString = valueString.trim(); + valueString = valueString.substring(1); + int firstLeftBracket = valueString.indexOf('['); + return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(']'); + } + + private void consumeNumber() { + skipSpace(); + + int nextNumberEnd = nextStopCharIndex(position, string); + TensorType.Value cellValueType = builder.type().valueType(); + String cellValueString = string.substring(position, nextNumberEnd); + try { + if (cellValueType == TensorType.Value.DOUBLE) + builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString)); + else if (cellValueType == TensorType.Value.FLOAT) + builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString)); + else + throw new IllegalArgumentException(cellValueType + " is not supported"); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("At position " + position + ": '" + + cellValueString + "' is not a valid " + cellValueType); + } + position = nextNumberEnd; + } + + } + + private static class MixedParser extends ValueParser { + + private final MixedTensor.BoundBuilder builder; + + public MixedParser(String string, MixedTensor.BoundBuilder builder) { + super(string); + this.builder = builder; + } + + private void parse() { + TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension)); + + skipSpace(); + consume('{'); + skipSpace(); + while (position + 1 < string.length()) { + int labelEnd = string.indexOf(':', position); + if (labelEnd <= position) + throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...} "); + String label = string.substring(position, labelEnd); + position = labelEnd + 1; + skipSpace(); + + TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build(); + parseDenseSubspace(sparseAddress); + if ( ! consumeOptional(',')) + consume('}'); + skipSpace(); + } + } + + private void parseDenseSubspace(TensorAddress sparseAddress) { + DenseParser denseParser = new DenseParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress)); + denseParser.parse(); + position+= denseParser.position(); + } + + private boolean consumeOptional(char character) { + skipSpace(); + + if (position >= string.length()) + return false; + if ( string.charAt(position) != character) + return false; + + position++; + return true; + } + + + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 95cc70804e2..ca3f8ff28a4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -82,7 +82,7 @@ public class TensorType { private final TensorType mappedSubtype; - private TensorType(Value valueType, Collection<Dimension> dimensions) { + public TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 1928971820c..b2aba5b02eb 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -22,6 +22,12 @@ public class TensorParserTestCase { } @Test + public void testSingle() { + assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(), + "tensor(x[1]):[1.0]"); + } + + @Test public void testDenseParsing() { assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(), "tensor():{0.0}"); @@ -55,18 +61,9 @@ public class TensorParserTestCase { .cell(3.0, 1, 0, 0) .cell(4.0, 1, 1, 0) .cell(5.0, 2, 0, 0) - .cell(6.0, 2, 1, 0).build(), - "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]"); - assertEquals("Messy input", - Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) - .cell( 1.0, 0, 0, 0) - .cell( 2.0, 0, 1, 0) - .cell( 3.0, 1, 0, 0) - .cell( 4.0, 1, 1, 0) - .cell( 5.0, 2, 0, 0) .cell(-6.0, 2, 1, 0).build(), - Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ] ]")); - assertEquals("Skipping syntactic sugar", + "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [-6.0]]]"); + assertEquals("Skipping structure", Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) .cell( 1.0, 0, 0, 0) .cell( 2.0, 0, 1, 0) @@ -77,6 +74,16 @@ public class TensorParserTestCase { Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]")); } + @Test + public void testMixedParsing() { + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])")) + .cell(TensorAddress.ofLabels("a", "0"), 1) + .cell(TensorAddress.ofLabels("a", "1"), 2) + .cell(TensorAddress.ofLabels("b", "0"), 3) + .cell(TensorAddress.ofLabels("b", "1"), 4).build(), + Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}")); + } + private void assertDense(Tensor expectedTensor, String denseFormat) { assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat)); assertEquals(denseFormat, expectedTensor.toString()); @@ -92,7 +99,7 @@ public class TensorParserTestCase { "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); assertIllegal("At {x:0}: '1-.0' is not a valid double", "{{x:0}:1-.0}"); - assertIllegal("At index 0: '1-.0' is not a valid double", + assertIllegal("At position 1: '1-.0' is not a valid double", "tensor(x[1]):[1-.0]"); } @@ -102,7 +109,7 @@ public class TensorParserTestCase { fail("Expected an IllegalArgumentException when parsing " + tensor); } catch (IllegalArgumentException e) { - assertEquals(message, e.getMessage()); + assertEquals(message, e.getCause().getMessage()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 11365531019..9f077cb7b00 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -56,7 +56,8 @@ public class TensorTestCase { fail("Expected parse error"); } catch (IllegalArgumentException expected) { - assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", expected.getMessage()); + assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", + expected.getCause().getMessage()); } } @@ -259,9 +260,9 @@ public class TensorTestCase { assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0", "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}"); assertLargest("{x:1,y:1}:4.0", - "tensor(x[2],y[2]):[[1,2],[3,4]"); + "tensor(x[2],y[2]):[[1,2],[3,4]]"); assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0", - "tensor(x[2],y[2]):[[4,2],[3,4]"); + "tensor(x[2],y[2]):[[4,2],[3,4]]"); } @Test @@ -273,9 +274,9 @@ public class TensorTestCase { assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0", "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}"); assertSmallest("{x:0,y:0}:1.0", - "tensor(x[2],y[2]):[[1,2],[3,4]"); + "tensor(x[2],y[2]):[[1,2],[3,4]]"); assertSmallest("{x:0,y:1}:2.0", - "tensor(x[2],y[2]):[[4,2],[3,4]"); + "tensor(x[2],y[2]):[[4,2],[3,4]]"); } private void assertLargest(String expectedCells, String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index e16b7b90a1d..7cddeab1641 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.Name; import org.junit.Test; import java.util.Collections; +import java.util.HashMap; import java.util.List; import static org.junit.Assert.assertEquals; @@ -19,21 +20,36 @@ import static org.junit.Assert.assertEquals; public class DynamicTensorTestCase { @Test - public void testDynamicTensorFunction() { + public void testDynamicIndexedRank1TensorFunction() { TensorType dense = TensorType.fromSpec("tensor(x[3])"); DynamicTensor<Name> t1 = DynamicTensor.from(dense, List.of(new Constant(1), new Constant(2), new Constant(3))); assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString()); + } + @Test + public void testDynamicMappedRank1TensorFunction() { TensorType sparse = TensorType.fromSpec("tensor(x{})"); DynamicTensor<Name> t2 = DynamicTensor.from(sparse, Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), - new Constant(5))); + new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } + @Test + public void testDynamicMappedRank2TensorFunction() { + TensorType sparse = TensorType.fromSpec("tensor(x{},y{})"); + HashMap<TensorAddress, ScalarFunction<Name>> values = new HashMap<>(); + values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "b").build(), + new Constant(5)); + values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "c").build(), + new Constant(7)); + DynamicTensor<Name> t2 = DynamicTensor.from(sparse, values); + assertEquals(Tensor.from(sparse, "{{x:a,y:b}:5, {x:a,y:c}:7}"), t2.evaluate()); + } + private static class Constant implements ScalarFunction<Name> { private final double value; |