diff options
author | Lester Solbakken <lesters@oath.com> | 2018-02-08 10:32:45 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-02-08 10:32:45 +0100 |
commit | b7e875046fa3cc6f8f4e7459684c6351f072925c (patch) | |
tree | cc09cdff56e43476efcc544d533053c8479640fd /config-model/src/test/java/com/yahoo/searchdefinition | |
parent | 6f48606d0066595ed9a7e4514c1e5fcf0460fbb6 (diff) |
Bring back tensorflow import test for small constants
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 3e11eb72a30..4693ac5cf4d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -250,6 +251,46 @@ public class RankingExpressionWithTensorFlowTestCase { } } + @Test + public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { + final String expression = "join(rename(reduce(join(join(join(rename(constant(\"dnn_hidden2_Const\"), d0, d1), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))"; + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist/saved')", + null, + null, + "input", + application); + search.assertFirstPhaseExpression(expression, "my_profile"); + assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist/saved')", + null, + null, + "input", + storedApplication); + searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); + assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { + Value value = search.rankProfile("my_profile").getConstants().get(name); + assertNotNull(value); + assertEquals(type, value.type()); + } + /** * Verifies that the constant with the given name exists, and - only if an expected size is given - * that the content of the constant is available and has the expected size. |