aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-14 12:54:07 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-14 12:59:16 +0000
commitb99705e43c8e30fba0406ae450034edcfdc9ab52 (patch)
tree19d0aee2414a5d1747c7028e17f4e64cd3cdeabf
parent9553cd2709c0791aa530f5388cf156116e857795 (diff)
add special handling of TensorFunctionNode containing Generate function
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java27
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java28
-rw-r--r--config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg4
-rw-r--r--config-model/src/test/derived/globalphase_onnx_inside/test.sd5
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/files/m.py51
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/files/ranking_model.onnx23
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg43
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/test.sd42
-rw-r--r--config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java5
10 files changed, 224 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index a00bbb682a8..f9b3bc77040 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -1017,7 +1017,7 @@ public class RankProfile implements Cloneable {
inlineFunctions);
var needInputs = new HashSet<String>();
var recorder = new InputRecorder(needInputs);
- recorder.transform(globalPhaseRanking.function().getBody(), context);
+ recorder.process(globalPhaseRanking.function().getBody(), context);
for (String input : needInputs) {
if (input.startsWith("constant(") || input.startsWith("query(")) {
continue;
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 7124628be0c..6e7787c2dd1 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -10,7 +10,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.functions.Generate;
import java.io.StringReader;
import java.util.HashSet;
@@ -21,7 +23,7 @@ import java.util.Set;
*
* @author arnej
*/
-public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> {
+public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
private final Set<String> neededInputs;
private final Set<String> handled = new HashSet<>();
@@ -30,12 +32,28 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon
this.neededInputs = target;
}
+ public void process(RankingExpression expression, RankProfileTransformContext context) {
+ transform(expression.getRoot(), new InputRecorderContext(context));
+ }
+
@Override
- public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) {
if (node instanceof ReferenceNode r) {
handle(r, context);
return node;
}
+ if (node instanceof TensorFunctionNode t) {
+ var f = t.function();
+ if (f instanceof Generate) {
+ var childContext = new InputRecorderContext(context);
+ var tt = f.type(context.types());
+ // expects only indexed dimensions, should we check?
+ for (var dim : tt.dimensions()) {
+ childContext.localVariables().add(dim.name());
+ }
+ return transformChildren(t, childContext);
+ }
+ }
if (node instanceof CompositeNode c)
return transformChildren(c, context);
if (node instanceof ConstantNode) {
@@ -44,11 +62,14 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon
throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]");
}
- private void handle(ReferenceNode feature, RankProfileTransformContext context) {
+ private void handle(ReferenceNode feature, InputRecorderContext context) {
Reference ref = feature.reference();
String name = ref.name();
var args = ref.arguments();
boolean simpleFunctionOrIdentifier = (args.size() == 0) && (ref.output() == null);
+ if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) {
+ return;
+ }
if (ref.isSimpleRankingExpressionWrapper()) {
name = ref.simpleArgument().get();
simpleFunctionOrIdentifier = true;
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java
new file mode 100644
index 00000000000..54617374b67
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java
@@ -0,0 +1,28 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.expressiontransforms;
+
+import com.yahoo.schema.RankProfile;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+
+import java.util.HashSet;
+import java.util.Set;
+
+class InputRecorderContext extends TransformContext {
+
+ private final RankProfileTransformContext parent;
+ private final Set<String> localVariables = new HashSet<>();
+
+ public RankProfile rankProfile() { return parent.rankProfile(); }
+ public Set<String> localVariables() { return localVariables; }
+
+ public InputRecorderContext(RankProfileTransformContext parent) {
+ super(parent.constants(), parent.types());
+ this.parent = parent;
+ }
+
+ public InputRecorderContext(InputRecorderContext parent) {
+ super(parent.constants(), parent.types());
+ this.parent = parent.parent;
+ this.localVariables.addAll(parent.localVariables);
+ }
+}
diff --git a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
index 1e9cf8ce0e9..e456dec58ed 100644
--- a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
+++ b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
@@ -4,7 +4,7 @@ rankprofile[].fef.property[].value "query(yy)"
rankprofile[].fef.property[].name "rankingExpression(handicap).type"
rankprofile[].fef.property[].value "tensor(d0[2])"
rankprofile[].fef.property[].name "rankingExpression(indirect_a).rankingScript"
-rankprofile[].fef.property[].value "attribute(aa)"
+rankprofile[].fef.property[].value "attribute(aa) + tensor(d1[3])((d1 + attribute(extra)))"
rankprofile[].fef.property[].name "rankingExpression(indirect_a).type"
rankprofile[].fef.property[].value "tensor(d1[3])"
rankprofile[].fef.property[].name "rankingExpression(indirect_x).rankingScript"
@@ -25,6 +25,8 @@ rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript"
rankprofile[].fef.property[].value "reduce(constant(ww) * (onnx(inside).foobar - rankingExpression(handicap)), sum)"
rankprofile[].fef.property[].name "vespa.match.feature"
rankprofile[].fef.property[].value "attribute(aa)"
+rankprofile[].fef.property[].name "vespa.match.feature"
+rankprofile[].fef.property[].value "attribute(extra)"
rankprofile[].fef.property[].name "vespa.globalphase.rerankcount"
rankprofile[].fef.property[].value "13"
rankprofile[].fef.property[].name "vespa.type.attribute.aa"
diff --git a/config-model/src/test/derived/globalphase_onnx_inside/test.sd b/config-model/src/test/derived/globalphase_onnx_inside/test.sd
index 405a65ceb06..f5788611b0a 100644
--- a/config-model/src/test/derived/globalphase_onnx_inside/test.sd
+++ b/config-model/src/test/derived/globalphase_onnx_inside/test.sd
@@ -4,6 +4,9 @@ schema test {
field aa type tensor(d1[3]) {
indexing: attribute
}
+ field extra type float {
+ indexing: attribute
+ }
}
constant xx {
@@ -38,7 +41,7 @@ schema test {
expression: sum(constant(ww) * (onnx(inside).foobar - handicap))
}
function indirect_a() {
- expression: attribute(aa)
+ expression: attribute(aa) + tensor(d1[3])(d1+attribute(extra))
}
function indirect_x() {
expression: constant(xx)
diff --git a/config-model/src/test/derived/globalphase_token_functions/files/m.py b/config-model/src/test/derived/globalphase_token_functions/files/m.py
new file mode 100644
index 00000000000..004135b32eb
--- /dev/null
+++ b/config-model/src/test/derived/globalphase_token_functions/files/m.py
@@ -0,0 +1,51 @@
+# imports
+
+from onnx import TensorProto
+from onnx.helper import (
+ make_model, make_node, make_graph,
+ make_tensor_value_info, make_value_info)
+from onnx.checker import check_model
+
+# inputs
+
+# TensorProto.DOUBLE is the element type, [128] the shape
+A = make_tensor_value_info('input_ids', TensorProto.DOUBLE, [128])
+B = make_tensor_value_info('attention_mask', TensorProto.DOUBLE, [128])
+C = make_tensor_value_info('token_type_ids', TensorProto.DOUBLE, [128])
+
+# outputs, the shape is defined
+Y = make_tensor_value_info('vector_Y', TensorProto.DOUBLE, [128])
+S = make_tensor_value_info('score', TensorProto.DOUBLE, [1])
+
+# Creates node defined by the operator type, inputs, outputs, and possibly options
+node1 = make_node('Mul', ['input_ids', 'attention_mask'], ['masked'])
+node2 = make_node('Add', ['masked', 'token_type_ids'], ['vector_Y'])
+node3 = make_node('ReduceSum', inputs=['vector_Y'], outputs=['score'], keepdims=1)
+
+# from nodes to graph
+# the graph is built from the list of nodes, the list of inputs,
+# the list of outputs and a name.
+
+graph = make_graph([node1, node2, node3], # nodes
+ 'ranking_model', # a name
+ [A, B, C], # inputs
+ [S]) # outputs
+
+# onnx graph to model
+onnx_model = make_model(graph)
+
+# ensure we do not get too new opset version:
+del onnx_model.opset_import[:]
+opset = onnx_model.opset_import.add()
+opset.version = 17
+
+# Let's check the model is consistent, this function is described in
+# section Checker and Shape Inference.
+check_model(onnx_model)
+
+# The serialization
+with open("ranking_model.onnx", "wb") as f:
+ f.write(onnx_model.SerializeToString())
+
+# the work is done, let's display it...
+print(onnx_model)
diff --git a/config-model/src/test/derived/globalphase_token_functions/files/ranking_model.onnx b/config-model/src/test/derived/globalphase_token_functions/files/ranking_model.onnx
new file mode 100644
index 00000000000..274fbba3fdb
--- /dev/null
+++ b/config-model/src/test/derived/globalphase_token_functions/files/ranking_model.onnx
@@ -0,0 +1,23 @@
+:þ
+(
+ input_ids
+attention_maskmasked"Mul
+'
+masked
+token_type_idsvector_Y"Add
+-
+vector_Yscore" ReduceSum*
+keepdims  ranking_modelZ
+ input_ids
+  
+€Z
+attention_mask
+  
+€Z
+token_type_ids
+  
+€b
+score
+
+ 
+B \ No newline at end of file
diff --git a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
new file mode 100644
index 00000000000..33381060178
--- /dev/null
+++ b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
@@ -0,0 +1,43 @@
+rankprofile[].name "default"
+rankprofile[].fef.property[].name "vespa.type.attribute.tokens"
+rankprofile[].fef.property[].value "tensor(d0[128])"
+rankprofile[].name "unranked"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "value(0)"
+rankprofile[].fef.property[].name "vespa.hitcollector.heapsize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.hitcollector.arraysize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures"
+rankprofile[].fef.property[].value "true"
+rankprofile[].fef.property[].name "vespa.type.attribute.tokens"
+rankprofile[].fef.property[].value "tensor(d0[128])"
+rankprofile[].name "using_model"
+rankprofile[].fef.property[].name "rankingExpression(__token_length@1019197748).rankingScript"
+rankprofile[].fef.property[].value "reduce(map(query(input), f(x)(x > 0)), sum)"
+rankprofile[].fef.property[].name "rankingExpression(__token_length@-812590320).rankingScript"
+rankprofile[].fef.property[].value "reduce(map(attribute(tokens), f(x)(x > 0)), sum)"
+rankprofile[].fef.property[].name "rankingExpression(input_ids).rankingScript"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < 1.0, 101.0, if (d1 < 1.0 + rankingExpression(__token_length@1019197748), query(input){d0:(d1 - (1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0), 102.0, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320)), attribute(tokens){d0:(d1 - (1.0 + rankingExpression(__token_length@1019197748) + 1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320) + 1.0), 102.0, 0.0)))))))"
+rankprofile[].fef.property[].name "rankingExpression(input_ids).type"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])"
+rankprofile[].fef.property[].name "rankingExpression(token_type_ids).rankingScript"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320) + 1.0), 1.0, 0.0))))"
+rankprofile[].fef.property[].name "rankingExpression(token_type_ids).type"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])"
+rankprofile[].fef.property[].name "rankingExpression(attention_mask).rankingScript"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@1019197748) + 1.0 + rankingExpression(__token_length@-812590320) + 1.0), 1.0, 0.0)))"
+rankprofile[].fef.property[].name "rankingExpression(attention_mask).type"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])"
+rankprofile[].fef.property[].name "vespa.rank.globalphase"
+rankprofile[].fef.property[].value "rankingExpression(globalphase)"
+rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript"
+rankprofile[].fef.property[].value "onnx(my_ranking_model).score{d0:0}"
+rankprofile[].fef.property[].name "vespa.match.feature"
+rankprofile[].fef.property[].value "attribute(tokens)"
+rankprofile[].fef.property[].name "vespa.globalphase.rerankcount"
+rankprofile[].fef.property[].value "1000"
+rankprofile[].fef.property[].name "vespa.type.attribute.tokens"
+rankprofile[].fef.property[].value "tensor(d0[128])"
+rankprofile[].fef.property[].name "vespa.type.query.input"
+rankprofile[].fef.property[].value "tensor(d0[32])"
diff --git a/config-model/src/test/derived/globalphase_token_functions/test.sd b/config-model/src/test/derived/globalphase_token_functions/test.sd
new file mode 100644
index 00000000000..f8cc8863ad1
--- /dev/null
+++ b/config-model/src/test/derived/globalphase_token_functions/test.sd
@@ -0,0 +1,42 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+schema test {
+ document test {
+ field title type string {
+ indexing: index | summary
+ }
+ field tokens type tensor(d0[128]) {
+ indexing: attribute
+ }
+ }
+ fieldset default {
+ fields: title
+ }
+
+ onnx-model my_ranking_model {
+ file: files/ranking_model.onnx
+ input input_ids: input_ids
+ input attention_mask: attention_mask
+ input token_type_ids: token_type_ids
+ }
+
+ rank-profile using_model {
+ inputs {
+ query(input) tensor(d0[32])
+ }
+ function input_ids() {
+ expression: tokenInputIds(128, query(input), attribute(tokens))
+ }
+ function token_type_ids() {
+ expression: tokenTypeIds(128, query(input), attribute(tokens))
+ }
+ function attention_mask() {
+ expression: tokenAttentionMask(128, query(input), attribute(tokens))
+ }
+ global-phase {
+ rerank-count: 1000
+ expression: onnx(my_ranking_model).score{d0:0}
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java
index 2ff33dd70d8..3d0d9de13f5 100644
--- a/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java
@@ -19,4 +19,9 @@ public class GlobalPhaseOnnxModelsTestCase extends AbstractExportingTestCase {
assertCorrectDeriving("globalphase_onnx_inside");
}
+ @Test
+ void testWithTokenFunctions() throws IOException, ParseException {
+ assertCorrectDeriving("globalphase_token_functions");
+ }
+
}