diff options
author | Lester Solbakken <lesters@oath.com> | 2018-03-05 16:22:51 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-03-05 16:22:51 +0100 |
commit | 18d94ce9ee93db460f40c4e533ef1442768a73a2 (patch) | |
tree | d4d3b08055f789c4885af8aee536a88e0a12589e /config-model | |
parent | 3dc6c980c74ff9b280a840374c85026297de89a3 (diff) |
Add testing of macro generation from stored model
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index bd2bbf5c6d5..c650151980c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -262,8 +262,24 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test + public void testMacroGeneration() { + final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; + + RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", + "tensorflow('mnist/saved')"); + search.assertFirstPhaseExpression(expression, "my_profile"); + search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); + } + + @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')", @@ -273,6 +289,8 @@ public class RankingExpressionWithTensorFlowTestCase { application); search.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search); + search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -281,7 +299,7 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')", null, null, @@ -289,25 +307,14 @@ public class RankingExpressionWithTensorFlowTestCase { storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search); + searchFromStored.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); + searchFromStored.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } } - @Test - public void testMacroGeneration() { - final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; - final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; - - RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist/saved')"); - search.assertFirstPhaseExpression(expression, "my_profile"); - search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); - } - private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { Value value = search.rankProfile("my_profile").getConstants().get(name); assertNotNull(value); |