diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-30 15:27:07 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-30 15:27:07 -0800 |
commit | 1420382a5c16f08cb854d58b2c29b485f51f7f9e (patch) | |
tree | f74d5ac199d3611b757c3973e0b598f430dc287f | |
parent | e0a9e9978266016823b33e1b4f3a6008b641feac (diff) |
Refactor
4 files changed, 161 insertions, 136 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java new file mode 100644 index 00000000000..b0c6cc3fe7b --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -0,0 +1,127 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.Softmax; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Map; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; + +/** + * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. + * + * @author bratseth + */ +class OperationMapper { + + /* + A note on conversion from implicitly numbered to explicitly named dimensions: + Vespa tensor dimensions are explicitly named and thus have an explicit notion of being + 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation + comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation + around dimension renaming operations which mirrors those built into the TF operation definitions. + + To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' + dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation + and the result is then renamed again (if necessary) to recover this convention across a full nested + computation. + + This requires us to track tensor types throughout the conversion. + */ + + TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { + // Note that this generalizes the corresponding TF function as it does not verify that the tensor + // types are the same, with the assumption that this already happened on the TF side + // (and if not, this should do the right thing anyway) + ensureArguments(2, arguments, "join"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(0); + + TensorType resultType = Join.outputType(a.type(), b.type()); + Join function = new Join(a.function(), b.function(), doubleFunction); + return new TypedTensorFunction(resultType, function); + } + + TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { + ensureArguments(1, arguments, "apply"); + TypedTensorFunction a = arguments.get(0); + + TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); + com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); + return new TypedTensorFunction(resultType, function); + } + + TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) { + // TODO: Verify with TF documentation + String name; + TensorType inputType; + if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants + if (tfNode.getInputList().size() != 1) + throw new IllegalArgumentException("A Variable/read node must have one input but has " + + tfNode.getInputList().size()); + name = tfNode.getInput(0); + AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); + inputType = TensorFlowImporter.importTensorType(shapes.getList().getShape(0)); + } + else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name + name = tfNode.getName(); + inputType = inputs.get(name); + if (inputType == null) + throw new IllegalArgumentException("An identity operation node is referencing input '" + name + + "', but there is no such input"); + } + return new TypedTensorFunction(inputType, new VariableTensor(name)); + } + + TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { + ensureArguments(2, arguments, "matmul"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(0); + if (a.type().rank() < 2 || b.type().rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (a.type().rank() != b.type().rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first + // and the last dimension of the second argument be not present in the first argument, while leaving the + // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. + + // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly + + String beforeLastDim = "d" + (a.type().rank() - 1); + String lastDim = "d" + a.type().rank(); + String afterLastDim = "d" + (a.type().rank() + 1); + + Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim), + ImmutableList.of(lastDim, afterLastDim)); + Matmul matmul = new Matmul(a.function(), renamedB, lastDim); + return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim), + new Rename(matmul, afterLastDim, lastDim)); + } + + TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { + ensureArguments(1, arguments, "softmax"); + TypedTensorFunction a = arguments.get(0); + // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 + String dimension = "d" + (a.type().rank() - 1); + Softmax softmax = new Softmax(a.function(), dimension); + return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); + } + + private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { + if ( arguments.size() != count) + throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + + ", but got " + arguments.size()); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index e47f2ad53d9..167ff684725 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -1,6 +1,5 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; -import com.google.common.collect.ImmutableList; import com.google.protobuf.ProtocolStringList; import com.google.protobuf.TextFormat; import com.yahoo.io.IOUtils; @@ -8,15 +7,8 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.Softmax; -import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; -import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; @@ -31,8 +23,6 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleUnaryOperator; import java.util.stream.Collectors; /** @@ -42,20 +32,7 @@ import java.util.stream.Collectors; */ public class TensorFlowImporter { - /* - A note on conversion from implicitly numbered to explicitly named dimensions: - Vespa tensor dimensions are explicitly named and thus have an explicit notion of being - 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation - comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation - around dimension renaming operations which mirrors those built into the TF operation definitions. - - To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' - dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation - and the result is then renamed again (if necessary) to recover this convention across a full nested - computation. - - This requires us to track tensor types throughout the conversion. - */ + private final OperationMapper operationMapper = new OperationMapper(); /** * Imports a saved TensorFlow model from a directory. @@ -116,7 +93,7 @@ public class TensorFlowImporter { return inputs; } - private TensorType importTensorType(TensorShapeProto tensorShape) { + static TensorType importTensorType(TensorShapeProto tensorShape) { TensorType.Builder b = new TensorType.Builder(); for (int i = 0; i < tensorShape.getDimCount(); i++) { int dimensionSize = (int) tensorShape.getDim(i).getSize(); @@ -147,11 +124,11 @@ public class TensorFlowImporter { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add()); - case "acos" : return map(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.acos()); - case "identity" : return identity(tfNode, inputs); - case "matmul" : return matmul(importArguments(tfNode, inputs, graph, indent)); - case "softmax" : return softmax(importArguments(tfNode, inputs, graph, indent)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.acos()); + case "identity" : return operationMapper.identity(tfNode, inputs); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, indent)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, indent)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } @@ -162,93 +139,6 @@ public class TensorFlowImporter { .collect(Collectors.toList()); } - private TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { - // Note that this generalizes the corresponding TF function as it does not verify that the tensor - // types are the same, with the assumption that this already happened on the TF side - // (and if not, this should do the right thing anyway) - ensureArguments(2, arguments, "join"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(0); - - TensorType resultType = Join.outputType(a.type(), b.type()); - Join function = new Join(a.function(), b.function(), doubleFunction); - return new TypedTensorFunction(resultType, function); - } - - private TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { - ensureArguments(1, arguments, "apply"); - TypedTensorFunction a = arguments.get(0); - - TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); - com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); - return new TypedTensorFunction(resultType, function); - } - - private TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) { - // TODO: Verify with TF documentation - String name; - TensorType inputType; - if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants - if (tfNode.getInputList().size() != 1) - throw new IllegalArgumentException("A Variable/read node must have one input but has " + - tfNode.getInputList().size()); - name = tfNode.getInput(0); - AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); - inputType = importTensorType(shapes.getList().getShape(0)); - } - else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name - name = tfNode.getName(); - inputType = inputs.get(name); - if (inputType == null) - throw new IllegalArgumentException("An identity operation node is referencing input '" + name + - "', but there is no such input"); - } - return new TypedTensorFunction(inputType, new VariableTensor(name)); - } - - private TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "matmul"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(0); - if (a.type().rank() < 2 || b.type.rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (a.type().rank() != b.type.rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first - // and the last dimension of the second argument be not present in the first argument, while leaving the - // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. - - // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly - - String beforeLastDim = "d" + (a.type().rank() - 1); - String lastDim = "d" + a.type().rank(); - String afterLastDim = "d" + (a.type().rank() + 1); - - Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim), - ImmutableList.of(lastDim, afterLastDim)); - Matmul matmul = new Matmul(a.function(), renamedB, lastDim); - return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim), - new Rename(matmul, afterLastDim, lastDim)); - } - - private TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "softmax"); - TypedTensorFunction a = arguments.get(0); - // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 - String dimension = "d" + (a.type().rank() - 1); - Softmax softmax = new Softmax(a.function(), dimension); - return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); - } - - private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { - if ( arguments.size() != count) - throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + - ", but got " + arguments.size()); - } - private NodeDef getNode(String name, GraphDef graph) { return graph.getNodeList().stream() .filter(node -> node.getName().equals(name)) @@ -278,21 +168,5 @@ public class TensorFlowImporter { private boolean contains(String string, ProtocolStringList strings) { return strings.asByteStringList().stream().anyMatch(s -> s.toStringUtf8().equals(string)); } - - /** A tensor function returning a specific tensor type */ - private static final class TypedTensorFunction { - - private final TensorType type; - private final TensorFunction function; - - public TypedTensorFunction(TensorType type, TensorFunction function) { - this.type = type; - this.function = function; - } - - public TensorType type() { return type; } - public TensorFunction function() { return function; } - - } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java new file mode 100644 index 00000000000..234d620d02f --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java @@ -0,0 +1,24 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +/** + * A tensor function returning a specific tensor type + * + * @author bratseth + */ +final class TypedTensorFunction { + + private final TensorType type; + private final TensorFunction function; + + public TypedTensorFunction(TensorType type, TensorFunction function) { + this.type = type; + this.function = function; + } + + public TensorType type() { return type; } + public TensorFunction function() { return function; } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index bfe2bb3a63b..f2164a1b177 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -12,7 +12,7 @@ import static org.junit.Assert.assertEquals; * @author bratseth */ public class TensorFlowImporterTestCase { - + @Test public void testModel1() { List<RankingExpression> expressions = @@ -26,10 +26,10 @@ public class TensorFlowImporterTestCase { "d1)", toNonPrimitiveString(expressions.get(0))); } - + private String toNonPrimitiveString(RankingExpression expression) { // toString on the wrapping expression will map to primitives, which is harder to read return ((TensorFunctionNode)expression.getRoot()).function().toString(); } - + } |