diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-25 16:40:00 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-25 16:40:00 +0200 |
commit | 7ef86b1fb25f2268d00fa3af87bc1e594de0b1b3 (patch) | |
tree | 5af4bc2b63e291b7e80d2ffc3ea85b5dfdf2b044 /vespajlib | |
parent | a8949c869c613d671886b87ab684b2dfef9d9ca5 (diff) |
Split values into IndexedDoubleTensor subclass
Diffstat (limited to 'vespajlib')
4 files changed, 68 insertions, 39 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 43388e4e18d..e4b6162eeca 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -792,10 +792,10 @@ "com.yahoo.tensor.Tensor" ], "attributes": [ - "public" + "public", + "abstract" ], "methods": [ - "public long size()", "public java.util.Iterator cellIterator()", "public com.yahoo.tensor.IndexedTensor$SubspaceIterator cellIterator(com.yahoo.tensor.PartialAddress, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator valueIterator()", @@ -803,14 +803,13 @@ "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", "public double get(com.yahoo.tensor.TensorAddress)", - "public double get(long)", + "public abstract double get(long)", "public com.yahoo.tensor.TensorType type()", - "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", + "public abstract com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public com.yahoo.tensor.Tensor remove(java.util.Set)", - "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", "public bridge synthetic com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)" diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java new file mode 100644 index 00000000000..27cecdab80c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -0,0 +1,46 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import java.util.Arrays; + +/** + * An indexed tensor implementation holding values as doubles + * + * @author bratseth + */ +class IndexedDoubleTensor extends IndexedTensor { + + private final double[] values; + + IndexedDoubleTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { + super(type, dimensionSizes); + this.values = values; + } + + @Override + public long size() { + return values.length; + } + + /** + * Returns the value at the given index by direct lookup. Only use + * if you know the underlying data layout. + * + * @param valueIndex the direct index into the underlying data. + * @throws IndexOutOfBoundsException if index is out of bounds + */ + @Override + public double get(long valueIndex) { return values[(int)valueIndex]; } + + @Override + public IndexedTensor withType(TensorType type) { + if ( ! this.type().isRenamableTo(type)) + throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + + ": Types are not compatible"); + return new IndexedDoubleTensor(type, dimensionSizes(), values); + } + + @Override + public int hashCode() { return Arrays.hashCode(values); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 38d832d01c2..5f2c04bbd56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -20,7 +20,7 @@ import java.util.function.DoubleBinaryOperator; * * @author bratseth */ -public class IndexedTensor implements Tensor { +public abstract class IndexedTensor implements Tensor { /** The prescribed and possibly abstract type this is an instance of */ private final TensorType type; @@ -28,17 +28,9 @@ public class IndexedTensor implements Tensor { /** The sizes of the dimensions of this in the order of the dimensions of the type */ private final DimensionSizes dimensionSizes; - private final double[] values; - - private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { + IndexedTensor(TensorType type, DimensionSizes dimensionSizes) { this.type = type; this.dimensionSizes = dimensionSizes; - this.values = values; - } - - @Override - public long size() { - return values.length; } /** @@ -96,13 +88,13 @@ public class IndexedTensor implements Tensor { } /** - * Returns the value at the given indexes + * Returns the value at the given indexes as a double * * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(long ... indexes) { - return values[(int)toValueIndex(indexes, dimensionSizes)]; + return get((int)toValueIndex(indexes, dimensionSizes)); } /** Returns the value at this address, or NaN if there is no value at this address */ @@ -110,7 +102,7 @@ public class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return values[(int)toValueIndex(address, dimensionSizes)]; + return get((int)toValueIndex(address, dimensionSizes)); } catch (IndexOutOfBoundsException e) { return Double.NaN; @@ -124,7 +116,7 @@ public class IndexedTensor implements Tensor { * @param valueIndex the direct index into the underlying data. * @throws IndexOutOfBoundsException if index is out of bounds */ - public double get(long valueIndex) { return values[(int)valueIndex]; } + public abstract double get(long valueIndex); private static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed @@ -164,13 +156,7 @@ public class IndexedTensor implements Tensor { public TensorType type() { return type; } @Override - public IndexedTensor withType(TensorType type) { - if (!this.type.isRenamableTo(type)) { - throw new IllegalArgumentException("IndexedTensor.withType: types are not compatible. Current type: '" + - this.type.toString() + "', requested type: '" + type.toString() + "'"); - } - return new IndexedTensor(type, dimensionSizes, values); - } + public abstract IndexedTensor withType(TensorType type); public DimensionSizes dimensionSizes() { return dimensionSizes; @@ -179,13 +165,13 @@ public class IndexedTensor implements Tensor { @Override public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) - return Collections.singletonMap(TensorAddress.of(), values[0]); + return Collections.singletonMap(TensorAddress.of(), get(0)); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); - Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); - for (long i = 0; i < values.length; i++) { + Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size()); + for (long i = 0; i < size(); i++) { indexes.next(); - builder.put(indexes.toAddress(), values[(int)i]); + builder.put(indexes.toAddress(), get(i)); } return builder.build(); } @@ -201,9 +187,6 @@ public class IndexedTensor implements Tensor { } @Override - public int hashCode() { return Arrays.hashCode(values); } - - @Override public String toString() { return Tensor.toStandardString(this); } @Override @@ -302,7 +285,7 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { - IndexedTensor tensor = new IndexedTensor(type, sizes, values); + IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO // prevent further modification sizes = null; values = null; @@ -348,12 +331,12 @@ public class IndexedTensor implements Tensor { if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values"); if (type.dimensions().isEmpty()) // single number - return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); + return new IndexedDoubleTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); double[] values = new double[(int)dimensionSizes.totalSize()]; fillValues(0, 0, firstDimension, dimensionSizes, values); - return new IndexedTensor(type, dimensionSizes, values); + return new IndexedDoubleTensor(type, dimensionSizes, values); } private DimensionSizes findDimensionSizes(List<Object> firstDimension) { @@ -460,7 +443,7 @@ public class IndexedTensor implements Tensor { private final class CellIterator implements Iterator<Cell> { private long count = 0; - private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); + private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size()); private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN); @Override @@ -485,13 +468,13 @@ public class IndexedTensor implements Tensor { @Override public boolean hasNext() { - return count < values.length; + return count < size(); } @Override public Double next() { try { - return values[(int)count++]; + return get(count++); } catch (IndexOutOfBoundsException e) { throw new NoSuchElementException("No element at position " + count); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index df78f3dfc3a..b1c7a2341c0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -143,6 +143,7 @@ public class TensorType { } private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { + if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); |