summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-19 15:55:17 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-19 15:55:17 +0100
commit120b42f1e7f1fa0ce4b34a6e0956d52a62ca6aff (patch)
tree73bba5576289cbf87bb34e4cfab25e0c4dc7c8f9 /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent2959b5aefb258cf320f375f63a6555441fd0aa51 (diff)
Split iterating into subspaces for performance
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java67
1 files changed, 61 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 19b4ad39af3..df4ae4ec534 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -2,18 +2,20 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import java.util.Arrays;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
@@ -69,7 +71,7 @@ public class Join extends PrimitiveTensorFunction {
TensorType joinedType = a.type().combineWith(b.type());
// Choose join algorithm
- if (a.type().equals(b.type()) && a.type().dimensions().size() == 1 && a.type().dimensions().get(0).isIndexed())
+ if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType);
@@ -81,8 +83,12 @@ public class Join extends PrimitiveTensorFunction {
return generalJoin(a, b, joinedType);
}
+ private boolean hasSingleIndexedDimension(Tensor tensor) {
+ return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
+ }
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.length(0), b.length(0));
+ int joinedLength = Math.min(a.size(0), b.size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength});
@@ -105,6 +111,42 @@ public class Join extends PrimitiveTensorFunction {
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ if (subspace.type().isIndexed() && superspace.type().isIndexed())
+ return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
+ else
+ return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
+ }
+
+ private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
+ return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build();
+
+ // Find size of joined tensor
+ int[] joinedSizes = new int[joinedType.dimensions().size()];
+ for (int i = 0; i < joinedSizes.length; i++) {
+ Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name());
+ if (subspaceIndex.isPresent())
+ joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get()));
+ else
+ joinedSizes[i] = superspace.size(i);
+ }
+
+ Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes);
+
+ // Find dimensions which are only in the supertype
+ Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
+ superDimensionNames.removeAll(subspace.type().dimensionNames());
+
+ for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
+ IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
+ joinSubspaces(subspace.valueIterator(), subspace.size(),
+ subspaceInSuper, subspaceInSuper.size(),
+ reversedArgumentOrder, builder);
+ }
+ return builder.build();
+ }
+
+ private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Map.Entry<TensorAddress, Double>> i = superspace.cellIterator(); i.hasNext(); ) {
@@ -112,13 +154,26 @@ public class Join extends PrimitiveTensorFunction {
TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
double subspaceValue = subspace.get(subaddress);
if ( ! Double.isNaN(subspaceValue))
- builder.cell(supercell.getKey(),
+ builder.cell(supercell.getKey(),
reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
: combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
return builder.build();
}
-
+
+ private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
+ Iterator<Map.Entry<TensorAddress, Double>> superspace, int superspaceSize,
+ boolean reversedArgumentOrder, Tensor.Builder builder) {
+ int joinedLength = Math.min(subspaceSize, superspaceSize);
+ for (int i = 0; i < joinedLength; i++) {
+ Double subvalue = subspace.next();
+ Map.Entry<TensorAddress, Double> supercell = superspace.next();
+ builder.cell(supercell.getKey(),
+ reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subvalue)
+ : combinator.applyAsDouble(subvalue, supercell.getValue()));
+ }
+ }
+
/** Returns the indexes in the superspace type which should be retained to create the subspace type */
private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
int[] subspaceIndexes = new int[subtype.dimensions().size()];