From 8d1451224710fa7a65e9a7410113c64c104a8dd0 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Wed, 23 Jun 2021 10:00:18 +0200 Subject: Add feature falg for controlling onnx dryrun verification. --- .../RankingExpressionWithOnnxModelTestCase.java | 117 +++++++++++---------- 1 file changed, 64 insertions(+), 53 deletions(-) (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java') diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index f460383f42b..ab148130a7d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -2,8 +2,10 @@ package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -27,9 +29,9 @@ public class RankingExpressionWithOnnxModelTestCase { @Test public void testOnnxModelFeature() throws Exception { - VespaModel model = loadModel(applicationDir); + VespaModel model = loadModel(applicationDir, false); assertTransformedFeature(model); - assertGeneratedConfig(model); + assertGeneratedConfig(model, false); Path storedApplicationDir = applicationDir.append("copy"); try { @@ -39,73 +41,82 @@ public class RankingExpressionWithOnnxModelTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - VespaModel storedModel = loadModel(storedApplicationDir); + VespaModel storedModel = loadModel(storedApplicationDir, true); assertTransformedFeature(storedModel); - assertGeneratedConfig(storedModel); + assertGeneratedConfig(storedModel, true); } finally { IOUtils.recursiveDeleteDir(storedApplicationDir.toFile()); } } - private VespaModel loadModel(Path path) throws Exception { + private VespaModel loadModel(Path path, boolean dryRunOnnx) throws Exception { FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile()); - DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).build(); + ModelContext.Properties properties = new TestProperties().setDryRunOnnxOnSetup(dryRunOnnx); + DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).properties(properties).build(); return new VespaModel(state); } - private void assertGeneratedConfig(VespaModel model) { - DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); + private void assertGeneratedConfig(VespaModel vespaModel, boolean expectDryRunOnnx) { + DocumentDatabase db = ((IndexedSearchCluster)vespaModel.getSearchClusters().get(0)).getDocumentDbs().get(0); OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); assertEquals(6, config.model().size()); + for (OnnxModelsConfig.Model model : config.model()) { + assertEquals(expectDryRunOnnx, model.dry_run_on_setup()); + } - assertEquals("my_model", config.model(0).name()); - assertEquals(3, config.model(0).input().size()); - assertEquals("second/input:0", config.model(0).input(0).name()); - assertEquals("constant(my_constant)", config.model(0).input(0).source()); - assertEquals("first_input", config.model(0).input(1).name()); - assertEquals("attribute(document_field)", config.model(0).input(1).source()); - assertEquals("third_input", config.model(0).input(2).name()); - assertEquals("rankingExpression(my_function)", config.model(0).input(2).source()); - assertEquals(3, config.model(0).output().size()); - assertEquals("path/to/output:0", config.model(0).output(0).name()); - assertEquals("out", config.model(0).output(0).as()); - assertEquals("path/to/output:1", config.model(0).output(1).name()); - assertEquals("path_to_output_1", config.model(0).output(1).as()); - assertEquals("path/to/output:2", config.model(0).output(2).name()); - assertEquals("path_to_output_2", config.model(0).output(2).as()); - - assertEquals("files_model_onnx", config.model(1).name()); - assertEquals(3, config.model(1).input().size()); - assertEquals(3, config.model(1).output().size()); - assertEquals("path/to/output:0", config.model(1).output(0).name()); - assertEquals("path_to_output_0", config.model(1).output(0).as()); - assertEquals("path/to/output:1", config.model(1).output(1).name()); - assertEquals("path_to_output_1", config.model(1).output(1).as()); - assertEquals("path/to/output:2", config.model(1).output(2).name()); - assertEquals("path_to_output_2", config.model(1).output(2).as()); - assertEquals("files_model_onnx", config.model(1).name()); - - assertEquals("another_model", config.model(2).name()); - assertEquals("third_input", config.model(2).input(2).name()); - assertEquals("rankingExpression(another_function)", config.model(2).input(2).source()); - - assertEquals("files_summary_model_onnx", config.model(3).name()); - assertEquals(3, config.model(3).input().size()); - assertEquals(3, config.model(3).output().size()); - - assertEquals("dynamic_model", config.model(5).name()); - assertEquals(1, config.model(5).input().size()); - assertEquals(1, config.model(5).output().size()); - assertEquals("rankingExpression(my_function)", config.model(5).input(0).source()); - - assertEquals("unbound_model", config.model(4).name()); - assertEquals(1, config.model(4).input().size()); - assertEquals(1, config.model(4).output().size()); - assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); - + OnnxModelsConfig.Model model = config.model(0); + assertEquals("my_model", model.name()); + assertEquals(3, model.input().size()); + assertEquals("second/input:0", model.input(0).name()); + assertEquals("constant(my_constant)", model.input(0).source()); + assertEquals("first_input", model.input(1).name()); + assertEquals("attribute(document_field)", model.input(1).source()); + assertEquals("third_input", model.input(2).name()); + assertEquals("rankingExpression(my_function)", model.input(2).source()); + assertEquals(3, model.output().size()); + assertEquals("path/to/output:0", model.output(0).name()); + assertEquals("out", model.output(0).as()); + assertEquals("path/to/output:1", model.output(1).name()); + assertEquals("path_to_output_1", model.output(1).as()); + assertEquals("path/to/output:2", model.output(2).name()); + assertEquals("path_to_output_2", model.output(2).as()); + + model = config.model(1); + assertEquals("files_model_onnx", model.name()); + assertEquals(3, model.input().size()); + assertEquals(3, model.output().size()); + assertEquals("path/to/output:0", model.output(0).name()); + assertEquals("path_to_output_0", model.output(0).as()); + assertEquals("path/to/output:1", model.output(1).name()); + assertEquals("path_to_output_1", model.output(1).as()); + assertEquals("path/to/output:2", model.output(2).name()); + assertEquals("path_to_output_2", model.output(2).as()); + assertEquals("files_model_onnx", model.name()); + + model = config.model(2); + assertEquals("another_model", model.name()); + assertEquals("third_input", model.input(2).name()); + assertEquals("rankingExpression(another_function)", model.input(2).source()); + + model = config.model(3); + assertEquals("files_summary_model_onnx", model.name()); + assertEquals(3, model.input().size()); + assertEquals(3, model.output().size()); + + model = config.model(4); + assertEquals("unbound_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); + + model = config.model(5); + assertEquals("dynamic_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); } private void assertTransformedFeature(VespaModel model) { -- cgit v1.2.3