diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java | 55 |
1 files changed, 30 insertions, 25 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 27a009b5e7e..30b36e83457 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -31,14 +31,14 @@ class SparseBinaryFormat implements BinaryFormat { encodeCells(buffer, tensor); } - private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) { + private void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) { buffer.putInt1_4Bytes(sortedDimensions.size()); for (TensorType.Dimension dimension : sortedDimensions) { - encodeString(buffer, dimension.name()); + buffer.putUtf8String(dimension.name()); } } - private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { + private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { buffer.putInt1_4Bytes(tensor.size()); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -47,35 +47,47 @@ class SparseBinaryFormat implements BinaryFormat { } } - private static void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) { + private void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) { for (int i = 0; i < address.size(); i++) - encodeString(buffer, address.label(i)); - } - - private static void encodeString(GrowableByteBuffer buffer, String value) { - byte[] stringBytes = Utf8.toBytes(value); - buffer.putInt1_4Bytes(stringBytes.length); - buffer.put(stringBytes); + buffer.putUtf8String(address.label(i)); } @Override - public Tensor decode(GrowableByteBuffer buffer) { - TensorType type = decodeDimensions(buffer); + public Tensor decode(TensorType type, GrowableByteBuffer buffer) { + if (type == null) // TODO (January 2017): Remove this when types are available + type = decodeDimensionsToType(buffer); + else + consumeAndValidateDimensions(type, buffer); Tensor.Builder builder = Tensor.Builder.of(type); decodeCells(buffer, builder, type); return builder.build(); } - private static TensorType decodeDimensions(GrowableByteBuffer buffer) { + private TensorType decodeDimensionsToType(GrowableByteBuffer buffer) { TensorType.Builder builder = new TensorType.Builder(); int numDimensions = buffer.getInt1_4Bytes(); for (int i = 0; i < numDimensions; ++i) { - builder.mapped(decodeString(buffer)); // TODO: Support indexed + builder.mapped(buffer.getUtf8String()); } return builder.build(); } - private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { + private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) { + int dimensionCount = buffer.getInt1_4Bytes(); + if (type.dimensions().size() != dimensionCount) + throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + + " dimensions but type is " + type); + + for (int i = 0; i < dimensionCount; ++i) { + TensorType.Dimension expectedDimension = type.dimensions().get(i); + String encodedName = buffer.getUtf8String(); + if ( ! expectedDimension.name().equals(encodedName)) + throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + + "' as dimension " + i + " but type is " + type); + } + } + + private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { int numCells = buffer.getInt1_4Bytes(); for (int i = 0; i < numCells; ++i) { Tensor.Builder.CellBuilder cellBuilder = builder.cell(); @@ -84,20 +96,13 @@ class SparseBinaryFormat implements BinaryFormat { } } - private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) { + private void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) { - String label = decodeString(buffer); + String label = buffer.getUtf8String(); if ( ! label.isEmpty()) { builder.label(dimension.name(), label); } } } - private static String decodeString(GrowableByteBuffer buffer) { - int stringLength = buffer.getInt1_4Bytes(); - byte[] stringBytes = new byte[stringLength]; - buffer.get(stringBytes); - return Utf8.toString(stringBytes); - } - } |