summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-14 08:28:16 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-14 08:28:16 +0100
commit1f9c107864f1635dd3b69e5c73fd83bc78e28756 (patch)
treeffba4b9c4d6306fded70697805459d22367ae7ea /vespajlib
parenta078e26db8db9d0d315cda7f467721d26c522f99 (diff)
Optimize subspace join
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java32
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java7
2 files changed, 36 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index a44024b301a..5f91ff3034c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -73,6 +73,10 @@ public class Join extends PrimitiveTensorFunction {
return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType);
+ else if (a.type().dimensions().containsAll(b.type().dimensions()))
+ return subspaceJoin(b, a, joinedType);
+ else if (b.type().dimensions().containsAll(a.type().dimensions()))
+ return subspaceJoin(a, b,joinedType);
else
return generalJoin(a, b, joinedType);
}
@@ -95,6 +99,34 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
+
+ /** Join a tensor into a superspace */
+ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType) {
+ int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
+ Tensor.Builder builder = Tensor.Builder.of(joinedType);
+ for (Map.Entry<TensorAddress, Double> supercell : superspace.cells().entrySet()) {
+ TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
+ double subspaceValue = subspace.get(subaddress);
+ if (subspaceValue != Double.NaN)
+ builder.cell(supercell.getKey(), combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ }
+ return builder.build();
+ }
+
+ /** Returns the indexes in the superspace type which should be retained to create the subspace type */
+ private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
+ int[] subspaceIndexes = new int[subtype.dimensions().size()];
+ for (int i = 0; i < subtype.dimensions().size(); i++)
+ subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
+ return subspaceIndexes;
+ }
+
+ private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
+ String[] subspaceLabels = new String[subspaceIndexes.length];
+ for (int i = 0; i < subspaceIndexes.length; i++)
+ subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]);
+ return new TensorAddress(subspaceLabels);
+ }
/** Slow join which works for any two tensors */
private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index 2b6b8cf05da..ab62020b265 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -91,7 +91,8 @@ public class TensorFunctionBenchmark {
time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped);
System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time);
// Initial: 760 ms
- time = new TensorFunctionBenchmark().benchmark(10, generateMatrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped);
+ // - After special-casing subspace: 15 ms
+ time = new TensorFunctionBenchmark().benchmark(500, generateMatrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped);
System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed:
@@ -102,8 +103,8 @@ public class TensorFunctionBenchmark {
time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.indexedUnbound),TensorType.Dimension.Type.indexedUnbound);
System.out.printf("Indexed vectors, time per join: %1$8.3f ms\n", time);
// Initial: 3500 ms
- // time = new TensorFunctionBenchmark().benchmark(10, generateMatrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound);
- // System.out.printf("Indexed matrix, time per join: %1$8.3f ms\n", time);
+ time = new TensorFunctionBenchmark().benchmark(10, generateMatrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound);
+ System.out.printf("Indexed matrix, time per join: %1$8.3f ms\n", time);
}
}