aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-13 14:07:01 +0100
committerGitHub <noreply@github.com>2018-03-13 14:07:01 +0100
commit4454b5db222b28d67962c872a77ce154021e7b20 (patch)
treebb7a3eb1029b2e4152b76497f31f0e17832b819d /searchlib/src
parent0f69157c9f65f0ee22787effc2532d692354bc29 (diff)
parent3f260a9420f3b395a2490df532f8fe883756b0fb (diff)
Merge pull request #5309 from vespa-engine/bratseth/disambiguate-constants
Prefix constants by model name
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java4
11 files changed, 95 insertions, 46 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 80a9262afeb..217eafd7446 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -49,25 +49,27 @@ public class TensorFlowImporter {
* The model should be saved as a .pbtxt or .pb file.
* The name of the model is taken as the db/pbtxt file name (not including the file ending).
*
+ * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
* @param modelDir the directory containing the TensorFlow model files to import
*/
- public TensorFlowModel importModel(String modelDir) {
+ public TensorFlowModel importModel(String modelName, String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importModel(model);
+
+ return importModel(modelName, model);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
}
}
- public TensorFlowModel importModel(File modelDir) {
- return importModel(modelDir.toString());
+ public TensorFlowModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
}
/** Imports a TensorFlow model */
- public TensorFlowModel importModel(SavedModelBundle model) {
+ public TensorFlowModel importModel(String modelName, SavedModelBundle model) {
try {
- return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model);
}
catch (IOException e) {
throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
@@ -79,8 +81,8 @@ public class TensorFlowImporter {
* finding a suitable set of dimensions names for each
* placeholder/constant/variable, then importing the expressions.
*/
- private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) {
- TensorFlowModel model = new TensorFlowModel();
+ private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) {
+ TensorFlowModel model = new TensorFlowModel(modelName);
OperationIndex index = new OperationIndex();
importSignatures(graph, model);
@@ -138,21 +140,21 @@ public class TensorFlowImporter {
private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
for (TensorFlowModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
- importNode(outputName, graph.getGraphDef(), index);
+ importNode(model.name(), outputName, graph.getGraphDef(), index);
}
}
}
- private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) {
- if (index.alreadyImported(name)) {
- return index.get(name);
+ private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
+ if (index.alreadyImported(nodeName)) {
+ return index.get(nodeName);
}
- NodeDef node = getTensorFlowNodeFromGraph(namePartOf(name), graph);
- List<TensorFlowOperation> inputs = importNodeInputs(node, graph, index);
- TensorFlowOperation operation = OperationMapper.get(node, inputs, portPartOf(name));
- index.put(name, operation);
+ NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
+ List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
+ TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
+ index.put(nodeName, operation);
- List<TensorFlowOperation> controlInputs = importControlInputs(node, graph, index);
+ List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
if (controlInputs.size() > 0) {
operation.setControlInputs(controlInputs);
}
@@ -160,17 +162,17 @@ public class TensorFlowImporter {
return operation;
}
- private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
return node.getInputList().stream()
.filter(name -> ! isControlDependency(name))
- .map(name -> importNode(name, graph, index))
+ .map(nodeName -> importNode(modelName, nodeName, graph, index))
.collect(Collectors.toList());
}
- private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
return node.getInputList().stream()
- .filter(name -> isControlDependency(name))
- .map(name -> importNode(name, graph, index))
+ .filter(nodeName -> isControlDependency(nodeName))
+ .map(nodeName -> importNode(modelName, nodeName, graph, index))
.collect(Collectors.toList());
}
@@ -280,9 +282,9 @@ public class TensorFlowImporter {
}
if (tensor.type().rank() == 0 || tensor.size() <= 1) {
- model.smallConstant(operation.vespaName(), tensor);
+ model.smallConstant(name, tensor);
} else {
- model.largeConstant(operation.vespaName(), tensor);
+ model.largeConstant(name, tensor);
}
return operation.function();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 351aa417f9c..721214f9e94 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -10,6 +10,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.regex.Pattern;
/**
* The result of importing a TensorFlow model into Vespa.
@@ -22,6 +23,25 @@ import java.util.Map;
// This object can be built incrementally within this package, but is immutable when observed from outside the package
public class TensorFlowModel {
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+
+ private final String name;
+
+ /**
+ * Creates a TensorFlow model
+ *
+ * @param name the name of this mode, containing only characters in [A-Za-z0-9_]
+ */
+ public TensorFlowModel(String name) {
+ if ( ! nameRegexp.matcher(name).matches())
+ throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
+ this.name = name;
+ }
+
+ /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ public String name() { return name; }
+
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> smallConstants = new HashMap<>();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
index ee358f45b22..1b87c302835 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -32,12 +32,12 @@ import java.util.List;
*/
public class OperationMapper {
- public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
switch (node.getOp().toLowerCase()) {
// array ops
- case "const": return new Const(node, inputs, port);
+ case "const": return new Const(modelName, node, inputs, port);
case "expanddims": return new ExpandDims(node, inputs, port);
- case "identity": return new Identity(node, inputs, port);
+ case "identity": return new Identity(modelName, node, inputs, port);
case "placeholder": return new Placeholder(node, inputs, port);
case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port);
case "reshape": return new Reshape(node, inputs, port);
@@ -76,11 +76,11 @@ public class OperationMapper {
case "selu": return new Map(node, inputs, port, ScalarFunctions.selu());
// state ops
- case "variable": return new Variable(node, inputs, port);
- case "variablev2": return new Variable(node, inputs, port);
+ case "variable": return new Variable(modelName, node, inputs, port);
+ case "variablev2": return new Variable(modelName, node, inputs, port);
// evaluation no-ops
- case "stopgradient":return new Identity(node, inputs, port);
+ case "stopgradient":return new Identity(modelName, node, inputs, port);
case "noop": return new NoOp(node, inputs, port);
}
return new NoOp(node, inputs, port);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
index d06d7b48def..eb2a82fe114 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
@@ -23,8 +23,11 @@ import java.util.Optional;
public class Const extends TensorFlowOperation {
- public Const(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ private final String modelName;
+
+ public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
super(node, inputs, port);
+ this.modelName = modelName;
setConstantValue(value());
}
@@ -52,6 +55,12 @@ public class Const extends TensorFlowOperation {
return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
}
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
+ }
+
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
for (TensorType.Dimension dimension : type.type().dimensions()) {
@@ -71,7 +80,7 @@ public class Const extends TensorFlowOperation {
}
private Value value() {
- if (!node.getAttrMap().containsKey("value")) {
+ if ( ! node.getAttrMap().containsKey("value")) {
throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
"const has missing 'value' attribute");
}
@@ -89,6 +98,6 @@ public class Const extends TensorFlowOperation {
return new DoubleValue(attrValue.getF());
}
throw new IllegalArgumentException("Requesting value of constant in " +
- node.getName() + " but type is not recognized.");
+ node.getName() + " but type is not recognized.");
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
index d79707a42e6..306232bb9ff 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
@@ -9,8 +9,17 @@ import java.util.List;
public class Identity extends TensorFlowOperation {
- public Identity(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ private final String modelName;
+
+ public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
super(node, inputs, port);
+ this.modelName = modelName;
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
index dfe0796d9b8..83f9b37e631 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
@@ -37,7 +37,7 @@ public class Mean extends TensorFlowOperation {
TensorFlowOperation reductionIndices = inputs.get(1);
if (!reductionIndices.getConstantValue().isPresent()) {
throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
- "reduction indices must be a constant.");
+ "reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
reduceDimensions = new ArrayList<>();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
index 6f377c4bda2..7aefac6217c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
@@ -11,13 +11,22 @@ import java.util.List;
public class Variable extends TensorFlowOperation {
- public Variable(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ private final String modelName;
+
+ public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
super(node, inputs, port);
+ this.modelName = modelName;
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_");
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
index c6ee586a78c..0f5eec93feb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -14,7 +14,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index c0e25a85ed0..b4cd2f11b0e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -16,7 +16,7 @@ public class DropoutImportTestCase {
@Test
public void testDropoutImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(tf_macro_outputs_BiasAdd, reduce(constant(outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(tf_macro_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 0deac3f8216..9f919c452d6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -17,18 +17,18 @@ public class MnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().largeConstants().get("Variable_read");
+ Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().largeConstants().get("Variable_1_read");
+ Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
@@ -59,7 +59,7 @@ public class MnistSoftmaxImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(Variable_read), f(a,b)(a * b)), sum, d2), constant(Variable_1_read), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
output.getRoot().toString());
// Test execution
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index daacd014b63..7ca16939477 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -34,9 +34,9 @@ public class TestableTensorFlowModel {
private final int d0Size = 1;
private final int d1Size = 784;
- public TestableTensorFlowModel(String modelDir) {
+ public TestableTensorFlowModel(String modelName, String modelDir) {
tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
- model = new TensorFlowImporter().importModel(tensorFlowModel);
+ model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
}
public TensorFlowModel get() { return model; }