diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2021-02-08 10:13:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-08 10:13:35 +0100 |
commit | 259f6258554484f992e94e675ce0123a8c30d186 (patch) | |
tree | 3846bd7d24dfbadab8c46b7523f8952ff06587a0 | |
parent | 310ba5b0887926959270d602f908b0abec2ec55c (diff) | |
parent | a67ade21ce03d381e766d0eabd73d7cf74500d9c (diff) |
Merge pull request #16413 from vespa-engine/lesters/fix-token-type-ids
Add 0's after token sequence for tokenTypeIds
3 files changed, 22 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 192fb9baa9a..032341297bf 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -114,14 +114,20 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform /** * Transforms a feature of the form * - * tokenTypeIds(128, a, ...) + * tokenTypeIds(128, a, b, ...) * * to an expression that generates a tensor that has values 0 for "a" * (including CLS and SEP tokens) and 1 for the rest of the sequence. * * Concretely, transforms to a tensor generation expression: * - * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1)) + * tensor(d0[1],d1[128])( + * if (d1 < 1 + length_a + 1, + * 0, + * if (d1 < 1 + length_a + 1 + length_b + 1 + ..., + * 1, + * 0 + * ))) */ private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) { checkArguments(feature); @@ -131,11 +137,18 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform // we need to add functions calculating the token lengths of the arguments createTokenLengthFunctions(feature, context); - ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1); - ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg)); - ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); - ExpressionNode expr = new IfNode(comparison, ZERO, ONE); + List<ExpressionNode> tokenSequence = createTokenSequence(feature); + ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); + ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); + ExpressionNode expr = new IfNode( + new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), + ZERO, + new IfNode( + new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), + ONE, + ZERO + ) + ); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd index 5b440e80bed..4f45e0f6318 100644 --- a/config-model/src/test/integration/onnx-model/schemas/test.sd +++ b/config-model/src/test/integration/onnx-model/schemas/test.sd @@ -101,7 +101,7 @@ search test { rank-profile test_dynamic_model_with_transformer_tokens { function my_function() { - expression: tokenTypeIds(2, attribute(document_field)) + expression: tokenTypeIds(10, attribute(document_field), attribute(document_field)) } first-phase { expression: onnx(dynamic_model){d0:0,d1:1} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index 73ff4ac3bcd..a3ad9f4f4ba 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -152,7 +152,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); - assertEquals("tensor<float>(d0[1],d1[2])((if (d1 < rankingExpression(__token_length@-1993461420) + 2, 0, 1)))", config.rankprofile(7).fef().property(1).value()); + assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1, 0, if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1 + rankingExpression(__token_length@-1993461420) + 1, 1, 0))))", config.rankprofile(7).fef().property(1).value()); assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); |