diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 97 |
1 files changed, 50 insertions, 47 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index be323313369..62ee471fcf4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -82,25 +82,29 @@ public class Join extends PrimitiveTensorFunction { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); + return evaluate(a, b, joinedType, combinator); + } + static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { // Choose join algorithm if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) - return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType); + return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator); else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) - return singleSpaceJoin(a, b, joinedType); + return singleSpaceJoin(a, b, joinedType, combinator); else if (a.type().dimensions().containsAll(b.type().dimensions())) - return subspaceJoin(b, a, joinedType, true); + return subspaceJoin(b, a, joinedType, true, combinator); else if (b.type().dimensions().containsAll(a.type().dimensions())) - return subspaceJoin(a, b, joinedType, false); + return subspaceJoin(a, b, joinedType, false, combinator); else - return generalJoin(a, b, joinedType); + return generalJoin(a, b, joinedType, combinator); + } - private boolean hasSingleIndexedDimension(Tensor tensor) { + private static boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } - private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { + private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); @@ -111,7 +115,7 @@ public class Join extends PrimitiveTensorFunction { } /** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */ - private Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); @@ -123,14 +127,14 @@ public class Join extends PrimitiveTensorFunction { } /** Join a tensor into a superspace */ - private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) - return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder); + return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator); else - return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder); + return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator); } - private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); @@ -145,16 +149,17 @@ public class Join extends PrimitiveTensorFunction { for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), - subspaceInSuper, subspaceInSuper.size(), - reversedArgumentOrder, builder); + subspaceInSuper, subspaceInSuper.size(), + reversedArgumentOrder, builder, combinator); } return builder.build(); } - private void joinSubspaces(Iterator<Double> subspace, long subspaceSize, - Iterator<Tensor.Cell> superspace, long superspaceSize, - boolean reversedArgumentOrder, IndexedTensor.Builder builder) { + private static void joinSubspaces(Iterator<Double> subspace, long subspaceSize, + Iterator<Tensor.Cell> superspace, long superspaceSize, + boolean reversedArgumentOrder, IndexedTensor.Builder builder, + DoubleBinaryOperator combinator) { long joinedLength = Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { @@ -169,7 +174,7 @@ public class Join extends PrimitiveTensorFunction { } } - private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) { + private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) { DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size()); for (int i = 0; i < builder.dimensions(); i++) { String dimensionName = joinedType.dimensions().get(i).name(); @@ -185,7 +190,7 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { @@ -194,21 +199,21 @@ public class Join extends PrimitiveTensorFunction { double subspaceValue = subspace.get(subaddress); if ( ! Double.isNaN(subspaceValue)) builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } return builder.build(); } /** Returns the indexes in the superspace type which should be retained to create the subspace type */ - private int[] subspaceIndexes(TensorType supertype, TensorType subtype) { + private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) { int[] subspaceIndexes = new int[subtype.dimensions().size()]; for (int i = 0; i < subtype.dimensions().size(); i++) subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } - private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { + private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); @@ -216,25 +221,25 @@ public class Join extends PrimitiveTensorFunction { } /** Slow join which works for any two tensors */ - private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { if (a instanceof IndexedTensor && b instanceof IndexedTensor) - return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType); + return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator); else - return mappedHashJoin(a, b, joinedType); + return mappedHashJoin(a, b, joinedType, combinator); } - private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) { + private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) { DimensionSizes joinedSize = joinedSize(joinedType, a, b); Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize); int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); - joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder); -// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); + joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator); return builder.build(); } - private void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, - int[] aToIndexes, int[] bToIndexes, boolean reversedOrder, Tensor.Builder builder) { + private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, + int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, + DoubleBinaryOperator combinator) { Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); @@ -252,15 +257,14 @@ public class Join extends PrimitiveTensorFunction { for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType); - double joinedValue = reversedOrder ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) - : combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); + double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); builder.cell(joinedAddress, joinedValue); } } } } - private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { + private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) @@ -269,7 +273,7 @@ public class Join extends PrimitiveTensorFunction { } /** Returns the sizes from the joined sizes which are present in the type argument */ - private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { + private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); int dimensionIndex = 0; for (int i = 0; i < joinedType.dimensions().size(); i++) { @@ -279,7 +283,7 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); Tensor.Builder builder = Tensor.Builder.of(joinedType); @@ -288,7 +292,7 @@ public class Join extends PrimitiveTensorFunction { for (Iterator<Tensor.Cell> bIterator = b.cellIterator(); bIterator.hasNext(); ) { Map.Entry<TensorAddress, Double> bCell = bIterator.next(); TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes, - bCell.getKey(), bToIndexes, joinedType); + bCell.getKey(), bToIndexes, joinedType); if (combinedAddress == null) continue; // not combinable builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue())); } @@ -296,10 +300,10 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { TensorType commonDimensionType = commonDimensions(a, b); if (commonDimensionType.dimensions().isEmpty()) { - return mappedGeneralJoin(a, b, joinedType); // fallback + return mappedGeneralJoin(a, b, joinedType, combinator); // fallback } boolean swapTensors = a.size() > b.size(); @@ -351,15 +355,15 @@ public class Join extends PrimitiveTensorFunction { * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ - private int[] mapIndexes(TensorType fromType, TensorType toType) { + static int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); return toIndexes; } - private TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType joinedType) { + private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, + TensorType joinedType) { String[] joinedLabels = new String[joinedType.dimensions().size()]; mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); @@ -373,7 +377,7 @@ public class Join extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { + private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; @@ -382,11 +386,10 @@ public class Join extends PrimitiveTensorFunction { return true; } - /** * Returns common dimension of a and b as a new tensor type */ - private TensorType commonDimensions(Tensor a, Tensor b) { + private static TensorType commonDimensions(Tensor a, Tensor b) { TensorType.Builder typeBuilder = new TensorType.Builder(); TensorType aType = a.type(); TensorType bType = b.type(); @@ -402,14 +405,14 @@ public class Join extends PrimitiveTensorFunction { return typeBuilder.build(); } - private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { + private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { TensorAddress address = cell.getKey(); String[] labels = new String[indexMap.length]; for (int i = 0; i < labels.length; ++i) { labels[i] = address.label(indexMap[i]); } return TensorAddress.of(labels); - } } + |