diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index 284dfea2141..bc247e5561f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -11,6 +11,8 @@ import com.yahoo.tensor.TensorType; import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -21,6 +23,15 @@ import java.util.stream.Collectors; */ class MixedBinaryFormat implements BinaryFormat { + private final TensorType.Value serializationValueType; + + MixedBinaryFormat() { + this(TensorType.Value.DOUBLE); + } + MixedBinaryFormat(TensorType.Value serializationValueType) { + this.serializationValueType = serializationValueType; + } + @Override public void encode(GrowableByteBuffer buffer, Tensor tensor) { if ( ! ( tensor instanceof MixedTensor)) @@ -50,6 +61,13 @@ class MixedBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) { + switch (serializationValueType) { + case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; + case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; + } + } + + private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor, Consumer<Double> consumer) { List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); long denseSubspaceSize = tensor.denseSubspaceSize(); if (sparseDimensions.size() > 0) { @@ -63,9 +81,9 @@ class MixedBinaryFormat implements BinaryFormat { new IllegalStateException("Dimension not found in address.")); buffer.putUtf8String(cell.getKey().label(index)); } - buffer.putDouble(cell.getValue()); + consumer.accept(cell.getValue()); for (int i = 1; i < denseSubspaceSize; ++i ) { - buffer.putDouble(cellIterator.next().getValue()); + consumer.accept(cellIterator.next().getValue()); } } } @@ -75,6 +93,10 @@ class MixedBinaryFormat implements BinaryFormat { TensorType type; if (optionalType.isPresent()) { type = optionalType.get(); + if (type.valueType() != this.serializationValueType) { + throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() + + " is not " + this.serializationValueType); + } TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + @@ -89,7 +111,7 @@ class MixedBinaryFormat implements BinaryFormat { } private TensorType decodeType(GrowableByteBuffer buffer) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(serializationValueType); int numMappedDimensions = buffer.getInt1_4Bytes(); for (int i = 0; i < numMappedDimensions; ++i) { builder.mapped(buffer.getUtf8String()); @@ -102,6 +124,13 @@ class MixedBinaryFormat implements BinaryFormat { } private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) { + switch (serializationValueType) { + case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; + case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; + } + } + + private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type, Supplier<Double> supplier) { List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions); long denseSubspaceSize = builder.denseSubspaceSize(); @@ -118,7 +147,7 @@ class MixedBinaryFormat implements BinaryFormat { sparseAddress.add(sparseDimension.name(), buffer.getUtf8String()); } for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) { - denseSubspace[(int)denseOffset] = buffer.getDouble(); + denseSubspace[(int)denseOffset] = supplier.get(); } builder.block(sparseAddress.build(), denseSubspace); } |