aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-02 12:17:39 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-02 12:17:39 +0200
commit6eb80166172e10255841fd3d3cf70bed09d3d8c1 (patch)
tree0fe5b3ebb646460c70b9222ba6b1505a81d3619f /vespajlib/src/main/java
parentda41894e5b4f7525ee59d9c69838bdc21735d0f2 (diff)
Parse tensor value type
Diffstat (limited to 'vespajlib/src/main/java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java64
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java16
4 files changed, 80 insertions, 43 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index fa32d385004..998f3170aa0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -11,7 +11,7 @@ class TensorParser {
static Tensor tensorFrom(String tensorString, Optional<TensorType> type) {
tensorString = tensorString.trim();
try {
- if (tensorString.startsWith("tensor(")) {
+ if (tensorString.startsWith("tensor")) {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
String valueString = tensorString.substring(colonIndex + 1);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 036f5e3ee5d..bded55405c0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -24,25 +24,19 @@ import java.util.stream.Collectors;
*/
public class TensorType {
- public enum ValueType { DOUBLE, FLOAT};
+ /** The permissible cell value types. Default is double. */
+ // Types added here must also be added to TensorTypeParser.parseValueTypeSpec
+ public enum Value { DOUBLE, FLOAT};
/** The empty tensor type - which is the same as a double */
- public static final TensorType empty = new TensorType(ValueType.DOUBLE, Collections.emptyList());
+ public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
- private ValueType valueType;
-
- public final ValueType valueType() { return valueType; }
-
- //TODO Remove once value type is wired in were it should.
- public final TensorType valueType(ValueType valueType) {
- this.valueType = valueType;
- return this;
- }
+ private final Value valueType;
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
- private TensorType(ValueType valueType, Collection<Dimension> dimensions) {
+ private TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
@@ -64,6 +58,9 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns the numeric type of the cell values of this */
+ public Value valueType() { return valueType; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }
@@ -386,14 +383,14 @@ public class TensorType {
private final Map<String, Dimension> dimensions = new LinkedHashMap<>();
- private final ValueType valueType;
+ private final Value valueType;
/** Creates an empty builder with cells of type double*/
public Builder() {
- this(ValueType.DOUBLE);
+ this(Value.DOUBLE);
}
- public Builder(ValueType valueType) {
+ public Builder(Value valueType) {
this.valueType = valueType;
}
@@ -405,21 +402,21 @@ public class TensorType {
* If it is indexed in one and mapped in the other it will become mapped.
*/
public Builder(TensorType ... types) {
- this(ValueType.DOUBLE, types);
+ this(Value.DOUBLE, types);
}
- public Builder(ValueType valueType, TensorType ... types) {
+ public Builder(Value valueType, TensorType ... types) {
this.valueType = valueType;
for (TensorType type : types)
addDimensionsOf(type);
}
- /**
- * Creates a builder from the given dimensions.
- */
+ /** Creates a builder from the given dimensions */
public Builder(Iterable<Dimension> dimensions) {
- this(ValueType.DOUBLE, dimensions);
+ this(Value.DOUBLE, dimensions);
}
- public Builder(ValueType valueType, Iterable<Dimension> dimensions) {
+
+ /** Creates a builder from the given value type and dimensions */
+ public Builder(Value valueType, Iterable<Dimension> dimensions) {
this.valueType = valueType;
for (TensorType.Dimension dimension : dimensions) {
dimension(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 32ad6171e57..a5733f1cc4c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -2,8 +2,10 @@
package com.yahoo.tensor;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -14,23 +16,32 @@ import java.util.regex.Pattern;
*/
public class TensorTypeParser {
- private final static String START_STRING = "tensor(";
+ private final static String START_STRING = "tensor";
private final static String END_STRING = ")";
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
public static TensorType fromSpec(String specString) {
- return new TensorType.Builder(dimensionsFromSpec(specString)).build();
- }
+ if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING))
+ throw formatException(specString);
+ String specBody = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- public static List<TensorType.Dimension> dimensionsFromSpec(String specString) {
- if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
- throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
- " and end with '" + END_STRING + "', but was '" + specString + "'");
+ String dimensionsSpec;
+ TensorType.Value valueType;
+ if (specBody.startsWith("(")) {
+ valueType = TensorType.Value.DOUBLE; // no value type spec: Use default
+ dimensionsSpec = specBody.substring(1);
+ }
+ else {
+ int parenthesisIndex = specBody.indexOf("(");
+ if (parenthesisIndex < 0)
+ throw formatException(specString);
+ valueType = parseValueTypeSpec(specBody.substring(0, parenthesisIndex), specString);
+ dimensionsSpec = specBody.substring(parenthesisIndex + 1);
}
- String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- if (dimensionsSpec.isEmpty()) return Collections.emptyList();
+
+ if (dimensionsSpec.isEmpty()) return new TensorType.Builder(valueType, Collections.emptyList()).build();
List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
@@ -38,10 +49,23 @@ public class TensorTypeParser {
boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
tryParseMappedDimension(trimmedElement, dimensions);
if ( ! success)
- throw new IllegalArgumentException("Failed parsing element '" + element +
- "' in type spec '" + specString + "'");
+ throw formatException(specString, "Dimension '" + element + "' is on the wrong format");
+ }
+ return new TensorType.Builder(valueType, dimensions).build();
+ }
+
+ private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
+ if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
+ throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
+
+ String valueType = valueTypeSpec.substring(1, valueTypeSpec.length() - 1);
+ switch (valueType) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw formatException(fullSpecString,
+ "Value type must be either 'double' or 'float'" +
+ " but was '" + valueType + "'");
}
- return dimensions;
}
private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
@@ -69,5 +93,21 @@ public class TensorTypeParser {
return false;
}
+
+ private static IllegalArgumentException formatException(String spec) {
+ return formatException(spec, Optional.empty());
+ }
+
+ private static IllegalArgumentException formatException(String spec, String errorDetail) {
+ return formatException(spec, Optional.of(errorDetail));
+ }
+
+ private static IllegalArgumentException formatException(String spec, Optional<String> errorDetail) {
+ throw new IllegalArgumentException("A tensor type spec must be on the form " +
+ "tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was '" + spec + "'. " +
+ errorDetail.map(s -> s + ". ").orElse("") +
+ "Examples: tensor(x[]), tensor<float>(name{}, x[10])");
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 500c436516f..ecd4f7d1965 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -43,7 +43,7 @@ public class DenseBinaryFormat implements BinaryFormat {
encodeCells(buffer, tensor);
}
- private void encodeValueType(GrowableByteBuffer buffer, TensorType.ValueType valueType) {
+ private void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
switch (valueType) {
case DOUBLE:
if (encodeType != EncodeType.DOUBLE_IS_DEFAULT) {
@@ -100,7 +100,7 @@ public class DenseBinaryFormat implements BinaryFormat {
sizes = sizesFromType(serializedType);
}
else {
- type = decodeType(buffer, TensorType.ValueType.DOUBLE);
+ type = decodeType(buffer, TensorType.Value.DOUBLE);
sizes = sizesFromType(type);
}
Tensor.Builder builder = Tensor.Builder.of(type, sizes);
@@ -108,16 +108,16 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private TensorType decodeType(GrowableByteBuffer buffer, TensorType.ValueType valueType) {
- TensorType.ValueType serializedValueType = TensorType.ValueType.DOUBLE;
- if ((valueType != TensorType.ValueType.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
+ private TensorType decodeType(GrowableByteBuffer buffer, TensorType.Value valueType) {
+ TensorType.Value serializedValueType = TensorType.Value.DOUBLE;
+ if ((valueType != TensorType.Value.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
int type = buffer.getInt1_4Bytes();
switch (type) {
case DOUBLE_VALUE_TYPE:
- serializedValueType = TensorType.ValueType.DOUBLE;
+ serializedValueType = TensorType.Value.DOUBLE;
break;
case FLOAT_VALUE_TYPE:
- serializedValueType = TensorType.ValueType.FLOAT;
+ serializedValueType = TensorType.Value.FLOAT;
break;
default:
throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal.");
@@ -141,7 +141,7 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private void decodeCells(TensorType.ValueType valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ private void decodeCells(TensorType.Value valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
switch (valueType) {
case DOUBLE:
decodeCellsAsDouble(sizes, buffer, builder);