summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-01 14:49:51 -0800
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-01 14:49:51 -0800
commitba18370c1056dfc675e6183fa234fbefdd7ee545 (patch)
tree5397ad543895934484d1c3b329cdc1fc9fc84088 /searchlib
parentc8c79c3363ad1149fa137f6b4899dad8369b309c (diff)
Import tensor constants
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java27
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java92
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java43
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java14
4 files changed, 140 insertions, 36 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index b0c6cc3fe7b..0717d3e1b2b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -1,12 +1,16 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
@@ -17,7 +21,7 @@ import java.util.function.DoubleUnaryOperator;
/**
* Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
- *
+ *
* @author bratseth
*/
class OperationMapper {
@@ -28,15 +32,17 @@ class OperationMapper {
'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
around dimension renaming operations which mirrors those built into the TF operation definitions.
-
+
To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
and the result is then renamed again (if necessary) to recover this convention across a full nested
computation.
-
+
This requires us to track tensor types throughout the conversion.
*/
+ private TensorConverter tensorConverter = new TensorConverter();
+
TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
// Note that this generalizes the corresponding TF function as it does not verify that the tensor
// types are the same, with the assumption that this already happened on the TF side
@@ -59,11 +65,10 @@ class OperationMapper {
return new TypedTensorFunction(resultType, function);
}
- TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) {
- // TODO: Verify with TF documentation
+ TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs, SavedModelBundle model) {
String name;
TensorType inputType;
- if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants
+ if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model
if (tfNode.getInputList().size() != 1)
throw new IllegalArgumentException("A Variable/read node must have one input but has " +
tfNode.getInputList().size());
@@ -72,6 +77,12 @@ class OperationMapper {
if (shapes == null)
throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape");
inputType = TensorFlowImporter.importTensorType(shapes.getList().getShape(0));
+ Session.Runner fetched = model.session().runner().fetch("Variable");
+ List<org.tensorflow.Tensor<?>> result = fetched.run();
+ if ( result.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + result.size());
+ Tensor constant = tensorConverter.toVespaTensor(result.get(0));
+ return new TypedTensorFunction(inputType, new ConstantTensor(constant));
}
else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name
name = tfNode.getName();
@@ -79,8 +90,8 @@ class OperationMapper {
if (inputType == null)
throw new IllegalArgumentException("An identity operation node is referencing input '" + name +
"', but there is no such input");
+ return new TypedTensorFunction(inputType, new VariableTensor(name));
}
- return new TypedTensorFunction(inputType, new VariableTensor(name));
}
TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
@@ -93,7 +104,7 @@ class OperationMapper {
throw new IllegalArgumentException("Tensors in matmul must have the same rank");
// Let the second-to-last dimension of the second tensor be the same as the last dimension of the first
- // and the last dimension of the second argument be not present in the first argument, while leaving the
+ // and the last dimension of the second argument be not present in the first argument, while leaving the
// rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
// TODO: Check if transpose_a or transpose_b is set and rename differently accordingly
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
new file mode 100644
index 00000000000..a74445008b7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
@@ -0,0 +1,92 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+
+/**
+ * @author bratseth
+ */
+public class TensorConverter {
+
+ public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
+ TensorType type = toVespaTensorType(tfTensor.shape());
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cellByDirectIndex(i, values.get(i));
+ return builder.build();
+ }
+
+ private TensorType toVespaTensorType(long[] shape) {
+ TensorType.Builder b = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (long dimensionSize : shape)
+ b.indexed("d" + (dimensionIndex++), (int)dimensionSize);
+ return b.build();
+ }
+
+ private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
+ switch (tfTensor.dataType()) {
+ case DOUBLE: return new DoubleValues(tfTensor);
+ case FLOAT: return new FloatValues(tfTensor);
+ // TODO: The rest
+ default:
+ throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tfTensor.dataType() + " to a Vespa tensor");
+ }
+ }
+
+ /** Allows reading values from buffers of various numeric types as bytes */
+ private static abstract class Values {
+
+ private final int size;
+
+ protected Values(int size) {
+ this.size = size;
+ }
+
+ abstract double get(int i);
+
+ int size() { return size; }
+
+ }
+
+ private static class DoubleValues extends Values {
+
+ private final DoubleBuffer values;
+
+ DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = DoubleBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+ private static class FloatValues extends Values {
+
+ private final FloatBuffer values;
+
+ FloatValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = FloatBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 91a0f863a14..c14f8c71a3e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -10,7 +10,6 @@ import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.SavedModel;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import org.tensorflow.framework.TensorShapeProto;
@@ -40,7 +39,7 @@ public class TensorFlowImporter {
public List<RankingExpression> importModel(String modelDir, MessageLogger logger) {
try {
SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), logger);
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, logger);
}
catch (IOException e) {
@@ -49,15 +48,7 @@ public class TensorFlowImporter {
}
- /** Import all declared inputs in all the graphs in the given model */
- private List<RankingExpression> importModel(SavedModel model, MessageLogger logger) {
- // TODO: Handle name conflicts between output keys in different graphs?
- return model.getMetaGraphsList().stream()
- .flatMap(graph -> importGraph(graph, logger).stream())
- .collect(Collectors.toList());
- }
-
- private List<RankingExpression> importGraph(MetaGraphDef graph, MessageLogger logger) {
+ private List<RankingExpression> importGraph(MetaGraphDef graph, SavedModelBundle model, MessageLogger logger) {
List<RankingExpression> expressions = new ArrayList<>();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap());
@@ -65,7 +56,8 @@ public class TensorFlowImporter {
try {
ExpressionNode result = importOutput(output.getValue(),
inputs,
- graph.getGraphDef());
+ graph.getGraphDef(),
+ model);
expressions.add(new RankingExpression(output.getKey(), result));
}
catch (IllegalArgumentException e) {
@@ -97,35 +89,38 @@ public class TensorFlowImporter {
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph) {
+ private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph, SavedModelBundle model) {
NodeDef node = getNode(nameOf(output.getName()), graph);
- return new TensorFunctionNode(importNode(node, inputs, graph, "").function());
+ return new TensorFunctionNode(importNode(node, inputs, graph, model).function());
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) {
- return tensorFunctionOf(tfNode, inputs, graph, indent);
+ private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, SavedModelBundle model) {
+ return tensorFunctionOf(tfNode, inputs, graph, model);
}
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode,
Map<String, TensorType> inputs,
GraphDef graph,
- String indent) {
+ SavedModelBundle model) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
- case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add());
- case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.acos());
- case "identity" : return operationMapper.identity(tfNode, inputs);
- case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, indent));
- case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, indent));
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model), ScalarFunctions.acos());
+ case "identity" : return operationMapper.identity(tfNode, inputs, model);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
- private List<TypedTensorFunction> importArguments(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) {
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode,
+ Map<String, TensorType> inputs,
+ GraphDef graph,
+ SavedModelBundle model) {
return tfNode.getInputList().stream()
- .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, indent + " "))
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model))
.collect(Collectors.toList());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
index 9b53b3824e2..936936dc3eb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
@@ -7,6 +7,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
+import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
@@ -22,11 +23,11 @@ public class TensorFlowImporterTestCase {
new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", logger);
// Check logged messages
- assertEquals(2, logger.messages.size());
+ assertEquals(2, logger.messages().size());
assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported",
- logger.messages.get(0));
+ logger.messages().get(0));
assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported",
- logger.messages.get(1));
+ logger.messages().get(1));
// Check resulting Vespa expression
assertEquals(1, expressions.size());
@@ -46,7 +47,12 @@ public class TensorFlowImporterTestCase {
private class TestLogger implements TensorFlowImporter.MessageLogger {
- List<String> messages = new ArrayList<>();
+ private List<String> messages = new ArrayList<>();
+
+ /** Returns the messages in sorted order */
+ public List<String> messages() {
+ return messages.stream().sorted().collect(Collectors.toList());
+ }
@Override
public void log(Level level, String message) {