diff options
author | Lester Solbakken <lesters@oath.com> | 2018-02-08 09:15:56 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-02-08 09:15:56 +0100 |
commit | ff500ab6c72887f64cfbf0e6b40748c7c6e9dd08 (patch) | |
tree | 08941eb8a4b630e447e4209b519f3aa713f94a73 /config-model/src/test/java/com | |
parent | 74b3ef7b54e8ac8b0473c016185f1476a3fd3db4 (diff) |
Inline small tensor constants imported from tensorflow
Diffstat (limited to 'config-model/src/test/java/com')
2 files changed, 1 insertions, 41 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 83cc3ae418a..3e11eb72a30 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -250,46 +250,6 @@ public class RankingExpressionWithTensorFlowTestCase { } } - @Test - public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(rename(reduce(join(map(join(rename(reduce(join(join(join(constant(\"dnn_hidden1_mul_x\"), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))"; - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "tensorflow('mnist/saved')", - null, - null, - "input", - application); - search.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search); - - // At this point the expression is stored - copy application to another location which do not have a models dir - Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); - try { - storedApplicationDirectory.toFile().mkdirs(); - IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), - storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "tensorflow('mnist/saved')", - null, - null, - "input", - storedApplication); - searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search); - } - finally { - IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); - } - } - - private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { - Value value = search.rankProfile("my_profile").getConstants().get(name); - assertNotNull(value); - assertEquals(type, value.type()); - } - /** * Verifies that the constant with the given name exists, and - only if an expected size is given - * that the content of the constant is available and has the expected size. diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index c18cfcfe1aa..b001db69768 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -64,7 +64,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) - assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); + assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)"); } @Test |