summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-02 11:16:33 +0100
committerLester Solbakken <lesters@oath.com>2018-02-02 11:16:33 +0100
commita55b45e2bb1442b94480b585ab2b973c15d4be36 (patch)
treed8efd7092895096be3a043654bc7885b82180eba /searchlib
parent93f30cd3912913ffbdc292d32415474466e39bd2 (diff)
Replace / in Tensorflow constants and placeholders to _
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java23
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java2
5 files changed, 21 insertions, 14 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index cdcb4df0360..b8f8e288257 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -204,20 +204,21 @@ class OperationMapper {
private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) {
String name = params.node().getName();
+ String vespaName = toVespaName(params.node().getName());
TensorType type = params.result().arguments().get(name);
if (type == null) {
throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
"', but there is no such placeholder");
}
- params.result().requiredMacro(name, type);
+ params.result().requiredMacro(vespaName, type);
// Included literally in the expression and so must be produced by a separate macro in the rank profile
- TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name, type));
+ TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(vespaName, type));
return Optional.of(output);
}
private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) {
- String name = params.node().getInput(0);
- Tensor defaultValue = getConstantTensor(params, name);
+ String name = toVespaName(params.node().getInput(0));
+ Tensor defaultValue = getConstantTensor(params, params.node().getInput(0));
params.result().constant(name, defaultValue);
params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
// The default value will be provided by the macro. Users can override macro to change value.
@@ -542,16 +543,18 @@ class OperationMapper {
}
private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) {
- params.result().constant(params.node().getName(), constant);
+ String name = toVespaName(params.node().getName());
+ params.result().constant(name, constant);
TypedTensorFunction output = new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(
- new ReferenceNode("constant(\"" + params.node().getName() + "\")")));
+ new ReferenceNode("constant(\"" + name + "\")")));
return Optional.of(output);
}
private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) {
- if (params.result().constants().containsKey(name)) {
- return params.result().constants().get(name);
+ String vespaName = toVespaName(name);
+ if (params.result().constants().containsKey(vespaName)) {
+ return params.result().constants().get(vespaName);
}
Session.Runner fetched = params.model().session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
@@ -692,4 +695,8 @@ class OperationMapper {
return true;
}
+ public static String toVespaName(String name) {
+ return name != null ? name.replace('/', '_') : null;
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 6d78b501fdc..c97ee2b1514 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -120,7 +120,7 @@ public class TensorFlowImporter {
// We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
// will be used. We parse the TensorFunction here to convert it to a RankingExpression tree
params.result().expression(nodeName,
- new RankingExpression(params.node().getName(), function.get().function().toString()));
+ new RankingExpression(nodeName, function.get().function().toString()));
return function;
}
catch (ParseException e) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index fe725e50a3f..530f4793b62 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -37,7 +37,7 @@ public class TensorFlowModel {
/** Returns the given signature. If it does not already exist it is added to this. */
Signature signature(String name) {
- return signatures.computeIfAbsent(name, n -> new Signature(n));
+ return signatures.computeIfAbsent(name, Signature::new);
}
/** Returns an immutable map of the arguments ("Placeholders") of this */
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index 445ccf231a7..3b25bfe1b1e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -20,8 +20,8 @@ public class DropoutImportTestCase {
// Check (provided) macros
assertEquals(1, model.get().macros().size());
- assertTrue(model.get().macros().containsKey("training/input"));
- assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString());
+ assertTrue(model.get().macros().containsKey("training_input"));
+ assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString());
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
@@ -37,7 +37,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/BiasAdd", output.getName());
- assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs/kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs/bias\"), d0, d1), f(a,b)(a + b))",
+ assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 127b63c66c9..2c621fd2e92 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -52,7 +52,7 @@ public class TestableTensorFlowModel {
runner.feed(inputName, placeholder);
List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
assertEquals(1, results.size());
- return new TensorConverter().toVespaTensor(results.get(0));
+ return TensorConverter.toVespaTensor(results.get(0));
}
private Context contextFrom(TensorFlowModel result) {