summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-02-08 10:46:18 +0100
committerLester Solbakken <lesters@oath.com>2021-02-08 10:46:18 +0100
commit2c7d759efe6c71952e433aa978c9f546aefac1d8 (patch)
tree92b637834d35c2633444f30081de605648a98c03
parented3025cbe3c37cac465092b67af18664e4a6838d (diff)
Revert "Revert "Add 0's after token sequence for tokenTypeIds ""
This reverts commit 0e5c47a8666e32eb304ca634c4b96ae8e96bb166.
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java27
-rw-r--r--config-model/src/test/integration/onnx-model/schemas/test.sd2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java2
3 files changed, 22 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 192fb9baa9a..032341297bf 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -114,14 +114,20 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
/**
* Transforms a feature of the form
*
- * tokenTypeIds(128, a, ...)
+ * tokenTypeIds(128, a, b, ...)
*
* to an expression that generates a tensor that has values 0 for "a"
* (including CLS and SEP tokens) and 1 for the rest of the sequence.
*
* Concretely, transforms to a tensor generation expression:
*
- * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1))
+ * tensor(d0[1],d1[128])(
+ * if (d1 < 1 + length_a + 1,
+ * 0,
+ * if (d1 < 1 + length_a + 1 + length_b + 1 + ...,
+ * 1,
+ * 0
+ * )))
*/
private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
checkArguments(feature);
@@ -131,11 +137,18 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
// we need to add functions calculating the token lengths of the arguments
createTokenLengthFunctions(feature, context);
- ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1);
- ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg));
- ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
- ExpressionNode expr = new IfNode(comparison, ZERO, ONE);
+ List<ExpressionNode> tokenSequence = createTokenSequence(feature);
+ ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
+ ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
+ ExpressionNode expr = new IfNode(
+ new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr),
+ ZERO,
+ new IfNode(
+ new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr),
+ ONE,
+ ZERO
+ )
+ );
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd
index 5b440e80bed..4f45e0f6318 100644
--- a/config-model/src/test/integration/onnx-model/schemas/test.sd
+++ b/config-model/src/test/integration/onnx-model/schemas/test.sd
@@ -101,7 +101,7 @@ search test {
rank-profile test_dynamic_model_with_transformer_tokens {
function my_function() {
- expression: tokenTypeIds(2, attribute(document_field))
+ expression: tokenTypeIds(10, attribute(document_field), attribute(document_field))
}
first-phase {
expression: onnx(dynamic_model){d0:0,d1:1}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index 73ff4ac3bcd..a3ad9f4f4ba 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -152,7 +152,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
- assertEquals("tensor<float>(d0[1],d1[2])((if (d1 < rankingExpression(__token_length@-1993461420) + 2, 0, 1)))", config.rankprofile(7).fef().property(1).value());
+ assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1, 0, if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1 + rankingExpression(__token_length@-1993461420) + 1, 1, 0))))", config.rankprofile(7).fef().property(1).value());
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());