diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-01-17 13:51:14 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-17 13:51:14 +0100 |
commit | fd26b36e3607df463b35e856b37d24b5e3514fb7 (patch) | |
tree | 403836969d050736403f6512a455198a2c63edad /config-model/src/test/java/com | |
parent | ceec6d572c06ff812715c97d2c35383c48402f24 (diff) | |
parent | c84b8f952ef5857aa44fad479551eda1f3a4e106 (diff) |
Merge pull request #4692 from vespa-engine/bratseth/store-converted-expressions-in-zk
Bratseth/store converted expressions in zk
Diffstat (limited to 'config-model/src/test/java/com')
2 files changed, 242 insertions, 33 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index ff53fdafacf..7c749608e1f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,5 +1,7 @@ package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; @@ -22,7 +24,11 @@ class RankProfileSearchFixture { private Search search; RankProfileSearchFixture(String rankProfiles) throws ParseException { - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + this(MockApplicationPackage.createEmpty(), rankProfiles); + } + + RankProfileSearchFixture(ApplicationPackage applicationpackage, String rankProfiles) throws ParseException { + SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry); String sdContent = "search test {\n" + " document test {\n" + " }\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 31f7511155b..0354173f365 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,24 +1,36 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; +import java.io.BufferedInputStream; import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -27,47 +39,52 @@ import static org.junit.Assert.fail; */ public class RankingExpressionWithTensorFlowTestCase { - // The "../" is to escape the "models/" element prepended to the path - private final String modelDirectory = "../src/test/integration/tensorflow/mnist_softmax/saved"; + private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { - IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "converted_variables")); + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); } @Test public void testMinimalTensorFlowReference() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "')" + + " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant(10, "Variable_1", search); - assertConstant(7840, "Variable", search); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); } @Test public void testNestedTensorFlowReference() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: 5 + sum(tensorflow('" + modelDirectory + "'))" + + " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" + " }\n" + " }"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertConstant(10, "Variable_1", search); - assertConstant(7840, "Variable", search); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); } @Test public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -75,10 +92,12 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'y')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -87,18 +106,21 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { try { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_defaultz')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + - modelDirectory + "','serving_defaultz'): Model does not have the specified signature 'serving_defaultz'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + + "Model does not have the specified signature 'serving_defaultz'", Exceptions.toMessageString(expected)); } } @@ -106,36 +128,83 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { try { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'x')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + - modelDirectory + "','serving_default','x'): Model does not have the specified output 'x'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved','serving_default','x'): " + + "Model does not have the specified output 'x'", Exceptions.toMessageString(expected)); } } - private void assertConstant(int expectedSize, String name, RankProfileSearchFixture search) { + @Test + public void testImportingFromStoredExpressions() throws ParseException, IOException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); + RankProfileSearchFixture search = new RankProfileSearchFixture( + application, + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture( + storedApplication, + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + + " }\n" + + " }"); + searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); + // Verify that the constants exists, but don't verify the content as we are not + // simulating file distribution in this test + assertConstant("Variable_1", searchFromStored, Optional.empty()); + assertConstant("Variable", searchFromStored, Optional.empty()); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + + } + + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + private void assertConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { try { - TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove - if (constant == null) { // New way - File constantFile = new File(modelDirectory.substring(3) + "/converted_variables", name + ".tbf"); - RankingConstant rankingConstant = search.search().getRankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); - assertEquals(constantFile.getAbsolutePath(), rankingConstant.getFileName()); - assertTrue("Constant file has been written", constantFile.exists()); - Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantFile))); - assertEquals(expectedSize, deserializedConstant.size()); - } else { // Old way. TODO: Remove - assertNotNull(name + " is imported", constant); - assertEquals(expectedSize, constant.asTensor().size()); + Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); + RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + assertEquals(name, rankingConstant.getName()); + assertEquals(constantApplicationPackagePath.toString(), rankingConstant.getFileName()); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); } } catch (IOException e) { @@ -143,4 +212,138 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private static class StoringApplicationPackage extends MockApplicationPackage { + + private final File root; + + StoringApplicationPackage(Path applicationPackageWritableRoot) { + this(applicationPackageWritableRoot.toFile()); + } + + StoringApplicationPackage(File applicationPackageWritableRoot) { + super(null, null, Collections.emptyList(), null, + null, null, false); + this.root = applicationPackageWritableRoot; + } + + @Override + public File getFileReference(Path path) { + return Path.fromString(root.toString()).append(path).toFile(); + } + + @Override + public ApplicationFile getFile(Path file) { + return new StoringApplicationPackageFile(file, Path.fromString(root.toString())); + } + + } + + private static class StoringApplicationPackageFile extends ApplicationFile { + + /** The path to the application package root */ + private final Path root; + + /** The File pointing to the actual file represented by this */ + private final File file; + + StoringApplicationPackageFile(Path filePath, Path applicationPackagePath) { + super(filePath); + this.root = applicationPackagePath; + file = applicationPackagePath.append(filePath).toFile(); + } + + @Override + public boolean isDirectory() { + return file.isDirectory(); + } + + @Override + public boolean exists() { + return file.exists(); + } + + @Override + public Reader createReader() throws FileNotFoundException { + try { + if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); + return IOUtils.createReader(file, "UTF-8"); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public InputStream createInputStream() throws FileNotFoundException { + try { + if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); + return new BufferedInputStream(new FileInputStream(file)); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ApplicationFile createDirectory() { + file.mkdirs(); + return this; + } + + @Override + public ApplicationFile writeFile(Reader input) { + try { + IOUtils.writeFile(file, IOUtils.readAll(input), false); + return this; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public List<ApplicationFile> listFiles(PathFilter filter) { + if ( ! isDirectory()) return Collections.emptyList(); + return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString()))) + .map(f -> new StoringApplicationPackageFile(asApplicationRelativePath(f), + root)) + .collect(Collectors.toList()); + } + + @Override + public ApplicationFile delete() { + file.delete(); + return this; + } + + @Override + public MetaData getMetaData() { + throw new UnsupportedOperationException(); + } + + @Override + public int compareTo(ApplicationFile other) { + return this.getPath().getName().compareTo((other).getPath().getName()); + } + + /** Strips the application package root path prefix from the path of the given file */ + private Path asApplicationRelativePath(File file) { + Path path = Path.fromString(file.toString()); + + Iterator<String> pathIterator = path.iterator(); + // Skip the path elements this shares with the root + for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) { + String rootElement = rootIterator.next(); + String pathElement = pathIterator.next(); + if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken"); + } + // Build a path from the remaining + Path relative = Path.fromString(""); + while (pathIterator.hasNext()) + relative = relative.append(pathIterator.next()); + return relative; + } + + } + } |