diff options
author | Lester Solbakken <lesters@oath.com> | 2018-02-05 19:06:08 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-02-05 19:06:08 +0100 |
commit | 69ba083a5edb91609bfb4ac14c88160fb064add9 (patch) | |
tree | 8a98b7b2871abf13cc7c34deb789accd54961417 | |
parent | 0d7be904836d79d3658daed98d3cbb1339c4aeb7 (diff) |
Store small constants separately
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index b8f8e288257..55782c36d18 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -219,7 +219,7 @@ class OperationMapper { private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) { String name = toVespaName(params.node().getInput(0)); Tensor defaultValue = getConstantTensor(params, params.node().getInput(0)); - params.result().constant(name, defaultValue); + params.result().largeConstant(name, defaultValue); params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); // The default value will be provided by the macro. Users can override macro to change value. TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); @@ -544,7 +544,11 @@ class OperationMapper { private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) { String name = toVespaName(params.node().getName()); - params.result().constant(name, constant); + if (constant.type().rank() == 0 || constant.size() <= 1) { + params.result().smallConstant(name, constant); + } else { + params.result().largeConstant(name, constant); + } TypedTensorFunction output = new TypedTensorFunction(constant.type(), new TensorFunctionNode.TensorFunctionExpressionNode( new ReferenceNode("constant(\"" + name + "\")"))); @@ -553,8 +557,11 @@ class OperationMapper { private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { String vespaName = toVespaName(name); - if (params.result().constants().containsKey(vespaName)) { - return params.result().constants().get(vespaName); + if (params.result().smallConstants().containsKey(vespaName)) { + return params.result().smallConstants().get(vespaName); + } + if (params.result().largeConstants().containsKey(vespaName)) { + return params.result().largeConstants().get(vespaName); } Session.Runner fetched = params.model().session().runner().fetch(name); List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); |