diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 116 |
1 files changed, 67 insertions, 49 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index adf5c81283e..fb0109ed32e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -48,6 +48,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,14 +68,14 @@ public class ConvertedModel { private final ModelName modelName; private final String modelDescription; - private final ImmutableMap<String, RankingExpression> expressions; + private final ImmutableMap<String, ExpressionFunction> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ private final Optional<ImportedModel> sourceModel; private ConvertedModel(ModelName modelName, String modelDescription, - Map<String, RankingExpression> expressions, + Map<String, ExpressionFunction> expressions, Optional<ImportedModel> sourceModel) { this.modelName = modelName; this.modelDescription = modelDescription; @@ -132,23 +133,23 @@ public class ConvertedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public Map<String, RankingExpression> expressions() { return expressions; } + public Map<String, ExpressionFunction> expressions() { return expressions; } /** * Returns the expression matching the given arguments. */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { - RankingExpression expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we can verify - verifyRequiredFunctions(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles()); - return expression.getRoot(); + ExpressionFunction expression = selectExpression(arguments); + if (sourceModel.isPresent()) // we should verify + verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); + return expression.getBody().getRoot(); } - private RankingExpression selectExpression(FeatureArguments arguments) { + private ExpressionFunction selectExpression(FeatureArguments arguments) { if (expressions.isEmpty()) throw new IllegalArgumentException("No expressions available in " + this); - RankingExpression expression = expressions.get(arguments.toName()); + ExpressionFunction expression = expressions.get(arguments.toName()); if (expression != null) return expression; if ( ! arguments.signature().isPresent()) { @@ -158,7 +159,7 @@ public class ConvertedModel { } if ( ! arguments.output().isPresent()) { - List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix = + List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix = expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList()); if (entriesWithTheRightPrefix.size() < 1) throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + @@ -179,10 +180,10 @@ public class ConvertedModel { // ----------------------- Static model conversion/storage below here - private static Map<String, RankingExpression> convertAndStore(ImportedModel model, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelStore store) { + private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelStore store) { // Add constants Set<String> constantsReplacedByFunctions = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -193,8 +194,8 @@ public class ConvertedModel { addGeneratedFunctions(model, profile); // Add expressions - Map<String, RankingExpression> expressions = new HashMap<>(); - for (Pair<String, RankingExpression> output : model.outputExpressions()) { + Map<String, ExpressionFunction> expressions = new HashMap<>(); + for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { addExpression(output.getSecond(), output.getFirst(), constantsReplacedByFunctions, model, store, profile, queryProfiles, @@ -210,21 +211,21 @@ public class ConvertedModel { return expressions; } - private static void addExpression(RankingExpression expression, + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, ImportedModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Map<String, RankingExpression> expressions) { - expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions); - reduceBatchDimensions(expression, model, profile, queryProfiles); + Map<String, ExpressionFunction> expressions) { + expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); + reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } - private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) { + private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -290,15 +291,15 @@ public class ConvertedModel { } /** - * Verify that the functions referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredFunctions. + * Verify that the inputs declared in the given expression exists in the given rank profile as functions, + * and return tensors of the correct types. */ - private static void verifyRequiredFunctions(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { + private static void verifyInputs(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { - TensorType requiredType = model.requiredFunctions().get(functionName); + TensorType requiredType = model.inputs().get(functionName); if (requiredType == null) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); @@ -375,7 +376,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredFunctions().containsKey(referenceNode.getName())) { + if (model.inputs().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -383,7 +384,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredFunctions().containsKey(referenceNode.getName())) { + if (model.inputs().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -451,7 +452,8 @@ public class ConvertedModel { Set<String> constantsReplacedByFunctions) { if (constantsReplacedByFunctions.isEmpty()) return expression; return new RankingExpression(expression.getName(), - replaceConstantsByFunctions(expression.getRoot(), constantsReplacedByFunctions)); + replaceConstantsByFunctions(expression.getRoot(), + constantsReplacedByFunctions)); } private static ExpressionNode replaceConstantsByFunctions(ExpressionNode node, Set<String> constantsReplacedByFunctions) { @@ -524,19 +526,21 @@ public class ConvertedModel { * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part * is always the model name */ - void writeExpression(String name, RankingExpression expression) { - application.getFile(modelFiles.expressionPath(name)) - .writeFile(new StringReader(expression.getRoot().toString())); + void writeExpression(String name, ExpressionFunction expression) { + StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString()); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) + b.append('\n').append(input.getKey()).append('\t').append(input.getValue()); + application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString())); } - Map<String, RankingExpression> readExpressions() { - Map<String, RankingExpression> expressions = new HashMap<>(); + Map<String, ExpressionFunction> readExpressions() { + Map<String, ExpressionFunction> expressions = new HashMap<>(); ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); for (ApplicationFile expressionFile : expressionPath.listFiles()) { - try (Reader reader = new BufferedReader(expressionFile.createReader())){ + try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){ String name = expressionFile.getPath().getName(); - expressions.put(name, new RankingExpression(name, reader)); + expressions.put(name, readExpression(name, reader)); } catch (IOException e) { throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e); @@ -548,8 +552,22 @@ public class ConvertedModel { return expressions; } + private ExpressionFunction readExpression(String name, BufferedReader reader) + throws IOException, ParseException { + // First line is expression + RankingExpression expression = new RankingExpression(name, reader.readLine()); + // Next lines are inputs on the format name\ttensorTypeSpec + Map<String, TensorType> inputs = new LinkedHashMap<>(); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + inputs.put(parts[0], TensorType.fromSpec(parts[1])); + } + return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty()); + } + /** Adds this function expression to the application package so it can be read later. */ - void writeFunction(String name, RankingExpression expression) { + public void writeFunction(String name, RankingExpression expression) { application.getFile(modelFiles.functionsPath()).appendFile(name + "\t" + expression.getRoot().toString() + "\n"); } @@ -561,20 +579,20 @@ public class ConvertedModel { if ( ! file.exists()) return Collections.emptyList(); List<Pair<String, RankingExpression>> functions = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[0], parts[1]); - functions.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + name, e); + try (BufferedReader reader = new BufferedReader(file.createReader())) { + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[0], parts[1]); + functions.add(new Pair<>(name, expression)); + } catch (ParseException e) { + throw new IllegalStateException("Could not parse " + name, e); + } } + return functions; } - return functions; } catch (IOException e) { throw new UncheckedIOException(e); |