summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-04-26 13:45:54 +0200
committerGitHub <noreply@github.com>2020-04-26 13:45:54 +0200
commit1ea6410fd642bb92a03b5798387fa977286af167 (patch)
tree2a4d3fa648d581c472e2ceb83484a24edf0b739a /config-model/src/main/java
parent84e4a15ed3545f59f8f38f4fe8c4770098cce642 (diff)
parentd95c4ceb6c4db635e93b605d719f01db9d5f2e6f (diff)
Merge pull request #13067 from vespa-engine/lesters/bert-searchlib-and-config-model
Lesters/bert searchlib and config model
Diffstat (limited to 'config-model/src/main/java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java34
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java34
3 files changed, 55 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 6de7c985326..4011ce43841 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -41,6 +41,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
private final Map<Reference, TensorType> resolvedTypes = new HashMap<>();
+ /** To avoid re-resolving diamond-shaped dependencies */
+ private final Map<Reference, TensorType> globallyResolvedTypes;
+
/** For invocation loop detection */
private final Deque<Reference> currentResolutionCallStack;
@@ -53,6 +56,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
this.currentResolutionCallStack = new ArrayDeque<>();
this.queryFeaturesNotDeclared = new TreeSet<>();
tensorsAreUsed = false;
+ globallyResolvedTypes = new HashMap<>();
}
private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions,
@@ -60,12 +64,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
Map<Reference, TensorType> featureTypes,
Deque<Reference> currentResolutionCallStack,
SortedSet<Reference> queryFeaturesNotDeclared,
- boolean tensorsAreUsed) {
+ boolean tensorsAreUsed,
+ Map<Reference, TensorType> globallyResolvedTypes) {
super(functions, bindings);
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = currentResolutionCallStack;
this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
this.tensorsAreUsed = tensorsAreUsed;
+ this.globallyResolvedTypes = globallyResolvedTypes;
}
public void setType(Reference reference, TensorType type) {
@@ -82,11 +88,25 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
resolvedTypes.clear();
}
- @Override
+ private boolean referenceCanBeResolvedGlobally(Reference reference) {
+ Optional<ExpressionFunction> function = functionInvocation(reference);
+ return function.isPresent() && function.get().arguments().size() == 0;
+ // are there other cases we would like to resolve globally?
+ }
+
+ @Override
public TensorType getType(Reference reference) {
// computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
+
+ boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference);
+
TensorType resolvedType = resolvedTypes.get(reference);
- if (resolvedType != null) return resolvedType;
+ if (resolvedType == null && canBeResolvedGlobally) {
+ resolvedType = globallyResolvedTypes.get(reference);
+ }
+ if (resolvedType != null) {
+ return resolvedType;
+ }
resolvedType = resolveType(reference);
if (resolvedType == null)
@@ -94,6 +114,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
resolvedTypes.put(reference, resolvedType);
if (resolvedType.rank() > 0)
tensorsAreUsed = true;
+
+ if (canBeResolvedGlobally) {
+ globallyResolvedTypes.put(reference, resolvedType);
+ }
+
return resolvedType;
}
@@ -254,7 +279,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
featureTypes,
currentResolutionCallStack,
queryFeaturesNotDeclared,
- tensorsAreUsed);
+ tensorsAreUsed,
+ globallyResolvedTypes);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 23eb814de81..ea126123a25 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -680,11 +680,12 @@ public class RankProfile implements Cloneable {
Map<String, RankingExpressionFunction> inlineFunctions =
compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms);
+ firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+ secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+
// Function compiling second pass: compile all functions and insert previously compiled inline functions
functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms);
- firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
- secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
}
private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 1a22b98fd9f..c3c10139684 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -200,9 +200,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
if (functions.isEmpty()) return;
List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
-
Map<String, String> functionProperties = new LinkedHashMap<>();
- functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions));
if (firstPhaseRanking != null) {
functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions));
@@ -211,20 +209,30 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions));
}
+ SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
+ replaceFunctionSummaryFeatures(context);
+
+ // First phase, second phase and summary features should add all required functions to the context.
+ // However, we need to add any functions not referenced in those anyway for model-evaluation.
+ deriveFunctionProperties(functions, functionExpressions, functionProperties);
+
for (Map.Entry<String, String> e : functionProperties.entrySet()) {
rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue()));
}
- SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
- replaceFunctionSummaryFeatures(context);
}
- private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions,
- List<ExpressionFunction> functionExpressions) {
- SerializationContext context = new SerializationContext(functionExpressions);
+ private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions,
+ List<ExpressionFunction> functionExpressions,
+ Map<String, String> functionProperties) {
+ SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
+ String propertyName = RankingExpression.propertyName(e.getKey());
+ if (context.serializedFunctions().containsKey(propertyName)) {
+ continue;
+ }
String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
- context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString);
+ context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString);
for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue());
if (e.getValue().function().returnType().isPresent())
@@ -232,7 +240,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
// else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types
// throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved");
}
- return context.serializedFunctions();
+ functionProperties.putAll(context.serializedFunctions());
}
private void replaceFunctionSummaryFeatures(SerializationContext context) {
@@ -241,9 +249,11 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
// Is the feature a function?
- if (context.getFunction(referenceNode.getName()) != null) {
- context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()),
- referenceNode.toString(new StringBuilder(), context, null, null).toString());
+ ExpressionFunction function = context.getFunction(referenceNode.getName());
+ if (function != null) {
+ String propertyName = RankingExpression.propertyName(referenceNode.getName());
+ String expressionString = function.getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
+ context.addFunctionSerialization(propertyName, expressionString);
ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput());
functionSummaryFeatures.put(referenceNode.getName(), newReferenceNode);
i.remove(); // Will add the expanded one in next block