diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java | 394 |
1 files changed, 232 insertions, 162 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 867740c7912..d85d0983509 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -1,11 +1,11 @@ package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.base.Joiner; import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; +import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.FeatureNames; @@ -39,11 +39,14 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; +import java.io.Reader; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -62,69 +65,155 @@ import java.util.stream.Collectors; */ public class ConvertedModel { - private final ExpressionNode convertedExpression; + private final String modelName; + private final Path modelPath; - public ConvertedModel(FeatureArguments arguments, + /** + * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots + * where the first part is always the model name, the second the signature or (if none) + * expression name (if more than one), and the third is the output name (if any). + */ + private final Map<String, RankingExpression> expressions; + + public ConvertedModel(Path modelPath, RankProfileTransformContext context, ModelImporter modelImporter, - Map<Path, ImportedModel> importedModels) { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); + FeatureArguments arguments) { // TODO: Remove + this.modelPath = modelPath; + this.modelName = toModelName(modelPath); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory - convertedExpression = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, importedModels); + expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, arguments); else - convertedExpression = transformFromStoredModel(store, context.rankProfile()); + expressions = transformFromStoredModel(store, context.rankProfile()); } - private ExpressionNode importModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelImporter modelImporter, - Map<Path, ImportedModel> importedModels) { - ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), - k -> modelImporter.importModel(store.arguments().modelName(), - store.modelDir())); - return transformFromImportedModel(model, store, profile, queryProfiles); + private Map<String, RankingExpression> importModel(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelImporter modelImporter, + FeatureArguments arguments) { + ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.modelDir()); + return transformFromImportedModel(model, store, profile, queryProfiles, arguments); + } + + /** Returns the expression matching the given arguments */ + public ExpressionNode expression(FeatureArguments arguments) { + if (expressions.isEmpty()) + throw new IllegalArgumentException("No expressions available in " + this); + + RankingExpression expression = expressions.get(arguments.toName()); + if (expression != null) return expression.getRoot(); + + if ( ! arguments.signature().isPresent()) { + if (expressions.size() > 1) + throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix()); + return expressions.values().iterator().next().getRoot(); + } + + if ( ! arguments.output().isPresent()) { + List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix = + expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(modelName + "." + arguments.signature().get() + ".")).collect(Collectors.toList()); + if (entriesWithTheRightPrefix.size() < 1) + throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + + missingExpressionMessageSuffix()); + if (entriesWithTheRightPrefix.size() > 1) + throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() + + missingExpressionMessageSuffix()); + return entriesWithTheRightPrefix.get(0).getValue().getRoot(); + } + + throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix()); } - public ExpressionNode expression() { return convertedExpression; } + private String missingExpressionMessageSuffix() { + return "' in model '" + this.modelPath + "'. " + + "Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", ")); + } - private ExpressionNode transformFromImportedModel(ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { + private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + FeatureArguments arguments) { // Add constants Set<String> constantsReplacedByMacros = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, constantsReplacedByMacros, k, v)); - // Find the specified expression - ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature()); - String output = chooseOutput(signature, store.arguments().output()); - if (signature.skippedOutputs().containsKey(output)) { - String message = "Could not import model output '" + output + "'"; - if (!signature.skippedOutputs().get(output).isEmpty()) { - message += ": " + signature.skippedOutputs().get(output); + // Add macros + addGeneratedMacros(model, profile); + + // Add expressions + Map<String, RankingExpression> expressions = new HashMap<>(); + for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) { + if ( ! matches(signatureEntry.getValue(), arguments, Optional.empty())) continue; + + for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) { + if ( ! matches(signatureEntry.getValue(), arguments, Optional.of(outputEntry.getKey()))) continue; + addExpression(model.expressions().get(outputEntry.getValue()), + modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); } - if (!signature.importWarnings().isEmpty()) { - message += ": " + String.join(", ", signature.importWarnings()); + if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs + addExpression(model.expressions().get(signatureEntry.getKey()), + modelName + "." + signatureEntry.getKey(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); } - throw new IllegalArgumentException(message); } + if (model.signatures().isEmpty()) { // fallback: Model without signatures + if (model.expressions().size() == 1) { // Use just model name + addExpression(model.expressions().values().iterator().next(), + modelName, + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); + } + else { + for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) { + addExpression(expressionEntry.getValue(), + modelName + "." + expressionEntry.getKey(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); + } + } + } + + // Transform and save macro - must come after reading expressions due to optimization transforms + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - RankingExpression expression = model.expressions().get(output); + return expressions; + } + + private boolean matches(ImportedModel.Signature signature, FeatureArguments arguments, Optional<String> output) { + if ( ! modelName.equals(arguments.modelName)) return false; + if ( arguments.signature.isPresent() && ! signature.name().equals(arguments.signature().get())) return false; + if (output.isPresent() && arguments.output().isPresent() && ! output.get().matches(arguments.output().get())) return false; + return true; + } + + private void addExpression(RankingExpression expression, + String expressionName, + Set<String> constantsReplacedByMacros, + ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + Map<String, RankingExpression> expressions) { expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); + store.writeExpression(expressionName, expression); + expressions.put(expressionName, expression); } - ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -137,60 +226,7 @@ public class ConvertedModel { addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); } - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing signature, or the only signature if none is specified. - * Throws IllegalArgumentException in all other cases. - */ - private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) { - if ( ! signatureName.isPresent()) { - if (importResult.signatures().size() == 0) - throw new IllegalArgumentException("No signatures are available"); - if (importResult.signatures().size() > 1) - throw new IllegalArgumentException("Model has multiple signatures (" + - Joiner.on(", ").join(importResult.signatures().keySet()) + - "), one must be specified " + - "as a second argument to tensorflow()"); - return importResult.signatures().values().stream().findFirst().get(); - } - else { - ImportedModel.Signature signature = importResult.signatures().get(signatureName.get()); - if (signature == null) - throw new IllegalArgumentException("Model does not have the specified signature '" + - signatureName.get() + "'"); - return signature; - } - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (signature.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); - if (signature.outputs().size() > 1) - throw new IllegalArgumentException(signature + " has multiple outputs (" + - Joiner.on(", ").join(signature.outputs().keySet()) + - "), one must be specified " + - "as a third argument to tensorflow()"); - return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); - } - else { - String output = signature.outputs().get(outputName.get()); - if (output == null) { - if (signature.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - signature.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } + return store.readExpressions(); } private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { @@ -227,22 +263,15 @@ public class ConvertedModel { } private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { + if (profile.getMacros().containsKey(macroName)) throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); - } + profile.addMacro(macroName, false); // todo: inline if only used once RankProfile.Macro macro = profile.getMacros().get(macroName); macro.setRankingExpression(expression); macro.setTextualExpression(expression.getRoot().toString()); } - private String skippedOutputsDescription(ImportedModel.Signature signature) { - if (signature.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); - } - /** * Verify that the macros referred in the given expression exists in the given rank profile, * and return tensors of the types specified in requiredMacros. @@ -375,8 +404,8 @@ public class ConvertedModel { /** * If batch dimensions have been reduced away above, bring them back here * for any following computation of the tensor. - * Todo: determine when this is not necessary! */ + // TODO: determine when this is not necessary! private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { if (after.equals(before)) { return node; @@ -452,24 +481,29 @@ public class ConvertedModel { return new TensorValue(tensor); } + private static String toModelName(Path modelPath) { + return modelPath.toString().replace("/", "_"); + } + + @Override + public String toString() { return "model '" + modelName + "'"; } + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ static class ModelStore { private final ApplicationPackage application; - private final FeatureArguments arguments; + private final ModelFiles modelFiles; - ModelStore(ApplicationPackage application, FeatureArguments arguments) { + ModelStore(ApplicationPackage application, Path modelPath) { this.application = application; - this.arguments = arguments; + this.modelFiles = new ModelFiles(modelPath); } - public FeatureArguments arguments() { return arguments; } - public boolean hasStoredModel() { try { - return application.getFile(arguments.expressionPath()).exists(); + return application.getFileReference(modelFiles.storedModelPath()).exists(); } catch (UnsupportedOperationException e) { return false; @@ -480,40 +514,49 @@ public class ConvertedModel { * Returns the directory which contains the source model to use for these arguments */ public File modelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath())); } /** * Adds this expression to the application package, such that it can be read later. + * + * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part + * is always the model name */ - void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) + void writeExpression(String name, RankingExpression expression) { + application.getFile(modelFiles.expressionPath(name)) .writeFile(new StringReader(expression.getRoot().toString())); } - /** Reads the previously stored ranking expression for these arguments */ - RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + Map<String, RankingExpression> readExpressions() { + Map<String, RankingExpression> expressions = new HashMap<>(); + ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); + if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); + for (ApplicationFile expressionFile : expressionPath.listFiles()) { + try { + String name = expressionFile.getPath().getName(); + expressions.put(name, new RankingExpression(name, expressionFile.createReader())); + } + catch (FileNotFoundException e) { + throw new IllegalStateException("Expression file removed while reading: " + expressionFile, e); + } + catch (ParseException e) { + throw new IllegalStateException("Invalid stored expression in " + expressionFile, e); + } } + return expressions; } /** Adds this macro expression to the application package to it can be read later. */ void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" + expression.getRoot().toString() + "\n"); } /** Reads the previously stored macro expressions for these arguments */ List<Pair<String, RankingExpression>> readMacros() { try { - ApplicationFile file = application.getFile(arguments.macrosPath()); + ApplicationFile file = application.getFile(modelFiles.macrosPath()); if (!file.exists()) return Collections.emptyList(); List<Pair<String, RankingExpression>> macros = new ArrayList<>(); @@ -527,7 +570,7 @@ public class ConvertedModel { macros.add(new Pair<>(name, expression)); } catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + throw new IllegalStateException("Could not parse " + name, e); } } return macros; @@ -544,7 +587,7 @@ public class ConvertedModel { List<RankingConstant> readLargeConstants() { try { List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) { String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); } @@ -562,13 +605,13 @@ public class ConvertedModel { * @return the path to the stored constant, relative to the application package root */ Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants"); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + application.getFile(modelFiles.largeConstantsPath().append(name + ".constant")) .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution @@ -579,7 +622,7 @@ public class ConvertedModel { private List<Pair<String, Tensor>> readSmallConstants() { try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + ApplicationFile file = application.getFile(modelFiles.smallConstantsPath()); if (!file.exists()) return Collections.emptyList(); List<Pair<String, Tensor>> constants = new ArrayList<>(); @@ -604,7 +647,7 @@ public class ConvertedModel { */ public void writeSmallConstant(String name, Tensor constant) { // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" + constant.type().toString() + "\t" + constant.toString() + "\n"); } @@ -628,26 +671,24 @@ public class ConvertedModel { } } + private void close(Reader reader) { + try { + if (reader != null) + reader.close(); + } + catch (IOException e) { + // ignore + } + } + } - /** Encapsulates the arguments to the import feature */ - static class FeatureArguments { + static class ModelFiles { Path modelPath; - /** Optional arguments */ - Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - this(Path.fromString(asString(arguments.expressions().get(0))), - optionalArgument(1, arguments), - optionalArgument(2, arguments)); - } - - public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { + public ModelFiles(Path modelPath) { this.modelPath = modelPath; - this.signature = signature; - this.output = output; } /** Returns modelPath with slashes replaced by underscores */ @@ -655,37 +696,66 @@ public class ConvertedModel { /** Returns relative path to this model below the "models/" dir in the application package */ public Path modelPath() { return modelPath; } - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - /** Path to the small constants file */ + public Path storedModelPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath()); + } + + public Path expressionPath(String name) { + return storedModelPath().append("expressions").append(name); + } + + public Path expressionsPath() { + return storedModelPath().append("expressions"); + } + public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + return storedModelPath().append("constants.txt"); } /** Path to the large (ranking) constants directory */ public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + return storedModelPath().append("constants"); } /** Path to the macros file */ public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + return storedModelPath().append("macros.txt"); } - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); + } + + /** Encapsulates the arguments of a specific model output */ + static class FeatureArguments { + + private final String modelName; + private final Path modelPath; + + /** Optional arguments */ + private final Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { + this(Path.fromString(asString(arguments.expressions().get(0))), + optionalArgument(1, arguments), + optionalArgument(2, arguments)); } - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - signature.ifPresent(s -> fileName.append(s).append(".")); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); + public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { + this.modelPath = modelPath; + this.modelName = toModelName(modelPath); + this.signature = signature; + this.output = output; + } + + public Path modelPath() { return modelPath; } + + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + public String toName() { + return modelName + + (signature.isPresent() ? "." + signature.get() : "") + + (output.isPresent() ? "." + output.get() : ""); } private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { @@ -694,7 +764,7 @@ public class ConvertedModel { return Optional.of(asString(arguments.expressions().get(argumentIndex))); } - private static String asString(ExpressionNode node) { + public static String asString(ExpressionNode node) { if ( ! (node instanceof ConstantNode)) throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); return stripQuotes(((ConstantNode)node).sourceString()); |