diff options
author | Harald Musum <musum@verizonmedia.com> | 2021-02-08 10:26:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-08 10:26:22 +0100 |
commit | 0e5c47a8666e32eb304ca634c4b96ae8e96bb166 (patch) | |
tree | e85052ce27247d085f9a89ec5466c045f2e26cc7 /config-model | |
parent | 7fad88a0c27615f3a01dab36e8e4ff754aee0350 (diff) |
Revert "Add 0's after token sequence for tokenTypeIds "
Diffstat (limited to 'config-model')
3 files changed, 9 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 032341297bf..192fb9baa9a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -114,20 +114,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform /** * Transforms a feature of the form * - * tokenTypeIds(128, a, b, ...) + * tokenTypeIds(128, a, ...) * * to an expression that generates a tensor that has values 0 for "a" * (including CLS and SEP tokens) and 1 for the rest of the sequence. * * Concretely, transforms to a tensor generation expression: * - * tensor(d0[1],d1[128])( - * if (d1 < 1 + length_a + 1, - * 0, - * if (d1 < 1 + length_a + 1 + length_b + 1 + ..., - * 1, - * 0 - * ))) + * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1)) */ private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) { checkArguments(feature); @@ -137,18 +131,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform // we need to add functions calculating the token lengths of the arguments createTokenLengthFunctions(feature, context); - List<ExpressionNode> tokenSequence = createTokenSequence(feature); - ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); - ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - ExpressionNode expr = new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), - ZERO, - new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), - ONE, - ZERO - ) - ); + ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1); + ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg)); + ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO); + ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + ExpressionNode expr = new IfNode(comparison, ZERO, ONE); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd index 4f45e0f6318..5b440e80bed 100644 --- a/config-model/src/test/integration/onnx-model/schemas/test.sd +++ b/config-model/src/test/integration/onnx-model/schemas/test.sd @@ -101,7 +101,7 @@ search test { rank-profile test_dynamic_model_with_transformer_tokens { function my_function() { - expression: tokenTypeIds(10, attribute(document_field), attribute(document_field)) + expression: tokenTypeIds(2, attribute(document_field)) } first-phase { expression: onnx(dynamic_model){d0:0,d1:1} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index a3ad9f4f4ba..73ff4ac3bcd 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -152,7 +152,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); - assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1, 0, if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1 + rankingExpression(__token_length@-1993461420) + 1, 1, 0))))", config.rankprofile(7).fef().property(1).value()); + assertEquals("tensor<float>(d0[1],d1[2])((if (d1 < rankingExpression(__token_length@-1993461420) + 2, 0, 1)))", config.rankprofile(7).fef().property(1).value()); assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); |