summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-11-30 15:27:07 -0800
committerJon Bratseth <bratseth@yahoo-inc.com>2017-11-30 15:27:07 -0800
commit1420382a5c16f08cb854d58b2c29b485f51f7f9e (patch)
treef74d5ac199d3611b757c3973e0b598f430dc287f /searchlib
parente0a9e9978266016823b33e1b4f3a6008b641feac (diff)
Refactor
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java127
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java140
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java24
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java6
4 files changed, 161 insertions, 136 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
new file mode 100644
index 00000000000..b0c6cc3fe7b
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -0,0 +1,127 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.Softmax;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Map;
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
+ *
+ * @author bratseth
+ */
+class OperationMapper {
+
+ /*
+ A note on conversion from implicitly numbered to explicitly named dimensions:
+ Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
+ 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
+ comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
+ around dimension renaming operations which mirrors those built into the TF operation definitions.
+
+ To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
+ dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
+ and the result is then renamed again (if necessary) to recover this convention across a full nested
+ computation.
+
+ This requires us to track tensor types throughout the conversion.
+ */
+
+ TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
+ // Note that this generalizes the corresponding TF function as it does not verify that the tensor
+ // types are the same, with the assumption that this already happened on the TF side
+ // (and if not, this should do the right thing anyway)
+ ensureArguments(2, arguments, "join");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(0);
+
+ TensorType resultType = Join.outputType(a.type(), b.type());
+ Join function = new Join(a.function(), b.function(), doubleFunction);
+ return new TypedTensorFunction(resultType, function);
+ }
+
+ TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
+ ensureArguments(1, arguments, "apply");
+ TypedTensorFunction a = arguments.get(0);
+
+ TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
+ com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
+ return new TypedTensorFunction(resultType, function);
+ }
+
+ TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) {
+ // TODO: Verify with TF documentation
+ String name;
+ TensorType inputType;
+ if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants
+ if (tfNode.getInputList().size() != 1)
+ throw new IllegalArgumentException("A Variable/read node must have one input but has " +
+ tfNode.getInputList().size());
+ name = tfNode.getInput(0);
+ AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
+ if (shapes == null)
+ throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape");
+ inputType = TensorFlowImporter.importTensorType(shapes.getList().getShape(0));
+ }
+ else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name
+ name = tfNode.getName();
+ inputType = inputs.get(name);
+ if (inputType == null)
+ throw new IllegalArgumentException("An identity operation node is referencing input '" + name +
+ "', but there is no such input");
+ }
+ return new TypedTensorFunction(inputType, new VariableTensor(name));
+ }
+
+ TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
+ ensureArguments(2, arguments, "matmul");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(0);
+ if (a.type().rank() < 2 || b.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (a.type().rank() != b.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first
+ // and the last dimension of the second argument be not present in the first argument, while leaving the
+ // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
+
+ // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly
+
+ String beforeLastDim = "d" + (a.type().rank() - 1);
+ String lastDim = "d" + a.type().rank();
+ String afterLastDim = "d" + (a.type().rank() + 1);
+
+ Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim),
+ ImmutableList.of(lastDim, afterLastDim));
+ Matmul matmul = new Matmul(a.function(), renamedB, lastDim);
+ return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim),
+ new Rename(matmul, afterLastDim, lastDim));
+ }
+
+ TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
+ ensureArguments(1, arguments, "softmax");
+ TypedTensorFunction a = arguments.get(0);
+ // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
+ String dimension = "d" + (a.type().rank() - 1);
+ Softmax softmax = new Softmax(a.function(), dimension);
+ return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
+ }
+
+ private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
+ if ( arguments.size() != count)
+ throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName +
+ ", but got " + arguments.size());
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index e47f2ad53d9..167ff684725 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -1,6 +1,5 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-import com.google.common.collect.ImmutableList;
import com.google.protobuf.ProtocolStringList;
import com.google.protobuf.TextFormat;
import com.yahoo.io.IOUtils;
@@ -8,15 +7,8 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Matmul;
-import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.Softmax;
-import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
-import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
@@ -31,8 +23,6 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.function.DoubleBinaryOperator;
-import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;
/**
@@ -42,20 +32,7 @@ import java.util.stream.Collectors;
*/
public class TensorFlowImporter {
- /*
- A note on conversion from implicitly numbered to explicitly named dimensions:
- Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
- 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
- comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
- around dimension renaming operations which mirrors those built into the TF operation definitions.
-
- To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
- dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
- and the result is then renamed again (if necessary) to recover this convention across a full nested
- computation.
-
- This requires us to track tensor types throughout the conversion.
- */
+ private final OperationMapper operationMapper = new OperationMapper();
/**
* Imports a saved TensorFlow model from a directory.
@@ -116,7 +93,7 @@ public class TensorFlowImporter {
return inputs;
}
- private TensorType importTensorType(TensorShapeProto tensorShape) {
+ static TensorType importTensorType(TensorShapeProto tensorShape) {
TensorType.Builder b = new TensorType.Builder();
for (int i = 0; i < tensorShape.getDimCount(); i++) {
int dimensionSize = (int) tensorShape.getDim(i).getSize();
@@ -147,11 +124,11 @@ public class TensorFlowImporter {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
- case "add" : case "add_n" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add());
- case "acos" : return map(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.acos());
- case "identity" : return identity(tfNode, inputs);
- case "matmul" : return matmul(importArguments(tfNode, inputs, graph, indent));
- case "softmax" : return softmax(importArguments(tfNode, inputs, graph, indent));
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.acos());
+ case "identity" : return operationMapper.identity(tfNode, inputs);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, indent));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, indent));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
@@ -162,93 +139,6 @@ public class TensorFlowImporter {
.collect(Collectors.toList());
}
- private TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
- // Note that this generalizes the corresponding TF function as it does not verify that the tensor
- // types are the same, with the assumption that this already happened on the TF side
- // (and if not, this should do the right thing anyway)
- ensureArguments(2, arguments, "join");
- TypedTensorFunction a = arguments.get(0);
- TypedTensorFunction b = arguments.get(0);
-
- TensorType resultType = Join.outputType(a.type(), b.type());
- Join function = new Join(a.function(), b.function(), doubleFunction);
- return new TypedTensorFunction(resultType, function);
- }
-
- private TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
- ensureArguments(1, arguments, "apply");
- TypedTensorFunction a = arguments.get(0);
-
- TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
- com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
- return new TypedTensorFunction(resultType, function);
- }
-
- private TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) {
- // TODO: Verify with TF documentation
- String name;
- TensorType inputType;
- if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants
- if (tfNode.getInputList().size() != 1)
- throw new IllegalArgumentException("A Variable/read node must have one input but has " +
- tfNode.getInputList().size());
- name = tfNode.getInput(0);
- AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
- if (shapes == null)
- throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape");
- inputType = importTensorType(shapes.getList().getShape(0));
- }
- else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name
- name = tfNode.getName();
- inputType = inputs.get(name);
- if (inputType == null)
- throw new IllegalArgumentException("An identity operation node is referencing input '" + name +
- "', but there is no such input");
- }
- return new TypedTensorFunction(inputType, new VariableTensor(name));
- }
-
- private TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
- ensureArguments(2, arguments, "matmul");
- TypedTensorFunction a = arguments.get(0);
- TypedTensorFunction b = arguments.get(0);
- if (a.type().rank() < 2 || b.type.rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (a.type().rank() != b.type.rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first
- // and the last dimension of the second argument be not present in the first argument, while leaving the
- // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
-
- // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly
-
- String beforeLastDim = "d" + (a.type().rank() - 1);
- String lastDim = "d" + a.type().rank();
- String afterLastDim = "d" + (a.type().rank() + 1);
-
- Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim),
- ImmutableList.of(lastDim, afterLastDim));
- Matmul matmul = new Matmul(a.function(), renamedB, lastDim);
- return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim),
- new Rename(matmul, afterLastDim, lastDim));
- }
-
- private TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
- ensureArguments(1, arguments, "softmax");
- TypedTensorFunction a = arguments.get(0);
- // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
- String dimension = "d" + (a.type().rank() - 1);
- Softmax softmax = new Softmax(a.function(), dimension);
- return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
- }
-
- private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
- if ( arguments.size() != count)
- throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName +
- ", but got " + arguments.size());
- }
-
private NodeDef getNode(String name, GraphDef graph) {
return graph.getNodeList().stream()
.filter(node -> node.getName().equals(name))
@@ -278,21 +168,5 @@ public class TensorFlowImporter {
private boolean contains(String string, ProtocolStringList strings) {
return strings.asByteStringList().stream().anyMatch(s -> s.toStringUtf8().equals(string));
}
-
- /** A tensor function returning a specific tensor type */
- private static final class TypedTensorFunction {
-
- private final TensorType type;
- private final TensorFunction function;
-
- public TypedTensorFunction(TensorType type, TensorFunction function) {
- this.type = type;
- this.function = function;
- }
-
- public TensorType type() { return type; }
- public TensorFunction function() { return function; }
-
- }
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
new file mode 100644
index 00000000000..234d620d02f
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
@@ -0,0 +1,24 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+/**
+ * A tensor function returning a specific tensor type
+ *
+ * @author bratseth
+ */
+final class TypedTensorFunction {
+
+ private final TensorType type;
+ private final TensorFunction function;
+
+ public TypedTensorFunction(TensorType type, TensorFunction function) {
+ this.type = type;
+ this.function = function;
+ }
+
+ public TensorType type() { return type; }
+ public TensorFunction function() { return function; }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
index bfe2bb3a63b..f2164a1b177 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
@@ -12,7 +12,7 @@ import static org.junit.Assert.assertEquals;
* @author bratseth
*/
public class TensorFlowImporterTestCase {
-
+
@Test
public void testModel1() {
List<RankingExpression> expressions =
@@ -26,10 +26,10 @@ public class TensorFlowImporterTestCase {
"d1)",
toNonPrimitiveString(expressions.get(0)));
}
-
+
private String toNonPrimitiveString(RankingExpression expression) {
// toString on the wrapping expression will map to primitives, which is harder to read
return ((TensorFunctionNode)expression.getRoot()).function().toString();
}
-
+
}