From 625a97da76c7056f6bacdd6a9e7c9dda4282623f Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 13 Mar 2018 14:34:54 +0100 Subject: Prefix macro names by modelName to avoid name conflicts --- .../tensorflow/importer/OperationMapper.java | 66 +++++++++++----------- .../tensorflow/importer/operations/Const.java | 7 +-- .../tensorflow/importer/operations/ExpandDims.java | 4 +- .../tensorflow/importer/operations/Identity.java | 7 +-- .../tensorflow/importer/operations/Join.java | 4 +- .../tensorflow/importer/operations/Map.java | 4 +- .../tensorflow/importer/operations/Matmul.java | 4 +- .../tensorflow/importer/operations/Mean.java | 4 +- .../tensorflow/importer/operations/Merge.java | 4 +- .../tensorflow/importer/operations/NoOp.java | 4 +- .../importer/operations/Placeholder.java | 4 +- .../operations/PlaceholderWithDefault.java | 4 +- .../tensorflow/importer/operations/Reshape.java | 4 +- .../tensorflow/importer/operations/Select.java | 4 +- .../tensorflow/importer/operations/Shape.java | 4 +- .../tensorflow/importer/operations/Squeeze.java | 4 +- .../tensorflow/importer/operations/Switch.java | 4 +- .../importer/operations/TensorFlowOperation.java | 9 ++- .../tensorflow/importer/operations/Variable.java | 7 +-- .../tensorflow/DropoutImportTestCase.java | 2 +- 20 files changed, 75 insertions(+), 79 deletions(-) (limited to 'searchlib/src') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java index 1b87c302835..d5a3d2d69a3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -36,44 +36,44 @@ public class OperationMapper { switch (node.getOp().toLowerCase()) { // array ops case "const": return new Const(modelName, node, inputs, port); - case "expanddims": return new ExpandDims(node, inputs, port); + case "expanddims": return new ExpandDims(modelName, node, inputs, port); case "identity": return new Identity(modelName, node, inputs, port); - case "placeholder": return new Placeholder(node, inputs, port); - case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port); - case "reshape": return new Reshape(node, inputs, port); - case "shape": return new Shape(node, inputs, port); - case "squeeze": return new Squeeze(node, inputs, port); + case "placeholder": return new Placeholder(modelName, node, inputs, port); + case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port); + case "reshape": return new Reshape(modelName, node, inputs, port); + case "shape": return new Shape(modelName, node, inputs, port); + case "squeeze": return new Squeeze(modelName, node, inputs, port); // control flow - case "merge": return new Merge(node, inputs, port); - case "switch": return new Switch(node, inputs, port); + case "merge": return new Merge(modelName, node, inputs, port); + case "switch": return new Switch(modelName, node, inputs, port); // math ops - case "add": return new Join(node, inputs, port, ScalarFunctions.add()); - case "add_n": return new Join(node, inputs, port, ScalarFunctions.add()); - case "acos": return new Map(node, inputs, port, ScalarFunctions.acos()); - case "div": return new Join(node, inputs, port, ScalarFunctions.divide()); - case "realdiv": return new Join(node, inputs, port, ScalarFunctions.divide()); - case "floor": return new Map(node, inputs, port, ScalarFunctions.floor()); - case "matmul": return new Matmul(node, inputs, port); - case "maximum": return new Join(node, inputs, port, ScalarFunctions.max()); - case "mean": return new Mean(node, inputs, port); - case "reducemean": return new Mean(node, inputs, port); - case "mul": return new Join(node, inputs, port, ScalarFunctions.multiply()); - case "multiply": return new Join(node, inputs, port, ScalarFunctions.multiply()); - case "rsqrt": return new Map(node, inputs, port, ScalarFunctions.rsqrt()); - case "select": return new Select(node, inputs, port); - case "where3": return new Select(node, inputs, port); - case "sigmoid": return new Map(node, inputs, port, ScalarFunctions.sigmoid()); - case "squareddifference": return new Join(node, inputs, port, ScalarFunctions.squareddifference()); - case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract()); - case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract()); + case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); + case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); + case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos()); + case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); + case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); + case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor()); + case "matmul": return new Matmul(modelName, node, inputs, port); + case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max()); + case "mean": return new Mean(modelName, node, inputs, port); + case "reducemean": return new Mean(modelName, node, inputs, port); + case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); + case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); + case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt()); + case "select": return new Select(modelName, node, inputs, port); + case "where3": return new Select(modelName, node, inputs, port); + case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference()); + case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); + case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); // nn ops - case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add()); - case "elu": return new Map(node, inputs, port, ScalarFunctions.elu()); - case "relu": return new Map(node, inputs, port, ScalarFunctions.relu()); - case "selu": return new Map(node, inputs, port, ScalarFunctions.selu()); + case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); + case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu()); + case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu()); + case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu()); // state ops case "variable": return new Variable(modelName, node, inputs, port); @@ -81,9 +81,9 @@ public class OperationMapper { // evaluation no-ops case "stopgradient":return new Identity(modelName, node, inputs, port); - case "noop": return new NoOp(node, inputs, port); + case "noop": return new NoOp(modelName, node, inputs, port); } - return new NoOp(node, inputs, port); + return new NoOp(modelName, node, inputs, port); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java index eb2a82fe114..718e2a4b3c2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java @@ -23,11 +23,8 @@ import java.util.Optional; public class Const extends TensorFlowOperation { - private final String modelName; - public Const(String modelName, NodeDef node, List inputs, int port) { - super(node, inputs, port); - this.modelName = modelName; + super(modelName, node, inputs, port); setConstantValue(value()); } @@ -58,7 +55,7 @@ public class Const extends TensorFlowOperation { /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + super.vespaName(); + return modelName() + "_" + super.vespaName(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java index c1ad21f41d8..2d0f4c7042b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java @@ -22,8 +22,8 @@ public class ExpandDims extends TensorFlowOperation { private List expandDimensions; - public ExpandDims(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public ExpandDims(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java index 306232bb9ff..1408e7e04f0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java @@ -9,17 +9,14 @@ import java.util.List; public class Identity extends TensorFlowOperation { - private final String modelName; - public Identity(String modelName, NodeDef node, List inputs, int port) { - super(node, inputs, port); - this.modelName = modelName; + super(modelName, node, inputs, port); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + super.vespaName(); + return modelName() + "_" + super.vespaName(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java index 0f9833567c7..6cbfe0dfb05 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java @@ -17,8 +17,8 @@ public class Join extends TensorFlowOperation { private final DoubleBinaryOperator operator; - public Join(NodeDef node, List inputs, int port, DoubleBinaryOperator operator) { - super(node, inputs, port); + public Join(String modelName, NodeDef node, List inputs, int port, DoubleBinaryOperator operator) { + super(modelName, node, inputs, port); this.operator = operator; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java index 105d65b3d69..c015f5ecba8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java @@ -13,8 +13,8 @@ public class Map extends TensorFlowOperation { private final DoubleUnaryOperator operator; - public Map(NodeDef node, List inputs, int port, DoubleUnaryOperator operator) { - super(node, inputs, port); + public Map(String modelName, NodeDef node, List inputs, int port, DoubleUnaryOperator operator) { + super(modelName, node, inputs, port); this.operator = operator; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java index ac4f78653d6..b2b9530a161 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java @@ -12,8 +12,8 @@ import java.util.Optional; public class Matmul extends TensorFlowOperation { - public Matmul(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Matmul(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java index 83f9b37e631..3eba872c6a0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java @@ -25,8 +25,8 @@ public class Mean extends TensorFlowOperation { private List reduceDimensions; - public Mean(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Mean(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java index d3561716725..4c95e67e184 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java @@ -9,8 +9,8 @@ import java.util.List; public class Merge extends TensorFlowOperation { - public Merge(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Merge(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java index acf5d13b057..d558ec89e87 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java @@ -10,8 +10,8 @@ import java.util.List; public class NoOp extends TensorFlowOperation { - public NoOp(NodeDef node, List inputs, int port) { - super(node, Collections.emptyList(), port); // don't propagate inputs + public NoOp(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, Collections.emptyList(), port); // don't propagate inputs } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java index dadce395faf..eb4b615b434 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java @@ -15,8 +15,8 @@ public class Placeholder extends TensorFlowOperation { private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - public Placeholder(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Placeholder(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); standardNamingType = OrderedTensorType.fromTensorFlowType(node); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java index 4e5709505ce..f74d1d6cb75 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java @@ -11,8 +11,8 @@ import java.util.Optional; public class PlaceholderWithDefault extends TensorFlowOperation { - public PlaceholderWithDefault(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public PlaceholderWithDefault(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java index 9b3e28ce56b..e7d90e5fc1f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java @@ -30,8 +30,8 @@ import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.impor public class Reshape extends TensorFlowOperation { - public Reshape(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Reshape(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java index 6a29d428cf3..5fdcb5a695f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java @@ -17,8 +17,8 @@ import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.impor public class Select extends TensorFlowOperation { - public Select(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Select(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java index 8f4313022e0..af49d2c108b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java @@ -13,8 +13,8 @@ import java.util.List; public class Shape extends TensorFlowOperation { - public Shape(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Shape(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); createConstantValue(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java index d7750b52fc3..17ce9e8b7cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java @@ -19,8 +19,8 @@ public class Squeeze extends TensorFlowOperation { private List squeezeDimensions; - public Squeeze(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Squeeze(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java index 1cc0e1936eb..de4d8862fd6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java @@ -10,8 +10,8 @@ import java.util.Optional; public class Switch extends TensorFlowOperation { - public Switch(NodeDef node, List inputs, int port) { - super(node, inputs, port); + public Switch(String modelName, NodeDef node, List inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java index 2533148e5be..a0a3c71145b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -32,6 +32,8 @@ public abstract class TensorFlowOperation { protected final static String MACRO_PREFIX = "tf_macro_"; + private final String modelName; + protected final NodeDef node; protected final int port; protected final List inputs; @@ -45,13 +47,16 @@ public abstract class TensorFlowOperation { private Value constantValue = null; private List controlInputs = Collections.emptyList(); - TensorFlowOperation(NodeDef node, List inputs, int port) { + TensorFlowOperation(String modelName, NodeDef node, List inputs, int port) { + this.modelName = modelName; this.node = node; this.port = port; this.inputs = Collections.unmodifiableList(inputs); this.inputs.forEach(i -> i.outputs.add(this)); } + protected String modelName() { return modelName; } + protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); @@ -122,7 +127,7 @@ public abstract class TensorFlowOperation { public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; } /** Retrieve the valid Vespa name of this node if it is a macro */ - public String macroName() { return vespaName() != null ? MACRO_PREFIX + vespaName() : null; } + public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; } /** Retrieve the list of warnings produced during its lifetime */ public List warnings() { return Collections.unmodifiableList(importWarnings); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java index 7aefac6217c..b18a8a9b212 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java @@ -11,17 +11,14 @@ import java.util.List; public class Variable extends TensorFlowOperation { - private final String modelName; - public Variable(String modelName, NodeDef node, List inputs, int port) { - super(node, inputs, port); - this.modelName = modelName; + super(modelName, node, inputs, port); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + super.vespaName(); + return modelName() + "_" + super.vespaName(); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index b4cd2f11b0e..50a467ec581 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(tf_macro_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } -- cgit v1.2.3