diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | 43 |
1 files changed, 35 insertions, 8 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 7228af2b0de..29859817736 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -6,6 +6,7 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; +import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; @@ -22,10 +23,12 @@ import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; +import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.io.Reader; import java.io.UncheckedIOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; @@ -156,6 +159,7 @@ public class RankingExpressionWithTensorFlowTestCase { " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); + search.compileRankProfile("my_profile"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -196,7 +200,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + - "Model does not have the specified signature 'serving_defaultz'", + "No expressions available in model 'mnist_softmax_saved'", +// "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+ +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } @@ -212,7 +218,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_default','x'): " + - "Model does not have the specified output 'x'", + "No expressions available in model 'mnist_softmax_saved'", +// "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " + +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } @@ -268,7 +276,8 @@ public class RankingExpressionWithTensorFlowTestCase { String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_saved_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"; - RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + RankProfileSearchFixture search = fixtureWithUncompiled(rankProfile, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile"); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", @@ -282,7 +291,8 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfile, storedApplication); + searchFromStored.compileRankProfile("my_profile"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", searchFromStored.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); @@ -297,7 +307,7 @@ public class RankingExpressionWithTensorFlowTestCase { public void testTensorFlowReduceBatchDimension() { final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist_softmax/saved')"); + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", search, Optional.of(10L)); assertLargeConstant("mnist_softmax_saved_layer_Variable_read", search, Optional.of(7840L)); @@ -362,7 +372,7 @@ public class RankingExpressionWithTensorFlowTestCase { } private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { - Value value = search.rankProfile("my_profile").getConstants().get(name); + Value value = search.compiledRankProfile("my_profile").getConstants().get(name); assertNotNull(value); assertEquals(type, value.type()); } @@ -410,7 +420,7 @@ public class RankingExpressionWithTensorFlowTestCase { String macroName, StoringApplicationPackage application) { try { - return new RankProfileSearchFixture( + RankProfileSearchFixture fixture = new RankProfileSearchFixture( application, application.getQueryProfiles(), " rank-profile my_profile {\n" + @@ -423,13 +433,15 @@ public class RankingExpressionWithTensorFlowTestCase { " }", constant, field); + fixture.compileRankProfile("my_profile"); + return fixture; } catch (ParseException e) { throw new IllegalArgumentException(e); } } - private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + private RankProfileSearchFixture fixtureWithUncompiled(String rankProfile, StoringApplicationPackage application) { try { return new RankProfileSearchFixture(application, application.getQueryProfiles(), rankProfile, null, null); @@ -463,6 +475,21 @@ public class RankingExpressionWithTensorFlowTestCase { return new StoringApplicationPackageFile(file, Path.fromString(root.toString())); } + @Override + public List<NamedReader> getFiles(Path path, String suffix) { + List<NamedReader> readers = new ArrayList<>(); + for (File file : getFileReference(path).listFiles()) { + if ( ! file.getName().endsWith(suffix)) continue; + try { + readers.add(new NamedReader(file.getName(), new FileReader(file))); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + return readers; + } + } static class StoringApplicationPackageFile extends ApplicationFile { |