diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index cff9abb08ed..7f590d4b230 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.google.common.collect.ImmutableList; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.path.Path; @@ -10,7 +11,11 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; @@ -26,6 +31,9 @@ import static org.junit.Assert.assertEquals; */ class RankProfileSearchFixture { + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; private Search search; @@ -83,7 +91,8 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) - .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); + .compile(queryProfileRegistry, + new ImportedModels(applicationDir.toFile(), importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } |