diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 83 |
1 files changed, 14 insertions, 69 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 6128611302f..19b4ad39af3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -2,20 +2,18 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; -import java.util.Arrays; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; -import java.util.Set; import java.util.function.DoubleBinaryOperator; /** @@ -71,7 +69,7 @@ public class Join extends PrimitiveTensorFunction { TensorType joinedType = a.type().combineWith(b.type()); // Choose join algorithm - if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) + if (a.type().equals(b.type()) && a.type().dimensions().size() == 1 && a.type().dimensions().get(0).isIndexed()) return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType); else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) return singleSpaceJoin(a, b, joinedType); @@ -83,12 +81,8 @@ public class Join extends PrimitiveTensorFunction { return generalJoin(a, b, joinedType); } - private boolean hasSingleIndexedDimension(Tensor tensor) { - return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); - } - private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { - int joinedLength = Math.min(a.size(0), b.size(0)); + int joinedLength = Math.min(a.length(0), b.length(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength}); @@ -111,42 +105,6 @@ public class Join extends PrimitiveTensorFunction { /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { - if (subspace.type().isIndexed() && superspace.type().isIndexed()) - return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder); - else - return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder); - } - - private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { - if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes - return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build(); - - // Find size of joined tensor - int[] joinedSizes = new int[joinedType.dimensions().size()]; - for (int i = 0; i < joinedSizes.length; i++) { - Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name()); - if (subspaceIndex.isPresent()) - joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get())); - else - joinedSizes[i] = superspace.size(i); - } - - Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes); - - // Find dimensions which are only in the supertype - Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); - superDimensionNames.removeAll(subspace.type().dimensionNames()); - - for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { - IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); - joinSubspaces(subspace.valueIterator(), subspace.size(), - subspaceInSuper, subspaceInSuper.size(), - reversedArgumentOrder, builder); - } - return builder.build(); - } - - private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Map.Entry<TensorAddress, Double>> i = superspace.cellIterator(); i.hasNext(); ) { @@ -154,26 +112,13 @@ public class Join extends PrimitiveTensorFunction { TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); double subspaceValue = subspace.get(subaddress); if ( ! Double.isNaN(subspaceValue)) - builder.cell(supercell.getKey(), + builder.cell(supercell.getKey(), reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } return builder.build(); } - - private void joinSubspaces(Iterator<Double> subspace, int subspaceSize, - Iterator<Map.Entry<TensorAddress, Double>> superspace, int superspaceSize, - boolean reversedArgumentOrder, Tensor.Builder builder) { - int joinedLength = Math.min(subspaceSize, superspaceSize); - for (int i = 0; i < joinedLength; i++) { - Double subvalue = subspace.next(); - Map.Entry<TensorAddress, Double> supercell = superspace.next(); - builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subvalue) - : combinator.applyAsDouble(subvalue, supercell.getValue())); - } - } - + /** Returns the indexes in the superspace type which should be retained to create the subspace type */ private int[] subspaceIndexes(TensorType supertype, TensorType subtype) { int[] subspaceIndexes = new int[subtype.dimensions().size()]; @@ -185,8 +130,8 @@ public class Join extends PrimitiveTensorFunction { private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) - subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); - return TensorAddress.of(subspaceLabels); + subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]); + return new TensorAddress(subspaceLabels); } /** Slow join which works for any two tensors */ @@ -224,10 +169,10 @@ public class Join extends PrimitiveTensorFunction { private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) { String[] joinedLabels = new String[joinedType.dimensions().size()]; - mapContent(a, joinedLabels, aToIndexes); - boolean compatible = mapContent(b, joinedLabels, bToIndexes); + mapContent(a.labels(), joinedLabels, aToIndexes); + boolean compatible = mapContent(b.labels(), joinedLabels, bToIndexes); if ( ! compatible) return null; - return TensorAddress.of(joinedLabels); + return new TensorAddress(joinedLabels); } /** @@ -236,11 +181,11 @@ public class Join extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { + private boolean mapContent(List<String> from, String[] to, int[] indexMap) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; - if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; - to[toIndex] = from.label(i); + if (to[toIndex] != null && ! to[toIndex].equals(from.get(i))) return false; + to[toIndex] = from.get(i); } return true; } |