diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-16 11:00:36 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-16 11:05:35 +0000 |
commit | d883fa1f9fff36baf7231e6e9f8958017a705657 (patch) | |
tree | a99d74c943de8eb2d0b6c47746e39b612aa4c5c6 | |
parent | 9ec2f1a8888610e961ba4c6894abb096a8373850 (diff) |
* make InputRecorder handle slice with embedded expressions
* special handling for Generate and Slice in BindingExtractor
9 files changed, 57 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 6e7787c2dd1..993a2442fdb 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; import java.io.StringReader; import java.util.HashSet; @@ -53,6 +54,13 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } return transformChildren(t, childContext); } + if (f instanceof Slice s) { + for (var tf : s.selectorFunctions()) { + if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { + transform(expr.wrappedExpression(), context); + } + } + } } if (node instanceof CompositeNode c) return transformChildren(c, context); diff --git a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg index 33381060178..bf47dba9a71 100644 --- a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg +++ b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg @@ -32,9 +32,11 @@ rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])" rankprofile[].fef.property[].name "vespa.rank.globalphase" rankprofile[].fef.property[].value "rankingExpression(globalphase)" rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" -rankprofile[].fef.property[].value "onnx(my_ranking_model).score{d0:0}" +rankprofile[].fef.property[].value "onnx(my_ranking_model).score{d0:(attribute(outputidx))}" rankprofile[].fef.property[].name "vespa.match.feature" rankprofile[].fef.property[].value "attribute(tokens)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(outputidx)" rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" rankprofile[].fef.property[].value "1000" rankprofile[].fef.property[].name "vespa.type.attribute.tokens" diff --git a/config-model/src/test/derived/globalphase_token_functions/test.sd b/config-model/src/test/derived/globalphase_token_functions/test.sd index f8cc8863ad1..a1d14258aab 100644 --- a/config-model/src/test/derived/globalphase_token_functions/test.sd +++ b/config-model/src/test/derived/globalphase_token_functions/test.sd @@ -8,6 +8,9 @@ schema test { field tokens type tensor(d0[128]) { indexing: attribute } + field outputidx type double { + indexing: attribute + } } fieldset default { fields: title @@ -35,7 +38,7 @@ schema test { } global-phase { rerank-count: 1000 - expression: onnx(my_ranking_model).score{d0:0} + expression: onnx(my_ranking_model).score{d0:attribute(outputidx)} } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java index 6b1f60df6f4..126e9f9f4e6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java @@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -45,6 +48,11 @@ class BindingExtractor { arguments.addAll(other.arguments); onnxModelsInUse.putAll(other.onnxModelsInUse); } + + void removeTarget(String name) { + bindTargets.remove(name); + arguments.remove(name); + } } private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<>(); @@ -83,6 +91,26 @@ class BindingExtractor { } return result; } + else if (node instanceof TensorFunctionNode tfn) { + for (ExpressionNode child : tfn.children()) { + result.merge(extractBindTargets(child)); + } + var f = tfn.function(); + if (f instanceof Generate) { + var tt = f.type(null); + for (var dim : tt.dimensions()) { + result.removeTarget(dim.name()); + } + } + else if (f instanceof Slice<?> s) { + for (var selectorFunc : s.selectorFunctions()) { + if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) { + result.merge(extractBindTargets(expr.wrappedExpression())); + } + } + } + return result; + } else if (isOnnx(node)) { return extractOnnxTargets(node); } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 666c3a103b5..a34898fe4e9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -109,7 +109,7 @@ class FunctionReference { /** * Returns a function reference from the given return type serial form, - * or empty if the string is not a valid function return typoe serial form + * or empty if the string is not a valid function return type serial form */ static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) { Matcher expressionMatcher = returnTypePattern.matcher(serialForm); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index be0414421fe..de3ec9648d1 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1673,6 +1673,7 @@ ], "methods" : [ "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", + "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode wrappedExpression()", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 75187d8ca19..7577c65527b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -209,6 +209,8 @@ public class TensorFunctionNode extends CompositeNode { this.expression = expression; } + public ExpressionNode wrappedExpression() { return expression; } + @Override public List<TensorFunction<Reference>> arguments() { if (expression instanceof CompositeNode) diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index c3b87278345..011e7e4a31d 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2763,6 +2763,7 @@ "methods" : [ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", "public java.util.List arguments()", + "public java.util.List selectorFunctions()", "public com.yahoo.tensor.functions.Slice withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index a9a6df6ed4d..87e24306031 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -47,6 +48,14 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY @Override public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } + public List<TensorFunction<NAMETYPE>> selectorFunctions() { + var result = new ArrayList<TensorFunction<NAMETYPE>>(); + for (var dimVal : subspaceAddress) { + dimVal.index().ifPresent(fun -> fun.asTensorFunction().ifPresent(tf -> result.add(tf))); + } + return result; + } + @Override public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 1) |