diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 1a22b98fd9f..da2b23595a9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -17,6 +17,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; +import org.tensorflow.op.core.Rank; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -200,9 +201,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (functions.isEmpty()) return; List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); - Map<String, String> functionProperties = new LinkedHashMap<>(); - functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions)); if (firstPhaseRanking != null) { functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions)); @@ -211,20 +210,30 @@ public class RawRankProfile implements RankProfilesConfig.Producer { functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions)); } + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); + replaceFunctionSummaryFeatures(context); + + // First phase, second phase and summary features should add all required functions to the context. + // However, we need to add any functions not referenced in those anyway for model-evaluation. + deriveFunctionProperties(functions, functionExpressions, functionProperties); + for (Map.Entry<String, String> e : functionProperties.entrySet()) { rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue())); } - SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); - replaceFunctionSummaryFeatures(context); } - private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, - List<ExpressionFunction> functionExpressions) { - SerializationContext context = new SerializationContext(functionExpressions); + private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, + List<ExpressionFunction> functionExpressions, + Map<String, String> functionProperties) { + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { + String propertyName = RankingExpression.propertyName(e.getKey()); + if (context.serializedFunctions().containsKey(propertyName)) { + continue; + } String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); - context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); + context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); if (e.getValue().function().returnType().isPresent()) @@ -232,7 +241,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types // throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved"); } - return context.serializedFunctions(); + functionProperties.putAll(context.serializedFunctions()); } private void replaceFunctionSummaryFeatures(SerializationContext context) { @@ -241,9 +250,11 @@ public class RawRankProfile implements RankProfilesConfig.Producer { for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); // Is the feature a function? - if (context.getFunction(referenceNode.getName()) != null) { - context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()), - referenceNode.toString(new StringBuilder(), context, null, null).toString()); + ExpressionFunction function = context.getFunction(referenceNode.getName()); + if (function != null) { + String propertyName = RankingExpression.propertyName(referenceNode.getName()); + String expressionString = function.getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + context.addFunctionSerialization(propertyName, expressionString); ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput()); functionSummaryFeatures.put(referenceNode.getName(), newReferenceNode); i.remove(); // Will add the expanded one in next block |