diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 57 |
1 files changed, 25 insertions, 32 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index e0ac549651c..047d8ee6ef0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -12,9 +12,11 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.impl.StringTensorAddress; +import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -206,7 +208,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); - TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); + TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes); Double subspaceValue = subspace.getAsDouble(subaddress); if (subspaceValue != null) { builder.cell(supercell.getKey(), @@ -226,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return subspaceIndexes; } - private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { - String[] subspaceLabels = new String[subspaceIndexes.length]; - for (int i = 0; i < subspaceIndexes.length; i++) - subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); - return StringTensorAddress.unsafeOf(subspaceLabels); - } - /** Slow join which works for any two tensors */ private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { if (a instanceof IndexedTensor && b instanceof IndexedTensor) @@ -253,9 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { - Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); + Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames())); int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection - Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); + Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames())); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); @@ -266,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); - PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize); + PartialAddress matchingBCells = sharedDimensionSize > 0 + ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize) + : empty; // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); @@ -278,12 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } } + private static PartialAddress empty = new PartialAddress.Builder(0).build(); private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions, int sharedDimensionSize) { PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); - for (int i = 0; i < addressType.dimensions().size(); i++) - if (retainDimensions.contains(addressType.dimensions().get(i).name())) - builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); + for (int i = 0; i < addressType.dimensions().size(); i++) { + String dimension = addressType.dimensions().get(i).name(); + if (retainDimensions.contains(dimension)) + builder.add(dimension, address.numericLabel(i)); + } return builder.build(); } @@ -338,7 +338,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt()); for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell aCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon); + TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon); aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell); } @@ -346,7 +346,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell bCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon); + TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon); for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) { TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType); @@ -377,11 +377,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) { - String[] joinedLabels = new String[joinedType.dimensions().size()]; + int[] joinedLabels = new int[joinedType.dimensions().size()]; + Arrays.fill(joinedLabels, Tensor.INVALID_INDEX); mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); if ( ! compatible) return null; - return StringTensorAddress.unsafeOf(joinedLabels); + return TensorAddressAny.ofUnsafe(joinedLabels); } /** @@ -390,11 +391,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { - for (int i = 0; i < from.size(); i++) { + private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) { + for (int i = 0, sz = from.size(); i < sz; i++) { int toIndex = indexMap[i]; - String label = from.label(i); - if (to[toIndex] != null && ! to[toIndex].equals(label)) return false; + int label = Convert.safe2Int(from.numericLabel(i)); + if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != label) + return false; to[toIndex] = label; } return true; @@ -417,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return typeBuilder.build(); } - private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { - TensorAddress address = cell.getKey(); - String[] labels = new String[indexMap.length]; - for (int i = 0; i < labels.length; ++i) { - labels[i] = address.label(indexMap[i]); - } - return StringTensorAddress.unsafeOf(labels); - } - } |