diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java | 14 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java | 12 |
2 files changed, 23 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 88766da67fc..2d1a0d069a8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -9,9 +9,9 @@ import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; import java.util.Map; -import java.util.stream.Collectors; /** * Evaluates machine-learned models added to Vespa applications and available as config form. @@ -28,9 +28,19 @@ public class ModelsEvaluator extends AbstractComponent { @Inject public ModelsEvaluator(RankProfilesConfig config, RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig, onnxModelsConfig)); + this(new RankProfilesConfigImporter(fileAcquirer) + .importFrom(config, constantsConfig, expressionsConfig, onnxModelsConfig)); + } + + public ModelsEvaluator(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig, + FileAcquirer fileAcquirer) { + this(new RankProfilesConfigImporter(fileAcquirer) + .importFrom(config, constantsConfig, new RankingExpressionsConfig.Builder().build(), onnxModelsConfig)); } public ModelsEvaluator(Map<String, Model> models) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 1bdb2810ddf..06ca7a60f4c 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -15,6 +15,7 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; import java.io.File; import java.io.IOException; @@ -51,11 +52,12 @@ public class RankProfilesConfigImporter { */ public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig) { try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile, constantsConfig, onnxModelsConfig); + Model model = importProfile(profile, constantsConfig, expressionsConfig, onnxModelsConfig); models.put(model.name(), model); } return models; @@ -65,8 +67,16 @@ public class RankProfilesConfigImporter { } } + @Deprecated + public Map<String, Model> importFrom(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig) { + return importFrom(config, constantsConfig, new RankingExpressionsConfig.Builder().build(), onnxModelsConfig); + } + private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig) throws ParseException { |