summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-14 08:28:16 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-14 08:28:16 +0100
commit1f9c107864f1635dd3b69e5c73fd83bc78e28756 (patch)
treeffba4b9c4d6306fded70697805459d22367ae7ea /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parenta078e26db8db9d0d315cda7f467721d26c522f99 (diff)
Optimize subspace join
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java32
1 files changed, 32 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index a44024b301a..5f91ff3034c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -73,6 +73,10 @@ public class Join extends PrimitiveTensorFunction {
return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType);
+ else if (a.type().dimensions().containsAll(b.type().dimensions()))
+ return subspaceJoin(b, a, joinedType);
+ else if (b.type().dimensions().containsAll(a.type().dimensions()))
+ return subspaceJoin(a, b,joinedType);
else
return generalJoin(a, b, joinedType);
}
@@ -95,6 +99,34 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
+
+ /** Join a tensor into a superspace */
+ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType) {
+ int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
+ Tensor.Builder builder = Tensor.Builder.of(joinedType);
+ for (Map.Entry<TensorAddress, Double> supercell : superspace.cells().entrySet()) {
+ TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
+ double subspaceValue = subspace.get(subaddress);
+ if (subspaceValue != Double.NaN)
+ builder.cell(supercell.getKey(), combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ }
+ return builder.build();
+ }
+
+ /** Returns the indexes in the superspace type which should be retained to create the subspace type */
+ private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
+ int[] subspaceIndexes = new int[subtype.dimensions().size()];
+ for (int i = 0; i < subtype.dimensions().size(); i++)
+ subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
+ return subspaceIndexes;
+ }
+
+ private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
+ String[] subspaceLabels = new String[subspaceIndexes.length];
+ for (int i = 0; i < subspaceIndexes.length; i++)
+ subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]);
+ return new TensorAddress(subspaceLabels);
+ }
/** Slow join which works for any two tensors */
private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) {