From 1f9c107864f1635dd3b69e5c73fd83bc78e28756 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 14 Dec 2016 08:28:16 +0100 Subject: Optimize subspace join --- .../main/java/com/yahoo/tensor/functions/Join.java | 32 ++++++++++++++++++++++ 1 file changed, 32 insertions(+) (limited to 'vespajlib/src/main/java/com/yahoo') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index a44024b301a..5f91ff3034c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -73,6 +73,10 @@ public class Join extends PrimitiveTensorFunction { return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType); else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) return singleSpaceJoin(a, b, joinedType); + else if (a.type().dimensions().containsAll(b.type().dimensions())) + return subspaceJoin(b, a, joinedType); + else if (b.type().dimensions().containsAll(a.type().dimensions())) + return subspaceJoin(a, b,joinedType); else return generalJoin(a, b, joinedType); } @@ -95,6 +99,34 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } + + /** Join a tensor into a superspace */ + private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType) { + int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); + Tensor.Builder builder = Tensor.Builder.of(joinedType); + for (Map.Entry supercell : superspace.cells().entrySet()) { + TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); + double subspaceValue = subspace.get(subaddress); + if (subspaceValue != Double.NaN) + builder.cell(supercell.getKey(), combinator.applyAsDouble(subspaceValue, supercell.getValue())); + } + return builder.build(); + } + + /** Returns the indexes in the superspace type which should be retained to create the subspace type */ + private int[] subspaceIndexes(TensorType supertype, TensorType subtype) { + int[] subspaceIndexes = new int[subtype.dimensions().size()]; + for (int i = 0; i < subtype.dimensions().size(); i++) + subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); + return subspaceIndexes; + } + + private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { + String[] subspaceLabels = new String[subspaceIndexes.length]; + for (int i = 0; i < subspaceIndexes.length; i++) + subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]); + return new TensorAddress(subspaceLabels); + } /** Slow join which works for any two tensors */ private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) { -- cgit v1.2.3