summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java37
1 files changed, 33 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index 284dfea2141..bc247e5561f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -11,6 +11,8 @@ import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
/**
@@ -21,6 +23,15 @@ import java.util.stream.Collectors;
*/
class MixedBinaryFormat implements BinaryFormat {
+ private final TensorType.Value serializationValueType;
+
+ MixedBinaryFormat() {
+ this(TensorType.Value.DOUBLE);
+ }
+ MixedBinaryFormat(TensorType.Value serializationValueType) {
+ this.serializationValueType = serializationValueType;
+ }
+
@Override
public void encode(GrowableByteBuffer buffer, Tensor tensor) {
if ( ! ( tensor instanceof MixedTensor))
@@ -50,6 +61,13 @@ class MixedBinaryFormat implements BinaryFormat {
}
private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) {
+ switch (serializationValueType) {
+ case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
+ case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
+ }
+ }
+
+ private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor, Consumer<Double> consumer) {
List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
long denseSubspaceSize = tensor.denseSubspaceSize();
if (sparseDimensions.size() > 0) {
@@ -63,9 +81,9 @@ class MixedBinaryFormat implements BinaryFormat {
new IllegalStateException("Dimension not found in address."));
buffer.putUtf8String(cell.getKey().label(index));
}
- buffer.putDouble(cell.getValue());
+ consumer.accept(cell.getValue());
for (int i = 1; i < denseSubspaceSize; ++i ) {
- buffer.putDouble(cellIterator.next().getValue());
+ consumer.accept(cellIterator.next().getValue());
}
}
}
@@ -75,6 +93,10 @@ class MixedBinaryFormat implements BinaryFormat {
TensorType type;
if (optionalType.isPresent()) {
type = optionalType.get();
+ if (type.valueType() != this.serializationValueType) {
+ throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() +
+ " is not " + this.serializationValueType);
+ }
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
@@ -89,7 +111,7 @@ class MixedBinaryFormat implements BinaryFormat {
}
private TensorType decodeType(GrowableByteBuffer buffer) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(serializationValueType);
int numMappedDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numMappedDimensions; ++i) {
builder.mapped(buffer.getUtf8String());
@@ -102,6 +124,13 @@ class MixedBinaryFormat implements BinaryFormat {
}
private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
+ switch (serializationValueType) {
+ case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
+ case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
+ }
+ }
+
+ private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type, Supplier<Double> supplier) {
List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions);
long denseSubspaceSize = builder.denseSubspaceSize();
@@ -118,7 +147,7 @@ class MixedBinaryFormat implements BinaryFormat {
sparseAddress.add(sparseDimension.name(), buffer.getUtf8String());
}
for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
- denseSubspace[(int)denseOffset] = buffer.getDouble();
+ denseSubspace[(int)denseOffset] = supplier.get();
}
builder.block(sparseAddress.build(), denseSubspace);
}