summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-02-28 14:20:53 +0100
committerGitHub <noreply@github.com>2023-02-28 14:20:53 +0100
commit00de2a92c9cdc1056f718674b055a9639fe64b3b (patch)
tree66ad9d07b5292d5961353d6454c78ebda6ad1922 /config-model
parent05a7e7a75feae44c6ca9ed27d4d3570873603702 (diff)
parent18aa303fe0959786838aa63ec8f0aca092be2d99 (diff)
Merge pull request #26179 from vespa-engine/arnej/add-new-components-4
add new components for global-phase handling
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java3
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java11
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java16
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/search/RankProfilesEvaluatorComponent.java49
-rw-r--r--config-model/src/test/derived/globalphase_onnx_inside/files/ax_plus_b.onnx23
-rw-r--r--config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg34
-rw-r--r--config-model/src/test/derived/globalphase_onnx_inside/test.sd42
-rw-r--r--config-model/src/test/derived/rankingexpression/rank-profiles.cfg2
-rw-r--r--config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java22
9 files changed, 187 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index ad6eb038058..7cb0a088f5f 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -1019,6 +1019,9 @@ public class RankProfile implements Cloneable {
var recorder = new InputRecorder(needInputs);
recorder.transform(globalPhaseRanking.function().getBody(), context);
for (String input : needInputs) {
+ if (input.startsWith("constant(") || input.startsWith("query(")) {
+ continue;
+ }
try {
addMatchFeatures(new FeatureList(input));
} catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) {
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index b0f63ebb732..4e7988a2006 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -3,16 +3,13 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.RankProfile;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import java.io.StringReader;
import java.util.Set;
/**
@@ -86,13 +83,7 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon
throw new IllegalArgumentException("missing onnx model: " + arg);
}
for (String onnxInput : model.getInputMap().values()) {
- var reader = new StringReader(onnxInput);
- try {
- var asExpression = new RankingExpression(reader);
- transform(asExpression.getRoot(), context);
- } catch (ParseException e) {
- throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage());
- }
+ neededInputs.add(onnxInput);
}
return;
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
index 86c48407775..14c25ee7452 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
@@ -7,6 +7,7 @@ import com.yahoo.search.config.IndexInfoConfig;
import com.yahoo.search.config.SchemaInfoConfig;
import com.yahoo.search.pagetemplates.PageTemplatesConfig;
import com.yahoo.search.query.profile.config.QueryProfilesConfig;
+import com.yahoo.search.ranking.RankProfilesEvaluatorFactory;
import com.yahoo.schema.derived.SchemaInfo;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
@@ -56,6 +57,8 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains>
owningCluster.addComponent(Component.fromClassAndBundle(CompiledQueryProfileRegistry.class, SEARCH_AND_DOCPROC_BUNDLE));
owningCluster.addComponent(Component.fromClassAndBundle(com.yahoo.search.schema.SchemaInfo.class, SEARCH_AND_DOCPROC_BUNDLE));
owningCluster.addComponent(Component.fromClassAndBundle(SearchStatusExtension.class, SEARCH_AND_DOCPROC_BUNDLE));
+ owningCluster.addComponent(Component.fromClassAndBundle(RankProfilesEvaluatorFactory.class, SEARCH_AND_DOCPROC_BUNDLE));
+ owningCluster.addComponent(Component.fromClassAndBundle(com.yahoo.search.ranking.GlobalPhaseRanker.class, SEARCH_AND_DOCPROC_BUNDLE));
cluster.addSearchAndDocprocBundles();
}
@@ -68,9 +71,16 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains>
/** Adds a Dispatcher component to the owning container cluster for each search cluster */
private void initializeDispatchers(Collection<SearchCluster> searchClusters) {
for (SearchCluster searchCluster : searchClusters) {
- if ( ! ( searchCluster instanceof IndexedSearchCluster)) continue;
- var dispatcher = new DispatcherComponent((IndexedSearchCluster)searchCluster);
- owningCluster.addComponent(dispatcher);
+ if (searchCluster instanceof IndexedSearchCluster indexed) {
+ var dispatcher = new DispatcherComponent(indexed);
+ owningCluster.addComponent(dispatcher);
+ for (var documentDb : indexed.getDocumentDbs()) {
+ var factory = new RankProfilesEvaluatorComponent(documentDb);
+ if (! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) {
+ owningCluster.addComponent(factory);
+ }
+ }
+ }
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/RankProfilesEvaluatorComponent.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/RankProfilesEvaluatorComponent.java
new file mode 100644
index 00000000000..75a2802ee53
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/RankProfilesEvaluatorComponent.java
@@ -0,0 +1,49 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.search;
+
+import com.yahoo.config.model.producer.AnyConfigProducer;
+import com.yahoo.osgi.provider.model.ComponentModel;
+import com.yahoo.search.ranking.RankProfilesEvaluator;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
+import com.yahoo.vespa.model.container.ContainerModelEvaluation;
+import com.yahoo.vespa.model.container.PlatformBundles;
+import com.yahoo.vespa.model.container.component.Component;
+import com.yahoo.vespa.model.search.DocumentDatabase;
+
+public class RankProfilesEvaluatorComponent
+ extends Component<AnyConfigProducer, ComponentModel>
+ implements
+ RankProfilesConfig.Producer,
+ RankingConstantsConfig.Producer,
+ RankingExpressionsConfig.Producer,
+ OnnxModelsConfig.Producer
+{
+ private final DocumentDatabase ddb;
+
+ public RankProfilesEvaluatorComponent(DocumentDatabase db) {
+ super(toComponentModel(db.getSchemaName()));
+ ddb = db;
+ }
+
+ private static ComponentModel toComponentModel(String p) {
+ String myComponentId = "ranking-expression-evaluator." + p;
+ return new ComponentModel(myComponentId,
+ RankProfilesEvaluator.class.getName(),
+ PlatformBundles.SEARCH_AND_DOCPROC_BUNDLE);
+ }
+
+ @Override
+ public void getConfig(RankProfilesConfig.Builder builder) { ddb.getConfig(builder); }
+
+ @Override
+ public void getConfig(RankingExpressionsConfig.Builder builder) { ddb.getConfig(builder); }
+
+ @Override
+ public void getConfig(RankingConstantsConfig.Builder builder) { ddb.getConfig(builder); }
+
+ @Override
+ public void getConfig(OnnxModelsConfig.Builder builder) { ddb.getConfig(builder); }
+}
diff --git a/config-model/src/test/derived/globalphase_onnx_inside/files/ax_plus_b.onnx b/config-model/src/test/derived/globalphase_onnx_inside/files/ax_plus_b.onnx
new file mode 100644
index 00000000000..17282d13dc3
--- /dev/null
+++ b/config-model/src/test/derived/globalphase_onnx_inside/files/ax_plus_b.onnx
@@ -0,0 +1,23 @@
+:©
+
+matrix_X
+vector_AXA"MatMul
+
+XA
+vector_Bvector_Y"AddlrZ
+matrix_X
+  
+
+Z
+vector_A
+
+ 
+Z
+vector_B
+
+ 
+b
+vector_Y
+
+ 
+B \ No newline at end of file
diff --git a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
new file mode 100644
index 00000000000..35bb1ccc3d2
--- /dev/null
+++ b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
@@ -0,0 +1,34 @@
+rankprofile[].name "default"
+rankprofile[].fef.property[].name "rankingExpression(handicap).rankingScript"
+rankprofile[].fef.property[].value "query(yy)"
+rankprofile[].fef.property[].name "rankingExpression(handicap).type"
+rankprofile[].fef.property[].value "tensor(d0[2])"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "rankingExpression(firstphase)"
+rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
+rankprofile[].fef.property[].value "reduce(attribute(aa), sum)"
+rankprofile[].fef.property[].name "vespa.rank.globalphase"
+rankprofile[].fef.property[].value "rankingExpression(globalphase)"
+rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript"
+rankprofile[].fef.property[].value "reduce(constant(ww) * (onnx(inside).foobar - rankingExpression(handicap)), sum)"
+rankprofile[].fef.property[].name "vespa.match.feature"
+rankprofile[].fef.property[].value "attribute(aa)"
+rankprofile[].fef.property[].name "vespa.globalphase.rerankcount"
+rankprofile[].fef.property[].value "13"
+rankprofile[].fef.property[].name "vespa.type.attribute.aa"
+rankprofile[].fef.property[].value "tensor(d1[3])"
+rankprofile[].fef.property[].name "vespa.type.query.bb"
+rankprofile[].fef.property[].value "tensor(d0[2])"
+rankprofile[].fef.property[].name "vespa.type.query.yy"
+rankprofile[].fef.property[].value "tensor(d0[2])"
+rankprofile[].name "unranked"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "value(0)"
+rankprofile[].fef.property[].name "vespa.hitcollector.heapsize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.hitcollector.arraysize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures"
+rankprofile[].fef.property[].value "true"
+rankprofile[].fef.property[].name "vespa.type.attribute.aa"
+rankprofile[].fef.property[].value "tensor(d1[3])"
diff --git a/config-model/src/test/derived/globalphase_onnx_inside/test.sd b/config-model/src/test/derived/globalphase_onnx_inside/test.sd
new file mode 100644
index 00000000000..c38e318ce6b
--- /dev/null
+++ b/config-model/src/test/derived/globalphase_onnx_inside/test.sd
@@ -0,0 +1,42 @@
+schema test {
+
+ document test {
+ field aa type tensor(d1[3]) {
+ indexing: attribute
+ }
+ }
+
+ constant xx {
+ file: files/const_xx.json
+ type: tensor(d0[2],d1[3])
+ }
+ constant ww {
+ file: files/const_ww.json
+ type: tensor(d0[2])
+ }
+
+ rank-profile default {
+ inputs {
+ query(bb) tensor(d0[2])
+ query(yy) tensor(d0[2])
+ }
+ onnx-model inside {
+ file: files/ax_plus_b.onnx
+ input vector_A: attribute(aa)
+ input matrix_X: constant(xx)
+ input vector_B: query(bb)
+ output vector_Y: foobar
+ }
+ first-phase {
+ expression: sum(attribute(aa))
+ }
+ function handicap() {
+ expression: query(yy)
+ }
+ global-phase {
+ rerank-count: 13
+ expression: sum(constant(ww) * (onnx(inside).foobar - handicap))
+ }
+ }
+
+}
diff --git a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg
index ea8bc5f77e6..e3947e9e46f 100644
--- a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg
+++ b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg
@@ -410,8 +410,6 @@ rankprofile[].fef.property[].value "rankingExpression(globalphase)"
rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(myplus) + reduce(rankingExpression(mymul), sum) + firstPhase"
rankprofile[].fef.property[].name "vespa.match.feature"
-rankprofile[].fef.property[].value "query(fromq)"
-rankprofile[].fef.property[].name "vespa.match.feature"
rankprofile[].fef.property[].value "firstPhase"
rankprofile[].fef.property[].name "vespa.match.feature"
rankprofile[].fef.property[].value "attribute(t1)"
diff --git a/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java
new file mode 100644
index 00000000000..2ff33dd70d8
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java
@@ -0,0 +1,22 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.derived;
+
+import com.yahoo.schema.parser.ParseException;
+import org.junit.jupiter.api.Test;
+import java.io.IOException;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Tests exporting with global-phase and ONNX models
+ *
+ * @author arnej
+ */
+public class GlobalPhaseOnnxModelsTestCase extends AbstractExportingTestCase {
+
+ @Test
+ void testModelInRankProfile() throws IOException, ParseException {
+ assertCorrectDeriving("globalphase_onnx_inside");
+ }
+
+}