diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java | 27 |
1 files changed, 7 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 032341297bf..192fb9baa9a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -114,20 +114,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform /** * Transforms a feature of the form * - * tokenTypeIds(128, a, b, ...) + * tokenTypeIds(128, a, ...) * * to an expression that generates a tensor that has values 0 for "a" * (including CLS and SEP tokens) and 1 for the rest of the sequence. * * Concretely, transforms to a tensor generation expression: * - * tensor(d0[1],d1[128])( - * if (d1 < 1 + length_a + 1, - * 0, - * if (d1 < 1 + length_a + 1 + length_b + 1 + ..., - * 1, - * 0 - * ))) + * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1)) */ private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) { checkArguments(feature); @@ -137,18 +131,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform // we need to add functions calculating the token lengths of the arguments createTokenLengthFunctions(feature, context); - List<ExpressionNode> tokenSequence = createTokenSequence(feature); - ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); - ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - ExpressionNode expr = new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), - ZERO, - new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), - ONE, - ZERO - ) - ); + ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1); + ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg)); + ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO); + ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + ExpressionNode expr = new IfNode(comparison, ZERO, ONE); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } |