aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java27
1 files changed, 7 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 032341297bf..192fb9baa9a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -114,20 +114,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
/**
* Transforms a feature of the form
*
- * tokenTypeIds(128, a, b, ...)
+ * tokenTypeIds(128, a, ...)
*
* to an expression that generates a tensor that has values 0 for "a"
* (including CLS and SEP tokens) and 1 for the rest of the sequence.
*
* Concretely, transforms to a tensor generation expression:
*
- * tensor(d0[1],d1[128])(
- * if (d1 < 1 + length_a + 1,
- * 0,
- * if (d1 < 1 + length_a + 1 + length_b + 1 + ...,
- * 1,
- * 0
- * )))
+ * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1))
*/
private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
checkArguments(feature);
@@ -137,18 +131,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
// we need to add functions calculating the token lengths of the arguments
createTokenLengthFunctions(feature, context);
- List<ExpressionNode> tokenSequence = createTokenSequence(feature);
- ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
- ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
- ExpressionNode expr = new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr),
- ZERO,
- new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr),
- ONE,
- ZERO
- )
- );
+ ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1);
+ ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg));
+ ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO);
+ ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ ExpressionNode expr = new IfNode(comparison, ZERO, ONE);
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}