summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-12-06 12:49:45 +0100
committerLester Solbakken <lesters@oath.com>2020-12-06 12:49:45 +0100
commitec8abe27e6c48439526c6fb5b0277e61bfc5e4bb (patch)
tree32f35f097f785f0d1a76b59b85f1a8258268f7a6 /config-model
parent23f7bf9b66adc6316c9642b8c29c6aeb93e316b9 (diff)
Add convenience functions for Transformer models
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java307
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java95
3 files changed, 403 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
index a723be8b478..b19c0c3152d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
@@ -29,6 +29,7 @@ public class ExpressionTransforms {
new OnnxModelTransformer(),
new XgboostFeatureConverter(),
new LightGBMFeatureConverter(),
+ new TokenTransformer(),
new ConstantDereferencer(),
new ConstantTensorTransformer(),
new FunctionInliner(),
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
new file mode 100644
index 00000000000..58ae9799f23
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -0,0 +1,307 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.IfNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Slice;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+/**
+ * Convenience feature transforms for inputs to Transformer type models.
+ *
+ * Replaces features of the form
+ *
+ * token_input_ids
+ * token_type_ids
+ * token_attention_mask
+ *
+ * to tensor generation expressions that generate the required input.
+ * In general, these models expect input of the form:
+ *
+ * CLS + arg1 + SEP + arg2 + SEP + 0's
+ *
+ * @author lesters
+ */
+public class TokenTransformer extends ExpressionTransformer<RankProfileTransformContext> {
+
+ static private final ConstantNode ZERO = new ConstantNode(new DoubleValue(0.0), "0");
+ static private final ConstantNode ONE = new ConstantNode(new DoubleValue(1.0), "1");
+ static private final ConstantNode TWO = new ConstantNode(new DoubleValue(2.0), "2");
+ static private final ConstantNode CLS = new ConstantNode(new DoubleValue(101), "101");
+ static private final ConstantNode SEP = new ConstantNode(new DoubleValue(102), "102");
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode)
+ return transformFeature((ReferenceNode) node, context);
+ else if (node instanceof CompositeNode)
+ return super.transformChildren((CompositeNode) node, context);
+ else
+ return node;
+ }
+
+ private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if (feature.getName().equals("token_input_ids"))
+ return transformTokenInputIds(feature, context);
+ if (feature.getName().equals("token_type_ids"))
+ return transformTokenTypeIds(feature, context);
+ if (feature.getName().equals("token_attention_mask"))
+ return transformTokenAttentionMask(feature, context);
+ return feature;
+ }
+
+ /**
+ * Transforms a feature of the form
+ *
+ * token_input_ids(128, a, b, ...)
+ *
+ * to an expression that concatenates the arguments a, b, ... using the
+ * special Transformers sequences of CLS and SEP, up to length 128, so
+ * that the sequence becomes
+ *
+ * CLS + a + SEP + b + SEP + 0's
+ *
+ * Concretely, transforms to a tensor generation expression:
+ *
+ * tensor(d0[1],d1[128])(
+ * if (d1 < 1,
+ * 101,
+ * if (d1 < 1 + length_a,
+ * a{d0:(d1 - (1)},
+ * if (d1 < 1 + length_a + 1,
+ * 102,
+ * if (d1 < 1 + length_a + 1 + length_b,
+ * b{d0:(d1 - (1 + length_a + 1))},
+ * if (d1 < 1 + length_a + 1 + length_b + 1,
+ * 102,
+ * 0.0
+ * ))))))
+ *
+ * Functions calculating lengths of arguments are added to the rank profile.
+ */
+ private ExpressionNode transformTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) {
+ if (contextHasFunction(feature, context))
+ return feature;
+ checkArguments(feature, context);
+
+ TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
+
+ // we need to add functions calculating the token lengths of the arguments
+ createTokenLengthFunctions(feature, context);
+
+ // create token sequence: CLS + arg1 + SEP + arg2 + SEP + ....
+ ExpressionNode tokenSequenceExpr = createTokenSequenceExpr(0, createTokenSequence(feature));
+ return new TensorFunctionNode(Generate.bound(type, wrapScalar(tokenSequenceExpr)));
+ }
+
+ /**
+ * Transforms a feature of the form
+ *
+ * token_type_ids(128, a, ...)
+ *
+ * to an expression that generates a tensor that has values 0 for "a"
+ * (including CLS and SEP tokens) and 1 for the rest of the sequence.
+ *
+ * Concretely, transforms to a tensor generation expression:
+ *
+ * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1))
+ */
+ private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
+ if (contextHasFunction(feature, context))
+ return feature;
+ checkArguments(feature, context);
+
+ TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
+
+ // we need to add functions calculating the token lengths of the arguments
+ createTokenLengthFunctions(feature, context);
+
+ ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1);
+ ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg));
+ ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO);
+ ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ ExpressionNode expr = new IfNode(comparison, ZERO, ONE);
+ return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
+ }
+
+ /**
+ * Transforms a feature of the form
+ *
+ * token_attention_mask(128, a, b, ...)
+ *
+ * to an expression that generates a tensor that has values 1 for all
+ * arguments (including CLS and SEP tokens) and 0 for the rest of the
+ * sequence.
+ *
+ * Concretely, transforms to a tensor generation expression:
+ *
+ * tensor(d0[1],d1[128])(if(d1 < 1 + length_a + 1 + length_b + 1 + ..., 1, 0))
+ *
+ */
+ private ExpressionNode transformTokenAttentionMask(ReferenceNode feature, RankProfileTransformContext context) {
+ if (contextHasFunction(feature, context))
+ return feature;
+ checkArguments(feature, context);
+
+ TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0));
+
+ // we need to add functions calculating the token lengths of the arguments
+ createTokenLengthFunctions(feature, context);
+
+ List<ExpressionNode> tokenSequence = createTokenSequence(feature);
+ ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
+ ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ ExpressionNode expr = new IfNode(comparison, ONE, ZERO);
+ return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
+ }
+
+ private boolean contextHasFunction(ReferenceNode feature, RankProfileTransformContext context) {
+ return context.rankProfile().getFunctions().containsKey(feature.getName());
+ }
+
+ private void checkArguments(ReferenceNode feature, RankProfileTransformContext context) {
+ final String featureName = feature.getName();
+ if (feature.getArguments().size() < 2) {
+ throw new IllegalArgumentException(featureName + " requires at least 2 arguments: the length of the token " +
+ "sequence and where to retrieve the tokens from.");
+ }
+ for (int i = 1; i < feature.getArguments().size(); ++i) {
+ ExpressionNode arg = feature.getArguments().expressions().get(i);
+ if ( ! (arg instanceof ReferenceNode)) {
+ throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " +
+ "the argument must be a reference. Got " + arg.toString());
+ }
+ }
+ }
+
+ private TensorType createTensorType(String featureName, ExpressionNode argument) {
+ try {
+ int length = Integer.parseInt(argument.toString());
+ return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1).indexed("d1", length).build();
+ } catch (NumberFormatException ex) {
+ throw new IllegalArgumentException("Invalid argument to " + featureName + ": the first argument must be " +
+ "the length to the token sequence to generate. Got " + argument.toString());
+ }
+ }
+
+ private String lengthFunctionName(ReferenceNode arg) {
+ return "__token_length@" + arg.hashCode();
+ }
+
+ private List<ExpressionNode> createTokenSequence(ReferenceNode feature) {
+ List<ExpressionNode> sequence = new ArrayList<>();
+ sequence.add(CLS);
+ for (int i = 1; i < feature.getArguments().size(); ++i) {
+ sequence.add(feature.getArguments().expressions().get(i));
+ sequence.add(SEP);
+ }
+ return sequence;
+ }
+
+ /**
+ * Adds functions for calculating the token length input. Assumes that
+ * token sequences are 0-padded, so this returns the number of non-0
+ * tokens using a map and reduce-sum.
+ */
+ private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) {
+ for (int i = 1; i < feature.getArguments().size(); ++i) {
+ ExpressionNode arg = feature.getArguments().expressions().get(i);
+ if ( ! (arg instanceof ReferenceNode)) {
+ throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " +
+ "the argument must be a reference. Got " + arg.toString());
+ }
+ ReferenceNode ref = (ReferenceNode) arg;
+ String functionName = lengthFunctionName(ref);
+ if ( ! context.rankProfile().getFunctions().containsKey(functionName)) {
+ context.rankProfile().addFunction(functionName, List.of(), "sum(map(" + ref + ", f(x)(x > 0)))", false);
+ }
+ }
+ }
+
+ /**
+ * Recursively creates partial expressions of the form
+ *
+ * if (d1 < 1 + length_a,
+ * a{d0:(d1 - 1},
+ * ...
+ *
+ * for each part of the token sequence. CLS and SEP are added directly,
+ * and we create a slice expression for each argument to extract the
+ * actual tokens.
+ */
+ private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) {
+ ExpressionNode lengthExpr = createLengthExpr(iter, sequence);
+ ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+
+ ExpressionNode trueExpr = sequence.get(iter);
+ if (sequence.get(iter) instanceof ReferenceNode) {
+ trueExpr = createTokenExtractExpr(iter, sequence);
+ }
+
+ ExpressionNode falseExpr;
+ if (iter < sequence.size() - 1) {
+ falseExpr = createTokenSequenceExpr(iter + 1, sequence);
+ } else {
+ falseExpr = ZERO; // 0-padding for rest of sequence
+ }
+
+ return new IfNode(comparison, trueExpr, falseExpr);
+ }
+
+ /**
+ * Creates an expression for the length of the token sequence so far, where
+ * the lengths of CLS and SEP are 1, and the length of the arguments are
+ * calculated using auxiliary functions.
+ */
+ private ExpressionNode createLengthExpr(int iter, List<ExpressionNode> sequence) {
+ List<ExpressionNode> factors = new ArrayList<>();
+ List<ArithmeticOperator> operators = new ArrayList<>();
+ for (int i = 0; i < iter + 1; ++i) {
+ if (sequence.get(i) instanceof ConstantNode) {
+ factors.add(ONE);
+ } else if (sequence.get(i) instanceof ReferenceNode) {
+ factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i))));
+ }
+ if (i >= 1) {
+ operators.add(ArithmeticOperator.PLUS);
+ }
+ }
+ return new ArithmeticNode(factors, operators);
+ }
+
+ /**
+ * Create the slice expression to extract the tokens from arguments
+ */
+ private ExpressionNode createTokenExtractExpr(int iter, List<ExpressionNode> sequence) {
+ ExpressionNode expr;
+ if (iter >= 1) {
+ ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence));
+ expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr));
+ } else {
+ expr = new ReferenceNode("d1");
+ }
+ List<Slice.DimensionValue<Reference>> slices = List.of(new Slice.DimensionValue<>("d0", wrapScalar(expr)) );
+ TensorFunction<Reference> argument = new TensorFunctionNode.ExpressionTensorFunction(sequence.get(iter));
+ return new TensorFunctionNode(new Slice<>(argument, slices));
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java
new file mode 100644
index 00000000000..19d4b4a6778
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java
@@ -0,0 +1,95 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.model.test.MockApplicationPackage;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
+import com.yahoo.searchdefinition.expressiontransforms.TokenTransformer;
+import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import java.util.Collections;
+
+import static org.junit.Assert.assertEquals;
+
+public class RankingExpressionWithTransformerTokensTestCase {
+
+ @Test
+ public void testTokenInputIds() throws Exception {
+ String expected = "tensor(d0[1],d1[12]):[101,1,2,102,3,4,5,102,6,7,102,0]";
+ String a = "tensor(d0[2]):[1,2]";
+ String b = "tensor(d0[3]):[3,4,5]";
+ String c = "tensor(d0[2]):[6,7]";
+ String expression = "token_input_ids(12, a, b, c)";
+ Tensor result = evaluateExpression(expression, a, b, c);
+ assertEquals(Tensor.from(expected), result);
+ }
+
+ @Test
+ public void testTokenTypeIds() throws Exception {
+ String expected = "tensor(d0[1],d1[10]):[0,0,0,0,1,1,1,1,1,1]";
+ String a = "tensor(d0[2]):[1,2]";
+ String b = "tensor(d0[3]):[3,4,5]";
+ String expression = "token_type_ids(10, a, b)";
+ Tensor result = evaluateExpression(expression, a, b);
+ assertEquals(Tensor.from(expected), result);
+ }
+
+ @Test
+ public void testAttentionMask() throws Exception {
+ String expected = "tensor(d0[1],d1[10]):[1,1,1,1,1,1,1,1,0,0]";
+ String a = "tensor(d0[2]):[1,2]";
+ String b = "tensor(d0[3]):[3,4,5]";
+ String expression = "token_attention_mask(10, a, b)";
+ Tensor result = evaluateExpression(expression, a, b);
+ assertEquals(Tensor.from(expected), result);
+ }
+
+ private Tensor evaluateExpression(String expression, String a, String b) throws Exception {
+ return evaluateExpression(expression, a, b, null, null);
+ }
+
+ private Tensor evaluateExpression(String expression, String a, String b, String c) throws Exception {
+ return evaluateExpression(expression, a, b, c, null);
+ }
+
+ private Tensor evaluateExpression(String expression, String a, String b, String c, String d) throws Exception {
+ MapContext context = new MapContext();
+ if (a != null) context.put("a", new TensorValue(Tensor.from(a)));
+ if (b != null) context.put("b", new TensorValue(Tensor.from(b)));
+ if (c != null) context.put("c", new TensorValue(Tensor.from(c)));
+ if (d != null) context.put("d", new TensorValue(Tensor.from(d)));
+ var transformContext = createTransformContext();
+ var rankingExpression = new RankingExpression(expression);
+ var transformed = new TokenTransformer().transform(rankingExpression, transformContext);
+ for (var entry : transformContext.rankProfile().getFunctions().entrySet()) {
+ context.put(entry.getKey(), entry.getValue().function().getBody().evaluate(context).asDouble());
+ }
+ return transformed.evaluate(context).asTensor();
+ }
+
+ private RankProfileTransformContext createTransformContext() throws ParseException {
+ MockApplicationPackage application = (MockApplicationPackage) MockApplicationPackage.createEmpty();
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ QueryProfileRegistry queryProfileRegistry = application.getQueryProfiles();
+ String sdContent = "search test {\n" +
+ " document test {}\n" +
+ " rank-profile my_profile inherits default {}\n" +
+ "}";
+ SearchBuilder searchBuilder = new SearchBuilder(application, rankProfileRegistry, queryProfileRegistry);
+ searchBuilder.importString(sdContent);
+ searchBuilder.build();
+ Search search = searchBuilder.getSearch();
+ RankProfile rp = rankProfileRegistry.get(search, "my_profile");
+ return new RankProfileTransformContext(rp, queryProfileRegistry, Collections.EMPTY_MAP, null, Collections.EMPTY_MAP, Collections.EMPTY_MAP);
+ }
+
+}