diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-11 13:44:04 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-11 13:44:04 +0200 |
commit | e5fedf20729be081f80f23f9f458ebe465e1c194 (patch) | |
tree | 958a8e1fa438cd828d417767661271d17b7c7c2e | |
parent | 812511539b5ae0146623867229eb297b530a6d35 (diff) |
Refactor test
3 files changed, 67 insertions, 35 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedModelTester.java new file mode 100644 index 00000000000..04b242babc7 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedModelTester.java @@ -0,0 +1,57 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Optional; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ImportedModelTester { + + private final String modelName; + private final Path applicationDir; + + public ImportedModelTester(String modelName, Path applicationDir) { + this.modelName = modelName; + this.applicationDir = applicationDir; + } + + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + public void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) { + try { + Path constantApplicationPackagePath = Path.fromString("models.generated/" + modelName + "/constants").append(constantName + ".tbf"); + RankingConstant rankingConstant = model.rankingConstants().get(constantName); + assertEquals(constantName, rankingConstant.getName()); + assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); + } + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 3ca4cdaf547..4ebfdede4a5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -46,10 +46,10 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testGlobalOnnxModel() throws SAXException, IOException { - ApplicationPackageTester tester = ApplicationPackageTester.create(applicationDir.toString()); - VespaModel model = new VespaModel(tester.app()); - assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); - assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); + tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedAppDir = applicationDir.append("copy"); @@ -60,8 +60,8 @@ public class RankingExpressionWithOnnxTestCase { storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); ApplicationPackageTester storedTester = ApplicationPackageTester.create(storedAppDir.toString()); VespaModel storedModel = new VespaModel(storedTester.app()); - assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); - assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); } finally { IOUtils.recursiveDeleteDir(storedAppDir.toFile()); @@ -263,31 +263,6 @@ public class RankingExpressionWithOnnxTestCase { } } - /** - * Verifies that the constant with the given name exists, and - only if an expected size is given - - * that the content of the constant is available and has the expected size. - */ - private void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) { - try { - Path constantApplicationPackagePath = Path.fromString("models.generated/" + name + "/constants").append(constantName + ".tbf"); - RankingConstant rankingConstant = model.rankingConstants().get(constantName); - assertEquals(constantName, rankingConstant.getName()); - assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); - - if (expectedSize.isPresent()) { - Path constantPath = applicationDir.append(constantApplicationPackagePath); - assertTrue("Constant file '" + constantPath + "' has been written", - constantPath.toFile().exists()); - Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); - assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); - } - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 71b45ec628a..1583a3d66ad 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -61,8 +61,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testGlobalTensorFlowModel() throws SAXException, IOException { - ApplicationPackageTester tester = ApplicationPackageTester.create(applicationDir.toString()); - VespaModel model = new VespaModel(tester.app()); + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); assertLargeConstant(name + "_layer_Variable_1_read", model, Optional.of(10L)); assertLargeConstant(name + "_layer_Variable_read", model, Optional.of(7840L)); @@ -75,8 +75,8 @@ public class RankingExpressionWithTensorFlowTestCase { storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); ApplicationPackageTester storedTester = ApplicationPackageTester.create(storedAppDir.toString()); VespaModel storedModel = new VespaModel(storedTester.app()); - assertLargeConstant(name + "_layer_Variable_1_read", storedModel, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", storedModel, Optional.of(7840L)); + tester.assertLargeConstant(name + "_layer_Variable_1_read", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_layer_Variable_read", storedModel, Optional.of(7840L)); } finally { IOUtils.recursiveDeleteDir(storedAppDir.toFile()); |