aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java75
1 files changed, 36 insertions, 39 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java
index bfd0520c62a..83d19b010bb 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java
@@ -12,8 +12,8 @@ import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.parser.ParseException;
import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
-import org.junit.After;
-import org.junit.Test;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.FileReader;
@@ -23,10 +23,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.fail;
+import static org.junit.jupiter.api.Assertions.*;
public class RankingExpressionWithOnnxTestCase {
@@ -35,13 +32,13 @@ public class RankingExpressionWithOnnxTestCase {
private final static String name = "mnist_softmax";
private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))";
- @After
+ @AfterEach
public void removeGeneratedModelFiles() {
IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
}
@Test
- public void testOnnxReferenceWithConstantFeature() {
+ void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx_vespa('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
@@ -50,12 +47,12 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReferenceWithQueryFeature() {
+ void testOnnxReferenceWithQueryFeature() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType =
"<query-profile-type id='root'>" +
- " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784])'/>" +
- "</query-profile-type>";
+ " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784])'/>" +
+ "</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
queryProfileType);
@@ -69,7 +66,7 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReferenceWithDocumentFeature() {
+ void testOnnxReferenceWithDocumentFeature() {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
"onnx_vespa('mnist_softmax.onnx')",
@@ -82,12 +79,12 @@ public class RankingExpressionWithOnnxTestCase {
@Test
- public void testOnnxReferenceWithFeatureCombination() {
+ void testOnnxReferenceWithFeatureCombination() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType =
"<query-profile-type id='root'>" +
- " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784],d2[10])'/>" +
- "</query-profile-type>";
+ " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784],d2[10])'/>" +
+ "</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType);
RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
"onnx_vespa('mnist_softmax.onnx')",
@@ -100,28 +97,28 @@ public class RankingExpressionWithOnnxTestCase {
@Test
- public void testNestedOnnxReference() {
+ void testNestedOnnxReference() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
"5 + sum(onnx_vespa('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@Test
- public void testOnnxReferenceWithSpecifiedOutput() {
+ void testOnnxReferenceWithSpecifiedOutput() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
"onnx_vespa('mnist_softmax.onnx', 'layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
- public void testOnnxReferenceWithSpecifiedOutputAndSignature() {
+ void testOnnxReferenceWithSpecifiedOutputAndSignature() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
"onnx_vespa('mnist_softmax.onnx', 'default.layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
- public void testOnnxReferenceMissingFunction() throws ParseException {
+ void testOnnxReferenceMissingFunction() throws ParseException {
try {
RankProfileSearchFixture search = new RankProfileSearchFixture(
new StoringApplicationPackage(applicationDir),
@@ -137,15 +134,15 @@ public class RankingExpressionWithOnnxTestCase {
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx_vespa(\"mnist_softmax.onnx\"): " +
- "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " +
- "not present in rank profile 'my_profile'",
+ "onnx_vespa(\"mnist_softmax.onnx\"): " +
+ "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " +
+ "not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
}
@Test
- public void testOnnxReferenceWithWrongFunctionType() {
+ void testOnnxReferenceWithWrongFunctionType() {
try {
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)",
"onnx_vespa('mnist_softmax.onnx')");
@@ -154,15 +151,15 @@ public class RankingExpressionWithOnnxTestCase {
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx_vespa(\"mnist_softmax.onnx\"): " +
- "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " +
- "but this function returns tensor(d0[1],d5[10])",
+ "onnx_vespa(\"mnist_softmax.onnx\"): " +
+ "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " +
+ "but this function returns tensor(d0[1],d5[10])",
Exceptions.toMessageString(expected));
}
}
@Test
- public void testOnnxReferenceSpecifyingNonExistingOutput() {
+ void testOnnxReferenceSpecifyingNonExistingOutput() {
try {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
"onnx_vespa('mnist_softmax.onnx', 'y')");
@@ -171,14 +168,14 @@ public class RankingExpressionWithOnnxTestCase {
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx_vespa(\"mnist_softmax.onnx\",\"y\"): " +
- "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add",
- Exceptions.toMessageString(expected));
+ "onnx_vespa(\"mnist_softmax.onnx\",\"y\"): " +
+ "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add",
+ Exceptions.toMessageString(expected));
}
}
@Test
- public void testImportingFromStoredExpressions() throws IOException {
+ void testImportingFromStoredExpressions() throws IOException {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
"onnx_vespa(\"mnist_softmax.onnx\")");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -206,7 +203,7 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException {
+ void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException {
String rankProfile =
" rank-profile my_profile {\n" +
" function Placeholder() {\n" +
@@ -230,8 +227,8 @@ public class RankingExpressionWithOnnxTestCase {
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
- assertNull("Constant overridden by function is not added",
- search.search().constants().get(name + "_Variable"));
+ assertNull(search.search().constants().get(name + "_Variable"),
+ "Constant overridden by function is not added");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -245,18 +242,18 @@ public class RankingExpressionWithOnnxTestCase {
searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
- assertNull("Constant overridden by function is not added",
- searchFromStored.search().constants().get(name + "_Variable"));
+ assertNull(searchFromStored.search().constants().get(name + "_Variable"),
+ "Constant overridden by function is not added");
} finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
}
}
@Test
- public void testFunctionGeneration() {
+ void testFunctionGeneration() {
final String name = "small_constants_and_functions";
final String rankProfiles =
- " rank-profile my_profile {\n" +
+ " rank-profile my_profile {\n" +
" function input() {\n" +
" expression: tensor<float>(d0[3])(0.0)\n" +
" }\n" +
@@ -275,7 +272,7 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException {
+ void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException {
final String name = "small_constants_and_functions";
final String rankProfiles =
" rank-profile my_profile {\n" +