summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-14 08:34:09 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-14 08:34:09 +0100
commitf5ccf036b4f7368f217a6bcbffc1699aac5eac2d (patch)
tree749afd3b29f52b918c67099c1742cb9db50211cf /vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
parent3954dbe2403bdbb21e9a558fbc55fd137afa40f8 (diff)
Interpret dimensions in written order
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java66
1 files changed, 43 insertions, 23 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 8d07a1ed9a8..ea21249bede 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
@@ -23,11 +24,17 @@ class TensorParser {
Optional<TensorType> type;
String valueString;
+ // The order in which dimensions are written in the type string.
+ // This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than
+ // the natural order of the tensor.
+ List<String> dimensionOrder;
+
tensorString = tensorString.trim();
if (tensorString.startsWith("tensor")) {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
- TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
+ dimensionOrder = new ArrayList<>();
+ TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder);
if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
"passed type " + explicitType.get());
@@ -37,6 +44,7 @@ class TensorParser {
else {
type = explicitType;
valueString = tensorString;
+ dimensionOrder = null;
}
valueString = valueString.trim();
@@ -45,10 +53,10 @@ class TensorParser {
return tensorFromSparseValueString(valueString, type);
}
else if (valueString.startsWith("{")) {
- return tensorFromMixedValueString(valueString, type);
+ return tensorFromMixedValueString(valueString, type, dimensionOrder);
}
else if (valueString.startsWith("[")) {
- return tensorFromDenseValueString(valueString, type);
+ return tensorFromDenseValueString(valueString, type, dimensionOrder);
}
else {
if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty))
@@ -102,7 +110,9 @@ class TensorParser {
}
}
- private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromMixedValueString(String valueString,
+ Optional<TensorType> type,
+ List<String> dimensionOrder) {
if (type.isEmpty())
throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
@@ -117,7 +127,7 @@ class TensorParser {
throw new IllegalArgumentException("A mixed tensor must be enclosed in {}");
// TODO: Check if there is also at least one bound indexed dimension
MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get());
- MixedValueParser parser = new MixedValueParser(valueString, builder);
+ MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder);
parser.parse();
return builder.build();
}
@@ -126,7 +136,9 @@ class TensorParser {
}
}
- private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromDenseValueString(String valueString,
+ Optional<TensorType> type,
+ List<String> dimensionOrder) {
if (type.isEmpty())
throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
@@ -135,7 +147,7 @@ class TensorParser {
"only dense dimensions with a given size");
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get());
- new DenseValueParser(valueString, builder).parse();
+ new DenseValueParser(valueString, dimensionOrder, builder).parse();
return builder.build();
}
@@ -157,10 +169,10 @@ class TensorParser {
skipSpace();
if (position >= string.length())
- throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
"' but got the end of the string");
if ( string.charAt(position) != character)
- throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
"' but got '" + string.charAt(position) + "'");
position++;
}
@@ -176,10 +188,12 @@ class TensorParser {
private long tensorIndex = 0;
- public DenseValueParser(String string, IndexedTensor.DirectIndexBuilder builder) {
+ public DenseValueParser(String string,
+ List<String> dimensionOrder,
+ IndexedTensor.DirectIndexBuilder builder) {
super(string);
this.builder = builder;
- indexes = IndexedTensor.Indexes.of(builder.type());
+ indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder);
hasInnerStructure = hasInnerStructure(string);
}
@@ -189,10 +203,10 @@ class TensorParser {
while (indexes.hasNext()) {
indexes.next();
- for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++)
+ for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++)
consume('[');
consumeNumber();
- for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++)
+ for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++)
consume(']');
if (indexes.hasNext())
consume(',');
@@ -220,14 +234,14 @@ class TensorParser {
String cellValueString = string.substring(position, nextNumberEnd);
try {
if (cellValueType == TensorType.Value.DOUBLE)
- builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString));
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), Double.parseDouble(cellValueString));
else if (cellValueType == TensorType.Value.FLOAT)
- builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString));
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), Float.parseFloat(cellValueString));
else
throw new IllegalArgumentException(cellValueType + " is not supported");
}
catch (NumberFormatException e) {
- throw new IllegalArgumentException("At position " + position + ": '" +
+ throw new IllegalArgumentException("At value position " + position + ": '" +
cellValueString + "' is not a valid " + cellValueType);
}
position = nextNumberEnd;
@@ -248,15 +262,19 @@ class TensorParser {
private static class MixedValueParser extends ValueParser {
private final MixedTensor.BoundBuilder builder;
+ private List<String> dimensionOrder;
- public MixedValueParser(String string, MixedTensor.BoundBuilder builder) {
+ public MixedValueParser(String string, List<String> dimensionOrder, MixedTensor.BoundBuilder builder) {
super(string);
+ this.dimensionOrder = dimensionOrder;
this.builder = builder;
}
private void parse() {
- TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
- TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension));
+ TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
+ TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension));
+ if (dimensionOrder != null)
+ dimensionOrder.remove(mappedDimension.name());
skipSpace();
consume('{');
@@ -269,16 +287,18 @@ class TensorParser {
position = labelEnd + 1;
skipSpace();
- TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build();
- parseDenseSubspace(sparseAddress);
+ TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build();
+ parseDenseSubspace(mappedAddress, dimensionOrder);
if ( ! consumeOptional(','))
consume('}');
skipSpace();
}
}
- private void parseDenseSubspace(TensorAddress sparseAddress) {
- DenseValueParser denseParser = new DenseValueParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress));
+ private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) {
+ DenseValueParser denseParser = new DenseValueParser(string.substring(position),
+ denseDimensionOrder,
+ builder.denseSubspaceBuilder(sparseAddress));
denseParser.parse();
position+= denseParser.position();
}