summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-17 08:24:25 +0100
committerGitHub <noreply@github.com>2023-03-17 08:24:25 +0100
commit5128a20597f89bd2cb420611b8472c36db436daf (patch)
treeee9a20c0c1fcfcb9646908ac7bcdf2de9ef322a7
parent4b47fa5a75f924c18184cc2fc3959c7b3153ab84 (diff)
parentd883fa1f9fff36baf7231e6e9f8958017a705657 (diff)
Merge pull request #26462 from vespa-engine/arnej/handle-expression-in-slice
* make InputRecorder handle slice with embedded expressions
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java8
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg4
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/test.sd5
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java28
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java2
-rw-r--r--searchlib/abi-spec.json1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java2
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java9
9 files changed, 57 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 6e7787c2dd1..993a2442fdb 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Slice;
import java.io.StringReader;
import java.util.HashSet;
@@ -53,6 +54,13 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
}
return transformChildren(t, childContext);
}
+ if (f instanceof Slice s) {
+ for (var tf : s.selectorFunctions()) {
+ if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
+ transform(expr.wrappedExpression(), context);
+ }
+ }
+ }
}
if (node instanceof CompositeNode c)
return transformChildren(c, context);
diff --git a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
index 33381060178..bf47dba9a71 100644
--- a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
+++ b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
@@ -32,9 +32,11 @@ rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])"
rankprofile[].fef.property[].name "vespa.rank.globalphase"
rankprofile[].fef.property[].value "rankingExpression(globalphase)"
rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript"
-rankprofile[].fef.property[].value "onnx(my_ranking_model).score{d0:0}"
+rankprofile[].fef.property[].value "onnx(my_ranking_model).score{d0:(attribute(outputidx))}"
rankprofile[].fef.property[].name "vespa.match.feature"
rankprofile[].fef.property[].value "attribute(tokens)"
+rankprofile[].fef.property[].name "vespa.match.feature"
+rankprofile[].fef.property[].value "attribute(outputidx)"
rankprofile[].fef.property[].name "vespa.globalphase.rerankcount"
rankprofile[].fef.property[].value "1000"
rankprofile[].fef.property[].name "vespa.type.attribute.tokens"
diff --git a/config-model/src/test/derived/globalphase_token_functions/test.sd b/config-model/src/test/derived/globalphase_token_functions/test.sd
index f8cc8863ad1..a1d14258aab 100644
--- a/config-model/src/test/derived/globalphase_token_functions/test.sd
+++ b/config-model/src/test/derived/globalphase_token_functions/test.sd
@@ -8,6 +8,9 @@ schema test {
field tokens type tensor(d0[128]) {
indexing: attribute
}
+ field outputidx type double {
+ indexing: attribute
+ }
}
fieldset default {
fields: title
@@ -35,7 +38,7 @@ schema test {
}
global-phase {
rerank-count: 1000
- expression: onnx(my_ranking_model).score{d0:0}
+ expression: onnx(my_ranking_model).score{d0:attribute(outputidx)}
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
index 6b1f60df6f4..126e9f9f4e6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
@@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Slice;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
@@ -45,6 +48,11 @@ class BindingExtractor {
arguments.addAll(other.arguments);
onnxModelsInUse.putAll(other.onnxModelsInUse);
}
+
+ void removeTarget(String name) {
+ bindTargets.remove(name);
+ arguments.remove(name);
+ }
}
private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<>();
@@ -83,6 +91,26 @@ class BindingExtractor {
}
return result;
}
+ else if (node instanceof TensorFunctionNode tfn) {
+ for (ExpressionNode child : tfn.children()) {
+ result.merge(extractBindTargets(child));
+ }
+ var f = tfn.function();
+ if (f instanceof Generate) {
+ var tt = f.type(null);
+ for (var dim : tt.dimensions()) {
+ result.removeTarget(dim.name());
+ }
+ }
+ else if (f instanceof Slice<?> s) {
+ for (var selectorFunc : s.selectorFunctions()) {
+ if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
+ result.merge(extractBindTargets(expr.wrappedExpression()));
+ }
+ }
+ }
+ return result;
+ }
else if (isOnnx(node)) {
return extractOnnxTargets(node);
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 666c3a103b5..a34898fe4e9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -109,7 +109,7 @@ class FunctionReference {
/**
* Returns a function reference from the given return type serial form,
- * or empty if the string is not a valid function return typoe serial form
+ * or empty if the string is not a valid function return type serial form
*/
static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) {
Matcher expressionMatcher = returnTypePattern.matcher(serialForm);
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index be0414421fe..de3ec9648d1 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1673,6 +1673,7 @@
],
"methods" : [
"public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode wrappedExpression()",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 75187d8ca19..7577c65527b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -209,6 +209,8 @@ public class TensorFunctionNode extends CompositeNode {
this.expression = expression;
}
+ public ExpressionNode wrappedExpression() { return expression; }
+
@Override
public List<TensorFunction<Reference>> arguments() {
if (expression instanceof CompositeNode)
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index c3b87278345..011e7e4a31d 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2763,6 +2763,7 @@
"methods" : [
"public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
"public java.util.List arguments()",
+ "public java.util.List selectorFunctions()",
"public com.yahoo.tensor.functions.Slice withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index a9a6df6ed4d..87e24306031 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
@@ -47,6 +48,14 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
@Override
public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
+ public List<TensorFunction<NAMETYPE>> selectorFunctions() {
+ var result = new ArrayList<TensorFunction<NAMETYPE>>();
+ for (var dimVal : subspaceAddress) {
+ dimVal.index().ifPresent(fun -> fun.asTensorFunction().ifPresent(tf -> result.add(tf)));
+ }
+ return result;
+ }
+
@Override
public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 1)