diff options
Diffstat (limited to 'vespajlib/src/main/java/com')
6 files changed, 37 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 5d384e0329b..f26174d9576 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -120,6 +120,17 @@ public abstract class IndexedTensor implements Tensor { } @Override + public Double getAsDouble(TensorAddress address) { + try { + long index = toValueIndex(address, dimensionSizes, type); + if (index < 0 || size() <= index) return null; + return get(index); + } catch (IllegalArgumentException e) { + return null; + } + } + + @Override public boolean has(TensorAddress address) { try { long index = toValueIndex(address, dimensionSizes, type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 5471ea65b97..3e0df5f2261 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -41,6 +41,9 @@ public class MappedTensor implements Tensor { public boolean has(TensorAddress address) { return cells.containsKey(address); } @Override + public Double getAsDouble(TensorAddress address) { return cells.get(address); } + + @Override public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index cf6e737bf27..aab4a4e1297 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -83,6 +83,16 @@ public class MixedTensor implements Tensor { } @Override + public Double getAsDouble(TensorAddress address) { + var block = index.blockOf(address); + int denseOffset = index.denseOffsetOf(address); + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { + return null; + } + return block.cells[denseOffset]; + } + + @Override public boolean has(TensorAddress address) { var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cff17fdfd7c..d034ac551f8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -90,6 +90,8 @@ public interface Tensor { /** Returns true if this cell exists */ boolean has(TensorAddress address); + /** null = no value present. More efficient that if (t.has(key)) t.get(key) */ + Double getAsDouble(TensorAddress address); /** * Returns the cell of this in some undefined order. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 7a336233de0..712e5528fc6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -129,8 +129,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (b.has(key)) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + Double bVal = b.getAsDouble(key); + if (bVal != null) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } return builder.build(); @@ -206,11 +207,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); - if (subspace.has(subaddress)) { - double subspaceValue = subspace.get(subaddress); + Double subspaceValue = subspace.getAsDouble(subaddress); + if (subspaceValue != null) { builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + reversedArgumentOrder + ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } } return builder.build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index 59394785382..ddad91dc060 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -121,10 +121,11 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (! b.has(key)) { + Double bVal = b.getAsDouble(key); + if (bVal == null) { builder.cell(key, aCell.getValue()); } else if (combinator != null) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } } |