From 2a4e18f6a8510d47bc903f37cc50a0b2d304255e Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 27 Jan 2021 15:27:50 +0100 Subject: Propagate type of transformer token helper functions --- .../searchdefinition/MapEvaluationTypeContext.java | 20 ++++++++++++++++++++ .../expressiontransforms/TokenTransformer.java | 2 +- .../src/test/integration/onnx-model/schemas/test.sd | 10 +++++++++- .../RankingExpressionWithOnnxModelTestCase.java | 14 +++++++++----- 4 files changed, 39 insertions(+), 7 deletions(-) (limited to 'config-model') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 9dcba72161b..bc98f0ab8c5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchdefinition; import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.searchdefinition.expressiontransforms.TokenTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -165,6 +166,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return onnxFeatureType.get(); } + // A reference to a feature for transformer token input? + Optional transformerTokensFeatureType = transformerTokensFeatureType(reference); + if (transformerTokensFeatureType.isPresent()) { + return transformerTokensFeatureType.get(); + } + // A reference to a feature which returns a tensor? Optional featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -237,6 +244,19 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(featureTypes.get(reference)); } + private Optional transformerTokensFeatureType(Reference reference) { + if ( ! reference.name().equals("tokenTypeIds") && + ! reference.name().equals("tokenInputIds") && + ! reference.name().equals("tokenAttentionMask")) + return Optional.empty(); + + if ( ! (reference.arguments().size() > 1)) + throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments"); + + ExpressionNode size = reference.arguments().expressions().get(0); + return Optional.of(TokenTransformer.createTensorType(reference.name(), size)); + } + /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 1bb38eda9ff..192fb9baa9a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -186,7 +186,7 @@ public class TokenTransformer extends ExpressionTransformer(d0[1],d1[2])((if (d1 < rankingExpression(__token_length@-1993461420) + 2, 0, 1)))", config.rankprofile(7).fef().property(1).value()); + + assertEquals("test_unbound_model", config.rankprofile(8).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(8).fef().property(3).name()); + assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value()); } -- cgit v1.2.3