summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-12-07 09:32:28 +0100
committerLester Solbakken <lesters@oath.com>2020-12-07 09:32:28 +0100
commit6b23ba62ddac3886bfeb056ce08800a2baabc241 (patch)
tree96a8d4e2eee1c72c14da0c93d1850406d0de88fa /config-model
parent52fbc31bc613198ae6d06b70714b7f7376f17663 (diff)
Rename tokenizer helper functions
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java18
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java6
2 files changed, 12 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 3a04821ec13..1bb38eda9ff 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -30,9 +30,9 @@ import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrap
*
* Replaces features of the form
*
- * token_input_ids
- * token_type_ids
- * token_attention_mask
+ * tokenInputIds
+ * tokenTypeIds
+ * tokenAttentionMask
*
* to tensor generation expressions that generate the required input.
* In general, these models expect input of the form:
@@ -60,11 +60,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if (feature.getName().equals("token_input_ids") && shouldTransform(feature, context))
+ if (feature.getName().equals("tokenInputIds") && shouldTransform(feature, context))
return transformTokenInputIds(feature, context);
- if (feature.getName().equals("token_type_ids") && shouldTransform(feature, context))
+ if (feature.getName().equals("tokenTypeIds") && shouldTransform(feature, context))
return transformTokenTypeIds(feature, context);
- if (feature.getName().equals("token_attention_mask") && shouldTransform(feature, context))
+ if (feature.getName().equals("tokenAttentionMask") && shouldTransform(feature, context))
return transformTokenAttentionMask(feature, context);
return feature;
}
@@ -72,7 +72,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
/**
* Transforms a feature of the form
*
- * token_input_ids(128, a, b, ...)
+ * tokenInputIds(128, a, b, ...)
*
* to an expression that concatenates the arguments a, b, ... using the
* special Transformers sequences of CLS and SEP, up to length 128, so
@@ -114,7 +114,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
/**
* Transforms a feature of the form
*
- * token_type_ids(128, a, ...)
+ * tokenTypeIds(128, a, ...)
*
* to an expression that generates a tensor that has values 0 for "a"
* (including CLS and SEP tokens) and 1 for the rest of the sequence.
@@ -142,7 +142,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
/**
* Transforms a feature of the form
*
- * token_attention_mask(128, a, b, ...)
+ * tokenAttentionMask(128, a, b, ...)
*
* to an expression that generates a tensor that has values 1 for all
* arguments (including CLS and SEP tokens) and 0 for the rest of the
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java
index 19d4b4a6778..174e685b112 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java
@@ -28,7 +28,7 @@ public class RankingExpressionWithTransformerTokensTestCase {
String a = "tensor(d0[2]):[1,2]";
String b = "tensor(d0[3]):[3,4,5]";
String c = "tensor(d0[2]):[6,7]";
- String expression = "token_input_ids(12, a, b, c)";
+ String expression = "tokenInputIds(12, a, b, c)";
Tensor result = evaluateExpression(expression, a, b, c);
assertEquals(Tensor.from(expected), result);
}
@@ -38,7 +38,7 @@ public class RankingExpressionWithTransformerTokensTestCase {
String expected = "tensor(d0[1],d1[10]):[0,0,0,0,1,1,1,1,1,1]";
String a = "tensor(d0[2]):[1,2]";
String b = "tensor(d0[3]):[3,4,5]";
- String expression = "token_type_ids(10, a, b)";
+ String expression = "tokenTypeIds(10, a, b)";
Tensor result = evaluateExpression(expression, a, b);
assertEquals(Tensor.from(expected), result);
}
@@ -48,7 +48,7 @@ public class RankingExpressionWithTransformerTokensTestCase {
String expected = "tensor(d0[1],d1[10]):[1,1,1,1,1,1,1,1,0,0]";
String a = "tensor(d0[2]):[1,2]";
String b = "tensor(d0[3]):[3,4,5]";
- String expression = "token_attention_mask(10, a, b)";
+ String expression = "tokenAttentionMask(10, a, b)";
Tensor result = evaluateExpression(expression, a, b);
assertEquals(Tensor.from(expected), result);
}