From ec8abe27e6c48439526c6fb5b0277e61bfc5e4bb Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Sun, 6 Dec 2020 12:49:45 +0100 Subject: Add convenience functions for Transformer models --- .../expressiontransforms/ExpressionTransforms.java | 1 + .../expressiontransforms/TokenTransformer.java | 307 +++++++++++++++++++++ ...ingExpressionWithTransformerTokensTestCase.java | 95 +++++++ 3 files changed, 403 insertions(+) create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index a723be8b478..b19c0c3152d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -29,6 +29,7 @@ public class ExpressionTransforms { new OnnxModelTransformer(), new XgboostFeatureConverter(), new LightGBMFeatureConverter(), + new TokenTransformer(), new ConstantDereferencer(), new ConstantTensorTransformer(), new FunctionInliner(), diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java new file mode 100644 index 00000000000..58ae9799f23 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -0,0 +1,307 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/** + * Convenience feature transforms for inputs to Transformer type models. + * + * Replaces features of the form + * + * token_input_ids + * token_type_ids + * token_attention_mask + * + * to tensor generation expressions that generate the required input. + * In general, these models expect input of the form: + * + * CLS + arg1 + SEP + arg2 + SEP + 0's + * + * @author lesters + */ +public class TokenTransformer extends ExpressionTransformer { + + static private final ConstantNode ZERO = new ConstantNode(new DoubleValue(0.0), "0"); + static private final ConstantNode ONE = new ConstantNode(new DoubleValue(1.0), "1"); + static private final ConstantNode TWO = new ConstantNode(new DoubleValue(2.0), "2"); + static private final ConstantNode CLS = new ConstantNode(new DoubleValue(101), "101"); + static private final ConstantNode SEP = new ConstantNode(new DoubleValue(102), "102"); + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if (feature.getName().equals("token_input_ids")) + return transformTokenInputIds(feature, context); + if (feature.getName().equals("token_type_ids")) + return transformTokenTypeIds(feature, context); + if (feature.getName().equals("token_attention_mask")) + return transformTokenAttentionMask(feature, context); + return feature; + } + + /** + * Transforms a feature of the form + * + * token_input_ids(128, a, b, ...) + * + * to an expression that concatenates the arguments a, b, ... using the + * special Transformers sequences of CLS and SEP, up to length 128, so + * that the sequence becomes + * + * CLS + a + SEP + b + SEP + 0's + * + * Concretely, transforms to a tensor generation expression: + * + * tensor(d0[1],d1[128])( + * if (d1 < 1, + * 101, + * if (d1 < 1 + length_a, + * a{d0:(d1 - (1)}, + * if (d1 < 1 + length_a + 1, + * 102, + * if (d1 < 1 + length_a + 1 + length_b, + * b{d0:(d1 - (1 + length_a + 1))}, + * if (d1 < 1 + length_a + 1 + length_b + 1, + * 102, + * 0.0 + * )))))) + * + * Functions calculating lengths of arguments are added to the rank profile. + */ + private ExpressionNode transformTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) { + if (contextHasFunction(feature, context)) + return feature; + checkArguments(feature, context); + + TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); + + // we need to add functions calculating the token lengths of the arguments + createTokenLengthFunctions(feature, context); + + // create token sequence: CLS + arg1 + SEP + arg2 + SEP + .... + ExpressionNode tokenSequenceExpr = createTokenSequenceExpr(0, createTokenSequence(feature)); + return new TensorFunctionNode(Generate.bound(type, wrapScalar(tokenSequenceExpr))); + } + + /** + * Transforms a feature of the form + * + * token_type_ids(128, a, ...) + * + * to an expression that generates a tensor that has values 0 for "a" + * (including CLS and SEP tokens) and 1 for the rest of the sequence. + * + * Concretely, transforms to a tensor generation expression: + * + * tensor(d0[1],d1[128])(if(d1 < length_a + 2, 0, 1)) + */ + private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) { + if (contextHasFunction(feature, context)) + return feature; + checkArguments(feature, context); + + TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); + + // we need to add functions calculating the token lengths of the arguments + createTokenLengthFunctions(feature, context); + + ReferenceNode arg = (ReferenceNode) feature.getArguments().expressions().get(1); + ExpressionNode argLength = new ReferenceNode(lengthFunctionName(arg)); + ExpressionNode lengthExpr = new ArithmeticNode(argLength, ArithmeticOperator.PLUS, TWO); + ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + ExpressionNode expr = new IfNode(comparison, ZERO, ONE); + return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); + } + + /** + * Transforms a feature of the form + * + * token_attention_mask(128, a, b, ...) + * + * to an expression that generates a tensor that has values 1 for all + * arguments (including CLS and SEP tokens) and 0 for the rest of the + * sequence. + * + * Concretely, transforms to a tensor generation expression: + * + * tensor(d0[1],d1[128])(if(d1 < 1 + length_a + 1 + length_b + 1 + ..., 1, 0)) + * + */ + private ExpressionNode transformTokenAttentionMask(ReferenceNode feature, RankProfileTransformContext context) { + if (contextHasFunction(feature, context)) + return feature; + checkArguments(feature, context); + + TensorType type = createTensorType(feature.getName(), feature.getArguments().expressions().get(0)); + + // we need to add functions calculating the token lengths of the arguments + createTokenLengthFunctions(feature, context); + + List tokenSequence = createTokenSequence(feature); + ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); + ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + ExpressionNode expr = new IfNode(comparison, ONE, ZERO); + return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); + } + + private boolean contextHasFunction(ReferenceNode feature, RankProfileTransformContext context) { + return context.rankProfile().getFunctions().containsKey(feature.getName()); + } + + private void checkArguments(ReferenceNode feature, RankProfileTransformContext context) { + final String featureName = feature.getName(); + if (feature.getArguments().size() < 2) { + throw new IllegalArgumentException(featureName + " requires at least 2 arguments: the length of the token " + + "sequence and where to retrieve the tokens from."); + } + for (int i = 1; i < feature.getArguments().size(); ++i) { + ExpressionNode arg = feature.getArguments().expressions().get(i); + if ( ! (arg instanceof ReferenceNode)) { + throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " + + "the argument must be a reference. Got " + arg.toString()); + } + } + } + + private TensorType createTensorType(String featureName, ExpressionNode argument) { + try { + int length = Integer.parseInt(argument.toString()); + return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1).indexed("d1", length).build(); + } catch (NumberFormatException ex) { + throw new IllegalArgumentException("Invalid argument to " + featureName + ": the first argument must be " + + "the length to the token sequence to generate. Got " + argument.toString()); + } + } + + private String lengthFunctionName(ReferenceNode arg) { + return "__token_length@" + arg.hashCode(); + } + + private List createTokenSequence(ReferenceNode feature) { + List sequence = new ArrayList<>(); + sequence.add(CLS); + for (int i = 1; i < feature.getArguments().size(); ++i) { + sequence.add(feature.getArguments().expressions().get(i)); + sequence.add(SEP); + } + return sequence; + } + + /** + * Adds functions for calculating the token length input. Assumes that + * token sequences are 0-padded, so this returns the number of non-0 + * tokens using a map and reduce-sum. + */ + private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) { + for (int i = 1; i < feature.getArguments().size(); ++i) { + ExpressionNode arg = feature.getArguments().expressions().get(i); + if ( ! (arg instanceof ReferenceNode)) { + throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " + + "the argument must be a reference. Got " + arg.toString()); + } + ReferenceNode ref = (ReferenceNode) arg; + String functionName = lengthFunctionName(ref); + if ( ! context.rankProfile().getFunctions().containsKey(functionName)) { + context.rankProfile().addFunction(functionName, List.of(), "sum(map(" + ref + ", f(x)(x > 0)))", false); + } + } + } + + /** + * Recursively creates partial expressions of the form + * + * if (d1 < 1 + length_a, + * a{d0:(d1 - 1}, + * ... + * + * for each part of the token sequence. CLS and SEP are added directly, + * and we create a slice expression for each argument to extract the + * actual tokens. + */ + private ExpressionNode createTokenSequenceExpr(int iter, List sequence) { + ExpressionNode lengthExpr = createLengthExpr(iter, sequence); + ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + + ExpressionNode trueExpr = sequence.get(iter); + if (sequence.get(iter) instanceof ReferenceNode) { + trueExpr = createTokenExtractExpr(iter, sequence); + } + + ExpressionNode falseExpr; + if (iter < sequence.size() - 1) { + falseExpr = createTokenSequenceExpr(iter + 1, sequence); + } else { + falseExpr = ZERO; // 0-padding for rest of sequence + } + + return new IfNode(comparison, trueExpr, falseExpr); + } + + /** + * Creates an expression for the length of the token sequence so far, where + * the lengths of CLS and SEP are 1, and the length of the arguments are + * calculated using auxiliary functions. + */ + private ExpressionNode createLengthExpr(int iter, List sequence) { + List factors = new ArrayList<>(); + List operators = new ArrayList<>(); + for (int i = 0; i < iter + 1; ++i) { + if (sequence.get(i) instanceof ConstantNode) { + factors.add(ONE); + } else if (sequence.get(i) instanceof ReferenceNode) { + factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i)))); + } + if (i >= 1) { + operators.add(ArithmeticOperator.PLUS); + } + } + return new ArithmeticNode(factors, operators); + } + + /** + * Create the slice expression to extract the tokens from arguments + */ + private ExpressionNode createTokenExtractExpr(int iter, List sequence) { + ExpressionNode expr; + if (iter >= 1) { + ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence)); + expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr)); + } else { + expr = new ReferenceNode("d1"); + } + List> slices = List.of(new Slice.DimensionValue<>("d0", wrapScalar(expr)) ); + TensorFunction argument = new TensorFunctionNode.ExpressionTensorFunction(sequence.get(iter)); + return new TensorFunctionNode(new Slice<>(argument, slices)); + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java new file mode 100644 index 00000000000..19d4b4a6778 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTransformerTokensTestCase.java @@ -0,0 +1,95 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.model.test.MockApplicationPackage; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; +import com.yahoo.searchdefinition.expressiontransforms.TokenTransformer; +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import java.util.Collections; + +import static org.junit.Assert.assertEquals; + +public class RankingExpressionWithTransformerTokensTestCase { + + @Test + public void testTokenInputIds() throws Exception { + String expected = "tensor(d0[1],d1[12]):[101,1,2,102,3,4,5,102,6,7,102,0]"; + String a = "tensor(d0[2]):[1,2]"; + String b = "tensor(d0[3]):[3,4,5]"; + String c = "tensor(d0[2]):[6,7]"; + String expression = "token_input_ids(12, a, b, c)"; + Tensor result = evaluateExpression(expression, a, b, c); + assertEquals(Tensor.from(expected), result); + } + + @Test + public void testTokenTypeIds() throws Exception { + String expected = "tensor(d0[1],d1[10]):[0,0,0,0,1,1,1,1,1,1]"; + String a = "tensor(d0[2]):[1,2]"; + String b = "tensor(d0[3]):[3,4,5]"; + String expression = "token_type_ids(10, a, b)"; + Tensor result = evaluateExpression(expression, a, b); + assertEquals(Tensor.from(expected), result); + } + + @Test + public void testAttentionMask() throws Exception { + String expected = "tensor(d0[1],d1[10]):[1,1,1,1,1,1,1,1,0,0]"; + String a = "tensor(d0[2]):[1,2]"; + String b = "tensor(d0[3]):[3,4,5]"; + String expression = "token_attention_mask(10, a, b)"; + Tensor result = evaluateExpression(expression, a, b); + assertEquals(Tensor.from(expected), result); + } + + private Tensor evaluateExpression(String expression, String a, String b) throws Exception { + return evaluateExpression(expression, a, b, null, null); + } + + private Tensor evaluateExpression(String expression, String a, String b, String c) throws Exception { + return evaluateExpression(expression, a, b, c, null); + } + + private Tensor evaluateExpression(String expression, String a, String b, String c, String d) throws Exception { + MapContext context = new MapContext(); + if (a != null) context.put("a", new TensorValue(Tensor.from(a))); + if (b != null) context.put("b", new TensorValue(Tensor.from(b))); + if (c != null) context.put("c", new TensorValue(Tensor.from(c))); + if (d != null) context.put("d", new TensorValue(Tensor.from(d))); + var transformContext = createTransformContext(); + var rankingExpression = new RankingExpression(expression); + var transformed = new TokenTransformer().transform(rankingExpression, transformContext); + for (var entry : transformContext.rankProfile().getFunctions().entrySet()) { + context.put(entry.getKey(), entry.getValue().function().getBody().evaluate(context).asDouble()); + } + return transformed.evaluate(context).asTensor(); + } + + private RankProfileTransformContext createTransformContext() throws ParseException { + MockApplicationPackage application = (MockApplicationPackage) MockApplicationPackage.createEmpty(); + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + QueryProfileRegistry queryProfileRegistry = application.getQueryProfiles(); + String sdContent = "search test {\n" + + " document test {}\n" + + " rank-profile my_profile inherits default {}\n" + + "}"; + SearchBuilder searchBuilder = new SearchBuilder(application, rankProfileRegistry, queryProfileRegistry); + searchBuilder.importString(sdContent); + searchBuilder.build(); + Search search = searchBuilder.getSearch(); + RankProfile rp = rankProfileRegistry.get(search, "my_profile"); + return new RankProfileTransformContext(rp, queryProfileRegistry, Collections.EMPTY_MAP, null, Collections.EMPTY_MAP, Collections.EMPTY_MAP); + } + +} -- cgit v1.2.3