diff options
2 files changed, 54 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index f454d941e31..04a31a47190 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -63,6 +63,8 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if (feature.getName().equals("customTokenInputIds") && shouldTransform(feature, context)) + return transformCustomTokenInputIds(feature, context); if (feature.getName().equals("tokenInputIds") && shouldTransform(feature, context)) return transformTokenInputIds(feature, context); if (feature.getName().equals("tokenTypeIds") && shouldTransform(feature, context)) @@ -102,15 +104,38 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform * Functions calculating lengths of arguments are added to the rank profile. */ private ExpressionNode transformTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) { - checkArguments(feature); + return transformTokenInputIds(feature, context, CLS, SEP, 1); + } - TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); + /** + * Transforms a feature of the form + * + * customTokenInputIds(1, 2, 128, a, b, ...) + * + * to an expression that concatenates the arguments a, b, ... using the + * first and second arguments as the CLS and SEP padding tokens, here + * 1 and 2, respectively. Otherwise, identical to tokenInputIds. + */ + private ExpressionNode transformCustomTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) { + ExpressionNode cls = feature.getArguments().expressions().get(0); + ExpressionNode sep = feature.getArguments().expressions().get(1); + return transformTokenInputIds(feature, context, cls, sep, 3); + } + + private ExpressionNode transformTokenInputIds(ReferenceNode feature, + RankProfileTransformContext context, + ExpressionNode cls, + ExpressionNode sep, + int fromArg) { + checkReferenceArguments(feature, fromArg); + + TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(fromArg - 1)); // we need to add functions calculating the token lengths of the arguments - createTokenLengthFunctions(feature, context); + createTokenLengthFunctions(feature, context, fromArg); // create token sequence: CLS + arg1 + SEP + arg2 + SEP + .... - ExpressionNode tokenSequenceExpr = createTokenSequenceExpr(0, createTokenSequence(feature)); + ExpressionNode tokenSequenceExpr = createTokenSequenceExpr(0, createTokenSequence(feature, cls, sep, fromArg)); return new TensorFunctionNode(Generate.bound(type, wrapScalar(tokenSequenceExpr))); } @@ -133,14 +158,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform * ))) */ private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) { - checkArguments(feature); + checkReferenceArguments(feature, 1); TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); // we need to add functions calculating the token lengths of the arguments - createTokenLengthFunctions(feature, context); + createTokenLengthFunctions(feature, context, 1); - List<ExpressionNode> tokenSequence = createTokenSequence(feature); + List<ExpressionNode> tokenSequence = createTokenSequence(feature, CLS, SEP, 1); ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); ExpressionNode expr = new IfNode( @@ -170,14 +195,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform * */ private ExpressionNode transformTokenAttentionMask(ReferenceNode feature, RankProfileTransformContext context) { - checkArguments(feature); + checkReferenceArguments(feature, 1); TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); // we need to add functions calculating the token lengths of the arguments - createTokenLengthFunctions(feature, context); + createTokenLengthFunctions(feature, context, 1); - List<ExpressionNode> tokenSequence = createTokenSequence(feature); + List<ExpressionNode> tokenSequence = createTokenSequence(feature, CLS, SEP, 1); ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode expr = new IfNode(comparison, ONE, ZERO); @@ -192,8 +217,8 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform return true; } - private void checkArguments(ReferenceNode feature) { - for (int i = 1; i < feature.getArguments().size(); ++i) { + private void checkReferenceArguments(ReferenceNode feature, int fromArg) { + for (int i = fromArg; i < feature.getArguments().size(); ++i) { ExpressionNode arg = feature.getArguments().expressions().get(i); if ( ! (arg instanceof ReferenceNode)) { throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " + @@ -229,12 +254,12 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform return commonLengthFunction.expand(ctx, List.of(arg), new ArrayDeque<String>()); } - private List<ExpressionNode> createTokenSequence(ReferenceNode feature) { + private List<ExpressionNode> createTokenSequence(ReferenceNode feature, ExpressionNode cls, ExpressionNode sep, int fromArg) { List<ExpressionNode> sequence = new ArrayList<>(); - sequence.add(CLS); - for (int i = 1; i < feature.getArguments().size(); ++i) { + sequence.add(cls); + for (int i = fromArg; i < feature.getArguments().size(); ++i) { sequence.add(feature.getArguments().expressions().get(i)); - sequence.add(SEP); + sequence.add(sep); } return sequence; } @@ -244,8 +269,8 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform * token sequences are 0-padded, so this returns the number of non-0 * tokens using a map and reduce-sum. */ - private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) { - for (int i = 1; i < feature.getArguments().size(); ++i) { + private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context, int fromArg) { + for (int i = fromArg; i < feature.getArguments().size(); ++i) { ExpressionNode arg = feature.getArguments().expressions().get(i); if ( ! (arg instanceof ReferenceNode ref)) { throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " + diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java index 5c82be0745e..6cfd7126fff 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java @@ -37,6 +37,17 @@ public class RankingExpressionWithTransformerTokensTestCase { } @Test + void testTokenInputIdsCustomPadTokens() throws Exception { + String expected = "tensor(d0[1],d1[13]):[1,11,12,2,13,14,15,2,16,17,2,0,0]"; + String a = "tensor(d0[2]):[11,12]"; + String b = "tensor(d0[3]):[13,14,15]"; + String c = "tensor(d0[2]):[16,17]"; + String expression = "customTokenInputIds(1, 2, 13, a, b, c)"; + Tensor result = evaluateExpression(expression, a, b, c); + assertEquals(Tensor.from(expected), result); + } + + @Test void testTokenTypeIds() throws Exception { String expected = "tensor(d0[1],d1[10]):[0,0,0,0,1,1,1,1,0,0]"; String a = "tensor(d0[2]):[1,2]"; |