diff options
author | Jon Bratseth <bratseth@oath.com> | 2020-01-07 11:15:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-07 11:15:30 +0100 |
commit | 25da72ef231afeb7e1689eabdcf9500a5cb5d14d (patch) | |
tree | 2029aeeb3336acf75990f41af406b57d747a5822 | |
parent | 0854f7fdfa6e23ed66f176c75dfc49a19b198589 (diff) | |
parent | a6bf3edfb2584b42062254d6a3ca06e91ba2487c (diff) |
Merge pull request #11643 from vespa-engine/bratseth/validate-lambdas
Validate lambdas
2 files changed, 41 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 3a3410aeebb..2a6e6793bcd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -11,10 +11,13 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; +import java.util.stream.Collectors; /** * A free, parametrized function @@ -27,7 +30,12 @@ public class LambdaFunctionNode extends CompositeNode { private final ExpressionNode functionExpression; public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { - // TODO: Verify that the function only accesses the given arguments + if ( ! arguments.containsAll(featuresAccessedIn(functionExpression))) { + throw new IllegalArgumentException("Lambda " + functionExpression + " accesses features outside its scope: " + + featuresAccessedIn(functionExpression).stream() + .filter(f -> ! arguments.contains(f)) + .collect(Collectors.joining(", "))); + } this.arguments = ImmutableList.copyOf(arguments); this.functionExpression = functionExpression; } @@ -134,6 +142,22 @@ public class LambdaFunctionNode extends CompositeNode { }); } + private static Set<String> featuresAccessedIn(ExpressionNode node) { + if (node instanceof ReferenceNode) { + return Set.of(((ReferenceNode) node).reference().toString()); + } + else if (node instanceof NameNode) { // (This clause probably not necessary) + return Set.of(((NameNode) node).getValue()); + } + else if (node instanceof CompositeNode) { + Set<String> features = new HashSet<>(); + ((CompositeNode)node).children().forEach(child -> features.addAll(featuresAccessedIn(child))); + return features; + } + return Set.of(); + } + + private class DoubleUnaryLambda implements DoubleUnaryOperator { @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 6e77ab186e8..6a87e0c6d46 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -459,6 +459,22 @@ public class EvaluationTestCase { } @Test + public void testLambdaValidation() { + EvaluationTester tester = new EvaluationTester(); + try { + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:2, {d1:2 }:3 }", + "map(tensor0, f(x) (log10(x+sum(tensor0)))", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }"); + fail("Expected validation failure"); + } + catch (IllegalArgumentException e) { + // success + assertEquals("Lambda log10(x + reduce(tensor0, sum)) accesses features outside its scope: tensor0", + e.getMessage()); + } + + } + + @Test public void testExpand() { EvaluationTester tester = new EvaluationTester(); // Add a dimension using a literal tensor |