diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-05-16 13:51:24 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-05-16 13:51:24 +0200 |
commit | a37ed1c28091f234f25c9b3649999821eb7f4802 (patch) | |
tree | 84d6f2c96e21bb8304f04e38f002869bbfbf394d | |
parent | 1d63b5d81c057a8fe99812be22abac38c8195241 (diff) |
Support addiong models in rank profiles
10 files changed, 74 insertions, 58 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 07f3048af04..ec560484513 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -116,9 +116,9 @@ public class RankProfile implements Cloneable { private Map<Reference, Input> inputs = new LinkedHashMap<>(); - private Map<Reference, Constant> constants = new HashMap<>(); + private Map<Reference, Constant> constants = new LinkedHashMap<>(); - private Map<String, OnnxModel> onnxModels = new HashMap<>(); + private Map<String, OnnxModel> onnxModels = new LinkedHashMap<>(); private Set<String> filterFields = new HashSet<>(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 7384f98b121..081450275d1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -158,7 +158,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema, Collection<RawRankProfile> rankProfiles, DeployState deployState) { - Map<String, OnnxModel> allModels = new HashMap<>(); + Map<String, OnnxModel> allModels = new LinkedHashMap<>(); addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(), allModels, schema != null ? schema.toString() : "[global]"); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index 71493df357c..58a9c78254a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -53,9 +53,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans return transformFeature(feature, context.rankProfile()); } - public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) { - ImmutableSchema search = rankProfile.schema(); - final String featureName = feature.getName(); + public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile profile) { + String featureName = feature.getName(); if ( ! featureName.equals("onnxModel") && ! featureName.equals("onnx")) return feature; Arguments arguments = feature.getArguments(); @@ -71,11 +70,11 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. String modelConfigName = getModelConfigName(feature.reference()); - OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + OnnxModel onnxModel = profile.onnxModels().get(modelConfigName); if (onnxModel == null) { String path = asString(arguments.expressions().get(0)); ModelName modelName = new ModelName(null, Path.fromString(path), true); - ConvertedModel convertedModel = ConvertedModel.fromStore(search.applicationPackage(), modelName, path, rankProfile); + ConvertedModel convertedModel = ConvertedModel.fromStore(profile.schema().applicationPackage(), modelName, path, profile); FeatureArguments featureArguments = new FeatureArguments(arguments); return convertedModel.expression(featureArguments, null); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java index f772c5fe903..47d770f609e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java @@ -43,6 +43,9 @@ public class ConvertParsedRanking { for (var constant : parsed.getConstants().values()) profile.add(constant); + for (var onnxModel : parsed.getOnnxModels()) + profile.add(onnxModel); + for (var input : parsed.getInputs().entrySet()) profile.addInput(input.getKey(), input.getValue()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java index 0ade3bfd76b..8f0f92c4027 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.parser; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfile.MatchPhaseSettings; import com.yahoo.searchdefinition.RankProfile.MutateOperation; @@ -54,6 +55,7 @@ class ParsedRankProfile extends ParsedBlock { private final Map<String, List<String>> rankProperties = new LinkedHashMap<>(); private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>(); private final Map<Reference, RankProfile.Input> inputs = new LinkedHashMap<>(); + private final List<OnnxModel> onnxModels = new ArrayList<>(); ParsedRankProfile(String name) { super(name, "rank-profile"); @@ -85,6 +87,7 @@ class ParsedRankProfile extends ParsedBlock { Map<String, List<String>> getRankProperties() { return Collections.unmodifiableMap(rankProperties); } Map<Reference, RankProfile.Constant> getConstants() { return Collections.unmodifiableMap(constants); } Map<Reference, RankProfile.Input> getInputs() { return Collections.unmodifiableMap(inputs); } + List<OnnxModel> getOnnxModels() { return List.copyOf(onnxModels); } Optional<String> getInheritedSummaryFeatures() { return Optional.ofNullable(this.inheritedSummaryFeatures); } Optional<String> getSecondPhaseExpression() { return Optional.ofNullable(this.secondPhaseExpression); } @@ -111,6 +114,10 @@ class ParsedRankProfile extends ParsedBlock { inputs.put(name, input); } + void add(OnnxModel model) { + onnxModels.add(model); + } + void addFieldRankFilter(String field, boolean filter) { fieldsRankFilter.put(field, filter); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java index 2bc10554b25..4c102594479 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java @@ -123,7 +123,7 @@ public class ParsedSchema extends ParsedBlock { extraIndexes.put(idxName, index); } - void addOnnxModel(OnnxModel model) { + void add(OnnxModel model) { onnxModels.add(model); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java index 19fbc116558..70ce051bb21 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -23,7 +23,7 @@ import java.util.Map; * * onnx("files/model.onnx", "path/to/output:1") * - * And generates an "onnx-model" configuration as if it was defined in the schema: + * And generates an "onnx-model" configuration as if it was defined in the profile: * * onnx-model files_model_onnx { * file: "files/model.onnx" @@ -45,31 +45,31 @@ public class OnnxModelConfigGenerator extends Processor { if (documentsOnly) return; for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) { if (profile.getFirstPhaseRanking() != null) { - process(profile.getFirstPhaseRanking().getRoot()); + process(profile.getFirstPhaseRanking().getRoot(), profile); } if (profile.getSecondPhaseRanking() != null) { - process(profile.getSecondPhaseRanking().getRoot()); + process(profile.getSecondPhaseRanking().getRoot(), profile); } for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { - process(function.getValue().function().getBody().getRoot()); + process(function.getValue().function().getBody().getRoot(), profile); } for (ReferenceNode feature : profile.getSummaryFeatures()) { - process(feature); + process(feature, profile); } } } - private void process(ExpressionNode node) { + private void process(ExpressionNode node, RankProfile profile) { if (node instanceof ReferenceNode) { - process((ReferenceNode)node); + process((ReferenceNode)node, profile); } else if (node instanceof CompositeNode) { for (ExpressionNode child : ((CompositeNode) node).children()) { - process(child); + process(child, profile); } } } - private void process(ReferenceNode feature) { + private void process(ReferenceNode feature, RankProfile profile) { if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) { if (feature.getArguments().size() > 0) { if (feature.getArguments().expressions().get(0) instanceof ConstantNode) { @@ -85,9 +85,9 @@ public class OnnxModelConfigGenerator extends Processor { } } - OnnxModel onnxModel = schema.onnxModels().get(modelConfigName); + OnnxModel onnxModel = profile.onnxModels().get(modelConfigName); if (onnxModel == null) - schema.add(new OnnxModel(modelConfigName, path)); + profile.add(new OnnxModel(modelConfigName, path)); } } } diff --git a/config-model/src/main/javacc/IntermediateParser.jj b/config-model/src/main/javacc/IntermediateParser.jj index 873196d8bda..01f111df284 100644 --- a/config-model/src/main/javacc/IntermediateParser.jj +++ b/config-model/src/main/javacc/IntermediateParser.jj @@ -427,7 +427,7 @@ void rootSchemaItem(ParsedSchema schema) : { } | structOutside(schema) | annotationOutside(schema) | fieldSet(schema) - | onnxModel(schema) + | onnxModelInSchema(schema) // Deprecated: TODO: Emit warning when on Vespa 8 ) } @@ -1703,31 +1703,38 @@ void hnswIndexBody(HnswIndexParams.Builder params) : | <MULTITHREADEDINDEXING> <COLON> bool = bool() { params.setMultiThreadedIndexing(bool); } ) } -/** - * Consumes a onnx-model block of a schema element. - * - * @param schema the schema object to add content to. - */ -void onnxModel(ParsedSchema schema) : +void onnxModelInSchema(ParsedSchema schema) : +{ + OnnxModel onnxModel; +} +{ + onnxModel = onnxModel() { schema.add(onnxModel); } +} + +void onnxModelInProfile(ParsedRankProfile profile) : +{ + OnnxModel onnxModel; +} +{ + onnxModel = onnxModel() { profile.add(onnxModel); } +} + +/** Consumes an onnx-model block. */ +OnnxModel onnxModel() : { String name; OnnxModel onnxModel; } { - ( <ONNXMODEL> name = identifier() - { - onnxModel = new OnnxModel(name); - } + ( <ONNXMODEL> name = identifier() { onnxModel = new OnnxModel(name); } lbrace() (onnxModelItem(onnxModel) (<NL>)*)+ <RBRACE> ) - { - schema.addOnnxModel(onnxModel); - } + { return onnxModel; } } /** - * This rule consumes an onnx-model block. + * Consumes an onnx-model block. * - * @param onnxModel The onnxModel to modify. + * @param onnxModel the onnxModel to modify */ void onnxModelItem(OnnxModel onnxModel) : { @@ -1849,6 +1856,7 @@ void rankProfileItem(ParsedSchema schema, ParsedRankProfile profile) : { } | constants(schema, profile) | matchFeatures(profile) | summaryFeatures(profile) + | onnxModelInProfile(profile) | strict(profile) ) } diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd index a15714767ba..82872758dd9 100644 --- a/config-model/src/test/integration/onnx-model/schemas/test.sd +++ b/config-model/src/test/integration/onnx-model/schemas/test.sd @@ -21,14 +21,6 @@ search test { output "path/to/output:0": out } - onnx-model another_model { - file: files/model.onnx - input first_input: attribute(document_field) - input "second/input:0": constant(my_constant) - input "third_input": another_function - output "path/to/output:2": out - } - onnx-model dynamic_model { file: files/dynamic_model.onnx input input: my_function @@ -72,6 +64,13 @@ search test { first-phase { expression: 1 } + onnx-model another_model { + file: files/model.onnx + input first_input: attribute(document_field) + input "second/input:0": constant(my_constant) + input "third_input": another_function + output "path/to/output:2": out + } summary-features { onnx(another_model).out onnx("files/summary_model.onnx", "path/to/output:2") diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index 1c23950d972..6820a8d9678 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -93,6 +93,18 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("path_to_output_2", model.output(2).as()); model = config.model(1); + assertEquals("dynamic_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); + + model = config.model(2); + assertEquals("unbound_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); + + model = config.model(3); assertEquals("files_model_onnx", model.name()); assertEquals(3, model.input().size()); assertEquals(3, model.output().size()); @@ -104,27 +116,15 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("path_to_output_2", model.output(2).as()); assertEquals("files_model_onnx", model.name()); - model = config.model(2); + model = config.model(4); assertEquals("another_model", model.name()); assertEquals("third_input", model.input(2).name()); assertEquals("rankingExpression(another_function)", model.input(2).source()); - model = config.model(3); + model = config.model(5); assertEquals("files_summary_model_onnx", model.name()); assertEquals(3, model.input().size()); assertEquals(3, model.output().size()); - - model = config.model(4); - assertEquals("unbound_model", model.name()); - assertEquals(1, model.input().size()); - assertEquals(1, model.output().size()); - assertEquals("rankingExpression(my_function)", model.input(0).source()); - - model = config.model(5); - assertEquals("dynamic_model", model.name()); - assertEquals(1, model.input().size()); - assertEquals(1, model.output().size()); - assertEquals("rankingExpression(my_function)", model.input(0).source()); } private void assertTransformedFeature(VespaModel model) { |