summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-02 16:52:34 -0800
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-02 16:52:34 -0800
commit3df3c57607c73bda31a60af6695aeafd8a57fabb (patch)
tree7473bdc08dc64140fb8d0ac5bfb6a7197b636231 /searchlib
parentba18370c1056dfc675e6183fa234fbefdd7ee545 (diff)
Import and return constant tensors
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java23
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java27
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java48
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java26
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java23
5 files changed, 100 insertions, 47 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java
new file mode 100644
index 00000000000..235771bfa9c
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java
@@ -0,0 +1,23 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.Tensor;
+
+/**
+ * A tensor with a name
+ *
+ * @author bratseth
+ */
+public class NamedTensor {
+
+ private final String name;
+ private final Tensor tensor;
+
+ public NamedTensor(String name, Tensor tensor) {
+ this.name = name;
+ this.tensor = tensor;
+ }
+
+ public String name() { return name; }
+ public Tensor tensor() { return tensor; }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 0717d3e1b2b..183cfabbd87 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -1,10 +1,11 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Rename;
@@ -49,7 +50,7 @@ class OperationMapper {
// (and if not, this should do the right thing anyway)
ensureArguments(2, arguments, "join");
TypedTensorFunction a = arguments.get(0);
- TypedTensorFunction b = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
TensorType resultType = Join.outputType(a.type(), b.type());
Join function = new Join(a.function(), b.function(), doubleFunction);
@@ -65,9 +66,10 @@ class OperationMapper {
return new TypedTensorFunction(resultType, function);
}
- TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs, SavedModelBundle model) {
+ TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs, SavedModelBundle model,
+ List<NamedTensor> constants) {
String name;
- TensorType inputType;
+ TensorType type;
if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model
if (tfNode.getInputList().size() != 1)
throw new IllegalArgumentException("A Variable/read node must have one input but has " +
@@ -75,29 +77,30 @@ class OperationMapper {
name = tfNode.getInput(0);
AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
if (shapes == null)
- throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape");
- inputType = TensorFlowImporter.importTensorType(shapes.getList().getShape(0));
- Session.Runner fetched = model.session().runner().fetch("Variable");
+ throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
+ Session.Runner fetched = model.session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> result = fetched.run();
if ( result.size() != 1)
throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + result.size());
Tensor constant = tensorConverter.toVespaTensor(result.get(0));
- return new TypedTensorFunction(inputType, new ConstantTensor(constant));
+ constants.add(new NamedTensor(name, constant));
+ return new TypedTensorFunction(constant.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
}
else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name
name = tfNode.getName();
- inputType = inputs.get(name);
- if (inputType == null)
+ type = inputs.get(name);
+ if (type == null)
throw new IllegalArgumentException("An identity operation node is referencing input '" + name +
"', but there is no such input");
- return new TypedTensorFunction(inputType, new VariableTensor(name));
+ return new TypedTensorFunction(type, new VariableTensor(name));
}
}
TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
ensureArguments(2, arguments, "matmul");
TypedTensorFunction a = arguments.get(0);
- TypedTensorFunction b = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
if (a.type().rank() < 2 || b.type().rank() < 2)
throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
if (a.type().rank() != b.type().rank())
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index c14f8c71a3e..51f1e444e70 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
@@ -35,12 +36,16 @@ public class TensorFlowImporter {
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a pbtxt file.
* The name of the model is taken at the pbtxt file name (not including the .pbtxt ending).
+ *
+ * @param modelDir the directory containing the TensorFlow model files to import
+ * @param constants any constant tensors imported from the TensorFlow model and referenced in the returned expressions
+ * @param logger a receiver of any messages generated by the import process
+ * @return the ranking expressions resulting from importing this TenorFlow model
*/
- public List<RankingExpression> importModel(String modelDir, MessageLogger logger) {
+ public List<RankingExpression> importModel(String modelDir, List<NamedTensor> constants, MessageLogger logger) {
try {
SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, logger);
-
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, constants, logger);
}
catch (IOException e) {
throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e);
@@ -48,7 +53,8 @@ public class TensorFlowImporter {
}
- private List<RankingExpression> importGraph(MetaGraphDef graph, SavedModelBundle model, MessageLogger logger) {
+ private List<RankingExpression> importGraph(MetaGraphDef graph, SavedModelBundle model,
+ List<NamedTensor> constants, MessageLogger logger) {
List<RankingExpression> expressions = new ArrayList<>();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap());
@@ -57,7 +63,8 @@ public class TensorFlowImporter {
ExpressionNode result = importOutput(output.getValue(),
inputs,
graph.getGraphDef(),
- model);
+ model,
+ constants);
expressions.add(new RankingExpression(output.getKey(), result));
}
catch (IllegalArgumentException e) {
@@ -77,7 +84,7 @@ public class TensorFlowImporter {
return inputs;
}
- static TensorType importTensorType(TensorShapeProto tensorShape) {
+ private TensorType importTensorType(TensorShapeProto tensorShape) {
TensorType.Builder b = new TensorType.Builder();
for (int i = 0; i < tensorShape.getDimCount(); i++) {
int dimensionSize = (int) tensorShape.getDim(i).getSize();
@@ -89,28 +96,32 @@ public class TensorFlowImporter {
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph, SavedModelBundle model) {
+ private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph,
+ SavedModelBundle model, List<NamedTensor> constants) {
NodeDef node = getNode(nameOf(output.getName()), graph);
- return new TensorFunctionNode(importNode(node, inputs, graph, model).function());
+ TensorFunction function = importNode(node, inputs, graph, model, constants).function();
+ return new TensorFunctionNode(function); // wrap top level (only) as an expression
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, SavedModelBundle model) {
- return tensorFunctionOf(tfNode, inputs, graph, model);
+ private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph,
+ SavedModelBundle model, List<NamedTensor> constants) {
+ return tensorFunctionOf(tfNode, inputs, graph, model, constants);
}
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode,
Map<String, TensorType> inputs,
GraphDef graph,
- SavedModelBundle model) {
+ SavedModelBundle model,
+ List<NamedTensor> constants) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
- case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model), ScalarFunctions.add());
- case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model), ScalarFunctions.acos());
- case "identity" : return operationMapper.identity(tfNode, inputs, model);
- case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model));
- case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model));
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.acos());
+ case "identity" : return operationMapper.identity(tfNode, inputs, model, constants);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model, constants));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model, constants));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
@@ -118,9 +129,10 @@ public class TensorFlowImporter {
private List<TypedTensorFunction> importArguments(NodeDef tfNode,
Map<String, TensorType> inputs,
GraphDef graph,
- SavedModelBundle model) {
+ SavedModelBundle model,
+ List<NamedTensor> constants) {
return tfNode.getInputList().stream()
- .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model))
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model, constants))
.collect(Collectors.toList());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ab5f1e7191d..d1f4cbddf6e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -21,18 +21,18 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
- @Beta
+@Beta
public class TensorFunctionNode extends CompositeNode {
private final TensorFunction function;
-
+
public TensorFunctionNode(TensorFunction function) {
this.function = function;
}
/** Returns the tensor function wrapped by this */
public TensorFunction function() { return function; }
-
+
@Override
public List<ExpressionNode> children() {
return function.functionArguments().stream()
@@ -53,7 +53,7 @@ public class TensorFunctionNode extends CompositeNode {
// Serialize as primitive
return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
}
-
+
@Override
public Value evaluate(Context context) {
return new TensorValue(function.evaluate(context));
@@ -62,8 +62,8 @@ public class TensorFunctionNode extends CompositeNode {
public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
return new TensorFunctionExpressionNode(node);
}
-
- /**
+
+ /**
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
@@ -71,13 +71,13 @@ public class TensorFunctionNode extends CompositeNode {
/** An expression which produces a tensor */
private final ExpressionNode expression;
-
+
public TensorFunctionExpressionNode(ExpressionNode expression) {
this.expression = expression;
}
-
+
@Override
- public List<TensorFunction> functionArguments() {
+ public List<TensorFunction> functionArguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(TensorFunctionExpressionNode::new)
@@ -111,7 +111,7 @@ public class TensorFunctionNode extends CompositeNode {
public String toString() {
return toString(ExpressionNodeToStringContext.empty);
}
-
+
@Override
public String toString(ToStringContext c) {
if (c instanceof ExpressionNodeToStringContext) {
@@ -124,14 +124,14 @@ public class TensorFunctionNode extends CompositeNode {
}
}
-
+
/** Allows passing serialization context arguments through TensorFunctions */
private static class ExpressionNodeToStringContext implements ToStringContext {
-
+
final SerializationContext context;
final Deque<String> path;
final CompositeNode parent;
-
+
public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);
public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
index 936936dc3eb..aaf198a9e8f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
@@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
import org.junit.Test;
import java.util.ArrayList;
@@ -18,9 +19,23 @@ public class TensorFlowImporterTestCase {
@Test
public void testModel1() {
+ List<NamedTensor> constants = new ArrayList<>();
TestLogger logger = new TestLogger();
List<RankingExpression> expressions =
- new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", logger);
+ new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", constants, logger);
+
+ // Check constants
+ assertEquals(2, constants.size());
+
+ assertEquals("Variable", constants.get(0).name());
+ assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ constants.get(0).tensor().type());
+ assertEquals(7840, constants.get(0).tensor().size());
+
+ assertEquals("Variable_1", constants.get(1).name());
+ assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ constants.get(1).tensor().type());
+ assertEquals(10, constants.get(1).tensor().size());
// Check logged messages
assertEquals(2, logger.messages().size());
@@ -33,10 +48,10 @@ public class TensorFlowImporterTestCase {
assertEquals(1, expressions.size());
assertEquals("scores", expressions.get(0).getName());
assertEquals("" +
- "softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " +
- "rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " +
+ "softmax(join(rename(matmul(x, rename(constant(Variable), (d1, d2), (d2, d3)), d2), d3, d2), " +
+ "constant(Variable_1), " +
"f(a,b)(a + b)), " +
- "d1)",
+ "d0)",
toNonPrimitiveString(expressions.get(0)));
}