diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2019-04-01 15:59:42 +0200 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2019-04-01 15:59:42 +0200 |
commit | 06b999904e735420ad5d1a74ae551f88573d2657 (patch) | |
tree | b8d8f134aab5a4adeb166ba56cedb64281d231ae /vespajlib | |
parent | e8440bd1dbafac4ce09797bb3395b0cac54c6d82 (diff) |
Add support for different values during decoding too.
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/abi-spec.json | 3 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 26 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java | 44 |
3 files changed, 62 insertions, 11 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 766b0daedcc..ccfe99119ee 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1162,8 +1162,11 @@ ], "methods": [ "public void <init>()", + "public void <init>(com.yahoo.tensor.TensorType$ValueType)", "public varargs void <init>(com.yahoo.tensor.TensorType[])", + "public varargs void <init>(com.yahoo.tensor.TensorType$ValueType, com.yahoo.tensor.TensorType[])", "public void <init>(java.lang.Iterable)", + "public void <init>(com.yahoo.tensor.TensorType$ValueType, java.lang.Iterable)", "public int rank()", "public com.yahoo.tensor.TensorType$Builder set(com.yahoo.tensor.TensorType$Dimension)", "public com.yahoo.tensor.TensorType$Builder indexed(java.lang.String, long)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index bb7a976af9b..d548b13601f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -27,16 +27,17 @@ public class TensorType { public enum ValueType { DOUBLE, FLOAT}; /** The empty tensor type - which is the same as a double */ - public static final TensorType empty = new TensorType(Collections.emptyList()); + public static final TensorType empty = new TensorType(ValueType.DOUBLE, Collections.emptyList()); - private final ValueType valueType = ValueType.DOUBLE; + private final ValueType valueType; public final ValueType valueType() { return valueType; } /** Sorted list of the dimensions of this */ private final ImmutableList<Dimension> dimensions; - private TensorType(Collection<Dimension> dimensions) { + private TensorType(ValueType valueType, Collection<Dimension> dimensions) { + this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = ImmutableList.copyOf(dimensionList); @@ -379,8 +380,15 @@ public class TensorType { private final Map<String, Dimension> dimensions = new LinkedHashMap<>(); - /** Creates an empty builder */ + private final ValueType valueType; + + /** Creates an empty builder with cells of type double*/ public Builder() { + this(ValueType.DOUBLE); + } + + public Builder(ValueType valueType) { + this.valueType = valueType; } /** @@ -391,6 +399,10 @@ public class TensorType { * If it is indexed in one and mapped in the other it will become mapped. */ public Builder(TensorType ... types) { + this(ValueType.DOUBLE, types); + } + public Builder(ValueType valueType, TensorType ... types) { + this.valueType = valueType; for (TensorType type : types) addDimensionsOf(type); } @@ -399,6 +411,10 @@ public class TensorType { * Creates a builder from the given dimensions. */ public Builder(Iterable<Dimension> dimensions) { + this(ValueType.DOUBLE, dimensions); + } + public Builder(ValueType valueType, Iterable<Dimension> dimensions) { + this.valueType = valueType; for (TensorType.Dimension dimension : dimensions) { dimension(dimension); } @@ -497,7 +513,7 @@ public class TensorType { } public TensorType build() { - return new TensorType(dimensions.values()); + return new TensorType(valueType, dimensions.values()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 6dcc870ef94..2537e7d8669 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -93,24 +93,41 @@ public class DenseBinaryFormat implements BinaryFormat { DimensionSizes sizes; if (optionalType.isPresent()) { type = optionalType.get(); - TensorType serializedType = decodeType(buffer); + TensorType serializedType = decodeType(buffer, type.valueType()); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + " cannot be assigned to type " + type); sizes = sizesFromType(serializedType); } else { - type = decodeType(buffer); + type = decodeType(buffer, TensorType.ValueType.DOUBLE); sizes = sizesFromType(type); } Tensor.Builder builder = Tensor.Builder.of(type, sizes); - decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder); + decodeCells(type.valueType(), sizes, buffer, (IndexedTensor.BoundBuilder)builder); return builder.build(); } - private TensorType decodeType(GrowableByteBuffer buffer) { + private TensorType decodeType(GrowableByteBuffer buffer, TensorType.ValueType valueType) { + TensorType.ValueType serializedValueType = TensorType.ValueType.DOUBLE; + if ((valueType != TensorType.ValueType.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) { + int type = buffer.getInt1_4Bytes(); + switch (type) { + case DOUBLE_VALUE_TYPE: + serializedValueType = TensorType.ValueType.DOUBLE; + break; + case FLOAT_VALUE_TYPE: + serializedValueType = TensorType.ValueType.DOUBLE; + break; + default: + throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal."); + } + } + if (valueType != serializedValueType) { + throw new IllegalArgumentException("Expected " + valueType + ", got " + serializedValueType); + } + TensorType.Builder builder = new TensorType.Builder(serializedValueType); int dimensionCount = buffer.getInt1_4Bytes(); - TensorType.Builder builder = new TensorType.Builder(); for (int i = 0; i < dimensionCount; i++) builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation return builder.build(); @@ -124,9 +141,24 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + private void decodeCells(TensorType.ValueType valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + switch (valueType) { + case DOUBLE: + decodeCellsAsDouble(sizes, buffer, builder); + break; + case FLOAT: + decodeCellsAsFloat(sizes, buffer, builder); + break; + } + } + + private void decodeCellsAsDouble(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { for (long i = 0; i < sizes.totalSize(); i++) builder.cellByDirectIndex(i, buffer.getDouble()); } + private void decodeCellsAsFloat(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.getFloat()); + } } |