From 7ef86b1fb25f2268d00fa3af87bc1e594de0b1b3 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 25 Apr 2019 16:40:00 +0200 Subject: Split values into IndexedDoubleTensor subclass --- vespajlib/abi-spec.json | 9 ++-- .../java/com/yahoo/tensor/IndexedDoubleTensor.java | 46 +++++++++++++++++++ .../main/java/com/yahoo/tensor/IndexedTensor.java | 51 ++++++++-------------- .../src/main/java/com/yahoo/tensor/TensorType.java | 1 + 4 files changed, 68 insertions(+), 39 deletions(-) create mode 100644 vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 43388e4e18d..e4b6162eeca 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -792,10 +792,10 @@ "com.yahoo.tensor.Tensor" ], "attributes": [ - "public" + "public", + "abstract" ], "methods": [ - "public long size()", "public java.util.Iterator cellIterator()", "public com.yahoo.tensor.IndexedTensor$SubspaceIterator cellIterator(com.yahoo.tensor.PartialAddress, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator valueIterator()", @@ -803,14 +803,13 @@ "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", "public double get(com.yahoo.tensor.TensorAddress)", - "public double get(long)", + "public abstract double get(long)", "public com.yahoo.tensor.TensorType type()", - "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", + "public abstract com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public com.yahoo.tensor.Tensor remove(java.util.Set)", - "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", "public bridge synthetic com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)" diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java new file mode 100644 index 00000000000..27cecdab80c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -0,0 +1,46 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import java.util.Arrays; + +/** + * An indexed tensor implementation holding values as doubles + * + * @author bratseth + */ +class IndexedDoubleTensor extends IndexedTensor { + + private final double[] values; + + IndexedDoubleTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { + super(type, dimensionSizes); + this.values = values; + } + + @Override + public long size() { + return values.length; + } + + /** + * Returns the value at the given index by direct lookup. Only use + * if you know the underlying data layout. + * + * @param valueIndex the direct index into the underlying data. + * @throws IndexOutOfBoundsException if index is out of bounds + */ + @Override + public double get(long valueIndex) { return values[(int)valueIndex]; } + + @Override + public IndexedTensor withType(TensorType type) { + if ( ! this.type().isRenamableTo(type)) + throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + + ": Types are not compatible"); + return new IndexedDoubleTensor(type, dimensionSizes(), values); + } + + @Override + public int hashCode() { return Arrays.hashCode(values); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 38d832d01c2..5f2c04bbd56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -20,7 +20,7 @@ import java.util.function.DoubleBinaryOperator; * * @author bratseth */ -public class IndexedTensor implements Tensor { +public abstract class IndexedTensor implements Tensor { /** The prescribed and possibly abstract type this is an instance of */ private final TensorType type; @@ -28,17 +28,9 @@ public class IndexedTensor implements Tensor { /** The sizes of the dimensions of this in the order of the dimensions of the type */ private final DimensionSizes dimensionSizes; - private final double[] values; - - private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { + IndexedTensor(TensorType type, DimensionSizes dimensionSizes) { this.type = type; this.dimensionSizes = dimensionSizes; - this.values = values; - } - - @Override - public long size() { - return values.length; } /** @@ -96,13 +88,13 @@ public class IndexedTensor implements Tensor { } /** - * Returns the value at the given indexes + * Returns the value at the given indexes as a double * * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(long ... indexes) { - return values[(int)toValueIndex(indexes, dimensionSizes)]; + return get((int)toValueIndex(indexes, dimensionSizes)); } /** Returns the value at this address, or NaN if there is no value at this address */ @@ -110,7 +102,7 @@ public class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return values[(int)toValueIndex(address, dimensionSizes)]; + return get((int)toValueIndex(address, dimensionSizes)); } catch (IndexOutOfBoundsException e) { return Double.NaN; @@ -124,7 +116,7 @@ public class IndexedTensor implements Tensor { * @param valueIndex the direct index into the underlying data. * @throws IndexOutOfBoundsException if index is out of bounds */ - public double get(long valueIndex) { return values[(int)valueIndex]; } + public abstract double get(long valueIndex); private static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed @@ -164,13 +156,7 @@ public class IndexedTensor implements Tensor { public TensorType type() { return type; } @Override - public IndexedTensor withType(TensorType type) { - if (!this.type.isRenamableTo(type)) { - throw new IllegalArgumentException("IndexedTensor.withType: types are not compatible. Current type: '" + - this.type.toString() + "', requested type: '" + type.toString() + "'"); - } - return new IndexedTensor(type, dimensionSizes, values); - } + public abstract IndexedTensor withType(TensorType type); public DimensionSizes dimensionSizes() { return dimensionSizes; @@ -179,13 +165,13 @@ public class IndexedTensor implements Tensor { @Override public Map cells() { if (dimensionSizes.dimensions() == 0) - return Collections.singletonMap(TensorAddress.of(), values[0]); + return Collections.singletonMap(TensorAddress.of(), get(0)); ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); - for (long i = 0; i < values.length; i++) { + Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size()); + for (long i = 0; i < size(); i++) { indexes.next(); - builder.put(indexes.toAddress(), values[(int)i]); + builder.put(indexes.toAddress(), get(i)); } return builder.build(); } @@ -200,9 +186,6 @@ public class IndexedTensor implements Tensor { throw new IllegalArgumentException("Remove is not supported for indexed tensors"); } - @Override - public int hashCode() { return Arrays.hashCode(values); } - @Override public String toString() { return Tensor.toStandardString(this); } @@ -302,7 +285,7 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { - IndexedTensor tensor = new IndexedTensor(type, sizes, values); + IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO // prevent further modification sizes = null; values = null; @@ -348,12 +331,12 @@ public class IndexedTensor implements Tensor { if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values"); if (type.dimensions().isEmpty()) // single number - return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); + return new IndexedDoubleTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); double[] values = new double[(int)dimensionSizes.totalSize()]; fillValues(0, 0, firstDimension, dimensionSizes, values); - return new IndexedTensor(type, dimensionSizes, values); + return new IndexedDoubleTensor(type, dimensionSizes, values); } private DimensionSizes findDimensionSizes(List firstDimension) { @@ -460,7 +443,7 @@ public class IndexedTensor implements Tensor { private final class CellIterator implements Iterator { private long count = 0; - private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); + private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size()); private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN); @Override @@ -485,13 +468,13 @@ public class IndexedTensor implements Tensor { @Override public boolean hasNext() { - return count < values.length; + return count < size(); } @Override public Double next() { try { - return values[(int)count++]; + return get(count++); } catch (IndexOutOfBoundsException e) { throw new NoSuchElementException("No element at position " + count); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index df78f3dfc3a..b1c7a2341c0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -143,6 +143,7 @@ public class TensorType { } private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { + if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); -- cgit v1.2.3 From 3873424bb18acd179441cdd914070c32e41699ee Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 26 Apr 2019 11:26:04 +0200 Subject: Move bound builder double array into double subclass --- .../com/yahoo/language/process/Transformer.java | 2 +- .../java/com/yahoo/tensor/IndexedDoubleTensor.java | 65 +++++++++++++++++- .../java/com/yahoo/tensor/IndexedFloatTensor.java | 44 ++++++++++++ .../main/java/com/yahoo/tensor/IndexedTensor.java | 79 ++++++---------------- 4 files changed, 129 insertions(+), 61 deletions(-) create mode 100644 vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java diff --git a/linguistics/src/main/java/com/yahoo/language/process/Transformer.java b/linguistics/src/main/java/com/yahoo/language/process/Transformer.java index 398ddc0262b..2b84c8ab570 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Transformer.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Transformer.java @@ -6,7 +6,7 @@ import com.yahoo.language.Language; /** * Interface for providers of text transformations such as accent removal. * - * @author Mathias Mølster Lidal + * @authorMathias Mølster Lidal */ public interface Transformer { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 27cecdab80c..80350d9e5f5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -34,13 +34,72 @@ class IndexedDoubleTensor extends IndexedTensor { @Override public IndexedTensor withType(TensorType type) { - if ( ! this.type().isRenamableTo(type)) - throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + - ": Types are not compatible"); + throwOnIncompatibleType(type); return new IndexedDoubleTensor(type, dimensionSizes(), values); } @Override public int hashCode() { return Arrays.hashCode(values); } + /** A bound builder can create the double array directly */ + public static class BoundDoubleBuilder extends BoundBuilder { + + private double[] values; + + BoundDoubleBuilder(TensorType type) { + this(type, dimensionSizesOf(type)); + } + + BoundDoubleBuilder(TensorType type, DimensionSizes sizes) { + super(type, sizes); + values = new double[(int)sizes.totalSize()]; + } + + @Override + public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { + values[(int)toValueIndex(indexes, sizes())] = value; + return this; + } + + @Override + public CellBuilder cell() { + return new CellBuilder(type, this); + } + + @Override + public Builder cell(TensorAddress address, double value) { + values[(int)toValueIndex(address, sizes())] = value; + return this; + } + + @Override + public IndexedTensor build() { + IndexedTensor tensor = new IndexedDoubleTensor(type, sizes(), values); + // prevent further modification + values = null; + return tensor; + } + + @Override + public Builder cell(Cell cell, double value) { + long directIndex = cell.getDirectIndex(); + if (directIndex >= 0) // optimization + values[(int)directIndex] = value; + else + super.cell(cell, value); + return this; + } + + /** + * Set a cell value by the index in the internal layout of this cell. + * This requires knowledge of the internal layout of cells in this implementation, and should therefore + * probably not be used (but when it can be used it is fast). + */ + @Override + public void cellByDirectIndex(long index, double value) { + values[(int)index] = value; + } + + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java new file mode 100644 index 00000000000..563d72137e7 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -0,0 +1,44 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import java.util.Arrays; + +/** + * An indexed tensor implementation holding values as floats + * + * @author bratseth + */ +class IndexedFloatTensor extends IndexedTensor { + + private final float[] values; + + IndexedFloatTensor(TensorType type, DimensionSizes dimensionSizes, float[] values) { + super(type, dimensionSizes); + this.values = values; + } + + @Override + public long size() { + return values.length; + } + + /** + * Returns the value at the given index by direct lookup. Only use + * if you know the underlying data layout. + * + * @param valueIndex the direct index into the underlying data. + * @throws IndexOutOfBoundsException if index is out of bounds + */ + @Override + public double get(long valueIndex) { return values[(int)valueIndex]; } + + @Override + public IndexedTensor withType(TensorType type) { + throwOnIncompatibleType(type); + return new IndexedFloatTensor(type, dimensionSizes(), values); + } + + @Override + public int hashCode() { return Arrays.hashCode(values); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 5f2c04bbd56..6e587b05460 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -118,7 +118,7 @@ public abstract class IndexedTensor implements Tensor { */ public abstract double get(long valueIndex); - private static long toValueIndex(long[] indexes, DimensionSizes sizes) { + static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed @@ -132,7 +132,7 @@ public abstract class IndexedTensor implements Tensor { return valueIndex; } - private static long toValueIndex(TensorAddress address, DimensionSizes sizes) { + static long toValueIndex(TensorAddress address, DimensionSizes sizes) { if (address.isEmpty()) return 0; long valueIndex = 0; @@ -152,6 +152,12 @@ public abstract class IndexedTensor implements Tensor { return product; } + void throwOnIncompatibleType(TensorType type) { + if ( ! this.type().isRenamableTo(type)) + throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + + ": Types are not compatible"); + } + @Override public TensorType type() { return type; } @@ -205,7 +211,7 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type) { if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) - return new BoundBuilder(type); + return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); } @@ -218,8 +224,8 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) - throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + - "for " + type); + throw new IllegalArgumentException(sizes.dimensions() + + " is the wrong number of dimensions for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { Optional size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) @@ -228,7 +234,13 @@ public abstract class IndexedTensor implements Tensor { " but cannot be larger than " + size.get() + " in " + type); } - return new BoundBuilder(type, sizes); + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + // return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); TODO + else if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default } public abstract Builder cell(double value, long ... indexes); @@ -242,14 +254,9 @@ public abstract class IndexedTensor implements Tensor { } /** A bound builder can create the double array directly */ - public static class BoundBuilder extends Builder { + public static abstract class BoundBuilder extends Builder { private DimensionSizes sizes; - private double[] values; - - private BoundBuilder(TensorType type) { - this(type, dimensionSizesOf(type)); - } static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); @@ -258,58 +265,16 @@ public abstract class IndexedTensor implements Tensor { return b.build(); } - private BoundBuilder(TensorType type, DimensionSizes sizes) { + BoundBuilder(TensorType type, DimensionSizes sizes) { super(type); if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; - values = new double[(int)sizes.totalSize()]; - } - - @Override - public BoundBuilder cell(double value, long ... indexes) { - values[(int)toValueIndex(indexes, sizes)] = value; - return this; - } - - @Override - public CellBuilder cell() { - return new CellBuilder(type, this); - } - - @Override - public Builder cell(TensorAddress address, double value) { - values[(int)toValueIndex(address, sizes)] = value; - return this; } - @Override - public IndexedTensor build() { - IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO - // prevent further modification - sizes = null; - values = null; - return tensor; - } + DimensionSizes sizes() { return sizes; } - @Override - public Builder cell(Cell cell, double value) { - long directIndex = cell.getDirectIndex(); - if (directIndex >= 0) // optimization - values[(int)directIndex] = value; - else - super.cell(cell, value); - return this; - } - - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ - public void cellByDirectIndex(long index, double value) { - values[(int)index] = value; - } + public abstract void cellByDirectIndex(long index, double value); } -- cgit v1.2.3 From 94b4b3ad837f9d3f9d43b158c4de8475ff2c2a2d Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 26 Apr 2019 11:34:13 +0200 Subject: Make float builder when appropriate --- .../java/com/yahoo/tensor/IndexedDoubleTensor.java | 4 -- .../java/com/yahoo/tensor/IndexedFloatTensor.java | 57 ++++++++++++++++++++++ .../main/java/com/yahoo/tensor/IndexedTensor.java | 5 +- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 80350d9e5f5..5d5c2be4576 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -46,10 +46,6 @@ class IndexedDoubleTensor extends IndexedTensor { private double[] values; - BoundDoubleBuilder(TensorType type) { - this(type, dimensionSizesOf(type)); - } - BoundDoubleBuilder(TensorType type, DimensionSizes sizes) { super(type, sizes); values = new double[(int)sizes.totalSize()]; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 563d72137e7..1e2aed1f5b4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -41,4 +41,61 @@ class IndexedFloatTensor extends IndexedTensor { @Override public int hashCode() { return Arrays.hashCode(values); } + /** A bound builder can create the float array directly */ + public static class BoundFloatBuilder extends BoundBuilder { + + private float[] values; + + BoundFloatBuilder(TensorType type, DimensionSizes sizes) { + super(type, sizes); + values = new float[(int)sizes.totalSize()]; + } + + @Override + public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { + values[(int)toValueIndex(indexes, sizes())] = (float)value; + return this; + } + + @Override + public CellBuilder cell() { + return new CellBuilder(type, this); + } + + @Override + public Builder cell(TensorAddress address, double value) { + values[(int)toValueIndex(address, sizes())] = (float)value; + return this; + } + + @Override + public IndexedTensor build() { + IndexedTensor tensor = new IndexedFloatTensor(type, sizes(), values); + // prevent further modification + values = null; + return tensor; + } + + @Override + public Builder cell(Cell cell, double value) { + long directIndex = cell.getDirectIndex(); + if (directIndex >= 0) // optimization + values[(int)directIndex] = (float)value; + else + super.cell(cell, value); + return this; + } + + /** + * Set a cell value by the index in the internal layout of this cell. + * This requires knowledge of the internal layout of cells in this implementation, and should therefore + * probably not be used (but when it can be used it is fast). + */ + @Override + public void cellByDirectIndex(long index, double value) { + values[(int)index] = (float)value; + } + + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6e587b05460..8e2223def83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -235,8 +235,7 @@ public abstract class IndexedTensor implements Tensor { } if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - // return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); TODO + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); else if (type.valueType() == TensorType.Value.FLOAT) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); else @@ -258,7 +257,7 @@ public abstract class IndexedTensor implements Tensor { private DimensionSizes sizes; - static DimensionSizes dimensionSizesOf(TensorType type) { + private static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < type.dimensions().size(); i++) b.set(i, type.dimensions().get(i).size().get()); -- cgit v1.2.3 From ae5d5e058f1bb2fd197886ac374ce807065fdb77 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 26 Apr 2019 14:24:17 +0200 Subject: Build tensors purely with floats --- .../com/yahoo/language/process/Transformer.java | 2 +- vespajlib/abi-spec.json | 71 +++++++++++++++++++++- .../java/com/yahoo/tensor/IndexedDoubleTensor.java | 25 ++++++-- .../java/com/yahoo/tensor/IndexedFloatTensor.java | 33 +++++++--- .../main/java/com/yahoo/tensor/IndexedTensor.java | 13 ++++ .../main/java/com/yahoo/tensor/MappedTensor.java | 10 +++ .../main/java/com/yahoo/tensor/MixedTensor.java | 15 +++++ .../src/main/java/com/yahoo/tensor/Tensor.java | 21 ++++++- 8 files changed, 169 insertions(+), 21 deletions(-) diff --git a/linguistics/src/main/java/com/yahoo/language/process/Transformer.java b/linguistics/src/main/java/com/yahoo/language/process/Transformer.java index 2b84c8ab570..46f3c060d4e 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Transformer.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Transformer.java @@ -6,7 +6,7 @@ import com.yahoo.language.Language; /** * Interface for providers of text transformations such as accent removal. * - * @authorMathias Mølster Lidal + * @author Mathias Mølster Lidal */ public interface Transformer { diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e4b6162eeca..c7363dbbd86 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -706,27 +706,77 @@ ], "fields": [] }, - "com.yahoo.tensor.IndexedTensor$BoundBuilder": { - "superClass": "com.yahoo.tensor.IndexedTensor$Builder", + "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder": { + "superClass": "com.yahoo.tensor.IndexedTensor$BoundBuilder", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(float, long[])", + "public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(double, long[])", + "public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()", + "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", + "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", + "public com.yahoo.tensor.IndexedTensor build()", + "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)", + "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)", + "public void cellByDirectIndex(long, float)", + "public void cellByDirectIndex(long, double)", + "public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(float, long[])", + "public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(double, long[])", + "public bridge synthetic com.yahoo.tensor.Tensor build()", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)" + ], + "fields": [] + }, + "com.yahoo.tensor.IndexedFloatTensor$BoundFloatBuilder": { + "superClass": "com.yahoo.tensor.IndexedTensor$BoundBuilder", "interfaces": [], "attributes": [ "public" ], "methods": [ "public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(double, long[])", + "public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(float, long[])", "public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()", "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", + "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public com.yahoo.tensor.IndexedTensor build()", "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)", + "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)", "public void cellByDirectIndex(long, double)", + "public void cellByDirectIndex(long, float)", + "public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(float, long[])", "public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(double, long[])", "public bridge synthetic com.yahoo.tensor.Tensor build()", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)", "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])", "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)" ], "fields": [] }, + "com.yahoo.tensor.IndexedTensor$BoundBuilder": { + "superClass": "com.yahoo.tensor.IndexedTensor$Builder", + "interfaces": [], + "attributes": [ + "public", + "abstract" + ], + "methods": [ + "public abstract void cellByDirectIndex(long, double)", + "public abstract void cellByDirectIndex(long, float)" + ], + "fields": [] + }, "com.yahoo.tensor.IndexedTensor$Builder": { "superClass": "java.lang.Object", "interfaces": [ @@ -740,9 +790,11 @@ "public static com.yahoo.tensor.IndexedTensor$Builder of(com.yahoo.tensor.TensorType)", "public static com.yahoo.tensor.IndexedTensor$Builder of(com.yahoo.tensor.TensorType, com.yahoo.tensor.DimensionSizes)", "public varargs abstract com.yahoo.tensor.IndexedTensor$Builder cell(double, long[])", + "public varargs abstract com.yahoo.tensor.IndexedTensor$Builder cell(float, long[])", "public com.yahoo.tensor.TensorType type()", "public abstract com.yahoo.tensor.IndexedTensor build()", "public bridge synthetic com.yahoo.tensor.Tensor build()", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])", "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])" ], "fields": [] @@ -828,11 +880,15 @@ "public static com.yahoo.tensor.MappedTensor$Builder of(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()", "public com.yahoo.tensor.TensorType type()", + "public com.yahoo.tensor.MappedTensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public com.yahoo.tensor.MappedTensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", + "public varargs com.yahoo.tensor.MappedTensor$Builder cell(float, long[])", "public varargs com.yahoo.tensor.MappedTensor$Builder cell(double, long[])", "public com.yahoo.tensor.MappedTensor build()", "public bridge synthetic com.yahoo.tensor.Tensor build()", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])", "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])", + "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)" ], "fields": [] @@ -869,6 +925,7 @@ ], "methods": [ "public long denseSubspaceSize()", + "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", "public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])", "public com.yahoo.tensor.MixedTensor build()", @@ -888,6 +945,7 @@ "methods": [ "public static com.yahoo.tensor.MixedTensor$Builder of(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.TensorType type()", + "public varargs com.yahoo.tensor.Tensor$Builder cell(float, long[])", "public varargs com.yahoo.tensor.Tensor$Builder cell(double, long[])", "public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()", "public abstract com.yahoo.tensor.MixedTensor build()", @@ -916,6 +974,7 @@ "public" ], "methods": [ + "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", "public com.yahoo.tensor.MixedTensor build()", "public void trackBounds(com.yahoo.tensor.TensorAddress)", @@ -981,7 +1040,8 @@ "methods": [ "public com.yahoo.tensor.Tensor$Builder$CellBuilder label(java.lang.String, java.lang.String)", "public com.yahoo.tensor.Tensor$Builder$CellBuilder label(java.lang.String, long)", - "public com.yahoo.tensor.Tensor$Builder value(double)" + "public com.yahoo.tensor.Tensor$Builder value(double)", + "public com.yahoo.tensor.Tensor$Builder value(float)" ], "fields": [] }, @@ -999,8 +1059,11 @@ "public abstract com.yahoo.tensor.TensorType type()", "public abstract com.yahoo.tensor.Tensor$Builder$CellBuilder cell()", "public abstract com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", + "public abstract com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public varargs abstract com.yahoo.tensor.Tensor$Builder cell(double, long[])", + "public varargs abstract com.yahoo.tensor.Tensor$Builder cell(float, long[])", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)", + "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)", "public abstract com.yahoo.tensor.Tensor build()" ], "fields": [] @@ -1016,6 +1079,8 @@ "methods": [ "public com.yahoo.tensor.TensorAddress getKey()", "public java.lang.Double getValue()", + "public float getFloatValue()", + "public double getDoubleValue()", "public java.lang.Double setValue(java.lang.Double)", "public boolean equals(java.lang.Object)", "public int hashCode()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 5d5c2be4576..c9e5be31c15 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -51,6 +51,11 @@ class IndexedDoubleTensor extends IndexedTensor { values = new double[(int)sizes.totalSize()]; } + @Override + public IndexedTensor.BoundBuilder cell(float value, long ... indexes) { + return cell((double)value, indexes); + } + @Override public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { values[(int)toValueIndex(indexes, sizes())] = value; @@ -62,6 +67,11 @@ class IndexedDoubleTensor extends IndexedTensor { return new CellBuilder(type, this); } + @Override + public Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + @Override public Builder cell(TensorAddress address, double value) { values[(int)toValueIndex(address, sizes())] = value; @@ -76,6 +86,11 @@ class IndexedDoubleTensor extends IndexedTensor { return tensor; } + @Override + public Builder cell(Cell cell, float value) { + return cell(cell, (double)value); + } + @Override public Builder cell(Cell cell, double value) { long directIndex = cell.getDirectIndex(); @@ -86,11 +101,11 @@ class IndexedDoubleTensor extends IndexedTensor { return this; } - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ + @Override + public void cellByDirectIndex(long index, float value) { + cellByDirectIndex(index, (double)value); + } + @Override public void cellByDirectIndex(long index, double value) { values[(int)index] = value; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 1e2aed1f5b4..4c8af0cbfd6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -53,7 +53,12 @@ class IndexedFloatTensor extends IndexedTensor { @Override public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { - values[(int)toValueIndex(indexes, sizes())] = (float)value; + return cell((float)value, indexes); + } + + @Override + public IndexedTensor.BoundBuilder cell(float value, long ... indexes) { + values[(int)toValueIndex(indexes, sizes())] = value; return this; } @@ -64,7 +69,12 @@ class IndexedFloatTensor extends IndexedTensor { @Override public Builder cell(TensorAddress address, double value) { - values[(int)toValueIndex(address, sizes())] = (float)value; + return cell(address, (float)value); + } + + @Override + public Builder cell(TensorAddress address, float value) { + values[(int)toValueIndex(address, sizes())] = value; return this; } @@ -78,22 +88,27 @@ class IndexedFloatTensor extends IndexedTensor { @Override public Builder cell(Cell cell, double value) { + return cell(cell, (float)value); + } + + @Override + public Builder cell(Cell cell, float value) { long directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization - values[(int)directIndex] = (float)value; + values[(int)directIndex] = value; else super.cell(cell, value); return this; } - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ @Override public void cellByDirectIndex(long index, double value) { - values[(int)index] = (float)value; + cellByDirectIndex(index, (float)value); + } + + @Override + public void cellByDirectIndex(long index, float value) { + values[(int)index] = value; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 8e2223def83..07375cfa604 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -243,6 +243,7 @@ public abstract class IndexedTensor implements Tensor { } public abstract Builder cell(double value, long ... indexes); + public abstract Builder cell(float value, long ... indexes); @Override public TensorType type() { return type; } @@ -275,6 +276,8 @@ public abstract class IndexedTensor implements Tensor { public abstract void cellByDirectIndex(long index, double value); + public abstract void cellByDirectIndex(long index, float value); + } /** @@ -352,6 +355,11 @@ public abstract class IndexedTensor implements Tensor { return new CellBuilder(type, this); } + @Override + public Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + @Override public Builder cell(TensorAddress address, double value) { long[] indexes = new long[address.size()]; @@ -362,6 +370,11 @@ public abstract class IndexedTensor implements Tensor { return this; } + @Override + public Builder cell(float value, long... indexes) { + return cell((double)value, indexes); + } + /** * Set a value using an index API. The number of indexes must be the same as the dimensions in the type of this. * Values can be written in any order but all values needed to make this dense must be provided diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 22ceed22d3e..693c4b5f2b0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -114,12 +114,22 @@ public class MappedTensor implements Tensor { @Override public TensorType type() { return type; } + @Override + public Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + @Override public Builder cell(TensorAddress address, double value) { cells.put(address, value); return this; } + @Override + public Builder cell(float value, long... labels) { + return cell((double)value, labels); + } + @Override public Builder cell(double value, long... labels) { cells.put(TensorAddress.of(labels), value); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index c06cb2a0986..95f64cec0c1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -192,6 +192,11 @@ public class MixedTensor implements Tensor { return type; } + @Override + public Tensor.Builder cell(float value, long... labels) { + return cell((double)value, labels); + } + @Override public Tensor.Builder cell(double value, long... labels) { throw new UnsupportedOperationException("Not implemented."); @@ -235,6 +240,11 @@ public class MixedTensor implements Tensor { return denseSubspaceMap.get(sparsePartial); } + @Override + public Tensor.Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + @Override public Tensor.Builder cell(TensorAddress address, double value) { TensorAddress sparsePart = index.sparsePartialAddress(address); @@ -292,6 +302,11 @@ public class MixedTensor implements Tensor { dimensionBounds = new long[type.dimensions().size()]; } + @Override + public Tensor.Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + @Override public Tensor.Builder cell(TensorAddress address, double value) { cells.put(address, value); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index eb16801c306..ebb341147cf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -370,9 +370,9 @@ public interface Tensor { class Cell implements Map.Entry { private final TensorAddress address; - private final Double value; + private final Number value; - Cell(TensorAddress address, Double value) { + Cell(TensorAddress address, Number value) { this.address = address; this.value = value; } @@ -387,8 +387,15 @@ public interface Tensor { */ long getDirectIndex() { return -1; } + /** Returns the value as a double */ @Override - public Double getValue() { return value; } + public Double getValue() { return value.doubleValue(); } + + /** Returns the value as a float */ + public float getFloatValue() { return value.floatValue(); } + + /** Returns the value as a double */ + public double getDoubleValue() { return value.doubleValue(); } @Override public Double setValue(Double value) { @@ -446,9 +453,11 @@ public interface Tensor { /** Add a cell */ Builder cell(TensorAddress address, double value); + Builder cell(TensorAddress address, float value); /** Add a cell */ Builder cell(double value, long ... labels); + Builder cell(float value, long ... labels); /** * Add a cell @@ -459,6 +468,9 @@ public interface Tensor { default Builder cell(Cell cell, double value) { return cell(cell.getKey(), value); } + default Builder cell(Cell cell, float value) { + return cell(cell.getKey(), value); + } Tensor build(); @@ -484,6 +496,9 @@ public interface Tensor { public Builder value(double cellValue) { return tensorBuilder.cell(addressBuilder.build(), cellValue); } + public Builder value(float cellValue) { + return tensorBuilder.cell(addressBuilder.build(), cellValue); + } } -- cgit v1.2.3 From f0f7f4962e6339ad2b4fbd293e89df86a6ec7a0a Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 26 Apr 2019 14:48:50 +0200 Subject: Decode directly as float --- .../tensor/serialization/DenseBinaryFormat.java | 28 +++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 5072484567d..ec8c7de1e72 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -51,15 +51,22 @@ public class DenseBinaryFormat implements BinaryFormat { private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { switch (serializationValueType) { - case DOUBLE: encodeCells(tensor, buffer::putDouble); break; - case FLOAT: encodeCells(tensor, (i) -> buffer.putFloat(i.floatValue())); break; + case DOUBLE: encodeDoubleCells(tensor, buffer); break; + case FLOAT: encodeFloatCells(tensor, buffer); break; } } - private void encodeCells(Tensor tensor, Consumer consumer) { + private void encodeDoubleCells(Tensor tensor, GrowableByteBuffer buffer) { Iterator i = tensor.valueIterator(); while (i.hasNext()) { - consumer.accept(i.next()); + buffer.putDouble(i.next()); + } + } + + private void encodeFloatCells(Tensor tensor, GrowableByteBuffer buffer) { + Iterator i = tensor.valueIterator(); // TODO: floatValueIterator + while (i.hasNext()) { + buffer.putFloat(i.next().floatValue()); } } @@ -106,14 +113,19 @@ public class DenseBinaryFormat implements BinaryFormat { private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { switch (serializationValueType) { - case DOUBLE: decodeCells(sizes, builder, buffer::getDouble); break; - case FLOAT: decodeCells(sizes, builder, () -> (double)buffer.getFloat()); break; + case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break; + case FLOAT: decodeFloatCells(sizes, builder, buffer); break; } } - private void decodeCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, Supplier supplier) { + private void decodeDoubleCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.getDouble()); + } + + private void decodeFloatCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, supplier.get()); + builder.cellByDirectIndex(i, buffer.getFloat()); } } -- cgit v1.2.3 From e92b8dd81cfc469d42f858785919964baf8afb0e Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 26 Apr 2019 14:55:04 +0200 Subject: Encode directly as float --- .../java/com/yahoo/tensor/IndexedDoubleTensor.java | 10 +++------- .../java/com/yahoo/tensor/IndexedFloatTensor.java | 12 ++++-------- .../main/java/com/yahoo/tensor/IndexedTensor.java | 11 ++++++++++- .../tensor/serialization/DenseBinaryFormat.java | 20 ++++++++------------ 4 files changed, 25 insertions(+), 28 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index c9e5be31c15..285837a1bc6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -22,16 +22,12 @@ class IndexedDoubleTensor extends IndexedTensor { return values.length; } - /** - * Returns the value at the given index by direct lookup. Only use - * if you know the underlying data layout. - * - * @param valueIndex the direct index into the underlying data. - * @throws IndexOutOfBoundsException if index is out of bounds - */ @Override public double get(long valueIndex) { return values[(int)valueIndex]; } + @Override + public float getFloat(long valueIndex) { return (float)get(valueIndex); } + @Override public IndexedTensor withType(TensorType type) { throwOnIncompatibleType(type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 4c8af0cbfd6..8f8c24c8421 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -22,15 +22,11 @@ class IndexedFloatTensor extends IndexedTensor { return values.length; } - /** - * Returns the value at the given index by direct lookup. Only use - * if you know the underlying data layout. - * - * @param valueIndex the direct index into the underlying data. - * @throws IndexOutOfBoundsException if index is out of bounds - */ @Override - public double get(long valueIndex) { return values[(int)valueIndex]; } + public double get(long valueIndex) { return getFloat(valueIndex); } + + @Override + public float getFloat(long valueIndex) { return values[(int)valueIndex]; } @Override public IndexedTensor withType(TensorType type) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 07375cfa604..f6af1cf0ed2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -110,7 +110,7 @@ public abstract class IndexedTensor implements Tensor { } /** - * Returns the value at the given index by direct lookup. Only use + * Returns the value at the given index as a double by direct lookup. Only use * if you know the underlying data layout. * * @param valueIndex the direct index into the underlying data. @@ -118,6 +118,15 @@ public abstract class IndexedTensor implements Tensor { */ public abstract double get(long valueIndex); + /** + * Returns the value at the given index as a float by direct lookup. Only use + * if you know the underlying data layout. + * + * @param valueIndex the direct index into the underlying data. + * @throws IndexOutOfBoundsException if index is out of bounds + */ + public abstract float getFloat(long valueIndex); + static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index ec8c7de1e72..0cec09157fb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -38,7 +38,7 @@ public class DenseBinaryFormat implements BinaryFormat { if ( ! ( tensor instanceof IndexedTensor)) throw new RuntimeException("The dense format is only supported for indexed tensors"); encodeDimensions(buffer, (IndexedTensor)tensor); - encodeCells(buffer, tensor); + encodeCells(buffer, (IndexedTensor)tensor); } private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) { @@ -49,25 +49,21 @@ public class DenseBinaryFormat implements BinaryFormat { } } - private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { + private void encodeCells(GrowableByteBuffer buffer, IndexedTensor tensor) { switch (serializationValueType) { case DOUBLE: encodeDoubleCells(tensor, buffer); break; case FLOAT: encodeFloatCells(tensor, buffer); break; } } - private void encodeDoubleCells(Tensor tensor, GrowableByteBuffer buffer) { - Iterator i = tensor.valueIterator(); - while (i.hasNext()) { - buffer.putDouble(i.next()); - } + private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putDouble(tensor.get(i)); } - private void encodeFloatCells(Tensor tensor, GrowableByteBuffer buffer) { - Iterator i = tensor.valueIterator(); // TODO: floatValueIterator - while (i.hasNext()) { - buffer.putFloat(i.next().floatValue()); - } + private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putFloat(tensor.getFloat(i)); } @Override -- cgit v1.2.3 From ecd1dba4a8f11e7e5265c98675fc2fa780a4c6e5 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 26 Apr 2019 14:56:40 +0200 Subject: Add float accessor --- vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index f6af1cf0ed2..d43e9ee74a3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -97,6 +97,16 @@ public abstract class IndexedTensor implements Tensor { return get((int)toValueIndex(indexes, dimensionSizes)); } + /** + * Returns the value at the given indexes as a float + * + * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this + * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given + */ + public float getFloat(long ... indexes) { + return getFloat((int)toValueIndex(indexes, dimensionSizes)); + } + /** Returns the value at this address, or NaN if there is no value at this address */ @Override public double get(TensorAddress address) { -- cgit v1.2.3 From 62b7463c18819588789fe2b848f64b01f3b84f90 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 26 Apr 2019 15:16:57 +0200 Subject: Add methods --- vespajlib/abi-spec.json | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index c7363dbbd86..4f81f3baea8 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -854,8 +854,10 @@ "public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", + "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", "public abstract double get(long)", + "public abstract float getFloat(long)", "public com.yahoo.tensor.TensorType type()", "public abstract com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", -- cgit v1.2.3 From 97fcbb11ff8e3ab2173fd09d0af071d06e6629e8 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 29 Apr 2019 10:25:23 +0200 Subject: Add a test and correct condition --- vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 2 +- vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index d43e9ee74a3..19edfc0269e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -255,7 +255,7 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.FLOAT) + else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); else return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 02d16e6f3e4..b01d171792c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -36,6 +36,17 @@ public class TensorTestCase { assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); } + @Test + public void testValueTypes() { + assertEquals(Tensor.from("tensor(x[1]):{{x:0}:5}").getClass(), IndexedDoubleTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), + IndexedDoubleTensor.class); + + assertEquals(Tensor.from("tensor(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + } + @Test public void testParseError() { try { -- cgit v1.2.3