summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-02-08 10:46:18 +0100
committerLester Solbakken <lesters@oath.com>2021-02-08 10:46:18 +0100
commit2c7d759efe6c71952e433aa978c9f546aefac1d8 (patch)
tree92b637834d35c2633444f30081de605648a98c03 /config-model/src/main/java/com/yahoo
parented3025cbe3c37cac465092b67af18664e4a6838d (diff)
Revert "Revert "Add 0's after token sequence for tokenTypeIds ""
This reverts commit 0e5c47a8666e32eb304ca634c4b96ae8e96bb166.
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java27
1 files changed, 20 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 192fb9baa9a..032341297bf 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -114,14 +114,20 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
/**
* Transforms a feature of the form
*
- * tokenTypeIds(128, a, ...)
+ * tokenTypeIds(128, a, b, ...)
*
* to an expression that generates a tensor that has values 0 for "a"
* (including CLS and SEP tokens) and 1 for the rest of the sequence.
*
* Concretely, transforms to a tensor generation expression:
*
- * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1))
+ * tensor(d0[1],d1[128])(
+ * if (d1 < 1 + length_a + 1,
+ * 0,
+ * if (d1 < 1 + length_a + 1 + length_b + 1 + ...,
+ * 1,
+ * 0
+ * )))
*/
private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
checkArguments(feature);
@@ -131,11 +137,18 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
// we need to add functions calculating the token lengths of the arguments
createTokenLengthFunctions(feature, context);
- ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1);
- ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg));
- ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
- ExpressionNode expr = new IfNode(comparison, ZERO, ONE);
+ List<ExpressionNode> tokenSequence = createTokenSequence(feature);
+ ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
+ ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
+ ExpressionNode expr = new IfNode(
+ new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr),
+ ZERO,
+ new IfNode(
+ new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr),
+ ONE,
+ ZERO
+ )
+ );
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}