From 0d7be904836d79d3658daed98d3cbb1339c4aeb7 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 5 Feb 2018 18:54:32 +0100 Subject: Support small constants --- .../integration/tensorflow/TensorFlowModel.java | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 530f4793b62..351aa417f9c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -24,13 +24,15 @@ public class TensorFlowModel { private final Map signatures = new HashMap<>(); private final Map arguments = new HashMap<>(); - private final Map constants = new HashMap<>(); + private final Map smallConstants = new HashMap<>(); + private final Map largeConstants = new HashMap<>(); private final Map expressions = new HashMap<>(); private final Map macros = new HashMap<>(); private final Map requiredMacros = new HashMap<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } - void constant(String name, Tensor constant) { constants.put(name, constant); } + void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void macro(String name, RankingExpression expression) { macros.put(name, expression); } void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } @@ -43,8 +45,19 @@ public class TensorFlowModel { /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map arguments() { return Collections.unmodifiableMap(arguments); } - /** Returns an immutable map of the constants of this */ - public Map constants() { return Collections.unmodifiableMap(constants); } + /** + * Returns an immutable map of the small constants of this. + * These should have sizes up to a few kb at most, and correspond to constant + * values given in the TensorFlow source. + */ + public Map smallConstants() { return Collections.unmodifiableMap(smallConstants); } + + /** + * Returns an immutable map of the large constants of this. + * These can have sizes in gigabytes and must be distributed to nodes separately from configuration, + * and correspond to Variable files stored separately in TensorFlow. + */ + public Map largeConstants() { return Collections.unmodifiableMap(largeConstants); } /** * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes -- cgit v1.2.3 From 69ba083a5edb91609bfb4ac14c88160fb064add9 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 5 Feb 2018 19:06:08 +0100 Subject: Store small constants separately --- .../integration/tensorflow/OperationMapper.java | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index b8f8e288257..55782c36d18 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -219,7 +219,7 @@ class OperationMapper { private static Optional placeholderWithDefault(TensorFlowImporter.Parameters params) { String name = toVespaName(params.node().getInput(0)); Tensor defaultValue = getConstantTensor(params, params.node().getInput(0)); - params.result().constant(name, defaultValue); + params.result().largeConstant(name, defaultValue); params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); // The default value will be provided by the macro. Users can override macro to change value. TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); @@ -544,7 +544,11 @@ class OperationMapper { private static Optional createConstant(TensorFlowImporter.Parameters params, Tensor constant) { String name = toVespaName(params.node().getName()); - params.result().constant(name, constant); + if (constant.type().rank() == 0 || constant.size() <= 1) { + params.result().smallConstant(name, constant); + } else { + params.result().largeConstant(name, constant); + } TypedTensorFunction output = new TypedTensorFunction(constant.type(), new TensorFunctionNode.TensorFunctionExpressionNode( new ReferenceNode("constant(\"" + name + "\")"))); @@ -553,8 +557,11 @@ class OperationMapper { private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { String vespaName = toVespaName(name); - if (params.result().constants().containsKey(vespaName)) { - return params.result().constants().get(vespaName); + if (params.result().smallConstants().containsKey(vespaName)) { + return params.result().smallConstants().get(vespaName); + } + if (params.result().largeConstants().containsKey(vespaName)) { + return params.result().largeConstants().get(vespaName); } Session.Runner fetched = params.model().session().runner().fetch(name); List> importedTensors = fetched.run(); -- cgit v1.2.3 From 62de95451cf663f3f43532d2c4746eaa1b678d95 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 5 Feb 2018 19:10:56 +0100 Subject: Fix unit tests for large/small tensor constants --- .../integration/tensorflow/MnistSoftmaxImportTestCase.java | 6 +++--- .../integration/tensorflow/TestableTensorFlowModel.java | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 01dd15d5fa0..ad5abd4c03d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -20,15 +20,15 @@ public class MnistSoftmaxImportTestCase { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, model.get().constants().size()); + assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().constants().get("Variable"); + Tensor constant0 = model.get().largeConstants().get("Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().constants().get("Variable_1"); + Tensor constant1 = model.get().largeConstants().get("Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d0", 10).build(), constant1.type()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 2c621fd2e92..ae7714b271a 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -57,7 +57,8 @@ public class TestableTensorFlowModel { private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); return context; } -- cgit v1.2.3