summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2023-11-06 12:31:36 +0000
committerArne Juul <arnej@vespa.ai>2023-11-06 12:31:36 +0000
commit6b1804d1d8b1136e388051b15f8b0522a8a21783 (patch)
tree27357cce4cd6117e653b3e46bbe6d483d5f1ba83 /searchlib/src/main
parent036ec7b3f7a4ff3ff8b4ef7cbeabdfbfc1f72e27 (diff)
special case Generate for features access in LambdaFunctionNode
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java47
1 files changed, 35 insertions, 12 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index b2641fdf229..0f1331515cc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
import java.util.Collections;
import java.util.Deque;
@@ -151,19 +152,41 @@ public class LambdaFunctionNode extends CompositeNode {
});
}
- private static Set<String> featuresAccessedIn(ExpressionNode node) {
- if (node instanceof ReferenceNode) {
- return Set.of(((ReferenceNode) node).reference().toString());
- }
- else if (node instanceof NameNode) { // (This clause probably not necessary)
- return Set.of(((NameNode) node).getValue());
- }
- else if (node instanceof CompositeNode) {
- Set<String> features = new HashSet<>();
- ((CompositeNode)node).children().forEach(child -> features.addAll(featuresAccessedIn(child)));
- return features;
+ private static class FeatureFinder {
+ private final Set<String> target;
+ private final Set<String> localVariables = new HashSet<>();
+ FeatureFinder(Set<String> target) { this.target = target; }
+ void process(ExpressionNode node) {
+ if (node instanceof ReferenceNode refNode) {
+ String featureName = refNode.reference().toString();
+ if (! localVariables.contains(featureName)) {
+ target.add(featureName);
+ }
+ return;
+ }
+ Optional<FeatureFinder> subProcessor = Optional.empty();
+ if (node instanceof TensorFunctionNode t) {
+ var fun = t.function();
+ if (fun instanceof Generate<?> g) {
+ var ff = new FeatureFinder(target);
+ var genType = g.type(null); // Generate knows its own type without any context
+ for (var dim : genType.dimensions()) {
+ ff.localVariables.add(dim.name());
+ }
+ subProcessor = Optional.of(ff);
+ }
+ }
+ if (node instanceof CompositeNode composite) {
+ final FeatureFinder processor = subProcessor.orElse(this);
+ composite.children().forEach(child -> processor.process(child));
+ }
}
- return Set.of();
+ }
+
+ private static Set<String> featuresAccessedIn(ExpressionNode node) {
+ Set<String> features = new HashSet<>();
+ new FeatureFinder(features).process(node);
+ return features;
}
@Override