aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java14
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java12
2 files changed, 23 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 88766da67fc..2d1a0d069a8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -9,9 +9,9 @@ import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import java.util.Map;
-import java.util.stream.Collectors;
/**
* Evaluates machine-learned models added to Vespa applications and available as config form.
@@ -28,9 +28,19 @@ public class ModelsEvaluator extends AbstractComponent {
@Inject
public ModelsEvaluator(RankProfilesConfig config,
RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer) {
- this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig, onnxModelsConfig));
+ this(new RankProfilesConfigImporter(fileAcquirer)
+ .importFrom(config, constantsConfig, expressionsConfig, onnxModelsConfig));
+ }
+
+ public ModelsEvaluator(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig,
+ FileAcquirer fileAcquirer) {
+ this(new RankProfilesConfigImporter(fileAcquirer)
+ .importFrom(config, constantsConfig, new RankingExpressionsConfig.Builder().build(), onnxModelsConfig));
}
public ModelsEvaluator(Map<String, Model> models) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 1bdb2810ddf..06ca7a60f4c 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -15,6 +15,7 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import java.io.File;
import java.io.IOException;
@@ -51,11 +52,12 @@ public class RankProfilesConfigImporter {
*/
public Map<String, Model> importFrom(RankProfilesConfig config,
RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig) {
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
- Model model = importProfile(profile, constantsConfig, onnxModelsConfig);
+ Model model = importProfile(profile, constantsConfig, expressionsConfig, onnxModelsConfig);
models.put(model.name(), model);
}
return models;
@@ -65,8 +67,16 @@ public class RankProfilesConfigImporter {
}
}
+ @Deprecated
+ public Map<String, Model> importFrom(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig) {
+ return importFrom(config, constantsConfig, new RankingExpressionsConfig.Builder().build(), onnxModelsConfig);
+ }
+
private Model importProfile(RankProfilesConfig.Rankprofile profile,
RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig)
throws ParseException {