diff options
author | Lester Solbakken <lesters@oath.com> | 2021-09-09 10:52:44 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-09-09 10:52:44 +0200 |
commit | 74f7d71869539aea6a6272124d71e72ce193a248 (patch) | |
tree | 415f29fedce23a37897c3cb8b54c1f5c47103e82 /config-model/src/main/java/com/yahoo | |
parent | 0bd1f4da69f10cfe5f4de2585a7973240e9b42f5 (diff) |
Write blobs (large rank expression files) for model evaluator tester
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java | 58 |
1 files changed, 48 insertions, 10 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java index b98cabb6f33..6d723961acb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java @@ -9,7 +9,10 @@ import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.google.common.collect.ImmutableList; +import com.yahoo.config.FileReference; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.config.model.application.provider.MockFileRegistry; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -21,10 +24,13 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; import com.yahoo.vespa.model.VespaModel; +import net.jpountz.lz4.LZ4FrameOutputStream; import org.xml.sax.SAXException; import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; @@ -66,13 +72,14 @@ public class ModelsEvaluatorTester { File temporaryApplicationDir = null; try { temporaryApplicationDir = createTemporaryApplicationDir(modelsPath); - RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir); + MockFileRegistry fileRegistry = new MockFileBlobRegistry(temporaryApplicationDir); + RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir, fileRegistry); RankProfilesConfig rankProfilesConfig = getRankProfilesConfig(rankProfileList); RankingConstantsConfig rankingConstantsConfig = getRankingConstantConfig(rankProfileList); RankingExpressionsConfig rankingExpressionsConfig = getRankingExpressionsConfig(rankProfileList); OnnxModelsConfig onnxModelsConfig = getOnnxModelsConfig(rankProfileList); - FileAcquirer files = createFileAcquirer(rankingConstantsConfig, onnxModelsConfig, temporaryApplicationDir); + FileAcquirer files = createFileAcquirer(fileRegistry, temporaryApplicationDir); return new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, rankingExpressionsConfig, onnxModelsConfig, files); @@ -93,12 +100,16 @@ public class ModelsEvaluatorTester { return temporaryApplicationDir; } - private static RankProfileList createRankProfileList(File appDir) throws IOException, SAXException { + private static RankProfileList createRankProfileList(File appDir, FileRegistry registry) throws IOException, SAXException { ApplicationPackage app = new MockApplicationPackage.Builder() .withEmptyHosts() .withServices(modelEvaluationServices) .withRoot(appDir).build(); - DeployState deployState = new DeployState.Builder().applicationPackage(app).modelImporters(importers).build(); + DeployState deployState = new DeployState.Builder() + .applicationPackage(app) + .fileRegistry(registry) + .modelImporters(importers).build(); + VespaModel vespaModel = new VespaModel(deployState); return vespaModel.rankProfileList(); } @@ -127,13 +138,10 @@ public class ModelsEvaluatorTester { return builder.build(); } - private static FileAcquirer createFileAcquirer(RankingConstantsConfig constantsConfig, OnnxModelsConfig onnxModelsConfig, File appDir) { + private static FileAcquirer createFileAcquirer(MockFileRegistry fileRegistry, File appDir) { Map<String, File> fileMap = new HashMap<>(); - for (RankingConstantsConfig.Constant constant : constantsConfig.constant()) { - fileMap.put(constant.fileref().value(), relativePath(appDir, constant.fileref().value())); - } - for (OnnxModelsConfig.Model model : onnxModelsConfig.model()) { - fileMap.put(model.fileref().value(), relativePath(appDir, model.fileref().value())); + for (FileRegistry.Entry entry : fileRegistry.export()) { + fileMap.put(entry.reference.value(), relativePath(appDir, entry.reference.value())); } return MockFileAcquirer.returnFiles(fileMap); } @@ -142,4 +150,34 @@ public class ModelsEvaluatorTester { return new File(root.getAbsolutePath() + File.separator + subpath); } + private static class MockFileBlobRegistry extends MockFileRegistry { + + private final File appDir; + + MockFileBlobRegistry(File appdir) { + this.appDir = appdir; + } + + @Override + public FileReference addBlob(String name, ByteBuffer blob) { + writeBlob(blob, name); + return addFile(name); + } + + private void writeBlob(ByteBuffer blob, String relativePath) { + try (FileOutputStream fos = new FileOutputStream(new File(appDir, relativePath))) { + if (relativePath.endsWith(".lz4")) { + LZ4FrameOutputStream lz4 = new LZ4FrameOutputStream(fos); + lz4.write(blob.array(), blob.arrayOffset(), blob.remaining()); + lz4.close(); + } else { + fos.write(blob.array(), blob.arrayOffset(), blob.remaining()); + } + } catch (IOException e) { + throw new IllegalArgumentException("Failed writing temp file", e); + } + } + + } + } |