summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-03-05 16:22:51 +0100
committerLester Solbakken <lesters@oath.com>2018-03-05 16:22:51 +0100
commit18d94ce9ee93db460f40c4e533ef1442768a73a2 (patch)
treed4d3b08055f789c4885af8aee536a88e0a12589e /config-model
parent3dc6c980c74ff9b280a840374c85026297de89a3 (diff)
Add testing of macro generation from stored model
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java35
1 files changed, 21 insertions, 14 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index bd2bbf5c6d5..c650151980c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -262,8 +262,24 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
+ public void testMacroGeneration() {
+ final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
+ final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))";
+
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
+ "tensorflow('mnist/saved')");
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
+ }
+
+ @Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
+ final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))";
+
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
@@ -273,6 +289,8 @@ public class RankingExpressionWithTensorFlowTestCase {
application);
search.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search);
+ search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -281,7 +299,7 @@ public class RankingExpressionWithTensorFlowTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
null,
null,
@@ -289,25 +307,14 @@ public class RankingExpressionWithTensorFlowTestCase {
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search);
+ searchFromStored.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
}
}
- @Test
- public void testMacroGeneration() {
- final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
- final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))";
-
- RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
- "tensorflow('mnist/saved')");
- search.assertFirstPhaseExpression(expression, "my_profile");
- search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
- }
-
private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
Value value = search.rankProfile("my_profile").getConstants().get(name);
assertNotNull(value);