diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/processing')
-rw-r--r-- | config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java | 51 |
1 files changed, 17 insertions, 34 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java index 2f53dba7bb4..83d19b010bb 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java @@ -31,8 +31,6 @@ public class RankingExpressionWithOnnxTestCase { private final static String name = "mnist_softmax"; private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))"; - private final static String vespaExpressionConstants = "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\n" + - "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\n"; @AfterEach public void removeGeneratedModelFiles() { @@ -43,7 +41,7 @@ public class RankingExpressionWithOnnxTestCase { void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx_vespa('mnist_softmax.onnx')", - vespaExpressionConstants + "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", + "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -60,7 +58,7 @@ public class RankingExpressionWithOnnxTestCase { queryProfileType); RankProfileSearchFixture search = fixtureWith("query(mytensor)", "onnx_vespa('mnist_softmax.onnx')", - vespaExpressionConstants, + null, null, "Placeholder", application); @@ -72,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", "onnx_vespa('mnist_softmax.onnx')", - vespaExpressionConstants, + null, "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); @@ -90,7 +88,7 @@ public class RankingExpressionWithOnnxTestCase { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", "onnx_vespa('mnist_softmax.onnx')", - vespaExpressionConstants + "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", + "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); @@ -101,24 +99,21 @@ public class RankingExpressionWithOnnxTestCase { @Test void testNestedOnnxReference() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "5 + sum(onnx_vespa('mnist_softmax.onnx'))", - vespaExpressionConstants); + "5 + sum(onnx_vespa('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test void testOnnxReferenceWithSpecifiedOutput() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx', 'layer_add')", - vespaExpressionConstants); + "onnx_vespa('mnist_softmax.onnx', 'layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test void testOnnxReferenceWithSpecifiedOutputAndSignature() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')", - vespaExpressionConstants); + "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -182,8 +177,7 @@ public class RankingExpressionWithOnnxTestCase { @Test void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx_vespa(\"mnist_softmax.onnx\")", - vespaExpressionConstants); + "onnx_vespa(\"mnist_softmax.onnx\")"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -193,14 +187,12 @@ public class RankingExpressionWithOnnxTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - String constants = "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[2],d1[784]) }\n" + - "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[2],d1[784]) }\n"; RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx')", - constants, - null, - "Placeholder", - storedApplication); + "onnx_vespa('mnist_softmax.onnx')", + null, + null, + "Placeholder", + storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test @@ -229,8 +221,7 @@ public class RankingExpressionWithOnnxTestCase { String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b))"; - String constant = "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[10]) }\n"; - RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir), constant); + RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); search.compileRankProfile("my_profile", applicationDir.append("models")); search.compileRankProfile("my_profile_child", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); @@ -246,7 +237,7 @@ public class RankingExpressionWithOnnxTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication, constant); + RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication); searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); @@ -335,11 +326,7 @@ public class RankingExpressionWithOnnxTestCase { } private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { - return fixtureWith(placeholderExpression, firstPhaseExpression, null); - } - - private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, String constant) { - return fixtureWith(placeholderExpression, firstPhaseExpression, constant, null, "Placeholder", + return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); } @@ -350,13 +337,9 @@ public class RankingExpressionWithOnnxTestCase { } private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) { - return uncompiledFixtureWith(rankProfile, application, null); - } - - private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application, String constant) { try { return new RankProfileSearchFixture(application, application.getQueryProfiles(), - rankProfile, constant, null); + rankProfile, null, null); } catch (ParseException e) { throw new IllegalArgumentException(e); |