diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-02-28 14:20:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-28 14:20:53 +0100 |
commit | 00de2a92c9cdc1056f718674b055a9639fe64b3b (patch) | |
tree | 66ad9d07b5292d5961353d6454c78ebda6ad1922 /config-model | |
parent | 05a7e7a75feae44c6ca9ed27d4d3570873603702 (diff) | |
parent | 18aa303fe0959786838aa63ec8f0aca092be2d99 (diff) |
Merge pull request #26179 from vespa-engine/arnej/add-new-components-4
add new components for global-phase handling
Diffstat (limited to 'config-model')
9 files changed, 187 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index ad6eb038058..7cb0a088f5f 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1019,6 +1019,9 @@ public class RankProfile implements Cloneable { var recorder = new InputRecorder(needInputs); recorder.transform(globalPhaseRanking.function().getBody(), context); for (String input : needInputs) { + if (input.startsWith("constant(") || input.startsWith("query(")) { + continue; + } try { addMatchFeatures(new FeatureList(input)); } catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) { diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index b0f63ebb732..4e7988a2006 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -3,16 +3,13 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; import com.yahoo.schema.RankProfile; -import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import java.io.StringReader; import java.util.Set; /** @@ -86,13 +83,7 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon throw new IllegalArgumentException("missing onnx model: " + arg); } for (String onnxInput : model.getInputMap().values()) { - var reader = new StringReader(onnxInput); - try { - var asExpression = new RankingExpression(reader); - transform(asExpression.getRoot(), context); - } catch (ParseException e) { - throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage()); - } + neededInputs.add(onnxInput); } return; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java index 86c48407775..14c25ee7452 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java @@ -7,6 +7,7 @@ import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.search.config.SchemaInfoConfig; import com.yahoo.search.pagetemplates.PageTemplatesConfig; import com.yahoo.search.query.profile.config.QueryProfilesConfig; +import com.yahoo.search.ranking.RankProfilesEvaluatorFactory; import com.yahoo.schema.derived.SchemaInfo; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -56,6 +57,8 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains> owningCluster.addComponent(Component.fromClassAndBundle(CompiledQueryProfileRegistry.class, SEARCH_AND_DOCPROC_BUNDLE)); owningCluster.addComponent(Component.fromClassAndBundle(com.yahoo.search.schema.SchemaInfo.class, SEARCH_AND_DOCPROC_BUNDLE)); owningCluster.addComponent(Component.fromClassAndBundle(SearchStatusExtension.class, SEARCH_AND_DOCPROC_BUNDLE)); + owningCluster.addComponent(Component.fromClassAndBundle(RankProfilesEvaluatorFactory.class, SEARCH_AND_DOCPROC_BUNDLE)); + owningCluster.addComponent(Component.fromClassAndBundle(com.yahoo.search.ranking.GlobalPhaseRanker.class, SEARCH_AND_DOCPROC_BUNDLE)); cluster.addSearchAndDocprocBundles(); } @@ -68,9 +71,16 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains> /** Adds a Dispatcher component to the owning container cluster for each search cluster */ private void initializeDispatchers(Collection<SearchCluster> searchClusters) { for (SearchCluster searchCluster : searchClusters) { - if ( ! ( searchCluster instanceof IndexedSearchCluster)) continue; - var dispatcher = new DispatcherComponent((IndexedSearchCluster)searchCluster); - owningCluster.addComponent(dispatcher); + if (searchCluster instanceof IndexedSearchCluster indexed) { + var dispatcher = new DispatcherComponent(indexed); + owningCluster.addComponent(dispatcher); + for (var documentDb : indexed.getDocumentDbs()) { + var factory = new RankProfilesEvaluatorComponent(documentDb); + if (! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) { + owningCluster.addComponent(factory); + } + } + } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/RankProfilesEvaluatorComponent.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/RankProfilesEvaluatorComponent.java new file mode 100644 index 00000000000..75a2802ee53 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/RankProfilesEvaluatorComponent.java @@ -0,0 +1,49 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.search; + +import com.yahoo.config.model.producer.AnyConfigProducer; +import com.yahoo.osgi.provider.model.ComponentModel; +import com.yahoo.search.ranking.RankProfilesEvaluator; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; +import com.yahoo.vespa.model.container.ContainerModelEvaluation; +import com.yahoo.vespa.model.container.PlatformBundles; +import com.yahoo.vespa.model.container.component.Component; +import com.yahoo.vespa.model.search.DocumentDatabase; + +public class RankProfilesEvaluatorComponent + extends Component<AnyConfigProducer, ComponentModel> + implements + RankProfilesConfig.Producer, + RankingConstantsConfig.Producer, + RankingExpressionsConfig.Producer, + OnnxModelsConfig.Producer +{ + private final DocumentDatabase ddb; + + public RankProfilesEvaluatorComponent(DocumentDatabase db) { + super(toComponentModel(db.getSchemaName())); + ddb = db; + } + + private static ComponentModel toComponentModel(String p) { + String myComponentId = "ranking-expression-evaluator." + p; + return new ComponentModel(myComponentId, + RankProfilesEvaluator.class.getName(), + PlatformBundles.SEARCH_AND_DOCPROC_BUNDLE); + } + + @Override + public void getConfig(RankProfilesConfig.Builder builder) { ddb.getConfig(builder); } + + @Override + public void getConfig(RankingExpressionsConfig.Builder builder) { ddb.getConfig(builder); } + + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { ddb.getConfig(builder); } + + @Override + public void getConfig(OnnxModelsConfig.Builder builder) { ddb.getConfig(builder); } +} diff --git a/config-model/src/test/derived/globalphase_onnx_inside/files/ax_plus_b.onnx b/config-model/src/test/derived/globalphase_onnx_inside/files/ax_plus_b.onnx new file mode 100644 index 00000000000..17282d13dc3 --- /dev/null +++ b/config-model/src/test/derived/globalphase_onnx_inside/files/ax_plus_b.onnx @@ -0,0 +1,23 @@ +:© + +matrix_X +vector_AXA"MatMul + +XA +vector_Bvector_Y"AddlrZ +matrix_X + + +Z +vector_A + + +Z +vector_B + + +b +vector_Y + + +B
\ No newline at end of file diff --git a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg new file mode 100644 index 00000000000..35bb1ccc3d2 --- /dev/null +++ b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg @@ -0,0 +1,34 @@ +rankprofile[].name "default" +rankprofile[].fef.property[].name "rankingExpression(handicap).rankingScript" +rankprofile[].fef.property[].value "query(yy)" +rankprofile[].fef.property[].name "rankingExpression(handicap).type" +rankprofile[].fef.property[].value "tensor(d0[2])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(attribute(aa), sum)" +rankprofile[].fef.property[].name "vespa.rank.globalphase" +rankprofile[].fef.property[].value "rankingExpression(globalphase)" +rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" +rankprofile[].fef.property[].value "reduce(constant(ww) * (onnx(inside).foobar - rankingExpression(handicap)), sum)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(aa)" +rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" +rankprofile[].fef.property[].value "13" +rankprofile[].fef.property[].name "vespa.type.attribute.aa" +rankprofile[].fef.property[].value "tensor(d1[3])" +rankprofile[].fef.property[].name "vespa.type.query.bb" +rankprofile[].fef.property[].value "tensor(d0[2])" +rankprofile[].fef.property[].name "vespa.type.query.yy" +rankprofile[].fef.property[].value "tensor(d0[2])" +rankprofile[].name "unranked" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "value(0)" +rankprofile[].fef.property[].name "vespa.hitcollector.heapsize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.hitcollector.arraysize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" +rankprofile[].fef.property[].value "true" +rankprofile[].fef.property[].name "vespa.type.attribute.aa" +rankprofile[].fef.property[].value "tensor(d1[3])" diff --git a/config-model/src/test/derived/globalphase_onnx_inside/test.sd b/config-model/src/test/derived/globalphase_onnx_inside/test.sd new file mode 100644 index 00000000000..c38e318ce6b --- /dev/null +++ b/config-model/src/test/derived/globalphase_onnx_inside/test.sd @@ -0,0 +1,42 @@ +schema test { + + document test { + field aa type tensor(d1[3]) { + indexing: attribute + } + } + + constant xx { + file: files/const_xx.json + type: tensor(d0[2],d1[3]) + } + constant ww { + file: files/const_ww.json + type: tensor(d0[2]) + } + + rank-profile default { + inputs { + query(bb) tensor(d0[2]) + query(yy) tensor(d0[2]) + } + onnx-model inside { + file: files/ax_plus_b.onnx + input vector_A: attribute(aa) + input matrix_X: constant(xx) + input vector_B: query(bb) + output vector_Y: foobar + } + first-phase { + expression: sum(attribute(aa)) + } + function handicap() { + expression: query(yy) + } + global-phase { + rerank-count: 13 + expression: sum(constant(ww) * (onnx(inside).foobar - handicap)) + } + } + +} diff --git a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg index ea8bc5f77e6..e3947e9e46f 100644 --- a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg +++ b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg @@ -410,8 +410,6 @@ rankprofile[].fef.property[].value "rankingExpression(globalphase)" rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" rankprofile[].fef.property[].value "rankingExpression(myplus) + reduce(rankingExpression(mymul), sum) + firstPhase" rankprofile[].fef.property[].name "vespa.match.feature" -rankprofile[].fef.property[].value "query(fromq)" -rankprofile[].fef.property[].name "vespa.match.feature" rankprofile[].fef.property[].value "firstPhase" rankprofile[].fef.property[].name "vespa.match.feature" rankprofile[].fef.property[].value "attribute(t1)" diff --git a/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java new file mode 100644 index 00000000000..2ff33dd70d8 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/schema/derived/GlobalPhaseOnnxModelsTestCase.java @@ -0,0 +1,22 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.derived; + +import com.yahoo.schema.parser.ParseException; +import org.junit.jupiter.api.Test; +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Tests exporting with global-phase and ONNX models + * + * @author arnej + */ +public class GlobalPhaseOnnxModelsTestCase extends AbstractExportingTestCase { + + @Test + void testModelInRankProfile() throws IOException, ParseException { + assertCorrectDeriving("globalphase_onnx_inside"); + } + +} |