aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2017-01-10 19:15:22 +0100
committerGitHub <noreply@github.com>2017-01-10 19:15:22 +0100
commitf3b8b754e40c346ca2ea23cf8f114adbbab041a7 (patch)
tree1aee34e9b0a15e2000054b859dcc80ca5d352c2a /vespajlib/src
parent4a4b1952754ef75b86d35eb1a85bdc180eeb935c (diff)
Revert "Add (disabled) dense tensor binary format"
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java45
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java87
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java55
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java30
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java15
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java5
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java55
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java59
16 files changed, 147 insertions, 293 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
index eba749bd14e..c33882052b4 100644
--- a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
+++ b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
@@ -1,8 +1,6 @@
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.io;
-import com.yahoo.text.Utf8;
-
import java.nio.*;
/**
@@ -22,22 +20,21 @@ import java.nio.*;
* No methods except getByteBuffer() expose the encapsulated
* ByteBuffer, which is intentional.
*
- * @author Einar M R Rosenvinge
+ * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
*/
public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
-
public static final int DEFAULT_BASE_SIZE = 64*1024;
public static final float DEFAULT_GROW_FACTOR = 2.0f;
private ByteBuffer buffer;
private float growFactor;
private int mark = -1;
- // NOTE: It might have been better to subclass HeapByteBuffer,
- // but that class is package-private. Subclassing ByteBuffer would involve
- // implementing a lot of abstract methods, which would mean reinventing
- // some (too many) wheels.
+ //NOTE: It might have been better to subclass HeapByteBuffer,
+ //but that class is package-private. Subclassing ByteBuffer would involve
+ //implementing a lot of abstract methods, which would mean reinventing
+ //some (too many) wheels.
- // CONSTRUCTORS:
+ //CONSTRUCTORS:
public GrowableByteBuffer() {
this(DEFAULT_BASE_SIZE, DEFAULT_GROW_FACTOR);
@@ -64,7 +61,7 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
}
- // ACCESSORS:
+ //ACCESSORS:
public float getGrowFactor() {
return growFactor;
@@ -367,21 +364,6 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
}
}
- /** Writes this string to the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */
- public void putUtf8String(String value) {
- byte[] stringBytes = Utf8.toBytes(value);
- putInt1_4Bytes(stringBytes.length);
- put(stringBytes);
- }
-
- /** Reads a string from the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */
- public String getUtf8String() {
- int stringLength = getInt1_4Bytes();
- byte[] stringBytes = new byte[stringLength];
- get(stringBytes);
- return Utf8.toString(stringBytes);
- }
-
/**
* Computes the size used for storing the given integer using 1 or 4 bytes.
*
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 7570a357452..daa85cc51e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -29,14 +29,6 @@ public final class DimensionSizes {
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
- /** Returns the product of the sizes of this */
- public int totalSize() {
- int productSize = 1;
- for (int dimensionSize : sizes )
- productSize *= dimensionSize;
- return productSize;
- }
-
@Override
public boolean equals(Object o) {
if (o == this) return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index bee93ddb4e0..9315922f57a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -103,6 +103,7 @@ public class IndexedTensor implements Tensor {
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(int ... indexes) {
+ if (values.length == 0) return Double.NaN;
return values[toValueIndex(indexes, dimensionSizes)];
}
@@ -156,7 +157,7 @@ public class IndexedTensor implements Tensor {
@Override
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
- return Collections.singletonMap(TensorAddress.empty, values[0]);
+ return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
@@ -216,6 +217,13 @@ public class IndexedTensor implements Tensor {
public abstract Builder cell(double value, int ... indexes);
+ protected double[] arrayFor(DimensionSizes sizes) {
+ int productSize = 1;
+ for (int i = 0; i < sizes.dimensions(); i++ )
+ productSize *= sizes.size(i);
+ return new double[productSize];
+ }
+
@Override
public TensorType type() { return type; }
@@ -225,7 +233,7 @@ public class IndexedTensor implements Tensor {
}
/** A bound builder can create the double array directly */
- public static class BoundBuilder extends Builder {
+ private static class BoundBuilder extends Builder {
private DimensionSizes sizes;
private double[] values;
@@ -234,7 +242,7 @@ public class IndexedTensor implements Tensor {
this(type, dimensionSizesOf(type));
}
- static DimensionSizes dimensionSizesOf(TensorType type) {
+ public static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < type.dimensions().size(); i++)
b.set(i, type.dimensions().get(i).size().get());
@@ -246,7 +254,8 @@ public class IndexedTensor implements Tensor {
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[sizes.totalSize()];
+ values = arrayFor(sizes);
+ Arrays.fill(values, Double.NaN);
}
@Override
@@ -268,6 +277,10 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
+ // Note that we do not check for no NaN's here for performance reasons.
+ // NaN's don't get lost so leaving them in place should be quite benign
+ if (values.length == 1 && Double.isNaN(values[0]))
+ values = new double[0];
IndexedTensor tensor = new IndexedTensor(type, sizes, values);
// prevent further modification
sizes = null;
@@ -277,6 +290,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(Cell cell, double value) {
+ // TODO: Use internal index if applicable
+ // values[internalIndex] = value;
+ // return this;
int directIndex = cell.getDirectIndex();
if (directIndex >= 0) // optimization
values[directIndex] = value;
@@ -285,15 +301,6 @@ public class IndexedTensor implements Tensor {
return this;
}
- /**
- * Set a cell value by the index in the internal layout of this cell.
- * This requires knowledge of the internal layout of cells in this implementation, and should therefore
- * probably not be used (but when it can be used it is fast).
- */
- public void cellByDirectIndex(int index, double value) {
- values[index] = value;
- }
-
}
/**
@@ -311,13 +318,13 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
- if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values");
-
+ if (firstDimension == null) // empty
+ return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {});
if (type.dimensions().isEmpty()) // single number
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
- double[] values = new double[dimensionSizes.totalSize()];
+ double[] values = arrayFor(dimensionSizes);
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
@@ -326,10 +333,8 @@ public class IndexedTensor implements Tensor {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
- for (int i = 0; i < b.dimensions(); i++) {
- if (i < dimensionSizeList.size())
- b.set(i, dimensionSizeList.get(i));
- }
+ for (int i = 0; i < b.dimensions(); i++)
+ b.set(i, dimensionSizeList.get(i));
return b.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 29c508ce12f..51d40a89f3b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -213,9 +213,10 @@ public interface Tensor {
static String contentToString(Tensor tensor) {
List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
- if (tensor.type().dimensions().isEmpty()) {
+ if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number
if (cellEntries.isEmpty()) return "{}";
- return "{" + cellEntries.get(0).getValue() +"}";
+ double value = cellEntries.get(0).getValue();
+ return value == 0.0 ? "{}" : "{" + value +"}";
}
Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index fbc469c1829..82f36972a47 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -53,6 +53,9 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns true if all dimensions of this are indexed */
+ public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); }
+
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index f295e129a0f..ceade39ce42 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -113,7 +113,7 @@ public class Join extends PrimitiveTensorFunction {
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
- if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
+ if (subspace.type().isIndexed() && superspace.type().isIndexed())
return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
else
return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
index 9b0ccdcb6c8..f3adf63739a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
@@ -4,7 +4,6 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
/**
* Representation of a specific binary format with functions for serializing a Tensor object into
@@ -22,10 +21,7 @@ interface BinaryFormat {
/**
* Deserialize the given binary data into a Tensor object.
- *
- * @param type the expected abstract type of the tensor to serialize
- * @param buffer the buffer containing the tensor binary data
*/
- Tensor decode(TensorType type, GrowableByteBuffer buffer);
+ Tensor decode(GrowableByteBuffer buffer);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
deleted file mode 100644
index 0a97576d5b7..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ /dev/null
@@ -1,87 +0,0 @@
-package com.yahoo.tensor.serialization;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.io.GrowableByteBuffer;
-import com.yahoo.tensor.DimensionSizes;
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.text.Utf8;
-
-import java.util.Iterator;
-
-/**
- * Implementation of a dense binary format for a tensor on the form:
- *
- * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]*
- * Cell_values = [double, double, double, ...]*
- * where values are encoded in order of increasing indexes in each dimension, increasing
- * indexes of later dimensions in the dimension type before earlier.
- *
- * @author bratseth
- */
-@Beta
-public class DenseBinaryFormat implements BinaryFormat {
-
- @Override
- public void encode(GrowableByteBuffer buffer, Tensor tensor) {
- if ( ! ( tensor instanceof IndexedTensor))
- throw new RuntimeException("The dense format is only supported for indexed tensors");
- encodeDimensions(buffer, (IndexedTensor)tensor);
- encodeCells(buffer, tensor);
- }
-
- private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) {
- buffer.putInt1_4Bytes(tensor.type().dimensions().size());
- for (int i = 0; i < tensor.type().dimensions().size(); i++) {
- buffer.putUtf8String(tensor.type().dimensions().get(i).name());
- buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i));
- }
- }
-
- private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
- Iterator<Double> i = tensor.valueIterator();
- while (i.hasNext())
- buffer.putDouble(i.next());
- }
-
- @Override
- public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
- DimensionSizes sizes = decodeDimensionSizes(type, buffer);
- Tensor.Builder builder = Tensor.Builder.of(type, sizes);
- decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder);
- return builder.build();
- }
-
- private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) {
- int dimensionCount = buffer.getInt1_4Bytes();
- if (type.dimensions().size() != dimensionCount)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
- " dimensions but type is " + type);
-
- DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount);
- for (int i = 0; i < dimensionCount; i++) {
- TensorType.Dimension expectedDimension = type.dimensions().get(i);
-
- String encodedName = buffer.getUtf8String();
- int encodedSize = buffer.getInt1_4Bytes();
-
- if ( ! expectedDimension.name().equals(encodedName))
- throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
- "' as dimension " + i + " but type is " + type);
-
- if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize +
- " in " + expectedDimension + " in type " + type);
-
- builder.set(i, encodedSize);
- }
- return builder.build();
- }
-
- private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
- for (int i = 0; i < sizes.totalSize(); i++)
- builder.cellByDirectIndex(i, buffer.getDouble());
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 30b36e83457..27a009b5e7e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -31,14 +31,14 @@ class SparseBinaryFormat implements BinaryFormat {
encodeCells(buffer, tensor);
}
- private void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
+ private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
buffer.putInt1_4Bytes(sortedDimensions.size());
for (TensorType.Dimension dimension : sortedDimensions) {
- buffer.putUtf8String(dimension.name());
+ encodeString(buffer, dimension.name());
}
}
- private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
buffer.putInt1_4Bytes(tensor.size());
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -47,47 +47,35 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
+ private static void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
for (int i = 0; i < address.size(); i++)
- buffer.putUtf8String(address.label(i));
+ encodeString(buffer, address.label(i));
+ }
+
+ private static void encodeString(GrowableByteBuffer buffer, String value) {
+ byte[] stringBytes = Utf8.toBytes(value);
+ buffer.putInt1_4Bytes(stringBytes.length);
+ buffer.put(stringBytes);
}
@Override
- public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
- if (type == null) // TODO (January 2017): Remove this when types are available
- type = decodeDimensionsToType(buffer);
- else
- consumeAndValidateDimensions(type, buffer);
+ public Tensor decode(GrowableByteBuffer buffer) {
+ TensorType type = decodeDimensions(buffer);
Tensor.Builder builder = Tensor.Builder.of(type);
decodeCells(buffer, builder, type);
return builder.build();
}
- private TensorType decodeDimensionsToType(GrowableByteBuffer buffer) {
+ private static TensorType decodeDimensions(GrowableByteBuffer buffer) {
TensorType.Builder builder = new TensorType.Builder();
int numDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numDimensions; ++i) {
- builder.mapped(buffer.getUtf8String());
+ builder.mapped(decodeString(buffer)); // TODO: Support indexed
}
return builder.build();
}
- private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) {
- int dimensionCount = buffer.getInt1_4Bytes();
- if (type.dimensions().size() != dimensionCount)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
- " dimensions but type is " + type);
-
- for (int i = 0; i < dimensionCount; ++i) {
- TensorType.Dimension expectedDimension = type.dimensions().get(i);
- String encodedName = buffer.getUtf8String();
- if ( ! expectedDimension.name().equals(encodedName))
- throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
- "' as dimension " + i + " but type is " + type);
- }
- }
-
- private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
+ private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
int numCells = buffer.getInt1_4Bytes();
for (int i = 0; i < numCells; ++i) {
Tensor.Builder.CellBuilder cellBuilder = builder.cell();
@@ -96,13 +84,20 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
+ private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
for (TensorType.Dimension dimension : type.dimensions()) {
- String label = buffer.getUtf8String();
+ String label = decodeString(buffer);
if ( ! label.isEmpty()) {
builder.label(dimension.name(), label);
}
}
}
+ private static String decodeString(GrowableByteBuffer buffer) {
+ int stringLength = buffer.getInt1_4Bytes();
+ byte[] stringBytes = new byte[stringLength];
+ buffer.get(stringBytes);
+ return Utf8.toString(stringBytes);
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 65216aa2fcd..5a45f20b6d8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -3,9 +3,7 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
-import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
/**
* Class used by clients for serializing a Tensor object into binary format or
@@ -20,31 +18,25 @@ import com.yahoo.tensor.TensorType;
public class TypedBinaryFormat {
private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
- private static final int DENSE_BINARY_FORMAT_TYPE = 2;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
- if (tensor instanceof IndexedTensor && 1==2) { // TODO: Activate when we have type information everywhere
- buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
- new DenseBinaryFormat().encode(buffer, tensor);
- }
- else {
- buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
- new SparseBinaryFormat().encode(buffer, tensor);
- }
+ buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
+ new SparseBinaryFormat().encode(buffer, tensor);
buffer.flip();
byte[] result = new byte[buffer.remaining()];
buffer.get(result);
return result;
}
- public static Tensor decode(TensorType type, byte[] data) {
+ public static Tensor decode(byte[] data) {
GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data);
int formatType = buffer.getInt1_4Bytes();
switch (formatType) {
- case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
- case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat().decode(type, buffer);
- default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
+ case SPARSE_BINARY_FORMAT_TYPE:
+ return new SparseBinaryFormat().decode(buffer);
+ default:
+ throw new IllegalArgumentException("Binary format type " + formatType + " is not a known format");
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index e150b1cf24f..3f7f02c6c00 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -1,6 +1,5 @@
package com.yahoo.tensor;
-import junit.framework.TestCase;
import org.junit.Test;
import java.util.HashMap;
@@ -8,7 +7,6 @@ import java.util.Iterator;
import java.util.Map;
import static junit.framework.TestCase.assertTrue;
-import static junit.framework.TestCase.fail;
import static org.junit.Assert.assertEquals;
/**
@@ -25,12 +23,16 @@ public class IndexedTensorTestCase {
@Test
public void testEmpty() {
Tensor empty = Tensor.Builder.of(TensorType.empty).build();
- assertEquals(1, empty.size());
- assertEquals((double)0.0, (double)empty.valueIterator().next(), 0.00000001);
+ assertTrue(empty instanceof IndexedTensor);
+ assertTrue(empty.isEmpty());
+ assertEquals("{}", empty.toString());
Tensor emptyFromString = Tensor.from(TensorType.empty, "{}");
+ assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString());
+ assertTrue(emptyFromString.isEmpty());
+ assertTrue(emptyFromString instanceof IndexedTensor);
assertEquals(empty, emptyFromString);
}
-
+
@Test
public void testSingleValue() {
Tensor singleValue = Tensor.Builder.of(TensorType.empty).cell(TensorAddress.empty, 3.5).build();
@@ -43,6 +45,22 @@ public class IndexedTensorTestCase {
}
@Test
+ public void testSingleValueWithDimensions() {
+ TensorType type = new TensorType.Builder().indexed("x").indexed("y").build();
+ Tensor emptyWithDimensions = Tensor.Builder.of(type).build();
+ assertTrue(emptyWithDimensions instanceof IndexedTensor);
+ assertEquals("tensor(x[],y[]):{}", emptyWithDimensions.toString());
+ Tensor emptyWithDimensionsFromString = Tensor.from("tensor(x[],y[]):{}");
+ assertEquals("tensor(x[],y[]):{}", emptyWithDimensionsFromString.toString());
+ assertTrue(emptyWithDimensionsFromString instanceof IndexedTensor);
+ assertEquals(emptyWithDimensions, emptyWithDimensionsFromString);
+
+ IndexedTensor emptyWithDimensionsIndexed = (IndexedTensor)emptyWithDimensions;
+ assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(0));
+ assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(1));
+ }
+
+ @Test
public void testBoundBuilding() {
TensorType type = new TensorType.Builder().indexed("v", vSize)
.indexed("w", wSize)
@@ -73,7 +91,7 @@ public class IndexedTensorTestCase {
for (int z = 0; z < zSize; z++)
builder.cell(value(v, w, x, y, z), v, w, x, y, z);
- IndexedTensor tensor = (IndexedTensor)builder.build();
+ IndexedTensor tensor = builder.build();
// Lookup by index arguments
for (int v = 0; v < vSize; v++)
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
index 5c2c3b9db32..4c32a80dc11 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
@@ -2,7 +2,6 @@
package com.yahoo.tensor;
import com.google.common.collect.Sets;
-import junit.framework.TestCase;
import org.junit.Test;
import java.util.Set;
@@ -19,20 +18,6 @@ import static org.junit.Assert.fail;
public class MappedTensorTestCase {
@Test
- public void testEmpty() {
- TensorType type = new TensorType.Builder().mapped("x").build();
- Tensor empty = Tensor.Builder.of(type).build();
- TestCase.assertTrue(empty instanceof MappedTensor);
- TestCase.assertTrue(empty.isEmpty());
- assertEquals("tensor(x{}):{}", empty.toString());
- Tensor emptyFromString = Tensor.from(type, "{}");
- assertEquals("tensor(x{}):{}", Tensor.from("tensor(x{}):{}").toString());
- TestCase.assertTrue(emptyFromString.isEmpty());
- TestCase.assertTrue(emptyFromString instanceof MappedTensor);
- assertEquals(empty, emptyFromString);
- }
-
- @Test
public void testOneDimensionalBuilding() {
TensorType type = new TensorType.Builder().mapped("x").build();
Tensor tensor = Tensor.Builder.of(type).
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index e2baa1d5ac3..2f060239eb1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -27,7 +27,6 @@ public class TensorFunctionBenchmark {
modelVectors = modelVectors.stream().map(t -> t.multiply(unitVector("k"))).collect(Collectors.toList());
}
dotProduct(queryVector, modelVectors, Math.max(iterations/10, 10)); // warmup
- System.gc();
long startTime = System.currentTimeMillis();
dotProduct(queryVector, modelVectors, iterations);
long totalTime = System.currentTimeMillis() - startTime;
@@ -107,41 +106,51 @@ public class TensorFunctionBenchmark {
// ---------------- Mapped with extra space (sidesteps current special-case optimizations):
// 410 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true);
System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time);
// 770 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true);
System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time);
// ---------------- Mapped:
// 2.6 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false);
System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time);
// 6.8 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false);
System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations):
// 30 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time);
// 27 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time);
// ---------------- Indexed unbound:
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time);
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed bound:
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false);
System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time);
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false);
System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time);
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index b35220cf013..feeba1a7a10 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -21,7 +21,7 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
- * Tests tensor functionality
+ * Tests Tensor functionality
*
* @author bratseth
*/
@@ -29,8 +29,7 @@ public class TensorTestCase {
@Test
public void testStringForm() {
- assertEquals("{5.7}", Tensor.from("{5.7}").toString());
- assertTrue(Tensor.from("{5.7}") instanceof IndexedTensor);
+ assertEquals("{}", Tensor.from("{}").toString());
assertEquals("{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString());
assertEquals("{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
deleted file mode 100644
index d2b2044f3ed..00000000000
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor.serialization;
-
-import com.google.common.collect.Sets;
-import com.yahoo.tensor.Tensor;
-import org.junit.Ignore;
-import org.junit.Test;
-
-import java.util.Arrays;
-import java.util.Set;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * Tests for the dense binary format.
- *
- * @author bratseth
- */
-public class DenseBinaryFormatTestCase {
-
- @Test
- public void testSerialization() {
- assertSerialization("{-5.37}");
- assertSerialization("tensor(x[]):{{x:0}:2.0}");
- assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0}");
- assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
- assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}");
- }
-
- @Test
- @Ignore // TODO: Activate when encoding in this format is activated
- public void requireThatSerializationFormatDoNotChange() {
- byte[] encodedTensor = new byte[]{2, // binary format type
- 2, // dimension count
- 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
- 1, (byte) 'z', 1, // dimension z with size
- 64, 0, 0, 0, 0, 0, 0, 0, // value 1
- 64, 8, 0, 0, 0, 0, 0, 0 // value 2
- };
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"))));
- }
-
- private void assertSerialization(String tensorString) {
- assertSerialization(Tensor.from(tensorString));
- }
-
- private void assertSerialization(Tensor tensor) {
- byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor);
- assertEquals(tensor, decodedTensor);
- }
-
-}
-
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 283aa90cf65..ad908101329 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -13,23 +13,52 @@ import static org.junit.Assert.assertEquals;
/**
* Tests for the sparse binary format.
*
+ * TODO: When new formats are added we should refactor this test to test all formats
+ * with the same set of tensor inputs (if feasible).
+ *
* @author geirst
*/
public class SparseBinaryFormatTestCase {
+ private static void assertSerialization(String tensorString) {
+ assertSerialization(Tensor.from(tensorString));
+ }
+
+ private static void assertSerialization(String tensorString, Set<String> dimensions) {
+ Tensor tensor = Tensor.from(tensorString);
+ assertEquals(dimensions, tensor.type().dimensionNames());
+ assertSerialization(tensor);
+ }
+
+ private static void assertSerialization(Tensor tensor) {
+ byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
+ Tensor decodedTensor = TypedBinaryFormat.decode(encodedTensor);
+ assertEquals(tensor, decodedTensor);
+ }
+
@Test
- public void testSerialization() {
- assertSerialization("tensor(x{}):{}");
- assertSerialization("tensor(x{}):{{x:0}:2.0}");
- assertSerialization("tensor(dimX{},dimY{}):{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}");
- assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}");
- assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}");
- assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}");
- assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}");
+ public void testSerializationOfTensorsWithDenseTensorAddresses() {
+ assertSerialization("{}");
+ assertSerialization("{{x:0}:2.0}");
+ assertSerialization("{{x:0}:2.0,{x:1}:3.0}");
+ assertSerialization("{{x:0,y:0}:2.0}");
+ assertSerialization("{{x:0,y:0}:2.0,{x:0,y:1}:3.0}");
+ assertSerialization("{{y:0,x:0}:2.0}");
+ assertSerialization("{{y:0,x:0}:2.0,{y:1,x:0}:3.0}");
+ assertSerialization("{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}");
}
@Test
- public void requireThatSerializationFormatDoNotChange() {
+ public void testSerializationOfTensorsWithSparseTensorAddresses() {
+ assertSerialization("{{x:0}:2.0, {x:1}:3.0}", Sets.newHashSet("x"));
+ assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}", Sets.newHashSet("x", "y"));
+ assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}", Sets.newHashSet("x", "y"));
+ assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}", Sets.newHashSet("x", "y", "z"));
+ }
+
+ @Test
+ public void requireThatCompactSerializationFormatDoNotChange() {
byte[] encodedTensor = new byte[] {1, // binary format type
2, // num dimensions
2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
@@ -37,17 +66,7 @@ public class SparseBinaryFormatTestCase {
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
- }
-
- private void assertSerialization(String tensorString) {
- assertSerialization(Tensor.from(tensorString));
- }
-
- private void assertSerialization(Tensor tensor) {
- byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor);
- assertEquals(tensor, decodedTensor);
+ Arrays.toString(TypedBinaryFormat.encode(Tensor.from("{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
}
}