aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java5
6 files changed, 37 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 5d384e0329b..f26174d9576 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -120,6 +120,17 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
+ public Double getAsDouble(TensorAddress address) {
+ try {
+ long index = toValueIndex(address, dimensionSizes, type);
+ if (index < 0 || size() <= index) return null;
+ return get(index);
+ } catch (IllegalArgumentException e) {
+ return null;
+ }
+ }
+
+ @Override
public boolean has(TensorAddress address) {
try {
long index = toValueIndex(address, dimensionSizes, type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 5471ea65b97..3e0df5f2261 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -41,6 +41,9 @@ public class MappedTensor implements Tensor {
public boolean has(TensorAddress address) { return cells.containsKey(address); }
@Override
+ public Double getAsDouble(TensorAddress address) { return cells.get(address); }
+
+ @Override
public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index cf6e737bf27..aab4a4e1297 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -83,6 +83,16 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Double getAsDouble(TensorAddress address) {
+ var block = index.blockOf(address);
+ int denseOffset = index.denseOffsetOf(address);
+ if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) {
+ return null;
+ }
+ return block.cells[denseOffset];
+ }
+
+ @Override
public boolean has(TensorAddress address) {
var block = index.blockOf(address);
int denseOffset = index.denseOffsetOf(address);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index cff17fdfd7c..d034ac551f8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -90,6 +90,8 @@ public interface Tensor {
/** Returns true if this cell exists */
boolean has(TensorAddress address);
+ /** null = no value present. More efficient that if (t.has(key)) t.get(key) */
+ Double getAsDouble(TensorAddress address);
/**
* Returns the cell of this in some undefined order.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 7a336233de0..712e5528fc6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -129,8 +129,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
var key = aCell.getKey();
- if (b.has(key)) {
- builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key)));
+ Double bVal = b.getAsDouble(key);
+ if (bVal != null) {
+ builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
}
}
return builder.build();
@@ -206,11 +207,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> supercell = i.next();
TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
- if (subspace.has(subaddress)) {
- double subspaceValue = subspace.get(subaddress);
+ Double subspaceValue = subspace.getAsDouble(subaddress);
+ if (subspaceValue != null) {
builder.cell(supercell.getKey(),
- reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
- : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ reversedArgumentOrder
+ ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
+ : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
}
return builder.build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index 59394785382..ddad91dc060 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -121,10 +121,11 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
var key = aCell.getKey();
- if (! b.has(key)) {
+ Double bVal = b.getAsDouble(key);
+ if (bVal == null) {
builder.cell(key, aCell.getValue());
} else if (combinator != null) {
- builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key)));
+ builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
}
}
}