diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-01-17 13:51:14 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-17 13:51:14 +0100 |
commit | fd26b36e3607df463b35e856b37d24b5e3514fb7 (patch) | |
tree | 403836969d050736403f6512a455198a2c63edad /config-model | |
parent | ceec6d572c06ff812715c97d2c35383c48402f24 (diff) | |
parent | c84b8f952ef5857aa44fad479551eda1f3a4e106 (diff) |
Merge pull request #4692 from vespa-engine/bratseth/store-converted-expressions-in-zk
Bratseth/store converted expressions in zk
Diffstat (limited to 'config-model')
13 files changed, 477 insertions, 128 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java b/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java index 5daf5ca70a5..385cd883da4 100644 --- a/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java +++ b/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java @@ -8,7 +8,7 @@ package com.yahoo.config.model; * * @author gjoranv * @author bratseth - * @author lulf + * @author Ulf Lilleengen */ public abstract class ConfigModel { diff --git a/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java b/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java index 5eb4afcc241..5912b476783 100644 --- a/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java +++ b/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java @@ -21,8 +21,7 @@ import java.util.*; /** * Config model adaptor of the Admin class. * - * @author lulf - * @since 5.1 + * @author Ulf Lilleengen */ public class AdminModel extends ConfigModel { @@ -46,8 +45,9 @@ public class AdminModel extends ConfigModel { @Override public void prepare(ConfigModelRepo configModelRepo) { verifyClusterControllersOnlyDefinedForContent(configModelRepo); - if (admin == null || admin.getClusterControllers() == null) return; - admin.getClusterControllers().prepare(); + if (admin == null) return; + if (admin.getClusterControllers() != null) + admin.getClusterControllers().prepare(); } private void verifyClusterControllersOnlyDefinedForContent(ConfigModelRepo configModelRepo) { @@ -61,9 +61,9 @@ public class AdminModel extends ConfigModel { public static class BuilderV2 extends ConfigModelBuilder<AdminModel> { public static final List<ConfigModelId> configModelIds = - ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "2.0"), + ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "2.0"), ConfigModelId.fromNameAndVersion("admin", "1.0")); - + public BuilderV2() { super(AdminModel.class); } @@ -91,7 +91,7 @@ public class AdminModel extends ConfigModel { public static class BuilderV4 extends ConfigModelBuilder<AdminModel> { public static final List<ConfigModelId> configModelIds = - ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "3.0"), + ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "3.0"), ConfigModelId.fromNameAndVersion("admin", "4.0")); public BuilderV4() { diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java index ddee0be6e9c..271ec6958ec 100644 --- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java +++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java @@ -37,8 +37,8 @@ public class MockApplicationPackage implements ApplicationPackage { private final Optional<String> validationOverrides; private final boolean failOnValidateXml; - private MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir, - String deploymentSpec, String validationOverrides, boolean failOnValidateXml) { + protected MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir, + String deploymentSpec, String validationOverrides, boolean failOnValidateXml) { this.hostsS = hosts; this.servicesS = services; this.searchDefinitions = searchDefinitions; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index f37ab9fb89f..df5697de0d5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -165,9 +165,8 @@ public class Search implements Serializable, ImmutableSearch { public void addRankingConstant(RankingConstant constant) { constant.validate(); String name = constant.getName(); - if (rankingConstants.get(name) != null) { - throw new IllegalArgumentException("Ranking constant '"+name+"' defined twice"); - } + if (rankingConstants.containsKey(name)) + throw new IllegalArgumentException("Ranking constant '" + name + "' defined twice"); rankingConstants.put(name, constant); } @@ -268,6 +267,8 @@ public class Search implements Serializable, ImmutableSearch { return sourceApplication.getRankingExpression(fileName); } + public ApplicationPackage sourceApplication() { return sourceApplication; } + /** * Returns a field defined in this search definition or one if its documents. Fields in this search definition takes * precedence over document fields having the same name diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 32f8f4871df..0dd5b4166ef 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -2,14 +2,17 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -17,12 +20,16 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.File; import java.io.IOException; +import java.io.StringReader; import java.io.UncheckedIOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -33,17 +40,14 @@ import java.util.Optional; * * @author bratseth */ -// TODO: - Verify types of macros -// - Avoid name conflicts across models for constants +// TODO: Verify types of macros +// TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - // TODO: Make system test work with this set to true, then remove the "true" path - private static final boolean constantsInConfig = true; - private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<String, TensorFlowModel> importedModels = new HashMap<>(); + private final Map<Path, TensorFlowModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -56,40 +60,48 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - try { - if ( ! feature.getName().equals("tensorflow")) return feature; + if ( ! feature.getName().equals("tensorflow")) return feature; - if (feature.getArguments().isEmpty()) - throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); + try { + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), + feature.getArguments()); + if (store.hasTensorFlowModels()) + return transformFromTensorFlowModel(store, context.rankProfile()); + else // is should have previously stored model information instead + return transformFromStoredModel(store, context.rankProfile()); + } + catch (IllegalArgumentException | UncheckedIOException e) { + throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); + } + } - String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0)); - TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); + private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile profile) { + TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> tensorFlowImporter.importModel(store.tensorFlowModelDir())); - // Find the specified expression - TensorFlowModel.Signature signature = chooseSignature(result, - optionalArgument(1, feature.getArguments())); - RankingExpression expression = chooseOutput(signature, - optionalArgument(2, feature.getArguments())); + // Find the specified expression + Signature signature = chooseSignature(model, store.arguments().signature()); + String output = chooseOutput(signature, store.arguments().output()); + RankingExpression expression = model.expressions().get(output); + store.writeConverted(expression); - // Add all constants (after finding outputs to fail faster when the output is not found) - if (constantsInConfig) - result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); - else // correct way, disabled for now - result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v)); + model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); + return expression.getRoot(); + } - return expression.getRoot(); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); + private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (RankingConstant constant : store.readRankingConstants()) { + if (!profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); } + return store.readConverted().getRoot(); } /** * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. */ - private TensorFlowModel.Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) { + private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) { if ( ! signatureName.isPresent()) { if (importResult.signatures().size() == 0) throw new IllegalArgumentException("No signatures are available"); @@ -101,7 +113,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return importResult.signatures().values().stream().findFirst().get(); } else { - TensorFlowModel.Signature signature = importResult.signatures().get(signatureName.get()); + Signature signature = importResult.signatures().get(signatureName.get()); if (signature == null) throw new IllegalArgumentException("Model does not have the specified signature '" + signatureName.get() + "'"); @@ -113,7 +125,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * Returns the specified, existing output expression, or the only output expression if no output name is specified. * Throws IllegalArgumentException in all other cases. */ - private RankingExpression chooseOutput(TensorFlowModel.Signature signature, Optional<String> outputName) { + private String chooseOutput(Signature signature, Optional<String> outputName) { if ( ! outputName.isPresent()) { if (signature.outputs().size() == 0) throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); @@ -122,11 +134,11 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil Joiner.on(", ").join(signature.outputs().keySet()) + "), one must be specified " + "as a third argument to tensorflow()"); - return signature.outputExpression(signature.outputs().keySet().stream().findFirst().get()); + return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); } else { - RankingExpression expression = signature.outputExpression(outputName.get()); - if (expression == null) { + String output = signature.outputs().get(outputName.get()); + if (output == null) { if (signature.skippedOutputs().containsKey(outputName.get())) throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + signature.skippedOutputs().get(outputName.get())); @@ -134,28 +146,16 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil throw new IllegalArgumentException("Model does not have the specified output '" + outputName.get() + "'"); } - return expression; + return output; } } - private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) { - try { - if (profile.getSearch().getRankingConstants().containsKey(constantName)) return; + private void transformConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + if (profile.getSearch().getRankingConstants().containsKey(constantName)) return; - File constantFilePath = new File(modelPath, "converted_variables").getCanonicalFile(); - if (!constantFilePath.exists()) { - if (!constantFilePath.mkdir()) - throw new IOException("Could not create directory " + constantFilePath); - } - - // "tbf" ending for "typed binary format" - recognized by the nodes reciving the file: - File constantFile = new File(constantFilePath, constantName + ".tbf"); - IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue)); - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath())); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } + Path constantPath = store.writeConstant(constantName, constantValue); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); } private String skippedOutputsDescription(TensorFlowModel.Signature signature) { @@ -165,27 +165,176 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return b.toString(); } - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + private static class ModelStore { - private String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } + private final ApplicationPackage application; + private final FeatureArguments arguments; + + public ModelStore(ApplicationPackage application, Arguments arguments) { + this.application = application; + this.arguments = new FeatureArguments(arguments); + } + + public FeatureArguments arguments() { return arguments; } + + public boolean hasTensorFlowModels() { + try { + return application.getFileReference(ApplicationPackage.MODELS_DIR).exists(); + } + catch (UnsupportedOperationException e) { + return false; // No files -> no TensorFlow models + } + } + + /** + * Returns the directory which (if hasTensorFlowModels is true) + * contains the source model to use for these arguments + */ + public File tensorFlowModelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + } + + /** + * Adds this expression to the application package, such that it can be read later. + */ + public void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) + .writeFile(new StringReader(expression.getRoot().toString())); + } + + /** Reads the previously stored ranking expression for these arguments */ + public RankingExpression readConverted() { + try { + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + + /** + * Reads the information about all the constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + public List<RankingConstant> readRankingConstants() { + try { + List<RankingConstant> constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + public Path writeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: + Path constantPath = constantsPath.append(name + ".tbf"); + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.rankingConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPath)); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return constantPath; + } + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } + } - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); } - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; + /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ + private static class FeatureArguments { + + private final Path modelPath; + + /** Optional arguments */ + private final Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); + + modelPath = Path.fromString(asString(arguments.expressions().get(0))); + signature = optionalArgument(1, arguments); + output = optionalArgument(2, arguments); + } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + public Path rankingConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); + } + + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + signature.ifPresent(s -> fileName.append(s).append(".")); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); + } + + private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java index 59b7388f5bb..071b3090f99 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java @@ -73,9 +73,7 @@ public class Admin extends AbstractConfigProducer implements Serializable { this.fileDistribution = fileDistributionConfigProducer; } - public Configserver getConfigserver() { - return defaultConfigserver; - } + public Configserver getConfigserver() { return defaultConfigserver; } /** Returns the configured monitoring endpoint, or null if not configured */ public Monitoring getMonitoring() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index fd062dc4ea4..58fc76f1508 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -49,16 +49,7 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer public static final IndexingMode REALTIME = new IndexingMode("REALTIME"); public static final IndexingMode STREAMING = new IndexingMode("STREAMING"); - public static IndexingMode createIndexingMode(String ixm) { - if ("REALTIME".equalsIgnoreCase(ixm)) { - return REALTIME; - } else if ("STREAMING".equalsIgnoreCase(ixm)) { - return STREAMING; - } - return null; - } - - private String name; + private final String name; private IndexingMode(String name) { this.name = name; @@ -72,6 +63,7 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer } public static final class SearchDefinitionSpec { + private final SearchDefinition searchDefinition; private final UserConfigRepo userConfigRepo; diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/config-model/src/test/integration/tensorflow/models/mnist_softmax/mnist_sftmax_with_saving.py index a1861a1c981..a1861a1c981 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/mnist_sftmax_with_saving.py diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/saved_model.pbtxt index 8100dfd594d..8100dfd594d 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/saved_model.pbtxt diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differindex 8474aa0a04c..8474aa0a04c 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.index Binary files differindex cfcdac20409..cfcdac20409 100644 --- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index +++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.index diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index ff53fdafacf..7c749608e1f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,5 +1,7 @@ package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; @@ -22,7 +24,11 @@ class RankProfileSearchFixture { private Search search; RankProfileSearchFixture(String rankProfiles) throws ParseException { - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + this(MockApplicationPackage.createEmpty(), rankProfiles); + } + + RankProfileSearchFixture(ApplicationPackage applicationpackage, String rankProfiles) throws ParseException { + SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry); String sdContent = "search test {\n" + " document test {\n" + " }\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 31f7511155b..0354173f365 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,24 +1,36 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; +import java.io.BufferedInputStream; import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -27,47 +39,52 @@ import static org.junit.Assert.fail; */ public class RankingExpressionWithTensorFlowTestCase { - // The "../" is to escape the "models/" element prepended to the path - private final String modelDirectory = "../src/test/integration/tensorflow/mnist_softmax/saved"; + private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { - IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "converted_variables")); + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); } @Test public void testMinimalTensorFlowReference() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "')" + + " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant(10, "Variable_1", search); - assertConstant(7840, "Variable", search); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); } @Test public void testNestedTensorFlowReference() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: 5 + sum(tensorflow('" + modelDirectory + "'))" + + " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" + " }\n" + " }"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertConstant(10, "Variable_1", search); - assertConstant(7840, "Variable", search); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); } @Test public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -75,10 +92,12 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'y')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -87,18 +106,21 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { try { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_defaultz')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + - modelDirectory + "','serving_defaultz'): Model does not have the specified signature 'serving_defaultz'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + + "Model does not have the specified signature 'serving_defaultz'", Exceptions.toMessageString(expected)); } } @@ -106,36 +128,83 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { try { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( + application, " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'x')" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + - modelDirectory + "','serving_default','x'): Model does not have the specified output 'x'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved','serving_default','x'): " + + "Model does not have the specified output 'x'", Exceptions.toMessageString(expected)); } } - private void assertConstant(int expectedSize, String name, RankProfileSearchFixture search) { + @Test + public void testImportingFromStoredExpressions() throws ParseException, IOException { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); + RankProfileSearchFixture search = new RankProfileSearchFixture( + application, + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture( + storedApplication, + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + + " }\n" + + " }"); + searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); + // Verify that the constants exists, but don't verify the content as we are not + // simulating file distribution in this test + assertConstant("Variable_1", searchFromStored, Optional.empty()); + assertConstant("Variable", searchFromStored, Optional.empty()); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + + } + + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + private void assertConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { try { - TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove - if (constant == null) { // New way - File constantFile = new File(modelDirectory.substring(3) + "/converted_variables", name + ".tbf"); - RankingConstant rankingConstant = search.search().getRankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); - assertEquals(constantFile.getAbsolutePath(), rankingConstant.getFileName()); - assertTrue("Constant file has been written", constantFile.exists()); - Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantFile))); - assertEquals(expectedSize, deserializedConstant.size()); - } else { // Old way. TODO: Remove - assertNotNull(name + " is imported", constant); - assertEquals(expectedSize, constant.asTensor().size()); + Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); + RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + assertEquals(name, rankingConstant.getName()); + assertEquals(constantApplicationPackagePath.toString(), rankingConstant.getFileName()); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); } } catch (IOException e) { @@ -143,4 +212,138 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private static class StoringApplicationPackage extends MockApplicationPackage { + + private final File root; + + StoringApplicationPackage(Path applicationPackageWritableRoot) { + this(applicationPackageWritableRoot.toFile()); + } + + StoringApplicationPackage(File applicationPackageWritableRoot) { + super(null, null, Collections.emptyList(), null, + null, null, false); + this.root = applicationPackageWritableRoot; + } + + @Override + public File getFileReference(Path path) { + return Path.fromString(root.toString()).append(path).toFile(); + } + + @Override + public ApplicationFile getFile(Path file) { + return new StoringApplicationPackageFile(file, Path.fromString(root.toString())); + } + + } + + private static class StoringApplicationPackageFile extends ApplicationFile { + + /** The path to the application package root */ + private final Path root; + + /** The File pointing to the actual file represented by this */ + private final File file; + + StoringApplicationPackageFile(Path filePath, Path applicationPackagePath) { + super(filePath); + this.root = applicationPackagePath; + file = applicationPackagePath.append(filePath).toFile(); + } + + @Override + public boolean isDirectory() { + return file.isDirectory(); + } + + @Override + public boolean exists() { + return file.exists(); + } + + @Override + public Reader createReader() throws FileNotFoundException { + try { + if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); + return IOUtils.createReader(file, "UTF-8"); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public InputStream createInputStream() throws FileNotFoundException { + try { + if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); + return new BufferedInputStream(new FileInputStream(file)); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ApplicationFile createDirectory() { + file.mkdirs(); + return this; + } + + @Override + public ApplicationFile writeFile(Reader input) { + try { + IOUtils.writeFile(file, IOUtils.readAll(input), false); + return this; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public List<ApplicationFile> listFiles(PathFilter filter) { + if ( ! isDirectory()) return Collections.emptyList(); + return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString()))) + .map(f -> new StoringApplicationPackageFile(asApplicationRelativePath(f), + root)) + .collect(Collectors.toList()); + } + + @Override + public ApplicationFile delete() { + file.delete(); + return this; + } + + @Override + public MetaData getMetaData() { + throw new UnsupportedOperationException(); + } + + @Override + public int compareTo(ApplicationFile other) { + return this.getPath().getName().compareTo((other).getPath().getName()); + } + + /** Strips the application package root path prefix from the path of the given file */ + private Path asApplicationRelativePath(File file) { + Path path = Path.fromString(file.toString()); + + Iterator<String> pathIterator = path.iterator(); + // Skip the path elements this shares with the root + for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) { + String rootElement = rootIterator.next(); + String pathElement = pathIterator.next(); + if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken"); + } + // Build a path from the remaining + Path relative = Path.fromString(""); + while (pathIterator.hasNext()) + relative = relative.append(pathIterator.next()); + return relative; + } + + } + } |