diff options
author | Lester Solbakken <lesters@oath.com> | 2020-12-07 09:12:37 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-12-07 09:12:37 +0100 |
commit | 2884e5f796de62ed57de95fcf0cdb54a014dee3a (patch) | |
tree | a982e2ef17e8662dab109ebc505ecf83864bb7d6 /config-model | |
parent | 3d0dfbcbfd9454110f9d1459b3e82a29477436bb (diff) |
Don't transform feature if it has wrong arity
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java | 33 |
1 files changed, 13 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 58ae9799f23..3a04821ec13 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -60,11 +60,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if (feature.getName().equals("token_input_ids")) + if (feature.getName().equals("token_input_ids") && shouldTransform(feature, context)) return transformTokenInputIds(feature, context); - if (feature.getName().equals("token_type_ids")) + if (feature.getName().equals("token_type_ids") && shouldTransform(feature, context)) return transformTokenTypeIds(feature, context); - if (feature.getName().equals("token_attention_mask")) + if (feature.getName().equals("token_attention_mask") && shouldTransform(feature, context)) return transformTokenAttentionMask(feature, context); return feature; } @@ -99,9 +99,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform * Functions calculating lengths of arguments are added to the rank profile. */ private ExpressionNode transformTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) { - if (contextHasFunction(feature, context)) - return feature; - checkArguments(feature, context); + checkArguments(feature); TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); @@ -126,9 +124,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1)) */ private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) { - if (contextHasFunction(feature, context)) - return feature; - checkArguments(feature, context); + checkArguments(feature); TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); @@ -158,9 +154,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform * */ private ExpressionNode transformTokenAttentionMask(ReferenceNode feature, RankProfileTransformContext context) { - if (contextHasFunction(feature, context)) - return feature; - checkArguments(feature, context); + checkArguments(feature); TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); @@ -174,16 +168,15 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } - private boolean contextHasFunction(ReferenceNode feature, RankProfileTransformContext context) { - return context.rankProfile().getFunctions().containsKey(feature.getName()); + private boolean shouldTransform(ReferenceNode feature, RankProfileTransformContext context) { + if (context.rankProfile().getFunctions().containsKey(feature.getName())) + return false; + if (feature.getArguments().size() < 2) + return false; + return true; } - private void checkArguments(ReferenceNode feature, RankProfileTransformContext context) { - final String featureName = feature.getName(); - if (feature.getArguments().size() < 2) { - throw new IllegalArgumentException(featureName + " requires at least 2 arguments: the length of the token " + - "sequence and where to retrieve the tokens from."); - } + private void checkArguments(ReferenceNode feature) { for (int i = 1; i < feature.getArguments().size(); ++i) { ExpressionNode arg = feature.getArguments().expressions().get(i); if ( ! (arg instanceof ReferenceNode)) { |