From 1d2eb6ff8ce51018e05b79a021d10383d1e85e16 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 23 Apr 2018 15:19:41 +0200 Subject: Use file distribuition for constant tensors of size 1 --- .../processing/RankingExpressionWithTensorFlowTestCase.java | 7 +++++-- .../integration/tensorflow/TensorFlowImporter.java | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 0b6b1ec9617..b5f64796a83 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; @@ -10,6 +11,8 @@ import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.derived.AttributeFields; +import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; @@ -335,7 +338,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", application); search.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); + assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); @@ -353,7 +356,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); + assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 64c777dbfca..4ec23f98fc5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -281,7 +281,7 @@ public class TensorFlowImporter { operation.setConstantValue(new TensorValue(tensor)); } - if (tensor.type().rank() == 0 || tensor.size() <= 1) { + if (tensor.type().rank() == 0) { model.smallConstant(name, tensor); } else { model.largeConstant(name, tensor); -- cgit v1.2.3 From efd856722c86009c86860f2c5c6aa9f7a0b3152b Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 23 Apr 2018 15:21:44 +0200 Subject: Remove unused imports --- .../processing/RankingExpressionWithTensorFlowTestCase.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index b5f64796a83..623f26a6b27 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; -import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; @@ -9,10 +8,7 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchdefinition.derived.AttributeFields; -import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; -- cgit v1.2.3