aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-09 10:52:44 +0200
committerLester Solbakken <lesters@oath.com>2021-09-09 10:52:44 +0200
commit74f7d71869539aea6a6272124d71e72ce193a248 (patch)
tree415f29fedce23a37897c3cb8b54c1f5c47103e82 /config-model/src/main/java/com/yahoo
parent0bd1f4da69f10cfe5f4de2585a7973240e9b42f5 (diff)
Write blobs (large rank expression files) for model evaluator tester
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java58
1 files changed, 48 insertions, 10 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
index b98cabb6f33..6d723961acb 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
@@ -9,7 +9,10 @@ import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.google.common.collect.ImmutableList;
+import com.yahoo.config.FileReference;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.config.model.application.provider.MockFileRegistry;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -21,10 +24,13 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import com.yahoo.vespa.model.VespaModel;
+import net.jpountz.lz4.LZ4FrameOutputStream;
import org.xml.sax.SAXException;
import java.io.File;
+import java.io.FileOutputStream;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
@@ -66,13 +72,14 @@ public class ModelsEvaluatorTester {
File temporaryApplicationDir = null;
try {
temporaryApplicationDir = createTemporaryApplicationDir(modelsPath);
- RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir);
+ MockFileRegistry fileRegistry = new MockFileBlobRegistry(temporaryApplicationDir);
+ RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir, fileRegistry);
RankProfilesConfig rankProfilesConfig = getRankProfilesConfig(rankProfileList);
RankingConstantsConfig rankingConstantsConfig = getRankingConstantConfig(rankProfileList);
RankingExpressionsConfig rankingExpressionsConfig = getRankingExpressionsConfig(rankProfileList);
OnnxModelsConfig onnxModelsConfig = getOnnxModelsConfig(rankProfileList);
- FileAcquirer files = createFileAcquirer(rankingConstantsConfig, onnxModelsConfig, temporaryApplicationDir);
+ FileAcquirer files = createFileAcquirer(fileRegistry, temporaryApplicationDir);
return new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, rankingExpressionsConfig, onnxModelsConfig, files);
@@ -93,12 +100,16 @@ public class ModelsEvaluatorTester {
return temporaryApplicationDir;
}
- private static RankProfileList createRankProfileList(File appDir) throws IOException, SAXException {
+ private static RankProfileList createRankProfileList(File appDir, FileRegistry registry) throws IOException, SAXException {
ApplicationPackage app = new MockApplicationPackage.Builder()
.withEmptyHosts()
.withServices(modelEvaluationServices)
.withRoot(appDir).build();
- DeployState deployState = new DeployState.Builder().applicationPackage(app).modelImporters(importers).build();
+ DeployState deployState = new DeployState.Builder()
+ .applicationPackage(app)
+ .fileRegistry(registry)
+ .modelImporters(importers).build();
+
VespaModel vespaModel = new VespaModel(deployState);
return vespaModel.rankProfileList();
}
@@ -127,13 +138,10 @@ public class ModelsEvaluatorTester {
return builder.build();
}
- private static FileAcquirer createFileAcquirer(RankingConstantsConfig constantsConfig, OnnxModelsConfig onnxModelsConfig, File appDir) {
+ private static FileAcquirer createFileAcquirer(MockFileRegistry fileRegistry, File appDir) {
Map<String, File> fileMap = new HashMap<>();
- for (RankingConstantsConfig.Constant constant : constantsConfig.constant()) {
- fileMap.put(constant.fileref().value(), relativePath(appDir, constant.fileref().value()));
- }
- for (OnnxModelsConfig.Model model : onnxModelsConfig.model()) {
- fileMap.put(model.fileref().value(), relativePath(appDir, model.fileref().value()));
+ for (FileRegistry.Entry entry : fileRegistry.export()) {
+ fileMap.put(entry.reference.value(), relativePath(appDir, entry.reference.value()));
}
return MockFileAcquirer.returnFiles(fileMap);
}
@@ -142,4 +150,34 @@ public class ModelsEvaluatorTester {
return new File(root.getAbsolutePath() + File.separator + subpath);
}
+ private static class MockFileBlobRegistry extends MockFileRegistry {
+
+ private final File appDir;
+
+ MockFileBlobRegistry(File appdir) {
+ this.appDir = appdir;
+ }
+
+ @Override
+ public FileReference addBlob(String name, ByteBuffer blob) {
+ writeBlob(blob, name);
+ return addFile(name);
+ }
+
+ private void writeBlob(ByteBuffer blob, String relativePath) {
+ try (FileOutputStream fos = new FileOutputStream(new File(appDir, relativePath))) {
+ if (relativePath.endsWith(".lz4")) {
+ LZ4FrameOutputStream lz4 = new LZ4FrameOutputStream(fos);
+ lz4.write(blob.array(), blob.arrayOffset(), blob.remaining());
+ lz4.close();
+ } else {
+ fos.write(blob.array(), blob.arrayOffset(), blob.remaining());
+ }
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Failed writing temp file", e);
+ }
+ }
+
+ }
+
}