summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java40
1 files changed, 20 insertions, 20 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 4693ac5cf4d..1e376824b7b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -42,7 +42,7 @@ import static org.junit.Assert.*;
public class RankingExpressionWithTensorFlowTestCase {
private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/");
- private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(\"layer_Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"layer_Variable_1\"), d0, d1), f(a,b)(a + b))";
+ private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b))";
@After
public void removeGeneratedConstantTensorFiles() {
@@ -54,8 +54,8 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -65,15 +65,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
public void testTensorFlowReferenceWithQueryFeature() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -85,8 +85,8 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -99,15 +99,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
public void testTensorFlowReferenceWithFeatureCombination() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -119,8 +119,8 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -128,8 +128,8 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(tensorflow('mnist_softmax/saved'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -224,8 +224,8 @@ public class RankingExpressionWithTensorFlowTestCase {
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -243,8 +243,8 @@ public class RankingExpressionWithTensorFlowTestCase {
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
// simulating file distribution in this test
- assertLargeConstant("layer_Variable_1", searchFromStored, Optional.empty());
- assertLargeConstant("layer_Variable", searchFromStored, Optional.empty());
+ assertLargeConstant("layer_Variable_1_read", searchFromStored, Optional.empty());
+ assertLargeConstant("layer_Variable_read", searchFromStored, Optional.empty());
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -253,7 +253,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(rename(reduce(join(join(join(rename(constant(\"dnn_hidden2_Const\"), d0, d1), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))";
+ final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist/saved')",