diff options
author | Lester Solbakken <lesters@oath.com> | 2020-12-07 09:32:28 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-12-07 09:32:28 +0100 |
commit | 6b23ba62ddac3886bfeb056ce08800a2baabc241 (patch) | |
tree | 96a8d4e2eee1c72c14da0c93d1850406d0de88fa /config-model | |
parent | 52fbc31bc613198ae6d06b70714b7f7376f17663 (diff) |
Rename tokenizer helper functions
Diffstat (limited to 'config-model')
2 files changed, 12 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 3a04821ec13..1bb38eda9ff 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -30,9 +30,9 @@ import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrap * * Replaces features of the form * - * token_input_ids - * token_type_ids - * token_attention_mask + * tokenInputIds + * tokenTypeIds + * tokenAttentionMask * * to tensor generation expressions that generate the required input. * In general, these models expect input of the form: @@ -60,11 +60,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if (feature.getName().equals("token_input_ids") && shouldTransform(feature, context)) + if (feature.getName().equals("tokenInputIds") && shouldTransform(feature, context)) return transformTokenInputIds(feature, context); - if (feature.getName().equals("token_type_ids") && shouldTransform(feature, context)) + if (feature.getName().equals("tokenTypeIds") && shouldTransform(feature, context)) return transformTokenTypeIds(feature, context); - if (feature.getName().equals("token_attention_mask") && shouldTransform(feature, context)) + if (feature.getName().equals("tokenAttentionMask") && shouldTransform(feature, context)) return transformTokenAttentionMask(feature, context); return feature; } @@ -72,7 +72,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform /** * Transforms a feature of the form * - * token_input_ids(128, a, b, ...) + * tokenInputIds(128, a, b, ...) * * to an expression that concatenates the arguments a, b, ... using the * special Transformers sequences of CLS and SEP, up to length 128, so @@ -114,7 +114,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform /** * Transforms a feature of the form * - * token_type_ids(128, a, ...) + * tokenTypeIds(128, a, ...) * * to an expression that generates a tensor that has values 0 for "a" * (including CLS and SEP tokens) and 1 for the rest of the sequence. @@ -142,7 +142,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform /** * Transforms a feature of the form * - * token_attention_mask(128, a, b, ...) + * tokenAttentionMask(128, a, b, ...) * * to an expression that generates a tensor that has values 1 for all * arguments (including CLS and SEP tokens) and 0 for the rest of the diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java index 19d4b4a6778..174e685b112 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java @@ -28,7 +28,7 @@ public class RankingExpressionWithTransformerTokensTestCase { String a = "tensor(d0[2]):[1,2]"; String b = "tensor(d0[3]):[3,4,5]"; String c = "tensor(d0[2]):[6,7]"; - String expression = "token_input_ids(12, a, b, c)"; + String expression = "tokenInputIds(12, a, b, c)"; Tensor result = evaluateExpression(expression, a, b, c); assertEquals(Tensor.from(expected), result); } @@ -38,7 +38,7 @@ public class RankingExpressionWithTransformerTokensTestCase { String expected = "tensor(d0[1],d1[10]):[0,0,0,0,1,1,1,1,1,1]"; String a = "tensor(d0[2]):[1,2]"; String b = "tensor(d0[3]):[3,4,5]"; - String expression = "token_type_ids(10, a, b)"; + String expression = "tokenTypeIds(10, a, b)"; Tensor result = evaluateExpression(expression, a, b); assertEquals(Tensor.from(expected), result); } @@ -48,7 +48,7 @@ public class RankingExpressionWithTransformerTokensTestCase { String expected = "tensor(d0[1],d1[10]):[1,1,1,1,1,1,1,1,0,0]"; String a = "tensor(d0[2]):[1,2]"; String b = "tensor(d0[3]):[3,4,5]"; - String expression = "token_attention_mask(10, a, b)"; + String expression = "tokenAttentionMask(10, a, b)"; Tensor result = evaluateExpression(expression, a, b); assertEquals(Tensor.from(expected), result); } |