aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-02 12:17:39 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-02 12:17:39 +0200
commit6eb80166172e10255841fd3d3cf70bed09d3d8c1 (patch)
tree0fe5b3ebb646460c70b9222ba6b1505a81d3619f /vespajlib
parentda41894e5b4f7525ee59d9c69838bdc21735d0f2 (diff)
Parse tensor value type
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java64
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java16
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java18
6 files changed, 104 insertions, 59 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index fa32d385004..998f3170aa0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -11,7 +11,7 @@ class TensorParser {
static Tensor tensorFrom(String tensorString, Optional<TensorType> type) {
tensorString = tensorString.trim();
try {
- if (tensorString.startsWith("tensor(")) {
+ if (tensorString.startsWith("tensor")) {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
String valueString = tensorString.substring(colonIndex + 1);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 036f5e3ee5d..bded55405c0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -24,25 +24,19 @@ import java.util.stream.Collectors;
*/
public class TensorType {
- public enum ValueType { DOUBLE, FLOAT};
+ /** The permissible cell value types. Default is double. */
+ // Types added here must also be added to TensorTypeParser.parseValueTypeSpec
+ public enum Value { DOUBLE, FLOAT};
/** The empty tensor type - which is the same as a double */
- public static final TensorType empty = new TensorType(ValueType.DOUBLE, Collections.emptyList());
+ public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
- private ValueType valueType;
-
- public final ValueType valueType() { return valueType; }
-
- //TODO Remove once value type is wired in were it should.
- public final TensorType valueType(ValueType valueType) {
- this.valueType = valueType;
- return this;
- }
+ private final Value valueType;
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
- private TensorType(ValueType valueType, Collection<Dimension> dimensions) {
+ private TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
@@ -64,6 +58,9 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns the numeric type of the cell values of this */
+ public Value valueType() { return valueType; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }
@@ -386,14 +383,14 @@ public class TensorType {
private final Map<String, Dimension> dimensions = new LinkedHashMap<>();
- private final ValueType valueType;
+ private final Value valueType;
/** Creates an empty builder with cells of type double*/
public Builder() {
- this(ValueType.DOUBLE);
+ this(Value.DOUBLE);
}
- public Builder(ValueType valueType) {
+ public Builder(Value valueType) {
this.valueType = valueType;
}
@@ -405,21 +402,21 @@ public class TensorType {
* If it is indexed in one and mapped in the other it will become mapped.
*/
public Builder(TensorType ... types) {
- this(ValueType.DOUBLE, types);
+ this(Value.DOUBLE, types);
}
- public Builder(ValueType valueType, TensorType ... types) {
+ public Builder(Value valueType, TensorType ... types) {
this.valueType = valueType;
for (TensorType type : types)
addDimensionsOf(type);
}
- /**
- * Creates a builder from the given dimensions.
- */
+ /** Creates a builder from the given dimensions */
public Builder(Iterable<Dimension> dimensions) {
- this(ValueType.DOUBLE, dimensions);
+ this(Value.DOUBLE, dimensions);
}
- public Builder(ValueType valueType, Iterable<Dimension> dimensions) {
+
+ /** Creates a builder from the given value type and dimensions */
+ public Builder(Value valueType, Iterable<Dimension> dimensions) {
this.valueType = valueType;
for (TensorType.Dimension dimension : dimensions) {
dimension(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 32ad6171e57..a5733f1cc4c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -2,8 +2,10 @@
package com.yahoo.tensor;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -14,23 +16,32 @@ import java.util.regex.Pattern;
*/
public class TensorTypeParser {
- private final static String START_STRING = "tensor(";
+ private final static String START_STRING = "tensor";
private final static String END_STRING = ")";
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
public static TensorType fromSpec(String specString) {
- return new TensorType.Builder(dimensionsFromSpec(specString)).build();
- }
+ if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING))
+ throw formatException(specString);
+ String specBody = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- public static List<TensorType.Dimension> dimensionsFromSpec(String specString) {
- if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
- throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
- " and end with '" + END_STRING + "', but was '" + specString + "'");
+ String dimensionsSpec;
+ TensorType.Value valueType;
+ if (specBody.startsWith("(")) {
+ valueType = TensorType.Value.DOUBLE; // no value type spec: Use default
+ dimensionsSpec = specBody.substring(1);
+ }
+ else {
+ int parenthesisIndex = specBody.indexOf("(");
+ if (parenthesisIndex < 0)
+ throw formatException(specString);
+ valueType = parseValueTypeSpec(specBody.substring(0, parenthesisIndex), specString);
+ dimensionsSpec = specBody.substring(parenthesisIndex + 1);
}
- String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- if (dimensionsSpec.isEmpty()) return Collections.emptyList();
+
+ if (dimensionsSpec.isEmpty()) return new TensorType.Builder(valueType, Collections.emptyList()).build();
List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
@@ -38,10 +49,23 @@ public class TensorTypeParser {
boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
tryParseMappedDimension(trimmedElement, dimensions);
if ( ! success)
- throw new IllegalArgumentException("Failed parsing element '" + element +
- "' in type spec '" + specString + "'");
+ throw formatException(specString, "Dimension '" + element + "' is on the wrong format");
+ }
+ return new TensorType.Builder(valueType, dimensions).build();
+ }
+
+ private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
+ if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
+ throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
+
+ String valueType = valueTypeSpec.substring(1, valueTypeSpec.length() - 1);
+ switch (valueType) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw formatException(fullSpecString,
+ "Value type must be either 'double' or 'float'" +
+ " but was '" + valueType + "'");
}
- return dimensions;
}
private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
@@ -69,5 +93,21 @@ public class TensorTypeParser {
return false;
}
+
+ private static IllegalArgumentException formatException(String spec) {
+ return formatException(spec, Optional.empty());
+ }
+
+ private static IllegalArgumentException formatException(String spec, String errorDetail) {
+ return formatException(spec, Optional.of(errorDetail));
+ }
+
+ private static IllegalArgumentException formatException(String spec, Optional<String> errorDetail) {
+ throw new IllegalArgumentException("A tensor type spec must be on the form " +
+ "tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was '" + spec + "'. " +
+ errorDetail.map(s -> s + ". ").orElse("") +
+ "Examples: tensor(x[]), tensor<float>(name{}, x[10])");
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 500c436516f..ecd4f7d1965 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -43,7 +43,7 @@ public class DenseBinaryFormat implements BinaryFormat {
encodeCells(buffer, tensor);
}
- private void encodeValueType(GrowableByteBuffer buffer, TensorType.ValueType valueType) {
+ private void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
switch (valueType) {
case DOUBLE:
if (encodeType != EncodeType.DOUBLE_IS_DEFAULT) {
@@ -100,7 +100,7 @@ public class DenseBinaryFormat implements BinaryFormat {
sizes = sizesFromType(serializedType);
}
else {
- type = decodeType(buffer, TensorType.ValueType.DOUBLE);
+ type = decodeType(buffer, TensorType.Value.DOUBLE);
sizes = sizesFromType(type);
}
Tensor.Builder builder = Tensor.Builder.of(type, sizes);
@@ -108,16 +108,16 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private TensorType decodeType(GrowableByteBuffer buffer, TensorType.ValueType valueType) {
- TensorType.ValueType serializedValueType = TensorType.ValueType.DOUBLE;
- if ((valueType != TensorType.ValueType.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
+ private TensorType decodeType(GrowableByteBuffer buffer, TensorType.Value valueType) {
+ TensorType.Value serializedValueType = TensorType.Value.DOUBLE;
+ if ((valueType != TensorType.Value.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
int type = buffer.getInt1_4Bytes();
switch (type) {
case DOUBLE_VALUE_TYPE:
- serializedValueType = TensorType.ValueType.DOUBLE;
+ serializedValueType = TensorType.Value.DOUBLE;
break;
case FLOAT_VALUE_TYPE:
- serializedValueType = TensorType.ValueType.FLOAT;
+ serializedValueType = TensorType.Value.FLOAT;
break;
default:
throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal.");
@@ -141,7 +141,7 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private void decodeCells(TensorType.ValueType valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ private void decodeCells(TensorType.Value valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
switch (valueType) {
case DOUBLE:
decodeCellsAsDouble(sizes, buffer, builder);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index f7a0a3cdb7d..d3bb702175a 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -58,10 +58,13 @@ public class TensorTypeTestCase {
@Test
public void requireThatIllegalSyntaxInSpecThrowsException() {
- assertIllegalTensorType("foo(x[10])", "Tensor type spec must start with 'tensor(' and end with ')', but was 'foo(x[10])'");
- assertIllegalTensorType("tensor(x_@[10])", "Failed parsing element 'x_@[10]' in type spec 'tensor(x_@[10])'");
- assertIllegalTensorType("tensor(x[10a])", "Failed parsing element 'x[10a]' in type spec 'tensor(x[10a])'");
- assertIllegalTensorType("tensor(x{10})", "Failed parsing element 'x{10}' in type spec 'tensor(x{10})'");
+ assertIllegalTensorType("foo(x[10])", "but was 'foo(x[10])'.");
+ assertIllegalTensorType("tensor(x_@[10])", "Dimension 'x_@[10]' is on the wrong format");
+ assertIllegalTensorType("tensor(x[10a])", "Dimension 'x[10a]' is on the wrong format");
+ assertIllegalTensorType("tensor(x{10})", "Dimension 'x{10}' is on the wrong format");
+ assertIllegalTensorType("tensor<(x{})", " Value type spec must be enclosed in <>");
+ assertIllegalTensorType("tensor<>(x{})", "Value type must be");
+ assertIllegalTensorType("tensor<notavalue>(x{})", "Value type must be");
}
@Test
@@ -88,6 +91,13 @@ public class TensorTypeTestCase {
assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])");
}
+ @Test
+ public void testValueType() {
+ assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
+ assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
+ assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ }
+
private static void assertTensorType(String typeSpec) {
assertTensorType(typeSpec, typeSpec);
}
@@ -121,4 +131,8 @@ public class TensorTypeTestCase {
assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType)));
}
+ private void assertValueType(TensorType.Value expectedValueType, String tensorTypeSpec) {
+ assertEquals(expectedValueType, TensorType.fromSpec(tensorTypeSpec).valueType());
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index e8b17812f32..8fa2537e1e2 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -63,27 +63,21 @@ public class DenseBinaryFormatTestCase {
64, 0, 0, 0, // value 1
64, 64, 0, 0, // value 2
};
- Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
- tensor.type().valueType(TensorType.ValueType.FLOAT);
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ Tensor tensor = Tensor.from("tensor<float>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
public void testSerializationOfDifferentValueTypes() {
- assertSerialization(TensorType.ValueType.DOUBLE, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
- assertSerialization(TensorType.ValueType.FLOAT, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
}
private void assertSerialization(String tensorString) {
- assertSerialization(TensorType.ValueType.DOUBLE, Tensor.from(tensorString));
- }
- private void assertSerialization(TensorType.ValueType valueType, String tensorString) {
- assertSerialization(valueType, Tensor.from(tensorString));
+ assertSerialization(Tensor.from(tensorString));
}
- private void assertSerialization(TensorType.ValueType valueType, Tensor tensor) {
- tensor.type().valueType(valueType);
+ private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}