summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-30 14:08:03 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-30 14:08:06 +0000
commit400a428fd3ae71684988e93953eb6c89462d057e (patch)
treeacf2e09b984428532dc64a6affe76d4403a951a7 /vespajlib
parent8e2478b8965bbd29709957e2c4fc37e8333a59e5 (diff)
add api for detecting cell existence
* new API "has(TensorAddress)" detects if a Tensor has a cell with the given address. * use new API in join and merge. This will give different results for cells that are present but contain NaN versus cells that aren't present at all. * use new API in slice. This gives a different default (0, not NaN) when trying to access cells that aren't present.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java10
8 files changed, 54 insertions, 13 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index ebca0a4d852..ccdd09e4cab 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -891,6 +891,7 @@
"public varargs double get(long[])",
"public varargs float getFloat(long[])",
"public double get(com.yahoo.tensor.TensorAddress)",
+ "public boolean has(com.yahoo.tensor.TensorAddress)",
"public abstract double get(long)",
"public abstract float getFloat(long)",
"public com.yahoo.tensor.TensorType type()",
@@ -941,6 +942,7 @@
"public com.yahoo.tensor.TensorType type()",
"public long size()",
"public double get(com.yahoo.tensor.TensorAddress)",
+ "public boolean has(com.yahoo.tensor.TensorAddress)",
"public java.util.Iterator cellIterator()",
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
@@ -1031,6 +1033,7 @@
"public com.yahoo.tensor.TensorType type()",
"public long size()",
"public double get(com.yahoo.tensor.TensorAddress)",
+ "public boolean has(com.yahoo.tensor.TensorAddress)",
"public java.util.Iterator cellIterator()",
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
@@ -1151,6 +1154,7 @@
"public boolean isEmpty()",
"public abstract long size()",
"public abstract double get(com.yahoo.tensor.TensorAddress)",
+ "public abstract boolean has(com.yahoo.tensor.TensorAddress)",
"public abstract java.util.Iterator cellIterator()",
"public abstract java.util.Iterator valueIterator()",
"public abstract java.util.Map cells()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index dfc26cf0282..67eace78c45 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -118,6 +118,17 @@ public abstract class IndexedTensor implements Tensor {
}
}
+ @Override
+ public boolean has(TensorAddress address) {
+ try {
+ long index = toValueIndex(address, dimensionSizes, type);
+ if (index < 0) return false;
+ return (index < size());
+ } catch (IllegalArgumentException e) {
+ return false;
+ }
+ }
+
/**
* Returns the value at the given <i>standard value order</i> index as a double.
*
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 33f904efd42..9d04e10bacb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -35,6 +35,9 @@ public class MappedTensor implements Tensor {
public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); }
@Override
+ public boolean has(TensorAddress address) { return cells.containsKey(address); }
+
+ @Override
public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 7631a2e4eab..e686a42d530 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -61,6 +61,17 @@ public class MixedTensor implements Tensor {
return cell.getValue();
}
+ @Override
+ public boolean has(TensorAddress address) {
+ long cellIndex = index.indexOf(address);
+ if (cellIndex < 0 || cellIndex >= cells.size())
+ return false;
+ Cell cell = cells.get((int)cellIndex);
+ if ( ! address.equals(cell.getKey()))
+ return false;
+ return true;
+ }
+
/**
* Returns an iterator over the cells of this tensor.
* Cells are returned in order of increasing indexes in the
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 3378520dc91..8d014edd68f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -70,6 +70,9 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
+ /** Returns true if this cell exists */
+ boolean has(TensorAddress address);
+
/**
* Returns the cell of this in some undefined order.
* A cell instances is only valid until next() is called.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index d43b7889982..0cbcfbb7ad6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -126,9 +126,10 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
- double bCellValue = b.get(aCell.getKey());
- if (Double.isNaN(bCellValue)) continue; // no match
- builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue));
+ var key = aCell.getKey();
+ if (b.has(key)) {
+ builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key)));
+ }
}
return builder.build();
}
@@ -203,11 +204,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> supercell = i.next();
TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
- double subspaceValue = subspace.get(subaddress);
- if ( ! Double.isNaN(subspaceValue))
+ if (subspace.has(subaddress)) {
+ double subspaceValue = subspace.get(subaddress);
builder.cell(supercell.getKey(),
reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
- : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ }
}
return builder.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index 4aa09f3f4e3..a2387affa67 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -125,11 +125,12 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) {
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
- double bCellValue = b.get(aCell.getKey());
- if (Double.isNaN(bCellValue))
- builder.cell(aCell.getKey(), aCell.getValue());
- else if (combinator != null)
- builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue));
+ var key = aCell.getKey();
+ if (! b.has(key)) {
+ builder.cell(key, aCell.getValue());
+ } else if (combinator != null) {
+ builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key)));
+ }
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 607c9a0ab44..da24aef50bc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -63,8 +63,14 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
TensorType resultType = resultType(tensor.type());
PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context);
- if (resultType.rank() == 0) // shortcut common case
- return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type())));
+ if (resultType.rank() == 0) { // shortcut common case
+ var key = subspaceAddress.asAddress(tensor.type());
+ if (tensor.has(key)) {
+ return Tensor.from(tensor.get(key));
+ } else {
+ return Tensor.from(0.0);
+ }
+ }
Tensor.Builder b = Tensor.Builder.of(resultType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {