summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-04-23 19:02:59 +0200
committerGitHub <noreply@github.com>2018-04-23 19:02:59 +0200
commit25753e098fc7b8ad4dab12d344ac9b4e276f5d2a (patch)
tree78394bac50b26d89c38d7fa9a9d1a084662d0602
parent0cbf66a3e9c024df5e53a1ca14056dfd43c68a74 (diff)
parentefd856722c86009c86860f2c5c6aa9f7a0b3152b (diff)
Merge pull request #5678 from vespa-engine/lesters/tensorflow-small-constants
Use file distribution for constant tensors of size 1
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java2
2 files changed, 3 insertions, 4 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 0b6b1ec9617..623f26a6b27 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -8,7 +8,6 @@ import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -335,7 +334,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
application);
search.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search);
+ assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
@@ -353,7 +352,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search);
+ assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 64c777dbfca..4ec23f98fc5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -281,7 +281,7 @@ public class TensorFlowImporter {
operation.setConstantValue(new TensorValue(tensor));
}
- if (tensor.type().rank() == 0 || tensor.size() <= 1) {
+ if (tensor.type().rank() == 0) {
model.smallConstant(name, tensor);
} else {
model.largeConstant(name, tensor);