diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
commit | f5ccf036b4f7368f217a6bcbffc1699aac5eac2d (patch) | |
tree | 749afd3b29f52b918c67099c1742cb9db50211cf /vespajlib | |
parent | 3954dbe2403bdbb21e9a558fbc55fd137afa40f8 (diff) |
Interpret dimensions in written order
Diffstat (limited to 'vespajlib')
6 files changed, 180 insertions, 65 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index d91b38a8a96..1fcdf7f5cca 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -693,6 +693,7 @@ "methods": [ "public void <init>(int)", "public com.yahoo.tensor.DimensionSizes$Builder set(int, long)", + "public com.yahoo.tensor.DimensionSizes$Builder add(long)", "public long size(int)", "public int dimensions()", "public com.yahoo.tensor.DimensionSizes build()" @@ -836,10 +837,12 @@ ], "methods": [ "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)", + "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType, java.util.List)", "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)", "public com.yahoo.tensor.TensorAddress toAddress()", "public long[] indexesCopy()", "public long[] indexesForReading()", + "public long toSourceValueIndex()", "public java.util.List toList()", "public java.lang.String toString()", "public abstract long size()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index d81c02fb75f..202817ece42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -71,6 +71,7 @@ public final class DimensionSizes { */ public final static class Builder { + private int dimensionIndex = 0; private long[] sizes; public Builder(int dimensions) { @@ -82,6 +83,11 @@ public final class DimensionSizes { return this; } + public Builder add(long size) { + sizes[dimensionIndex++] = size; + return this; + } + /** * Returns the length of this in the nth dimension * diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 30923976fa5..ba3a35e8eda 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor { indexes.next(); // start brackets - for (int i = 0; i < indexes.rightDimensionsAtStart(); i++) + for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) b.append("["); // value @@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); // end bracket and comma - for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++) + for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); if (index < size() - 1) b.append(", "); @@ -777,6 +777,10 @@ public abstract class IndexedTensor implements Tensor { return of(DimensionSizes.of(type)); } + public static Indexes of(TensorType type, List<String> iterateDimensionOrder) { + return of(DimensionSizes.of(type), toIterationOrder(iterateDimensionOrder, type)); + } + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -789,6 +793,10 @@ public abstract class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size); } + private static Indexes of(DimensionSizes sizes, List<Integer> iterateDimensions) { + return of(sizes, sizes, iterateDimensions); + } + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) { return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions)); } @@ -822,6 +830,16 @@ public abstract class IndexedTensor implements Tensor { } } + private static List<Integer> toIterationOrder(List<String> dimensionNames, TensorType type) { + if (dimensionNames == null) return completeIterationOrder(type.rank()); + + List<Integer> iterationDimensions = new ArrayList<>(type.rank()); + for (int i = 0; i < type.rank(); i++) + iterationDimensions.add(type.rank() - 1 - type.indexOfDimension(dimensionNames.get(i)).get()); + return iterationDimensions; + } + + /** Since the right dimensions binds closest, iteration order is the opposite of the tensor order */ private static List<Integer> completeIterationOrder(int length) { List<Integer> iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) @@ -854,7 +872,7 @@ public abstract class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public long[] indexesForReading() { return indexes; } - long toSourceValueIndex() { + public long toSourceValueIndex() { return IndexedTensor.toValueIndex(indexes, sourceSizes); } @@ -882,27 +900,12 @@ public abstract class IndexedTensor implements Tensor { /** Returns whether further values are available by calling next() */ public abstract boolean hasNext(); - /** Returns the number of dimensions from the right which are currently at the start position (0) */ - int rightDimensionsAtStart() { - int dimension = indexes.length - 1; - int atStartCount = 0; - while (dimension >= 0 && indexes[dimension] == 0) { - atStartCount++; - dimension--; - } - return atStartCount; - } + /** Returns the number of dimensions in iteration order which are currently at the start position (0) */ + abstract int nextDimensionsAtStart(); + + /** Returns the number of dimensions in iteration order which are currently at their end position */ + abstract int nextDimensionsAtEnd(); - /** Returns the number of dimensions from the right which are currently at the end position */ - int rightDimensionsAtEnd() { - int dimension = indexes.length - 1; - int atEndCount = 0; - while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) { - atEndCount++; - dimension--; - } - return atEndCount; - } } private final static class EmptyIndexes extends Indexes { @@ -920,6 +923,12 @@ public abstract class IndexedTensor implements Tensor { @Override public boolean hasNext() { return false; } + @Override + int nextDimensionsAtStart() { return 0; } + + @Override + int nextDimensionsAtEnd() { return 0; } + } private final static class SingleValueIndexes extends Indexes { @@ -939,6 +948,12 @@ public abstract class IndexedTensor implements Tensor { @Override public boolean hasNext() { return ! exhausted; } + @Override + int nextDimensionsAtStart() { return 1; } + + @Override + int nextDimensionsAtEnd() { return 1; } + } private static class MultiDimensionIndexes extends Indexes { @@ -987,6 +1002,22 @@ public abstract class IndexedTensor implements Tensor { return false; } + @Override + int nextDimensionsAtStart() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == 0) + dimension++; + return dimension; + } + + @Override + int nextDimensionsAtEnd() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == dimensionSizes().size(iterateDimensions.get(dimension)) - 1) + dimension++; + return dimension; + } + } /** In this case we can reuse the source index computation for the iteration index */ @@ -999,7 +1030,7 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { + public long toSourceValueIndex() { return lastComputedSourceValueIndex = super.toSourceValueIndex(); } @@ -1056,7 +1087,7 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentSourceValueIndex; } + public long toSourceValueIndex() { return currentSourceValueIndex; } @Override long toIterationValueIndex() { return currentIterationValueIndex; } @@ -1066,6 +1097,12 @@ public abstract class IndexedTensor implements Tensor { return indexes[iterateDimension] + 1 < size; } + @Override + int nextDimensionsAtStart() { return currentSourceValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentSourceValueIndex == size - 1 ? 1 : 0; } + } /** In this case we only need to keep track of one index */ @@ -1117,11 +1154,17 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentValueIndex; } + public long toSourceValueIndex() { return currentValueIndex; } @Override long toIterationValueIndex() { return currentValueIndex; } + @Override + int nextDimensionsAtStart() { return currentValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentValueIndex == size - 1 ? 1 : 0; } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 8d07a1ed9a8..ea21249bede 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -23,11 +24,17 @@ class TensorParser { Optional<TensorType> type; String valueString; + // The order in which dimensions are written in the type string. + // This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than + // the natural order of the tensor. + List<String> dimensionOrder; + tensorString = tensorString.trim(); if (tensorString.startsWith("tensor")) { int colonIndex = tensorString.indexOf(':'); String typeString = tensorString.substring(0, colonIndex); - TensorType typeFromString = TensorTypeParser.fromSpec(typeString); + dimensionOrder = new ArrayList<>(); + TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder); if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + "passed type " + explicitType.get()); @@ -37,6 +44,7 @@ class TensorParser { else { type = explicitType; valueString = tensorString; + dimensionOrder = null; } valueString = valueString.trim(); @@ -45,10 +53,10 @@ class TensorParser { return tensorFromSparseValueString(valueString, type); } else if (valueString.startsWith("{")) { - return tensorFromMixedValueString(valueString, type); + return tensorFromMixedValueString(valueString, type, dimensionOrder); } else if (valueString.startsWith("[")) { - return tensorFromDenseValueString(valueString, type); + return tensorFromDenseValueString(valueString, type, dimensionOrder); } else { if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty)) @@ -102,7 +110,9 @@ class TensorParser { } } - private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromMixedValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { if (type.isEmpty()) throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); @@ -117,7 +127,7 @@ class TensorParser { throw new IllegalArgumentException("A mixed tensor must be enclosed in {}"); // TODO: Check if there is also at least one bound indexed dimension MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get()); - MixedValueParser parser = new MixedValueParser(valueString, builder); + MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder); parser.parse(); return builder.build(); } @@ -126,7 +136,9 @@ class TensorParser { } } - private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromDenseValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { if (type.isEmpty()) throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); @@ -135,7 +147,7 @@ class TensorParser { "only dense dimensions with a given size"); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get()); - new DenseValueParser(valueString, builder).parse(); + new DenseValueParser(valueString, dimensionOrder, builder).parse(); return builder.build(); } @@ -157,10 +169,10 @@ class TensorParser { skipSpace(); if (position >= string.length()) - throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + "' but got the end of the string"); if ( string.charAt(position) != character) - throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + "' but got '" + string.charAt(position) + "'"); position++; } @@ -176,10 +188,12 @@ class TensorParser { private long tensorIndex = 0; - public DenseValueParser(String string, IndexedTensor.DirectIndexBuilder builder) { + public DenseValueParser(String string, + List<String> dimensionOrder, + IndexedTensor.DirectIndexBuilder builder) { super(string); this.builder = builder; - indexes = IndexedTensor.Indexes.of(builder.type()); + indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder); hasInnerStructure = hasInnerStructure(string); } @@ -189,10 +203,10 @@ class TensorParser { while (indexes.hasNext()) { indexes.next(); - for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++) + for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++) consume('['); consumeNumber(); - for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++) + for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++) consume(']'); if (indexes.hasNext()) consume(','); @@ -220,14 +234,14 @@ class TensorParser { String cellValueString = string.substring(position, nextNumberEnd); try { if (cellValueType == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString)); + builder.cellByDirectIndex(indexes.toSourceValueIndex(), Double.parseDouble(cellValueString)); else if (cellValueType == TensorType.Value.FLOAT) - builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString)); + builder.cellByDirectIndex(indexes.toSourceValueIndex(), Float.parseFloat(cellValueString)); else throw new IllegalArgumentException(cellValueType + " is not supported"); } catch (NumberFormatException e) { - throw new IllegalArgumentException("At position " + position + ": '" + + throw new IllegalArgumentException("At value position " + position + ": '" + cellValueString + "' is not a valid " + cellValueType); } position = nextNumberEnd; @@ -248,15 +262,19 @@ class TensorParser { private static class MixedValueParser extends ValueParser { private final MixedTensor.BoundBuilder builder; + private List<String> dimensionOrder; - public MixedValueParser(String string, MixedTensor.BoundBuilder builder) { + public MixedValueParser(String string, List<String> dimensionOrder, MixedTensor.BoundBuilder builder) { super(string); + this.dimensionOrder = dimensionOrder; this.builder = builder; } private void parse() { - TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); - TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension)); + TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension)); + if (dimensionOrder != null) + dimensionOrder.remove(mappedDimension.name()); skipSpace(); consume('{'); @@ -269,16 +287,18 @@ class TensorParser { position = labelEnd + 1; skipSpace(); - TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build(); - parseDenseSubspace(sparseAddress); + TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build(); + parseDenseSubspace(mappedAddress, dimensionOrder); if ( ! consumeOptional(',')) consume('}'); skipSpace(); } } - private void parseDenseSubspace(TensorAddress sparseAddress) { - DenseValueParser denseParser = new DenseValueParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress)); + private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) { + DenseValueParser denseParser = new DenseValueParser(string.substring(position), + denseDimensionOrder, + builder.denseSubspaceBuilder(sparseAddress)); denseParser.parse(); position+= denseParser.position(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index def3ab6b4ec..4fdb0906740 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -24,6 +24,13 @@ public class TensorTypeParser { private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}"); public static TensorType fromSpec(String specString) { + return fromSpec(specString, null); + } + + /** + * @param dimensionOrder if not null, this will be populated with the dimension names in the order they are written + */ + static TensorType fromSpec(String specString, List<String> dimensionOrder) { specString = specString.trim(); if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING)) throw formatException(specString); @@ -48,10 +55,14 @@ public class TensorTypeParser { List<TensorType.Dimension> dimensions = new ArrayList<>(); for (String element : dimensionsSpec.split(",")) { String trimmedElement = element.trim(); - boolean success = tryParseIndexedDimension(trimmedElement, dimensions) || - tryParseMappedDimension(trimmedElement, dimensions); - if ( ! success) + TensorType.Dimension dimension = tryParseIndexedDimension(trimmedElement); + if (dimension == null) + dimension = tryParseMappedDimension(trimmedElement); + if (dimension == null) throw formatException(specString, "Dimension '" + element + "' is on the wrong format"); + dimensions.add(dimension); + if (dimensionOrder != null) + dimensionOrder.add(dimension.name()); } return new TensorType.Builder(valueType, dimensions).build(); } @@ -68,29 +79,26 @@ public class TensorTypeParser { } } - private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) { + private static TensorType.Dimension tryParseIndexedDimension(String element) { Matcher matcher = indexedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); String dimensionSize = matcher.group(2); - if (dimensionSize.isEmpty()) { - dimensions.add(TensorType.Dimension.indexed(dimensionName)); - } else { - dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize))); - } - return true; + if (dimensionSize.isEmpty()) + return TensorType.Dimension.indexed(dimensionName); + else + return TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)); } - return false; + return null; } - private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) { + private static TensorType.Dimension tryParseMappedDimension(String element) { Matcher matcher = mappedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); - dimensions.add(TensorType.Dimension.mapped(dimensionName)); - return true; + return TensorType.Dimension.mapped(dimensionName); } - return false; + return null; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index b2aba5b02eb..9dfdee29845 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -75,6 +75,19 @@ public class TensorParserTestCase { } @Test + public void testDenseWrongOrder() { + assertEquals("Opposite order of dimensions", + Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2])")) + .cell(1, 0, 0) + .cell(4, 0, 1) + .cell(2, 1, 0) + .cell(5, 1, 1) + .cell(3, 2, 0) + .cell(6, 2, 1).build(), + Tensor.from("tensor(y[2],x[3]):[[1,2,3],[4,5,6]]")); + } + + @Test public void testMixedParsing() { assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])")) .cell(TensorAddress.ofLabels("a", "0"), 1) @@ -84,6 +97,28 @@ public class TensorParserTestCase { Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}")); } + @Test + public void testMixedWrongOrder() { + assertEquals("Opposite order of dimensions", + Tensor.Builder.of(TensorType.fromSpec("tensor(key{},x[3],y[2])")) + .cell(TensorAddress.ofLabels("key1", "0", "0"), 1) + .cell(TensorAddress.ofLabels("key1", "0", "1"), 4) + .cell(TensorAddress.ofLabels("key1", "1", "0"), 2) + .cell(TensorAddress.ofLabels("key1", "1", "1"), 5) + .cell(TensorAddress.ofLabels("key1", "2", "0"), 3) + .cell(TensorAddress.ofLabels("key1", "2", "1"), 6) + .cell(TensorAddress.ofLabels("key2", "0", "0"), 7) + .cell(TensorAddress.ofLabels("key2", "0", "1"), 10) + .cell(TensorAddress.ofLabels("key2", "1", "0"), 8) + .cell(TensorAddress.ofLabels("key2", "1", "1"), 11) + .cell(TensorAddress.ofLabels("key2", "2", "0"), 9) + .cell(TensorAddress.ofLabels("key2", "2", "1"), 12).build(), + Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}")); + assertEquals("Opposite order of dimensions", + Tensor.from("tensor(key{},x[3],y[2]):{key1:[[1,4],[2,5],[3,6]], key2:[[7,10],[8,11],[9,12]]}"), + Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}")); + } + private void assertDense(Tensor expectedTensor, String denseFormat) { assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat)); assertEquals(denseFormat, expectedTensor.toString()); @@ -99,7 +134,7 @@ public class TensorParserTestCase { "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); assertIllegal("At {x:0}: '1-.0' is not a valid double", "{{x:0}:1-.0}"); - assertIllegal("At position 1: '1-.0' is not a valid double", + assertIllegal("At value position 1: '1-.0' is not a valid double", "tensor(x[1]):[1-.0]"); } |