diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-21 15:25:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-21 15:25:14 +0200 |
commit | 326075a15aa345c4ee00afc17357c62d07f06cdc (patch) | |
tree | ae5455accba92f2fd1b3392aa68d21e13175016c /config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | |
parent | d914bbe09f5b6140f9becc474790fe21ac5fac00 (diff) | |
parent | 90ab7315e26b3847f77c4b65755906490da55cf6 (diff) |
Merge pull request #6640 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-2-5
Bratseth/generate rank profiles for all models part 2 5
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | 93 |
1 files changed, 68 insertions, 25 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 7228af2b0de..9a96555bb78 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -6,6 +6,7 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; +import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; @@ -22,10 +23,12 @@ import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; +import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.io.Reader; import java.io.UncheckedIOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; @@ -156,6 +159,7 @@ public class RankingExpressionWithTensorFlowTestCase { " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); + search.compileRankProfile("my_profile"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -196,7 +200,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + - "Model does not have the specified signature 'serving_defaultz'", + "No expressions available in model 'mnist_softmax_saved'", +// "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+ +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } @@ -212,7 +218,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_default','x'): " + - "Model does not have the specified output 'x'", + "No expressions available in model 'mnist_softmax_saved'", +// "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " + +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } @@ -251,8 +259,8 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException { - String rankProfile = + public void testImportingFromStoredExpressionsWithMacroOverridingConstantAndInheritance() throws IOException { + String rankProfiles = " rank-profile my_profile {\n" + " macro Placeholder() {\n" + " expression: tensor(d0[2],d1[784])(0.0)\n" + @@ -263,13 +271,17 @@ public class RankingExpressionWithTensorFlowTestCase { " first-phase {\n" + " expression: tensorflow('mnist_softmax/saved')" + " }\n" + + " }" + + " rank-profile my_profile_child inherits my_profile {\n" + " }"; - String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_saved_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"; - RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile"); + search.compileRankProfile("my_profile_child"); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", search.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); @@ -282,8 +294,11 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication); + searchFromStored.compileRankProfile("my_profile"); + searchFromStored.compileRankProfile("my_profile_child"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", searchFromStored.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", searchFromStored, Optional.of(10L)); @@ -297,7 +312,7 @@ public class RankingExpressionWithTensorFlowTestCase { public void testTensorFlowReduceBatchDimension() { final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist_softmax/saved')"); + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", search, Optional.of(10L)); assertLargeConstant("mnist_softmax_saved_layer_Variable_read", search, Optional.of(7840L)); @@ -321,22 +336,33 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { + public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException { + final String rankProfiles = + " rank-profile my_profile {\n" + + " macro input() {\n" + + " expression: tensor(d0[1],d1[784])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: tensorflow('mnist/saved')" + + " }\n" + + " }" + + " rank-profile my_profile_child inherits my_profile {\n" + + " }"; + final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist/saved')", - null, - null, - "input", - application); + RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile"); + search.compileRankProfile("my_profile_child"); search.assertFirstPhaseExpression(expression, "my_profile"); + search.assertFirstPhaseExpression(expression, "my_profile_child"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child"); search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -345,16 +371,16 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist/saved')", - null, - null, - "input", - storedApplication); + RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication); + searchFromStored.compileRankProfile("my_profile"); + searchFromStored.compileRankProfile("my_profile_child"); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); + searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child"); searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child"); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -362,7 +388,7 @@ public class RankingExpressionWithTensorFlowTestCase { } private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { - Value value = search.rankProfile("my_profile").getConstants().get(name); + Value value = search.compiledRankProfile("my_profile").getConstants().get(name); assertNotNull(value); assertEquals(type, value.type()); } @@ -410,7 +436,7 @@ public class RankingExpressionWithTensorFlowTestCase { String macroName, StoringApplicationPackage application) { try { - return new RankProfileSearchFixture( + RankProfileSearchFixture fixture = new RankProfileSearchFixture( application, application.getQueryProfiles(), " rank-profile my_profile {\n" + @@ -423,13 +449,15 @@ public class RankingExpressionWithTensorFlowTestCase { " }", constant, field); + fixture.compileRankProfile("my_profile"); + return fixture; } catch (ParseException e) { throw new IllegalArgumentException(e); } } - private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + private RankProfileSearchFixture fixtureWithUncompiled(String rankProfile, StoringApplicationPackage application) { try { return new RankProfileSearchFixture(application, application.getQueryProfiles(), rankProfile, null, null); @@ -463,6 +491,21 @@ public class RankingExpressionWithTensorFlowTestCase { return new StoringApplicationPackageFile(file, Path.fromString(root.toString())); } + @Override + public List<NamedReader> getFiles(Path path, String suffix) { + List<NamedReader> readers = new ArrayList<>(); + for (File file : getFileReference(path).listFiles()) { + if ( ! file.getName().endsWith(suffix)) continue; + try { + readers.add(new NamedReader(file.getName(), new FileReader(file))); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + return readers; + } + } static class StoringApplicationPackageFile extends ApplicationFile { |