diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-14 08:28:16 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-14 08:28:16 +0100 |
commit | 1f9c107864f1635dd3b69e5c73fd83bc78e28756 (patch) | |
tree | ffba4b9c4d6306fded70697805459d22367ae7ea /vespajlib | |
parent | a078e26db8db9d0d315cda7f467721d26c522f99 (diff) |
Optimize subspace join
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 32 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java | 7 |
2 files changed, 36 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index a44024b301a..5f91ff3034c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -73,6 +73,10 @@ public class Join extends PrimitiveTensorFunction { return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType); else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) return singleSpaceJoin(a, b, joinedType); + else if (a.type().dimensions().containsAll(b.type().dimensions())) + return subspaceJoin(b, a, joinedType); + else if (b.type().dimensions().containsAll(a.type().dimensions())) + return subspaceJoin(a, b,joinedType); else return generalJoin(a, b, joinedType); } @@ -95,6 +99,34 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } + + /** Join a tensor into a superspace */ + private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType) { + int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); + Tensor.Builder builder = Tensor.Builder.of(joinedType); + for (Map.Entry<TensorAddress, Double> supercell : superspace.cells().entrySet()) { + TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); + double subspaceValue = subspace.get(subaddress); + if (subspaceValue != Double.NaN) + builder.cell(supercell.getKey(), combinator.applyAsDouble(subspaceValue, supercell.getValue())); + } + return builder.build(); + } + + /** Returns the indexes in the superspace type which should be retained to create the subspace type */ + private int[] subspaceIndexes(TensorType supertype, TensorType subtype) { + int[] subspaceIndexes = new int[subtype.dimensions().size()]; + for (int i = 0; i < subtype.dimensions().size(); i++) + subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); + return subspaceIndexes; + } + + private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { + String[] subspaceLabels = new String[subspaceIndexes.length]; + for (int i = 0; i < subspaceIndexes.length; i++) + subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]); + return new TensorAddress(subspaceLabels); + } /** Slow join which works for any two tensors */ private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 2b6b8cf05da..ab62020b265 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -91,7 +91,8 @@ public class TensorFunctionBenchmark { time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); // Initial: 760 ms - time = new TensorFunctionBenchmark().benchmark(10, generateMatrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped); + // - After special-casing subspace: 15 ms + time = new TensorFunctionBenchmark().benchmark(500, generateMatrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed: @@ -102,8 +103,8 @@ public class TensorFunctionBenchmark { time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.indexedUnbound),TensorType.Dimension.Type.indexedUnbound); System.out.printf("Indexed vectors, time per join: %1$8.3f ms\n", time); // Initial: 3500 ms - // time = new TensorFunctionBenchmark().benchmark(10, generateMatrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound); - // System.out.printf("Indexed matrix, time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(10, generateMatrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound); + System.out.printf("Indexed matrix, time per join: %1$8.3f ms\n", time); } } |