diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java | 77 |
1 files changed, 27 insertions, 50 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 414a77e9164..b046d60f948 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -1,27 +1,22 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.Optional; import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; -import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -41,14 +36,36 @@ public class RankingExpressionWithOnnxTestCase { } @Test + public void testGlobalOnnxModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); + tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); + tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -68,8 +85,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -82,8 +97,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -104,8 +117,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -114,8 +125,6 @@ public class RankingExpressionWithOnnxTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @Test @@ -181,9 +190,6 @@ public class RankingExpressionWithOnnxTestCase { "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); - // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); try { @@ -200,8 +206,6 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.empty()); - assertLargeConstant( name + "_Variable", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -232,7 +236,6 @@ public class RankingExpressionWithOnnxTestCase { assertNull("Constant overridden by macro is not added", search.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -245,38 +248,12 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", - searchFromStored.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.of(10L)); + searchFromStored.search().rankingConstants().get( name + "_Variable")); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } } - /** - * Verifies that the constant with the given name exists, and - only if an expected size is given - - * that the content of the constant is available and has the expected size. - */ - private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { - try { - Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax.onnx/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().rankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); - assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); - - if (expectedSize.isPresent()) { - Path constantPath = applicationDir.append(constantApplicationPackagePath); - assertTrue("Constant file '" + constantPath + "' has been written", - constantPath.toFile().exists()); - Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); - assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); - } - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); |