diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
commit | f5ccf036b4f7368f217a6bcbffc1699aac5eac2d (patch) | |
tree | 749afd3b29f52b918c67099c1742cb9db50211cf | |
parent | 3954dbe2403bdbb21e9a558fbc55fd137afa40f8 (diff) |
Interpret dimensions in written order
9 files changed, 241 insertions, 88 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 11fc581640d..a248fa6dd45 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -19,6 +19,7 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.LinkedHashMap; @@ -93,14 +94,15 @@ public class TensorFunctionNode extends CompositeNode { } public static void wrapScalarBlock(TensorType type, + List<String> dimensionOrder, String mappedDimensionLabel, List<ExpressionNode> nodes, Map<TensorAddress, ScalarFunction<Reference>> receivingMap) { - TensorType.Dimension sparseDimension = type.dimensions().stream().filter(d -> ! d.isIndexed()).findFirst().get(); TensorType denseSubtype = new TensorType(type.valueType(), type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList())); - - IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype); + List<String> denseDimensionOrder = new ArrayList<>(dimensionOrder); + denseDimensionOrder.retainAll(denseSubtype.dimensionNames()); + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype, denseDimensionOrder); for (ExpressionNode node : nodes) { indexes.next(); @@ -119,7 +121,15 @@ public class TensorFunctionNode extends CompositeNode { } } - public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) { + public static List<ScalarFunction<Reference>> wrapScalars(TensorType type, + List<String> dimensionOrder, + List<ExpressionNode> nodes) { + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type, dimensionOrder); + List<ScalarFunction<Reference>> wrapped = new ArrayList<>(); + while (indexes.hasNext()) { + indexes.next(); + wrapped.add(wrapScalar(nodes.get((int)indexes.toSourceValueIndex()))); + } return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 71456d0ed00..22d2abd4aef 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -475,13 +475,14 @@ TensorFunctionNode tensorConcat() : TensorFunctionNode tensorGenerate() : { TensorType type; + List dimensionOrder = new ArrayList(); TensorFunctionNode expression; } { - <TENSOR> type = tensorType() + <TENSOR> type = tensorType(dimensionOrder) ( expression = tensorGenerateBody(type) | - expression = tensorValueBody(type) + expression = tensorValueBody(type, dimensionOrder) ) { return expression; } } @@ -500,7 +501,7 @@ TensorFunctionNode tensorRange() : TensorType type; } { - <RANGE> type = tensorType() + <RANGE> type = tensorType(null) { return new TensorFunctionNode(new Range(type)); } } @@ -509,7 +510,7 @@ TensorFunctionNode tensorDiag() : TensorType type; } { - <DIAG> type = tensorType() + <DIAG> type = tensorType(null) { return new TensorFunctionNode(new Diag(type)); } } @@ -518,7 +519,7 @@ TensorFunctionNode tensorRandom() : TensorType type; } { - <RANDOM> type = tensorType() + <RANDOM> type = tensorType(null) { return new TensorFunctionNode(new Random(type)); } } @@ -618,7 +619,7 @@ Reduce.Aggregator tensorReduceAggregator() : { return Reduce.Aggregator.valueOf(token.image); } } -TensorType tensorType() : +TensorType tensorType(List dimensionOrder) : { TensorType.Builder builder; TensorType.Value valueType; @@ -627,8 +628,8 @@ TensorType tensorType() : valueType = optionalTensorValueTypeParameter() { builder = new TensorType.Builder(valueType); } <LBRACE> - ( tensorTypeDimension(builder) ) ? - ( <COMMA> tensorTypeDimension(builder) ) * + ( tensorTypeDimension(builder, dimensionOrder) ) ? + ( <COMMA> tensorTypeDimension(builder, dimensionOrder) ) * <RBRACE> { return builder.build(); } } @@ -642,13 +643,17 @@ TensorType.Value optionalTensorValueTypeParameter() : { return TensorType.Value.fromId(valueType); } } -void tensorTypeDimension(TensorType.Builder builder) : +void tensorTypeDimension(TensorType.Builder builder, List dimensionOrder) : { String name; int size; } { name = identifier() + { // Keep track of the order in which dimensions are written, if necessary + if (dimensionOrder != null) + dimensionOrder.add(name); + } ( ( <LCURLY> <RCURLY> { builder.mapped(name); } ) | LOOKAHEAD(2) ( <LSQUARE> <RSQUARE> { builder.indexed(name); } ) | @@ -831,16 +836,16 @@ Value primitiveValue() : { return Value.parse(sign + token.image); } } -TensorFunctionNode tensorValueBody(TensorType type) : +TensorFunctionNode tensorValueBody(TensorType type, List dimensionOrder) : { DynamicTensor dynamicTensor; } { <COLON> ( - LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type) | + LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type, dimensionOrder) | dynamicTensor = mappedTensorValueBody(type) | - dynamicTensor = indexedTensorValueBody(type) + dynamicTensor = indexedTensorValueBody(type, dimensionOrder) ) { return new TensorFunctionNode(dynamicTensor); } } @@ -857,35 +862,35 @@ DynamicTensor mappedTensorValueBody(TensorType type) : { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } -DynamicTensor mixedTensorValueBody(TensorType type) : +DynamicTensor mixedTensorValueBody(TensorType type, List dimensionOrder) : { java.util.Map cells = new LinkedHashMap(); } { <LCURLY> - mixedBlock(type, cells) - ( <COMMA> mixedBlock(type, cells))* + mixedBlock(type, dimensionOrder, cells) + ( <COMMA> mixedBlock(type, dimensionOrder, cells))* <RCURLY> { return DynamicTensor.from(type, cells); } } -DynamicTensor indexedTensorValueBody(TensorType type) : +DynamicTensor indexedTensorValueBody(TensorType type, List dimensionOrder) : { List cells; } { cells = indexedTensorCells() - { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells, type, dimensionOrder)); } } -void mixedBlock(TensorType type, java.util.Map cellMap) : +void mixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) : { String label; List cells; } { label = tag() <COLON> cells = indexedTensorCells() - { TensorFunctionNode.wrapScalarBlock(type, label, cells, cellMap); } + { TensorFunctionNode.wrapScalarBlock(type, dimensionOrder, label, cells, cellMap); } } List indexedTensorCells() : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 0601043f2ce..fa65ce0408b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -408,6 +408,29 @@ public class EvaluationTestCase { "tensor(x{},y[2]):{{x:a,y:0}:one, {x:a,y:1}:one_half, {x:b,y:0}:a_quarter, {x:b,y:1}:2}"); tester.assertEvaluates("tensor(x{},y[2]):{a:[1.0, 0.5], b:[0.25, 2]}", "tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter, 2]}"); + tester.assertEvaluates("tensor(key{},x[2],y[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," + + " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}", + "tensor(key{},x[2],y[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," + + " key2:[[1,2,3],[4,5,6]]}"); + + // Opposite order in the expression: + // - indexed + tester.assertEvaluates("tensor(x[3],y[2]):[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]", + "tensor(y[2],x[3]):[[one,one_half,a_quarter],[a_quarter,one_half,one]]"); + // - mixed + tester.assertEvaluates("tensor(key{},x[3],y[2]):{key1:[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]," + + " key2:[[1.0, 4.00], [2.0,5.0], [3.00, 6.0]]}", + "tensor(key{},y[2],x[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," + + " key2:[[1,2,3],[4,5,6]]}"); + // Opposite order in literal parsing: + // - indexed + tester.assertEvaluates("tensor(y[2],x[3]):[[1,0.25,0.5],[0.5,0.25,1]]", + "tensor(x[3],y[2]):[[one,one_half], [a_quarter,a_quarter], [one_half,one]]"); + // - mixed + tester.assertEvaluates("tensor(key{},y[2],x[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," + + " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}", + "tensor(key{},x[3],y[2]):{key1:[[one,a_quarter],[one_half,one_half],[a_quarter,one]]," + + " key2:[[1,4],[2,5],[3,6]]}"); } @Test diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index d91b38a8a96..1fcdf7f5cca 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -693,6 +693,7 @@ "methods": [ "public void <init>(int)", "public com.yahoo.tensor.DimensionSizes$Builder set(int, long)", + "public com.yahoo.tensor.DimensionSizes$Builder add(long)", "public long size(int)", "public int dimensions()", "public com.yahoo.tensor.DimensionSizes build()" @@ -836,10 +837,12 @@ ], "methods": [ "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)", + "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType, java.util.List)", "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)", "public com.yahoo.tensor.TensorAddress toAddress()", "public long[] indexesCopy()", "public long[] indexesForReading()", + "public long toSourceValueIndex()", "public java.util.List toList()", "public java.lang.String toString()", "public abstract long size()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index d81c02fb75f..202817ece42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -71,6 +71,7 @@ public final class DimensionSizes { */ public final static class Builder { + private int dimensionIndex = 0; private long[] sizes; public Builder(int dimensions) { @@ -82,6 +83,11 @@ public final class DimensionSizes { return this; } + public Builder add(long size) { + sizes[dimensionIndex++] = size; + return this; + } + /** * Returns the length of this in the nth dimension * diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 30923976fa5..ba3a35e8eda 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor { indexes.next(); // start brackets - for (int i = 0; i < indexes.rightDimensionsAtStart(); i++) + for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) b.append("["); // value @@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); // end bracket and comma - for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++) + for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); if (index < size() - 1) b.append(", "); @@ -777,6 +777,10 @@ public abstract class IndexedTensor implements Tensor { return of(DimensionSizes.of(type)); } + public static Indexes of(TensorType type, List<String> iterateDimensionOrder) { + return of(DimensionSizes.of(type), toIterationOrder(iterateDimensionOrder, type)); + } + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -789,6 +793,10 @@ public abstract class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size); } + private static Indexes of(DimensionSizes sizes, List<Integer> iterateDimensions) { + return of(sizes, sizes, iterateDimensions); + } + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) { return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions)); } @@ -822,6 +830,16 @@ public abstract class IndexedTensor implements Tensor { } } + private static List<Integer> toIterationOrder(List<String> dimensionNames, TensorType type) { + if (dimensionNames == null) return completeIterationOrder(type.rank()); + + List<Integer> iterationDimensions = new ArrayList<>(type.rank()); + for (int i = 0; i < type.rank(); i++) + iterationDimensions.add(type.rank() - 1 - type.indexOfDimension(dimensionNames.get(i)).get()); + return iterationDimensions; + } + + /** Since the right dimensions binds closest, iteration order is the opposite of the tensor order */ private static List<Integer> completeIterationOrder(int length) { List<Integer> iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) @@ -854,7 +872,7 @@ public abstract class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public long[] indexesForReading() { return indexes; } - long toSourceValueIndex() { + public long toSourceValueIndex() { return IndexedTensor.toValueIndex(indexes, sourceSizes); } @@ -882,27 +900,12 @@ public abstract class IndexedTensor implements Tensor { /** Returns whether further values are available by calling next() */ public abstract boolean hasNext(); - /** Returns the number of dimensions from the right which are currently at the start position (0) */ - int rightDimensionsAtStart() { - int dimension = indexes.length - 1; - int atStartCount = 0; - while (dimension >= 0 && indexes[dimension] == 0) { - atStartCount++; - dimension--; - } - return atStartCount; - } + /** Returns the number of dimensions in iteration order which are currently at the start position (0) */ + abstract int nextDimensionsAtStart(); + + /** Returns the number of dimensions in iteration order which are currently at their end position */ + abstract int nextDimensionsAtEnd(); - /** Returns the number of dimensions from the right which are currently at the end position */ - int rightDimensionsAtEnd() { - int dimension = indexes.length - 1; - int atEndCount = 0; - while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) { - atEndCount++; - dimension--; - } - return atEndCount; - } } private final static class EmptyIndexes extends Indexes { @@ -920,6 +923,12 @@ public abstract class IndexedTensor implements Tensor { @Override public boolean hasNext() { return false; } + @Override + int nextDimensionsAtStart() { return 0; } + + @Override + int nextDimensionsAtEnd() { return 0; } + } private final static class SingleValueIndexes extends Indexes { @@ -939,6 +948,12 @@ public abstract class IndexedTensor implements Tensor { @Override public boolean hasNext() { return ! exhausted; } + @Override + int nextDimensionsAtStart() { return 1; } + + @Override + int nextDimensionsAtEnd() { return 1; } + } private static class MultiDimensionIndexes extends Indexes { @@ -987,6 +1002,22 @@ public abstract class IndexedTensor implements Tensor { return false; } + @Override + int nextDimensionsAtStart() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == 0) + dimension++; + return dimension; + } + + @Override + int nextDimensionsAtEnd() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == dimensionSizes().size(iterateDimensions.get(dimension)) - 1) + dimension++; + return dimension; + } + } /** In this case we can reuse the source index computation for the iteration index */ @@ -999,7 +1030,7 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { + public long toSourceValueIndex() { return lastComputedSourceValueIndex = super.toSourceValueIndex(); } @@ -1056,7 +1087,7 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentSourceValueIndex; } + public long toSourceValueIndex() { return currentSourceValueIndex; } @Override long toIterationValueIndex() { return currentIterationValueIndex; } @@ -1066,6 +1097,12 @@ public abstract class IndexedTensor implements Tensor { return indexes[iterateDimension] + 1 < size; } + @Override + int nextDimensionsAtStart() { return currentSourceValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentSourceValueIndex == size - 1 ? 1 : 0; } + } /** In this case we only need to keep track of one index */ @@ -1117,11 +1154,17 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentValueIndex; } + public long toSourceValueIndex() { return currentValueIndex; } @Override long toIterationValueIndex() { return currentValueIndex; } + @Override + int nextDimensionsAtStart() { return currentValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentValueIndex == size - 1 ? 1 : 0; } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 8d07a1ed9a8..ea21249bede 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -23,11 +24,17 @@ class TensorParser { Optional<TensorType> type; String valueString; + // The order in which dimensions are written in the type string. + // This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than + // the natural order of the tensor. + List<String> dimensionOrder; + tensorString = tensorString.trim(); if (tensorString.startsWith("tensor")) { int colonIndex = tensorString.indexOf(':'); String typeString = tensorString.substring(0, colonIndex); - TensorType typeFromString = TensorTypeParser.fromSpec(typeString); + dimensionOrder = new ArrayList<>(); + TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder); if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + "passed type " + explicitType.get()); @@ -37,6 +44,7 @@ class TensorParser { else { type = explicitType; valueString = tensorString; + dimensionOrder = null; } valueString = valueString.trim(); @@ -45,10 +53,10 @@ class TensorParser { return tensorFromSparseValueString(valueString, type); } else if (valueString.startsWith("{")) { - return tensorFromMixedValueString(valueString, type); + return tensorFromMixedValueString(valueString, type, dimensionOrder); } else if (valueString.startsWith("[")) { - return tensorFromDenseValueString(valueString, type); + return tensorFromDenseValueString(valueString, type, dimensionOrder); } else { if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty)) @@ -102,7 +110,9 @@ class TensorParser { } } - private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromMixedValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { if (type.isEmpty()) throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); @@ -117,7 +127,7 @@ class TensorParser { throw new IllegalArgumentException("A mixed tensor must be enclosed in {}"); // TODO: Check if there is also at least one bound indexed dimension MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get()); - MixedValueParser parser = new MixedValueParser(valueString, builder); + MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder); parser.parse(); return builder.build(); } @@ -126,7 +136,9 @@ class TensorParser { } } - private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromDenseValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { if (type.isEmpty()) throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); @@ -135,7 +147,7 @@ class TensorParser { "only dense dimensions with a given size"); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get()); - new DenseValueParser(valueString, builder).parse(); + new DenseValueParser(valueString, dimensionOrder, builder).parse(); return builder.build(); } @@ -157,10 +169,10 @@ class TensorParser { skipSpace(); if (position >= string.length()) - throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + "' but got the end of the string"); if ( string.charAt(position) != character) - throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + "' but got '" + string.charAt(position) + "'"); position++; } @@ -176,10 +188,12 @@ class TensorParser { private long tensorIndex = 0; - public DenseValueParser(String string, IndexedTensor.DirectIndexBuilder builder) { + public DenseValueParser(String string, + List<String> dimensionOrder, + IndexedTensor.DirectIndexBuilder builder) { super(string); this.builder = builder; - indexes = IndexedTensor.Indexes.of(builder.type()); + indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder); hasInnerStructure = hasInnerStructure(string); } @@ -189,10 +203,10 @@ class TensorParser { while (indexes.hasNext()) { indexes.next(); - for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++) + for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++) consume('['); consumeNumber(); - for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++) + for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++) consume(']'); if (indexes.hasNext()) consume(','); @@ -220,14 +234,14 @@ class TensorParser { String cellValueString = string.substring(position, nextNumberEnd); try { if (cellValueType == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString)); + builder.cellByDirectIndex(indexes.toSourceValueIndex(), Double.parseDouble(cellValueString)); else if (cellValueType == TensorType.Value.FLOAT) - builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString)); + builder.cellByDirectIndex(indexes.toSourceValueIndex(), Float.parseFloat(cellValueString)); else throw new IllegalArgumentException(cellValueType + " is not supported"); } catch (NumberFormatException e) { - throw new IllegalArgumentException("At position " + position + ": '" + + throw new IllegalArgumentException("At value position " + position + ": '" + cellValueString + "' is not a valid " + cellValueType); } position = nextNumberEnd; @@ -248,15 +262,19 @@ class TensorParser { private static class MixedValueParser extends ValueParser { private final MixedTensor.BoundBuilder builder; + private List<String> dimensionOrder; - public MixedValueParser(String string, MixedTensor.BoundBuilder builder) { + public MixedValueParser(String string, List<String> dimensionOrder, MixedTensor.BoundBuilder builder) { super(string); + this.dimensionOrder = dimensionOrder; this.builder = builder; } private void parse() { - TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); - TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension)); + TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension)); + if (dimensionOrder != null) + dimensionOrder.remove(mappedDimension.name()); skipSpace(); consume('{'); @@ -269,16 +287,18 @@ class TensorParser { position = labelEnd + 1; skipSpace(); - TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build(); - parseDenseSubspace(sparseAddress); + TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build(); + parseDenseSubspace(mappedAddress, dimensionOrder); if ( ! consumeOptional(',')) consume('}'); skipSpace(); } } - private void parseDenseSubspace(TensorAddress sparseAddress) { - DenseValueParser denseParser = new DenseValueParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress)); + private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) { + DenseValueParser denseParser = new DenseValueParser(string.substring(position), + denseDimensionOrder, + builder.denseSubspaceBuilder(sparseAddress)); denseParser.parse(); position+= denseParser.position(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index def3ab6b4ec..4fdb0906740 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -24,6 +24,13 @@ public class TensorTypeParser { private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}"); public static TensorType fromSpec(String specString) { + return fromSpec(specString, null); + } + + /** + * @param dimensionOrder if not null, this will be populated with the dimension names in the order they are written + */ + static TensorType fromSpec(String specString, List<String> dimensionOrder) { specString = specString.trim(); if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING)) throw formatException(specString); @@ -48,10 +55,14 @@ public class TensorTypeParser { List<TensorType.Dimension> dimensions = new ArrayList<>(); for (String element : dimensionsSpec.split(",")) { String trimmedElement = element.trim(); - boolean success = tryParseIndexedDimension(trimmedElement, dimensions) || - tryParseMappedDimension(trimmedElement, dimensions); - if ( ! success) + TensorType.Dimension dimension = tryParseIndexedDimension(trimmedElement); + if (dimension == null) + dimension = tryParseMappedDimension(trimmedElement); + if (dimension == null) throw formatException(specString, "Dimension '" + element + "' is on the wrong format"); + dimensions.add(dimension); + if (dimensionOrder != null) + dimensionOrder.add(dimension.name()); } return new TensorType.Builder(valueType, dimensions).build(); } @@ -68,29 +79,26 @@ public class TensorTypeParser { } } - private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) { + private static TensorType.Dimension tryParseIndexedDimension(String element) { Matcher matcher = indexedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); String dimensionSize = matcher.group(2); - if (dimensionSize.isEmpty()) { - dimensions.add(TensorType.Dimension.indexed(dimensionName)); - } else { - dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize))); - } - return true; + if (dimensionSize.isEmpty()) + return TensorType.Dimension.indexed(dimensionName); + else + return TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)); } - return false; + return null; } - private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) { + private static TensorType.Dimension tryParseMappedDimension(String element) { Matcher matcher = mappedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); - dimensions.add(TensorType.Dimension.mapped(dimensionName)); - return true; + return TensorType.Dimension.mapped(dimensionName); } - return false; + return null; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index b2aba5b02eb..9dfdee29845 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -75,6 +75,19 @@ public class TensorParserTestCase { } @Test + public void testDenseWrongOrder() { + assertEquals("Opposite order of dimensions", + Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2])")) + .cell(1, 0, 0) + .cell(4, 0, 1) + .cell(2, 1, 0) + .cell(5, 1, 1) + .cell(3, 2, 0) + .cell(6, 2, 1).build(), + Tensor.from("tensor(y[2],x[3]):[[1,2,3],[4,5,6]]")); + } + + @Test public void testMixedParsing() { assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])")) .cell(TensorAddress.ofLabels("a", "0"), 1) @@ -84,6 +97,28 @@ public class TensorParserTestCase { Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}")); } + @Test + public void testMixedWrongOrder() { + assertEquals("Opposite order of dimensions", + Tensor.Builder.of(TensorType.fromSpec("tensor(key{},x[3],y[2])")) + .cell(TensorAddress.ofLabels("key1", "0", "0"), 1) + .cell(TensorAddress.ofLabels("key1", "0", "1"), 4) + .cell(TensorAddress.ofLabels("key1", "1", "0"), 2) + .cell(TensorAddress.ofLabels("key1", "1", "1"), 5) + .cell(TensorAddress.ofLabels("key1", "2", "0"), 3) + .cell(TensorAddress.ofLabels("key1", "2", "1"), 6) + .cell(TensorAddress.ofLabels("key2", "0", "0"), 7) + .cell(TensorAddress.ofLabels("key2", "0", "1"), 10) + .cell(TensorAddress.ofLabels("key2", "1", "0"), 8) + .cell(TensorAddress.ofLabels("key2", "1", "1"), 11) + .cell(TensorAddress.ofLabels("key2", "2", "0"), 9) + .cell(TensorAddress.ofLabels("key2", "2", "1"), 12).build(), + Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}")); + assertEquals("Opposite order of dimensions", + Tensor.from("tensor(key{},x[3],y[2]):{key1:[[1,4],[2,5],[3,6]], key2:[[7,10],[8,11],[9,12]]}"), + Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}")); + } + private void assertDense(Tensor expectedTensor, String denseFormat) { assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat)); assertEquals(denseFormat, expectedTensor.toString()); @@ -99,7 +134,7 @@ public class TensorParserTestCase { "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); assertIllegal("At {x:0}: '1-.0' is not a valid double", "{{x:0}:1-.0}"); - assertIllegal("At position 1: '1-.0' is not a valid double", + assertIllegal("At value position 1: '1-.0' is not a valid double", "tensor(x[1]):[1-.0]"); } |