From 9b458439aab4b1b48ee2e76b1e0b5e31ce0c3177 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 18 Dec 2019 16:46:13 +0100 Subject: Exmbrace serialized expressions --- .../rankingexpression/rule/TensorFunctionNode.java | 16 ++++++++++++---- .../RankingExpressionTestCase.java | 22 +++++++++++----------- .../evaluation/EvaluationTestCase.java | 2 +- 3 files changed, 24 insertions(+), 16 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 6200515462b..6e1cdf52158 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -169,16 +169,24 @@ public class TensorFunctionNode extends CompositeNode { if (outermost instanceof ExpressionToStringContext) { ExpressionToStringContext context = (ExpressionToStringContext)outermost; - return expression.toString(new StringBuilder(), - new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), - context.path, - context.parent).toString(); + ExpressionNode root = expression; + if (root instanceof CompositeNode && ! (root instanceof EmbracedNode) && ! isIdentifierReference(root)) + root = new EmbracedNode(root); // Output embraced if composite + return root.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent).toString(); } else { return expression.toString(); } } + private boolean isIdentifierReference(ExpressionNode node) { + if ( ! (node instanceof ReferenceNode)) return false; + return ((ReferenceNode)node).reference().isIdentifier(); + } + } /** diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index b750a7607cc..ea09de32137 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -150,11 +150,11 @@ public class RankingExpressionTestCase { "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)"); assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))", "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)"); - assertSerialization("tensor(x{}):{{x:a}:1 + 2 + 3,{x:b}:if (1 > 2, 3, 4),{x:c}:reduce(tensor0 * tensor1, sum)}", + assertSerialization("tensor(x{}):{{x:a}:(1 + 2 + 3),{x:b}:(if (1 > 2, 3, 4)),{x:c}:(reduce(tensor0 * tensor1, sum))}", "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }"); assertSerialization("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3}", "tensor(x[3]):[1.0, 2.0, 3]"); - assertSerialization("tensor(x[3]):{{x:0}:1.0,{x:1}:reduce(tensor0 * tensor1, sum),{x:2}:3}", + assertSerialization("tensor(x[3]):{{x:0}:1.0,{x:1}:(reduce(tensor0 * tensor1, sum)),{x:2}:3}", "tensor(x[3]):[1.0, sum(tensor0*tensor1), 3]"); } @@ -165,44 +165,44 @@ public class RankingExpressionTestCase { functions.add(new ExpressionFunction("tensorFunction", List.of(), new RankingExpression("tensor(x[3]):[1, 2, 3]"))); // Getting a value from a tensor supplied by a function, inside a tensor generate function - assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction)[x])"), + assertSerialization(List.of("tensor(x[3])((rankingExpression(tensorFunction)[x]))"), "tensor(x[3])(tensorFunction[x])", functions, false); // Getting a value from a tensor supplied by a function, where the value index is supplied by a function, inside a tensor generate function, short form - assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction)[rankingExpression(scalarFunction)])"), + assertSerialization(List.of("tensor(x[3])((rankingExpression(tensorFunction)[(rankingExpression(scalarFunction))]))"), "tensor(x[3])(tensorFunction[scalarFunction()])", functions, false); // 'scalarFunction' is interpreted as a label here since it is the short form of a mapped dimension - assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction){scalarFunction})"), + assertSerialization(List.of("tensor(x[3])((rankingExpression(tensorFunction){scalarFunction}))"), "tensor(x[3])(tensorFunction{scalarFunction})", functions, false); // Getting a value from a tensor supplied by a function, where the value index is supplied by a function, inside a tensor generate function, long form - assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction){x:rankingExpression(scalarFunction)})"), + assertSerialization(List.of("tensor(x[3])((rankingExpression(tensorFunction){x:(rankingExpression(scalarFunction))}))"), "tensor(x[3])(tensorFunction{x:scalarFunction()})", functions, false); // 'scalarFunction' without parentheses is interpreted as a label instead of a reference to the function - assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction){x:scalarFunction})"), + assertSerialization(List.of("tensor(x[3])((rankingExpression(tensorFunction){x:scalarFunction}))"), "tensor(x[3])(tensorFunction{x:scalarFunction})", functions, false); // Accessing a function in a dynamic tensor, short form - assertSerialization(List.of("tensor(x[2]):{{x:0}:rankingExpression(scalarFunction),{x:1}:rankingExpression(scalarFunction)}"), + assertSerialization(List.of("tensor(x[2]):{{x:0}:(rankingExpression(scalarFunction)),{x:1}:(rankingExpression(scalarFunction))}"), "tensor(x[2]):[scalarFunction(), scalarFunction()]", functions, false); // Accessing a function in a dynamic tensor, long form - assertSerialization(List.of("tensor(x{}):{{x:foo}:rankingExpression(scalarFunction),{x:bar}:rankingExpression(scalarFunction)}"), + assertSerialization(List.of("tensor(x{}):{{x:foo}:(rankingExpression(scalarFunction)),{x:bar}:(rankingExpression(scalarFunction))}"), "tensor(x{}):{{x:foo}:scalarFunction(), {x:bar}:scalarFunction()}", functions, false); // Shadowing - assertSerialization(List.of("tensor(scalarFunction[1])(rankingExpression(tensorFunction){x:scalarFunction + rankingExpression(scalarFunction)})"), + assertSerialization(List.of("tensor(scalarFunction[1])((rankingExpression(tensorFunction){x:(scalarFunction + rankingExpression(scalarFunction))}))"), "tensor(scalarFunction[1])(tensorFunction{x: scalarFunction + scalarFunction()})", - functions, false); + functions, true); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 642ef8b873b..ca2f6c6bbec 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -324,7 +324,7 @@ public class EvaluationTestCase { "{y:1}:((1+1)+a)}{y:1}" + "}"); - // tensor value + // tensor slice tester.assertEvaluates("3.0", "tensor0{x:1}", "{ {x:0}:1, {x:1}:3 }"); tester.assertEvaluates("1.2", "tensor0{key:foo,x:0}", true, "{ {key:foo,x:0}:1.2, {key:bar,x:0}:3 }"); tester.assertEvaluates("3.0", "tensor0{bar}", true, "{ {x:foo}:1, {x:bar}:3 }"); -- cgit v1.2.3