diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2016-12-20 09:22:00 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-20 09:22:00 +0100 |
commit | 5f32c0369cf796e46b70576d2f4eb8e470edb0e6 (patch) | |
tree | f15261cc22786afe1bdbab63e9075970501e542b /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | |
parent | 3cd484f5a35af1b2fda324e3787c741be02179fa (diff) |
Revert "Revert "Bratseth/tensor subiterators""
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 83 |
1 files changed, 69 insertions, 14 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 19b4ad39af3..6128611302f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -2,18 +2,20 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import java.util.Arrays; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; /** @@ -69,7 +71,7 @@ public class Join extends PrimitiveTensorFunction { TensorType joinedType = a.type().combineWith(b.type()); // Choose join algorithm - if (a.type().equals(b.type()) && a.type().dimensions().size() == 1 && a.type().dimensions().get(0).isIndexed()) + if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType); else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) return singleSpaceJoin(a, b, joinedType); @@ -81,8 +83,12 @@ public class Join extends PrimitiveTensorFunction { return generalJoin(a, b, joinedType); } + private boolean hasSingleIndexedDimension(Tensor tensor) { + return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); + } + private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { - int joinedLength = Math.min(a.length(0), b.length(0)); + int joinedLength = Math.min(a.size(0), b.size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength}); @@ -105,6 +111,42 @@ public class Join extends PrimitiveTensorFunction { /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + if (subspace.type().isIndexed() && superspace.type().isIndexed()) + return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder); + else + return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder); + } + + private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes + return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build(); + + // Find size of joined tensor + int[] joinedSizes = new int[joinedType.dimensions().size()]; + for (int i = 0; i < joinedSizes.length; i++) { + Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name()); + if (subspaceIndex.isPresent()) + joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get())); + else + joinedSizes[i] = superspace.size(i); + } + + Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes); + + // Find dimensions which are only in the supertype + Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); + superDimensionNames.removeAll(subspace.type().dimensionNames()); + + for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { + IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); + joinSubspaces(subspace.valueIterator(), subspace.size(), + subspaceInSuper, subspaceInSuper.size(), + reversedArgumentOrder, builder); + } + return builder.build(); + } + + private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Map.Entry<TensorAddress, Double>> i = superspace.cellIterator(); i.hasNext(); ) { @@ -112,13 +154,26 @@ public class Join extends PrimitiveTensorFunction { TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); double subspaceValue = subspace.get(subaddress); if ( ! Double.isNaN(subspaceValue)) - builder.cell(supercell.getKey(), + builder.cell(supercell.getKey(), reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } return builder.build(); } - + + private void joinSubspaces(Iterator<Double> subspace, int subspaceSize, + Iterator<Map.Entry<TensorAddress, Double>> superspace, int superspaceSize, + boolean reversedArgumentOrder, Tensor.Builder builder) { + int joinedLength = Math.min(subspaceSize, superspaceSize); + for (int i = 0; i < joinedLength; i++) { + Double subvalue = subspace.next(); + Map.Entry<TensorAddress, Double> supercell = superspace.next(); + builder.cell(supercell.getKey(), + reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subvalue) + : combinator.applyAsDouble(subvalue, supercell.getValue())); + } + } + /** Returns the indexes in the superspace type which should be retained to create the subspace type */ private int[] subspaceIndexes(TensorType supertype, TensorType subtype) { int[] subspaceIndexes = new int[subtype.dimensions().size()]; @@ -130,8 +185,8 @@ public class Join extends PrimitiveTensorFunction { private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) - subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]); - return new TensorAddress(subspaceLabels); + subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); + return TensorAddress.of(subspaceLabels); } /** Slow join which works for any two tensors */ @@ -169,10 +224,10 @@ public class Join extends PrimitiveTensorFunction { private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) { String[] joinedLabels = new String[joinedType.dimensions().size()]; - mapContent(a.labels(), joinedLabels, aToIndexes); - boolean compatible = mapContent(b.labels(), joinedLabels, bToIndexes); + mapContent(a, joinedLabels, aToIndexes); + boolean compatible = mapContent(b, joinedLabels, bToIndexes); if ( ! compatible) return null; - return new TensorAddress(joinedLabels); + return TensorAddress.of(joinedLabels); } /** @@ -181,11 +236,11 @@ public class Join extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(List<String> from, String[] to, int[] indexMap) { + private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; - if (to[toIndex] != null && ! to[toIndex].equals(from.get(i))) return false; - to[toIndex] = from.get(i); + if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; + to[toIndex] = from.label(i); } return true; } |