summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java97
1 files changed, 50 insertions, 47 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index be323313369..62ee471fcf4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -82,25 +82,29 @@ public class Join extends PrimitiveTensorFunction {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
+ return evaluate(a, b, joinedType, combinator);
+ }
+ static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
// Choose join algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
- return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
+ return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
- return singleSpaceJoin(a, b, joinedType);
+ return singleSpaceJoin(a, b, joinedType, combinator);
else if (a.type().dimensions().containsAll(b.type().dimensions()))
- return subspaceJoin(b, a, joinedType, true);
+ return subspaceJoin(b, a, joinedType, true, combinator);
else if (b.type().dimensions().containsAll(a.type().dimensions()))
- return subspaceJoin(a, b, joinedType, false);
+ return subspaceJoin(a, b, joinedType, false, combinator);
else
- return generalJoin(a, b, joinedType);
+ return generalJoin(a, b, joinedType, combinator);
+
}
- private boolean hasSingleIndexedDimension(Tensor tensor) {
+ private static boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
- private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
+ private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
@@ -111,7 +115,7 @@ public class Join extends PrimitiveTensorFunction {
}
/** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */
- private Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
@@ -123,14 +127,14 @@ public class Join extends PrimitiveTensorFunction {
}
/** Join a tensor into a superspace */
- private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
- return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
+ return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator);
else
- return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
+ return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator);
}
- private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
@@ -145,16 +149,17 @@ public class Join extends PrimitiveTensorFunction {
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
- subspaceInSuper, subspaceInSuper.size(),
- reversedArgumentOrder, builder);
+ subspaceInSuper, subspaceInSuper.size(),
+ reversedArgumentOrder, builder, combinator);
}
return builder.build();
}
- private void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
- Iterator<Tensor.Cell> superspace, long superspaceSize,
- boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
+ private static void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
+ Iterator<Tensor.Cell> superspace, long superspaceSize,
+ boolean reversedArgumentOrder, IndexedTensor.Builder builder,
+ DoubleBinaryOperator combinator) {
long joinedLength = Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
@@ -169,7 +174,7 @@ public class Join extends PrimitiveTensorFunction {
}
}
- private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
+ private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
for (int i = 0; i < builder.dimensions(); i++) {
String dimensionName = joinedType.dimensions().get(i).name();
@@ -185,7 +190,7 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
@@ -194,21 +199,21 @@ public class Join extends PrimitiveTensorFunction {
double subspaceValue = subspace.get(subaddress);
if ( ! Double.isNaN(subspaceValue))
builder.cell(supercell.getKey(),
- reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
- : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
+ : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
return builder.build();
}
/** Returns the indexes in the superspace type which should be retained to create the subspace type */
- private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
+ private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
int[] subspaceIndexes = new int[subtype.dimensions().size()];
for (int i = 0; i < subtype.dimensions().size(); i++)
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
- private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
+ private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
@@ -216,25 +221,25 @@ public class Join extends PrimitiveTensorFunction {
}
/** Slow join which works for any two tensors */
- private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
- return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType);
+ return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator);
else
- return mappedHashJoin(a, b, joinedType);
+ return mappedHashJoin(a, b, joinedType, combinator);
}
- private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
+ private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
DimensionSizes joinedSize = joinedSize(joinedType, a, b);
Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
- joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder);
-// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
+ joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator);
return builder.build();
}
- private void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
- int[] aToIndexes, int[] bToIndexes, boolean reversedOrder, Tensor.Builder builder) {
+ private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
+ int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
+ DoubleBinaryOperator combinator) {
Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
@@ -252,15 +257,14 @@ public class Join extends PrimitiveTensorFunction {
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
- double joinedValue = reversedOrder ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue())
- : combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
+ double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
builder.cell(joinedAddress, joinedValue);
}
}
}
}
- private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
+ private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
@@ -269,7 +273,7 @@ public class Join extends PrimitiveTensorFunction {
}
/** Returns the sizes from the joined sizes which are present in the type argument */
- private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
+ private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
int dimensionIndex = 0;
for (int i = 0; i < joinedType.dimensions().size(); i++) {
@@ -279,7 +283,7 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
Tensor.Builder builder = Tensor.Builder.of(joinedType);
@@ -288,7 +292,7 @@ public class Join extends PrimitiveTensorFunction {
for (Iterator<Tensor.Cell> bIterator = b.cellIterator(); bIterator.hasNext(); ) {
Map.Entry<TensorAddress, Double> bCell = bIterator.next();
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
- bCell.getKey(), bToIndexes, joinedType);
+ bCell.getKey(), bToIndexes, joinedType);
if (combinedAddress == null) continue; // not combinable
builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
}
@@ -296,10 +300,10 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
TensorType commonDimensionType = commonDimensions(a, b);
if (commonDimensionType.dimensions().isEmpty()) {
- return mappedGeneralJoin(a, b, joinedType); // fallback
+ return mappedGeneralJoin(a, b, joinedType, combinator); // fallback
}
boolean swapTensors = a.size() > b.size();
@@ -351,15 +355,15 @@ public class Join extends PrimitiveTensorFunction {
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
- private int[] mapIndexes(TensorType fromType, TensorType toType) {
+ static int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
return toIndexes;
}
- private TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType joinedType) {
+ private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType joinedType) {
String[] joinedLabels = new String[joinedType.dimensions().size()];
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
@@ -373,7 +377,7 @@ public class Join extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
+ private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false;
@@ -382,11 +386,10 @@ public class Join extends PrimitiveTensorFunction {
return true;
}
-
/**
* Returns common dimension of a and b as a new tensor type
*/
- private TensorType commonDimensions(Tensor a, Tensor b) {
+ private static TensorType commonDimensions(Tensor a, Tensor b) {
TensorType.Builder typeBuilder = new TensorType.Builder();
TensorType aType = a.type();
TensorType bType = b.type();
@@ -402,14 +405,14 @@ public class Join extends PrimitiveTensorFunction {
return typeBuilder.build();
}
- private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
+ private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
TensorAddress address = cell.getKey();
String[] labels = new String[indexMap.length];
for (int i = 0; i < labels.length; ++i) {
labels[i] = address.label(indexMap[i]);
}
return TensorAddress.of(labels);
-
}
}
+