aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r--vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java45
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java87
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java55
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java22
10 files changed, 72 insertions, 193 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
index eba749bd14e..c33882052b4 100644
--- a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
+++ b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
@@ -1,8 +1,6 @@
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.io;
-import com.yahoo.text.Utf8;
-
import java.nio.*;
/**
@@ -22,22 +20,21 @@ import java.nio.*;
* No methods except getByteBuffer() expose the encapsulated
* ByteBuffer, which is intentional.
*
- * @author Einar M R Rosenvinge
+ * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
*/
public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
-
public static final int DEFAULT_BASE_SIZE = 64*1024;
public static final float DEFAULT_GROW_FACTOR = 2.0f;
private ByteBuffer buffer;
private float growFactor;
private int mark = -1;
- // NOTE: It might have been better to subclass HeapByteBuffer,
- // but that class is package-private. Subclassing ByteBuffer would involve
- // implementing a lot of abstract methods, which would mean reinventing
- // some (too many) wheels.
+ //NOTE: It might have been better to subclass HeapByteBuffer,
+ //but that class is package-private. Subclassing ByteBuffer would involve
+ //implementing a lot of abstract methods, which would mean reinventing
+ //some (too many) wheels.
- // CONSTRUCTORS:
+ //CONSTRUCTORS:
public GrowableByteBuffer() {
this(DEFAULT_BASE_SIZE, DEFAULT_GROW_FACTOR);
@@ -64,7 +61,7 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
}
- // ACCESSORS:
+ //ACCESSORS:
public float getGrowFactor() {
return growFactor;
@@ -367,21 +364,6 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
}
}
- /** Writes this string to the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */
- public void putUtf8String(String value) {
- byte[] stringBytes = Utf8.toBytes(value);
- putInt1_4Bytes(stringBytes.length);
- put(stringBytes);
- }
-
- /** Reads a string from the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */
- public String getUtf8String() {
- int stringLength = getInt1_4Bytes();
- byte[] stringBytes = new byte[stringLength];
- get(stringBytes);
- return Utf8.toString(stringBytes);
- }
-
/**
* Computes the size used for storing the given integer using 1 or 4 bytes.
*
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 7570a357452..daa85cc51e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -29,14 +29,6 @@ public final class DimensionSizes {
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
- /** Returns the product of the sizes of this */
- public int totalSize() {
- int productSize = 1;
- for (int dimensionSize : sizes )
- productSize *= dimensionSize;
- return productSize;
- }
-
@Override
public boolean equals(Object o) {
if (o == this) return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index bee93ddb4e0..9315922f57a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -103,6 +103,7 @@ public class IndexedTensor implements Tensor {
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(int ... indexes) {
+ if (values.length == 0) return Double.NaN;
return values[toValueIndex(indexes, dimensionSizes)];
}
@@ -156,7 +157,7 @@ public class IndexedTensor implements Tensor {
@Override
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
- return Collections.singletonMap(TensorAddress.empty, values[0]);
+ return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
@@ -216,6 +217,13 @@ public class IndexedTensor implements Tensor {
public abstract Builder cell(double value, int ... indexes);
+ protected double[] arrayFor(DimensionSizes sizes) {
+ int productSize = 1;
+ for (int i = 0; i < sizes.dimensions(); i++ )
+ productSize *= sizes.size(i);
+ return new double[productSize];
+ }
+
@Override
public TensorType type() { return type; }
@@ -225,7 +233,7 @@ public class IndexedTensor implements Tensor {
}
/** A bound builder can create the double array directly */
- public static class BoundBuilder extends Builder {
+ private static class BoundBuilder extends Builder {
private DimensionSizes sizes;
private double[] values;
@@ -234,7 +242,7 @@ public class IndexedTensor implements Tensor {
this(type, dimensionSizesOf(type));
}
- static DimensionSizes dimensionSizesOf(TensorType type) {
+ public static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < type.dimensions().size(); i++)
b.set(i, type.dimensions().get(i).size().get());
@@ -246,7 +254,8 @@ public class IndexedTensor implements Tensor {
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[sizes.totalSize()];
+ values = arrayFor(sizes);
+ Arrays.fill(values, Double.NaN);
}
@Override
@@ -268,6 +277,10 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
+ // Note that we do not check for no NaN's here for performance reasons.
+ // NaN's don't get lost so leaving them in place should be quite benign
+ if (values.length == 1 && Double.isNaN(values[0]))
+ values = new double[0];
IndexedTensor tensor = new IndexedTensor(type, sizes, values);
// prevent further modification
sizes = null;
@@ -277,6 +290,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(Cell cell, double value) {
+ // TODO: Use internal index if applicable
+ // values[internalIndex] = value;
+ // return this;
int directIndex = cell.getDirectIndex();
if (directIndex >= 0) // optimization
values[directIndex] = value;
@@ -285,15 +301,6 @@ public class IndexedTensor implements Tensor {
return this;
}
- /**
- * Set a cell value by the index in the internal layout of this cell.
- * This requires knowledge of the internal layout of cells in this implementation, and should therefore
- * probably not be used (but when it can be used it is fast).
- */
- public void cellByDirectIndex(int index, double value) {
- values[index] = value;
- }
-
}
/**
@@ -311,13 +318,13 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
- if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values");
-
+ if (firstDimension == null) // empty
+ return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {});
if (type.dimensions().isEmpty()) // single number
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
- double[] values = new double[dimensionSizes.totalSize()];
+ double[] values = arrayFor(dimensionSizes);
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
@@ -326,10 +333,8 @@ public class IndexedTensor implements Tensor {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
- for (int i = 0; i < b.dimensions(); i++) {
- if (i < dimensionSizeList.size())
- b.set(i, dimensionSizeList.get(i));
- }
+ for (int i = 0; i < b.dimensions(); i++)
+ b.set(i, dimensionSizeList.get(i));
return b.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 29c508ce12f..51d40a89f3b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -213,9 +213,10 @@ public interface Tensor {
static String contentToString(Tensor tensor) {
List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
- if (tensor.type().dimensions().isEmpty()) {
+ if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number
if (cellEntries.isEmpty()) return "{}";
- return "{" + cellEntries.get(0).getValue() +"}";
+ double value = cellEntries.get(0).getValue();
+ return value == 0.0 ? "{}" : "{" + value +"}";
}
Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index fbc469c1829..82f36972a47 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -53,6 +53,9 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns true if all dimensions of this are indexed */
+ public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); }
+
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index f295e129a0f..ceade39ce42 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -113,7 +113,7 @@ public class Join extends PrimitiveTensorFunction {
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
- if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
+ if (subspace.type().isIndexed() && superspace.type().isIndexed())
return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
else
return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
index 9b0ccdcb6c8..f3adf63739a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
@@ -4,7 +4,6 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
/**
* Representation of a specific binary format with functions for serializing a Tensor object into
@@ -22,10 +21,7 @@ interface BinaryFormat {
/**
* Deserialize the given binary data into a Tensor object.
- *
- * @param type the expected abstract type of the tensor to serialize
- * @param buffer the buffer containing the tensor binary data
*/
- Tensor decode(TensorType type, GrowableByteBuffer buffer);
+ Tensor decode(GrowableByteBuffer buffer);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
deleted file mode 100644
index 0a97576d5b7..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ /dev/null
@@ -1,87 +0,0 @@
-package com.yahoo.tensor.serialization;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.io.GrowableByteBuffer;
-import com.yahoo.tensor.DimensionSizes;
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.text.Utf8;
-
-import java.util.Iterator;
-
-/**
- * Implementation of a dense binary format for a tensor on the form:
- *
- * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]*
- * Cell_values = [double, double, double, ...]*
- * where values are encoded in order of increasing indexes in each dimension, increasing
- * indexes of later dimensions in the dimension type before earlier.
- *
- * @author bratseth
- */
-@Beta
-public class DenseBinaryFormat implements BinaryFormat {
-
- @Override
- public void encode(GrowableByteBuffer buffer, Tensor tensor) {
- if ( ! ( tensor instanceof IndexedTensor))
- throw new RuntimeException("The dense format is only supported for indexed tensors");
- encodeDimensions(buffer, (IndexedTensor)tensor);
- encodeCells(buffer, tensor);
- }
-
- private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) {
- buffer.putInt1_4Bytes(tensor.type().dimensions().size());
- for (int i = 0; i < tensor.type().dimensions().size(); i++) {
- buffer.putUtf8String(tensor.type().dimensions().get(i).name());
- buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i));
- }
- }
-
- private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
- Iterator<Double> i = tensor.valueIterator();
- while (i.hasNext())
- buffer.putDouble(i.next());
- }
-
- @Override
- public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
- DimensionSizes sizes = decodeDimensionSizes(type, buffer);
- Tensor.Builder builder = Tensor.Builder.of(type, sizes);
- decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder);
- return builder.build();
- }
-
- private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) {
- int dimensionCount = buffer.getInt1_4Bytes();
- if (type.dimensions().size() != dimensionCount)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
- " dimensions but type is " + type);
-
- DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount);
- for (int i = 0; i < dimensionCount; i++) {
- TensorType.Dimension expectedDimension = type.dimensions().get(i);
-
- String encodedName = buffer.getUtf8String();
- int encodedSize = buffer.getInt1_4Bytes();
-
- if ( ! expectedDimension.name().equals(encodedName))
- throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
- "' as dimension " + i + " but type is " + type);
-
- if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize +
- " in " + expectedDimension + " in type " + type);
-
- builder.set(i, encodedSize);
- }
- return builder.build();
- }
-
- private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
- for (int i = 0; i < sizes.totalSize(); i++)
- builder.cellByDirectIndex(i, buffer.getDouble());
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 30b36e83457..27a009b5e7e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -31,14 +31,14 @@ class SparseBinaryFormat implements BinaryFormat {
encodeCells(buffer, tensor);
}
- private void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
+ private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
buffer.putInt1_4Bytes(sortedDimensions.size());
for (TensorType.Dimension dimension : sortedDimensions) {
- buffer.putUtf8String(dimension.name());
+ encodeString(buffer, dimension.name());
}
}
- private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
buffer.putInt1_4Bytes(tensor.size());
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -47,47 +47,35 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
+ private static void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
for (int i = 0; i < address.size(); i++)
- buffer.putUtf8String(address.label(i));
+ encodeString(buffer, address.label(i));
+ }
+
+ private static void encodeString(GrowableByteBuffer buffer, String value) {
+ byte[] stringBytes = Utf8.toBytes(value);
+ buffer.putInt1_4Bytes(stringBytes.length);
+ buffer.put(stringBytes);
}
@Override
- public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
- if (type == null) // TODO (January 2017): Remove this when types are available
- type = decodeDimensionsToType(buffer);
- else
- consumeAndValidateDimensions(type, buffer);
+ public Tensor decode(GrowableByteBuffer buffer) {
+ TensorType type = decodeDimensions(buffer);
Tensor.Builder builder = Tensor.Builder.of(type);
decodeCells(buffer, builder, type);
return builder.build();
}
- private TensorType decodeDimensionsToType(GrowableByteBuffer buffer) {
+ private static TensorType decodeDimensions(GrowableByteBuffer buffer) {
TensorType.Builder builder = new TensorType.Builder();
int numDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numDimensions; ++i) {
- builder.mapped(buffer.getUtf8String());
+ builder.mapped(decodeString(buffer)); // TODO: Support indexed
}
return builder.build();
}
- private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) {
- int dimensionCount = buffer.getInt1_4Bytes();
- if (type.dimensions().size() != dimensionCount)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
- " dimensions but type is " + type);
-
- for (int i = 0; i < dimensionCount; ++i) {
- TensorType.Dimension expectedDimension = type.dimensions().get(i);
- String encodedName = buffer.getUtf8String();
- if ( ! expectedDimension.name().equals(encodedName))
- throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
- "' as dimension " + i + " but type is " + type);
- }
- }
-
- private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
+ private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
int numCells = buffer.getInt1_4Bytes();
for (int i = 0; i < numCells; ++i) {
Tensor.Builder.CellBuilder cellBuilder = builder.cell();
@@ -96,13 +84,20 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
+ private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
for (TensorType.Dimension dimension : type.dimensions()) {
- String label = buffer.getUtf8String();
+ String label = decodeString(buffer);
if ( ! label.isEmpty()) {
builder.label(dimension.name(), label);
}
}
}
+ private static String decodeString(GrowableByteBuffer buffer) {
+ int stringLength = buffer.getInt1_4Bytes();
+ byte[] stringBytes = new byte[stringLength];
+ buffer.get(stringBytes);
+ return Utf8.toString(stringBytes);
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 65216aa2fcd..5a45f20b6d8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -3,9 +3,7 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
-import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
/**
* Class used by clients for serializing a Tensor object into binary format or
@@ -20,31 +18,25 @@ import com.yahoo.tensor.TensorType;
public class TypedBinaryFormat {
private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
- private static final int DENSE_BINARY_FORMAT_TYPE = 2;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
- if (tensor instanceof IndexedTensor && 1==2) { // TODO: Activate when we have type information everywhere
- buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
- new DenseBinaryFormat().encode(buffer, tensor);
- }
- else {
- buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
- new SparseBinaryFormat().encode(buffer, tensor);
- }
+ buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
+ new SparseBinaryFormat().encode(buffer, tensor);
buffer.flip();
byte[] result = new byte[buffer.remaining()];
buffer.get(result);
return result;
}
- public static Tensor decode(TensorType type, byte[] data) {
+ public static Tensor decode(byte[] data) {
GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data);
int formatType = buffer.getInt1_4Bytes();
switch (formatType) {
- case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
- case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat().decode(type, buffer);
- default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
+ case SPARSE_BINARY_FORMAT_TYPE:
+ return new SparseBinaryFormat().decode(buffer);
+ default:
+ throw new IllegalArgumentException("Binary format type " + formatType + " is not a known format");
}
}