diff options
author | Lester Solbakken <lesters@oath.com> | 2018-02-05 16:04:42 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-02-22 12:54:34 +0100 |
commit | b1f46fcd0495dbce905fb8b7318781f4cf5965a7 (patch) | |
tree | d0a0506fe66e5af4af2a927101a0eb9ed9420d38 /config-model | |
parent | e307df56eaaf5b0ebca5aefb7f7e0c5c3a970bdb (diff) |
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | 36 |
1 files changed, 18 insertions, 18 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 4693ac5cf4d..962d17541cc 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -42,7 +42,7 @@ import static org.junit.Assert.*; public class RankingExpressionWithTensorFlowTestCase { private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); - private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(\"layer_Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"layer_Variable_1\"), d0, d1), f(a,b)(a + b))"; + private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { @@ -54,8 +54,8 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -65,8 +65,8 @@ public class RankingExpressionWithTensorFlowTestCase { "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -85,8 +85,8 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -99,8 +99,8 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -119,8 +119,8 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -128,8 +128,8 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -224,8 +224,8 @@ public class RankingExpressionWithTensorFlowTestCase { "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -243,8 +243,8 @@ public class RankingExpressionWithTensorFlowTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test - assertLargeConstant("layer_Variable_1", searchFromStored, Optional.empty()); - assertLargeConstant("layer_Variable", searchFromStored, Optional.empty()); + assertLargeConstant("layer_Variable_1_read", searchFromStored, Optional.empty()); + assertLargeConstant("layer_Variable_read", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -253,7 +253,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(rename(reduce(join(join(join(rename(constant(\"dnn_hidden2_Const\"), d0, d1), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))"; + final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist/saved')", |