diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
commit | 5c24dc5c9642a8d9ed70aee4c950fd0678a1ebec (patch) | |
tree | bd9b74bf00c832456f0b83c1b2cd7010be387d68 /config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java | |
parent | f17c4fe7de4c55f5c4ee61897eab8c2f588d8405 (diff) |
Rename the 'searchdefinition' package to 'schema'
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java | 184 |
1 files changed, 184 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java new file mode 100644 index 00000000000..713e11fd608 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java @@ -0,0 +1,184 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.search.DocumentDatabase; +import com.yahoo.vespa.model.search.IndexedSearchCluster; +import org.junit.After; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class RankingExpressionWithOnnxModelTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/onnx-model/"); + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + @Test + public void testOnnxModelFeature() throws Exception { + VespaModel model = loadModel(applicationDir); + assertTransformedFeature(model); + assertGeneratedConfig(model); + + Path storedApplicationDir = applicationDir.append("copy"); + try { + storedApplicationDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedApplicationDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append("schemas").toFile(), storedApplicationDir.append("schemas").toFile()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + + VespaModel storedModel = loadModel(storedApplicationDir); + assertTransformedFeature(storedModel); + assertGeneratedConfig(storedModel); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDir.toFile()); + } + } + + private VespaModel loadModel(Path path) throws Exception { + FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile()); + DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).build(); + return new VespaModel(state); + } + + private void assertGeneratedConfig(VespaModel vespaModel) { + DocumentDatabase db = ((IndexedSearchCluster)vespaModel.getSearchClusters().get(0)).getDocumentDbs().get(0); + + RankingConstantsConfig.Builder rankingConstantsConfigBuilder = new RankingConstantsConfig.Builder(); + db.getConfig(rankingConstantsConfigBuilder); + var rankingConstantsConfig = rankingConstantsConfigBuilder.build(); + assertEquals(1, rankingConstantsConfig.constant().size()); + assertEquals("my_constant", rankingConstantsConfig.constant(0).name()); + assertEquals("tensor(d0[2])", rankingConstantsConfig.constant(0).type()); + assertEquals("files/constant.json", rankingConstantsConfig.constant(0).fileref().value()); + + OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); + ((OnnxModelsConfig.Producer) db).getConfig(builder); + OnnxModelsConfig config = new OnnxModelsConfig(builder); + assertEquals(6, config.model().size()); + for (OnnxModelsConfig.Model model : config.model()) { + assertTrue(model.dry_run_on_setup()); + } + + OnnxModelsConfig.Model model = config.model(0); + assertEquals("my_model", model.name()); + assertEquals(3, model.input().size()); + assertEquals("second/input:0", model.input(0).name()); + assertEquals("constant(my_constant)", model.input(0).source()); + assertEquals("first_input", model.input(1).name()); + assertEquals("attribute(document_field)", model.input(1).source()); + assertEquals("third_input", model.input(2).name()); + assertEquals("rankingExpression(my_function)", model.input(2).source()); + assertEquals(3, model.output().size()); + assertEquals("path/to/output:0", model.output(0).name()); + assertEquals("out", model.output(0).as()); + assertEquals("path/to/output:1", model.output(1).name()); + assertEquals("path_to_output_1", model.output(1).as()); + assertEquals("path/to/output:2", model.output(2).name()); + assertEquals("path_to_output_2", model.output(2).as()); + + model = config.model(1); + assertEquals("dynamic_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); + + model = config.model(2); + assertEquals("unbound_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); + + model = config.model(3); + assertEquals("files_model_onnx", model.name()); + assertEquals(3, model.input().size()); + assertEquals(3, model.output().size()); + assertEquals("path/to/output:0", model.output(0).name()); + assertEquals("path_to_output_0", model.output(0).as()); + assertEquals("path/to/output:1", model.output(1).name()); + assertEquals("path_to_output_1", model.output(1).as()); + assertEquals("path/to/output:2", model.output(2).name()); + assertEquals("path_to_output_2", model.output(2).as()); + assertEquals("files_model_onnx", model.name()); + + model = config.model(4); + assertEquals("another_model", model.name()); + assertEquals("third_input", model.input(2).name()); + assertEquals("rankingExpression(another_function)", model.input(2).source()); + + model = config.model(5); + assertEquals("files_summary_model_onnx", model.name()); + assertEquals(3, model.input().size()); + assertEquals(3, model.output().size()); + } + + private void assertTransformedFeature(VespaModel model) { + DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); + RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); + ((RankProfilesConfig.Producer) db).getConfig(builder); + RankProfilesConfig config = new RankProfilesConfig(builder); + assertEquals(10, config.rankprofile().size()); + + assertEquals("test_model_config", config.rankprofile(2).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); + assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name()); + assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); + + assertEquals("test_generated_model_config", config.rankprofile(3).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); + assertEquals("rankingExpression(first_input).rankingScript", config.rankprofile(3).fef().property(2).name()); + assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name()); + assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name()); + assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name()); + assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); + + assertEquals("test_summary_features", config.rankprofile(4).name()); + assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); + assertEquals("1", config.rankprofile(4).fef().property(3).value()); + assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); + assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value()); + assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); + assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value()); + + assertEquals("test_dynamic_model", config.rankprofile(5).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); + + assertEquals("test_dynamic_model_2", config.rankprofile(6).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); + + assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); + assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); + + assertEquals("test_unbound_model", config.rankprofile(8).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(8).fef().property(3).name()); + assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value()); + + + } + +} |