summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-09 16:07:43 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-09 16:07:43 +0100
commitdc0f70fac9167acf487453daf565636c675934df (patch)
treeaaccfae7aaf4a48e35655a66c75ea57412ede6a6 /config-model/src/test/java/com/yahoo/searchdefinition
parentfa9fe82c82d6a562e3ae02b9577f536a16c72c92 (diff)
Basic TensorFlow integration
This wil replace any occurrence of tensorflow(...) in ranking expressions with the corresponding translated expression. It is functional but these tings are outstanding - Propagate warnings - Import a model just once even if referred multiple times - Add constants as tensor files rather than config
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java58
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java119
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java56
3 files changed, 184 insertions, 49 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
new file mode 100644
index 00000000000..e71a627d7db
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -0,0 +1,58 @@
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchdefinition.parser.ParseException;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Helper class for setting up and asserting over a Search instance with a rank profile given literally
+ * in the search definition language.
+ *
+ * @author geirst
+ */
+class RankProfileSearchFixture {
+
+ private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ private Search search;
+
+ RankProfileSearchFixture(String rankProfiles) throws ParseException {
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ String sdContent = "search test {\n" +
+ " document test {\n" +
+ " }\n" +
+ rankProfiles +
+ "\n" +
+ "}";
+ builder.importString(sdContent);
+ builder.build();
+ search = builder.getSearch();
+ }
+
+ public void assertFirstPhaseExpression(String expExpression, String rankProfile) {
+ assertEquals(expExpression, rankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString());
+ }
+
+ public void assertSecondPhaseExpression(String expExpression, String rankProfile) {
+ assertEquals(expExpression, rankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString());
+ }
+
+ public void assertRankProperty(String expValue, String name, String rankProfile) {
+ List<RankProfile.RankProperty> rankPropertyList = rankProfile(rankProfile).getRankPropertyMap().get(name);
+ assertEquals(1, rankPropertyList.size());
+ assertEquals(expValue, rankPropertyList.get(0).getValue());
+ }
+
+ public void assertMacro(String expExpression, String macroName, String rankProfile) {
+ assertEquals(expExpression, rankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString());
+ }
+
+ public RankProfile rankProfile(String rankProfile) {
+ return rankProfileRegistry.getRankProfile(search, rankProfile).compile();
+ }
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
new file mode 100644
index 00000000000..5ad85ac872c
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -0,0 +1,119 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.yolean.Exceptions;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class RankingExpressionWithTensorFlowTestCase {
+
+ private final String modelDirectory = "src/test/integration/tensorflow/mnist_softmax/saved";
+ private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))";
+
+ @Test
+ public void testMinimalTensorFlowReference() throws ParseException {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('" + modelDirectory + "')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+
+ Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor();
+ assertNotNull("Variable_1 is imported", variable_1);
+ assertEquals(10, variable_1.size());
+
+ Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor();
+ assertNotNull("Variable is imported", variable);
+ assertEquals(7840, variable.size());
+ }
+
+ @Test
+ public void testNestedTensorFlowReference() throws ParseException {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: 5 + sum(tensorflow('" + modelDirectory + "'))" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
+
+ Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor();
+ assertNotNull("Variable_1 is imported", variable_1);
+ assertEquals(10, variable_1.size());
+
+ Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor();
+ assertNotNull("Variable is imported", variable);
+ assertEquals(7840, variable.size());
+ }
+
+ @Test
+ public void testTensorFlowReferenceSpecifyingSignature() throws ParseException {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('" + modelDirectory + "', 'serving_default')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ }
+
+ @Test
+ public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'y')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ }
+
+ @Test
+ public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException {
+ try {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('" + modelDirectory + "', 'serving_defaultz')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" +
+ modelDirectory + "','serving_defaultz'): Model does not have the specified signatures 'serving_defaultz'",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException {
+ try {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'x')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" +
+ modelDirectory + "','serving_default','x'): Model does not have the specified outputs 'x'",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
index 4dcf7523fd0..dba2bdbfbbf 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
@@ -1,61 +1,19 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankProfileRegistry;
-import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-
/**
* @author geirst
*/
public class RankingExpressionWithTensorTestCase {
- private static class SearchFixture {
- RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- Search search;
- SearchFixture(String rankProfiles) throws ParseException {
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
- String sdContent = "search test {\n" +
- " document test {\n" +
- " }\n" +
- rankProfiles +
- "\n" +
- "}";
- builder.importString(sdContent);
- builder.build();
- search = builder.getSearch();
- }
- public void assertFirstPhaseExpression(String expExpression, String rankProfile) {
- assertEquals(expExpression, getRankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString());
- }
- public void assertSecondPhaseExpression(String expExpression, String rankProfile) {
- assertEquals(expExpression, getRankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString());
- }
- public void assertRankProperty(String expValue, String name, String rankProfile) {
- List<RankProfile.RankProperty> rankPropertyList = getRankProfile(rankProfile).getRankPropertyMap().get(name);
- assertEquals(1, rankPropertyList.size());
- assertEquals(expValue, rankPropertyList.get(0).getValue());
- }
- public void assertMacro(String expExpression, String macroName, String rankProfile) {
- assertEquals(expExpression, getRankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString());
- }
- private RankProfile getRankProfile(String rankProfile) {
- return rankProfileRegistry.getRankProfile(search, rankProfile).compile();
- }
- }
-
@Test
public void requireThatSingleLineConstantTensorAndTypeCanBeParsed() throws ParseException {
- SearchFixture f = new SearchFixture(
+ RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" first-phase {\n" +
" expression: sum(my_tensor)\n" +
@@ -74,7 +32,7 @@ public class RankingExpressionWithTensorTestCase {
@Test
public void requireThatMultiLineConstantTensorAndTypeCanBeParsed() throws ParseException {
- SearchFixture f = new SearchFixture(
+ RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" first-phase {\n" +
" expression: sum(my_tensor)\n" +
@@ -96,7 +54,7 @@ public class RankingExpressionWithTensorTestCase {
@Test
public void requireThatConstantTensorsCanBeUsedInSecondPhaseExpression() throws ParseException {
- SearchFixture f = new SearchFixture(
+ RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" second-phase {\n" +
" expression: sum(my_tensor)\n" +
@@ -114,7 +72,7 @@ public class RankingExpressionWithTensorTestCase {
@Test
public void requireThatConstantTensorsCanBeUsedInInheritedRankProfile() throws ParseException {
- SearchFixture f = new SearchFixture(
+ RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile parent {\n" +
" constants {\n" +
" my_tensor {\n" +
@@ -134,7 +92,7 @@ public class RankingExpressionWithTensorTestCase {
@Test
public void requireThatConstantTensorsCanBeUsedInMacro() throws ParseException {
- SearchFixture f = new SearchFixture(
+ RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" macro my_macro() {\n" +
" expression: sum(my_tensor)\n" +
@@ -156,7 +114,7 @@ public class RankingExpressionWithTensorTestCase {
@Test
public void requireThatCombinationOfConstantTensorsAndConstantValuesCanBeUsed() throws ParseException {
- SearchFixture f = new SearchFixture(
+ RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" first-phase {\n" +
" expression: my_number_1 + sum(my_tensor) + my_number_2\n" +
@@ -181,7 +139,7 @@ public class RankingExpressionWithTensorTestCase {
public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'");
- new SearchFixture(
+ new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" constants {\n" +
" my_tensor {\n" +