summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-13 15:21:44 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-13 15:21:44 +0100
commit3783a9b21f8ab7ca3700903d9780a9f7374cf0c5 (patch)
treeec003528946a37b9f0aeb49e1b314fdc6601c26e /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent5b67e6f8f641141f848ad3989156151f9f182441 (diff)
Check agreement between TF and Vespa execution
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java32
1 files changed, 16 insertions, 16 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 9a37127e1f0..ff887e3e9a6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator;
* The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
- *
+ *
* @author bratseth
*/
@Beta
public class Join extends PrimitiveTensorFunction {
-
+
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
@@ -56,7 +56,7 @@ public class Join extends PrimitiveTensorFunction {
if (aDim.name().equals(bDim.name())) { // include
if (aDim.isIndexed() && bDim.isIndexed()) {
if (aDim.size().isPresent() || bDim.size().isPresent())
- typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
bDim.size().orElse(Integer.MAX_VALUE)));
else
typeBuilder.indexed(aDim.name());
@@ -112,11 +112,11 @@ public class Join extends PrimitiveTensorFunction {
else
return generalJoin(a, b, joinedType);
}
-
+
private boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
-
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
@@ -138,7 +138,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
@@ -150,7 +150,7 @@ public class Join extends PrimitiveTensorFunction {
private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
-
+
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
@@ -158,14 +158,14 @@ public class Join extends PrimitiveTensorFunction {
// Find dimensions which are only in the supertype
Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
-
+
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder);
}
-
+
return builder.build();
}
@@ -224,7 +224,7 @@ public class Join extends PrimitiveTensorFunction {
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
-
+
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
@@ -259,7 +259,7 @@ public class Join extends PrimitiveTensorFunction {
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
@@ -276,7 +276,7 @@ public class Join extends PrimitiveTensorFunction {
}
}
}
-
+
private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
@@ -284,7 +284,7 @@ public class Join extends PrimitiveTensorFunction {
builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
return builder.build();
}
-
+
/** Returns the sizes from the joined sizes which are present in the type argument */
private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
@@ -295,7 +295,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
@@ -364,7 +364,7 @@ public class Join extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
@@ -384,7 +384,7 @@ public class Join extends PrimitiveTensorFunction {
return TensorAddress.of(joinedLabels);
}
- /**
+ /**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was