diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
commit | f5ccf036b4f7368f217a6bcbffc1699aac5eac2d (patch) | |
tree | 749afd3b29f52b918c67099c1742cb9db50211cf /vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | |
parent | 3954dbe2403bdbb21e9a558fbc55fd137afa40f8 (diff) |
Interpret dimensions in written order
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | 66 |
1 files changed, 43 insertions, 23 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 8d07a1ed9a8..ea21249bede 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -23,11 +24,17 @@ class TensorParser { Optional<TensorType> type; String valueString; + // The order in which dimensions are written in the type string. + // This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than + // the natural order of the tensor. + List<String> dimensionOrder; + tensorString = tensorString.trim(); if (tensorString.startsWith("tensor")) { int colonIndex = tensorString.indexOf(':'); String typeString = tensorString.substring(0, colonIndex); - TensorType typeFromString = TensorTypeParser.fromSpec(typeString); + dimensionOrder = new ArrayList<>(); + TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder); if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + "passed type " + explicitType.get()); @@ -37,6 +44,7 @@ class TensorParser { else { type = explicitType; valueString = tensorString; + dimensionOrder = null; } valueString = valueString.trim(); @@ -45,10 +53,10 @@ class TensorParser { return tensorFromSparseValueString(valueString, type); } else if (valueString.startsWith("{")) { - return tensorFromMixedValueString(valueString, type); + return tensorFromMixedValueString(valueString, type, dimensionOrder); } else if (valueString.startsWith("[")) { - return tensorFromDenseValueString(valueString, type); + return tensorFromDenseValueString(valueString, type, dimensionOrder); } else { if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty)) @@ -102,7 +110,9 @@ class TensorParser { } } - private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromMixedValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { if (type.isEmpty()) throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); @@ -117,7 +127,7 @@ class TensorParser { throw new IllegalArgumentException("A mixed tensor must be enclosed in {}"); // TODO: Check if there is also at least one bound indexed dimension MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get()); - MixedValueParser parser = new MixedValueParser(valueString, builder); + MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder); parser.parse(); return builder.build(); } @@ -126,7 +136,9 @@ class TensorParser { } } - private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromDenseValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { if (type.isEmpty()) throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); @@ -135,7 +147,7 @@ class TensorParser { "only dense dimensions with a given size"); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get()); - new DenseValueParser(valueString, builder).parse(); + new DenseValueParser(valueString, dimensionOrder, builder).parse(); return builder.build(); } @@ -157,10 +169,10 @@ class TensorParser { skipSpace(); if (position >= string.length()) - throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + "' but got the end of the string"); if ( string.charAt(position) != character) - throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + "' but got '" + string.charAt(position) + "'"); position++; } @@ -176,10 +188,12 @@ class TensorParser { private long tensorIndex = 0; - public DenseValueParser(String string, IndexedTensor.DirectIndexBuilder builder) { + public DenseValueParser(String string, + List<String> dimensionOrder, + IndexedTensor.DirectIndexBuilder builder) { super(string); this.builder = builder; - indexes = IndexedTensor.Indexes.of(builder.type()); + indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder); hasInnerStructure = hasInnerStructure(string); } @@ -189,10 +203,10 @@ class TensorParser { while (indexes.hasNext()) { indexes.next(); - for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++) + for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++) consume('['); consumeNumber(); - for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++) + for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++) consume(']'); if (indexes.hasNext()) consume(','); @@ -220,14 +234,14 @@ class TensorParser { String cellValueString = string.substring(position, nextNumberEnd); try { if (cellValueType == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString)); + builder.cellByDirectIndex(indexes.toSourceValueIndex(), Double.parseDouble(cellValueString)); else if (cellValueType == TensorType.Value.FLOAT) - builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString)); + builder.cellByDirectIndex(indexes.toSourceValueIndex(), Float.parseFloat(cellValueString)); else throw new IllegalArgumentException(cellValueType + " is not supported"); } catch (NumberFormatException e) { - throw new IllegalArgumentException("At position " + position + ": '" + + throw new IllegalArgumentException("At value position " + position + ": '" + cellValueString + "' is not a valid " + cellValueType); } position = nextNumberEnd; @@ -248,15 +262,19 @@ class TensorParser { private static class MixedValueParser extends ValueParser { private final MixedTensor.BoundBuilder builder; + private List<String> dimensionOrder; - public MixedValueParser(String string, MixedTensor.BoundBuilder builder) { + public MixedValueParser(String string, List<String> dimensionOrder, MixedTensor.BoundBuilder builder) { super(string); + this.dimensionOrder = dimensionOrder; this.builder = builder; } private void parse() { - TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); - TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension)); + TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension)); + if (dimensionOrder != null) + dimensionOrder.remove(mappedDimension.name()); skipSpace(); consume('{'); @@ -269,16 +287,18 @@ class TensorParser { position = labelEnd + 1; skipSpace(); - TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build(); - parseDenseSubspace(sparseAddress); + TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build(); + parseDenseSubspace(mappedAddress, dimensionOrder); if ( ! consumeOptional(',')) consume('}'); skipSpace(); } } - private void parseDenseSubspace(TensorAddress sparseAddress) { - DenseValueParser denseParser = new DenseValueParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress)); + private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) { + DenseValueParser denseParser = new DenseValueParser(string.substring(position), + denseDimensionOrder, + builder.denseSubspaceBuilder(sparseAddress)); denseParser.parse(); position+= denseParser.position(); } |