diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-02-07 09:55:04 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-02-07 09:55:04 +0100 |
commit | 1f6b1bdd519409243cb6e2dec182605599ac1aab (patch) | |
tree | b907614bc2cf127406f0b356516f644d95bda61f /config-model/src/test/java/com | |
parent | 67a4cc635d67059c53a1a812d0c7958b1a379ccc (diff) |
Test model with small constant
Diffstat (limited to 'config-model/src/test/java/com')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | 117 |
1 files changed, 80 insertions, 37 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 7246b22b0f8..83cc3ae418a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -6,12 +6,13 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; -import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.yolean.Exceptions; import org.junit.After; @@ -24,7 +25,6 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.Reader; -import java.io.StringReader; import java.io.UncheckedIOException; import java.util.Arrays; import java.util.Collections; @@ -33,9 +33,7 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; /** * @author bratseth @@ -51,27 +49,27 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReference() throws ParseException { + public void testTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant("layer_Variable_1", search, Optional.of(10L)); - assertConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); + assertLargeConstant("layer_Variable", search, Optional.of(7840L)); } @Test - public void testTensorFlowReferenceWithConstantFeature() throws ParseException { + public void testTensorFlowReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "tensorflow('mnist_softmax/saved')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant("layer_Variable_1", search, Optional.of(10L)); - assertConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); + assertLargeConstant("layer_Variable", search, Optional.of(7840L)); } @Test - public void testTensorFlowReferenceWithQueryFeature() throws ParseException { + public void testTensorFlowReferenceWithQueryFeature() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" + @@ -83,27 +81,29 @@ public class RankingExpressionWithTensorFlowTestCase { "tensorflow('mnist_softmax/saved')", null, null, + "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant("layer_Variable_1", search, Optional.of(10L)); - assertConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); + assertLargeConstant("layer_Variable", search, Optional.of(7840L)); } @Test - public void testTensorFlowReferenceWithDocumentFeature() throws ParseException { + public void testTensorFlowReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", "tensorflow('mnist_softmax/saved')", null, "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant("layer_Variable_1", search, Optional.of(10L)); - assertConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); + assertLargeConstant("layer_Variable", search, Optional.of(7840L)); } @Test - public void testTensorFlowReferenceWithFeatureCombination() throws ParseException { + public void testTensorFlowReferenceWithFeatureCombination() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" + @@ -115,30 +115,31 @@ public class RankingExpressionWithTensorFlowTestCase { "tensorflow('mnist_softmax/saved')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant("layer_Variable_1", search, Optional.of(10L)); - assertConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); + assertLargeConstant("layer_Variable", search, Optional.of(7840L)); } @Test - public void testNestedTensorFlowReference() throws ParseException { + public void testNestedTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertConstant("layer_Variable_1", search, Optional.of(10L)); - assertConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); + assertLargeConstant("layer_Variable", search, Optional.of(7840L)); } @Test - public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + public void testTensorFlowReferenceSpecifyingSignature() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved', 'serving_default')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test - public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { + public void testTensorFlowReferenceSpecifyingSignatureAndOutput() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved', 'serving_default', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -168,7 +169,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithWrongMacroType() throws ParseException { + public void testTensorFlowReferenceWithWrongMacroType() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)", "tensorflow('mnist_softmax/saved')"); @@ -185,7 +186,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { + public void testTensorFlowReferenceSpecifyingNonExistingSignature() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved', 'serving_defaultz')"); @@ -201,7 +202,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { + public void testTensorFlowReferenceSpecifyingNonExistingOutput() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved', 'serving_default', 'x')"); @@ -217,12 +218,13 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testImportingFromStoredExpressions() throws ParseException, IOException { + public void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant("layer_Variable_1", search, Optional.of(10L)); - assertConstant("layer_Variable", search, Optional.of(7840L)); + + assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); + assertLargeConstant("layer_Variable", search, Optional.of(7840L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -235,24 +237,64 @@ public class RankingExpressionWithTensorFlowTestCase { "tensorflow('mnist_softmax/saved')", null, null, + "Placeholder", storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test - assertConstant("layer_Variable_1", searchFromStored, Optional.empty()); - assertConstant("layer_Variable", searchFromStored, Optional.empty()); + assertLargeConstant("layer_Variable_1", searchFromStored, Optional.empty()); + assertLargeConstant("layer_Variable", searchFromStored, Optional.empty()); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + @Test + public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { + final String expression = "join(rename(reduce(join(map(join(rename(reduce(join(join(join(constant(\"dnn_hidden1_mul_x\"), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))"; + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist/saved')", + null, + null, + "input", + application); + search.assertFirstPhaseExpression(expression, "my_profile"); + assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist/saved')", + null, + null, + "input", + storedApplication); + searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); + assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } + } + private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { + Value value = search.rankProfile("my_profile").getConstants().get(name); + assertNotNull(value); + assertEquals(type, value.type()); } /** * Verifies that the constant with the given name exists, and - only if an expected size is given - * that the content of the constant is available and has the expected size. */ - private void assertConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { + private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { try { Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); RankingConstant rankingConstant = search.search().getRankingConstants().get(name); @@ -274,13 +316,13 @@ public class RankingExpressionWithTensorFlowTestCase { } private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { - return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, + return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); } private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, String constant, String field) { - return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, + return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder", new StoringApplicationPackage(applicationDir)); } @@ -288,13 +330,14 @@ public class RankingExpressionWithTensorFlowTestCase { String firstPhaseExpression, String constant, String field, + String macroName, StoringApplicationPackage application) { try { return new RankProfileSearchFixture( application, application.getQueryProfiles(), " rank-profile my_profile {\n" + - " macro Placeholder() {\n" + + " macro " + macroName + "() {\n" + " expression: " + placeholderExpression + " }\n" + " first-phase {\n" + |