diff options
author | Lester Solbakken <lesters@oath.com> | 2018-04-23 15:19:41 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-04-23 15:19:41 +0200 |
commit | 1d2eb6ff8ce51018e05b79a021d10383d1e85e16 (patch) | |
tree | 73c42f7b8632b2e0e130cc6e0c37555b889a3e0d | |
parent | 9c4c0fb0d7b6d37beab393e842833d5ffec2d524 (diff) |
Use file distribuition for constant tensors of size 1
2 files changed, 6 insertions, 3 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 0b6b1ec9617..b5f64796a83 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; @@ -10,6 +11,8 @@ import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.derived.AttributeFields; +import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; @@ -335,7 +338,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", application); search.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); + assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); @@ -353,7 +356,7 @@ public class RankingExpressionWithTensorFlowTestCase { "input", storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search); + assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 64c777dbfca..4ec23f98fc5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -281,7 +281,7 @@ public class TensorFlowImporter { operation.setConstantValue(new TensorValue(tensor)); } - if (tensor.type().rank() == 0 || tensor.size() <= 1) { + if (tensor.type().rank() == 0) { model.smallConstant(name, tensor); } else { model.largeConstant(name, tensor); |