aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2023-11-06 12:31:36 +0000
committerArne Juul <arnej@vespa.ai>2023-11-06 12:31:36 +0000
commit6b1804d1d8b1136e388051b15f8b0522a8a21783 (patch)
tree27357cce4cd6117e653b3e46bbe6d483d5f1ba83
parent036ec7b3f7a4ff3ff8b4ef7cbeabdfbfc1f72e27 (diff)
special case Generate for features access in LambdaFunctionNode
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java47
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java5
2 files changed, 39 insertions, 13 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index b2641fdf229..0f1331515cc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
import java.util.Collections;
import java.util.Deque;
@@ -151,19 +152,41 @@ public class LambdaFunctionNode extends CompositeNode {
});
}
- private static Set<String> featuresAccessedIn(ExpressionNode node) {
- if (node instanceof ReferenceNode) {
- return Set.of(((ReferenceNode) node).reference().toString());
- }
- else if (node instanceof NameNode) { // (This clause probably not necessary)
- return Set.of(((NameNode) node).getValue());
- }
- else if (node instanceof CompositeNode) {
- Set<String> features = new HashSet<>();
- ((CompositeNode)node).children().forEach(child -> features.addAll(featuresAccessedIn(child)));
- return features;
+ private static class FeatureFinder {
+ private final Set<String> target;
+ private final Set<String> localVariables = new HashSet<>();
+ FeatureFinder(Set<String> target) { this.target = target; }
+ void process(ExpressionNode node) {
+ if (node instanceof ReferenceNode refNode) {
+ String featureName = refNode.reference().toString();
+ if (! localVariables.contains(featureName)) {
+ target.add(featureName);
+ }
+ return;
+ }
+ Optional<FeatureFinder> subProcessor = Optional.empty();
+ if (node instanceof TensorFunctionNode t) {
+ var fun = t.function();
+ if (fun instanceof Generate<?> g) {
+ var ff = new FeatureFinder(target);
+ var genType = g.type(null); // Generate knows its own type without any context
+ for (var dim : genType.dimensions()) {
+ ff.localVariables.add(dim.name());
+ }
+ subProcessor = Optional.of(ff);
+ }
+ }
+ if (node instanceof CompositeNode composite) {
+ final FeatureFinder processor = subProcessor.orElse(this);
+ composite.children().forEach(child -> processor.process(child));
+ }
}
- return Set.of();
+ }
+
+ private static Set<String> featuresAccessedIn(ExpressionNode node) {
+ Set<String> features = new HashSet<>();
+ new FeatureFinder(features).process(node);
+ return features;
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index f9ba7552560..637e9be5fc3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -769,6 +769,10 @@ public class EvaluationTestCase {
@Test
public void testLambdaValidation() {
EvaluationTester tester = new EvaluationTester();
+ // check that we are allowed to access dimension name "y" inside Generate
+ tester.assertEvaluates("{ {d1:0}:15, {d1:1}:150, {d1:2 }:1500 }",
+ "map(tensor0, f(x) (sum(tensor(y[6])(x*y))))",
+ "{ {d1:0}:1, {d1:1}:10, {d1:2}:100 }");
try {
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:2, {d1:2 }:3 }",
"map(tensor0, f(x) (log10(x+sum(tensor0)))", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }");
@@ -779,7 +783,6 @@ public class EvaluationTestCase {
assertEquals("Lambda log10(x + reduce(tensor0, sum)) accesses features outside its scope: tensor0",
e.getMessage());
}
-
}
@Test