diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-23 16:20:32 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-23 16:20:32 -0700 |
commit | 4e44e5472829c033c3d995c618f2febcc4463eb7 (patch) | |
tree | 402dc48f0fce44759ce7bca8068c6b98097dd031 /config-model | |
parent | 2ee637ff5ef12924e77d5fbf087fb9fb803f0143 (diff) |
Use ExpressionFunction
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java | 8 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 57 |
2 files changed, 33 insertions, 32 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 04481a3bc8d..2c6a7941772 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -237,8 +237,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), model.name(), profile, queryProfiles.getRegistry(), model); - for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs + for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) { + profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use arguments } } } @@ -249,8 +249,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); - for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs + for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) { + profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use inputs } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index d72a22f7c5e..fb0109ed32e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -48,6 +48,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,14 +68,14 @@ public class ConvertedModel { private final ModelName modelName; private final String modelDescription; - private final ImmutableMap<String, ImportedModel.ExpressionWithInputs> expressions; + private final ImmutableMap<String, ExpressionFunction> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ private final Optional<ImportedModel> sourceModel; private ConvertedModel(ModelName modelName, String modelDescription, - Map<String, ImportedModel.ExpressionWithInputs> expressions, + Map<String, ExpressionFunction> expressions, Optional<ImportedModel> sourceModel) { this.modelName = modelName; this.modelDescription = modelDescription; @@ -132,23 +133,23 @@ public class ConvertedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public Map<String, ImportedModel.ExpressionWithInputs> expressions() { return expressions; } + public Map<String, ExpressionFunction> expressions() { return expressions; } /** * Returns the expression matching the given arguments. */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { - ImportedModel.ExpressionWithInputs expression = selectExpression(arguments); + ExpressionFunction expression = selectExpression(arguments); if (sourceModel.isPresent()) // we should verify - verifyInputs(expression.expression(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); - return expression.expression().getRoot(); + verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); + return expression.getBody().getRoot(); } - private ImportedModel.ExpressionWithInputs selectExpression(FeatureArguments arguments) { + private ExpressionFunction selectExpression(FeatureArguments arguments) { if (expressions.isEmpty()) throw new IllegalArgumentException("No expressions available in " + this); - ImportedModel.ExpressionWithInputs expression = expressions.get(arguments.toName()); + ExpressionFunction expression = expressions.get(arguments.toName()); if (expression != null) return expression; if ( ! arguments.signature().isPresent()) { @@ -158,7 +159,7 @@ public class ConvertedModel { } if ( ! arguments.output().isPresent()) { - List<Map.Entry<String, ImportedModel.ExpressionWithInputs>> entriesWithTheRightPrefix = + List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix = expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList()); if (entriesWithTheRightPrefix.size() < 1) throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + @@ -179,10 +180,10 @@ public class ConvertedModel { // ----------------------- Static model conversion/storage below here - private static Map<String, ImportedModel.ExpressionWithInputs> convertAndStore(ImportedModel model, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelStore store) { + private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelStore store) { // Add constants Set<String> constantsReplacedByFunctions = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -193,8 +194,8 @@ public class ConvertedModel { addGeneratedFunctions(model, profile); // Add expressions - Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>(); - for (Pair<String, ImportedModel.ExpressionWithInputs> output : model.outputExpressions()) { + Map<String, ExpressionFunction> expressions = new HashMap<>(); + for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { addExpression(output.getSecond(), output.getFirst(), constantsReplacedByFunctions, model, store, profile, queryProfiles, @@ -210,21 +211,21 @@ public class ConvertedModel { return expressions; } - private static void addExpression(ImportedModel.ExpressionWithInputs expression, + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, ImportedModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Map<String, ImportedModel.ExpressionWithInputs> expressions) { - expression = expression.with(replaceConstantsByFunctions(expression.expression(), constantsReplacedByFunctions)); - reduceBatchDimensions(expression.expression(), model, profile, queryProfiles); + Map<String, ExpressionFunction> expressions) { + expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); + reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } - private static Map<String, ImportedModel.ExpressionWithInputs> convertStored(ModelStore store, RankProfile profile) { + private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -525,15 +526,15 @@ public class ConvertedModel { * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part * is always the model name */ - void writeExpression(String name, ImportedModel.ExpressionWithInputs expression) { - StringBuilder b = new StringBuilder(expression.expression().getRoot().toString()); - for (Map.Entry<String, TensorType> input : expression.inputs().entrySet()) + void writeExpression(String name, ExpressionFunction expression) { + StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString()); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) b.append('\n').append(input.getKey()).append('\t').append(input.getValue()); application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString())); } - Map<String, ImportedModel.ExpressionWithInputs> readExpressions() { - Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>(); + Map<String, ExpressionFunction> readExpressions() { + Map<String, ExpressionFunction> expressions = new HashMap<>(); ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); for (ApplicationFile expressionFile : expressionPath.listFiles()) { @@ -551,18 +552,18 @@ public class ConvertedModel { return expressions; } - private ImportedModel.ExpressionWithInputs readExpression(String name, BufferedReader reader) + private ExpressionFunction readExpression(String name, BufferedReader reader) throws IOException, ParseException { // First line is expression RankingExpression expression = new RankingExpression(name, reader.readLine()); // Next lines are inputs on the format name\ttensorTypeSpec - Map<String, TensorType> inputs = new HashMap<>(); + Map<String, TensorType> inputs = new LinkedHashMap<>(); String line; while (null != (line = reader.readLine())) { String[] parts = line.split("\t"); inputs.put(parts[0], TensorType.fromSpec(parts[1])); } - return new ImportedModel.ExpressionWithInputs(expression, inputs); + return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty()); } /** Adds this function expression to the application package so it can be read later. */ |