diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-09 16:07:43 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-09 16:07:43 +0100 |
commit | dc0f70fac9167acf487453daf565636c675934df (patch) | |
tree | aaccfae7aaf4a48e35655a66c75ea57412ede6a6 /config-model/src/test/java/com/yahoo/searchdefinition | |
parent | fa9fe82c82d6a562e3ae02b9577f536a16c72c92 (diff) |
Basic TensorFlow integration
This wil replace any occurrence of tensorflow(...)
in ranking expressions with the corresponding translated expression.
It is functional but these tings are outstanding
- Propagate warnings
- Import a model just once even if referred multiple times
- Add constants as tensor files rather than config
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition')
3 files changed, 184 insertions, 49 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java new file mode 100644 index 00000000000..e71a627d7db --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -0,0 +1,58 @@ +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchdefinition.parser.ParseException; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Helper class for setting up and asserting over a Search instance with a rank profile given literally + * in the search definition language. + * + * @author geirst + */ +class RankProfileSearchFixture { + + private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + private Search search; + + RankProfileSearchFixture(String rankProfiles) throws ParseException { + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + String sdContent = "search test {\n" + + " document test {\n" + + " }\n" + + rankProfiles + + "\n" + + "}"; + builder.importString(sdContent); + builder.build(); + search = builder.getSearch(); + } + + public void assertFirstPhaseExpression(String expExpression, String rankProfile) { + assertEquals(expExpression, rankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString()); + } + + public void assertSecondPhaseExpression(String expExpression, String rankProfile) { + assertEquals(expExpression, rankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString()); + } + + public void assertRankProperty(String expValue, String name, String rankProfile) { + List<RankProfile.RankProperty> rankPropertyList = rankProfile(rankProfile).getRankPropertyMap().get(name); + assertEquals(1, rankPropertyList.size()); + assertEquals(expValue, rankPropertyList.get(0).getValue()); + } + + public void assertMacro(String expExpression, String macroName, String rankProfile) { + assertEquals(expExpression, rankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString()); + } + + public RankProfile rankProfile(String rankProfile) { + return rankProfileRegistry.getRankProfile(search, rankProfile).compile(); + } +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java new file mode 100644 index 00000000000..5ad85ac872c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -0,0 +1,119 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class RankingExpressionWithTensorFlowTestCase { + + private final String modelDirectory = "src/test/integration/tensorflow/mnist_softmax/saved"; + private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; + + @Test + public void testMinimalTensorFlowReference() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + + Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor(); + assertNotNull("Variable_1 is imported", variable_1); + assertEquals(10, variable_1.size()); + + Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor(); + assertNotNull("Variable is imported", variable); + assertEquals(7840, variable.size()); + } + + @Test + public void testNestedTensorFlowReference() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: 5 + sum(tensorflow('" + modelDirectory + "'))" + + " }\n" + + " }"); + search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); + + Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor(); + assertNotNull("Variable_1 is imported", variable_1); + assertEquals(10, variable_1.size()); + + Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor(); + assertNotNull("Variable is imported", variable); + assertEquals(7840, variable.size()); + } + + @Test + public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_default')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test + public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'y')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test + public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { + try { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_defaultz')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" + + modelDirectory + "','serving_defaultz'): Model does not have the specified signatures 'serving_defaultz'", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { + try { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'x')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" + + modelDirectory + "','serving_default','x'): Model does not have the specified outputs 'x'", + Exceptions.toMessageString(expected)); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java index 4dcf7523fd0..dba2bdbfbbf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java @@ -1,61 +1,19 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.Search; -import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import java.util.List; - -import static org.junit.Assert.assertEquals; - /** * @author geirst */ public class RankingExpressionWithTensorTestCase { - private static class SearchFixture { - RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - Search search; - SearchFixture(String rankProfiles) throws ParseException { - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); - String sdContent = "search test {\n" + - " document test {\n" + - " }\n" + - rankProfiles + - "\n" + - "}"; - builder.importString(sdContent); - builder.build(); - search = builder.getSearch(); - } - public void assertFirstPhaseExpression(String expExpression, String rankProfile) { - assertEquals(expExpression, getRankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString()); - } - public void assertSecondPhaseExpression(String expExpression, String rankProfile) { - assertEquals(expExpression, getRankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString()); - } - public void assertRankProperty(String expValue, String name, String rankProfile) { - List<RankProfile.RankProperty> rankPropertyList = getRankProfile(rankProfile).getRankPropertyMap().get(name); - assertEquals(1, rankPropertyList.size()); - assertEquals(expValue, rankPropertyList.get(0).getValue()); - } - public void assertMacro(String expExpression, String macroName, String rankProfile) { - assertEquals(expExpression, getRankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString()); - } - private RankProfile getRankProfile(String rankProfile) { - return rankProfileRegistry.getRankProfile(search, rankProfile).compile(); - } - } - @Test public void requireThatSingleLineConstantTensorAndTypeCanBeParsed() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " first-phase {\n" + " expression: sum(my_tensor)\n" + @@ -74,7 +32,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatMultiLineConstantTensorAndTypeCanBeParsed() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " first-phase {\n" + " expression: sum(my_tensor)\n" + @@ -96,7 +54,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatConstantTensorsCanBeUsedInSecondPhaseExpression() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " second-phase {\n" + " expression: sum(my_tensor)\n" + @@ -114,7 +72,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatConstantTensorsCanBeUsedInInheritedRankProfile() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile parent {\n" + " constants {\n" + " my_tensor {\n" + @@ -134,7 +92,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatConstantTensorsCanBeUsedInMacro() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " macro my_macro() {\n" + " expression: sum(my_tensor)\n" + @@ -156,7 +114,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatCombinationOfConstantTensorsAndConstantValuesCanBeUsed() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " first-phase {\n" + " expression: my_number_1 + sum(my_tensor) + my_number_2\n" + @@ -181,7 +139,7 @@ public class RankingExpressionWithTensorTestCase { public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'"); - new SearchFixture( + new RankProfileSearchFixture( " rank-profile my_profile {\n" + " constants {\n" + " my_tensor {\n" + |