summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-08 10:32:45 +0100
committerLester Solbakken <lesters@oath.com>2018-02-08 10:32:45 +0100
commitb7e875046fa3cc6f8f4e7459684c6351f072925c (patch)
treecc09cdff56e43476efcc544d533053c8479640fd /config-model/src/test/java/com
parent6f48606d0066595ed9a7e4514c1e5fcf0460fbb6 (diff)
Bring back tensorflow import test for small constants
Diffstat (limited to 'config-model/src/test/java/com')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java41
1 files changed, 41 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 3e11eb72a30..4693ac5cf4d 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -8,6 +8,7 @@ import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -250,6 +251,46 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
+ @Test
+ public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
+ final String expression = "join(rename(reduce(join(join(join(rename(constant(\"dnn_hidden2_Const\"), d0, d1), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))";
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist/saved')",
+ null,
+ null,
+ "input",
+ application);
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search);
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist/saved')",
+ null,
+ null,
+ "input",
+ storedApplication);
+ searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
+ assertSmallConstant("dnn_hidden2_Const", TensorType.fromSpec("tensor(d0[1])"), search);
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
+ Value value = search.rankProfile("my_profile").getConstants().get(name);
+ assertNotNull(value);
+ assertEquals(type, value.type());
+ }
+
/**
* Verifies that the constant with the given name exists, and - only if an expected size is given -
* that the content of the constant is available and has the expected size.