aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-14 12:54:07 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-14 12:59:16 +0000
commitb99705e43c8e30fba0406ae450034edcfdc9ab52 (patch)
tree19d0aee2414a5d1747c7028e17f4e64cd3cdeabf /config-model/src/main/java/com/yahoo/schema
parent9553cd2709c0791aa530f5388cf156116e857795 (diff)
add special handling of TensorFunctionNode containing Generate function
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java27
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java28
3 files changed, 53 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index a00bbb682a8..f9b3bc77040 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -1017,7 +1017,7 @@ public class RankProfile implements Cloneable {
inlineFunctions);
var needInputs = new HashSet<String>();
var recorder = new InputRecorder(needInputs);
- recorder.transform(globalPhaseRanking.function().getBody(), context);
+ recorder.process(globalPhaseRanking.function().getBody(), context);
for (String input : needInputs) {
if (input.startsWith("constant(") || input.startsWith("query(")) {
continue;
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 7124628be0c..6e7787c2dd1 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -10,7 +10,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.functions.Generate;
import java.io.StringReader;
import java.util.HashSet;
@@ -21,7 +23,7 @@ import java.util.Set;
*
* @author arnej
*/
-public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> {
+public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
private final Set<String> neededInputs;
private final Set<String> handled = new HashSet<>();
@@ -30,12 +32,28 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon
this.neededInputs = target;
}
+ public void process(RankingExpression expression, RankProfileTransformContext context) {
+ transform(expression.getRoot(), new InputRecorderContext(context));
+ }
+
@Override
- public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) {
if (node instanceof ReferenceNode r) {
handle(r, context);
return node;
}
+ if (node instanceof TensorFunctionNode t) {
+ var f = t.function();
+ if (f instanceof Generate) {
+ var childContext = new InputRecorderContext(context);
+ var tt = f.type(context.types());
+ // expects only indexed dimensions, should we check?
+ for (var dim : tt.dimensions()) {
+ childContext.localVariables().add(dim.name());
+ }
+ return transformChildren(t, childContext);
+ }
+ }
if (node instanceof CompositeNode c)
return transformChildren(c, context);
if (node instanceof ConstantNode) {
@@ -44,11 +62,14 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon
throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]");
}
- private void handle(ReferenceNode feature, RankProfileTransformContext context) {
+ private void handle(ReferenceNode feature, InputRecorderContext context) {
Reference ref = feature.reference();
String name = ref.name();
var args = ref.arguments();
boolean simpleFunctionOrIdentifier = (args.size() == 0) && (ref.output() == null);
+ if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) {
+ return;
+ }
if (ref.isSimpleRankingExpressionWrapper()) {
name = ref.simpleArgument().get();
simpleFunctionOrIdentifier = true;
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java
new file mode 100644
index 00000000000..54617374b67
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java
@@ -0,0 +1,28 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.expressiontransforms;
+
+import com.yahoo.schema.RankProfile;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+
+import java.util.HashSet;
+import java.util.Set;
+
+class InputRecorderContext extends TransformContext {
+
+ private final RankProfileTransformContext parent;
+ private final Set<String> localVariables = new HashSet<>();
+
+ public RankProfile rankProfile() { return parent.rankProfile(); }
+ public Set<String> localVariables() { return localVariables; }
+
+ public InputRecorderContext(RankProfileTransformContext parent) {
+ super(parent.constants(), parent.types());
+ this.parent = parent;
+ }
+
+ public InputRecorderContext(InputRecorderContext parent) {
+ super(parent.constants(), parent.types());
+ this.parent = parent.parent;
+ this.localVariables.addAll(parent.localVariables);
+ }
+}