diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java | 258 |
1 files changed, 139 insertions, 119 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 0911f567fa1..f7a06f86ab7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -1,5 +1,6 @@ package com.yahoo.searchdefinition.expressiontransforms; +import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; @@ -16,7 +17,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -38,7 +38,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.Reader; import java.io.StringReader; @@ -65,49 +64,91 @@ import java.util.stream.Collectors; public class ConvertedModel { private final String modelName; - private final Path modelPath; - - /** - * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots - * where the first part is always the model name, the second the signature or (if none) - * expression name (if more than one), and the third is the output name (if any). - */ - private final Map<String, RankingExpression> expressions; + private final String modelDescription; + private final ImmutableMap<String, RankingExpression> expressions; + + /** The source importedModel, or empty if this was created from a stored converted model */ + private final Optional<ImportedModel> sourceModel; + + private ConvertedModel(String modelName, + String modelDescription, + Map<String, RankingExpression> expressions, + Optional<ImportedModel> sourceModel) { + this.modelName = modelName; + this.modelDescription = modelDescription; + this.expressions = ImmutableMap.copyOf(expressions); + this.sourceModel = sourceModel; + } /** - * Create a converted model for a rank profile given from either an imported model, + * Create and store a converted model for a rank profile given from either an imported model, * or (if unavailable) from stored application package data. */ - public ConvertedModel(Path modelPath, RankProfileTransformContext context) { - this.modelPath = modelPath; - this.modelName = toModelName(modelPath); - ModelStore store = new ModelStore(context.rankProfile().applicationPackage(), modelPath); - if ( store.hasSourceModel()) - expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels()); + public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) { + File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath); + if (sourceModel.exists()) + return fromSource(toModelName(modelPath), + modelPath.toString(), + context.rankProfile(), + context.queryProfiles(), + context.importedModels().get(sourceModel)); // TODO: Convert to name here, make sure its done just one way else - expressions = transformFromStoredModel(store, context.rankProfile()); + return fromStore(toModelName(modelPath), + modelPath.toString(), + context.rankProfile()); + } + + public static ConvertedModel fromSource(String modelName, + String modelDescription, + RankProfile rankProfile, + QueryProfileRegistry queryProfileRegistry, + ImportedModel importedModel) { + ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); + return new ConvertedModel(modelName, + modelDescription, + convertAndStore(importedModel, rankProfile, queryProfileRegistry, modelStore), + Optional.of(importedModel)); } - private Map<String, RankingExpression> convertModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ImportedModels importedModels) { - ImportedModel model = importedModels.get(store.sourceModelFile()); - return transformFromImportedModel(model, store, profile, queryProfiles); + public static ConvertedModel fromStore(String modelName, + String modelDescription, + RankProfile rankProfile) { + ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); + return new ConvertedModel(modelName, + modelDescription, + convertStored(modelStore, rankProfile), + Optional.empty()); } - /** Returns the expression matching the given arguments */ - public ExpressionNode expression(FeatureArguments arguments) { + /** + * Returns all the output expressions of this indexed by name. The names consist of one or two parts + * separated by dot, where the first part is the signature name + * if signatures are used, or the expression name if signatures are not used and there are multiple + * expressions, and the second is the output name if signature names are used. + */ + public Map<String, RankingExpression> expressions() { return expressions; } + + /** + * Returns the expression matching the given arguments. + */ + public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { + RankingExpression expression = selectExpression(arguments); + if (sourceModel.isPresent()) // we can verify + verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles()); + return expression.getRoot(); + } + + private RankingExpression selectExpression(FeatureArguments arguments) { if (expressions.isEmpty()) throw new IllegalArgumentException("No expressions available in " + this); RankingExpression expression = expressions.get(arguments.toName()); - if (expression != null) return expression.getRoot(); + if (expression != null) return expression; if ( ! arguments.signature().isPresent()) { if (expressions.size() > 1) throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix()); - return expressions.values().iterator().next().getRoot(); + return expressions.values().iterator().next(); } if ( ! arguments.output().isPresent()) { @@ -119,21 +160,23 @@ public class ConvertedModel { if (entriesWithTheRightPrefix.size() > 1) throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() + missingExpressionMessageSuffix()); - return entriesWithTheRightPrefix.get(0).getValue().getRoot(); + return entriesWithTheRightPrefix.get(0).getValue(); } throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix()); } private String missingExpressionMessageSuffix() { - return "' in model '" + this.modelPath + "'. " + + return "' in model '" + modelDescription + "'. " + "Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", ")); } - private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { + // ----------------------- Static model conversion/storage below here + + private static Map<String, RankingExpression> convertAndStore(ImportedModel model, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelStore store) { // Add constants Set<String> constantsReplacedByMacros = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -161,22 +204,21 @@ public class ConvertedModel { return expressions; } - private void addExpression(RankingExpression expression, - String expressionName, - Set<String> constantsReplacedByMacros, - ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - Map<String, RankingExpression> expressions) { + private static void addExpression(RankingExpression expression, + String expressionName, + Set<String> constantsReplacedByMacros, + ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + Map<String, RankingExpression> expressions) { expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model, profile, queryProfiles); reduceBatchDimensions(expression, model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } - private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) { + private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -192,12 +234,12 @@ public class ConvertedModel { return store.readExpressions(); } - private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { store.writeSmallConstant(constantName, constantValue); profile.addConstant(constantName, asValue(constantValue)); } - private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, Set<String> constantsReplacedByMacros, String constantName, Tensor constantValue) { RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); @@ -217,7 +259,7 @@ public class ConvertedModel { } } - private void transformGeneratedMacro(ModelStore store, + private static void transformGeneratedMacro(ModelStore store, Set<String> constantsReplacedByMacros, String macroName, RankingExpression expression) { @@ -226,15 +268,16 @@ public class ConvertedModel { store.writeMacro(macroName, expression); } - private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { if (profile.getMacros().containsKey(macroName)) { if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression)) throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile + - " - with a different definition"); + " - with a different definition" + + ": Has\n" + profile.getMacros().get(macroName).getRankingExpression() + + "\nwant to add " + expression + "\n"); return; } - profile.addMacro(macroName, false); // todo: inline if only used once - RankProfile.Macro macro = profile.getMacros().get(macroName); + RankProfile.Macro macro = profile.addMacro(macroName, false); // TODO: Inline if only used once macro.setRankingExpression(expression); macro.setTextualExpression(expression.getRoot().toString()); } @@ -243,8 +286,8 @@ public class ConvertedModel { * Verify that the macros referred in the given expression exists in the given rank profile, * and return tensors of the types specified in requiredMacros. */ - private void verifyRequiredMacros(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { + private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> macroNames = new HashSet<>(); addMacroNamesIn(expression.getRoot(), macroNames, model); for (String macroName : macroNames) { @@ -272,7 +315,7 @@ public class ConvertedModel { } } - private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { + private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { return "The required type of this is " + requiredType + ", but this macro returns " + actualType + (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + "in query profile types - see the documentation." @@ -282,7 +325,7 @@ public class ConvertedModel { /** * Add the generated macros to the rank profile */ - private void addGeneratedMacros(ImportedModel model, RankProfile profile) { + private static void addGeneratedMacros(ImportedModel model, RankProfile profile) { model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy())); } @@ -291,8 +334,8 @@ public class ConvertedModel { * macro specifies that a single exemplar should be evaluated, we can * reduce the batch dimension out. */ - private void reduceBatchDimensions(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { + private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); TensorType typeBeforeReducing = expression.getRoot().type(typeContext); @@ -319,8 +362,8 @@ public class ConvertedModel { expression.setRoot(root); } - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, - TypeContext<Reference> typeContext) { + private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, + TypeContext<Reference> typeContext) { if (node instanceof TensorFunctionNode) { TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); if (tensorFunction instanceof Rename) { @@ -350,7 +393,7 @@ public class ConvertedModel { return node; } - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { + private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { TensorFunction result = function; TensorType type = function.type(context); if (type.dimensions().size() > 1) { @@ -372,7 +415,7 @@ public class ConvertedModel { * for any following computation of the tensor. */ // TODO: determine when this is not necessary! - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { if (after.equals(before)) { return node; } @@ -399,14 +442,14 @@ public class ConvertedModel { * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. * This method does that for the given expression and returns the result. */ - private RankingExpression replaceConstantsByMacros(RankingExpression expression, + private static RankingExpression replaceConstantsByMacros(RankingExpression expression, Set<String> constantsReplacedByMacros) { if (constantsReplacedByMacros.isEmpty()) return expression; return new RankingExpression(expression.getName(), replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); } - private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { if (node instanceof ReferenceNode) { Reference reference = ((ReferenceNode)node).reference(); if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { @@ -424,7 +467,7 @@ public class ConvertedModel { return node; } - private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { + private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; if (referenceNode.getOutput() == null) { // macro references cannot specify outputs @@ -440,7 +483,7 @@ public class ConvertedModel { } } - private Value asValue(Tensor tensor) { + private static Value asValue(Tensor tensor) { if (tensor.type().rank() == 0) return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors else @@ -455,6 +498,13 @@ public class ConvertedModel { public String toString() { return "model '" + modelName + "'"; } /** + * Returns the directory which contains the source model to use for these arguments + */ + public static File sourceModelFile(ApplicationPackage application, Path sourceModelPath) { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(sourceModelPath)); + } + + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ static class ModelStore { @@ -462,20 +512,9 @@ public class ConvertedModel { private final ApplicationPackage application; private final ModelFiles modelFiles; - ModelStore(ApplicationPackage application, Path modelPath) { + ModelStore(ApplicationPackage application, String modelName) { this.application = application; - this.modelFiles = new ModelFiles(modelPath); - } - - public boolean hasSourceModel() { - return sourceModelFile().exists(); - } - - /** - * Returns the directory which contains the source model to use for these arguments - */ - public File sourceModelFile() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath())); + this.modelFiles = new ModelFiles(modelName); } /** @@ -508,7 +547,7 @@ public class ConvertedModel { return expressions; } - /** Adds this macro expression to the application package to it can be read later. */ + /** Adds this macro expression to the application package so it can be read later. */ void writeMacro(String name, RankingExpression expression) { application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" + expression.getRoot().toString() + "\n"); @@ -518,7 +557,7 @@ public class ConvertedModel { List<Pair<String, RankingExpression>> readMacros() { try { ApplicationFile file = application.getFile(modelFiles.macrosPath()); - if (!file.exists()) return Collections.emptyList(); + if ( ! file.exists()) return Collections.emptyList(); List<Pair<String, RankingExpression>> macros = new ArrayList<>(); BufferedReader reader = new BufferedReader(file.createReader()); @@ -527,7 +566,7 @@ public class ConvertedModel { String[] parts = line.split("\t"); String name = parts[0]; try { - RankingExpression expression = new RankingExpression(parts[1]); + RankingExpression expression = new RankingExpression(parts[0], parts[1]); macros.add(new Pair<>(name, expression)); } catch (ParseException e) { @@ -548,7 +587,7 @@ public class ConvertedModel { List<RankingConstant> readLargeConstants() { try { List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) { + for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsInfoPath()).listFiles()) { String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); } @@ -566,13 +605,13 @@ public class ConvertedModel { * @return the path to the stored constant, relative to the application package root */ Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants"); + Path constantsPath = modelFiles.largeConstantsContentPath(); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); // Remember the constant in a file we replicate in ZooKeeper - application.getFile(modelFiles.largeConstantsPath().append(name + ".constant")) + application.getFile(modelFiles.largeConstantsInfoPath().append(name + ".constant")) .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution @@ -609,8 +648,8 @@ public class ConvertedModel { public void writeSmallConstant(String name, Tensor constant) { // Secret file format for remembering constants: application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); + constant.type().toString() + "\t" + + constant.toString() + "\n"); } /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ @@ -632,40 +671,24 @@ public class ConvertedModel { } } - private void close(Reader reader) { - try { - if (reader != null) - reader.close(); - } - catch (IOException e) { - // ignore - } - } - } static class ModelFiles { - Path modelPath; + String modelName; - public ModelFiles(Path modelPath) { - this.modelPath = modelPath; + public ModelFiles(String modelName) { + this.modelName = modelName; } - /** Returns modelPath with slashes replaced by underscores */ - public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } - - /** Returns relative path to this model below the "models/" dir in the application package */ - public Path modelPath() { return modelPath; } - /** Files stored below this path will be replicated in zookeeper */ public Path storedModelReplicatedPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath()); + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName); } - /** Files stored below this path will not be replicated */ + /** Files stored below this path will not be replicated in zookeeper */ public Path storedModelPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath()); + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName); } public Path expressionPath(String name) { @@ -681,7 +704,12 @@ public class ConvertedModel { } /** Path to the large (ranking) constants directory */ - public Path largeConstantsPath() { + public Path largeConstantsContentPath() { + return storedModelPath().append("constants"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsInfoPath() { return storedModelReplicatedPath().append("constants"); } @@ -695,27 +723,19 @@ public class ConvertedModel { /** Encapsulates the arguments of a specific model output */ static class FeatureArguments { - private final String modelName; - private final Path modelPath; - /** Optional arguments */ private final Optional<String> signature, output; public FeatureArguments(Arguments arguments) { - this(Path.fromString(asString(arguments.expressions().get(0))), - optionalArgument(1, arguments), + this(optionalArgument(1, arguments), optionalArgument(2, arguments)); } - public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { - this.modelPath = modelPath; - this.modelName = toModelName(modelPath); + public FeatureArguments(Optional<String> signature, Optional<String> output) { this.signature = signature; this.output = output; } - public Path modelPath() { return modelPath; } - public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } |