aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-12-07 09:12:37 +0100
committerLester Solbakken <lesters@oath.com>2020-12-07 09:12:37 +0100
commit2884e5f796de62ed57de95fcf0cdb54a014dee3a (patch)
treea982e2ef17e8662dab109ebc505ecf83864bb7d6 /config-model
parent3d0dfbcbfd9454110f9d1459b3e82a29477436bb (diff)
Don't transform feature if it has wrong arity
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java33
1 files changed, 13 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 58ae9799f23..3a04821ec13 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -60,11 +60,11 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if (feature.getName().equals("token_input_ids"))
+ if (feature.getName().equals("token_input_ids") && shouldTransform(feature, context))
return transformTokenInputIds(feature, context);
- if (feature.getName().equals("token_type_ids"))
+ if (feature.getName().equals("token_type_ids") && shouldTransform(feature, context))
return transformTokenTypeIds(feature, context);
- if (feature.getName().equals("token_attention_mask"))
+ if (feature.getName().equals("token_attention_mask") && shouldTransform(feature, context))
return transformTokenAttentionMask(feature, context);
return feature;
}
@@ -99,9 +99,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
* Functions calculating lengths of arguments are added to the rank profile.
*/
private ExpressionNode transformTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) {
- if (contextHasFunction(feature, context))
- return feature;
- checkArguments(feature, context);
+ checkArguments(feature);
TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
@@ -126,9 +124,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
* tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1))
*/
private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
- if (contextHasFunction(feature, context))
- return feature;
- checkArguments(feature, context);
+ checkArguments(feature);
TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
@@ -158,9 +154,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*
*/
private ExpressionNode transformTokenAttentionMask(ReferenceNode feature, RankProfileTransformContext context) {
- if (contextHasFunction(feature, context))
- return feature;
- checkArguments(feature, context);
+ checkArguments(feature);
TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
@@ -174,16 +168,15 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
- private boolean contextHasFunction(ReferenceNode feature, RankProfileTransformContext context) {
- return context.rankProfile().getFunctions().containsKey(feature.getName());
+ private boolean shouldTransform(ReferenceNode feature, RankProfileTransformContext context) {
+ if (context.rankProfile().getFunctions().containsKey(feature.getName()))
+ return false;
+ if (feature.getArguments().size() < 2)
+ return false;
+ return true;
}
- private void checkArguments(ReferenceNode feature, RankProfileTransformContext context) {
- final String featureName = feature.getName();
- if (feature.getArguments().size() < 2) {
- throw new IllegalArgumentException(featureName + " requires at least 2 arguments: the length of the token " +
- "sequence and where to retrieve the tokens from.");
- }
+ private void checkArguments(ReferenceNode feature) {
for (int i = 1; i < feature.getArguments().size(); ++i) {
ExpressionNode arg = feature.getArguments().expressions().get(i);
if ( ! (arg instanceof ReferenceNode)) {