summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java55
1 files changed, 30 insertions, 25 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 27a009b5e7e..30b36e83457 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -31,14 +31,14 @@ class SparseBinaryFormat implements BinaryFormat {
encodeCells(buffer, tensor);
}
- private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
+ private void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
buffer.putInt1_4Bytes(sortedDimensions.size());
for (TensorType.Dimension dimension : sortedDimensions) {
- encodeString(buffer, dimension.name());
+ buffer.putUtf8String(dimension.name());
}
}
- private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
buffer.putInt1_4Bytes(tensor.size());
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -47,35 +47,47 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private static void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
+ private void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
for (int i = 0; i < address.size(); i++)
- encodeString(buffer, address.label(i));
- }
-
- private static void encodeString(GrowableByteBuffer buffer, String value) {
- byte[] stringBytes = Utf8.toBytes(value);
- buffer.putInt1_4Bytes(stringBytes.length);
- buffer.put(stringBytes);
+ buffer.putUtf8String(address.label(i));
}
@Override
- public Tensor decode(GrowableByteBuffer buffer) {
- TensorType type = decodeDimensions(buffer);
+ public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
+ if (type == null) // TODO (January 2017): Remove this when types are available
+ type = decodeDimensionsToType(buffer);
+ else
+ consumeAndValidateDimensions(type, buffer);
Tensor.Builder builder = Tensor.Builder.of(type);
decodeCells(buffer, builder, type);
return builder.build();
}
- private static TensorType decodeDimensions(GrowableByteBuffer buffer) {
+ private TensorType decodeDimensionsToType(GrowableByteBuffer buffer) {
TensorType.Builder builder = new TensorType.Builder();
int numDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numDimensions; ++i) {
- builder.mapped(decodeString(buffer)); // TODO: Support indexed
+ builder.mapped(buffer.getUtf8String());
}
return builder.build();
}
- private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
+ private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) {
+ int dimensionCount = buffer.getInt1_4Bytes();
+ if (type.dimensions().size() != dimensionCount)
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
+ " dimensions but type is " + type);
+
+ for (int i = 0; i < dimensionCount; ++i) {
+ TensorType.Dimension expectedDimension = type.dimensions().get(i);
+ String encodedName = buffer.getUtf8String();
+ if ( ! expectedDimension.name().equals(encodedName))
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
+ "' as dimension " + i + " but type is " + type);
+ }
+ }
+
+ private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
int numCells = buffer.getInt1_4Bytes();
for (int i = 0; i < numCells; ++i) {
Tensor.Builder.CellBuilder cellBuilder = builder.cell();
@@ -84,20 +96,13 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
+ private void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
for (TensorType.Dimension dimension : type.dimensions()) {
- String label = decodeString(buffer);
+ String label = buffer.getUtf8String();
if ( ! label.isEmpty()) {
builder.label(dimension.name(), label);
}
}
}
- private static String decodeString(GrowableByteBuffer buffer) {
- int stringLength = buffer.getInt1_4Bytes();
- byte[] stringBytes = new byte[stringLength];
- buffer.get(stringBytes);
- return Utf8.toString(stringBytes);
- }
-
}