diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-01-22 12:29:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-22 12:29:53 +0100 |
commit | 4148debe89932119346b102a81164921af007d00 (patch) | |
tree | 132a9b0af3002cf1ddbff9048ad63fbca4143bb0 /searchlib | |
parent | c69fdcea1939f225216e95907d30d9b99fcbebae (diff) | |
parent | 92abbad9207758578a4c56b6c9fe7f332a6546ee (diff) |
Merge pull request #4730 from vespa-engine/bratseth/reparse-expressions
Parse generated tensor function trees
Diffstat (limited to 'searchlib')
8 files changed, 79 insertions, 130 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index 6e79877a657..1ec6ea4693b 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -2,18 +2,24 @@ package com.yahoo.searchlib.rankingexpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.parser.RankingExpressionParser; import com.yahoo.searchlib.rankingexpression.parser.TokenMgrError; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; -import com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode; -import java.io.*; -import java.util.*; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.Reader; +import java.io.Serializable; +import java.io.StringReader; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; /** * <p>A ranking expression. Ranking expressions are used to calculate a rank score for a searched instance from a set of @@ -92,13 +98,15 @@ public class RankingExpression implements Serializable { } /** - * Creates a ranking expression from a string + * Creates a new ranking expression by consuming from the reader * - * @param expression The reader that contains the string to parse. + * @param name the name of the ranking expression + * @param expression the expression to parse. * @throws ParseException if the string could not be parsed. */ - public RankingExpression(String expression) throws ParseException { + public RankingExpression(String name, String expression) throws ParseException { try { + this.name = name; if (expression == null || expression.length() == 0) { throw new IllegalArgumentException("Empty ranking expressions are not allowed"); } @@ -112,6 +120,16 @@ public class RankingExpression implements Serializable { } /** + * Creates a ranking expression from a string + * + * @param expression The reader that contains the string to parse. + * @throws ParseException if the string could not be parsed. + */ + public RankingExpression(String expression) throws ParseException { + this("", expression); + } + + /** * Creates a ranking expression from a file. For convenience, the file.getName() up to any dot becomes the name of * this expression. * @@ -259,7 +277,7 @@ public class RankingExpression implements Serializable { /** * Creates a ranking expression from a string - * + * * @throws IllegalArgumentException if the string is not a valid ranking expression */ public static RankingExpression from(String expression) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 8c1b4a4e5fe..d70872b2048 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -117,7 +117,7 @@ class OperationMapper { Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); result.constant(name, constant); return new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); + new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant('" + name + "')"))); } TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 45f2b21343f..a8cb5e6e1c7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -2,7 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.yolean.Exceptions; @@ -105,10 +105,16 @@ public class TensorFlowImporter { /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); - // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output - // will be used - result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function()))); - return function; + try { + // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output + // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree + result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), function.function().toString())); + return function; + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function.function() + + " cannot be parsed as a ranking expression", e); + } } private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index 1632a17748c..69df572272a 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -10,7 +10,7 @@ import java.util.Deque; * An opaque name in a ranking expression. This is used to represent names passed to the context * and interpreted by the given context in a way which is opaque to the ranking expressions. * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public final class NameNode extends ExpressionNode { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index ff2c9d8ea6d..139709998b4 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -72,7 +72,7 @@ public final class ReferenceNode extends CompositeNode { myArguments = null; myOutput = null; } else if (context.getFunction(myName) != null) { - // Replace this whole node with a reference to another script. + // Replace by the referenced expression ExpressionFunction function = context.getFunction(myName); if (function != null && myArguments != null && function.arguments().size() == myArguments.size() && myOutput == null) { String myPath = name + this.arguments.expressions(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 8af3448ca6f..b42570d3aea 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -139,7 +139,9 @@ public class TensorFunctionNode extends CompositeNode { final Deque<String> path; final CompositeNode parent; - public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null); + public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(new SerializationContext(), + null, + null); public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { this.context = context; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index febd0a60bb3..a796eaa4ac0 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -242,7 +242,7 @@ ExpressionNode value() : LOOKAHEAD(2) ret = ifExpression() | LOOKAHEAD(4) ret = function() | ret = feature() | - ret = queryFeature() | + ret = legacyQueryFeature() | ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) ) { ret = not ? new NotNode(ret) : ret; @@ -264,15 +264,6 @@ IfNode ifExpression() : } } -ReferenceNode queryFeature() : -{ - String name; -} -{ - ( <DOLLAR> name = identifier() ) - { return new ReferenceNode("query", Arrays.asList((ExpressionNode)new NameNode(name)), null); } -} - ReferenceNode feature() : { List<ExpressionNode> args = null; @@ -283,6 +274,16 @@ ReferenceNode feature() : { return new ReferenceNode(name, args, out); } } +// Query features can be referenced as "$name" instead of "query(name)". TODO: Warn this is deprecated +ReferenceNode legacyQueryFeature() : +{ + String name; +} +{ + ( <DOLLAR> name = identifier() ) + { return new ReferenceNode("query", Arrays.asList((ExpressionNode)new NameNode(name)), null); } +} + String outs() : { StringBuilder ret = new StringBuilder(); @@ -333,7 +334,7 @@ ExpressionNode arg() : { ( ret = constantPrimitive() | LOOKAHEAD(2) ret = feature() | - name = identifier() { ret = new NameNode(name); } ) + name = identifier() { ret = new NameNode(name); } ) { return ret; } } @@ -342,7 +343,7 @@ ExpressionNode function() : ExpressionNode function; } { - ( function = scalarOrTensorFunction() | function = tensorFunction() ) + ( LOOKAHEAD(2) function = scalarOrTensorFunction() | function = tensorFunction() ) { return function; } } @@ -717,7 +718,7 @@ String identifier() : Function func; } { - name = tensorFunctionName() { return name; } | + LOOKAHEAD(2) name = tensorFunctionName() { return name; } | func = unaryFunctionName() { return func.toString(); } | func = binaryFunctionName() { return func.toString(); } | <IF> { return token.image; } | @@ -770,11 +771,25 @@ List<String> tagCommaLeadingList() : ConstantNode constantPrimitive() : { String sign = ""; + String value; } { ( <SUB> { sign = "-";} ) ? - ( <INTEGER> | <FLOAT> | <STRING> ) - { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); } + ( <INTEGER> { value = token.image; } | + <FLOAT> { value = token.image; } | + value = stringPath() ) + { return new ConstantNode(Value.parse(sign + value),sign + value); } +} + +// Strings separated by "/" +String stringPath() : +{ + StringBuilder b = new StringBuilder(); +} +{ + <STRING> { b.append(token.image); } + ( LOOKAHEAD(2) <DIV> <STRING> { b.append("/").append(token.image); } ) * + { return b.toString(); } } Value primitiveValue() : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java index 3ec074dc653..e22e4a36bab 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java @@ -60,10 +60,7 @@ public class TensorflowImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("" + - "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + - "rename(constant(Variable_1), d0, d1), " + - "f(a,b)(a + b))", + assertEquals("join(rename(reduce(join(Placeholder, rename(constant('Variable'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('Variable_1'), d0, d1), f(a,b)(a + b))", toNonPrimitiveString(output)); // Test execution @@ -139,97 +136,8 @@ public class TensorflowImportTestCase { assertNotNull(output); assertEquals("dnn/outputs/add", output.getName()); assertEquals("" + - "join(" + - "rename(" + - "matmul(" + - "map(" + - "join(" + - "rename(" + - "matmul(" + - "map(" + - "join(" + - "rename(" + - "matmul(" + - "map(" + - "join(" + - "rename(" + - "matmul(" + - "X, " + - "rename(" + - "constant(dnn/hidden1/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/hidden1/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - "), " + - "f(a)(if(a < 0, exp(a)-1, a))" + - "), " + - "rename(" + - "constant(dnn/hidden2/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/hidden2/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - "), " + - "f(a)(max(0, a))" + - "), " + - "rename(" + - "constant(dnn/hidden3/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/hidden3/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - "), " + - "f(a)(1 / (1 + exp(-a)))" + - "), " + - "rename(" + - "constant(dnn/outputs/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/outputs/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - ")", - toNonPrimitiveString(output)); + "join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(X, rename(constant('dnn/hidden1/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden1/bias'), d0, d1), f(a,b)(a + b)), f(a)(if (a < 0, exp(a) - 1, a))), rename(constant('dnn/hidden2/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden2/bias'), d0, d1), f(a,b)(a + b)), f(a)(max(0,a))), rename(constant('dnn/hidden3/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden3/bias'), d0, d1), f(a,b)(a + b)), f(a)(1 / (1 + exp(-a)))), rename(constant('dnn/outputs/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/outputs/bias'), d0, d1), f(a,b)(a + b))", + toNonPrimitiveString(output)); // Test constants assertEqualResult(model, result, "X", "dnn/hidden1/weights/read"); @@ -262,7 +170,7 @@ public class TensorflowImportTestCase { Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); - assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); + assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { @@ -276,7 +184,7 @@ public class TensorflowImportTestCase { private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.constants().forEach((name, tensor) -> context.put("constant('" + name + "')", new TensorValue(tensor))); return context; } |