diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java | 36 |
1 files changed, 31 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index fb424439592..1bdb2810ddf 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.io.File; @@ -48,11 +49,13 @@ public class RankProfilesConfigImporter { * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { + public Map<String, Model> importFrom(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig) { try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile, constantsConfig); + Model model = importProfile(profile, constantsConfig, onnxModelsConfig); models.put(model.name(), model); } return models; @@ -62,9 +65,12 @@ public class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) + private Model importProfile(RankProfilesConfig.Rankprofile profile, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig) throws ParseException { + List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig); List<Constant> constants = readLargeConstants(constantsConfig); Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>(); @@ -76,7 +82,7 @@ public class RankProfilesConfigImporter { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); - if ( reference.isPresent()) { + if (reference.isPresent()) { RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), @@ -122,7 +128,7 @@ public class RankProfilesConfigImporter { constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions, referencedFunctions, constants); + return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -136,6 +142,26 @@ public class RankProfilesConfigImporter { return null; } + private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) { + List<OnnxModel> onnxModels = new ArrayList<>(); + if (onnxModelsConfig != null) { + for (OnnxModelsConfig.Model onnxModelConfig : onnxModelsConfig.model()) { + onnxModels.add(readOnnxModelConfig(onnxModelConfig)); + } + } + return onnxModels; + } + + private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model onnxModelConfig) { + try { + String name = onnxModelConfig.name(); + File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS); + return new OnnxModel(name, file); + } catch (InterruptedException e) { + throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); + } + } + private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) { List<Constant> constants = new ArrayList<>(); |