aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Marius Venstad <jonmv@users.noreply.github.com>2024-05-24 15:22:18 +0200
committerGitHub <noreply@github.com>2024-05-24 15:22:18 +0200
commit8b7dbbeba78a596515c87fae8351a38f673c1c1a (patch)
treea8f1b719254fc0f97739b5ed88ee1e32f5a2af40
parent2165583576933da1bfd3bafedabb66f303491fd9 (diff)
parentbcf2ed742f80aab5939b885fce5d731e72773259 (diff)
Merge pull request #31298 from vespa-engine/jonmv/add-cost-for-all-distributed-onnx-models
Include cost of all distributed ONNX models, not just those with cust…
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java3
-rw-r--r--config-model/src/test/cfg/application/ml_serving/services.xml8
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java25
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java2
4 files changed, 36 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 6c2b9ef8e59..5bcd21a5b9b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -815,8 +815,9 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
!container.getHostResource().realResources().gpuResources().isZero());
onnxModel.setGpuDevice(gpuDevice, hasGpu);
}
- cluster.onnxModelCostCalculator().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions());
}
+ for (OnnxModel onnxModel : models.asMap().values())
+ cluster.onnxModelCostCalculator().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions());
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models));
}
diff --git a/config-model/src/test/cfg/application/ml_serving/services.xml b/config-model/src/test/cfg/application/ml_serving/services.xml
index 3a5a4438c78..b1271b1297f 100644
--- a/config-model/src/test/cfg/application/ml_serving/services.xml
+++ b/config-model/src/test/cfg/application/ml_serving/services.xml
@@ -3,7 +3,13 @@
<services version="1.0">
<container version="1.0">
- <model-evaluation/>
+ <model-evaluation>
+ <onnx>
+ <models>
+ <model name="sqrt" /> <!-- list one of the models -->
+ </models>
+ </onnx>
+ </model-evaluation>
<nodes>
<node hostalias="node1" />
</nodes>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
index 4fd61f59ed7..6114cc7d8bf 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
@@ -8,9 +8,14 @@ import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.model.ApplicationPackageTester;
import com.yahoo.config.model.api.ApplicationClusterEndpoint;
import com.yahoo.config.model.api.ContainerEndpoint;
+import com.yahoo.config.model.api.OnnxModelCost;
+import com.yahoo.config.model.api.OnnxModelCost.Calculator;
+import com.yahoo.config.model.api.OnnxModelCost.ModelInfo;
+import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
@@ -22,7 +27,10 @@ import org.xml.sax.SAXException;
import java.io.IOException;
import java.io.UncheckedIOException;
+import java.net.URI;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.Set;
@@ -47,6 +55,7 @@ public class ImportedModelTester {
private final String modelName;
private final Path applicationDir;
private final DeployState deployState;
+ public final Calculator calculator = new MockCalculator();
public ImportedModelTester(String modelName, Path applicationDir) {
this(modelName, applicationDir, new DeployState.Builder());
@@ -58,6 +67,7 @@ public class ImportedModelTester {
deployState = deployStateBuilder.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app())
.endpoints(Set.of(new ContainerEndpoint("container", ApplicationClusterEndpoint.Scope.zone, List.of("default.example.com"))))
.modelImporters(importers)
+ .onnxModelCost((pkg, app, cluster) -> calculator)
.build();
}
@@ -98,4 +108,19 @@ public class ImportedModelTester {
}
}
+ public static class MockCalculator implements OnnxModelCost.Calculator {
+ private final Map<String, ModelInfo> models = new HashMap<>();
+ @Override public long aggregatedModelCostInBytes() { return models.size(); }
+ @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {
+ models.put(path.toString(), new ModelInfo(path.toString(), 1, 1, onnxModelOptions));
+ }
+ @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {
+ models.put(uri.toString(), new ModelInfo(uri.toString(), 1, 1, onnxModelOptions));
+ }
+ @Override public Map<String, ModelInfo> models() { return Map.copyOf(models); }
+ @Override public void setRestartOnDeploy() { }
+ @Override public boolean restartOnDeploy() { return false; }
+ @Override public void store() { }
+ }
+
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index cc33c8561fc..bced4c546c6 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -55,6 +55,7 @@ public class ModelEvaluationTest {
RankProfilesConfig config = new RankProfilesConfig(b);
assertEquals(0, config.rankprofile().size());
+ assertEquals(0, tester.calculator.aggregatedModelCostInBytes());
}
finally {
IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
@@ -69,6 +70,7 @@ public class ModelEvaluationTest {
try {
ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir);
assertHasMlModels(tester.createVespaModel(), appDir);
+ assertEquals(3, tester.calculator.aggregatedModelCostInBytes());
// At this point the expression is stored - copy application to another location which do not have a models dir
storedAppDir.toFile().mkdirs();