diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-07 13:18:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-03-07 13:18:53 +0100 |
commit | f509303b9d51b33382d3c2635cb6e0b1fef4b00a (patch) | |
tree | d2fe1a7480a1a8c0bd9d343fa6eb7a8a20b84780 | |
parent | 7c1236b77b7e2264bc6199ac9d3ff974d81462d2 (diff) | |
parent | 8be35f5e89f6f9c0d94426497364e132f4fc42fe (diff) |
Merge pull request #5236 from vespa-engine/lesters/fix-tf-dimension-renaming-at-join
Fix TensorFlow dimension renaming at join
3 files changed, 7 insertions, 5 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index c650151980c..8e404e72ec7 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -288,7 +288,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", application); search.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search); + assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); @@ -306,7 +306,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search); + assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); searchFromStored.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); searchFromStored.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java index 3742e443a06..db762d5ddb0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -42,6 +42,8 @@ public class OrderedTensorType { return this.type; } + public int rank() { return dimensions.size(); } + public List<TensorType.Dimension> dimensions() { return dimensions; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java index aa27ba2684d..3e6e036636d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java @@ -61,10 +61,10 @@ public class Join extends TensorFlowOperation { // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). - TensorType a = inputs.get(0).type().get().type(); - TensorType b = inputs.get(1).type().get().type(); + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); if (a.rank() < b.rank()) { - TensorType temp = a; + OrderedTensorType temp = a; a = b; b = temp; } |