summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-05-16 13:51:24 +0200
committerJon Bratseth <bratseth@gmail.com>2022-05-16 13:51:24 +0200
commita37ed1c28091f234f25c9b3649999821eb7f4802 (patch)
tree84d6f2c96e21bb8304f04e38f002869bbfbf394d
parent1d63b5d81c057a8fe99812be22abac38c8195241 (diff)
Support addiong models in rank profiles
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java22
-rw-r--r--config-model/src/main/javacc/IntermediateParser.jj40
-rw-r--r--config-model/src/test/integration/onnx-model/schemas/test.sd15
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java28
10 files changed, 74 insertions, 58 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 07f3048af04..ec560484513 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -116,9 +116,9 @@ public class RankProfile implements Cloneable {
private Map<Reference, Input> inputs = new LinkedHashMap<>();
- private Map<Reference, Constant> constants = new HashMap<>();
+ private Map<Reference, Constant> constants = new LinkedHashMap<>();
- private Map<String, OnnxModel> onnxModels = new HashMap<>();
+ private Map<String, OnnxModel> onnxModels = new LinkedHashMap<>();
private Set<String> filterFields = new HashSet<>();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 7384f98b121..081450275d1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -158,7 +158,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema,
Collection<RawRankProfile> rankProfiles,
DeployState deployState) {
- Map<String, OnnxModel> allModels = new HashMap<>();
+ Map<String, OnnxModel> allModels = new LinkedHashMap<>();
addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(),
allModels,
schema != null ? schema.toString() : "[global]");
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index 71493df357c..58a9c78254a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -53,9 +53,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
return transformFeature(feature, context.rankProfile());
}
- public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
- ImmutableSchema search = rankProfile.schema();
- final String featureName = feature.getName();
+ public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile profile) {
+ String featureName = feature.getName();
if ( ! featureName.equals("onnxModel") && ! featureName.equals("onnx")) return feature;
Arguments arguments = feature.getArguments();
@@ -71,11 +70,11 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
// ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
String modelConfigName = getModelConfigName(feature.reference());
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ OnnxModel onnxModel = profile.onnxModels().get(modelConfigName);
if (onnxModel == null) {
String path = asString(arguments.expressions().get(0));
ModelName modelName = new ModelName(null, Path.fromString(path), true);
- ConvertedModel convertedModel = ConvertedModel.fromStore(search.applicationPackage(), modelName, path, rankProfile);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(profile.schema().applicationPackage(), modelName, path, profile);
FeatureArguments featureArguments = new FeatureArguments(arguments);
return convertedModel.expression(featureArguments, null);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java
index f772c5fe903..47d770f609e 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java
@@ -43,6 +43,9 @@ public class ConvertParsedRanking {
for (var constant : parsed.getConstants().values())
profile.add(constant);
+ for (var onnxModel : parsed.getOnnxModels())
+ profile.add(onnxModel);
+
for (var input : parsed.getInputs().entrySet())
profile.addInput(input.getKey(), input.getValue());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java
index 0ade3bfd76b..8f0f92c4027 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.parser;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfile.MatchPhaseSettings;
import com.yahoo.searchdefinition.RankProfile.MutateOperation;
@@ -54,6 +55,7 @@ class ParsedRankProfile extends ParsedBlock {
private final Map<String, List<String>> rankProperties = new LinkedHashMap<>();
private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>();
private final Map<Reference, RankProfile.Input> inputs = new LinkedHashMap<>();
+ private final List<OnnxModel> onnxModels = new ArrayList<>();
ParsedRankProfile(String name) {
super(name, "rank-profile");
@@ -85,6 +87,7 @@ class ParsedRankProfile extends ParsedBlock {
Map<String, List<String>> getRankProperties() { return Collections.unmodifiableMap(rankProperties); }
Map<Reference, RankProfile.Constant> getConstants() { return Collections.unmodifiableMap(constants); }
Map<Reference, RankProfile.Input> getInputs() { return Collections.unmodifiableMap(inputs); }
+ List<OnnxModel> getOnnxModels() { return List.copyOf(onnxModels); }
Optional<String> getInheritedSummaryFeatures() { return Optional.ofNullable(this.inheritedSummaryFeatures); }
Optional<String> getSecondPhaseExpression() { return Optional.ofNullable(this.secondPhaseExpression); }
@@ -111,6 +114,10 @@ class ParsedRankProfile extends ParsedBlock {
inputs.put(name, input);
}
+ void add(OnnxModel model) {
+ onnxModels.add(model);
+ }
+
void addFieldRankFilter(String field, boolean filter) {
fieldsRankFilter.put(field, filter);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java
index 2bc10554b25..4c102594479 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java
@@ -123,7 +123,7 @@ public class ParsedSchema extends ParsedBlock {
extraIndexes.put(idxName, index);
}
- void addOnnxModel(OnnxModel model) {
+ void add(OnnxModel model) {
onnxModels.add(model);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
index 19fbc116558..70ce051bb21 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -23,7 +23,7 @@ import java.util.Map;
*
* onnx("files/model.onnx", "path/to/output:1")
*
- * And generates an "onnx-model" configuration as if it was defined in the schema:
+ * And generates an "onnx-model" configuration as if it was defined in the profile:
*
* onnx-model files_model_onnx {
* file: "files/model.onnx"
@@ -45,31 +45,31 @@ public class OnnxModelConfigGenerator extends Processor {
if (documentsOnly) return;
for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) {
if (profile.getFirstPhaseRanking() != null) {
- process(profile.getFirstPhaseRanking().getRoot());
+ process(profile.getFirstPhaseRanking().getRoot(), profile);
}
if (profile.getSecondPhaseRanking() != null) {
- process(profile.getSecondPhaseRanking().getRoot());
+ process(profile.getSecondPhaseRanking().getRoot(), profile);
}
for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
- process(function.getValue().function().getBody().getRoot());
+ process(function.getValue().function().getBody().getRoot(), profile);
}
for (ReferenceNode feature : profile.getSummaryFeatures()) {
- process(feature);
+ process(feature, profile);
}
}
}
- private void process(ExpressionNode node) {
+ private void process(ExpressionNode node, RankProfile profile) {
if (node instanceof ReferenceNode) {
- process((ReferenceNode)node);
+ process((ReferenceNode)node, profile);
} else if (node instanceof CompositeNode) {
for (ExpressionNode child : ((CompositeNode) node).children()) {
- process(child);
+ process(child, profile);
}
}
}
- private void process(ReferenceNode feature) {
+ private void process(ReferenceNode feature, RankProfile profile) {
if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
if (feature.getArguments().size() > 0) {
if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
@@ -85,9 +85,9 @@ public class OnnxModelConfigGenerator extends Processor {
}
}
- OnnxModel onnxModel = schema.onnxModels().get(modelConfigName);
+ OnnxModel onnxModel = profile.onnxModels().get(modelConfigName);
if (onnxModel == null)
- schema.add(new OnnxModel(modelConfigName, path));
+ profile.add(new OnnxModel(modelConfigName, path));
}
}
}
diff --git a/config-model/src/main/javacc/IntermediateParser.jj b/config-model/src/main/javacc/IntermediateParser.jj
index 873196d8bda..01f111df284 100644
--- a/config-model/src/main/javacc/IntermediateParser.jj
+++ b/config-model/src/main/javacc/IntermediateParser.jj
@@ -427,7 +427,7 @@ void rootSchemaItem(ParsedSchema schema) : { }
| structOutside(schema)
| annotationOutside(schema)
| fieldSet(schema)
- | onnxModel(schema)
+ | onnxModelInSchema(schema) // Deprecated: TODO: Emit warning when on Vespa 8
)
}
@@ -1703,31 +1703,38 @@ void hnswIndexBody(HnswIndexParams.Builder params) :
| <MULTITHREADEDINDEXING> <COLON> bool = bool() { params.setMultiThreadedIndexing(bool); } )
}
-/**
- * Consumes a onnx-model block of a schema element.
- *
- * @param schema the schema object to add content to.
- */
-void onnxModel(ParsedSchema schema) :
+void onnxModelInSchema(ParsedSchema schema) :
+{
+ OnnxModel onnxModel;
+}
+{
+ onnxModel = onnxModel() { schema.add(onnxModel); }
+}
+
+void onnxModelInProfile(ParsedRankProfile profile) :
+{
+ OnnxModel onnxModel;
+}
+{
+ onnxModel = onnxModel() { profile.add(onnxModel); }
+}
+
+/** Consumes an onnx-model block. */
+OnnxModel onnxModel() :
{
String name;
OnnxModel onnxModel;
}
{
- ( <ONNXMODEL> name = identifier()
- {
- onnxModel = new OnnxModel(name);
- }
+ ( <ONNXMODEL> name = identifier() { onnxModel = new OnnxModel(name); }
lbrace() (onnxModelItem(onnxModel) (<NL>)*)+ <RBRACE> )
- {
- schema.addOnnxModel(onnxModel);
- }
+ { return onnxModel; }
}
/**
- * This rule consumes an onnx-model block.
+ * Consumes an onnx-model block.
*
- * @param onnxModel The onnxModel to modify.
+ * @param onnxModel the onnxModel to modify
*/
void onnxModelItem(OnnxModel onnxModel) :
{
@@ -1849,6 +1856,7 @@ void rankProfileItem(ParsedSchema schema, ParsedRankProfile profile) : { }
| constants(schema, profile)
| matchFeatures(profile)
| summaryFeatures(profile)
+ | onnxModelInProfile(profile)
| strict(profile) )
}
diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd
index a15714767ba..82872758dd9 100644
--- a/config-model/src/test/integration/onnx-model/schemas/test.sd
+++ b/config-model/src/test/integration/onnx-model/schemas/test.sd
@@ -21,14 +21,6 @@ search test {
output "path/to/output:0": out
}
- onnx-model another_model {
- file: files/model.onnx
- input first_input: attribute(document_field)
- input "second/input:0": constant(my_constant)
- input "third_input": another_function
- output "path/to/output:2": out
- }
-
onnx-model dynamic_model {
file: files/dynamic_model.onnx
input input: my_function
@@ -72,6 +64,13 @@ search test {
first-phase {
expression: 1
}
+ onnx-model another_model {
+ file: files/model.onnx
+ input first_input: attribute(document_field)
+ input "second/input:0": constant(my_constant)
+ input "third_input": another_function
+ output "path/to/output:2": out
+ }
summary-features {
onnx(another_model).out
onnx("files/summary_model.onnx", "path/to/output:2")
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index 1c23950d972..6820a8d9678 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -93,6 +93,18 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("path_to_output_2", model.output(2).as());
model = config.model(1);
+ assertEquals("dynamic_model", model.name());
+ assertEquals(1, model.input().size());
+ assertEquals(1, model.output().size());
+ assertEquals("rankingExpression(my_function)", model.input(0).source());
+
+ model = config.model(2);
+ assertEquals("unbound_model", model.name());
+ assertEquals(1, model.input().size());
+ assertEquals(1, model.output().size());
+ assertEquals("rankingExpression(my_function)", model.input(0).source());
+
+ model = config.model(3);
assertEquals("files_model_onnx", model.name());
assertEquals(3, model.input().size());
assertEquals(3, model.output().size());
@@ -104,27 +116,15 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("path_to_output_2", model.output(2).as());
assertEquals("files_model_onnx", model.name());
- model = config.model(2);
+ model = config.model(4);
assertEquals("another_model", model.name());
assertEquals("third_input", model.input(2).name());
assertEquals("rankingExpression(another_function)", model.input(2).source());
- model = config.model(3);
+ model = config.model(5);
assertEquals("files_summary_model_onnx", model.name());
assertEquals(3, model.input().size());
assertEquals(3, model.output().size());
-
- model = config.model(4);
- assertEquals("unbound_model", model.name());
- assertEquals(1, model.input().size());
- assertEquals(1, model.output().size());
- assertEquals("rankingExpression(my_function)", model.input(0).source());
-
- model = config.model(5);
- assertEquals("dynamic_model", model.name());
- assertEquals(1, model.input().size());
- assertEquals(1, model.output().size());
- assertEquals("rankingExpression(my_function)", model.input(0).source());
}
private void assertTransformedFeature(VespaModel model) {