diff options
author | Lester Solbakken <lesters@oath.com> | 2021-03-15 11:02:45 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-03-15 11:02:45 +0100 |
commit | 0cbf5ae978e06651962a6666f28d2724064b2828 (patch) | |
tree | 3d0611a77f946fdd08a288a92bdbd08470bde880 | |
parent | 49412d97fcf8e05972d3961e51d00b83580a7ea1 (diff) |
ONNX: import rankingExpression input as function
3 files changed, 13 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 9b129eb66ce..8bef4c39ba1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -855,10 +855,14 @@ public class RankProfile implements Cloneable { private Optional<TensorType> resolveOnnxInputType(String onnxInputName, OnnxModel model, MapEvaluationTypeContext context) { String source = model.getInputMap().get(onnxInputName); if (source != null) { - // Source is either a simple reference (query/attribute/constant)... + // Source is either a simple reference (query/attribute/constant/rankingExpression)... Optional<Reference> reference = Reference.simple(source); if (reference.isPresent()) { - return Optional.of(context.getType(reference.get())); + if (reference.get().name().equals("rankingExpression") && reference.get().simpleArgument().isPresent()) { + source = reference.get().simpleArgument().get(); // look up function below + } else { + return Optional.of(context.getType(reference.get())); + } } // ... or a function ExpressionFunction func = context.getFunction(source); diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd index 4f45e0f6318..5799499e4df 100644 --- a/config-model/src/test/integration/onnx-model/schemas/test.sd +++ b/config-model/src/test/integration/onnx-model/schemas/test.sd @@ -117,4 +117,10 @@ search test { } } + rank-profile test_transformed_dynamic_model inherits test_dynamic_model { + first-phase { + expression: max(onnx(dynamic_model){d1:0},d0) + } + } + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index a3ad9f4f4ba..f460383f42b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -113,7 +113,7 @@ public class RankingExpressionWithOnnxModelTestCase { RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(9, config.rankprofile().size()); + assertEquals(10, config.rankprofile().size()); assertEquals("test_model_config", config.rankprofile(2).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); |