summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java83
1 files changed, 14 insertions, 69 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 6128611302f..19b4ad39af3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -2,20 +2,18 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-import java.util.Arrays;
-import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
@@ -71,7 +69,7 @@ public class Join extends PrimitiveTensorFunction {
TensorType joinedType = a.type().combineWith(b.type());
// Choose join algorithm
- if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
+ if (a.type().equals(b.type()) && a.type().dimensions().size() == 1 && a.type().dimensions().get(0).isIndexed())
return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType);
@@ -83,12 +81,8 @@ public class Join extends PrimitiveTensorFunction {
return generalJoin(a, b, joinedType);
}
- private boolean hasSingleIndexedDimension(Tensor tensor) {
- return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
- }
-
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.size(0), b.size(0));
+ int joinedLength = Math.min(a.length(0), b.length(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength});
@@ -111,42 +105,6 @@ public class Join extends PrimitiveTensorFunction {
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
- if (subspace.type().isIndexed() && superspace.type().isIndexed())
- return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
- else
- return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
- }
-
- private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
- if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
- return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build();
-
- // Find size of joined tensor
- int[] joinedSizes = new int[joinedType.dimensions().size()];
- for (int i = 0; i < joinedSizes.length; i++) {
- Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name());
- if (subspaceIndex.isPresent())
- joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get()));
- else
- joinedSizes[i] = superspace.size(i);
- }
-
- Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes);
-
- // Find dimensions which are only in the supertype
- Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
- superDimensionNames.removeAll(subspace.type().dimensionNames());
-
- for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
- IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
- joinSubspaces(subspace.valueIterator(), subspace.size(),
- subspaceInSuper, subspaceInSuper.size(),
- reversedArgumentOrder, builder);
- }
- return builder.build();
- }
-
- private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Map.Entry<TensorAddress, Double>> i = superspace.cellIterator(); i.hasNext(); ) {
@@ -154,26 +112,13 @@ public class Join extends PrimitiveTensorFunction {
TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
double subspaceValue = subspace.get(subaddress);
if ( ! Double.isNaN(subspaceValue))
- builder.cell(supercell.getKey(),
+ builder.cell(supercell.getKey(),
reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
: combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
return builder.build();
}
-
- private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
- Iterator<Map.Entry<TensorAddress, Double>> superspace, int superspaceSize,
- boolean reversedArgumentOrder, Tensor.Builder builder) {
- int joinedLength = Math.min(subspaceSize, superspaceSize);
- for (int i = 0; i < joinedLength; i++) {
- Double subvalue = subspace.next();
- Map.Entry<TensorAddress, Double> supercell = superspace.next();
- builder.cell(supercell.getKey(),
- reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subvalue)
- : combinator.applyAsDouble(subvalue, supercell.getValue()));
- }
- }
-
+
/** Returns the indexes in the superspace type which should be retained to create the subspace type */
private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
int[] subspaceIndexes = new int[subtype.dimensions().size()];
@@ -185,8 +130,8 @@ public class Join extends PrimitiveTensorFunction {
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
- subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
- return TensorAddress.of(subspaceLabels);
+ subspaceLabels[i] = superAddress.labels().get(subspaceIndexes[i]);
+ return new TensorAddress(subspaceLabels);
}
/** Slow join which works for any two tensors */
@@ -224,10 +169,10 @@ public class Join extends PrimitiveTensorFunction {
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
String[] joinedLabels = new String[joinedType.dimensions().size()];
- mapContent(a, joinedLabels, aToIndexes);
- boolean compatible = mapContent(b, joinedLabels, bToIndexes);
+ mapContent(a.labels(), joinedLabels, aToIndexes);
+ boolean compatible = mapContent(b.labels(), joinedLabels, bToIndexes);
if ( ! compatible) return null;
- return TensorAddress.of(joinedLabels);
+ return new TensorAddress(joinedLabels);
}
/**
@@ -236,11 +181,11 @@ public class Join extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
+ private boolean mapContent(List<String> from, String[] to, int[] indexMap) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
- if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false;
- to[toIndex] = from.label(i);
+ if (to[toIndex] != null && ! to[toIndex].equals(from.get(i))) return false;
+ to[toIndex] = from.get(i);
}
return true;
}