summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-05 19:06:08 +0100
committerLester Solbakken <lesters@oath.com>2018-02-05 19:06:08 +0100
commit69ba083a5edb91609bfb4ac14c88160fb064add9 (patch)
tree8a98b7b2871abf13cc7c34deb789accd54961417
parent0d7be904836d79d3658daed98d3cbb1339c4aeb7 (diff)
Store small constants separately
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java15
1 files changed, 11 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index b8f8e288257..55782c36d18 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -219,7 +219,7 @@ class OperationMapper {
private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) {
String name = toVespaName(params.node().getInput(0));
Tensor defaultValue = getConstantTensor(params, params.node().getInput(0));
- params.result().constant(name, defaultValue);
+ params.result().largeConstant(name, defaultValue);
params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
// The default value will be provided by the macro. Users can override macro to change value.
TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
@@ -544,7 +544,11 @@ class OperationMapper {
private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) {
String name = toVespaName(params.node().getName());
- params.result().constant(name, constant);
+ if (constant.type().rank() == 0 || constant.size() <= 1) {
+ params.result().smallConstant(name, constant);
+ } else {
+ params.result().largeConstant(name, constant);
+ }
TypedTensorFunction output = new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(
new ReferenceNode("constant(\"" + name + "\")")));
@@ -553,8 +557,11 @@ class OperationMapper {
private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) {
String vespaName = toVespaName(name);
- if (params.result().constants().containsKey(vespaName)) {
- return params.result().constants().get(vespaName);
+ if (params.result().smallConstants().containsKey(vespaName)) {
+ return params.result().smallConstants().get(vespaName);
+ }
+ if (params.result().largeConstants().containsKey(vespaName)) {
+ return params.result().largeConstants().get(vespaName);
}
Session.Runner fetched = params.model().session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();