diff options
Diffstat (limited to 'searchlib')
2 files changed, 41 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 3a3410aeebb..2a6e6793bcd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -11,10 +11,13 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; +import java.util.stream.Collectors; /** * A free, parametrized function @@ -27,7 +30,12 @@ public class LambdaFunctionNode extends CompositeNode { private final ExpressionNode functionExpression; public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { - // TODO: Verify that the function only accesses the given arguments + if ( ! arguments.containsAll(featuresAccessedIn(functionExpression))) { + throw new IllegalArgumentException("Lambda " + functionExpression + " accesses features outside its scope: " + + featuresAccessedIn(functionExpression).stream() + .filter(f -> ! arguments.contains(f)) + .collect(Collectors.joining(", "))); + } this.arguments = ImmutableList.copyOf(arguments); this.functionExpression = functionExpression; } @@ -134,6 +142,22 @@ public class LambdaFunctionNode extends CompositeNode { }); } + private static Set<String> featuresAccessedIn(ExpressionNode node) { + if (node instanceof ReferenceNode) { + return Set.of(((ReferenceNode) node).reference().toString()); + } + else if (node instanceof NameNode) { // (This clause probably not necessary) + return Set.of(((NameNode) node).getValue()); + } + else if (node instanceof CompositeNode) { + Set<String> features = new HashSet<>(); + ((CompositeNode)node).children().forEach(child -> features.addAll(featuresAccessedIn(child))); + return features; + } + return Set.of(); + } + + private class DoubleUnaryLambda implements DoubleUnaryOperator { @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index ca2f6c6bbec..bc217983812 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -454,6 +454,22 @@ public class EvaluationTestCase { } @Test + public void testLambdaValidation() { + EvaluationTester tester = new EvaluationTester(); + try { + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:2, {d1:2 }:3 }", + "map(tensor0, f(x) (log10(x+sum(tensor0)))", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }"); + fail("Expected validation failure"); + } + catch (IllegalArgumentException e) { + // success + assertEquals("Lambda log10(x + reduce(tensor0, sum)) accesses features outside its scope: tensor0", + e.getMessage()); + } + + } + + @Test public void testExpand() { EvaluationTester tester = new EvaluationTester(); // Add a dimension using a literal tensor |