summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2023-09-11 13:06:58 +0200
committerLester Solbakken <lesters@oath.com>2023-09-11 13:06:58 +0200
commit0c6889ebf0a7842c8c943ea54000f2f4055f7795 (patch)
tree9e8c89c9cc9a4b96f90b72300f03a3d803bda9cd /config-model
parentc38fcd2e6f09273459ade724fd571e615ff3f6c9 (diff)
Add utility function for custom token ids
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java61
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java11
2 files changed, 54 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index f454d941e31..04a31a47190 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -63,6 +63,8 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if (feature.getName().equals("customTokenInputIds") && shouldTransform(feature, context))
+ return transformCustomTokenInputIds(feature, context);
if (feature.getName().equals("tokenInputIds") && shouldTransform(feature, context))
return transformTokenInputIds(feature, context);
if (feature.getName().equals("tokenTypeIds") && shouldTransform(feature, context))
@@ -102,15 +104,38 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
* Functions calculating lengths of arguments are added to the rank profile.
*/
private ExpressionNode transformTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) {
- checkArguments(feature);
+ return transformTokenInputIds(feature, context, CLS, SEP, 1);
+ }
- TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
+ /**
+ * Transforms a feature of the form
+ *
+ * customTokenInputIds(1, 2, 128, a, b, ...)
+ *
+ * to an expression that concatenates the arguments a, b, ... using the
+ * first and second arguments as the CLS and SEP padding tokens, here
+ * 1 and 2, respectively. Otherwise, identical to tokenInputIds.
+ */
+ private ExpressionNode transformCustomTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) {
+ ExpressionNode cls = feature.getArguments().expressions().get(0);
+ ExpressionNode sep = feature.getArguments().expressions().get(1);
+ return transformTokenInputIds(feature, context, cls, sep, 3);
+ }
+
+ private ExpressionNode transformTokenInputIds(ReferenceNode feature,
+ RankProfileTransformContext context,
+ ExpressionNode cls,
+ ExpressionNode sep,
+ int fromArg) {
+ checkReferenceArguments(feature, fromArg);
+
+ TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(fromArg - 1));
// we need to add functions calculating the token lengths of the arguments
- createTokenLengthFunctions(feature, context);
+ createTokenLengthFunctions(feature, context, fromArg);
// create token sequence: CLS + arg1 + SEP + arg2 + SEP + ....
- ExpressionNode tokenSequenceExpr = createTokenSequenceExpr(0, createTokenSequence(feature));
+ ExpressionNode tokenSequenceExpr = createTokenSequenceExpr(0, createTokenSequence(feature, cls, sep, fromArg));
return new TensorFunctionNode(Generate.bound(type, wrapScalar(tokenSequenceExpr)));
}
@@ -133,14 +158,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
* )))
*/
private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
- checkArguments(feature);
+ checkReferenceArguments(feature, 1);
TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
// we need to add functions calculating the token lengths of the arguments
- createTokenLengthFunctions(feature, context);
+ createTokenLengthFunctions(feature, context, 1);
- List<ExpressionNode> tokenSequence = createTokenSequence(feature);
+ List<ExpressionNode> tokenSequence = createTokenSequence(feature, CLS, SEP, 1);
ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
ExpressionNode expr = new IfNode(
@@ -170,14 +195,14 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*
*/
private ExpressionNode transformTokenAttentionMask(ReferenceNode feature, RankProfileTransformContext context) {
- checkArguments(feature);
+ checkReferenceArguments(feature, 1);
TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
// we need to add functions calculating the token lengths of the arguments
- createTokenLengthFunctions(feature, context);
+ createTokenLengthFunctions(feature, context, 1);
- List<ExpressionNode> tokenSequence = createTokenSequence(feature);
+ List<ExpressionNode> tokenSequence = createTokenSequence(feature, CLS, SEP, 1);
ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr);
ExpressionNode expr = new IfNode(comparison, ONE, ZERO);
@@ -192,8 +217,8 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
return true;
}
- private void checkArguments(ReferenceNode feature) {
- for (int i = 1; i < feature.getArguments().size(); ++i) {
+ private void checkReferenceArguments(ReferenceNode feature, int fromArg) {
+ for (int i = fromArg; i < feature.getArguments().size(); ++i) {
ExpressionNode arg = feature.getArguments().expressions().get(i);
if ( ! (arg instanceof ReferenceNode)) {
throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " +
@@ -229,12 +254,12 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
return commonLengthFunction.expand(ctx, List.of(arg), new ArrayDeque<String>());
}
- private List<ExpressionNode> createTokenSequence(ReferenceNode feature) {
+ private List<ExpressionNode> createTokenSequence(ReferenceNode feature, ExpressionNode cls, ExpressionNode sep, int fromArg) {
List<ExpressionNode> sequence = new ArrayList<>();
- sequence.add(CLS);
- for (int i = 1; i < feature.getArguments().size(); ++i) {
+ sequence.add(cls);
+ for (int i = fromArg; i < feature.getArguments().size(); ++i) {
sequence.add(feature.getArguments().expressions().get(i));
- sequence.add(SEP);
+ sequence.add(sep);
}
return sequence;
}
@@ -244,8 +269,8 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
* token sequences are 0-padded, so this returns the number of non-0
* tokens using a map and reduce-sum.
*/
- private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) {
- for (int i = 1; i < feature.getArguments().size(); ++i) {
+ private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context, int fromArg) {
+ for (int i = fromArg; i < feature.getArguments().size(); ++i) {
ExpressionNode arg = feature.getArguments().expressions().get(i);
if ( ! (arg instanceof ReferenceNode ref)) {
throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " +
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java
index 5c82be0745e..6cfd7126fff 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.java
@@ -37,6 +37,17 @@ public class RankingExpressionWithTransformerTokensTestCase {
}
@Test
+ void testTokenInputIdsCustomPadTokens() throws Exception {
+ String expected = "tensor(d0[1],d1[13]):[1,11,12,2,13,14,15,2,16,17,2,0,0]";
+ String a = "tensor(d0[2]):[11,12]";
+ String b = "tensor(d0[3]):[13,14,15]";
+ String c = "tensor(d0[2]):[16,17]";
+ String expression = "customTokenInputIds(1, 2, 13, a, b, c)";
+ Tensor result = evaluateExpression(expression, a, b, c);
+ assertEquals(Tensor.from(expected), result);
+ }
+
+ @Test
void testTokenTypeIds() throws Exception {
String expected = "tensor(d0[1],d1[10]):[0,0,0,0,1,1,1,1,0,0]";
String a = "tensor(d0[2]):[1,2]";