diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2023-11-06 13:57:12 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-06 13:57:12 +0100 |
commit | 93eea4d5a70d05c5bf0cf716f1cd51a2e92c6c25 (patch) | |
tree | 34d0d0829410c51db4f484c63fcf4f42ed2ea335 | |
parent | 70a18df2765b149600b58a738c73182ed56ac361 (diff) | |
parent | 6b1804d1d8b1136e388051b15f8b0522a8a21783 (diff) |
Merge pull request #29246 from vespa-engine/arnej/allow-using-dimension-name-in-generate
special case Generate for features access in LambdaFunctionNode
2 files changed, 39 insertions, 13 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index b2641fdf229..0f1331515cc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; import java.util.Collections; import java.util.Deque; @@ -151,19 +152,41 @@ public class LambdaFunctionNode extends CompositeNode { }); } - private static Set<String> featuresAccessedIn(ExpressionNode node) { - if (node instanceof ReferenceNode) { - return Set.of(((ReferenceNode) node).reference().toString()); - } - else if (node instanceof NameNode) { // (This clause probably not necessary) - return Set.of(((NameNode) node).getValue()); - } - else if (node instanceof CompositeNode) { - Set<String> features = new HashSet<>(); - ((CompositeNode)node).children().forEach(child -> features.addAll(featuresAccessedIn(child))); - return features; + private static class FeatureFinder { + private final Set<String> target; + private final Set<String> localVariables = new HashSet<>(); + FeatureFinder(Set<String> target) { this.target = target; } + void process(ExpressionNode node) { + if (node instanceof ReferenceNode refNode) { + String featureName = refNode.reference().toString(); + if (! localVariables.contains(featureName)) { + target.add(featureName); + } + return; + } + Optional<FeatureFinder> subProcessor = Optional.empty(); + if (node instanceof TensorFunctionNode t) { + var fun = t.function(); + if (fun instanceof Generate<?> g) { + var ff = new FeatureFinder(target); + var genType = g.type(null); // Generate knows its own type without any context + for (var dim : genType.dimensions()) { + ff.localVariables.add(dim.name()); + } + subProcessor = Optional.of(ff); + } + } + if (node instanceof CompositeNode composite) { + final FeatureFinder processor = subProcessor.orElse(this); + composite.children().forEach(child -> processor.process(child)); + } } - return Set.of(); + } + + private static Set<String> featuresAccessedIn(ExpressionNode node) { + Set<String> features = new HashSet<>(); + new FeatureFinder(features).process(node); + return features; } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index f9ba7552560..637e9be5fc3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -769,6 +769,10 @@ public class EvaluationTestCase { @Test public void testLambdaValidation() { EvaluationTester tester = new EvaluationTester(); + // check that we are allowed to access dimension name "y" inside Generate + tester.assertEvaluates("{ {d1:0}:15, {d1:1}:150, {d1:2 }:1500 }", + "map(tensor0, f(x) (sum(tensor(y[6])(x*y))))", + "{ {d1:0}:1, {d1:1}:10, {d1:2}:100 }"); try { tester.assertEvaluates("{ {d1:0}:1, {d1:1}:2, {d1:2 }:3 }", "map(tensor0, f(x) (log10(x+sum(tensor0)))", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }"); @@ -779,7 +783,6 @@ public class EvaluationTestCase { assertEquals("Lambda log10(x + reduce(tensor0, sum)) accesses features outside its scope: tensor0", e.getMessage()); } - } @Test |