diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-04-23 13:03:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-23 13:03:19 +0200 |
commit | 69d032ee48c4c28fb874020220990392903480d0 (patch) | |
tree | b81c5b646134122b4030f4d76af06a0ca2e92f90 /config-model | |
parent | dbf15114b4505e0d4ebe6ad5263685d64619f0b8 (diff) | |
parent | 0f20f60145524b13b11453fa0c92f33be0732707 (diff) |
Merge pull request #17560 from vespa-engine/arnej/add-input-params-in-rank-profile
Arnej/add input params in rank profile
Diffstat (limited to 'config-model')
4 files changed, 63 insertions, 14 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 8bef4c39ba1..b460752d7bd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -103,6 +103,8 @@ public class RankProfile implements Cloneable { private Map<String, RankingExpressionFunction> functions = new LinkedHashMap<>(); + private Map<Reference, TensorType> inputFeatures = new LinkedHashMap<>(); + private Set<String> filterFields = new HashSet<>(); private final RankProfileRegistry rankProfileRegistry; @@ -578,6 +580,23 @@ public class RankProfile implements Cloneable { return rankingExpressionFunction; } + /** + * Use for rank profiles representing a model evaluation; it will assume + * that a input is provided with the declared type (for the purpose of + * type resolving). + **/ + public void addInputFeature(String name, TensorType declaredType) { + Reference ref = Reference.fromIdentifier(name); + if (inputFeatures.containsKey(ref)) { + TensorType hadType = inputFeatures.get(ref); + if (! declaredType.equals(hadType)) { + throw new IllegalArgumentException("Tried to replace input feature "+name+" with different type: "+ + hadType+" -> "+declaredType); + } + } + inputFeatures.put(ref, declaredType); + } + public RankingExpressionFunction findFunction(String name) { RankingExpressionFunction function = functions.get(name); return ((function == null) && (getInherited() != null)) @@ -677,6 +696,7 @@ public class RankProfile implements Cloneable { clone.summaryFeatures = summaryFeatures != null ? new LinkedHashSet<>(this.summaryFeatures) : null; clone.rankFeatures = rankFeatures != null ? new LinkedHashSet<>(this.rankFeatures) : null; clone.rankProperties = new LinkedHashMap<>(this.rankProperties); + clone.inputFeatures = new LinkedHashMap<>(this.inputFeatures); clone.functions = new LinkedHashMap<>(this.functions); clone.filterFields = new HashSet<>(this.filterFields); clone.constants = new HashMap<>(this.constants); @@ -790,8 +810,12 @@ public class RankProfile implements Cloneable { return typeContext(queryProfiles, collectFeatureTypes()); } + public MapEvaluationTypeContext typeContext() { return typeContext(new QueryProfileRegistry()); } + private Map<Reference, TensorType> collectFeatureTypes() { Map<Reference, TensorType> featureTypes = new HashMap<>(); + // Add input features + inputFeatures.forEach((k, v) -> featureTypes.put(k, v)); // Add attributes allFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); allImportedFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index b757259102b..9086ca9f40e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -210,6 +210,9 @@ public class ConvertedModel { Map<String, ExpressionFunction> expressions = new HashMap<>(); for (ImportedMlFunction outputFunction : model.outputExpressions()) { ExpressionFunction expression = asExpressionFunction(outputFunction); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) { + profile.addInputFeature(input.getKey(), input.getValue()); + } addExpression(expression, expression.getName(), constantsReplacedByFunctions, model, store, profile, queryProfiles, @@ -251,13 +254,20 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, Map<String, ExpressionFunction> expressions) { expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); + if (expression.returnType().isEmpty()) { + TensorType type = expression.getBody().type(profile.typeContext(queryProfiles)); + if (type != null) { + expression = expression.withReturnType(type); + } + } store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) + for (Pair<String, Tensor> constant : store.readSmallConstants()) { profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + } for (RankingConstant constant : store.readLargeConstants()) { if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) { @@ -269,7 +279,20 @@ public class ConvertedModel { addGeneratedFunctionToProfile(profile, function.getFirst(), function.getSecond()); } - return store.readExpressions(); + Map<String, ExpressionFunction> expressions = new HashMap<>(); + for (Pair<String, ExpressionFunction> output : store.readExpressions()) { + String name = output.getFirst(); + ExpressionFunction expression = output.getSecond(); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) { + profile.addInputFeature(input.getKey(), input.getValue()); + } + TensorType type = expression.getBody().type(profile.typeContext()); + if (type != null) { + expression = expression.withReturnType(type); + } + expressions.put(name, expression); + } + return expressions; } private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, @@ -321,8 +344,9 @@ public class ConvertedModel { "\nwant to add " + expression + "\n"); return; } - var fun = new ExpressionFunction(functionName, expression); - profile.addFunction(fun, false); // TODO: Inline if only used once + ExpressionFunction function = new ExpressionFunction(functionName, expression); + // XXX should we resolve type here? + profile.addFunction(function, false); // TODO: Inline if only used once } /** @@ -465,14 +489,14 @@ public class ConvertedModel { application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString())); } - Map<String, ExpressionFunction> readExpressions() { - Map<String, ExpressionFunction> expressions = new HashMap<>(); + List<Pair<String, ExpressionFunction>> readExpressions() { + List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); - if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); + if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyList(); for (ApplicationFile expressionFile : expressionPath.listFiles()) { - try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){ + try (BufferedReader reader = new BufferedReader(expressionFile.createReader())) { String name = expressionFile.getPath().getName(); - expressions.put(name, readExpression(name, reader)); + expressions.add(new Pair<>(name, readExpression(name, reader))); } catch (IOException e) { throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 8fe4a8fb022..d665b7f20f0 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -159,7 +159,7 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { public void testNeuralNetworkSetup() throws ParseException { // Note: the type assigned to query profile and constant tensors here is not the correct type RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[1])"); + QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(input[1])"); SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + @@ -184,19 +184,19 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { " }\n" + " }\n" + " constant W_hidden {\n" + - " type: tensor(x[1])\n" + + " type: tensor(hidden[1])\n" + " file: ignored.json\n" + " }\n" + " constant b_input {\n" + - " type: tensor(x[1])\n" + + " type: tensor(hidden[1])\n" + " file: ignored.json\n" + " }\n" + " constant W_final {\n" + - " type: tensor(x[1])\n" + + " type: tensor(final[1])\n" + " file: ignored.json\n" + " }\n" + " constant b_final {\n" + - " type: tensor(x[1])\n" + + " type: tensor(final[1])\n" + " file: ignored.json\n" + " }\n" + "}\n"); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 4c1c24c9790..1aaa1669377 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -145,6 +145,7 @@ public class ModelEvaluationTest { private final String profile = "rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).rankingScript: map(input, f(a)(exp(a)))\n" + + "rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).type: tensor<float>(d0[3])\n" + "rankingExpression(default.output).rankingScript: join(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), reduce(join(join(reduce(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), 9.999999974752427E-7, f(a,b)(a + b)), sum, d0), f(a,b)(a / b))\n" + "rankingExpression(default.output).input.type: tensor<float>(d0[3])\n" + "rankingExpression(default.output).type: tensor<float>(d0[3])\n"; |