diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-16 21:55:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-16 21:55:09 +0100 |
commit | ff593a0ce8745cbccf86eb538e705a63b94b55e1 (patch) | |
tree | 8582941ee9f19aee612094f235007d6ab447dc62 | |
parent | d9e17187fe49f662520d282c38e5cf779cbb8195 (diff) |
Access files through application package
15 files changed, 215 insertions, 112 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java index 97322fc1c55..1f5c39e90e0 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java @@ -55,7 +55,7 @@ import static com.yahoo.text.Lowercase.toLowerCase; * Construct using {@link com.yahoo.config.model.application.provider.FilesApplicationPackage#fromFile(java.io.File)} or * {@link com.yahoo.config.model.application.provider.FilesApplicationPackage#fromFileWithDeployData(java.io.File, DeployData)}. * - * @author vegardh + * @author Vegard Havdal */ public class FilesApplicationPackage implements ApplicationPackage { @@ -97,15 +97,15 @@ public class FilesApplicationPackage implements ApplicationPackage { } /** Creates package from a local directory, typically deploy app */ - public static FilesApplicationPackage fromFileWithDeployData(File appDir, DeployData deployData, + public static FilesApplicationPackage fromFileWithDeployData(File appDir, DeployData deployData, boolean includeSourceFiles) { return new Builder(appDir).includeSourceFiles(includeSourceFiles).deployData(deployData).build(); } private static ApplicationMetaData metaDataFromDeployData(File appDir, DeployData deployData) { - return new ApplicationMetaData(deployData.getDeployedByUser(), deployData.getDeployedFromDir(), - deployData.getDeployTimestamp(), deployData.getApplicationName(), - computeCheckSum(appDir), deployData.getGeneration(), + return new ApplicationMetaData(deployData.getDeployedByUser(), deployData.getDeployedFromDir(), + deployData.getDeployTimestamp(), deployData.getApplicationName(), + computeCheckSum(appDir), deployData.getGeneration(), deployData.getCurrentlyActiveGeneration()); } @@ -385,9 +385,9 @@ public class FilesApplicationPackage implements ApplicationPackage { } } - /** + /** * Creates a reader for a config definition - * + * * @param defPath the path to the application package * @return the reader of this config definition */ @@ -456,10 +456,10 @@ public class FilesApplicationPackage implements ApplicationPackage { if (defs.containsKey(key)) { if (nv[0].contains(".")) { - log.log(LogLevel.INFO, "Two config definitions found for the same name and namespace: " + key + + log.log(LogLevel.INFO, "Two config definitions found for the same name and namespace: " + key + ". The file '" + def + "' will take precedence"); } else { - log.log(LogLevel.INFO, "Two config definitions found for the same name and namespace: " + key + + log.log(LogLevel.INFO, "Two config definitions found for the same name and namespace: " + key + ". Skipping '" + def + "', as it does not contain namespace in filename"); continue; // skip } @@ -704,7 +704,7 @@ public class FilesApplicationPackage implements ApplicationPackage { } } - + /** * Adds the given path to the digest, or does nothing if path is neither file nor dir * @param path path to add to message digest diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index 83d12718b6a..9a7e1960696 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -135,7 +135,7 @@ public interface ApplicationPackage { */ List<NamedReader> getFiles(Path pathFromRoot, String suffix, boolean recurse); - /** Same as getFiles(pathFromRoot,suffix,false) */ + /** Same as getFiles(pathFromRoot, suffix, false) */ default List<NamedReader> getFiles(Path pathFromRoot, String suffix) { return getFiles(pathFromRoot,suffix,false); } diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java index ddee0be6e9c..271ec6958ec 100644 --- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java +++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java @@ -37,8 +37,8 @@ public class MockApplicationPackage implements ApplicationPackage { private final Optional<String> validationOverrides; private final boolean failOnValidateXml; - private MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir, - String deploymentSpec, String validationOverrides, boolean failOnValidateXml) { + protected MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir, + String deploymentSpec, String validationOverrides, boolean failOnValidateXml) { this.hostsS = hosts; this.servicesS = services; this.searchDefinitions = searchDefinitions; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index c6b07b25bb4..df5697de0d5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -267,6 +267,8 @@ public class Search implements Serializable, ImmutableSearch { return sourceApplication.getRankingExpression(fileName); } + public ApplicationPackage sourceApplication() { return sourceApplication; } + /** * Returns a field defined in this search definition or one if its documents. Fields in this search definition takes * precedence over document fields having the same name diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 7cefa9d9187..6c1d104fc89 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -41,7 +41,7 @@ import java.util.Optional; public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { // TODO: Make system test work with this set to true, then remove the "true" path - private static final boolean constantsInConfig = false; + private static final boolean constantsInConfig = true; private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); @@ -62,49 +62,37 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - FeatureArguments arguments = new FeatureArguments(feature.getArguments()); - if (arguments.modelDir().exists()) - return transformFromTensorFlowModel(arguments, context.rankProfile()); - else - return transformFromStoredConvertedModel(arguments); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), + feature.getArguments()); + if (store.hasTensorFlowModels()) + return transformFromTensorFlowModel(store, context.rankProfile()); + else // is should have previously stored model information instead + return store.readConverted().getRoot(); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); } } - private ExpressionNode transformFromTensorFlowModel(FeatureArguments arguments, RankProfile rankProfile) { - TensorFlowModel model = importedModels.computeIfAbsent(arguments.modelPath(), - k -> tensorFlowImporter.importModel(arguments.modelDir().toString())); + private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile rankProfile) { + TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> tensorFlowImporter.importModel(store.tensorFlowModelDir())); // Find the specified expression - Signature signature = chooseSignature(model, arguments.signature()); - String output = chooseOutput(signature, arguments.output()); + Signature signature = chooseSignature(model, store.arguments().signature()); + String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); - writeConverted(arguments, expression); + store.writeConverted(expression); - // Add all constants (after finding outputs to fail faster when the output is not found) + // Add all constants (after finding outputs to fail faster when the output is not found) TODO: Remove the first path if (constantsInConfig) model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v))); else // correct way, disabled for now - model.constants().forEach((k, v) -> transformConstant(arguments, rankProfile, k, v)); + model.constants().forEach((k, v) -> transformConstant(store, rankProfile, k, v)); return expression.getRoot(); } - private ExpressionNode transformFromStoredConvertedModel(FeatureArguments arguments) { - File expressionFile = null; - try { - return new RankingExpression(IOUtils.readFile(arguments.expressionFile())).getRoot(); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + expressionFile, e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + expressionFile, e); - } - } - /** * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. @@ -158,38 +146,111 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private void writeConverted(FeatureArguments arguments, RankingExpression expression) { - try { - IOUtils.writeFile(arguments.expressionFile(), expression.getRoot().toString(), false); + private void transformConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + if (profile.getSearch().getRankingConstants().containsKey(constantName)) return; + + Path constantPath = store.writeConstant(constantName, constantValue); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } + + private String skippedOutputsDescription(TensorFlowModel.Signature signature) { + if (signature.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); + } + + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + private static class ModelStore { + + private final ApplicationPackage application; + private final FeatureArguments arguments; + + public ModelStore(ApplicationPackage application, Arguments arguments) { + this.application = application; + this.arguments = new FeatureArguments(arguments); + } + + public FeatureArguments arguments() { return arguments; } + + public boolean hasTensorFlowModels() { + try { + return application.getFileReference(ApplicationPackage.MODELS_DIR).exists(); + } + catch (UnsupportedOperationException e) { + return false; // No files -> no TensorFlow models + } } - catch (IOException e) { - throw new UncheckedIOException(e); + + /** + * Returns the directory which (if hasTensorFlowModels is true) + * contains the source model to use for these arguments + */ + public File tensorFlowModelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); } - } - private void transformConstant(FeatureArguments arguments, RankProfile profile, String constantName, Tensor constantValue) { - try { - if (profile.getSearch().getRankingConstants().containsKey(constantName)) return; + /** + * Adds this expression to the application package, such that it can be read later. + */ + public void writeConverted(RankingExpression expression) { + try { + // We don't really need to store this as a file - we could keep it in memory in the application + // package until we write it to ZooKeeper. However, we need to write constants to the models_generated + // directory in any case (as they are distributed over file distribution), + // so we just reuse the same mechanism for expressions + Path expressionsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("expressions"); + createIfNeeded(expressionsPath); + IOUtils.writeFile(application.getFileReference(expressionsPath.append(arguments.expressionFileName())), + expression.getRoot().toString(), false); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } - if ( ! arguments.constantsDir().exists()) - if ( ! arguments.constantsDir().mkdir()) - throw new IOException("Could not create directory " + arguments.constantsDir()); + /** Reads the previously stored ranking expression for these arguments */ + public RankingExpression readConverted() { + // TODO: ZK integrate + Path expressionPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("expressions").append(arguments.expressionFileName()); + try { + return new RankingExpression(IOUtils.readFile(application.getFileReference(expressionPath))); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + expressionPath, e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + expressionPath, e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + public Path writeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + createIfNeeded(constantsPath); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - File constantFile = new File(arguments.constantsDir(), constantName + ".tbf"); - IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue)); - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath())); + Path constantPath = constantsPath.append(name + ".tbf"); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return constantPath; } - catch (IOException e) { - throw new UncheckedIOException(e); + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } } - } - private String skippedOutputsDescription(TensorFlowModel.Signature signature) { - if (signature.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); } /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ @@ -231,24 +292,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - public File expressionFile() { - try { - StringBuilder fileName = new StringBuilder(); - signature.ifPresent(s -> fileName.append(s).append(".")); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - - return new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath) - .append("expressions") - .append(fileName.toString()) - .getRelative()) - .getCanonicalFile(); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } + public String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + signature.ifPresent(s -> fileName.append(s).append(".")); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); } public File constantsDir() { diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/config-model/src/test/integration/tensorflow/models/mnist_softmax/mnist_sftmax_with_saving.py index a1861a1c981..a1861a1c981 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/mnist_sftmax_with_saving.py diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/saved_model.pbtxt index 8100dfd594d..8100dfd594d 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/saved_model.pbtxt diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differindex 8474aa0a04c..8474aa0a04c 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.index Binary files differindex cfcdac20409..cfcdac20409 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.index diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index ff53fdafacf..7c749608e1f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,5 +1,7 @@ package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; @@ -22,7 +24,11 @@ class RankProfileSearchFixture { private Search search; RankProfileSearchFixture(String rankProfiles) throws ParseException { - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + this(MockApplicationPackage.createEmpty(), rankProfiles); + } + + RankProfileSearchFixture(ApplicationPackage applicationpackage, String rankProfiles) throws ParseException { + SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry); String sdContent = "search test {\n" + " document test {\n" + " }\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index efd2f51ba42..d70230d75d8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,8 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -15,6 +18,7 @@ import org.junit.Test; import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.Collections; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -27,22 +31,22 @@ import static org.junit.Assert.fail; */ public class RankingExpressionWithTensorFlowTestCase { - // The "../" is to escape the "models/" element prepended to the path - private final String modelDirectory = "../src/test/integration/tensorflow/mnist_softmax/saved"; + private final String applicationDirectory = "src/test/integration/tensorflow/"; private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { - IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "constants")); - IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "expressions")); + IOUtils.recursiveDeleteDir(new File(applicationDirectory, ApplicationPackage.MODELS_GENERATED_DIR.toString())); } @Test public void testMinimalTensorFlowReference() throws ParseException { + MockStoringApplicationPackage application = new MockStoringApplicationPackage(applicationDirectory); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "')" + + " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -52,10 +56,12 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testNestedTensorFlowReference() throws ParseException { + MockStoringApplicationPackage application = new MockStoringApplicationPackage(applicationDirectory); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: 5 + sum(tensorflow('" + modelDirectory + "'))" + + " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" + " }\n" + " }"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); @@ -65,10 +71,12 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + MockStoringApplicationPackage application = new MockStoringApplicationPackage(applicationDirectory); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -76,10 +84,12 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { + MockStoringApplicationPackage application = new MockStoringApplicationPackage(applicationDirectory); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'y')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -88,18 +98,21 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { try { + MockStoringApplicationPackage application = new MockStoringApplicationPackage(applicationDirectory); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_defaultz')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + - modelDirectory + "','serving_defaultz'): Model does not have the specified signature 'serving_defaultz'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + + "Model does not have the specified signature 'serving_defaultz'", Exceptions.toMessageString(expected)); } } @@ -107,18 +120,21 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { try { + MockStoringApplicationPackage application = new MockStoringApplicationPackage(applicationDirectory); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'x')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + - modelDirectory + "','serving_default','x'): Model does not have the specified output 'x'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved','serving_default','x'): " + + "Model does not have the specified output 'x'", Exceptions.toMessageString(expected)); } } @@ -127,12 +143,16 @@ public class RankingExpressionWithTensorFlowTestCase { try { TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove if (constant == null) { // New way - File constantFile = new File(modelDirectory.substring(3) + "/constants", name + ".tbf"); + Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); RankingConstant rankingConstant = search.search().getRankingConstants().get(name); assertEquals(name, rankingConstant.getName()); - assertEquals(constantFile.getAbsolutePath(), rankingConstant.getFileName()); - assertTrue("Constant file has been written", constantFile.exists()); - Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantFile))); + assertEquals(constantApplicationPackagePath.toString(), rankingConstant.getFileName()); + + Path constantPath = Path.fromString(applicationDirectory).append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); assertEquals(expectedSize, deserializedConstant.size()); } else { // Old way. TODO: Remove assertNotNull(name + " is imported", constant); @@ -144,4 +164,25 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private static class MockStoringApplicationPackage extends MockApplicationPackage { + + private final File root; + + public MockStoringApplicationPackage(String applicationPackageWritableRoot) { + this(new File(applicationPackageWritableRoot)); + } + + public MockStoringApplicationPackage(File applicationPackageWritableRoot) { + super(null, null, Collections.emptyList(), null, + null, null, false); + this.root = applicationPackageWritableRoot; + } + + @Override + public File getFileReference(Path path) { + return Path.fromString(root.toString()).append(path).toFile(); + } + + } + } diff --git a/config/src/test/java/com/yahoo/config/subscription/ConfigApiTest.java b/config/src/test/java/com/yahoo/config/subscription/ConfigApiTest.java index 5419100fcf1..c0080091db6 100755 --- a/config/src/test/java/com/yahoo/config/subscription/ConfigApiTest.java +++ b/config/src/test/java/com/yahoo/config/subscription/ConfigApiTest.java @@ -12,10 +12,10 @@ import static org.hamcrest.CoreMatchers.is; /** * Tests ConfigSubscriber API, and the ConfigHandle class. * - * @author <a href="gv@yahoo-inc.com">Harald Musum</a> - * @since 5.1 + * @author Harald Musum */ public class ConfigApiTest { + private static final String CONFIG_ID = "raw:" + "times 1\n"; @Test diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java index 94707635950..8ac992821fd 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java @@ -2,25 +2,24 @@ package com.yahoo.vespa.config.server.http; import com.google.inject.Inject; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; -import com.yahoo.container.logging.AccessLog; import com.yahoo.log.LogLevel; import com.yahoo.vespa.config.protocol.ConfigResponse; import com.yahoo.vespa.config.server.RequestHandler; import com.yahoo.vespa.config.server.tenant.Tenants; -import com.yahoo.config.provision.ApplicationId; import java.util.Optional; -import java.util.concurrent.Executor; /** * HTTP handler for a v2 getConfig operation * - * @author lulf - * @since 5.1 + * @author Ulf Lilleengen */ +// TODO: Make this API discoverable public class HttpGetConfigHandler extends HttpHandler { + private final RequestHandler requestHandler; public HttpGetConfigHandler(HttpHandler.Context ctx, RequestHandler requestHandler) { @@ -28,11 +27,12 @@ public class HttpGetConfigHandler extends HttpHandler { this.requestHandler = requestHandler; } + @SuppressWarnings("unused") // injected @Inject public HttpGetConfigHandler(HttpHandler.Context ctx, Tenants tenants) { this(ctx, tenants.defaultTenant().getRequestHandler()); } - + @Override public HttpResponse handleGET(HttpRequest req) { HttpConfigRequest request = HttpConfigRequest.createFromRequestV1(req); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 42945c59105..45f2b21343f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -14,6 +14,7 @@ import org.tensorflow.framework.SignatureDef; import org.tensorflow.framework.TensorInfo; import org.tensorflow.framework.TensorShapeProto; +import java.io.File; import java.io.IOException; import java.util.List; import java.util.Map; @@ -30,7 +31,7 @@ public class TensorFlowImporter { /** * Imports a saved TensorFlow model from a directory. - * The model should be saved as a pbtxt file. + * The model should be saved as a .pbtxt or .pb file. * The name of the model is taken as the db/pbtxt file name (not including the file ending). * * @param modelDir the directory containing the TensorFlow model files to import @@ -44,6 +45,10 @@ public class TensorFlowImporter { } } + public TensorFlowModel importModel(File modelDir) { + return importModel(modelDir.toString()); + } + /** Imports a TensorFlow model */ public TensorFlowModel importModel(SavedModelBundle model) { try { diff --git a/vespajlib/src/main/java/com/yahoo/path/Path.java b/vespajlib/src/main/java/com/yahoo/path/Path.java index 17bf9e9fe35..da55c6767d1 100644 --- a/vespajlib/src/main/java/com/yahoo/path/Path.java +++ b/vespajlib/src/main/java/com/yahoo/path/Path.java @@ -8,12 +8,10 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; -// TODO: Remove and replace usage by java.nio.file.Path - /** * Represents a path represented by a list of elements. Immutable * - * @author lulf + * @author Ulf Lilleengen */ @Beta public final class Path { |