summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-22 11:42:35 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-22 11:42:35 +0100
commit92abbad9207758578a4c56b6c9fe7f332a6546ee (patch)
tree37d448d8357587e1e620c3babd02ac6ba2f9f654 /searchlib
parent59594cb7ff0d97164eff542f184afe576e342a4b (diff)
Parse generated tensor function trees
To make generated tensor function trees transparent to the config model we need to convert each tensor function node to the corresponding ranking expression node. This is most easily done by parsing the tensor function tree string output as a ranking expression (something which is required to always work in any case).
Diffstat (limited to 'searchlib')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java36
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java16
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java2
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java4
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj45
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java102
8 files changed, 79 insertions, 130 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
index 6e79877a657..1ec6ea4693b 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
@@ -2,18 +2,24 @@
package com.yahoo.searchlib.rankingexpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.parser.RankingExpressionParser;
import com.yahoo.searchlib.rankingexpression.parser.TokenMgrError;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
-import com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode;
-import java.io.*;
-import java.util.*;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.io.Reader;
+import java.io.Serializable;
+import java.io.StringReader;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
/**
* <p>A ranking expression. Ranking expressions are used to calculate a rank score for a searched instance from a set of
@@ -92,13 +98,15 @@ public class RankingExpression implements Serializable {
}
/**
- * Creates a ranking expression from a string
+ * Creates a new ranking expression by consuming from the reader
*
- * @param expression The reader that contains the string to parse.
+ * @param name the name of the ranking expression
+ * @param expression the expression to parse.
* @throws ParseException if the string could not be parsed.
*/
- public RankingExpression(String expression) throws ParseException {
+ public RankingExpression(String name, String expression) throws ParseException {
try {
+ this.name = name;
if (expression == null || expression.length() == 0) {
throw new IllegalArgumentException("Empty ranking expressions are not allowed");
}
@@ -112,6 +120,16 @@ public class RankingExpression implements Serializable {
}
/**
+ * Creates a ranking expression from a string
+ *
+ * @param expression The reader that contains the string to parse.
+ * @throws ParseException if the string could not be parsed.
+ */
+ public RankingExpression(String expression) throws ParseException {
+ this("", expression);
+ }
+
+ /**
* Creates a ranking expression from a file. For convenience, the file.getName() up to any dot becomes the name of
* this expression.
*
@@ -259,7 +277,7 @@ public class RankingExpression implements Serializable {
/**
* Creates a ranking expression from a string
- *
+ *
* @throws IllegalArgumentException if the string is not a valid ranking expression
*/
public static RankingExpression from(String expression) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 8c1b4a4e5fe..d70872b2048 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -117,7 +117,7 @@ class OperationMapper {
Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
result.constant(name, constant);
return new TypedTensorFunction(constant.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
+ new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant('" + name + "')")));
}
TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 45f2b21343f..a8cb5e6e1c7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -2,7 +2,7 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.yolean.Exceptions;
@@ -105,10 +105,16 @@ public class TensorFlowImporter {
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
- // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
- // will be used
- result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
- return function;
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
+ // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree
+ result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), function.function().toString()));
+ return function;
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function.function() +
+ " cannot be parsed as a ranking expression", e);
+ }
}
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index 1632a17748c..69df572272a 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -10,7 +10,7 @@ import java.util.Deque;
* An opaque name in a ranking expression. This is used to represent names passed to the context
* and interpreted by the given context in a way which is opaque to the ranking expressions.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public final class NameNode extends ExpressionNode {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index ff2c9d8ea6d..139709998b4 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -72,7 +72,7 @@ public final class ReferenceNode extends CompositeNode {
myArguments = null;
myOutput = null;
} else if (context.getFunction(myName) != null) {
- // Replace this whole node with a reference to another script.
+ // Replace by the referenced expression
ExpressionFunction function = context.getFunction(myName);
if (function != null && myArguments != null && function.arguments().size() == myArguments.size() && myOutput == null) {
String myPath = name + this.arguments.expressions();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 8af3448ca6f..b42570d3aea 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -139,7 +139,9 @@ public class TensorFunctionNode extends CompositeNode {
final Deque<String> path;
final CompositeNode parent;
- public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);
+ public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(new SerializationContext(),
+ null,
+ null);
public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
this.context = context;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index febd0a60bb3..a796eaa4ac0 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -242,7 +242,7 @@ ExpressionNode value() :
LOOKAHEAD(2) ret = ifExpression() |
LOOKAHEAD(4) ret = function() |
ret = feature() |
- ret = queryFeature() |
+ ret = legacyQueryFeature() |
( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) )
{
ret = not ? new NotNode(ret) : ret;
@@ -264,15 +264,6 @@ IfNode ifExpression() :
}
}
-ReferenceNode queryFeature() :
-{
- String name;
-}
-{
- ( <DOLLAR> name = identifier() )
- { return new ReferenceNode("query", Arrays.asList((ExpressionNode)new NameNode(name)), null); }
-}
-
ReferenceNode feature() :
{
List<ExpressionNode> args = null;
@@ -283,6 +274,16 @@ ReferenceNode feature() :
{ return new ReferenceNode(name, args, out); }
}
+// Query features can be referenced as "$name" instead of "query(name)". TODO: Warn this is deprecated
+ReferenceNode legacyQueryFeature() :
+{
+ String name;
+}
+{
+ ( <DOLLAR> name = identifier() )
+ { return new ReferenceNode("query", Arrays.asList((ExpressionNode)new NameNode(name)), null); }
+}
+
String outs() :
{
StringBuilder ret = new StringBuilder();
@@ -333,7 +334,7 @@ ExpressionNode arg() :
{
( ret = constantPrimitive() |
LOOKAHEAD(2) ret = feature() |
- name = identifier() { ret = new NameNode(name); } )
+ name = identifier() { ret = new NameNode(name); } )
{ return ret; }
}
@@ -342,7 +343,7 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( function = scalarOrTensorFunction() | function = tensorFunction() )
+ ( LOOKAHEAD(2) function = scalarOrTensorFunction() | function = tensorFunction() )
{ return function; }
}
@@ -717,7 +718,7 @@ String identifier() :
Function func;
}
{
- name = tensorFunctionName() { return name; } |
+ LOOKAHEAD(2) name = tensorFunctionName() { return name; } |
func = unaryFunctionName() { return func.toString(); } |
func = binaryFunctionName() { return func.toString(); } |
<IF> { return token.image; } |
@@ -770,11 +771,25 @@ List<String> tagCommaLeadingList() :
ConstantNode constantPrimitive() :
{
String sign = "";
+ String value;
}
{
( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
+ ( <INTEGER> { value = token.image; } |
+ <FLOAT> { value = token.image; } |
+ value = stringPath() )
+ { return new ConstantNode(Value.parse(sign + value),sign + value); }
+}
+
+// Strings separated by "/"
+String stringPath() :
+{
+ StringBuilder b = new StringBuilder();
+}
+{
+ <STRING> { b.append(token.image); }
+ ( LOOKAHEAD(2) <DIV> <STRING> { b.append("/").append(token.image); } ) *
+ { return b.toString(); }
}
Value primitiveValue() :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
index 3ec074dc653..e22e4a36bab 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
@@ -60,10 +60,7 @@ public class TensorflowImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("" +
- "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
- "rename(constant(Variable_1), d0, d1), " +
- "f(a,b)(a + b))",
+ assertEquals("join(rename(reduce(join(Placeholder, rename(constant('Variable'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('Variable_1'), d0, d1), f(a,b)(a + b))",
toNonPrimitiveString(output));
// Test execution
@@ -139,97 +136,8 @@ public class TensorflowImportTestCase {
assertNotNull(output);
assertEquals("dnn/outputs/add", output.getName());
assertEquals("" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "X, " +
- "rename(" +
- "constant(dnn/hidden1/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden1/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(if(a < 0, exp(a)-1, a))" +
- "), " +
- "rename(" +
- "constant(dnn/hidden2/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden2/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(max(0, a))" +
- "), " +
- "rename(" +
- "constant(dnn/hidden3/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden3/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(1 / (1 + exp(-a)))" +
- "), " +
- "rename(" +
- "constant(dnn/outputs/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/outputs/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- ")",
- toNonPrimitiveString(output));
+ "join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(X, rename(constant('dnn/hidden1/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden1/bias'), d0, d1), f(a,b)(a + b)), f(a)(if (a < 0, exp(a) - 1, a))), rename(constant('dnn/hidden2/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden2/bias'), d0, d1), f(a,b)(a + b)), f(a)(max(0,a))), rename(constant('dnn/hidden3/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden3/bias'), d0, d1), f(a,b)(a + b)), f(a)(1 / (1 + exp(-a)))), rename(constant('dnn/outputs/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/outputs/bias'), d0, d1), f(a,b)(a + b))",
+ toNonPrimitiveString(output));
// Test constants
assertEqualResult(model, result, "X", "dnn/hidden1/weights/read");
@@ -262,7 +170,7 @@ public class TensorflowImportTestCase {
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
- assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
+ assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
@@ -276,7 +184,7 @@ public class TensorflowImportTestCase {
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.constants().forEach((name, tensor) -> context.put("constant('" + name + "')", new TensorValue(tensor)));
return context;
}