diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-11 13:37:59 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-11 13:37:59 +0200 |
commit | 812511539b5ae0146623867229eb297b530a6d35 (patch) | |
tree | 72e4111b4da1b17375c887bf89e04fd74c574c82 /config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java | |
parent | 7b280eec0f27ff793c6467d00784d89fdbe977d3 (diff) |
Only store large constants under global models
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java | 59 |
1 files changed, 35 insertions, 24 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 414a77e9164..3ca4cdaf547 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -3,14 +3,17 @@ package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.ApplicationPackageTester; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; @@ -20,6 +23,7 @@ import java.io.UncheckedIOException; import java.util.Optional; import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; +import org.xml.sax.SAXException; import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; @@ -41,14 +45,36 @@ public class RankingExpressionWithOnnxTestCase { } @Test + public void testGlobalOnnxModel() throws SAXException, IOException { + ApplicationPackageTester tester = ApplicationPackageTester.create(applicationDir.toString()); + VespaModel model = new VespaModel(tester.app()); + assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); + assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ApplicationPackageTester storedTester = ApplicationPackageTester.create(storedAppDir.toString()); + VespaModel storedModel = new VespaModel(storedTester.app()); + assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); + assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -68,8 +94,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -82,8 +106,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -104,8 +126,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -114,8 +134,6 @@ public class RankingExpressionWithOnnxTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @Test @@ -181,9 +199,6 @@ public class RankingExpressionWithOnnxTestCase { "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); - // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); try { @@ -200,8 +215,6 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.empty()); - assertLargeConstant( name + "_Variable", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -232,7 +245,6 @@ public class RankingExpressionWithOnnxTestCase { assertNull("Constant overridden by macro is not added", search.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -245,8 +257,7 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", - searchFromStored.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.of(10L)); + searchFromStored.search().rankingConstants().get( name + "_Variable")); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } @@ -256,19 +267,19 @@ public class RankingExpressionWithOnnxTestCase { * Verifies that the constant with the given name exists, and - only if an expected size is given - * that the content of the constant is available and has the expected size. */ - private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { + private void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) { try { - Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax.onnx/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().rankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); + Path constantApplicationPackagePath = Path.fromString("models.generated/" + name + "/constants").append(constantName + ".tbf"); + RankingConstant rankingConstant = model.rankingConstants().get(constantName); + assertEquals(constantName, rankingConstant.getName()); assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); if (expectedSize.isPresent()) { Path constantPath = applicationDir.append(constantApplicationPackagePath); assertTrue("Constant file '" + constantPath + "' has been written", - constantPath.toFile().exists()); + constantPath.toFile().exists()); Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); } } |