summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java77
1 files changed, 27 insertions, 50 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 414a77e9164..b046d60f948 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -1,27 +1,22 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.searchdefinition.processing;
import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
import com.yahoo.yolean.Exceptions;
import org.junit.After;
import org.junit.Test;
import java.io.IOException;
-import java.io.UncheckedIOException;
import java.util.Optional;
import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
-import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
@@ -41,14 +36,36 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
+ public void testGlobalOnnxModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+ tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L));
+ tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+ tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L));
+ tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant(name + "_Variable", search, Optional.of(7840L));
}
@Test
@@ -68,8 +85,6 @@ public class RankingExpressionWithOnnxTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant(name + "_Variable", search, Optional.of(7840L));
}
@Test
@@ -82,8 +97,6 @@ public class RankingExpressionWithOnnxTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
}
@@ -104,8 +117,6 @@ public class RankingExpressionWithOnnxTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
}
@@ -114,8 +125,6 @@ public class RankingExpressionWithOnnxTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(onnx('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
}
@Test
@@ -181,9 +190,6 @@ public class RankingExpressionWithOnnxTestCase {
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
-
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
try {
@@ -200,8 +206,6 @@ public class RankingExpressionWithOnnxTestCase {
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
// simulating file distribution in this test
- assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.empty());
- assertLargeConstant( name + "_Variable", searchFromStored, Optional.empty());
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -232,7 +236,6 @@ public class RankingExpressionWithOnnxTestCase {
assertNull("Constant overridden by macro is not added",
search.search().rankingConstants().get( name + "_Variable"));
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -245,38 +248,12 @@ public class RankingExpressionWithOnnxTestCase {
searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
- searchFromStored.search().rankingConstants().get( name + "_Variable"));
- assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.of(10L));
+ searchFromStored.search().rankingConstants().get( name + "_Variable"));
} finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
}
}
- /**
- * Verifies that the constant with the given name exists, and - only if an expected size is given -
- * that the content of the constant is available and has the expected size.
- */
- private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
- try {
- Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax.onnx/constants").append(name + ".tbf");
- RankingConstant rankingConstant = search.search().rankingConstants().get(name);
- assertEquals(name, rankingConstant.getName());
- assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
-
- if (expectedSize.isPresent()) {
- Path constantPath = applicationDir.append(constantApplicationPackagePath);
- assertTrue("Constant file '" + constantPath + "' has been written",
- constantPath.toFile().exists());
- Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(),
- GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
- assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
- }
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
new StoringApplicationPackage(applicationDir));