summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-07 13:18:53 +0100
committerGitHub <noreply@github.com>2018-03-07 13:18:53 +0100
commitf509303b9d51b33382d3c2635cb6e0b1fef4b00a (patch)
treed2fe1a7480a1a8c0bd9d343fa6eb7a8a20b84780
parent7c1236b77b7e2264bc6199ac9d3ff974d81462d2 (diff)
parent8be35f5e89f6f9c0d94426497364e132f4fc42fe (diff)
Merge pull request #5236 from vespa-engine/lesters/fix-tf-dimension-renaming-at-join
Fix TensorFlow dimension renaming at join
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java6
3 files changed, 7 insertions, 5 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index c650151980c..8e404e72ec7 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -288,7 +288,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
application);
search.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search);
+ assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search);
search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
@@ -306,7 +306,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search);
+ assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search);
searchFromStored.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
searchFromStored.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
index 3742e443a06..db762d5ddb0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -42,6 +42,8 @@ public class OrderedTensorType {
return this.type;
}
+ public int rank() { return dimensions.size(); }
+
public List<TensorType.Dimension> dimensions() {
return dimensions;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
index aa27ba2684d..3e6e036636d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
@@ -61,10 +61,10 @@ public class Join extends TensorFlowOperation {
// "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
// Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
- TensorType a = inputs.get(0).type().get().type();
- TensorType b = inputs.get(1).type().get().type();
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
if (a.rank() < b.rank()) {
- TensorType temp = a;
+ OrderedTensorType temp = a;
a = b;
b = temp;
}