diff options
author | Arne Juul <arnej@vespa.ai> | 2023-11-06 12:31:36 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2023-11-06 12:31:36 +0000 |
commit | 6b1804d1d8b1136e388051b15f8b0522a8a21783 (patch) | |
tree | 27357cce4cd6117e653b3e46bbe6d483d5f1ba83 /searchlib/src/main | |
parent | 036ec7b3f7a4ff3ff8b4ef7cbeabdfbfc1f72e27 (diff) |
special case Generate for features access in LambdaFunctionNode
Diffstat (limited to 'searchlib/src/main')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java | 47 |
1 files changed, 35 insertions, 12 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index b2641fdf229..0f1331515cc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; import java.util.Collections; import java.util.Deque; @@ -151,19 +152,41 @@ public class LambdaFunctionNode extends CompositeNode { }); } - private static Set<String> featuresAccessedIn(ExpressionNode node) { - if (node instanceof ReferenceNode) { - return Set.of(((ReferenceNode) node).reference().toString()); - } - else if (node instanceof NameNode) { // (This clause probably not necessary) - return Set.of(((NameNode) node).getValue()); - } - else if (node instanceof CompositeNode) { - Set<String> features = new HashSet<>(); - ((CompositeNode)node).children().forEach(child -> features.addAll(featuresAccessedIn(child))); - return features; + private static class FeatureFinder { + private final Set<String> target; + private final Set<String> localVariables = new HashSet<>(); + FeatureFinder(Set<String> target) { this.target = target; } + void process(ExpressionNode node) { + if (node instanceof ReferenceNode refNode) { + String featureName = refNode.reference().toString(); + if (! localVariables.contains(featureName)) { + target.add(featureName); + } + return; + } + Optional<FeatureFinder> subProcessor = Optional.empty(); + if (node instanceof TensorFunctionNode t) { + var fun = t.function(); + if (fun instanceof Generate<?> g) { + var ff = new FeatureFinder(target); + var genType = g.type(null); // Generate knows its own type without any context + for (var dim : genType.dimensions()) { + ff.localVariables.add(dim.name()); + } + subProcessor = Optional.of(ff); + } + } + if (node instanceof CompositeNode composite) { + final FeatureFinder processor = subProcessor.orElse(this); + composite.children().forEach(child -> processor.process(child)); + } } - return Set.of(); + } + + private static Set<String> featuresAccessedIn(ExpressionNode node) { + Set<String> features = new HashSet<>(); + new FeatureFinder(features).process(node); + return features; } @Override |