From 400a428fd3ae71684988e93953eb6c89462d057e Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Fri, 30 Apr 2021 14:08:03 +0000 Subject: add api for detecting cell existence * new API "has(TensorAddress)" detects if a Tensor has a cell with the given address. * use new API in join and merge. This will give different results for cells that are present but contain NaN versus cells that aren't present at all. * use new API in slice. This gives a different default (0, not NaN) when trying to access cells that aren't present. --- vespajlib/abi-spec.json | 4 ++++ .../src/main/java/com/yahoo/tensor/IndexedTensor.java | 11 +++++++++++ vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java | 3 +++ vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java | 11 +++++++++++ vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 3 +++ .../src/main/java/com/yahoo/tensor/functions/Join.java | 14 ++++++++------ .../src/main/java/com/yahoo/tensor/functions/Merge.java | 11 ++++++----- .../src/main/java/com/yahoo/tensor/functions/Slice.java | 10 ++++++++-- 8 files changed, 54 insertions(+), 13 deletions(-) (limited to 'vespajlib') diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index ebca0a4d852..ccdd09e4cab 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -891,6 +891,7 @@ "public varargs double get(long[])", "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", + "public boolean has(com.yahoo.tensor.TensorAddress)", "public abstract double get(long)", "public abstract float getFloat(long)", "public com.yahoo.tensor.TensorType type()", @@ -941,6 +942,7 @@ "public com.yahoo.tensor.TensorType type()", "public long size()", "public double get(com.yahoo.tensor.TensorAddress)", + "public boolean has(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", "public java.util.Map cells()", @@ -1031,6 +1033,7 @@ "public com.yahoo.tensor.TensorType type()", "public long size()", "public double get(com.yahoo.tensor.TensorAddress)", + "public boolean has(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", "public java.util.Map cells()", @@ -1151,6 +1154,7 @@ "public boolean isEmpty()", "public abstract long size()", "public abstract double get(com.yahoo.tensor.TensorAddress)", + "public abstract boolean has(com.yahoo.tensor.TensorAddress)", "public abstract java.util.Iterator cellIterator()", "public abstract java.util.Iterator valueIterator()", "public abstract java.util.Map cells()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index dfc26cf0282..67eace78c45 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -118,6 +118,17 @@ public abstract class IndexedTensor implements Tensor { } } + @Override + public boolean has(TensorAddress address) { + try { + long index = toValueIndex(address, dimensionSizes, type); + if (index < 0) return false; + return (index < size()); + } catch (IllegalArgumentException e) { + return false; + } + } + /** * Returns the value at the given standard value order index as a double. * diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 33f904efd42..9d04e10bacb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -34,6 +34,9 @@ public class MappedTensor implements Tensor { @Override public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); } + @Override + public boolean has(TensorAddress address) { return cells.containsKey(address); } + @Override public Iterator cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 7631a2e4eab..e686a42d530 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -61,6 +61,17 @@ public class MixedTensor implements Tensor { return cell.getValue(); } + @Override + public boolean has(TensorAddress address) { + long cellIndex = index.indexOf(address); + if (cellIndex < 0 || cellIndex >= cells.size()) + return false; + Cell cell = cells.get((int)cellIndex); + if ( ! address.equals(cell.getKey())) + return false; + return true; + } + /** * Returns an iterator over the cells of this tensor. * Cells are returned in order of increasing indexes in the diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 3378520dc91..8d014edd68f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -70,6 +70,9 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); + /** Returns true if this cell exists */ + boolean has(TensorAddress address); + /** * Returns the cell of this in some undefined order. * A cell instances is only valid until next() is called. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index d43b7889982..0cbcfbb7ad6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -126,9 +126,10 @@ public class Join extends PrimitiveTensorFunction i = a.cellIterator(); i.hasNext(); ) { Map.Entry aCell = i.next(); - double bCellValue = b.get(aCell.getKey()); - if (Double.isNaN(bCellValue)) continue; // no match - builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue)); + var key = aCell.getKey(); + if (b.has(key)) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + } } return builder.build(); } @@ -203,11 +204,12 @@ public class Join extends PrimitiveTensorFunction i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry supercell = i.next(); TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); - double subspaceValue = subspace.get(subaddress); - if ( ! Double.isNaN(subspaceValue)) + if (subspace.has(subaddress)) { + double subspaceValue = subspace.get(subaddress); builder.cell(supercell.getKey(), reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + } } return builder.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index 4aa09f3f4e3..a2387affa67 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -125,11 +125,12 @@ public class Merge extends PrimitiveTensorFunction i = a.cellIterator(); i.hasNext(); ) { Map.Entry aCell = i.next(); - double bCellValue = b.get(aCell.getKey()); - if (Double.isNaN(bCellValue)) - builder.cell(aCell.getKey(), aCell.getValue()); - else if (combinator != null) - builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue)); + var key = aCell.getKey(); + if (! b.has(key)) { + builder.cell(key, aCell.getValue()); + } else if (combinator != null) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + } } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 607c9a0ab44..da24aef50bc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -63,8 +63,14 @@ public class Slice extends PrimitiveTensorFunction i = tensor.cellIterator(); i.hasNext(); ) { -- cgit v1.2.3