summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2018-01-22 12:29:53 +0100
committerGitHub <noreply@github.com>2018-01-22 12:29:53 +0100
commit4148debe89932119346b102a81164921af007d00 (patch)
tree132a9b0af3002cf1ddbff9048ad63fbca4143bb0 /searchlib
parentc69fdcea1939f225216e95907d30d9b99fcbebae (diff)
parent92abbad9207758578a4c56b6c9fe7f332a6546ee (diff)
Merge pull request #4730 from vespa-engine/bratseth/reparse-expressions
Parse generated tensor function trees
Diffstat (limited to 'searchlib')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java36
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java16
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java2
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java4
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj45
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java102
8 files changed, 79 insertions, 130 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
index 6e79877a657..1ec6ea4693b 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
@@ -2,18 +2,24 @@
package com.yahoo.searchlib.rankingexpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.parser.RankingExpressionParser;
import com.yahoo.searchlib.rankingexpression.parser.TokenMgrError;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
-import com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode;
-import java.io.*;
-import java.util.*;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.io.Reader;
+import java.io.Serializable;
+import java.io.StringReader;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
/**
* <p>A ranking expression. Ranking expressions are used to calculate a rank score for a searched instance from a set of
@@ -92,13 +98,15 @@ public class RankingExpression implements Serializable {
}
/**
- * Creates a ranking expression from a string
+ * Creates a new ranking expression by consuming from the reader
*
- * @param expression The reader that contains the string to parse.
+ * @param name the name of the ranking expression
+ * @param expression the expression to parse.
* @throws ParseException if the string could not be parsed.
*/
- public RankingExpression(String expression) throws ParseException {
+ public RankingExpression(String name, String expression) throws ParseException {
try {
+ this.name = name;
if (expression == null || expression.length() == 0) {
throw new IllegalArgumentException("Empty ranking expressions are not allowed");
}
@@ -112,6 +120,16 @@ public class RankingExpression implements Serializable {
}
/**
+ * Creates a ranking expression from a string
+ *
+ * @param expression The reader that contains the string to parse.
+ * @throws ParseException if the string could not be parsed.
+ */
+ public RankingExpression(String expression) throws ParseException {
+ this("", expression);
+ }
+
+ /**
* Creates a ranking expression from a file. For convenience, the file.getName() up to any dot becomes the name of
* this expression.
*
@@ -259,7 +277,7 @@ public class RankingExpression implements Serializable {
/**
* Creates a ranking expression from a string
- *
+ *
* @throws IllegalArgumentException if the string is not a valid ranking expression
*/
public static RankingExpression from(String expression) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 8c1b4a4e5fe..d70872b2048 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -117,7 +117,7 @@ class OperationMapper {
Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
result.constant(name, constant);
return new TypedTensorFunction(constant.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
+ new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant('" + name + "')")));
}
TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 45f2b21343f..a8cb5e6e1c7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -2,7 +2,7 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.yolean.Exceptions;
@@ -105,10 +105,16 @@ public class TensorFlowImporter {
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
- // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
- // will be used
- result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
- return function;
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
+ // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree
+ result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), function.function().toString()));
+ return function;
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function.function() +
+ " cannot be parsed as a ranking expression", e);
+ }
}
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index 1632a17748c..69df572272a 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -10,7 +10,7 @@ import java.util.Deque;
* An opaque name in a ranking expression. This is used to represent names passed to the context
* and interpreted by the given context in a way which is opaque to the ranking expressions.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public final class NameNode extends ExpressionNode {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index ff2c9d8ea6d..139709998b4 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -72,7 +72,7 @@ public final class ReferenceNode extends CompositeNode {
myArguments = null;
myOutput = null;
} else if (context.getFunction(myName) != null) {
- // Replace this whole node with a reference to another script.
+ // Replace by the referenced expression
ExpressionFunction function = context.getFunction(myName);
if (function != null && myArguments != null && function.arguments().size() == myArguments.size() && myOutput == null) {
String myPath = name + this.arguments.expressions();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 8af3448ca6f..b42570d3aea 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -139,7 +139,9 @@ public class TensorFunctionNode extends CompositeNode {
final Deque<String> path;
final CompositeNode parent;
- public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);
+ public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(new SerializationContext(),
+ null,
+ null);
public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
this.context = context;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index febd0a60bb3..a796eaa4ac0 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -242,7 +242,7 @@ ExpressionNode value() :
LOOKAHEAD(2) ret = ifExpression() |
LOOKAHEAD(4) ret = function() |
ret = feature() |
- ret = queryFeature() |
+ ret = legacyQueryFeature() |
( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) )
{
ret = not ? new NotNode(ret) : ret;
@@ -264,15 +264,6 @@ IfNode ifExpression() :
}
}
-ReferenceNode queryFeature() :
-{
- String name;
-}
-{
- ( <DOLLAR> name = identifier() )
- { return new ReferenceNode("query", Arrays.asList((ExpressionNode)new NameNode(name)), null); }
-}
-
ReferenceNode feature() :
{
List<ExpressionNode> args = null;
@@ -283,6 +274,16 @@ ReferenceNode feature() :
{ return new ReferenceNode(name, args, out); }
}
+// Query features can be referenced as "$name" instead of "query(name)". TODO: Warn this is deprecated
+ReferenceNode legacyQueryFeature() :
+{
+ String name;
+}
+{
+ ( <DOLLAR> name = identifier() )
+ { return new ReferenceNode("query", Arrays.asList((ExpressionNode)new NameNode(name)), null); }
+}
+
String outs() :
{
StringBuilder ret = new StringBuilder();
@@ -333,7 +334,7 @@ ExpressionNode arg() :
{
( ret = constantPrimitive() |
LOOKAHEAD(2) ret = feature() |
- name = identifier() { ret = new NameNode(name); } )
+ name = identifier() { ret = new NameNode(name); } )
{ return ret; }
}
@@ -342,7 +343,7 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( function = scalarOrTensorFunction() | function = tensorFunction() )
+ ( LOOKAHEAD(2) function = scalarOrTensorFunction() | function = tensorFunction() )
{ return function; }
}
@@ -717,7 +718,7 @@ String identifier() :
Function func;
}
{
- name = tensorFunctionName() { return name; } |
+ LOOKAHEAD(2) name = tensorFunctionName() { return name; } |
func = unaryFunctionName() { return func.toString(); } |
func = binaryFunctionName() { return func.toString(); } |
<IF> { return token.image; } |
@@ -770,11 +771,25 @@ List<String> tagCommaLeadingList() :
ConstantNode constantPrimitive() :
{
String sign = "";
+ String value;
}
{
( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
+ ( <INTEGER> { value = token.image; } |
+ <FLOAT> { value = token.image; } |
+ value = stringPath() )
+ { return new ConstantNode(Value.parse(sign + value),sign + value); }
+}
+
+// Strings separated by "/"
+String stringPath() :
+{
+ StringBuilder b = new StringBuilder();
+}
+{
+ <STRING> { b.append(token.image); }
+ ( LOOKAHEAD(2) <DIV> <STRING> { b.append("/").append(token.image); } ) *
+ { return b.toString(); }
}
Value primitiveValue() :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
index 3ec074dc653..e22e4a36bab 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
@@ -60,10 +60,7 @@ public class TensorflowImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("" +
- "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
- "rename(constant(Variable_1), d0, d1), " +
- "f(a,b)(a + b))",
+ assertEquals("join(rename(reduce(join(Placeholder, rename(constant('Variable'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('Variable_1'), d0, d1), f(a,b)(a + b))",
toNonPrimitiveString(output));
// Test execution
@@ -139,97 +136,8 @@ public class TensorflowImportTestCase {
assertNotNull(output);
assertEquals("dnn/outputs/add", output.getName());
assertEquals("" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "X, " +
- "rename(" +
- "constant(dnn/hidden1/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden1/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(if(a < 0, exp(a)-1, a))" +
- "), " +
- "rename(" +
- "constant(dnn/hidden2/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden2/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(max(0, a))" +
- "), " +
- "rename(" +
- "constant(dnn/hidden3/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden3/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(1 / (1 + exp(-a)))" +
- "), " +
- "rename(" +
- "constant(dnn/outputs/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/outputs/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- ")",
- toNonPrimitiveString(output));
+ "join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(X, rename(constant('dnn/hidden1/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden1/bias'), d0, d1), f(a,b)(a + b)), f(a)(if (a < 0, exp(a) - 1, a))), rename(constant('dnn/hidden2/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden2/bias'), d0, d1), f(a,b)(a + b)), f(a)(max(0,a))), rename(constant('dnn/hidden3/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden3/bias'), d0, d1), f(a,b)(a + b)), f(a)(1 / (1 + exp(-a)))), rename(constant('dnn/outputs/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/outputs/bias'), d0, d1), f(a,b)(a + b))",
+ toNonPrimitiveString(output));
// Test constants
assertEqualResult(model, result, "X", "dnn/hidden1/weights/read");
@@ -262,7 +170,7 @@ public class TensorflowImportTestCase {
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
- assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
+ assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
@@ -276,7 +184,7 @@ public class TensorflowImportTestCase {
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.constants().forEach((name, tensor) -> context.put("constant('" + name + "')", new TensorValue(tensor)));
return context;
}