summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-04-23 15:19:41 +0200
committerLester Solbakken <lesters@oath.com>2018-04-23 15:19:41 +0200
commit1d2eb6ff8ce51018e05b79a021d10383d1e85e16 (patch)
tree73c42f7b8632b2e0e130cc6e0c37555b889a3e0d
parent9c4c0fb0d7b6d37beab393e842833d5ffec2d524 (diff)
Use file distribuition for constant tensors of size 1
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java2
2 files changed, 6 insertions, 3 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 0b6b1ec9617..b5f64796a83 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
@@ -10,6 +11,8 @@ import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchdefinition.derived.AttributeFields;
+import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
@@ -335,7 +338,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
application);
search.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search);
+ assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
@@ -353,7 +356,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("mnist_saved_dnn_hidden2_Const", TensorType.fromSpec("tensor(d2[1])"), search);
+ assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 64c777dbfca..4ec23f98fc5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -281,7 +281,7 @@ public class TensorFlowImporter {
operation.setConstantValue(new TensorValue(tensor));
}
- if (tensor.type().rank() == 0 || tensor.size() <= 1) {
+ if (tensor.type().rank() == 0) {
model.smallConstant(name, tensor);
} else {
model.largeConstant(name, tensor);