diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 1ded16636d3..7a336233de0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -114,7 +114,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { - long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); @@ -170,7 +170,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder, DoubleBinaryOperator combinator) { - long joinedLength = Math.min(subspaceSize, superspaceSize); + int joinedLength = (int)Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -252,6 +252,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); + int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); @@ -263,7 +264,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); - PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions); + PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize); // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); @@ -275,8 +276,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } } - private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { - PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); + private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, + Set<String> retainDimensions, int sharedDimensionSize) { + PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); |