summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-13 13:21:16 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-13 13:21:16 +0100
commit3f260a9420f3b395a2490df532f8fe883756b0fb (patch)
tree42851bab5228090b9ac06b02a5e83ebc1960fce2 /searchlib
parent7b9cf1820056d161ce761d0c040f7ceb21728f13 (diff)
Prefix constants by model name
Large constants are cross rank profiles. This avoids name conflicts when multiple models are used. It is not strictly necessary because the user can always disambiguate when chosing names, but there is a scenario where conflicts are plausible and leaving this to users is inconvenient: Multiple versions of the "same" model are tested in different rank profiles.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java4
11 files changed, 95 insertions, 46 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 80a9262afeb..217eafd7446 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -49,25 +49,27 @@ public class TensorFlowImporter {
* The model should be saved as a .pbtxt or .pb file.
* The name of the model is taken as the db/pbtxt file name (not including the file ending).
*
+ * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
* @param modelDir the directory containing the TensorFlow model files to import
*/
- public TensorFlowModel importModel(String modelDir) {
+ public TensorFlowModel importModel(String modelName, String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importModel(model);
+
+ return importModel(modelName, model);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
}
}
- public TensorFlowModel importModel(File modelDir) {
- return importModel(modelDir.toString());
+ public TensorFlowModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
}
/** Imports a TensorFlow model */
- public TensorFlowModel importModel(SavedModelBundle model) {
+ public TensorFlowModel importModel(String modelName, SavedModelBundle model) {
try {
- return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model);
}
catch (IOException e) {
throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
@@ -79,8 +81,8 @@ public class TensorFlowImporter {
* finding a suitable set of dimensions names for each
* placeholder/constant/variable, then importing the expressions.
*/
- private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) {
- TensorFlowModel model = new TensorFlowModel();
+ private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) {
+ TensorFlowModel model = new TensorFlowModel(modelName);
OperationIndex index = new OperationIndex();
importSignatures(graph, model);
@@ -138,21 +140,21 @@ public class TensorFlowImporter {
private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
for (TensorFlowModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
- importNode(outputName, graph.getGraphDef(), index);
+ importNode(model.name(), outputName, graph.getGraphDef(), index);
}
}
}
- private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) {
- if (index.alreadyImported(name)) {
- return index.get(name);
+ private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
+ if (index.alreadyImported(nodeName)) {
+ return index.get(nodeName);
}
- NodeDef node = getTensorFlowNodeFromGraph(namePartOf(name), graph);
- List<TensorFlowOperation> inputs = importNodeInputs(node, graph, index);
- TensorFlowOperation operation = OperationMapper.get(node, inputs, portPartOf(name));
- index.put(name, operation);
+ NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
+ List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
+ TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
+ index.put(nodeName, operation);
- List<TensorFlowOperation> controlInputs = importControlInputs(node, graph, index);
+ List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
if (controlInputs.size() > 0) {
operation.setControlInputs(controlInputs);
}
@@ -160,17 +162,17 @@ public class TensorFlowImporter {
return operation;
}
- private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
return node.getInputList().stream()
.filter(name -> ! isControlDependency(name))
- .map(name -> importNode(name, graph, index))
+ .map(nodeName -> importNode(modelName, nodeName, graph, index))
.collect(Collectors.toList());
}
- private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
return node.getInputList().stream()
- .filter(name -> isControlDependency(name))
- .map(name -> importNode(name, graph, index))
+ .filter(nodeName -> isControlDependency(nodeName))
+ .map(nodeName -> importNode(modelName, nodeName, graph, index))
.collect(Collectors.toList());
}
@@ -280,9 +282,9 @@ public class TensorFlowImporter {
}
if (tensor.type().rank() == 0 || tensor.size() <= 1) {
- model.smallConstant(operation.vespaName(), tensor);
+ model.smallConstant(name, tensor);
} else {
- model.largeConstant(operation.vespaName(), tensor);
+ model.largeConstant(name, tensor);
}
return operation.function();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 351aa417f9c..721214f9e94 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -10,6 +10,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.regex.Pattern;
/**
* The result of importing a TensorFlow model into Vespa.
@@ -22,6 +23,25 @@ import java.util.Map;
// This object can be built incrementally within this package, but is immutable when observed from outside the package
public class TensorFlowModel {
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+
+ private final String name;
+
+ /**
+ * Creates a TensorFlow model
+ *
+ * @param name the name of this mode, containing only characters in [A-Za-z0-9_]
+ */
+ public TensorFlowModel(String name) {
+ if ( ! nameRegexp.matcher(name).matches())
+ throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
+ this.name = name;
+ }
+
+ /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ public String name() { return name; }
+
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> smallConstants = new HashMap<>();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
index ee358f45b22..1b87c302835 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -32,12 +32,12 @@ import java.util.List;
*/
public class OperationMapper {
- public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
switch (node.getOp().toLowerCase()) {
// array ops
- case "const": return new Const(node, inputs, port);
+ case "const": return new Const(modelName, node, inputs, port);
case "expanddims": return new ExpandDims(node, inputs, port);
- case "identity": return new Identity(node, inputs, port);
+ case "identity": return new Identity(modelName, node, inputs, port);
case "placeholder": return new Placeholder(node, inputs, port);
case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port);
case "reshape": return new Reshape(node, inputs, port);
@@ -76,11 +76,11 @@ public class OperationMapper {
case "selu": return new Map(node, inputs, port, ScalarFunctions.selu());
// state ops
- case "variable": return new Variable(node, inputs, port);
- case "variablev2": return new Variable(node, inputs, port);
+ case "variable": return new Variable(modelName, node, inputs, port);
+ case "variablev2": return new Variable(modelName, node, inputs, port);
// evaluation no-ops
- case "stopgradient":return new Identity(node, inputs, port);
+ case "stopgradient":return new Identity(modelName, node, inputs, port);
case "noop": return new NoOp(node, inputs, port);
}
return new NoOp(node, inputs, port);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
index d06d7b48def..eb2a82fe114 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
@@ -23,8 +23,11 @@ import java.util.Optional;
public class Const extends TensorFlowOperation {
- public Const(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ private final String modelName;
+
+ public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
super(node, inputs, port);
+ this.modelName = modelName;
setConstantValue(value());
}
@@ -52,6 +55,12 @@ public class Const extends TensorFlowOperation {
return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
}
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
+ }
+
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
for (TensorType.Dimension dimension : type.type().dimensions()) {
@@ -71,7 +80,7 @@ public class Const extends TensorFlowOperation {
}
private Value value() {
- if (!node.getAttrMap().containsKey("value")) {
+ if ( ! node.getAttrMap().containsKey("value")) {
throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
"const has missing 'value' attribute");
}
@@ -89,6 +98,6 @@ public class Const extends TensorFlowOperation {
return new DoubleValue(attrValue.getF());
}
throw new IllegalArgumentException("Requesting value of constant in " +
- node.getName() + " but type is not recognized.");
+ node.getName() + " but type is not recognized.");
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
index d79707a42e6..306232bb9ff 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
@@ -9,8 +9,17 @@ import java.util.List;
public class Identity extends TensorFlowOperation {
- public Identity(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ private final String modelName;
+
+ public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
super(node, inputs, port);
+ this.modelName = modelName;
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
index dfe0796d9b8..83f9b37e631 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
@@ -37,7 +37,7 @@ public class Mean extends TensorFlowOperation {
TensorFlowOperation reductionIndices = inputs.get(1);
if (!reductionIndices.getConstantValue().isPresent()) {
throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
- "reduction indices must be a constant.");
+ "reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
reduceDimensions = new ArrayList<>();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
index 6f377c4bda2..7aefac6217c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
@@ -11,13 +11,22 @@ import java.util.List;
public class Variable extends TensorFlowOperation {
- public Variable(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ private final String modelName;
+
+ public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
super(node, inputs, port);
+ this.modelName = modelName;
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_");
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
index c6ee586a78c..0f5eec93feb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -14,7 +14,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index c0e25a85ed0..b4cd2f11b0e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -16,7 +16,7 @@ public class DropoutImportTestCase {
@Test
public void testDropoutImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(tf_macro_outputs_BiasAdd, reduce(constant(outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(tf_macro_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 0deac3f8216..9f919c452d6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -17,18 +17,18 @@ public class MnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().largeConstants().get("Variable_read");
+ Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().largeConstants().get("Variable_1_read");
+ Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
@@ -59,7 +59,7 @@ public class MnistSoftmaxImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(Variable_read), f(a,b)(a * b)), sum, d2), constant(Variable_1_read), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
output.getRoot().toString());
// Test execution
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index daacd014b63..7ca16939477 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -34,9 +34,9 @@ public class TestableTensorFlowModel {
private final int d0Size = 1;
private final int d1Size = 784;
- public TestableTensorFlowModel(String modelDir) {
+ public TestableTensorFlowModel(String modelName, String modelDir) {
tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
- model = new TensorFlowImporter().importModel(tensorFlowModel);
+ model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
}
public TensorFlowModel get() { return model; }