aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-03-15 11:02:45 +0100
committerLester Solbakken <lesters@oath.com>2021-03-15 11:02:45 +0100
commit0cbf5ae978e06651962a6666f28d2724064b2828 (patch)
tree3d0611a77f946fdd08a288a92bdbd08470bde880
parent49412d97fcf8e05972d3961e51d00b83580a7ea1 (diff)
ONNX: import rankingExpression input as function
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java8
-rw-r--r--config-model/src/test/integration/onnx-model/schemas/test.sd6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java2
3 files changed, 13 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 9b129eb66ce..8bef4c39ba1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -855,10 +855,14 @@ public class RankProfile implements Cloneable {
private Optional<TensorType> resolveOnnxInputType(String onnxInputName, OnnxModel model, MapEvaluationTypeContext context) {
String source = model.getInputMap().get(onnxInputName);
if (source != null) {
- // Source is either a simple reference (query/attribute/constant)...
+ // Source is either a simple reference (query/attribute/constant/rankingExpression)...
Optional<Reference> reference = Reference.simple(source);
if (reference.isPresent()) {
- return Optional.of(context.getType(reference.get()));
+ if (reference.get().name().equals("rankingExpression") && reference.get().simpleArgument().isPresent()) {
+ source = reference.get().simpleArgument().get(); // look up function below
+ } else {
+ return Optional.of(context.getType(reference.get()));
+ }
}
// ... or a function
ExpressionFunction func = context.getFunction(source);
diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd
index 4f45e0f6318..5799499e4df 100644
--- a/config-model/src/test/integration/onnx-model/schemas/test.sd
+++ b/config-model/src/test/integration/onnx-model/schemas/test.sd
@@ -117,4 +117,10 @@ search test {
}
}
+ rank-profile test_transformed_dynamic_model inherits test_dynamic_model {
+ first-phase {
+ expression: max(onnx(dynamic_model){d1:0},d0)
+ }
+ }
+
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index a3ad9f4f4ba..f460383f42b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -113,7 +113,7 @@ public class RankingExpressionWithOnnxModelTestCase {
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);
- assertEquals(9, config.rankprofile().size());
+ assertEquals(10, config.rankprofile().size());
assertEquals("test_model_config", config.rankprofile(2).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name());