aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-03-08 16:51:10 +0100
committerLester Solbakken <lesters@oath.com>2018-03-08 17:01:07 +0100
commit7ce1417a1a0c4cd3cd72d903f0c9ffde93baade8 (patch)
treec7b23f9bab1481f31b0eadeb7e4f34ef75939c2b /searchlib/src/main
parent697f91c9753b25b5074be2a4d99c10987c5cac62 (diff)
Make TensorFlow import joins compatible with broadcasting
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java116
1 files changed, 91 insertions, 25 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
index 3e6e036636d..ae30c2850bb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
@@ -4,9 +4,11 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.op
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import org.tensorflow.framework.NodeDef;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
@@ -25,10 +27,52 @@ public class Join extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- OrderedTensorType out = a.type().rank() >= b.type().rank() ? a : b;
- return out;
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+
+ // Well now we have potentially entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // In broadcasting, the size of each dimension is compared element-wise,
+ // starting with the trailing dimensions and working forward. A special
+ // case occurs when the size of one dimension is 1, while the other is not.
+ // Then the dimension with size 1 is "stretched" to be of compatible size.
+ //
+ // An example:
+ //
+ // Tensor A: d0[5], d1[1], d2[3], d3[1]
+ // Tensor B: d1[4], d2[1], d3[2]
+ //
+ // In TensorFlow and using the above rules of broadcasting, the resulting
+ // type is:
+ // d0[5], d1[4], d2[3], d2[2]
+ //
+ // However, in Vespa's tensor logic, the join of the two above tensors would
+ // result in a tensor of type:
+ // d0[5], d1[1], d2[1], d3[1]
+ //
+ // By reducing the dimensions of size 1 in each tensor before joining,
+ // we get equal results as in TensorFlow.
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < a.rank(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ long size = aDim.size().orElse(-1L);
+
+ if (i - sizeDifference > 0) {
+ TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
+ size = Math.max(size, bDim.size().orElse(-1L));
+ }
+
+ if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name(), size));
+ } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name()));
+ } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
+ builder.add(TensorType.Dimension.mapped(aDim.name()));
+ }
+ }
+ return builder.build();
}
@Override
@@ -36,15 +80,39 @@ public class Join extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ if (!allInputFunctionsPresent(2)) {
return null;
}
- // The dimension renaming below takes care of broadcasting.
+ TensorFlowOperation a = largestInput();
+ TensorFlowOperation b = smallestInput();
+
+ List<String> aDimensionsToReduce = new ArrayList<>();
+ List<String> bDimensionsToReduce = new ArrayList<>();
+ int sizeDifference = a.type().get().rank() - b.type().get().rank();
+ for (int i = 0; i < b.type().get().rank(); ++i) {
+ TensorType.Dimension bDim = b.type().get().dimensions().get(i);
+ TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
+ long bSize = bDim.size().orElse(-1L);
+ long aSize = aDim.size().orElse(-1L);
+ if (bSize == 1L && aSize != 1L) {
+ bDimensionsToReduce.add(bDim.name());
+ }
+ if (aSize == 1L && bSize != 1L) {
+ aDimensionsToReduce.add(bDim.name());
+ }
+ }
+
+ TensorFunction aReducedFunction = a.function().get();
+ if (aDimensionsToReduce.size() > 0) {
+ aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
+ }
+ TensorFunction bReducedFunction = b.function().get();
+ if (bDimensionsToReduce.size() > 0) {
+ bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
+ }
- return new com.yahoo.tensor.functions.Join(aFunction.get(), bFunction.get(), operator);
+ return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
}
@Override
@@ -52,22 +120,8 @@ public class Join extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return;
}
-
- // Well now we have potentially entered the wonderful world of "broadcasting"
- // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
- // I'm not able to extract from that any unambiguous specification of which dimensions
- // should be "stretched" when the tensor do not have the same number of dimensions.
- // From trying this with TensorFlow it appears that the second tensor is matched to the
- // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
- // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
-
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- if (a.rank() < b.rank()) {
- OrderedTensorType temp = a;
- a = b;
- b = temp;
- }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
int sizeDifference = a.rank() - b.rank();
for (int i = 0; i < b.rank(); ++i) {
String bDim = b.dimensions().get(i).name();
@@ -76,4 +130,16 @@ public class Join extends TensorFlowOperation {
}
}
+ private TensorFlowOperation largestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+ private TensorFlowOperation smallestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
}