summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArnstein Ressem <aressem@verizonmedia.com>2021-02-08 10:30:13 +0100
committerGitHub <noreply@github.com>2021-02-08 10:30:13 +0100
commited3025cbe3c37cac465092b67af18664e4a6838d (patch)
treee85052ce27247d085f9a89ec5466c045f2e26cc7
parent7fad88a0c27615f3a01dab36e8e4ff754aee0350 (diff)
parent0e5c47a8666e32eb304ca634c4b96ae8e96bb166 (diff)
Merge pull request #16430 from vespa-engine/revert-16413-lesters/fix-token-type-ids
Revert "Add 0's after token sequence for tokenTypeIds "
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java27
-rw-r--r--config-model/src/test/integration/onnx-model/schemas/test.sd2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java2
3 files changed, 9 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 032341297bf..192fb9baa9a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -114,20 +114,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
/**
* Transforms a feature of the form
*
- * tokenTypeIds(128, a, b, ...)
+ * tokenTypeIds(128, a, ...)
*
* to an expression that generates a tensor that has values 0 for "a"
* (including CLS and SEP tokens) and 1 for the rest of the sequence.
*
* Concretely, transforms to a tensor generation expression:
*
- * tensor(d0[1],d1[128])(
- * if (d1 < 1 + length_a + 1,
- * 0,
- * if (d1 < 1 + length_a + 1 + length_b + 1 + ...,
- * 1,
- * 0
- * )))
+ * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1))
*/
private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
checkArguments(feature);
@@ -137,18 +131,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
// we need to add functions calculating the token lengths of the arguments
createTokenLengthFunctions(feature, context);
- List<ExpressionNode> tokenSequence = createTokenSequence(feature);
- ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
- ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
- ExpressionNode expr = new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr),
- ZERO,
- new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr),
- ONE,
- ZERO
- )
- );
+ ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1);
+ ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg));
+ ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO);
+ ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ ExpressionNode expr = new IfNode(comparison, ZERO, ONE);
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd
index 4f45e0f6318..5b440e80bed 100644
--- a/config-model/src/test/integration/onnx-model/schemas/test.sd
+++ b/config-model/src/test/integration/onnx-model/schemas/test.sd
@@ -101,7 +101,7 @@ search test {
rank-profile test_dynamic_model_with_transformer_tokens {
function my_function() {
- expression: tokenTypeIds(10, attribute(document_field), attribute(document_field))
+ expression: tokenTypeIds(2, attribute(document_field))
}
first-phase {
expression: onnx(dynamic_model){d0:0,d1:1}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index a3ad9f4f4ba..73ff4ac3bcd 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -152,7 +152,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
- assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1, 0, if (d1 < 1 + rankingExpression(__token_length@-1993461420) + 1 + rankingExpression(__token_length@-1993461420) + 1, 1, 0))))", config.rankprofile(7).fef().property(1).value());
+ assertEquals("tensor<float>(d0[1],d1[2])((if (d1 < rankingExpression(__token_length@-1993461420) + 2, 0, 1)))", config.rankprofile(7).fef().property(1).value());
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());