blob: bcba920f80b2a85fc7918c539f2cfcc77c495593 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.vespa;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.vespa.parser.ModelParser;
import ai.vespa.rankingexpression.importer.vespa.parser.ParseException;
import ai.vespa.rankingexpression.importer.vespa.parser.SimpleCharStream;
import com.yahoo.io.IOUtils;
import java.io.File;
import java.io.IOException;
/**
* Imports a model from a Vespa native ranking expression "model" file
*/
public class VespaImporter extends ModelImporter {
@Override
public boolean canImport(String modelPath) {
File modelFile = new File(modelPath);
if ( ! modelFile.isFile()) return false;
return modelFile.toString().endsWith(".model");
}
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.VESPA);
new ModelParser(new SimpleCharStream(IOUtils.readFile(new File(modelPath))), model).model();
return model;
}
catch (IOException | ParseException e) {
throw new IllegalArgumentException("Could not import a Vespa model from '" + modelPath + "'", e);
}
}
}
|