summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2016-12-20 09:22:00 +0100
committerGitHub <noreply@github.com>2016-12-20 09:22:00 +0100
commit5f32c0369cf796e46b70576d2f4eb8e470edb0e6 (patch)
treef15261cc22786afe1bdbab63e9075970501e542b /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent3cd484f5a35af1b2fda324e3787c741be02179fa (diff)
Revert "Revert "Bratseth/tensor subiterators""
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java83
1 files changed, 69 insertions, 14 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 19b4ad39af3..6128611302f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -2,18 +2,20 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import java.util.Arrays;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
@@ -69,7 +71,7 @@ public class Join extends PrimitiveTensorFunction {
TensorType joinedType = a.type().combineWith(b.type());
// Choose join algorithm
- if (a.type().equals(b.type()) && a.type().dimensions().size() == 1 && a.type().dimensions().get(0).isIndexed())
+ if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType);
@@ -81,8 +83,12 @@ public class Join extends PrimitiveTensorFunction {
return generalJoin(a, b, joinedType);
}
+ private boolean hasSingleIndexedDimension(Tensor tensor) {
+ return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
+ }
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.length(0), b.length(0));
+ int joinedLength = Math.min(a.size(0), b.size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength});
@@ -105,6 +111,42 @@ public class Join extends PrimitiveTensorFunction {
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ if (subspace.type().isIndexed() && superspace.type().isIndexed())
+ return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
+ else
+ return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
+ }
+
+ private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
+ return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build();
+
+ // Find size of joined tensor
+ int[] joinedSizes = new int[joinedType.dimensions().size()];
+ for (int i = 0; i < joinedSizes.length; i++) {
+ Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name());
+ if (subspaceIndex.isPresent())
+ joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get()));
+ else
+ joinedSizes[i] = superspace.size(i);
+ }
+
+ Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes);
+
+ // Find dimensions which are only in the supertype
+ Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
+ superDimensionNames.removeAll(subspace.type().dimensionNames());
+
+ for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
+ IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
+ joinSubspaces(subspace.valueIterator(), subspace.size(),
+ subspaceInSuper, subspaceInSuper.size(),
+ reversedArgumentOrder, builder);
+ }
+ return builder.build();
+ }
+
+ private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Map.Entry<TensorAddress, Double>> i = superspace.cellIterator(); i.hasNext(); ) {
@@ -112,13 +154,26 @@ public class Join extends PrimitiveTensorFunction {
TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
double subspaceValue = subspace.get(subaddress);
if ( ! Double.isNaN(subspaceValue))
- builder.cell(supercell.getKey(),
+ builder.cell(supercell.getKey(),
reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
: combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
return builder.build();
}
-
+
+ private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
+ Iterator<Map.Entry<TensorAddress, Double>> superspace, int superspaceSize,
+ boolean reversedArgumentOrder, Tensor.Builder builder) {
+ int joinedLength = Math.min(subspaceSize, superspaceSize);
+ for (int i = 0; i < joinedLength; i++) {
+ Double subvalue = subspace.next();
+ Map.Entry<TensorAddress, Double> supercell = superspace.next();
+ builder.cell(supercell.getKey(),
+ reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subvalue)
+ : combinator.applyAsDouble(subvalue, supercell.getValue()));
+ }
+ }
+
/** Returns the indexes in the superspace type which should be retained to create the subspace type */
private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
int[] subspaceIndexes = new int[subtype.dimensions().size()];
@@ -130,8 +185,8 @@ public class Join extends PrimitiveTensorFunction {
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
- subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]);
- return new TensorAddress(subspaceLabels);
+ subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
+ return TensorAddress.of(subspaceLabels);
}
/** Slow join which works for any two tensors */
@@ -169,10 +224,10 @@ public class Join extends PrimitiveTensorFunction {
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
String[] joinedLabels = new String[joinedType.dimensions().size()];
- mapContent(a.labels(), joinedLabels, aToIndexes);
- boolean compatible = mapContent(b.labels(), joinedLabels, bToIndexes);
+ mapContent(a, joinedLabels, aToIndexes);
+ boolean compatible = mapContent(b, joinedLabels, bToIndexes);
if ( ! compatible) return null;
- return new TensorAddress(joinedLabels);
+ return TensorAddress.of(joinedLabels);
}
/**
@@ -181,11 +236,11 @@ public class Join extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(List<String> from, String[] to, int[] indexMap) {
+ private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
- if (to[toIndex] != null && ! to[toIndex].equals(from.get(i))) return false;
- to[toIndex] = from.get(i);
+ if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false;
+ to[toIndex] = from.label(i);
}
return true;
}