diff options
author | Jon Marius Venstad <jonmv@users.noreply.github.com> | 2024-05-24 15:22:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-24 15:22:18 +0200 |
commit | 8b7dbbeba78a596515c87fae8351a38f673c1c1a (patch) | |
tree | a8f1b719254fc0f97739b5ed88ee1e32f5a2af40 | |
parent | 2165583576933da1bfd3bafedabb66f303491fd9 (diff) | |
parent | bcf2ed742f80aab5939b885fce5d731e72773259 (diff) |
Merge pull request #31298 from vespa-engine/jonmv/add-cost-for-all-distributed-onnx-models
Include cost of all distributed ONNX models, not just those with cust…
4 files changed, 36 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 6c2b9ef8e59..5bcd21a5b9b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -815,8 +815,9 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { !container.getHostResource().realResources().gpuResources().isZero()); onnxModel.setGpuDevice(gpuDevice, hasGpu); } - cluster.onnxModelCostCalculator().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); } + for (OnnxModel onnxModel : models.asMap().values()) + cluster.onnxModelCostCalculator().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); } diff --git a/config-model/src/test/cfg/application/ml_serving/services.xml b/config-model/src/test/cfg/application/ml_serving/services.xml index 3a5a4438c78..b1271b1297f 100644 --- a/config-model/src/test/cfg/application/ml_serving/services.xml +++ b/config-model/src/test/cfg/application/ml_serving/services.xml @@ -3,7 +3,13 @@ <services version="1.0"> <container version="1.0"> - <model-evaluation/> + <model-evaluation> + <onnx> + <models> + <model name="sqrt" /> <!-- list one of the models --> + </models> + </onnx> + </model-evaluation> <nodes> <node hostalias="node1" /> </nodes> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 4fd61f59ed7..6114cc7d8bf 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -8,9 +8,14 @@ import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.yahoo.config.FileReference; +import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.model.ApplicationPackageTester; import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ContainerEndpoint; +import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.config.model.api.OnnxModelCost.Calculator; +import com.yahoo.config.model.api.OnnxModelCost.ModelInfo; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; @@ -22,7 +27,10 @@ import org.xml.sax.SAXException; import java.io.IOException; import java.io.UncheckedIOException; +import java.net.URI; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -47,6 +55,7 @@ public class ImportedModelTester { private final String modelName; private final Path applicationDir; private final DeployState deployState; + public final Calculator calculator = new MockCalculator(); public ImportedModelTester(String modelName, Path applicationDir) { this(modelName, applicationDir, new DeployState.Builder()); @@ -58,6 +67,7 @@ public class ImportedModelTester { deployState = deployStateBuilder.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app()) .endpoints(Set.of(new ContainerEndpoint("container", ApplicationClusterEndpoint.Scope.zone, List.of("default.example.com")))) .modelImporters(importers) + .onnxModelCost((pkg, app, cluster) -> calculator) .build(); } @@ -98,4 +108,19 @@ public class ImportedModelTester { } } + public static class MockCalculator implements OnnxModelCost.Calculator { + private final Map<String, ModelInfo> models = new HashMap<>(); + @Override public long aggregatedModelCostInBytes() { return models.size(); } + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) { + models.put(path.toString(), new ModelInfo(path.toString(), 1, 1, onnxModelOptions)); + } + @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) { + models.put(uri.toString(), new ModelInfo(uri.toString(), 1, 1, onnxModelOptions)); + } + @Override public Map<String, ModelInfo> models() { return Map.copyOf(models); } + @Override public void setRestartOnDeploy() { } + @Override public boolean restartOnDeploy() { return false; } + @Override public void store() { } + } + } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index cc33c8561fc..bced4c546c6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -55,6 +55,7 @@ public class ModelEvaluationTest { RankProfilesConfig config = new RankProfilesConfig(b); assertEquals(0, config.rankprofile().size()); + assertEquals(0, tester.calculator.aggregatedModelCostInBytes()); } finally { IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); @@ -69,6 +70,7 @@ public class ModelEvaluationTest { try { ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); assertHasMlModels(tester.createVespaModel(), appDir); + assertEquals(3, tester.calculator.aggregatedModelCostInBytes()); // At this point the expression is stored - copy application to another location which do not have a models dir storedAppDir.toFile().mkdirs(); |