summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-26 14:24:17 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-26 14:24:17 +0200
commitae5d5e058f1bb2fd197886ac374ce807065fdb77 (patch)
tree2966fda95d45f68ccf212e9fe8884528b7ce23f6 /vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
parent94b4b3ad837f9d3f9d43b158c4de8475ff2c2a2d (diff)
Build tensors purely with floats
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/Tensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java21
1 files changed, 18 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index eb16801c306..ebb341147cf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -370,9 +370,9 @@ public interface Tensor {
class Cell implements Map.Entry<TensorAddress, Double> {
private final TensorAddress address;
- private final Double value;
+ private final Number value;
- Cell(TensorAddress address, Double value) {
+ Cell(TensorAddress address, Number value) {
this.address = address;
this.value = value;
}
@@ -387,8 +387,15 @@ public interface Tensor {
*/
long getDirectIndex() { return -1; }
+ /** Returns the value as a double */
@Override
- public Double getValue() { return value; }
+ public Double getValue() { return value.doubleValue(); }
+
+ /** Returns the value as a float */
+ public float getFloatValue() { return value.floatValue(); }
+
+ /** Returns the value as a double */
+ public double getDoubleValue() { return value.doubleValue(); }
@Override
public Double setValue(Double value) {
@@ -446,9 +453,11 @@ public interface Tensor {
/** Add a cell */
Builder cell(TensorAddress address, double value);
+ Builder cell(TensorAddress address, float value);
/** Add a cell */
Builder cell(double value, long ... labels);
+ Builder cell(float value, long ... labels);
/**
* Add a cell
@@ -459,6 +468,9 @@ public interface Tensor {
default Builder cell(Cell cell, double value) {
return cell(cell.getKey(), value);
}
+ default Builder cell(Cell cell, float value) {
+ return cell(cell.getKey(), value);
+ }
Tensor build();
@@ -484,6 +496,9 @@ public interface Tensor {
public Builder value(double cellValue) {
return tensorBuilder.cell(addressBuilder.build(), cellValue);
}
+ public Builder value(float cellValue) {
+ return tensorBuilder.cell(addressBuilder.build(), cellValue);
+ }
}