aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java57
1 files changed, 25 insertions, 32 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index e0ac549651c..047d8ee6ef0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -12,9 +12,11 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.impl.StringTensorAddress;
+import com.yahoo.tensor.impl.Convert;
+import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -206,7 +208,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> supercell = i.next();
- TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
+ TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes);
Double subspaceValue = subspace.getAsDouble(subaddress);
if (subspaceValue != null) {
builder.cell(supercell.getKey(),
@@ -226,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return subspaceIndexes;
}
- private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
- String[] subspaceLabels = new String[subspaceIndexes.length];
- for (int i = 0; i < subspaceIndexes.length; i++)
- subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
- return StringTensorAddress.unsafeOf(subspaceLabels);
- }
-
/** Slow join which works for any two tensors */
private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
@@ -253,9 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
DoubleBinaryOperator combinator) {
- Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()));
int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection
- Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()));
DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
@@ -266,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
Tensor.Cell aCell = aSubspace.next();
- PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize);
+ PartialAddress matchingBCells = sharedDimensionSize > 0
+ ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize)
+ : empty;
// for each matching combination of dimensions ony in b
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
@@ -278,12 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
}
+ private static PartialAddress empty = new PartialAddress.Builder(0).build();
private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
Set<String> retainDimensions, int sharedDimensionSize) {
PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
- for (int i = 0; i < addressType.dimensions().size(); i++)
- if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
+ for (int i = 0; i < addressType.dimensions().size(); i++) {
+ String dimension = addressType.dimensions().get(i).name();
+ if (retainDimensions.contains(dimension))
+ builder.add(dimension, address.numericLabel(i));
+ }
return builder.build();
}
@@ -338,7 +338,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell aCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon);
+ TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon);
aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
}
@@ -346,7 +346,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell bCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon);
+ TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon);
for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) {
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
bCell.getKey(), bIndexesInJoined, joinedType);
@@ -377,11 +377,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
- String[] joinedLabels = new String[joinedType.dimensions().size()];
+ int[] joinedLabels = new int[joinedType.dimensions().size()];
+ Arrays.fill(joinedLabels, Tensor.INVALID_INDEX);
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
if ( ! compatible) return null;
- return StringTensorAddress.unsafeOf(joinedLabels);
+ return TensorAddressAny.ofUnsafe(joinedLabels);
}
/**
@@ -390,11 +391,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
- for (int i = 0; i < from.size(); i++) {
+ private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) {
+ for (int i = 0, sz = from.size(); i < sz; i++) {
int toIndex = indexMap[i];
- String label = from.label(i);
- if (to[toIndex] != null && ! to[toIndex].equals(label)) return false;
+ int label = Convert.safe2Int(from.numericLabel(i));
+ if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != label)
+ return false;
to[toIndex] = label;
}
return true;
@@ -417,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return typeBuilder.build();
}
- private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
- TensorAddress address = cell.getKey();
- String[] labels = new String[indexMap.length];
- for (int i = 0; i < labels.length; ++i) {
- labels[i] = address.label(indexMap[i]);
- }
- return StringTensorAddress.unsafeOf(labels);
- }
-
}