diff options
author | Arne Juul <arnej@vespa.ai> | 2023-10-30 11:31:12 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2023-10-30 11:51:20 +0000 |
commit | ecf09e55410ba9b7b70e83da2a449ca0efac1cef (patch) | |
tree | ad8ce5f187c08c20aacac433a03afd92d9849f12 /container-search | |
parent | deb7ac732b85a497a545f8df15e9d4e65943031c (diff) |
support renamed match-features better
Diffstat (limited to 'container-search')
8 files changed, 164 insertions, 60 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/FunEvalSpec.java b/container-search/src/main/java/com/yahoo/search/ranking/FunEvalSpec.java index df9c509dd82..ac1b7c8e218 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/FunEvalSpec.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/FunEvalSpec.java @@ -4,4 +4,7 @@ package com.yahoo.search.ranking; import java.util.List; import java.util.function.Supplier; -record FunEvalSpec(Supplier<Evaluator> evalSource, List<String> fromQuery, List<String> fromMF) {} +record FunEvalSpec(Supplier<Evaluator> evalSource, + List<String> fromQuery, + List<MatchFeatureInput> fromMF) +{} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java index 7340e9e2a5e..7783eabcdcc 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java @@ -96,6 +96,41 @@ class GlobalPhaseSetup { return defaultValues; } + static class InputResolver { + final List<String> usedNormalizers = new ArrayList<>(); + final List<String> fromQuery = new ArrayList<>(); + final List<MatchFeatureInput> fromMF = new ArrayList<>(); + private final Set<String> availableMatchFeatures; + private final Map<String, String> renamedFeatures; + private final Set<String> availableNormalizers; + + InputResolver(Set<String> availableMatchFeatures, + Map<String, String> renamedFeatures, + Set<String> availableNormalizers) + { + this.availableMatchFeatures = availableMatchFeatures; + this.renamedFeatures = renamedFeatures; + this.availableNormalizers = availableNormalizers; + } + void resolve(Collection<String> allInputs) { + for (var input : allInputs) { + String queryFeatureName = asQueryFeature(input); + if (queryFeatureName != null) { + fromQuery.add(queryFeatureName); + } else if (availableNormalizers.contains(input)) { + usedNormalizers.add(input); + } else if (availableMatchFeatures.contains(input)) { + String mfName = renamedFeatures.getOrDefault(input, input); + fromMF.add(new MatchFeatureInput(input, mfName)); + } else if (renamedFeatures.values().contains(input)) { + fromMF.add(new MatchFeatureInput(input, input)); + } else { + throw new IllegalArgumentException("Bad config, missing global-phase input: " + input); + } + } + } + } + static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) { var model = modelEvaluator.modelForRankProfile(rp.name()); Map<String, RankProfilesConfig.Rankprofile.Normalizer> availableNormalizers = new HashMap<>(); @@ -130,47 +165,31 @@ class GlobalPhaseSetup { } } } - for (var entry : renameFeatures.entrySet()) { - String old = entry.getKey(); - if (matchFeatures.contains(old)) { - matchFeatures.remove(old); - matchFeatures.add(entry.getValue()); - } - } if (rerankCount < 0) { rerankCount = 100; } if (functionEvaluatorSource != null) { + var mainResolver = new InputResolver(matchFeatures, renameFeatures, availableNormalizers.keySet()); var evaluator = functionEvaluatorSource.get(); var allInputs = List.copyOf(evaluator.function().arguments()); - List<String> fromMF = new ArrayList<>(); - List<String> fromQuery = new ArrayList<>(); + mainResolver.resolve(allInputs); List<NormalizerSetup> normalizers = new ArrayList<>(); - for (var input : allInputs) { - String queryFeatureName = asQueryFeature(input); - if (queryFeatureName != null) { - fromQuery.add(queryFeatureName); - } else if (availableNormalizers.containsKey(input)) { - var cfg = availableNormalizers.get(input); - String normInput = cfg.input(); - if (matchFeatures.contains(normInput)) { - Supplier<Evaluator> normSource = () -> new DummyEvaluator(normInput); - normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSource, List.of(normInput), rerankCount)); - } else { - Supplier<FunctionEvaluator> normSource = () -> model.evaluatorOf(normInput); - var normInputs = List.copyOf(normSource.get().function().arguments()); - var normSupplier = SimpleEvaluator.wrap(normSource); - normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSupplier, normInputs, rerankCount)); - } - } else if (matchFeatures.contains(input) || matchFeatures.contains(WrappedHit.alternate(input))) { - fromMF.add(input); + for (var input : mainResolver.usedNormalizers) { + var cfg = availableNormalizers.get(input); + String normInput = cfg.input(); + if (matchFeatures.contains(normInput) || renameFeatures.values().contains(normInput)) { + Supplier<Evaluator> normSource = () -> new DummyEvaluator(normInput); + normalizers.add(makeNormalizerSetup(cfg, matchFeatures, renameFeatures, normSource, List.of(normInput), rerankCount)); } else { - throw new IllegalArgumentException("Bad config, missing global-phase input: " + input); + Supplier<FunctionEvaluator> normSource = () -> model.evaluatorOf(normInput); + var normInputs = List.copyOf(normSource.get().function().arguments()); + var normSupplier = SimpleEvaluator.wrap(normSource); + normalizers.add(makeNormalizerSetup(cfg, matchFeatures, renameFeatures, normSupplier, normInputs, rerankCount)); } } Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource); - var gfun = new FunEvalSpec(supplier, fromQuery, fromMF); - var defaultValues = extraDefaultQueryFeatureValues(rp, fromQuery, normalizers); + var gfun = new FunEvalSpec(supplier, mainResolver.fromQuery, mainResolver.fromMF); + var defaultValues = extraDefaultQueryFeatureValues(rp, mainResolver.fromQuery, normalizers); return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, defaultValues); } return null; @@ -178,23 +197,14 @@ class GlobalPhaseSetup { private static NormalizerSetup makeNormalizerSetup(RankProfilesConfig.Rankprofile.Normalizer cfg, Set<String> matchFeatures, + Map<String, String> renamedFeatures, Supplier<Evaluator> evalSupplier, List<String> normInputs, int rerankCount) { - List<String> fromQuery = new ArrayList<>(); - List<String> fromMF = new ArrayList<>(); - for (var input : normInputs) { - String queryFeatureName = asQueryFeature(input); - if (queryFeatureName != null) { - fromQuery.add(queryFeatureName); - } else if (matchFeatures.contains(input) || matchFeatures.contains(WrappedHit.alternate(input))) { - fromMF.add(input); - } else { - throw new IllegalArgumentException("Bad config, missing normalizer input: " + input); - } - } - var fun = new FunEvalSpec(evalSupplier, fromQuery, fromMF); + var normResolver = new InputResolver(matchFeatures, renamedFeatures, Set.of()); + normResolver.resolve(normInputs); + var fun = new FunEvalSpec(evalSupplier, normResolver.fromQuery, normResolver.fromMF); return new NormalizerSetup(cfg.name(), makeNormalizerSupplier(cfg, rerankCount), fun); } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java index fee4f5b4160..32eaa4a29c4 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java @@ -14,10 +14,12 @@ class HitRescorer { private static final Logger logger = Logger.getLogger(HitRescorer.class.getName()); private final Supplier<Evaluator> mainEvalSrc; - private final List<String> mainFromMF; + private final List<MatchFeatureInput> mainFromMF; private final List<NormalizerContext> normalizers; - public HitRescorer(Supplier<Evaluator> mainEvalSrc, List<String> mainFromMF, List<NormalizerContext> normalizers) { + public HitRescorer(Supplier<Evaluator> mainEvalSrc, + List<MatchFeatureInput> mainFromMF, + List<NormalizerContext> normalizers) { this.mainEvalSrc = mainEvalSrc; this.mainFromMF = mainFromMF; this.normalizers = normalizers; @@ -48,13 +50,13 @@ class HitRescorer { return newScore; } - private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List<String> fromMF) { - for (String argName : fromMF) { - var asTensor = wrapped.getTensor(argName); + private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List<MatchFeatureInput> fromMF) { + for (var argSpec : fromMF) { + var asTensor = wrapped.getTensor(argSpec.matchFeatureName()); if (asTensor != null) { - scorer.bind(argName, asTensor); + scorer.bind(argSpec.inputName(), asTensor); } else { - logger.warning("Missing match-feature for Evaluator argument: " + argName); + logger.warning("Missing match-feature for Evaluator argument: " + argSpec.inputName()); return 0.0; } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/MatchFeatureInput.java b/container-search/src/main/java/com/yahoo/search/ranking/MatchFeatureInput.java new file mode 100644 index 00000000000..f80f29b3668 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/MatchFeatureInput.java @@ -0,0 +1,4 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +record MatchFeatureInput(String inputName, String matchFeatureName) {} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java index 9438b5ea824..ceac202db47 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java @@ -4,4 +4,8 @@ package com.yahoo.search.ranking; import java.util.List; import java.util.function.Supplier; -record NormalizerContext(String name, Normalizer normalizer, Supplier<Evaluator> evalSource, List<String> fromMF) {} +record NormalizerContext(String name, + Normalizer normalizer, + Supplier<Evaluator> evalSource, + List<MatchFeatureInput> fromMF) +{} diff --git a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java index f55130c0c93..39b202daf1e 100644 --- a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java +++ b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java @@ -38,7 +38,11 @@ public class GlobalPhaseRerankHitsImplTest { return new FunEvalSpec(() -> new EvalSum(constValue), Collections.emptyList(), Collections.emptyList()); } static FunEvalSpec makeSumSpec(List<String> fromQuery, List<String> fromMF) { - return new FunEvalSpec(() -> new EvalSum(0.0), fromQuery, fromMF); + List<MatchFeatureInput> mfList = new ArrayList<>(); + for (String mf : fromMF) { + mfList.add(new MatchFeatureInput(mf, mf)); + } + return new FunEvalSpec(() -> new EvalSum(0.0), fromQuery, mfList); } static class ExpectingNormalizer extends Normalizer { List<Double> expected; diff --git a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseSetupTest.java b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseSetupTest.java index 082531a97dd..dbe26c2ef94 100644 --- a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseSetupTest.java +++ b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseSetupTest.java @@ -67,28 +67,28 @@ public class GlobalPhaseSetupTest { assertEquals("normalize@2974853441@linear", n.name()); assertEquals(0, n.inputEvalSpec().fromQuery().size()); assertEquals(1, n.inputEvalSpec().fromMF().size()); - assertEquals("funmf", n.inputEvalSpec().fromMF().get(0)); + assertEquals("funmf", n.inputEvalSpec().fromMF().get(0).matchFeatureName()); assertEquals("linear", n.supplier().get().normalizing()); n = nList.get(1); assertEquals("normalize@3414032797@rrank", n.name()); assertEquals(0, n.inputEvalSpec().fromQuery().size()); assertEquals(1, n.inputEvalSpec().fromMF().size()); - assertEquals("attribute(year)", n.inputEvalSpec().fromMF().get(0)); + assertEquals("attribute(year)", n.inputEvalSpec().fromMF().get(0).inputName()); assertEquals("reciprocal-rank{k:60.0}", n.supplier().get().normalizing()); n = nList.get(2); assertEquals("normalize@3551296680@linear", n.name()); assertEquals(0, n.inputEvalSpec().fromQuery().size()); assertEquals(1, n.inputEvalSpec().fromMF().size()); - assertEquals("nativeRank", n.inputEvalSpec().fromMF().get(0)); + assertEquals("nativeRank", n.inputEvalSpec().fromMF().get(0).inputName()); assertEquals("linear", n.supplier().get().normalizing()); n = nList.get(3); assertEquals("normalize@4280591309@rrank", n.name()); assertEquals(0, n.inputEvalSpec().fromQuery().size()); assertEquals(1, n.inputEvalSpec().fromMF().size()); - assertEquals("bm25(myabstract)", n.inputEvalSpec().fromMF().get(0)); + assertEquals("bm25(myabstract)", n.inputEvalSpec().fromMF().get(0).inputName()); assertEquals("reciprocal-rank{k:42.0}", n.supplier().get().normalizing()); n = nList.get(4); @@ -96,24 +96,42 @@ public class GlobalPhaseSetupTest { assertEquals(1, n.inputEvalSpec().fromQuery().size()); assertEquals("myweight", n.inputEvalSpec().fromQuery().get(0)); assertEquals(1, n.inputEvalSpec().fromMF().size()); - assertEquals("attribute(foo1)", n.inputEvalSpec().fromMF().get(0)); + assertEquals("attribute(foo1)", n.inputEvalSpec().fromMF().get(0).inputName()); assertEquals("linear", n.supplier().get().normalizing()); n = nList.get(5); assertEquals("normalize@4640646880@linear", n.name()); assertEquals(0, n.inputEvalSpec().fromQuery().size()); assertEquals(1, n.inputEvalSpec().fromMF().size()); - assertEquals("attribute(foo1)", n.inputEvalSpec().fromMF().get(0)); + assertEquals("attribute(foo1)", n.inputEvalSpec().fromMF().get(0).inputName()); assertEquals("linear", n.supplier().get().normalizing()); n = nList.get(6); assertEquals("normalize@6283155534@linear", n.name()); assertEquals(0, n.inputEvalSpec().fromQuery().size()); assertEquals(1, n.inputEvalSpec().fromMF().size()); - assertEquals("bm25(mytitle)", n.inputEvalSpec().fromMF().get(0)); + assertEquals("bm25(mytitle)", n.inputEvalSpec().fromMF().get(0).inputName()); assertEquals("linear", n.supplier().get().normalizing()); } + @Test void funcWithArgsSetup() { + RankProfilesConfig rpCfg = readConfig("with_mf_funargs"); + assertEquals(1, rpCfg.rankprofile().size()); + RankProfilesEvaluator rpEvaluator = createEvaluator(rpCfg); + var setup = GlobalPhaseSetup.maybeMakeSetup(rpCfg.rankprofile().get(0), rpEvaluator); + assertNotNull(setup); + assertEquals(0, setup.normalizers.size()); + assertEquals(3, setup.matchFeaturesToHide.size()); + assertEquals(0, setup.globalPhaseEvalSpec.fromQuery().size()); + var wantMF = setup.globalPhaseEvalSpec.fromMF(); + assertEquals(4, wantMF.size()); + wantMF.sort((a, b) -> a.matchFeatureName().compareTo(b.matchFeatureName())); + assertEquals("plusOne(2)", wantMF.get(0).matchFeatureName()); + assertEquals("plusOne(attribute(foo2))", wantMF.get(1).matchFeatureName()); + assertEquals("useAttr(t1,42)", wantMF.get(2).matchFeatureName()); + assertEquals("withIndirect(foo1)", wantMF.get(3).matchFeatureName()); + } + private RankProfilesEvaluator createEvaluator(RankProfilesConfig config) { RankingConstantsConfig constantsConfig = new RankingConstantsConfig.Builder().build(); RankingExpressionsConfig expressionsConfig = new RankingExpressionsConfig.Builder().build(); diff --git a/container-search/src/test/resources/config/with_mf_funargs/rank-profiles.cfg b/container-search/src/test/resources/config/with_mf_funargs/rank-profiles.cfg new file mode 100644 index 00000000000..9acf22f76e5 --- /dev/null +++ b/container-search/src/test/resources/config/with_mf_funargs/rank-profiles.cfg @@ -0,0 +1,59 @@ +rankprofile[0].name "function-with-arg-in-global-phase" +rankprofile[0].fef.property[0].name "rankingExpression(useAttr@6598f1aecaec0a2d.40876484d21a389).rankingScript" +rankprofile[0].fef.property[0].value "attribute(t1) * 42" +rankprofile[0].fef.property[1].name "rankingExpression(plusOne@31852fecfab75f29).rankingScript" +rankprofile[0].fef.property[1].value "2 + 1" +rankprofile[0].fef.property[2].name "rankingExpression(useAttr@93d0729be0db6c70.fe12ed266262cc16).rankingScript" +rankprofile[0].fef.property[2].value "attribute(foo1) * 1.25" +rankprofile[0].fef.property[3].name "rankingExpression(withIndirect@93d0729be0db6c70).rankingScript" +rankprofile[0].fef.property[3].value "rankingExpression(useAttr@93d0729be0db6c70.fe12ed266262cc16)" +rankprofile[0].fef.property[4].name "rankingExpression(plusOne@4a2b16f9107d7185).rankingScript" +rankprofile[0].fef.property[4].value "attribute(foo2) + 1" +rankprofile[0].fef.property[5].name "vespa.type.feature.useAttr(t1,42)" +rankprofile[0].fef.property[5].value "tensor(m{},v[3])" +rankprofile[0].fef.property[6].name "rankingExpression(plusOne).rankingScript" +rankprofile[0].fef.property[6].value "x + 1" +rankprofile[0].fef.property[7].name "rankingExpression(useAttr).rankingScript" +rankprofile[0].fef.property[7].value "attribute(name) * weight" +rankprofile[0].fef.property[8].name "rankingExpression(useAttr@2e0b6bb9bf541103.fe12ed266262cc16).rankingScript" +rankprofile[0].fef.property[8].value "attribute(name) * 1.25" +rankprofile[0].fef.property[9].name "rankingExpression(withIndirect).rankingScript" +rankprofile[0].fef.property[9].value "rankingExpression(useAttr@2e0b6bb9bf541103.fe12ed266262cc16)" +rankprofile[0].fef.property[10].name "vespa.rank.firstphase" +rankprofile[0].fef.property[10].value "nativeRank" +rankprofile[0].fef.property[11].name "vespa.rank.globalphase" +rankprofile[0].fef.property[11].value "rankingExpression(globalphase)" +rankprofile[0].fef.property[12].name "rankingExpression(globalphase).rankingScript" +rankprofile[0].fef.property[12].value "reduce(rankingExpression(useAttr@6598f1aecaec0a2d.40876484d21a389) + rankingExpression(plusOne@31852fecfab75f29) + rankingExpression(withIndirect@93d0729be0db6c70) + rankingExpression(plusOne@4a2b16f9107d7185), sum)" +rankprofile[0].fef.property[13].name "vespa.match.feature" +rankprofile[0].fef.property[13].value "rankingExpression(plusOne@4a2b16f9107d7185)" +rankprofile[0].fef.property[14].name "vespa.match.feature" +rankprofile[0].fef.property[14].value "rankingExpression(plusOne@31852fecfab75f29)" +rankprofile[0].fef.property[15].name "vespa.match.feature" +rankprofile[0].fef.property[15].value "rankingExpression(withIndirect@93d0729be0db6c70)" +rankprofile[0].fef.property[16].name "vespa.match.feature" +rankprofile[0].fef.property[16].value "rankingExpression(useAttr@6598f1aecaec0a2d.40876484d21a389)" +rankprofile[0].fef.property[17].name "vespa.hidden.matchfeature" +rankprofile[0].fef.property[17].value "plusOne(2)" +rankprofile[0].fef.property[18].name "vespa.hidden.matchfeature" +rankprofile[0].fef.property[18].value "withIndirect(foo1)" +rankprofile[0].fef.property[19].name "vespa.hidden.matchfeature" +rankprofile[0].fef.property[19].value "useAttr(t1,42)" +rankprofile[0].fef.property[20].name "vespa.feature.rename" +rankprofile[0].fef.property[20].value "rankingExpression(plusOne@4a2b16f9107d7185)" +rankprofile[0].fef.property[21].name "vespa.feature.rename" +rankprofile[0].fef.property[21].value "plusOne(attribute(foo2))" +rankprofile[0].fef.property[22].name "vespa.feature.rename" +rankprofile[0].fef.property[22].value "rankingExpression(plusOne@31852fecfab75f29)" +rankprofile[0].fef.property[23].name "vespa.feature.rename" +rankprofile[0].fef.property[23].value "plusOne(2)" +rankprofile[0].fef.property[24].name "vespa.feature.rename" +rankprofile[0].fef.property[24].value "rankingExpression(withIndirect@93d0729be0db6c70)" +rankprofile[0].fef.property[25].name "vespa.feature.rename" +rankprofile[0].fef.property[25].value "withIndirect(foo1)" +rankprofile[0].fef.property[26].name "vespa.feature.rename" +rankprofile[0].fef.property[26].value "rankingExpression(useAttr@6598f1aecaec0a2d.40876484d21a389)" +rankprofile[0].fef.property[27].name "vespa.feature.rename" +rankprofile[0].fef.property[27].value "useAttr(t1,42)" +rankprofile[0].fef.property[28].name "vespa.type.attribute.t1" +rankprofile[0].fef.property[28].value "tensor(m{},v[3])" |