diff options
author | Lester Solbakken <lesters@oath.com> | 2021-01-27 15:27:50 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-01-27 15:27:50 +0100 |
commit | 2a4e18f6a8510d47bc903f37cc50a0b2d304255e (patch) | |
tree | 59d68f8d803370d7a194ed87295e5d71df933636 | |
parent | 3e3de199bd191b08e63bde89471cfee452ae8986 (diff) |
Propagate type of transformer token helper functions
4 files changed, 39 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 9dcba72161b..bc98f0ab8c5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchdefinition; import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.searchdefinition.expressiontransforms.TokenTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -165,6 +166,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return onnxFeatureType.get(); } + // A reference to a feature for transformer token input? + Optional<TensorType> transformerTokensFeatureType = transformerTokensFeatureType(reference); + if (transformerTokensFeatureType.isPresent()) { + return transformerTokensFeatureType.get(); + } + // A reference to a feature which returns a tensor? Optional<TensorType> featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -237,6 +244,19 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(featureTypes.get(reference)); } + private Optional<TensorType> transformerTokensFeatureType(Reference reference) { + if ( ! reference.name().equals("tokenTypeIds") && + ! reference.name().equals("tokenInputIds") && + ! reference.name().equals("tokenAttentionMask")) + return Optional.empty(); + + if ( ! (reference.arguments().size() > 1)) + throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments"); + + ExpressionNode size = reference.arguments().expressions().get(0); + return Optional.of(TokenTransformer.createTensorType(reference.name(), size)); + } + /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 1bb38eda9ff..192fb9baa9a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -186,7 +186,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform } } - private TensorType createTensorType(String featureName, ExpressionNode argument) { + public static TensorType createTensorType(String featureName, ExpressionNode argument) { try { int length = Integer.parseInt(argument.toString()); return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1).indexed("d1", length).build(); diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd index eac6f650fef..5b440e80bed 100644 --- a/config-model/src/test/integration/onnx-model/schemas/test.sd +++ b/config-model/src/test/integration/onnx-model/schemas/test.sd @@ -99,6 +99,15 @@ search test { } } + rank-profile test_dynamic_model_with_transformer_tokens { + function my_function() { + expression: tokenTypeIds(2, attribute(document_field)) + } + first-phase { + expression: onnx(dynamic_model){d0:0,d1:1} + } + } + rank-profile test_unbound_model { function my_function() { expression: tensor(d0[1],d1[2])(d1) @@ -108,5 +117,4 @@ search test { } } - } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index f8a379b4027..73ff4ac3bcd 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -113,7 +113,7 @@ public class RankingExpressionWithOnnxModelTestCase { RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(8, config.rankprofile().size()); + assertEquals(9, config.rankprofile().size()); assertEquals("test_model_config", config.rankprofile(2).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); @@ -150,10 +150,14 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); - assertEquals("test_unbound_model", config.rankprofile(7).name()); - assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(0).name()); - assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(7).fef().property(3).name()); - assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(7).fef().property(3).value()); + assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); + assertEquals("tensor<float>(d0[1],d1[2])((if (d1 < rankingExpression(__token_length@-1993461420) + 2, 0, 1)))", config.rankprofile(7).fef().property(1).value()); + + assertEquals("test_unbound_model", config.rankprofile(8).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(8).fef().property(3).name()); + assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value()); } |