aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java36
1 files changed, 31 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index fb424439592..1bdb2810ddf 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.io.File;
@@ -48,11 +49,13 @@ public class RankProfilesConfigImporter {
* Returns a map of the models contained in this config, indexed on name.
* The map is modifiable and owned by the caller.
*/
- public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) {
+ public Map<String, Model> importFrom(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig) {
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
- Model model = importProfile(profile, constantsConfig);
+ Model model = importProfile(profile, constantsConfig, onnxModelsConfig);
models.put(model.name(), model);
}
return models;
@@ -62,9 +65,12 @@ public class RankProfilesConfigImporter {
}
}
- private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
+ private Model importProfile(RankProfilesConfig.Rankprofile profile,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig)
throws ParseException {
+ List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig);
List<Constant> constants = readLargeConstants(constantsConfig);
Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>();
@@ -76,7 +82,7 @@ public class RankProfilesConfigImporter {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
- if ( reference.isPresent()) {
+ if (reference.isPresent()) {
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
ExpressionFunction function = new ExpressionFunction(reference.get().functionName(),
Collections.emptyList(),
@@ -122,7 +128,7 @@ public class RankProfilesConfigImporter {
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions, referencedFunctions, constants);
+ return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
@@ -136,6 +142,26 @@ public class RankProfilesConfigImporter {
return null;
}
+ private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) {
+ List<OnnxModel> onnxModels = new ArrayList<>();
+ if (onnxModelsConfig != null) {
+ for (OnnxModelsConfig.Model onnxModelConfig : onnxModelsConfig.model()) {
+ onnxModels.add(readOnnxModelConfig(onnxModelConfig));
+ }
+ }
+ return onnxModels;
+ }
+
+ private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model onnxModelConfig) {
+ try {
+ String name = onnxModelConfig.name();
+ File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS);
+ return new OnnxModel(name, file);
+ } catch (InterruptedException e) {
+ throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
+ }
+ }
+
private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) {
List<Constant> constants = new ArrayList<>();