diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-14 12:54:07 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-14 12:59:16 +0000 |
commit | b99705e43c8e30fba0406ae450034edcfdc9ab52 (patch) | |
tree | 19d0aee2414a5d1747c7028e17f4e64cd3cdeabf /config-model/src | |
parent | 9553cd2709c0791aa530f5388cf156116e857795 (diff) |
add special handling of TensorFunctionNode containing Generate function
Diffstat (limited to 'config-model/src')
10 files changed, 224 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index a00bbb682a8..f9b3bc77040 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1017,7 +1017,7 @@ public class RankProfile implements Cloneable { inlineFunctions); var needInputs = new HashSet<String>(); var recorder = new InputRecorder(needInputs); - recorder.transform(globalPhaseRanking.function().getBody(), context); + recorder.process(globalPhaseRanking.function().getBody(), context); for (String input : needInputs) { if (input.startsWith("constant(") || input.startsWith("query(")) { continue; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 7124628be0c..6e7787c2dd1 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -10,7 +10,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.functions.Generate; import java.io.StringReader; import java.util.HashSet; @@ -21,7 +23,7 @@ import java.util.Set; * * @author arnej */ -public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> { +public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); @@ -30,12 +32,28 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon this.neededInputs = target; } + public void process(RankingExpression expression, RankProfileTransformContext context) { + transform(expression.getRoot(), new InputRecorderContext(context)); + } + @Override - public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) { if (node instanceof ReferenceNode r) { handle(r, context); return node; } + if (node instanceof TensorFunctionNode t) { + var f = t.function(); + if (f instanceof Generate) { + var childContext = new InputRecorderContext(context); + var tt = f.type(context.types()); + // expects only indexed dimensions, should we check? + for (var dim : tt.dimensions()) { + childContext.localVariables().add(dim.name()); + } + return transformChildren(t, childContext); + } + } if (node instanceof CompositeNode c) return transformChildren(c, context); if (node instanceof ConstantNode) { @@ -44,11 +62,14 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]"); } - private void handle(ReferenceNode feature, RankProfileTransformContext context) { + private void handle(ReferenceNode feature, InputRecorderContext context) { Reference ref = feature.reference(); String name = ref.name(); var args = ref.arguments(); boolean simpleFunctionOrIdentifier = (args.size() == 0) && (ref.output() == null); + if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) { + return; + } if (ref.isSimpleRankingExpressionWrapper()) { name = ref.simpleArgument().get(); simpleFunctionOrIdentifier = true; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java new file mode 100644 index 00000000000..54617374b67 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java @@ -0,0 +1,28 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.expressiontransforms; + +import com.yahoo.schema.RankProfile; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; + +import java.util.HashSet; +import java.util.Set; + +class InputRecorderContext extends TransformContext { + + private final RankProfileTransformContext parent; + private final Set<String> localVariables = new HashSet<>(); + + public RankProfile rankProfile() { return parent.rankProfile(); } + public Set<String> localVariables() { return localVariables; } + + public InputRecorderContext(RankProfileTransformContext parent) { + super(parent.constants(), parent.types()); + this.parent = parent; + } + + public InputRecorderContext(InputRecorderContext parent) { + super(parent.constants(), parent.types()); + this.parent = parent.parent; + this.localVariables.addAll(parent.localVariables); + } +} diff --git a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg index 1e9cf8ce0e9..e456dec58ed 100644 --- a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg +++ b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg @@ -4,7 +4,7 @@ rankprofile[].fef.property[].value "query(yy)" rankprofile[].fef.property[].name "rankingExpression(handicap).type" rankprofile[].fef.property[].value "tensor(d0[2])" rankprofile[].fef.property[].name "rankingExpression(indirect_a).rankingScript" -rankprofile[].fef.property[].value "attribute(aa)" +rankprofile[].fef.property[].value "attribute(aa) + tensor(d1[3])((d1 + attribute(extra)))" rankprofile[].fef.property[].name "rankingExpression(indirect_a).type" rankprofile[].fef.property[].value "tensor(d1[3])" rankprofile[].fef.property[].name "rankingExpression(indirect_x).rankingScript" @@ -25,6 +25,8 @@ rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" rankprofile[].fef.property[].value "reduce(constant(ww) * (onnx(inside).foobar - rankingExpression(handicap)), sum)" rankprofile[].fef.property[].name "vespa.match.feature" rankprofile[].fef.property[].value "attribute(aa)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(extra)" rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" rankprofile[].fef.property[].value "13" rankprofile[].fef.property[].name "vespa.type.attribute.aa" diff --git a/config-model/src/test/derived/globalphase_onnx_inside/test.sd b/config-model/src/test/derived/globalphase_onnx_inside/test.sd index 405a65ceb06..f5788611b0a 100644 --- a/config-model/src/test/derived/globalphase_onnx_inside/test.sd +++ b/config-model/src/test/derived/globalphase_onnx_inside/test.sd @@ -4,6 +4,9 @@ schema test { field aa type tensor(d1[3]) { indexing: attribute } + field extra type float { + indexing: attribute + } } constant xx { @@ -38,7 +41,7 @@ schema test { expression: sum(constant(ww) * (onnx(inside).foobar - handicap)) } function indirect_a() { - expression: attribute(aa) + expression: attribute(aa) + tensor(d1[3])(d1+attribute(extra)) } function indirect_x() { expression: constant(xx) diff --git a/config-model/src/test/derived/globalphase_token_functions/files/m.py b/config-model/src/test/derived/globalphase_token_functions/files/m.py new file mode 100644 index 00000000000..004135b32eb --- /dev/null +++ b/config-model/src/test/derived/globalphase_token_functions/files/m.py @@ -0,0 +1,51 @@ +# imports + +from onnx import TensorProto +from onnx.helper import ( + make_model, make_node, make_graph, + make_tensor_value_info, make_value_info) +from onnx.checker import check_model + +# inputs + +# TensorProto.DOUBLE is the element type, [128] the shape +A = make_tensor_value_info('input_ids', TensorProto.DOUBLE, [128]) +B = make_tensor_value_info('attention_mask', TensorProto.DOUBLE, [128]) +C = make_tensor_value_info('token_type_ids', TensorProto.DOUBLE, [128]) + +# outputs, the shape is defined +Y = make_tensor_value_info('vector_Y', TensorProto.DOUBLE, [128]) +S = make_tensor_value_info('score', TensorProto.DOUBLE, [1]) + +# Creates node defined by the operator type, inputs, outputs, and possibly options +node1 = make_node('Mul', ['input_ids', 'attention_mask'], ['masked']) +node2 = make_node('Add', ['masked', 'token_type_ids'], ['vector_Y']) +node3 = make_node('ReduceSum', inputs=['vector_Y'], outputs=['score'], keepdims=1) + +# from nodes to graph +# the graph is built from the list of nodes, the list of inputs, +# the list of outputs and a name. + +graph = make_graph([node1, node2, node3], # nodes + 'ranking_model', # a name + [A, B, C], # inputs + [S]) # outputs + +# onnx graph to model +onnx_model = make_model(graph) + +# ensure we do not get too new opset version: +del onnx_model.opset_import[:] +opset = onnx_model.opset_import.add() +opset.version = 17 + +# Let's check the model is consistent, this function is described in +# section Checker and Shape Inference. +check_model(onnx_model) + +# The serialization +with open("ranking_model.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) + +# the work is done, let's display it... +print(onnx_model) diff --git a/config-model/src/test/derived/globalphase_token_functions/files/ranking_model.onnx b/config-model/src/test/derived/globalphase_token_functions/files/ranking_model.onnx new file mode 100644 index 00000000000..274fbba3fdb --- /dev/null +++ b/config-model/src/test/derived/globalphase_token_functions/files/ranking_model.onnx @@ -0,0 +1,23 @@ +:þ +( + input_ids +attention_maskmasked"Mul +' +masked +token_type_idsvector_Y"Add +- +vector_Yscore" ReduceSum* +keepdims
ranking_modelZ + input_ids + +€Z +attention_mask + +€Z +token_type_ids + +€b +score + + +B
\ No newline at end of file diff --git a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg new file mode 100644 index 00000000000..33381060178 --- /dev/null +++ b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg @@ -0,0 +1,43 @@ +rankprofile[].name "default" +rankprofile[].fef.property[].name "vespa.type.attribute.tokens" +rankprofile[].fef.property[].value "tensor(d0[128])" +rankprofile[].name "unranked" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "value(0)" +rankprofile[].fef.property[].name "vespa.hitcollector.heapsize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.hitcollector.arraysize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" +rankprofile[].fef.property[].value "true" +rankprofile[].fef.property[].name "vespa.type.attribute.tokens" +rankprofile[].fef.property[].value "tensor(d0[128])" +rankprofile[].name "using_model" +rankprofile[].fef.property[].name "rankingExpression(__token_length@1019197748).rankingScript" +rankprofile[].fef.property[].value "reduce(map(query(input), f(x)(x > 0)), sum)" +rankprofile[].fef.property[].name "rankingExpression(__token_length@-812590320).rankingScript" +rankprofile[].fef.property[].value "reduce(map(attribute(tokens), f(x)(x > 0)), sum)" +rankprofile[].fef.property[].name "rankingExpression(input_ids).rankingScript" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < 1.0, 101.0, if (d1 < 1.0 + rankingExpression(__token_length@1019197748), query(input){d0:(d1 - (1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0), 102.0, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320)), attribute(tokens){d0:(d1 - (1.0 + rankingExpression(__token_length@1019197748) + 1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320) + 1.0), 102.0, 0.0)))))))" +rankprofile[].fef.property[].name "rankingExpression(input_ids).type" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])" +rankprofile[].fef.property[].name "rankingExpression(token_type_ids).rankingScript" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320) + 1.0), 1.0, 0.0))))" +rankprofile[].fef.property[].name "rankingExpression(token_type_ids).type" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])" +rankprofile[].fef.property[].name "rankingExpression(attention_mask).rankingScript" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320) + 1.0), 1.0, 0.0)))" +rankprofile[].fef.property[].name "rankingExpression(attention_mask).type" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])" +rankprofile[].fef.property[].name "vespa.rank.globalphase" +rankprofile[].fef.property[].value "rankingExpression(globalphase)" +rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" +rankprofile[].fef.property[].value "onnx(my_ranking_model).score{d0:0}" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(tokens)" +rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" +rankprofile[].fef.property[].value "1000" +rankprofile[].fef.property[].name "vespa.type.attribute.tokens" +rankprofile[].fef.property[].value "tensor(d0[128])" +rankprofile[].fef.property[].name "vespa.type.query.input" +rankprofile[].fef.property[].value "tensor(d0[32])" diff --git a/config-model/src/test/derived/globalphase_token_functions/test.sd b/config-model/src/test/derived/globalphase_token_functions/test.sd new file mode 100644 index 00000000000..f8cc8863ad1 --- /dev/null +++ b/config-model/src/test/derived/globalphase_token_functions/test.sd @@ -0,0 +1,42 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +schema test { + document test { + field title type string { + indexing: index | summary + } + field tokens type tensor(d0[128]) { + indexing: attribute + } + } + fieldset default { + fields: title + } + + onnx-model my_ranking_model { + file: files/ranking_model.onnx + input input_ids: input_ids + input attention_mask: attention_mask + input token_type_ids: token_type_ids + } + + rank-profile using_model { + inputs { + query(input) tensor(d0[32]) + } + function input_ids() { + expression: tokenInputIds(128, query(input), attribute(tokens)) + } + function token_type_ids() { + expression: tokenTypeIds(128, query(input), attribute(tokens)) + } + function attention_mask() { + expression: tokenAttentionMask(128, query(input), attribute(tokens)) + } + global-phase { + rerank-count: 1000 + expression: onnx(my_ranking_model).score{d0:0} + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java index 2ff33dd70d8..3d0d9de13f5 100644 --- a/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java @@ -19,4 +19,9 @@ public class GlobalPhaseOnnxModelsTestCase extends AbstractExportingTestCase { assertCorrectDeriving("globalphase_onnx_inside"); } + @Test + void testWithTokenFunctions() throws IOException, ParseException { + assertCorrectDeriving("globalphase_token_functions"); + } + } |