diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-13 14:07:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-03-13 14:07:01 +0100 |
commit | 4454b5db222b28d67962c872a77ce154021e7b20 (patch) | |
tree | bb7a3eb1029b2e4152b76497f31f0e17832b819d /searchlib/src | |
parent | 0f69157c9f65f0ee22787effc2532d692354bc29 (diff) | |
parent | 3f260a9420f3b395a2490df532f8fe883756b0fb (diff) |
Merge pull request #5309 from vespa-engine/bratseth/disambiguate-constants
Prefix constants by model name
Diffstat (limited to 'searchlib/src')
11 files changed, 95 insertions, 46 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 80a9262afeb..217eafd7446 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -49,25 +49,27 @@ public class TensorFlowImporter { * The model should be saved as a .pbtxt or .pb file. * The name of the model is taken as the db/pbtxt file name (not including the file ending). * + * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] * @param modelDir the directory containing the TensorFlow model files to import */ - public TensorFlowModel importModel(String modelDir) { + public TensorFlowModel importModel(String modelName, String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importModel(model); + + return importModel(modelName, model); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); } } - public TensorFlowModel importModel(File modelDir) { - return importModel(modelDir.toString()); + public TensorFlowModel importModel(String modelName, File modelDir) { + return importModel(modelName, modelDir.toString()); } /** Imports a TensorFlow model */ - public TensorFlowModel importModel(SavedModelBundle model) { + public TensorFlowModel importModel(String modelName, SavedModelBundle model) { try { - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); + return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model); } catch (IOException e) { throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); @@ -79,8 +81,8 @@ public class TensorFlowImporter { * finding a suitable set of dimensions names for each * placeholder/constant/variable, then importing the expressions. */ - private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) { - TensorFlowModel model = new TensorFlowModel(); + private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) { + TensorFlowModel model = new TensorFlowModel(modelName); OperationIndex index = new OperationIndex(); importSignatures(graph, model); @@ -138,21 +140,21 @@ public class TensorFlowImporter { private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) { for (TensorFlowModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { - importNode(outputName, graph.getGraphDef(), index); + importNode(model.name(), outputName, graph.getGraphDef(), index); } } } - private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) { - if (index.alreadyImported(name)) { - return index.get(name); + private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) { + if (index.alreadyImported(nodeName)) { + return index.get(nodeName); } - NodeDef node = getTensorFlowNodeFromGraph(namePartOf(name), graph); - List<TensorFlowOperation> inputs = importNodeInputs(node, graph, index); - TensorFlowOperation operation = OperationMapper.get(node, inputs, portPartOf(name)); - index.put(name, operation); + NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph); + List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index); + TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName)); + index.put(nodeName, operation); - List<TensorFlowOperation> controlInputs = importControlInputs(node, graph, index); + List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index); if (controlInputs.size() > 0) { operation.setControlInputs(controlInputs); } @@ -160,17 +162,17 @@ public class TensorFlowImporter { return operation; } - private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) { + private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { return node.getInputList().stream() .filter(name -> ! isControlDependency(name)) - .map(name -> importNode(name, graph, index)) + .map(nodeName -> importNode(modelName, nodeName, graph, index)) .collect(Collectors.toList()); } - private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) { + private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { return node.getInputList().stream() - .filter(name -> isControlDependency(name)) - .map(name -> importNode(name, graph, index)) + .filter(nodeName -> isControlDependency(nodeName)) + .map(nodeName -> importNode(modelName, nodeName, graph, index)) .collect(Collectors.toList()); } @@ -280,9 +282,9 @@ public class TensorFlowImporter { } if (tensor.type().rank() == 0 || tensor.size() <= 1) { - model.smallConstant(operation.vespaName(), tensor); + model.smallConstant(name, tensor); } else { - model.largeConstant(operation.vespaName(), tensor); + model.largeConstant(name, tensor); } return operation.function(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 351aa417f9c..721214f9e94 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.regex.Pattern; /** * The result of importing a TensorFlow model into Vespa. @@ -22,6 +23,25 @@ import java.util.Map; // This object can be built incrementally within this package, but is immutable when observed from outside the package public class TensorFlowModel { + private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); + + private final String name; + + /** + * Creates a TensorFlow model + * + * @param name the name of this mode, containing only characters in [A-Za-z0-9_] + */ + public TensorFlowModel(String name) { + if ( ! nameRegexp.matcher(name).matches()) + throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + + name + "'"); + this.name = name; + } + + /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + public String name() { return name; } + private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); private final Map<String, Tensor> smallConstants = new HashMap<>(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java index ee358f45b22..1b87c302835 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -32,12 +32,12 @@ import java.util.List; */ public class OperationMapper { - public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) { + public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { switch (node.getOp().toLowerCase()) { // array ops - case "const": return new Const(node, inputs, port); + case "const": return new Const(modelName, node, inputs, port); case "expanddims": return new ExpandDims(node, inputs, port); - case "identity": return new Identity(node, inputs, port); + case "identity": return new Identity(modelName, node, inputs, port); case "placeholder": return new Placeholder(node, inputs, port); case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port); case "reshape": return new Reshape(node, inputs, port); @@ -76,11 +76,11 @@ public class OperationMapper { case "selu": return new Map(node, inputs, port, ScalarFunctions.selu()); // state ops - case "variable": return new Variable(node, inputs, port); - case "variablev2": return new Variable(node, inputs, port); + case "variable": return new Variable(modelName, node, inputs, port); + case "variablev2": return new Variable(modelName, node, inputs, port); // evaluation no-ops - case "stopgradient":return new Identity(node, inputs, port); + case "stopgradient":return new Identity(modelName, node, inputs, port); case "noop": return new NoOp(node, inputs, port); } return new NoOp(node, inputs, port); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java index d06d7b48def..eb2a82fe114 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java @@ -23,8 +23,11 @@ import java.util.Optional; public class Const extends TensorFlowOperation { - public Const(NodeDef node, List<TensorFlowOperation> inputs, int port) { + private final String modelName; + + public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { super(node, inputs, port); + this.modelName = modelName; setConstantValue(value()); } @@ -52,6 +55,12 @@ public class Const extends TensorFlowOperation { return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); } + /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { + return modelName + "_" + super.vespaName(); + } + @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { for (TensorType.Dimension dimension : type.type().dimensions()) { @@ -71,7 +80,7 @@ public class Const extends TensorFlowOperation { } private Value value() { - if (!node.getAttrMap().containsKey("value")) { + if ( ! node.getAttrMap().containsKey("value")) { throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + "const has missing 'value' attribute"); } @@ -89,6 +98,6 @@ public class Const extends TensorFlowOperation { return new DoubleValue(attrValue.getF()); } throw new IllegalArgumentException("Requesting value of constant in " + - node.getName() + " but type is not recognized."); + node.getName() + " but type is not recognized."); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java index d79707a42e6..306232bb9ff 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java @@ -9,8 +9,17 @@ import java.util.List; public class Identity extends TensorFlowOperation { - public Identity(NodeDef node, List<TensorFlowOperation> inputs, int port) { + private final String modelName; + + public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { super(node, inputs, port); + this.modelName = modelName; + } + + /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { + return modelName + "_" + super.vespaName(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java index dfe0796d9b8..83f9b37e631 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java @@ -37,7 +37,7 @@ public class Mean extends TensorFlowOperation { TensorFlowOperation reductionIndices = inputs.get(1); if (!reductionIndices.getConstantValue().isPresent()) { throw new IllegalArgumentException("Mean in " + node.getName() + ": " + - "reduction indices must be a constant."); + "reduction indices must be a constant."); } Tensor indices = reductionIndices.getConstantValue().get().asTensor(); reduceDimensions = new ArrayList<>(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java index 6f377c4bda2..7aefac6217c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java @@ -11,13 +11,22 @@ import java.util.List; public class Variable extends TensorFlowOperation { - public Variable(NodeDef node, List<TensorFlowOperation> inputs, int port) { + private final String modelName; + + public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { super(node, inputs, port); + this.modelName = modelName; + } + + /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { + return modelName + "_" + super.vespaName(); } @Override protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_"); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java index c6ee586a78c..0f5eec93feb 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -14,7 +14,7 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index c0e25a85ed0..b4cd2f11b0e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -16,7 +16,7 @@ public class DropoutImportTestCase { @Test public void testDropoutImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); // Check required macros assertEquals(1, model.get().requiredMacros().size()); @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(tf_macro_outputs_BiasAdd, reduce(constant(outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(tf_macro_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 0deac3f8216..9f919c452d6 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -17,18 +17,18 @@ public class MnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved"); // Check constants assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("Variable_read"); + Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("Variable_1_read"); + Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); @@ -59,7 +59,7 @@ public class MnistSoftmaxImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(Variable_read), f(a,b)(a * b)), sum, d2), constant(Variable_1_read), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getRoot().toString()); // Test execution diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index daacd014b63..7ca16939477 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -34,9 +34,9 @@ public class TestableTensorFlowModel { private final int d0Size = 1; private final int d1Size = 784; - public TestableTensorFlowModel(String modelDir) { + public TestableTensorFlowModel(String modelName, String modelDir) { tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); - model = new TensorFlowImporter().importModel(tensorFlowModel); + model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); } public TensorFlowModel get() { return model; } |