diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-26 14:24:17 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-26 14:24:17 +0200 |
commit | ae5d5e058f1bb2fd197886ac374ce807065fdb77 (patch) | |
tree | 2966fda95d45f68ccf212e9fe8884528b7ce23f6 /vespajlib/src/main/java | |
parent | 94b4b3ad837f9d3f9d43b158c4de8475ff2c2a2d (diff) |
Build tensors purely with floats
Diffstat (limited to 'vespajlib/src/main/java')
6 files changed, 100 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 5d5c2be4576..c9e5be31c15 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -52,6 +52,11 @@ class IndexedDoubleTensor extends IndexedTensor { } @Override + public IndexedTensor.BoundBuilder cell(float value, long ... indexes) { + return cell((double)value, indexes); + } + + @Override public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { values[(int)toValueIndex(indexes, sizes())] = value; return this; @@ -63,6 +68,11 @@ class IndexedDoubleTensor extends IndexedTensor { } @Override + public Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + + @Override public Builder cell(TensorAddress address, double value) { values[(int)toValueIndex(address, sizes())] = value; return this; @@ -77,6 +87,11 @@ class IndexedDoubleTensor extends IndexedTensor { } @Override + public Builder cell(Cell cell, float value) { + return cell(cell, (double)value); + } + + @Override public Builder cell(Cell cell, double value) { long directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization @@ -86,11 +101,11 @@ class IndexedDoubleTensor extends IndexedTensor { return this; } - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ + @Override + public void cellByDirectIndex(long index, float value) { + cellByDirectIndex(index, (double)value); + } + @Override public void cellByDirectIndex(long index, double value) { values[(int)index] = value; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 1e2aed1f5b4..4c8af0cbfd6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -53,7 +53,12 @@ class IndexedFloatTensor extends IndexedTensor { @Override public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { - values[(int)toValueIndex(indexes, sizes())] = (float)value; + return cell((float)value, indexes); + } + + @Override + public IndexedTensor.BoundBuilder cell(float value, long ... indexes) { + values[(int)toValueIndex(indexes, sizes())] = value; return this; } @@ -64,7 +69,12 @@ class IndexedFloatTensor extends IndexedTensor { @Override public Builder cell(TensorAddress address, double value) { - values[(int)toValueIndex(address, sizes())] = (float)value; + return cell(address, (float)value); + } + + @Override + public Builder cell(TensorAddress address, float value) { + values[(int)toValueIndex(address, sizes())] = value; return this; } @@ -78,22 +88,27 @@ class IndexedFloatTensor extends IndexedTensor { @Override public Builder cell(Cell cell, double value) { + return cell(cell, (float)value); + } + + @Override + public Builder cell(Cell cell, float value) { long directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization - values[(int)directIndex] = (float)value; + values[(int)directIndex] = value; else super.cell(cell, value); return this; } - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ @Override public void cellByDirectIndex(long index, double value) { - values[(int)index] = (float)value; + cellByDirectIndex(index, (float)value); + } + + @Override + public void cellByDirectIndex(long index, float value) { + values[(int)index] = value; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 8e2223def83..07375cfa604 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -243,6 +243,7 @@ public abstract class IndexedTensor implements Tensor { } public abstract Builder cell(double value, long ... indexes); + public abstract Builder cell(float value, long ... indexes); @Override public TensorType type() { return type; } @@ -275,6 +276,8 @@ public abstract class IndexedTensor implements Tensor { public abstract void cellByDirectIndex(long index, double value); + public abstract void cellByDirectIndex(long index, float value); + } /** @@ -353,6 +356,11 @@ public abstract class IndexedTensor implements Tensor { } @Override + public Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + + @Override public Builder cell(TensorAddress address, double value) { long[] indexes = new long[address.size()]; for (int i = 0; i < address.size(); i++) { @@ -362,6 +370,11 @@ public abstract class IndexedTensor implements Tensor { return this; } + @Override + public Builder cell(float value, long... indexes) { + return cell((double)value, indexes); + } + /** * Set a value using an index API. The number of indexes must be the same as the dimensions in the type of this. * Values can be written in any order but all values needed to make this dense must be provided diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 22ceed22d3e..693c4b5f2b0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -115,12 +115,22 @@ public class MappedTensor implements Tensor { public TensorType type() { return type; } @Override + public Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + + @Override public Builder cell(TensorAddress address, double value) { cells.put(address, value); return this; } @Override + public Builder cell(float value, long... labels) { + return cell((double)value, labels); + } + + @Override public Builder cell(double value, long... labels) { cells.put(TensorAddress.of(labels), value); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index c06cb2a0986..95f64cec0c1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -193,6 +193,11 @@ public class MixedTensor implements Tensor { } @Override + public Tensor.Builder cell(float value, long... labels) { + return cell((double)value, labels); + } + + @Override public Tensor.Builder cell(double value, long... labels) { throw new UnsupportedOperationException("Not implemented."); } @@ -236,6 +241,11 @@ public class MixedTensor implements Tensor { } @Override + public Tensor.Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + + @Override public Tensor.Builder cell(TensorAddress address, double value) { TensorAddress sparsePart = index.sparsePartialAddress(address); long denseOffset = index.denseOffset(address); @@ -293,6 +303,11 @@ public class MixedTensor implements Tensor { } @Override + public Tensor.Builder cell(TensorAddress address, float value) { + return cell(address, (double)value); + } + + @Override public Tensor.Builder cell(TensorAddress address, double value) { cells.put(address, value); trackBounds(address); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index eb16801c306..ebb341147cf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -370,9 +370,9 @@ public interface Tensor { class Cell implements Map.Entry<TensorAddress, Double> { private final TensorAddress address; - private final Double value; + private final Number value; - Cell(TensorAddress address, Double value) { + Cell(TensorAddress address, Number value) { this.address = address; this.value = value; } @@ -387,8 +387,15 @@ public interface Tensor { */ long getDirectIndex() { return -1; } + /** Returns the value as a double */ @Override - public Double getValue() { return value; } + public Double getValue() { return value.doubleValue(); } + + /** Returns the value as a float */ + public float getFloatValue() { return value.floatValue(); } + + /** Returns the value as a double */ + public double getDoubleValue() { return value.doubleValue(); } @Override public Double setValue(Double value) { @@ -446,9 +453,11 @@ public interface Tensor { /** Add a cell */ Builder cell(TensorAddress address, double value); + Builder cell(TensorAddress address, float value); /** Add a cell */ Builder cell(double value, long ... labels); + Builder cell(float value, long ... labels); /** * Add a cell @@ -459,6 +468,9 @@ public interface Tensor { default Builder cell(Cell cell, double value) { return cell(cell.getKey(), value); } + default Builder cell(Cell cell, float value) { + return cell(cell.getKey(), value); + } Tensor build(); @@ -484,6 +496,9 @@ public interface Tensor { public Builder value(double cellValue) { return tensorBuilder.cell(addressBuilder.build(), cellValue); } + public Builder value(float cellValue) { + return tensorBuilder.cell(addressBuilder.build(), cellValue); + } } |