diff options
author | Jon Bratseth <bratseth@oath.com> | 2020-04-26 13:45:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-26 13:45:54 +0200 |
commit | 1ea6410fd642bb92a03b5798387fa977286af167 (patch) | |
tree | 2a4d3fa648d581c472e2ceb83484a24edf0b739a /config-model/src/main/java | |
parent | 84e4a15ed3545f59f8f38f4fe8c4770098cce642 (diff) | |
parent | d95c4ceb6c4db635e93b605d719f01db9d5f2e6f (diff) |
Merge pull request #13067 from vespa-engine/lesters/bert-searchlib-and-config-model
Lesters/bert searchlib and config model
Diffstat (limited to 'config-model/src/main/java')
3 files changed, 55 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 6de7c985326..4011ce43841 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -41,6 +41,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement private final Map<Reference, TensorType> resolvedTypes = new HashMap<>(); + /** To avoid re-resolving diamond-shaped dependencies */ + private final Map<Reference, TensorType> globallyResolvedTypes; + /** For invocation loop detection */ private final Deque<Reference> currentResolutionCallStack; @@ -53,6 +56,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement this.currentResolutionCallStack = new ArrayDeque<>(); this.queryFeaturesNotDeclared = new TreeSet<>(); tensorsAreUsed = false; + globallyResolvedTypes = new HashMap<>(); } private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, @@ -60,12 +64,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement Map<Reference, TensorType> featureTypes, Deque<Reference> currentResolutionCallStack, SortedSet<Reference> queryFeaturesNotDeclared, - boolean tensorsAreUsed) { + boolean tensorsAreUsed, + Map<Reference, TensorType> globallyResolvedTypes) { super(functions, bindings); this.featureTypes.putAll(featureTypes); this.currentResolutionCallStack = currentResolutionCallStack; this.queryFeaturesNotDeclared = queryFeaturesNotDeclared; this.tensorsAreUsed = tensorsAreUsed; + this.globallyResolvedTypes = globallyResolvedTypes; } public void setType(Reference reference, TensorType type) { @@ -82,11 +88,25 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement resolvedTypes.clear(); } - @Override + private boolean referenceCanBeResolvedGlobally(Reference reference) { + Optional<ExpressionFunction> function = functionInvocation(reference); + return function.isPresent() && function.get().arguments().size() == 0; + // are there other cases we would like to resolve globally? + } + + @Override public TensorType getType(Reference reference) { // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: + + boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference); + TensorType resolvedType = resolvedTypes.get(reference); - if (resolvedType != null) return resolvedType; + if (resolvedType == null && canBeResolvedGlobally) { + resolvedType = globallyResolvedTypes.get(reference); + } + if (resolvedType != null) { + return resolvedType; + } resolvedType = resolveType(reference); if (resolvedType == null) @@ -94,6 +114,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement resolvedTypes.put(reference, resolvedType); if (resolvedType.rank() > 0) tensorsAreUsed = true; + + if (canBeResolvedGlobally) { + globallyResolvedTypes.put(reference, resolvedType); + } + return resolvedType; } @@ -254,7 +279,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement featureTypes, currentResolutionCallStack, queryFeaturesNotDeclared, - tensorsAreUsed); + tensorsAreUsed, + globallyResolvedTypes); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 23eb814de81..ea126123a25 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -680,11 +680,12 @@ public class RankProfile implements Cloneable { Map<String, RankingExpressionFunction> inlineFunctions = compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms); + firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + // Function compiling second pass: compile all functions and insert previously compiled inline functions functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); - firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); - secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); } private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 1a22b98fd9f..c3c10139684 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -200,9 +200,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (functions.isEmpty()) return; List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); - Map<String, String> functionProperties = new LinkedHashMap<>(); - functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions)); if (firstPhaseRanking != null) { functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions)); @@ -211,20 +209,30 @@ public class RawRankProfile implements RankProfilesConfig.Producer { functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions)); } + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); + replaceFunctionSummaryFeatures(context); + + // First phase, second phase and summary features should add all required functions to the context. + // However, we need to add any functions not referenced in those anyway for model-evaluation. + deriveFunctionProperties(functions, functionExpressions, functionProperties); + for (Map.Entry<String, String> e : functionProperties.entrySet()) { rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue())); } - SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); - replaceFunctionSummaryFeatures(context); } - private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, - List<ExpressionFunction> functionExpressions) { - SerializationContext context = new SerializationContext(functionExpressions); + private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, + List<ExpressionFunction> functionExpressions, + Map<String, String> functionProperties) { + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { + String propertyName = RankingExpression.propertyName(e.getKey()); + if (context.serializedFunctions().containsKey(propertyName)) { + continue; + } String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); - context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); + context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); if (e.getValue().function().returnType().isPresent()) @@ -232,7 +240,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types // throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved"); } - return context.serializedFunctions(); + functionProperties.putAll(context.serializedFunctions()); } private void replaceFunctionSummaryFeatures(SerializationContext context) { @@ -241,9 +249,11 @@ public class RawRankProfile implements RankProfilesConfig.Producer { for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); // Is the feature a function? - if (context.getFunction(referenceNode.getName()) != null) { - context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()), - referenceNode.toString(new StringBuilder(), context, null, null).toString()); + ExpressionFunction function = context.getFunction(referenceNode.getName()); + if (function != null) { + String propertyName = RankingExpression.propertyName(referenceNode.getName()); + String expressionString = function.getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + context.addFunctionSerialization(propertyName, expressionString); ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput()); functionSummaryFeatures.put(referenceNode.getName(), newReferenceNode); i.remove(); // Will add the expanded one in next block |