diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-04-30 14:08:03 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-04-30 14:08:06 +0000 |
commit | 400a428fd3ae71684988e93953eb6c89462d057e (patch) | |
tree | acf2e09b984428532dc64a6affe76d4403a951a7 /vespajlib/src | |
parent | 8e2478b8965bbd29709957e2c4fc37e8333a59e5 (diff) |
add api for detecting cell existence
* new API "has(TensorAddress)" detects if a Tensor has a cell with the given address.
* use new API in join and merge. This will give different results for
cells that are present but contain NaN versus cells that aren't present at all.
* use new API in slice. This gives a different default (0, not NaN) when trying to
access cells that aren't present.
Diffstat (limited to 'vespajlib/src')
7 files changed, 50 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index dfc26cf0282..67eace78c45 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -118,6 +118,17 @@ public abstract class IndexedTensor implements Tensor { } } + @Override + public boolean has(TensorAddress address) { + try { + long index = toValueIndex(address, dimensionSizes, type); + if (index < 0) return false; + return (index < size()); + } catch (IllegalArgumentException e) { + return false; + } + } + /** * Returns the value at the given <i>standard value order</i> index as a double. * diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 33f904efd42..9d04e10bacb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -35,6 +35,9 @@ public class MappedTensor implements Tensor { public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); } @Override + public boolean has(TensorAddress address) { return cells.containsKey(address); } + + @Override public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 7631a2e4eab..e686a42d530 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -61,6 +61,17 @@ public class MixedTensor implements Tensor { return cell.getValue(); } + @Override + public boolean has(TensorAddress address) { + long cellIndex = index.indexOf(address); + if (cellIndex < 0 || cellIndex >= cells.size()) + return false; + Cell cell = cells.get((int)cellIndex); + if ( ! address.equals(cell.getKey())) + return false; + return true; + } + /** * Returns an iterator over the cells of this tensor. * Cells are returned in order of increasing indexes in the diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 3378520dc91..8d014edd68f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -70,6 +70,9 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); + /** Returns true if this cell exists */ + boolean has(TensorAddress address); + /** * Returns the cell of this in some undefined order. * A cell instances is only valid until next() is called. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index d43b7889982..0cbcfbb7ad6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -126,9 +126,10 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); - double bCellValue = b.get(aCell.getKey()); - if (Double.isNaN(bCellValue)) continue; // no match - builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue)); + var key = aCell.getKey(); + if (b.has(key)) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + } } return builder.build(); } @@ -203,11 +204,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); - double subspaceValue = subspace.get(subaddress); - if ( ! Double.isNaN(subspaceValue)) + if (subspace.has(subaddress)) { + double subspaceValue = subspace.get(subaddress); builder.cell(supercell.getKey(), reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + } } return builder.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index 4aa09f3f4e3..a2387affa67 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -125,11 +125,12 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) { for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); - double bCellValue = b.get(aCell.getKey()); - if (Double.isNaN(bCellValue)) - builder.cell(aCell.getKey(), aCell.getValue()); - else if (combinator != null) - builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue)); + var key = aCell.getKey(); + if (! b.has(key)) { + builder.cell(key, aCell.getValue()); + } else if (combinator != null) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + } } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 607c9a0ab44..da24aef50bc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -63,8 +63,14 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY TensorType resultType = resultType(tensor.type()); PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context); - if (resultType.rank() == 0) // shortcut common case - return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type()))); + if (resultType.rank() == 0) { // shortcut common case + var key = subspaceAddress.asAddress(tensor.type()); + if (tensor.has(key)) { + return Tensor.from(tensor.get(key)); + } else { + return Tensor.from(0.0); + } + } Tensor.Builder b = Tensor.Builder.of(resultType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { |