summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-02-07 09:55:04 +0100
committerJon Bratseth <bratseth@oath.com>2018-02-07 09:55:04 +0100
commit1f6b1bdd519409243cb6e2dec182605599ac1aab (patch)
treeb907614bc2cf127406f0b356516f644d95bda61f /config-model/src/test/java/com
parent67a4cc635d67059c53a1a812d0c7958b1a379ccc (diff)
Test model with small constant
Diffstat (limited to 'config-model/src/test/java/com')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java117
1 files changed, 80 insertions, 37 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 7246b22b0f8..83cc3ae418a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -6,12 +6,13 @@ import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
-import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.yolean.Exceptions;
import org.junit.After;
@@ -24,7 +25,6 @@ import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
-import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.Collections;
@@ -33,9 +33,7 @@ import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.junit.Assert.*;
/**
* @author bratseth
@@ -51,27 +49,27 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReference() throws ParseException {
+ public void testTensorFlowReference() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertConstant("layer_Variable_1", search, Optional.of(10L));
- assertConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable", search, Optional.of(7840L));
}
@Test
- public void testTensorFlowReferenceWithConstantFeature() throws ParseException {
+ public void testTensorFlowReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"tensorflow('mnist_softmax/saved')",
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertConstant("layer_Variable_1", search, Optional.of(10L));
- assertConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable", search, Optional.of(7840L));
}
@Test
- public void testTensorFlowReferenceWithQueryFeature() throws ParseException {
+ public void testTensorFlowReferenceWithQueryFeature() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
" <field name='mytensor' type='tensor(d0[3],d1[784])'/>" +
@@ -83,27 +81,29 @@ public class RankingExpressionWithTensorFlowTestCase {
"tensorflow('mnist_softmax/saved')",
null,
null,
+ "Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertConstant("layer_Variable_1", search, Optional.of(10L));
- assertConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable", search, Optional.of(7840L));
}
@Test
- public void testTensorFlowReferenceWithDocumentFeature() throws ParseException {
+ public void testTensorFlowReferenceWithDocumentFeature() {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
"tensorflow('mnist_softmax/saved')",
null,
"field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
+ "Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertConstant("layer_Variable_1", search, Optional.of(10L));
- assertConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable", search, Optional.of(7840L));
}
@Test
- public void testTensorFlowReferenceWithFeatureCombination() throws ParseException {
+ public void testTensorFlowReferenceWithFeatureCombination() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
" <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" +
@@ -115,30 +115,31 @@ public class RankingExpressionWithTensorFlowTestCase {
"tensorflow('mnist_softmax/saved')",
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
"field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
+ "Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertConstant("layer_Variable_1", search, Optional.of(10L));
- assertConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable", search, Optional.of(7840L));
}
@Test
- public void testNestedTensorFlowReference() throws ParseException {
+ public void testNestedTensorFlowReference() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(tensorflow('mnist_softmax/saved'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
- assertConstant("layer_Variable_1", search, Optional.of(10L));
- assertConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable", search, Optional.of(7840L));
}
@Test
- public void testTensorFlowReferenceSpecifyingSignature() throws ParseException {
+ public void testTensorFlowReferenceSpecifyingSignature() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved', 'serving_default')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
- public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException {
+ public void testTensorFlowReferenceSpecifyingSignatureAndOutput() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved', 'serving_default', 'y')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -168,7 +169,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceWithWrongMacroType() throws ParseException {
+ public void testTensorFlowReferenceWithWrongMacroType() {
try {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)",
"tensorflow('mnist_softmax/saved')");
@@ -185,7 +186,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException {
+ public void testTensorFlowReferenceSpecifyingNonExistingSignature() {
try {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved', 'serving_defaultz')");
@@ -201,7 +202,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException {
+ public void testTensorFlowReferenceSpecifyingNonExistingOutput() {
try {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved', 'serving_default', 'x')");
@@ -217,12 +218,13 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testImportingFromStoredExpressions() throws ParseException, IOException {
+ public void testImportingFromStoredExpressions() throws IOException {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertConstant("layer_Variable_1", search, Optional.of(10L));
- assertConstant("layer_Variable", search, Optional.of(7840L));
+
+ assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable", search, Optional.of(7840L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -235,24 +237,64 @@ public class RankingExpressionWithTensorFlowTestCase {
"tensorflow('mnist_softmax/saved')",
null,
null,
+ "Placeholder",
storedApplication);
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
// simulating file distribution in this test
- assertConstant("layer_Variable_1", searchFromStored, Optional.empty());
- assertConstant("layer_Variable", searchFromStored, Optional.empty());
+ assertLargeConstant("layer_Variable_1", searchFromStored, Optional.empty());
+ assertLargeConstant("layer_Variable", searchFromStored, Optional.empty());
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ @Test
+ public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
+ final String expression = "join(rename(reduce(join(map(join(rename(reduce(join(join(join(constant(\"dnn_hidden1_mul_x\"), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))";
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist/saved')",
+ null,
+ null,
+ "input",
+ application);
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search);
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist/saved')",
+ null,
+ null,
+ "input",
+ storedApplication);
+ searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
+ assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search);
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
}
+ }
+ private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
+ Value value = search.rankProfile("my_profile").getConstants().get(name);
+ assertNotNull(value);
+ assertEquals(type, value.type());
}
/**
* Verifies that the constant with the given name exists, and - only if an expected size is given -
* that the content of the constant is available and has the expected size.
*/
- private void assertConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
+ private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
try {
Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf");
RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
@@ -274,13 +316,13 @@ public class RankingExpressionWithTensorFlowTestCase {
}
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
- return fixtureWith(placeholderExpression, firstPhaseExpression, null, null,
+ return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
new StoringApplicationPackage(applicationDir));
}
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
String constant, String field) {
- return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field,
+ return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder",
new StoringApplicationPackage(applicationDir));
}
@@ -288,13 +330,14 @@ public class RankingExpressionWithTensorFlowTestCase {
String firstPhaseExpression,
String constant,
String field,
+ String macroName,
StoringApplicationPackage application) {
try {
return new RankProfileSearchFixture(
application,
application.getQueryProfiles(),
" rank-profile my_profile {\n" +
- " macro Placeholder() {\n" +
+ " macro " + macroName + "() {\n" +
" expression: " + placeholderExpression +
" }\n" +
" first-phase {\n" +