diff options
Diffstat (limited to 'vespajlib')
16 files changed, 147 insertions, 293 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java index eba749bd14e..c33882052b4 100644 --- a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java +++ b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java @@ -1,8 +1,6 @@ // Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.io; -import com.yahoo.text.Utf8; - import java.nio.*; /** @@ -22,22 +20,21 @@ import java.nio.*; * No methods except getByteBuffer() expose the encapsulated * ByteBuffer, which is intentional. * - * @author Einar M R Rosenvinge + * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> */ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> { - public static final int DEFAULT_BASE_SIZE = 64*1024; public static final float DEFAULT_GROW_FACTOR = 2.0f; private ByteBuffer buffer; private float growFactor; private int mark = -1; - // NOTE: It might have been better to subclass HeapByteBuffer, - // but that class is package-private. Subclassing ByteBuffer would involve - // implementing a lot of abstract methods, which would mean reinventing - // some (too many) wheels. + //NOTE: It might have been better to subclass HeapByteBuffer, + //but that class is package-private. Subclassing ByteBuffer would involve + //implementing a lot of abstract methods, which would mean reinventing + //some (too many) wheels. - // CONSTRUCTORS: + //CONSTRUCTORS: public GrowableByteBuffer() { this(DEFAULT_BASE_SIZE, DEFAULT_GROW_FACTOR); @@ -64,7 +61,7 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> { } - // ACCESSORS: + //ACCESSORS: public float getGrowFactor() { return growFactor; @@ -367,21 +364,6 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> { } } - /** Writes this string to the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */ - public void putUtf8String(String value) { - byte[] stringBytes = Utf8.toBytes(value); - putInt1_4Bytes(stringBytes.length); - put(stringBytes); - } - - /** Reads a string from the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */ - public String getUtf8String() { - int stringLength = getInt1_4Bytes(); - byte[] stringBytes = new byte[stringLength]; - get(stringBytes); - return Utf8.toString(stringBytes); - } - /** * Computes the size used for storing the given integer using 1 or 4 bytes. * diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 7570a357452..daa85cc51e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -29,14 +29,6 @@ public final class DimensionSizes { /** Returns the number of dimensions this provides the size of */ public int dimensions() { return sizes.length; } - /** Returns the product of the sizes of this */ - public int totalSize() { - int productSize = 1; - for (int dimensionSize : sizes ) - productSize *= dimensionSize; - return productSize; - } - @Override public boolean equals(Object o) { if (o == this) return true; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index bee93ddb4e0..9315922f57a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -103,6 +103,7 @@ public class IndexedTensor implements Tensor { * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(int ... indexes) { + if (values.length == 0) return Double.NaN; return values[toValueIndex(indexes, dimensionSizes)]; } @@ -156,7 +157,7 @@ public class IndexedTensor implements Tensor { @Override public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) - return Collections.singletonMap(TensorAddress.empty, values[0]); + return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); @@ -216,6 +217,13 @@ public class IndexedTensor implements Tensor { public abstract Builder cell(double value, int ... indexes); + protected double[] arrayFor(DimensionSizes sizes) { + int productSize = 1; + for (int i = 0; i < sizes.dimensions(); i++ ) + productSize *= sizes.size(i); + return new double[productSize]; + } + @Override public TensorType type() { return type; } @@ -225,7 +233,7 @@ public class IndexedTensor implements Tensor { } /** A bound builder can create the double array directly */ - public static class BoundBuilder extends Builder { + private static class BoundBuilder extends Builder { private DimensionSizes sizes; private double[] values; @@ -234,7 +242,7 @@ public class IndexedTensor implements Tensor { this(type, dimensionSizesOf(type)); } - static DimensionSizes dimensionSizesOf(TensorType type) { + public static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < type.dimensions().size(); i++) b.set(i, type.dimensions().get(i).size().get()); @@ -246,7 +254,8 @@ public class IndexedTensor implements Tensor { if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; - values = new double[sizes.totalSize()]; + values = arrayFor(sizes); + Arrays.fill(values, Double.NaN); } @Override @@ -268,6 +277,10 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { + // Note that we do not check for no NaN's here for performance reasons. + // NaN's don't get lost so leaving them in place should be quite benign + if (values.length == 1 && Double.isNaN(values[0])) + values = new double[0]; IndexedTensor tensor = new IndexedTensor(type, sizes, values); // prevent further modification sizes = null; @@ -277,6 +290,9 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(Cell cell, double value) { + // TODO: Use internal index if applicable + // values[internalIndex] = value; + // return this; int directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization values[directIndex] = value; @@ -285,15 +301,6 @@ public class IndexedTensor implements Tensor { return this; } - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ - public void cellByDirectIndex(int index, double value) { - values[index] = value; - } - } /** @@ -311,13 +318,13 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { - if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values"); - + if (firstDimension == null) // empty + return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {}); if (type.dimensions().isEmpty()) // single number return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); - double[] values = new double[dimensionSizes.totalSize()]; + double[] values = arrayFor(dimensionSizes); fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } @@ -326,10 +333,8 @@ public class IndexedTensor implements Tensor { List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct - for (int i = 0; i < b.dimensions(); i++) { - if (i < dimensionSizeList.size()) - b.set(i, dimensionSizeList.get(i)); - } + for (int i = 0; i < b.dimensions(); i++) + b.set(i, dimensionSizeList.get(i)); return b.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 29c508ce12f..51d40a89f3b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -213,9 +213,10 @@ public interface Tensor { static String contentToString(Tensor tensor) { List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); - if (tensor.type().dimensions().isEmpty()) { + if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number if (cellEntries.isEmpty()) return "{}"; - return "{" + cellEntries.get(0).getValue() +"}"; + double value = cellEntries.get(0).getValue(); + return value == 0.0 ? "{}" : "{" + value +"}"; } Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index fbc469c1829..82f36972a47 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -53,6 +53,9 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } + /** Returns true if all dimensions of this are indexed */ + public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); } + /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index f295e129a0f..ceade39ce42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -113,7 +113,7 @@ public class Join extends PrimitiveTensorFunction { /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { - if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) + if (subspace.type().isIndexed() && superspace.type().isIndexed()) return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder); else return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index 9b0ccdcb6c8..f3adf63739a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -4,7 +4,6 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; /** * Representation of a specific binary format with functions for serializing a Tensor object into @@ -22,10 +21,7 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. - * - * @param type the expected abstract type of the tensor to serialize - * @param buffer the buffer containing the tensor binary data */ - Tensor decode(TensorType type, GrowableByteBuffer buffer); + Tensor decode(GrowableByteBuffer buffer); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java deleted file mode 100644 index 0a97576d5b7..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ /dev/null @@ -1,87 +0,0 @@ -package com.yahoo.tensor.serialization; - -import com.google.common.annotations.Beta; -import com.yahoo.io.GrowableByteBuffer; -import com.yahoo.tensor.DimensionSizes; -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.text.Utf8; - -import java.util.Iterator; - -/** - * Implementation of a dense binary format for a tensor on the form: - * - * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]* - * Cell_values = [double, double, double, ...]* - * where values are encoded in order of increasing indexes in each dimension, increasing - * indexes of later dimensions in the dimension type before earlier. - * - * @author bratseth - */ -@Beta -public class DenseBinaryFormat implements BinaryFormat { - - @Override - public void encode(GrowableByteBuffer buffer, Tensor tensor) { - if ( ! ( tensor instanceof IndexedTensor)) - throw new RuntimeException("The dense format is only supported for indexed tensors"); - encodeDimensions(buffer, (IndexedTensor)tensor); - encodeCells(buffer, tensor); - } - - private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) { - buffer.putInt1_4Bytes(tensor.type().dimensions().size()); - for (int i = 0; i < tensor.type().dimensions().size(); i++) { - buffer.putUtf8String(tensor.type().dimensions().get(i).name()); - buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i)); - } - } - - private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { - Iterator<Double> i = tensor.valueIterator(); - while (i.hasNext()) - buffer.putDouble(i.next()); - } - - @Override - public Tensor decode(TensorType type, GrowableByteBuffer buffer) { - DimensionSizes sizes = decodeDimensionSizes(type, buffer); - Tensor.Builder builder = Tensor.Builder.of(type, sizes); - decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder); - return builder.build(); - } - - private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) { - int dimensionCount = buffer.getInt1_4Bytes(); - if (type.dimensions().size() != dimensionCount) - throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + - " dimensions but type is " + type); - - DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount); - for (int i = 0; i < dimensionCount; i++) { - TensorType.Dimension expectedDimension = type.dimensions().get(i); - - String encodedName = buffer.getUtf8String(); - int encodedSize = buffer.getInt1_4Bytes(); - - if ( ! expectedDimension.name().equals(encodedName)) - throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + - "' as dimension " + i + " but type is " + type); - - if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize) - throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize + - " in " + expectedDimension + " in type " + type); - - builder.set(i, encodedSize); - } - return builder.build(); - } - - private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { - for (int i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.getDouble()); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 30b36e83457..27a009b5e7e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -31,14 +31,14 @@ class SparseBinaryFormat implements BinaryFormat { encodeCells(buffer, tensor); } - private void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) { + private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) { buffer.putInt1_4Bytes(sortedDimensions.size()); for (TensorType.Dimension dimension : sortedDimensions) { - buffer.putUtf8String(dimension.name()); + encodeString(buffer, dimension.name()); } } - private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { + private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { buffer.putInt1_4Bytes(tensor.size()); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -47,47 +47,35 @@ class SparseBinaryFormat implements BinaryFormat { } } - private void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) { + private static void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) { for (int i = 0; i < address.size(); i++) - buffer.putUtf8String(address.label(i)); + encodeString(buffer, address.label(i)); + } + + private static void encodeString(GrowableByteBuffer buffer, String value) { + byte[] stringBytes = Utf8.toBytes(value); + buffer.putInt1_4Bytes(stringBytes.length); + buffer.put(stringBytes); } @Override - public Tensor decode(TensorType type, GrowableByteBuffer buffer) { - if (type == null) // TODO (January 2017): Remove this when types are available - type = decodeDimensionsToType(buffer); - else - consumeAndValidateDimensions(type, buffer); + public Tensor decode(GrowableByteBuffer buffer) { + TensorType type = decodeDimensions(buffer); Tensor.Builder builder = Tensor.Builder.of(type); decodeCells(buffer, builder, type); return builder.build(); } - private TensorType decodeDimensionsToType(GrowableByteBuffer buffer) { + private static TensorType decodeDimensions(GrowableByteBuffer buffer) { TensorType.Builder builder = new TensorType.Builder(); int numDimensions = buffer.getInt1_4Bytes(); for (int i = 0; i < numDimensions; ++i) { - builder.mapped(buffer.getUtf8String()); + builder.mapped(decodeString(buffer)); // TODO: Support indexed } return builder.build(); } - private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) { - int dimensionCount = buffer.getInt1_4Bytes(); - if (type.dimensions().size() != dimensionCount) - throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + - " dimensions but type is " + type); - - for (int i = 0; i < dimensionCount; ++i) { - TensorType.Dimension expectedDimension = type.dimensions().get(i); - String encodedName = buffer.getUtf8String(); - if ( ! expectedDimension.name().equals(encodedName)) - throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + - "' as dimension " + i + " but type is " + type); - } - } - - private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { + private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { int numCells = buffer.getInt1_4Bytes(); for (int i = 0; i < numCells; ++i) { Tensor.Builder.CellBuilder cellBuilder = builder.cell(); @@ -96,13 +84,20 @@ class SparseBinaryFormat implements BinaryFormat { } } - private void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) { + private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) { - String label = buffer.getUtf8String(); + String label = decodeString(buffer); if ( ! label.isEmpty()) { builder.label(dimension.name(), label); } } } + private static String decodeString(GrowableByteBuffer buffer) { + int stringLength = buffer.getInt1_4Bytes(); + byte[] stringBytes = new byte[stringLength]; + buffer.get(stringBytes); + return Utf8.toString(stringBytes); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 65216aa2fcd..5a45f20b6d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -3,9 +3,7 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; -import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; /** * Class used by clients for serializing a Tensor object into binary format or @@ -20,31 +18,25 @@ import com.yahoo.tensor.TensorType; public class TypedBinaryFormat { private static final int SPARSE_BINARY_FORMAT_TYPE = 1; - private static final int DENSE_BINARY_FORMAT_TYPE = 2; public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); - if (tensor instanceof IndexedTensor && 1==2) { // TODO: Activate when we have type information everywhere - buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); - new DenseBinaryFormat().encode(buffer, tensor); - } - else { - buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); - new SparseBinaryFormat().encode(buffer, tensor); - } + buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); + new SparseBinaryFormat().encode(buffer, tensor); buffer.flip(); byte[] result = new byte[buffer.remaining()]; buffer.get(result); return result; } - public static Tensor decode(TensorType type, byte[] data) { + public static Tensor decode(byte[] data) { GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data); int formatType = buffer.getInt1_4Bytes(); switch (formatType) { - case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer); - case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat().decode(type, buffer); - default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); + case SPARSE_BINARY_FORMAT_TYPE: + return new SparseBinaryFormat().decode(buffer); + default: + throw new IllegalArgumentException("Binary format type " + formatType + " is not a known format"); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index e150b1cf24f..3f7f02c6c00 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -1,6 +1,5 @@ package com.yahoo.tensor; -import junit.framework.TestCase; import org.junit.Test; import java.util.HashMap; @@ -8,7 +7,6 @@ import java.util.Iterator; import java.util.Map; import static junit.framework.TestCase.assertTrue; -import static junit.framework.TestCase.fail; import static org.junit.Assert.assertEquals; /** @@ -25,12 +23,16 @@ public class IndexedTensorTestCase { @Test public void testEmpty() { Tensor empty = Tensor.Builder.of(TensorType.empty).build(); - assertEquals(1, empty.size()); - assertEquals((double)0.0, (double)empty.valueIterator().next(), 0.00000001); + assertTrue(empty instanceof IndexedTensor); + assertTrue(empty.isEmpty()); + assertEquals("{}", empty.toString()); Tensor emptyFromString = Tensor.from(TensorType.empty, "{}"); + assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString()); + assertTrue(emptyFromString.isEmpty()); + assertTrue(emptyFromString instanceof IndexedTensor); assertEquals(empty, emptyFromString); } - + @Test public void testSingleValue() { Tensor singleValue = Tensor.Builder.of(TensorType.empty).cell(TensorAddress.empty, 3.5).build(); @@ -43,6 +45,22 @@ public class IndexedTensorTestCase { } @Test + public void testSingleValueWithDimensions() { + TensorType type = new TensorType.Builder().indexed("x").indexed("y").build(); + Tensor emptyWithDimensions = Tensor.Builder.of(type).build(); + assertTrue(emptyWithDimensions instanceof IndexedTensor); + assertEquals("tensor(x[],y[]):{}", emptyWithDimensions.toString()); + Tensor emptyWithDimensionsFromString = Tensor.from("tensor(x[],y[]):{}"); + assertEquals("tensor(x[],y[]):{}", emptyWithDimensionsFromString.toString()); + assertTrue(emptyWithDimensionsFromString instanceof IndexedTensor); + assertEquals(emptyWithDimensions, emptyWithDimensionsFromString); + + IndexedTensor emptyWithDimensionsIndexed = (IndexedTensor)emptyWithDimensions; + assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(0)); + assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(1)); + } + + @Test public void testBoundBuilding() { TensorType type = new TensorType.Builder().indexed("v", vSize) .indexed("w", wSize) @@ -73,7 +91,7 @@ public class IndexedTensorTestCase { for (int z = 0; z < zSize; z++) builder.cell(value(v, w, x, y, z), v, w, x, y, z); - IndexedTensor tensor = (IndexedTensor)builder.build(); + IndexedTensor tensor = builder.build(); // Lookup by index arguments for (int v = 0; v < vSize; v++) diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java index 5c2c3b9db32..4c32a80dc11 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java @@ -2,7 +2,6 @@ package com.yahoo.tensor; import com.google.common.collect.Sets; -import junit.framework.TestCase; import org.junit.Test; import java.util.Set; @@ -19,20 +18,6 @@ import static org.junit.Assert.fail; public class MappedTensorTestCase { @Test - public void testEmpty() { - TensorType type = new TensorType.Builder().mapped("x").build(); - Tensor empty = Tensor.Builder.of(type).build(); - TestCase.assertTrue(empty instanceof MappedTensor); - TestCase.assertTrue(empty.isEmpty()); - assertEquals("tensor(x{}):{}", empty.toString()); - Tensor emptyFromString = Tensor.from(type, "{}"); - assertEquals("tensor(x{}):{}", Tensor.from("tensor(x{}):{}").toString()); - TestCase.assertTrue(emptyFromString.isEmpty()); - TestCase.assertTrue(emptyFromString instanceof MappedTensor); - assertEquals(empty, emptyFromString); - } - - @Test public void testOneDimensionalBuilding() { TensorType type = new TensorType.Builder().mapped("x").build(); Tensor tensor = Tensor.Builder.of(type). diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index e2baa1d5ac3..2f060239eb1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -27,7 +27,6 @@ public class TensorFunctionBenchmark { modelVectors = modelVectors.stream().map(t -> t.multiply(unitVector("k"))).collect(Collectors.toList()); } dotProduct(queryVector, modelVectors, Math.max(iterations/10, 10)); // warmup - System.gc(); long startTime = System.currentTimeMillis(); dotProduct(queryVector, modelVectors, iterations); long totalTime = System.currentTimeMillis() - startTime; @@ -107,41 +106,51 @@ public class TensorFunctionBenchmark { // ---------------- Mapped with extra space (sidesteps current special-case optimizations): // 410 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); // 770 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); // ---------------- Mapped: // 2.6 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); // 6.8 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations): // 30 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time); // 27 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time); // ---------------- Indexed unbound: // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed bound: // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index b35220cf013..feeba1a7a10 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -21,7 +21,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** - * Tests tensor functionality + * Tests Tensor functionality * * @author bratseth */ @@ -29,8 +29,7 @@ public class TensorTestCase { @Test public void testStringForm() { - assertEquals("{5.7}", Tensor.from("{5.7}").toString()); - assertTrue(Tensor.from("{5.7}") instanceof IndexedTensor); + assertEquals("{}", Tensor.from("{}").toString()); assertEquals("{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); assertEquals("{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java deleted file mode 100644 index d2b2044f3ed..00000000000 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor.serialization; - -import com.google.common.collect.Sets; -import com.yahoo.tensor.Tensor; -import org.junit.Ignore; -import org.junit.Test; - -import java.util.Arrays; -import java.util.Set; - -import static org.junit.Assert.assertEquals; - -/** - * Tests for the dense binary format. - * - * @author bratseth - */ -public class DenseBinaryFormatTestCase { - - @Test - public void testSerialization() { - assertSerialization("{-5.37}"); - assertSerialization("tensor(x[]):{{x:0}:2.0}"); - assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0}"); - assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); - assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}"); - } - - @Test - @Ignore // TODO: Activate when encoding in this format is activated - public void requireThatSerializationFormatDoNotChange() { - byte[] encodedTensor = new byte[]{2, // binary format type - 2, // dimension count - 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size - 1, (byte) 'z', 1, // dimension z with size - 64, 0, 0, 0, 0, 0, 0, 0, // value 1 - 64, 8, 0, 0, 0, 0, 0, 0 // value 2 - }; - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}")))); - } - - private void assertSerialization(String tensorString) { - assertSerialization(Tensor.from(tensorString)); - } - - private void assertSerialization(Tensor tensor) { - byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor); - assertEquals(tensor, decodedTensor); - } - -} - diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 283aa90cf65..ad908101329 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -13,23 +13,52 @@ import static org.junit.Assert.assertEquals; /** * Tests for the sparse binary format. * + * TODO: When new formats are added we should refactor this test to test all formats + * with the same set of tensor inputs (if feasible). + * * @author geirst */ public class SparseBinaryFormatTestCase { + private static void assertSerialization(String tensorString) { + assertSerialization(Tensor.from(tensorString)); + } + + private static void assertSerialization(String tensorString, Set<String> dimensions) { + Tensor tensor = Tensor.from(tensorString); + assertEquals(dimensions, tensor.type().dimensionNames()); + assertSerialization(tensor); + } + + private static void assertSerialization(Tensor tensor) { + byte[] encodedTensor = TypedBinaryFormat.encode(tensor); + Tensor decodedTensor = TypedBinaryFormat.decode(encodedTensor); + assertEquals(tensor, decodedTensor); + } + @Test - public void testSerialization() { - assertSerialization("tensor(x{}):{}"); - assertSerialization("tensor(x{}):{{x:0}:2.0}"); - assertSerialization("tensor(dimX{},dimY{}):{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}"); - assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}"); - assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}"); - assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}"); - assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}"); + public void testSerializationOfTensorsWithDenseTensorAddresses() { + assertSerialization("{}"); + assertSerialization("{{x:0}:2.0}"); + assertSerialization("{{x:0}:2.0,{x:1}:3.0}"); + assertSerialization("{{x:0,y:0}:2.0}"); + assertSerialization("{{x:0,y:0}:2.0,{x:0,y:1}:3.0}"); + assertSerialization("{{y:0,x:0}:2.0}"); + assertSerialization("{{y:0,x:0}:2.0,{y:1,x:0}:3.0}"); + assertSerialization("{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}"); } @Test - public void requireThatSerializationFormatDoNotChange() { + public void testSerializationOfTensorsWithSparseTensorAddresses() { + assertSerialization("{{x:0}:2.0, {x:1}:3.0}", Sets.newHashSet("x")); + assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}", Sets.newHashSet("x", "y")); + assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}", Sets.newHashSet("x", "y")); + assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}", Sets.newHashSet("x", "y", "z")); + assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}", Sets.newHashSet("x", "y", "z")); + } + + @Test + public void requireThatCompactSerializationFormatDoNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions @@ -37,17 +66,7 @@ public class SparseBinaryFormatTestCase { 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); - } - - private void assertSerialization(String tensorString) { - assertSerialization(Tensor.from(tensorString)); - } - - private void assertSerialization(Tensor tensor) { - byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor); - assertEquals(tensor, decodedTensor); + Arrays.toString(TypedBinaryFormat.encode(Tensor.from("{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); } } |