diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-02 12:17:39 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-02 12:17:39 +0200 |
commit | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (patch) | |
tree | 0fe5b3ebb646460c70b9222ba6b1505a81d3619f /vespajlib/src | |
parent | da41894e5b4f7525ee59d9c69838bdc21735d0f2 (diff) |
Parse tensor value type
Diffstat (limited to 'vespajlib/src')
6 files changed, 104 insertions, 59 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index fa32d385004..998f3170aa0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -11,7 +11,7 @@ class TensorParser { static Tensor tensorFrom(String tensorString, Optional<TensorType> type) { tensorString = tensorString.trim(); try { - if (tensorString.startsWith("tensor(")) { + if (tensorString.startsWith("tensor")) { int colonIndex = tensorString.indexOf(':'); String typeString = tensorString.substring(0, colonIndex); String valueString = tensorString.substring(colonIndex + 1); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 036f5e3ee5d..bded55405c0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -24,25 +24,19 @@ import java.util.stream.Collectors; */ public class TensorType { - public enum ValueType { DOUBLE, FLOAT}; + /** The permissible cell value types. Default is double. */ + // Types added here must also be added to TensorTypeParser.parseValueTypeSpec + public enum Value { DOUBLE, FLOAT}; /** The empty tensor type - which is the same as a double */ - public static final TensorType empty = new TensorType(ValueType.DOUBLE, Collections.emptyList()); + public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList()); - private ValueType valueType; - - public final ValueType valueType() { return valueType; } - - //TODO Remove once value type is wired in were it should. - public final TensorType valueType(ValueType valueType) { - this.valueType = valueType; - return this; - } + private final Value valueType; /** Sorted list of the dimensions of this */ private final ImmutableList<Dimension> dimensions; - private TensorType(ValueType valueType, Collection<Dimension> dimensions) { + private TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); @@ -64,6 +58,9 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } + /** Returns the numeric type of the cell values of this */ + public Value valueType() { return valueType; } + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } @@ -386,14 +383,14 @@ public class TensorType { private final Map<String, Dimension> dimensions = new LinkedHashMap<>(); - private final ValueType valueType; + private final Value valueType; /** Creates an empty builder with cells of type double*/ public Builder() { - this(ValueType.DOUBLE); + this(Value.DOUBLE); } - public Builder(ValueType valueType) { + public Builder(Value valueType) { this.valueType = valueType; } @@ -405,21 +402,21 @@ public class TensorType { * If it is indexed in one and mapped in the other it will become mapped. */ public Builder(TensorType ... types) { - this(ValueType.DOUBLE, types); + this(Value.DOUBLE, types); } - public Builder(ValueType valueType, TensorType ... types) { + public Builder(Value valueType, TensorType ... types) { this.valueType = valueType; for (TensorType type : types) addDimensionsOf(type); } - /** - * Creates a builder from the given dimensions. - */ + /** Creates a builder from the given dimensions */ public Builder(Iterable<Dimension> dimensions) { - this(ValueType.DOUBLE, dimensions); + this(Value.DOUBLE, dimensions); } - public Builder(ValueType valueType, Iterable<Dimension> dimensions) { + + /** Creates a builder from the given value type and dimensions */ + public Builder(Value valueType, Iterable<Dimension> dimensions) { this.valueType = valueType; for (TensorType.Dimension dimension : dimensions) { dimension(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index 32ad6171e57..a5733f1cc4c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -2,8 +2,10 @@ package com.yahoo.tensor; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -14,23 +16,32 @@ import java.util.regex.Pattern; */ public class TensorTypeParser { - private final static String START_STRING = "tensor("; + private final static String START_STRING = "tensor"; private final static String END_STRING = ")"; private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]"); private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}"); public static TensorType fromSpec(String specString) { - return new TensorType.Builder(dimensionsFromSpec(specString)).build(); - } + if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING)) + throw formatException(specString); + String specBody = specString.substring(START_STRING.length(), specString.length() - END_STRING.length()); - public static List<TensorType.Dimension> dimensionsFromSpec(String specString) { - if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) { - throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" + - " and end with '" + END_STRING + "', but was '" + specString + "'"); + String dimensionsSpec; + TensorType.Value valueType; + if (specBody.startsWith("(")) { + valueType = TensorType.Value.DOUBLE; // no value type spec: Use default + dimensionsSpec = specBody.substring(1); + } + else { + int parenthesisIndex = specBody.indexOf("("); + if (parenthesisIndex < 0) + throw formatException(specString); + valueType = parseValueTypeSpec(specBody.substring(0, parenthesisIndex), specString); + dimensionsSpec = specBody.substring(parenthesisIndex + 1); } - String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length()); - if (dimensionsSpec.isEmpty()) return Collections.emptyList(); + + if (dimensionsSpec.isEmpty()) return new TensorType.Builder(valueType, Collections.emptyList()).build(); List<TensorType.Dimension> dimensions = new ArrayList<>(); for (String element : dimensionsSpec.split(",")) { @@ -38,10 +49,23 @@ public class TensorTypeParser { boolean success = tryParseIndexedDimension(trimmedElement, dimensions) || tryParseMappedDimension(trimmedElement, dimensions); if ( ! success) - throw new IllegalArgumentException("Failed parsing element '" + element + - "' in type spec '" + specString + "'"); + throw formatException(specString, "Dimension '" + element + "' is on the wrong format"); + } + return new TensorType.Builder(valueType, dimensions).build(); + } + + private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { + if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) + throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); + + String valueType = valueTypeSpec.substring(1, valueTypeSpec.length() - 1); + switch (valueType) { + case "double" : return TensorType.Value.DOUBLE; + case "float" : return TensorType.Value.FLOAT; + default : throw formatException(fullSpecString, + "Value type must be either 'double' or 'float'" + + " but was '" + valueType + "'"); } - return dimensions; } private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) { @@ -69,5 +93,21 @@ public class TensorTypeParser { return false; } + + private static IllegalArgumentException formatException(String spec) { + return formatException(spec, Optional.empty()); + } + + private static IllegalArgumentException formatException(String spec, String errorDetail) { + return formatException(spec, Optional.of(errorDetail)); + } + + private static IllegalArgumentException formatException(String spec, Optional<String> errorDetail) { + throw new IllegalArgumentException("A tensor type spec must be on the form " + + "tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was '" + spec + "'. " + + errorDetail.map(s -> s + ". ").orElse("") + + "Examples: tensor(x[]), tensor<float>(name{}, x[10])"); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 500c436516f..ecd4f7d1965 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -43,7 +43,7 @@ public class DenseBinaryFormat implements BinaryFormat { encodeCells(buffer, tensor); } - private void encodeValueType(GrowableByteBuffer buffer, TensorType.ValueType valueType) { + private void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) { switch (valueType) { case DOUBLE: if (encodeType != EncodeType.DOUBLE_IS_DEFAULT) { @@ -100,7 +100,7 @@ public class DenseBinaryFormat implements BinaryFormat { sizes = sizesFromType(serializedType); } else { - type = decodeType(buffer, TensorType.ValueType.DOUBLE); + type = decodeType(buffer, TensorType.Value.DOUBLE); sizes = sizesFromType(type); } Tensor.Builder builder = Tensor.Builder.of(type, sizes); @@ -108,16 +108,16 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private TensorType decodeType(GrowableByteBuffer buffer, TensorType.ValueType valueType) { - TensorType.ValueType serializedValueType = TensorType.ValueType.DOUBLE; - if ((valueType != TensorType.ValueType.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) { + private TensorType decodeType(GrowableByteBuffer buffer, TensorType.Value valueType) { + TensorType.Value serializedValueType = TensorType.Value.DOUBLE; + if ((valueType != TensorType.Value.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) { int type = buffer.getInt1_4Bytes(); switch (type) { case DOUBLE_VALUE_TYPE: - serializedValueType = TensorType.ValueType.DOUBLE; + serializedValueType = TensorType.Value.DOUBLE; break; case FLOAT_VALUE_TYPE: - serializedValueType = TensorType.ValueType.FLOAT; + serializedValueType = TensorType.Value.FLOAT; break; default: throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal."); @@ -141,7 +141,7 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private void decodeCells(TensorType.ValueType valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + private void decodeCells(TensorType.Value valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { switch (valueType) { case DOUBLE: decodeCellsAsDouble(sizes, buffer, builder); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index f7a0a3cdb7d..d3bb702175a 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -58,10 +58,13 @@ public class TensorTypeTestCase { @Test public void requireThatIllegalSyntaxInSpecThrowsException() { - assertIllegalTensorType("foo(x[10])", "Tensor type spec must start with 'tensor(' and end with ')', but was 'foo(x[10])'"); - assertIllegalTensorType("tensor(x_@[10])", "Failed parsing element 'x_@[10]' in type spec 'tensor(x_@[10])'"); - assertIllegalTensorType("tensor(x[10a])", "Failed parsing element 'x[10a]' in type spec 'tensor(x[10a])'"); - assertIllegalTensorType("tensor(x{10})", "Failed parsing element 'x{10}' in type spec 'tensor(x{10})'"); + assertIllegalTensorType("foo(x[10])", "but was 'foo(x[10])'."); + assertIllegalTensorType("tensor(x_@[10])", "Dimension 'x_@[10]' is on the wrong format"); + assertIllegalTensorType("tensor(x[10a])", "Dimension 'x[10a]' is on the wrong format"); + assertIllegalTensorType("tensor(x{10})", "Dimension 'x{10}' is on the wrong format"); + assertIllegalTensorType("tensor<(x{})", " Value type spec must be enclosed in <>"); + assertIllegalTensorType("tensor<>(x{})", "Value type must be"); + assertIllegalTensorType("tensor<notavalue>(x{})", "Value type must be"); } @Test @@ -88,6 +91,13 @@ public class TensorTypeTestCase { assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])"); } + @Test + public void testValueType() { + assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); + assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])"); + assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])"); + } + private static void assertTensorType(String typeSpec) { assertTensorType(typeSpec, typeSpec); } @@ -121,4 +131,8 @@ public class TensorTypeTestCase { assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType))); } + private void assertValueType(TensorType.Value expectedValueType, String tensorTypeSpec) { + assertEquals(expectedValueType, TensorType.fromSpec(tensorTypeSpec).valueType()); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index e8b17812f32..8fa2537e1e2 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -63,27 +63,21 @@ public class DenseBinaryFormatTestCase { 64, 0, 0, 0, // value 1 64, 64, 0, 0, // value 2 }; - Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); - tensor.type().valueType(TensorType.ValueType.FLOAT); - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(tensor))); + Tensor tensor = Tensor.from("tensor<float>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test public void testSerializationOfDifferentValueTypes() { - assertSerialization(TensorType.ValueType.DOUBLE, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); - assertSerialization(TensorType.ValueType.FLOAT, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); } private void assertSerialization(String tensorString) { - assertSerialization(TensorType.ValueType.DOUBLE, Tensor.from(tensorString)); - } - private void assertSerialization(TensorType.ValueType valueType, String tensorString) { - assertSerialization(valueType, Tensor.from(tensorString)); + assertSerialization(Tensor.from(tensorString)); } - private void assertSerialization(TensorType.ValueType valueType, Tensor tensor) { - tensor.type().valueType(valueType); + private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } |