diff options
author | jonmv <venstad@gmail.com> | 2024-05-24 14:12:01 +0200 |
---|---|---|
committer | jonmv <venstad@gmail.com> | 2024-05-24 14:12:01 +0200 |
commit | bcf2ed742f80aab5939b885fce5d731e72773259 (patch) | |
tree | 728e6c838a24999839c2f115dfd154f68a3eb180 /config-model/src/test/java | |
parent | e2d154dfd9d5388f0ee79247bb43a2d00197729f (diff) |
Add unit test for onnx cost registered
Diffstat (limited to 'config-model/src/test/java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java | 25 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java | 2 |
2 files changed, 27 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 4fd61f59ed7..6114cc7d8bf 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -8,9 +8,14 @@ import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.yahoo.config.FileReference; +import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.model.ApplicationPackageTester; import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ContainerEndpoint; +import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.config.model.api.OnnxModelCost.Calculator; +import com.yahoo.config.model.api.OnnxModelCost.ModelInfo; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; @@ -22,7 +27,10 @@ import org.xml.sax.SAXException; import java.io.IOException; import java.io.UncheckedIOException; +import java.net.URI; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -47,6 +55,7 @@ public class ImportedModelTester { private final String modelName; private final Path applicationDir; private final DeployState deployState; + public final Calculator calculator = new MockCalculator(); public ImportedModelTester(String modelName, Path applicationDir) { this(modelName, applicationDir, new DeployState.Builder()); @@ -58,6 +67,7 @@ public class ImportedModelTester { deployState = deployStateBuilder.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app()) .endpoints(Set.of(new ContainerEndpoint("container", ApplicationClusterEndpoint.Scope.zone, List.of("default.example.com")))) .modelImporters(importers) + .onnxModelCost((pkg, app, cluster) -> calculator) .build(); } @@ -98,4 +108,19 @@ public class ImportedModelTester { } } + public static class MockCalculator implements OnnxModelCost.Calculator { + private final Map<String, ModelInfo> models = new HashMap<>(); + @Override public long aggregatedModelCostInBytes() { return models.size(); } + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) { + models.put(path.toString(), new ModelInfo(path.toString(), 1, 1, onnxModelOptions)); + } + @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) { + models.put(uri.toString(), new ModelInfo(uri.toString(), 1, 1, onnxModelOptions)); + } + @Override public Map<String, ModelInfo> models() { return Map.copyOf(models); } + @Override public void setRestartOnDeploy() { } + @Override public boolean restartOnDeploy() { return false; } + @Override public void store() { } + } + } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index cc33c8561fc..bced4c546c6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -55,6 +55,7 @@ public class ModelEvaluationTest { RankProfilesConfig config = new RankProfilesConfig(b); assertEquals(0, config.rankprofile().size()); + assertEquals(0, tester.calculator.aggregatedModelCostInBytes()); } finally { IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); @@ -69,6 +70,7 @@ public class ModelEvaluationTest { try { ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); assertHasMlModels(tester.createVespaModel(), appDir); + assertEquals(3, tester.calculator.aggregatedModelCostInBytes()); // At this point the expression is stored - copy application to another location which do not have a models dir storedAppDir.toFile().mkdirs(); |