diff options
author | Lester Solbakken <lesters@oath.com> | 2021-02-08 10:46:18 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-02-08 10:46:18 +0100 |
commit | 2c7d759efe6c71952e433aa978c9f546aefac1d8 (patch) | |
tree | 92b637834d35c2633444f30081de605648a98c03 /config-model/src/main/java/com/yahoo | |
parent | ed3025cbe3c37cac465092b67af18664e4a6838d (diff) |
Revert "Revert "Add 0's after token sequence for tokenTypeIds ""
This reverts commit 0e5c47a8666e32eb304ca634c4b96ae8e96bb166.
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 192fb9baa9a..032341297bf 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -114,14 +114,20 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform /** * Transforms a feature of the form * - * tokenTypeIds(128, a, ...) + * tokenTypeIds(128, a, b, ...) * * to an expression that generates a tensor that has values 0 for "a" * (including CLS and SEP tokens) and 1 for the rest of the sequence. * * Concretely, transforms to a tensor generation expression: * - * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1)) + * tensor(d0[1],d1[128])( + * if (d1 < 1 + length_a + 1, + * 0, + * if (d1 < 1 + length_a + 1 + length_b + 1 + ..., + * 1, + * 0 + * ))) */ private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) { checkArguments(feature); @@ -131,11 +137,18 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform // we need to add functions calculating the token lengths of the arguments createTokenLengthFunctions(feature, context); - ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1); - ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg)); - ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); - ExpressionNode expr = new IfNode(comparison, ZERO, ONE); + List<ExpressionNode> tokenSequence = createTokenSequence(feature); + ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); + ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); + ExpressionNode expr = new IfNode( + new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), + ZERO, + new IfNode( + new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), + ONE, + ZERO + ) + ); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } |