aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-01-27 15:27:50 +0100
committerLester Solbakken <lesters@oath.com>2021-01-27 15:27:50 +0100
commit2a4e18f6a8510d47bc903f37cc50a0b2d304255e (patch)
tree59d68f8d803370d7a194ed87295e5d71df933636
parent3e3de199bd191b08e63bde89471cfee452ae8986 (diff)
Propagate type of transformer token helper functions
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java20
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java2
-rw-r--r--config-model/src/test/integration/onnx-model/schemas/test.sd10
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java14
4 files changed, 39 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 9dcba72161b..bc98f0ab8c5 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchdefinition;
import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.searchdefinition.expressiontransforms.TokenTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -165,6 +166,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return onnxFeatureType.get();
}
+ // A reference to a feature for transformer token input?
+ Optional<TensorType> transformerTokensFeatureType = transformerTokensFeatureType(reference);
+ if (transformerTokensFeatureType.isPresent()) {
+ return transformerTokensFeatureType.get();
+ }
+
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -237,6 +244,19 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(featureTypes.get(reference));
}
+ private Optional<TensorType> transformerTokensFeatureType(Reference reference) {
+ if ( ! reference.name().equals("tokenTypeIds") &&
+ ! reference.name().equals("tokenInputIds") &&
+ ! reference.name().equals("tokenAttentionMask"))
+ return Optional.empty();
+
+ if ( ! (reference.arguments().size() > 1))
+ throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments");
+
+ ExpressionNode size = reference.arguments().expressions().get(0);
+ return Optional.of(TokenTransformer.createTensorType(reference.name(), size));
+ }
+
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 1bb38eda9ff..192fb9baa9a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -186,7 +186,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
}
}
- private TensorType createTensorType(String featureName, ExpressionNode argument) {
+ public static TensorType createTensorType(String featureName, ExpressionNode argument) {
try {
int length = Integer.parseInt(argument.toString());
return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1).indexed("d1", length).build();
diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd
index eac6f650fef..5b440e80bed 100644
--- a/config-model/src/test/integration/onnx-model/schemas/test.sd
+++ b/config-model/src/test/integration/onnx-model/schemas/test.sd
@@ -99,6 +99,15 @@ search test {
}
}
+ rank-profile test_dynamic_model_with_transformer_tokens {
+ function my_function() {
+ expression: tokenTypeIds(2, attribute(document_field))
+ }
+ first-phase {
+ expression: onnx(dynamic_model){d0:0,d1:1}
+ }
+ }
+
rank-profile test_unbound_model {
function my_function() {
expression: tensor(d0[1],d1[2])(d1)
@@ -108,5 +117,4 @@ search test {
}
}
-
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index f8a379b4027..73ff4ac3bcd 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -113,7 +113,7 @@ public class RankingExpressionWithOnnxModelTestCase {
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);
- assertEquals(8, config.rankprofile().size());
+ assertEquals(9, config.rankprofile().size());
assertEquals("test_model_config", config.rankprofile(2).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name());
@@ -150,10 +150,14 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name());
assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value());
- assertEquals("test_unbound_model", config.rankprofile(7).name());
- assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(0).name());
- assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(7).fef().property(3).name());
- assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(7).fef().property(3).value());
+ assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
+ assertEquals("tensor<float>(d0[1],d1[2])((if (d1 < rankingExpression(__token_length@-1993461420) + 2, 0, 1)))", config.rankprofile(7).fef().property(1).value());
+
+ assertEquals("test_unbound_model", config.rankprofile(8).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(8).fef().property(3).name());
+ assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value());
}