diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-04-23 19:02:59 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-23 19:02:59 +0200 |
commit | 25753e098fc7b8ad4dab12d344ac9b4e276f5d2a (patch) | |
tree | 78394bac50b26d89c38d7fa9a9d1a084662d0602 | |
parent | 0cbf66a3e9c024df5e53a1ca14056dfd43c68a74 (diff) | |
parent | efd856722c86009c86860f2c5c6aa9f7a0b3152b (diff) |
Merge pull request #5678 from vespa-engine/lesters/tensorflow-small-constants
Use file distribution for constant tensors of size 1
2 files changed, 3 insertions, 4 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 0b6b1ec9617..623f26a6b27 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -8,7 +8,6 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -335,7 +334,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", application); search.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); + assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); @@ -353,7 +352,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); + assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 64c777dbfca..4ec23f98fc5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -281,7 +281,7 @@ public class TensorFlowImporter { operation.setConstantValue(new TensorValue(tensor)); } - if (tensor.type().rank() == 0 || tensor.size() <= 1) { + if (tensor.type().rank() == 0) { model.smallConstant(name, tensor); } else { model.largeConstant(name, tensor); |