From a55b45e2bb1442b94480b585ab2b973c15d4be36 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 2 Feb 2018 11:16:33 +0100 Subject: Replace / in Tensorflow constants and placeholders to _ --- .../integration/tensorflow/OperationMapper.java | 23 ++++++++++++++-------- .../integration/tensorflow/TensorFlowImporter.java | 2 +- .../integration/tensorflow/TensorFlowModel.java | 2 +- .../tensorflow/DropoutImportTestCase.java | 6 +++--- .../tensorflow/TestableTensorFlowModel.java | 2 +- 5 files changed, 21 insertions(+), 14 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index cdcb4df0360..b8f8e288257 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -204,20 +204,21 @@ class OperationMapper { private static Optional placeholder(TensorFlowImporter.Parameters params) { String name = params.node().getName(); + String vespaName = toVespaName(params.node().getName()); TensorType type = params.result().arguments().get(name); if (type == null) { throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + "', but there is no such placeholder"); } - params.result().requiredMacro(name, type); + params.result().requiredMacro(vespaName, type); // Included literally in the expression and so must be produced by a separate macro in the rank profile - TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name, type)); + TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(vespaName, type)); return Optional.of(output); } private static Optional placeholderWithDefault(TensorFlowImporter.Parameters params) { - String name = params.node().getInput(0); - Tensor defaultValue = getConstantTensor(params, name); + String name = toVespaName(params.node().getInput(0)); + Tensor defaultValue = getConstantTensor(params, params.node().getInput(0)); params.result().constant(name, defaultValue); params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); // The default value will be provided by the macro. Users can override macro to change value. @@ -542,16 +543,18 @@ class OperationMapper { } private static Optional createConstant(TensorFlowImporter.Parameters params, Tensor constant) { - params.result().constant(params.node().getName(), constant); + String name = toVespaName(params.node().getName()); + params.result().constant(name, constant); TypedTensorFunction output = new TypedTensorFunction(constant.type(), new TensorFunctionNode.TensorFunctionExpressionNode( - new ReferenceNode("constant(\"" + params.node().getName() + "\")"))); + new ReferenceNode("constant(\"" + name + "\")"))); return Optional.of(output); } private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { - if (params.result().constants().containsKey(name)) { - return params.result().constants().get(name); + String vespaName = toVespaName(name); + if (params.result().constants().containsKey(vespaName)) { + return params.result().constants().get(vespaName); } Session.Runner fetched = params.model().session().runner().fetch(name); List> importedTensors = fetched.run(); @@ -692,4 +695,8 @@ class OperationMapper { return true; } + public static String toVespaName(String name) { + return name != null ? name.replace('/', '_') : null; + } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 6d78b501fdc..c97ee2b1514 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -120,7 +120,7 @@ public class TensorFlowImporter { // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree params.result().expression(nodeName, - new RankingExpression(params.node().getName(), function.get().function().toString())); + new RankingExpression(nodeName, function.get().function().toString())); return function; } catch (ParseException e) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index fe725e50a3f..530f4793b62 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -37,7 +37,7 @@ public class TensorFlowModel { /** Returns the given signature. If it does not already exist it is added to this. */ Signature signature(String name) { - return signatures.computeIfAbsent(name, n -> new Signature(n)); + return signatures.computeIfAbsent(name, Signature::new); } /** Returns an immutable map of the arguments ("Placeholders") of this */ diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index 445ccf231a7..3b25bfe1b1e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -20,8 +20,8 @@ public class DropoutImportTestCase { // Check (provided) macros assertEquals(1, model.get().macros().size()); - assertTrue(model.get().macros().containsKey("training/input")); - assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString()); + assertTrue(model.get().macros().containsKey("training_input")); + assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString()); // Check required macros assertEquals(1, model.get().requiredMacros().size()); @@ -37,7 +37,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs/kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs/bias\"), d0, d1), f(a,b)(a + b))", + assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 127b63c66c9..2c621fd2e92 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -52,7 +52,7 @@ public class TestableTensorFlowModel { runner.feed(inputName, placeholder); List> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); - return new TensorConverter().toVespaTensor(results.get(0)); + return TensorConverter.toVespaTensor(results.get(0)); } private Context contextFrom(TensorFlowModel result) { -- cgit v1.2.3