diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java | 75 |
1 files changed, 36 insertions, 39 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java index bfd0520c62a..83d19b010bb 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java @@ -12,8 +12,8 @@ import com.yahoo.schema.FeatureNames; import com.yahoo.schema.parser.ParseException; import com.yahoo.tensor.TensorType; import com.yahoo.yolean.Exceptions; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import java.io.File; import java.io.FileReader; @@ -23,10 +23,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.*; public class RankingExpressionWithOnnxTestCase { @@ -35,13 +32,13 @@ public class RankingExpressionWithOnnxTestCase { private final static String name = "mnist_softmax"; private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))"; - @After + @AfterEach public void removeGeneratedModelFiles() { IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); } @Test - public void testOnnxReferenceWithConstantFeature() { + void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx_vespa('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", @@ -50,12 +47,12 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReferenceWithQueryFeature() { + void testOnnxReferenceWithQueryFeature() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='query(mytensor)' type='tensor<float>(d0[1],d1[784])'/>" + - "</query-profile-type>"; + " <field name='query(mytensor)' type='tensor<float>(d0[1],d1[784])'/>" + + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); @@ -69,7 +66,7 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReferenceWithDocumentFeature() { + void testOnnxReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", "onnx_vespa('mnist_softmax.onnx')", @@ -82,12 +79,12 @@ public class RankingExpressionWithOnnxTestCase { @Test - public void testOnnxReferenceWithFeatureCombination() { + void testOnnxReferenceWithFeatureCombination() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='query(mytensor)' type='tensor<float>(d0[1],d1[784],d2[10])'/>" + - "</query-profile-type>"; + " <field name='query(mytensor)' type='tensor<float>(d0[1],d1[784],d2[10])'/>" + + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", "onnx_vespa('mnist_softmax.onnx')", @@ -100,28 +97,28 @@ public class RankingExpressionWithOnnxTestCase { @Test - public void testNestedOnnxReference() { + void testNestedOnnxReference() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "5 + sum(onnx_vespa('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test - public void testOnnxReferenceWithSpecifiedOutput() { + void testOnnxReferenceWithSpecifiedOutput() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "onnx_vespa('mnist_softmax.onnx', 'layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test - public void testOnnxReferenceWithSpecifiedOutputAndSignature() { + void testOnnxReferenceWithSpecifiedOutputAndSignature() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test - public void testOnnxReferenceMissingFunction() throws ParseException { + void testOnnxReferenceMissingFunction() throws ParseException { try { RankProfileSearchFixture search = new RankProfileSearchFixture( new StoringApplicationPackage(applicationDir), @@ -137,15 +134,15 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx_vespa(\"mnist_softmax.onnx\"): " + - "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " + - "not present in rank profile 'my_profile'", + "onnx_vespa(\"mnist_softmax.onnx\"): " + + "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " + + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } } @Test - public void testOnnxReferenceWithWrongFunctionType() { + void testOnnxReferenceWithWrongFunctionType() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)", "onnx_vespa('mnist_softmax.onnx')"); @@ -154,15 +151,15 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx_vespa(\"mnist_softmax.onnx\"): " + - "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " + - "but this function returns tensor(d0[1],d5[10])", + "onnx_vespa(\"mnist_softmax.onnx\"): " + + "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " + + "but this function returns tensor(d0[1],d5[10])", Exceptions.toMessageString(expected)); } } @Test - public void testOnnxReferenceSpecifyingNonExistingOutput() { + void testOnnxReferenceSpecifyingNonExistingOutput() { try { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "onnx_vespa('mnist_softmax.onnx', 'y')"); @@ -171,14 +168,14 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx_vespa(\"mnist_softmax.onnx\",\"y\"): " + - "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", - Exceptions.toMessageString(expected)); + "onnx_vespa(\"mnist_softmax.onnx\",\"y\"): " + + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", + Exceptions.toMessageString(expected)); } } @Test - public void testImportingFromStoredExpressions() throws IOException { + void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "onnx_vespa(\"mnist_softmax.onnx\")"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -206,7 +203,7 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException { + void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException { String rankProfile = " rank-profile my_profile {\n" + " function Placeholder() {\n" + @@ -230,8 +227,8 @@ public class RankingExpressionWithOnnxTestCase { search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); - assertNull("Constant overridden by function is not added", - search.search().constants().get(name + "_Variable")); + assertNull(search.search().constants().get(name + "_Variable"), + "Constant overridden by function is not added"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -245,18 +242,18 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); - assertNull("Constant overridden by function is not added", - searchFromStored.search().constants().get(name + "_Variable")); + assertNull(searchFromStored.search().constants().get(name + "_Variable"), + "Constant overridden by function is not added"); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } } @Test - public void testFunctionGeneration() { + void testFunctionGeneration() { final String name = "small_constants_and_functions"; final String rankProfiles = - " rank-profile my_profile {\n" + + " rank-profile my_profile {\n" + " function input() {\n" + " expression: tensor<float>(d0[3])(0.0)\n" + " }\n" + @@ -275,7 +272,7 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException { + void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException { final String name = "small_constants_and_functions"; final String rankProfiles = " rank-profile my_profile {\n" + |