// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; import com.google.common.collect.Sets; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.impl.Convert; import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.DoubleBinaryOperator; /** * The join tensor operation produces a tensor from the argument tensors containing the set of cells * given by the cross product of the cells of the given tensors, having as values the value produced by * applying the given combinator function on the values from the two source cells. * * @author bratseth */ public class Join extends PrimitiveTensorFunction { private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator combinator; public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(combinator, "The combinator function cannot be null"); this.argumentA = argumentA; this.argumentB = argumentB; this.combinator = combinator; } /** Returns the type resulting from applying Join to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { try { return TypeResolver.join(a, b); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Can not join " + a + " and " + b, e); } } public DoubleBinaryOperator combinator() { return combinator; } @Override public List> arguments() { return List.of(argumentA, argumentB); } @Override public TensorFunction withArguments(List> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); return new Join<>(arguments.get(0), arguments.get(1), combinator); } @Override public PrimitiveTensorFunction toPrimitive() { return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } @Override public String toString(ToStringContext context) { return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")"; } @Override public int hashCode() { return Objects.hash("join", argumentA, argumentB, combinator); } @Override public TensorType type(TypeContext context) { return outputType(argumentA.type(context), argumentB.type(context)); } @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = outputType(a.type(), b.type()); return evaluate(a, b, joinedType, combinator); } static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { // Choose join algorithm if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator); else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) return singleSpaceJoin(a, b, joinedType, combinator); else if (a.type().dimensions().containsAll(b.type().dimensions())) return subspaceJoin(b, a, joinedType, true, combinator); else if (b.type().dimensions().containsAll(a.type().dimensions())) return subspaceJoin(a, b, joinedType, false, combinator); else return generalJoin(a, b, joinedType, combinator); } private static boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator aIterator = a.valueIterator(); Iterator bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); for (int i = 0; i < joinedRank; i++) builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); return builder.build(); } /** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */ private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator i = a.cellIterator(); i.hasNext(); ) { Map.Entry aCell = i.next(); var key = aCell.getKey(); Double bVal = b.getAsDouble(key); if (bVal != null) { builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } return builder.build(); } /** Join a tensor into a superspace */ private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator); else return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator); } private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { if (subspace.isEmpty() || superspace.isEmpty()) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); // Find dimensions which are only in the supertype Set superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); superDimensionNames.removeAll(subspace.type().dimensionNames()); for (Iterator i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder, combinator); } return builder.build(); } private static void joinSubspaces(Iterator subspace, long subspaceSize, Iterator superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder, DoubleBinaryOperator combinator) { int joinedLength = (int)Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); builder.cell(supercell, combinator.applyAsDouble(supercell.getValue(), subspace.next())); } } else { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); builder.cell(supercell, combinator.applyAsDouble(subspace.next(), supercell.getValue())); } } } private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) { DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size()); for (int i = 0; i < builder.dimensions(); i++) { String dimensionName = joinedType.dimensions().get(i).name(); Optional aIndex = a.type().indexOfDimension(dimensionName); Optional bIndex = b.type().indexOfDimension(dimensionName); if (aIndex.isPresent() && bIndex.isPresent()) builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get()))); else if (aIndex.isPresent()) builder.set(i, a.dimensionSizes().size(aIndex.get())); else if (bIndex.isPresent()) builder.set(i, b.dimensionSizes().size(bIndex.get())); } return builder.build(); } private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry supercell = i.next(); TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes); Double subspaceValue = subspace.getAsDouble(subaddress); if (subspaceValue != null) { builder.cell(supercell.getKey(), reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } } return builder.build(); } /** Returns the indexes in the superspace type which should be retained to create the subspace type */ private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) { int[] subspaceIndexes = new int[subtype.dimensions().size()]; for (int i = 0; i < subtype.dimensions().size(); i++) subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } /** Slow join which works for any two tensors */ private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { if (a instanceof IndexedTensor && b instanceof IndexedTensor) return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator); else return mappedHashJoin(a, b, joinedType, combinator); } private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) { DimensionSizes joinedSize = joinedSize(joinedType, a, b); Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize); int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator); return builder.build(); } private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { Set sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames())); int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection Set dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames())); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); // for each combination of dimensions only in a for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { IndexedTensor.SubspaceIterator aSubspace = ia.next(); // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); PartialAddress matchingBCells = sharedDimensionSize > 0 ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize) : empty; // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType); double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); builder.cell(joinedAddress, joinedValue); } } } } private static PartialAddress empty = new PartialAddress.Builder(0).build(); private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set retainDimensions, int sharedDimensionSize) { PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); for (int i = 0; i < addressType.dimensions().size(); i++) { String dimension = addressType.dimensions().get(i).name(); if (retainDimensions.contains(dimension)) builder.add(dimension, address.numericLabel(i)); } return builder.build(); } /** Returns the sizes from the joined sizes which are present in the type argument */ private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); int dimensionIndex = 0; for (int i = 0; i < joinedType.dimensions().size(); i++) { if (type.dimensionNames().contains(joinedType.dimensions().get(i).name())) builder.set(dimensionIndex++, joinedSizes.size(i)); } return builder.build(); } private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator aIterator = a.cellIterator(); aIterator.hasNext(); ) { Map.Entry aCell = aIterator.next(); for (Iterator bIterator = b.cellIterator(); bIterator.hasNext(); ) { Map.Entry bCell = bIterator.next(); TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType); if (combinedAddress == null) continue; // not combinable builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue())); } } return builder.build(); } private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { TensorType commonDimensionType = commonDimensions(a, b); if (commonDimensionType.dimensions().isEmpty()) { return mappedGeneralJoin(a, b, joinedType, combinator); // fallback } boolean swapTensors = a.size() > b.size(); if (swapTensors) { Tensor temp = a; a = b; b = temp; } // Map dimension indexes to common and joined type int[] aIndexesInCommon = mapIndexes(commonDimensionType, a.type()); int[] bIndexesInCommon = mapIndexes(commonDimensionType, b.type()); int[] aIndexesInJoined = mapIndexes(a.type(), joinedType); int[] bIndexesInJoined = mapIndexes(b.type(), joinedType); // Iterate once through the smaller tensor and construct a hash map for common dimensions Map> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt()); for (Iterator cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell aCell = cellIterator.next(); TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon); aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell); } // Iterate once through the larger tensor and use the hash map to find joinable cells Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator cellIterator = b.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell bCell = cellIterator.next(); TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon); for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) { TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType); if (combinedAddress == null) continue; // not combinable double combinedValue = swapTensors ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) : combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); builder.cell(combinedAddress, combinedValue); } } return builder.build(); } /** * Returns an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ static int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name()); return toIndexes; } private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) { int[] joinedLabels = new int[joinedType.dimensions().size()]; Arrays.fill(joinedLabels, Tensor.INVALID_INDEX); mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); if ( ! compatible) return null; return TensorAddressAny.ofUnsafe(joinedLabels); } /** * Maps the content in the given list to the given array, using the given index map. * * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) { for (int i = 0, sz = from.size(); i < sz; i++) { int toIndex = indexMap[i]; int label = Convert.safe2Int(from.numericLabel(i)); if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != label) return false; to[toIndex] = label; } return true; } /** Returns common dimension of a and b as a new tensor type */ private static TensorType commonDimensions(Tensor a, Tensor b) { TensorType aType = a.type(); TensorType bType = b.type(); TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType)); for (int i = 0; i < aType.dimensions().size(); ++i) { TensorType.Dimension aDim = aType.dimensions().get(i); for (int j = 0; j < bType.dimensions().size(); ++j) { TensorType.Dimension bDim = bType.dimensions().get(j); if (aDim.equals(bDim)) { typeBuilder.set(bDim); } } } return typeBuilder.build(); } }