diff options
5 files changed, 100 insertions, 47 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java new file mode 100644 index 00000000000..235771bfa9c --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java @@ -0,0 +1,23 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.tensor.Tensor; + +/** + * A tensor with a name + * + * @author bratseth + */ +public class NamedTensor { + + private final String name; + private final Tensor tensor; + + public NamedTensor(String name, Tensor tensor) { + this.name = name; + this.tensor = tensor; + } + + public String name() { return name; } + public Tensor tensor() { return tensor; } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 0717d3e1b2b..183cfabbd87 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -1,10 +1,11 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Matmul; import com.yahoo.tensor.functions.Rename; @@ -49,7 +50,7 @@ class OperationMapper { // (and if not, this should do the right thing anyway) ensureArguments(2, arguments, "join"); TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(0); + TypedTensorFunction b = arguments.get(1); TensorType resultType = Join.outputType(a.type(), b.type()); Join function = new Join(a.function(), b.function(), doubleFunction); @@ -65,9 +66,10 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs, SavedModelBundle model) { + TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs, SavedModelBundle model, + List<NamedTensor> constants) { String name; - TensorType inputType; + TensorType type; if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model if (tfNode.getInputList().size() != 1) throw new IllegalArgumentException("A Variable/read node must have one input but has " + @@ -75,29 +77,30 @@ class OperationMapper { name = tfNode.getInput(0); AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); if (shapes == null) - throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); - inputType = TensorFlowImporter.importTensorType(shapes.getList().getShape(0)); - Session.Runner fetched = model.session().runner().fetch("Variable"); + throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); + Session.Runner fetched = model.session().runner().fetch(name); List<org.tensorflow.Tensor<?>> result = fetched.run(); if ( result.size() != 1) throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + result.size()); Tensor constant = tensorConverter.toVespaTensor(result.get(0)); - return new TypedTensorFunction(inputType, new ConstantTensor(constant)); + constants.add(new NamedTensor(name, constant)); + return new TypedTensorFunction(constant.type(), + new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); } else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name name = tfNode.getName(); - inputType = inputs.get(name); - if (inputType == null) + type = inputs.get(name); + if (type == null) throw new IllegalArgumentException("An identity operation node is referencing input '" + name + "', but there is no such input"); - return new TypedTensorFunction(inputType, new VariableTensor(name)); + return new TypedTensorFunction(type, new VariableTensor(name)); } } TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { ensureArguments(2, arguments, "matmul"); TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(0); + TypedTensorFunction b = arguments.get(1); if (a.type().rank() < 2 || b.type().rank() < 2) throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); if (a.type().rank() != b.type().rank()) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index c14f8c71a3e..51f1e444e70 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; import org.tensorflow.framework.GraphDef; @@ -35,12 +36,16 @@ public class TensorFlowImporter { * Imports a saved TensorFlow model from a directory. * The model should be saved as a pbtxt file. * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). + * + * @param modelDir the directory containing the TensorFlow model files to import + * @param constants any constant tensors imported from the TensorFlow model and referenced in the returned expressions + * @param logger a receiver of any messages generated by the import process + * @return the ranking expressions resulting from importing this TenorFlow model */ - public List<RankingExpression> importModel(String modelDir, MessageLogger logger) { + public List<RankingExpression> importModel(String modelDir, List<NamedTensor> constants, MessageLogger logger) { try { SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, logger); - + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, constants, logger); } catch (IOException e) { throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e); @@ -48,7 +53,8 @@ public class TensorFlowImporter { } - private List<RankingExpression> importGraph(MetaGraphDef graph, SavedModelBundle model, MessageLogger logger) { + private List<RankingExpression> importGraph(MetaGraphDef graph, SavedModelBundle model, + List<NamedTensor> constants, MessageLogger logger) { List<RankingExpression> expressions = new ArrayList<>(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap()); @@ -57,7 +63,8 @@ public class TensorFlowImporter { ExpressionNode result = importOutput(output.getValue(), inputs, graph.getGraphDef(), - model); + model, + constants); expressions.add(new RankingExpression(output.getKey(), result)); } catch (IllegalArgumentException e) { @@ -77,7 +84,7 @@ public class TensorFlowImporter { return inputs; } - static TensorType importTensorType(TensorShapeProto tensorShape) { + private TensorType importTensorType(TensorShapeProto tensorShape) { TensorType.Builder b = new TensorType.Builder(); for (int i = 0; i < tensorShape.getDimCount(); i++) { int dimensionSize = (int) tensorShape.getDim(i).getSize(); @@ -89,28 +96,32 @@ public class TensorFlowImporter { return b.build(); } - private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph, SavedModelBundle model) { + private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph, + SavedModelBundle model, List<NamedTensor> constants) { NodeDef node = getNode(nameOf(output.getName()), graph); - return new TensorFunctionNode(importNode(node, inputs, graph, model).function()); + TensorFunction function = importNode(node, inputs, graph, model, constants).function(); + return new TensorFunctionNode(function); // wrap top level (only) as an expression } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, SavedModelBundle model) { - return tensorFunctionOf(tfNode, inputs, graph, model); + private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, + SavedModelBundle model, List<NamedTensor> constants) { + return tensorFunctionOf(tfNode, inputs, graph, model, constants); } private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, - SavedModelBundle model) { + SavedModelBundle model, + List<NamedTensor> constants) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model), ScalarFunctions.acos()); - case "identity" : return operationMapper.identity(tfNode, inputs, model); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.acos()); + case "identity" : return operationMapper.identity(tfNode, inputs, model, constants); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model, constants)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model, constants)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } @@ -118,9 +129,10 @@ public class TensorFlowImporter { private List<TypedTensorFunction> importArguments(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, - SavedModelBundle model) { + SavedModelBundle model, + List<NamedTensor> constants) { return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model)) + .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model, constants)) .collect(Collectors.toList()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ab5f1e7191d..d1f4cbddf6e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -21,18 +21,18 @@ import java.util.stream.Collectors; * * @author bratseth */ - @Beta +@Beta public class TensorFunctionNode extends CompositeNode { private final TensorFunction function; - + public TensorFunctionNode(TensorFunction function) { this.function = function; } /** Returns the tensor function wrapped by this */ public TensorFunction function() { return function; } - + @Override public List<ExpressionNode> children() { return function.functionArguments().stream() @@ -53,7 +53,7 @@ public class TensorFunctionNode extends CompositeNode { // Serialize as primitive return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); } - + @Override public Value evaluate(Context context) { return new TensorValue(function.evaluate(context)); @@ -62,8 +62,8 @@ public class TensorFunctionNode extends CompositeNode { public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { return new TensorFunctionExpressionNode(node); } - - /** + + /** * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ @@ -71,13 +71,13 @@ public class TensorFunctionNode extends CompositeNode { /** An expression which produces a tensor */ private final ExpressionNode expression; - + public TensorFunctionExpressionNode(ExpressionNode expression) { this.expression = expression; } - + @Override - public List<TensorFunction> functionArguments() { + public List<TensorFunction> functionArguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(TensorFunctionExpressionNode::new) @@ -111,7 +111,7 @@ public class TensorFunctionNode extends CompositeNode { public String toString() { return toString(ExpressionNodeToStringContext.empty); } - + @Override public String toString(ToStringContext c) { if (c instanceof ExpressionNodeToStringContext) { @@ -124,14 +124,14 @@ public class TensorFunctionNode extends CompositeNode { } } - + /** Allows passing serialization context arguments through TensorFunctions */ private static class ExpressionNodeToStringContext implements ToStringContext { - + final SerializationContext context; final Deque<String> path; final CompositeNode parent; - + public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null); public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index 936936dc3eb..aaf198a9e8f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; import org.junit.Test; import java.util.ArrayList; @@ -18,9 +19,23 @@ public class TensorFlowImporterTestCase { @Test public void testModel1() { + List<NamedTensor> constants = new ArrayList<>(); TestLogger logger = new TestLogger(); List<RankingExpression> expressions = - new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", logger); + new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", constants, logger); + + // Check constants + assertEquals(2, constants.size()); + + assertEquals("Variable", constants.get(0).name()); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constants.get(0).tensor().type()); + assertEquals(7840, constants.get(0).tensor().size()); + + assertEquals("Variable_1", constants.get(1).name()); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constants.get(1).tensor().type()); + assertEquals(10, constants.get(1).tensor().size()); // Check logged messages assertEquals(2, logger.messages().size()); @@ -33,10 +48,10 @@ public class TensorFlowImporterTestCase { assertEquals(1, expressions.size()); assertEquals("scores", expressions.get(0).getName()); assertEquals("" + - "softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + - "rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + + "softmax(join(rename(matmul(x, rename(constant(Variable), (d1, d2), (d2, d3)), d2), d3, d2), " + + "constant(Variable_1), " + "f(a,b)(a + b)), " + - "d1)", + "d0)", toNonPrimitiveString(expressions.get(0))); } |