aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-07 15:23:45 +0100
committerGitHub <noreply@github.com>2023-03-07 15:23:45 +0100
commit56057c40b54fc5b031ab9c2bfea9450e0ee77993 (patch)
tree956656920d38c4f508a19bfa5cd222c263ae9123
parent0d569d1ba3661484b3d30c94dadae9d0444ffe98 (diff)
parent1197f63fe1a32b3e17493c9387527ac1c4e40cff (diff)
Merge pull request #26338 from vespa-engine/arnej/declare-more-function-return-types
Arnej/declare more function return types
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java15
-rw-r--r--config-model/src/test/derived/rankingexpression/rank-profiles.cfg2
-rw-r--r--config-model/src/test/derived/renamedfeatures/rank-profiles.cfg2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java6
5 files changed, 26 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
index acb125197d2..6272563f833 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
@@ -248,15 +248,15 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
SerializationContext context) {
for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
String propertyName = RankingExpression.propertyName(e.getKey());
- if (context.serializedFunctions().containsKey(propertyName)) continue;
+ if (! context.serializedFunctions().containsKey(propertyName)) {
- String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString();
+ String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString();
+ context.addFunctionSerialization(propertyName, expressionString);
+ e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()));
+ }
+ e.getValue().function().returnType().ifPresent(t -> context.addFunctionTypeSerialization(e.getKey(), t));
- context.addFunctionSerialization(propertyName, expressionString);
- e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
- .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()));
- if (e.getValue().function().returnType().isPresent())
- context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get());
// else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types
// throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved");
}
@@ -274,6 +274,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
String propertyName = RankingExpression.propertyName(referenceNode.getName());
String expressionString = function.getBody().getRoot().toString(context).toString();
context.addFunctionSerialization(propertyName, expressionString);
+ function.returnType().ifPresent(t -> context.addFunctionTypeSerialization(referenceNode.getName(), t));
var backendReferenceNode = new ReferenceNode(wrapInRankingExpression(referenceNode.getName()),
referenceNode.getArguments().expressions(),
referenceNode.getOutput());
diff --git a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg
index e3947e9e46f..202669ae049 100644
--- a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg
+++ b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg
@@ -351,6 +351,8 @@ rankprofile[].fef.property[].name "rankingExpression(myplus).rankingScript"
rankprofile[].fef.property[].value "attribute(foo1) + attribute(foo2)"
rankprofile[].fef.property[].name "rankingExpression(mymul).rankingScript"
rankprofile[].fef.property[].value "attribute(t1) * query(fromq)"
+rankprofile[].fef.property[].name "rankingExpression(mymul).type"
+rankprofile[].fef.property[].value "tensor(m{},v[3])"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
rankprofile[].fef.property[].value "attribute(foo1)"
rankprofile[].fef.property[].name "vespa.rank.secondphase"
diff --git a/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg b/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg
index d084401d920..ea2d051484b 100644
--- a/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg
+++ b/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg
@@ -56,6 +56,8 @@ rankprofile[].fef.property[].value "tensor(m{},v[3])"
rankprofile[].name "withmf"
rankprofile[].fef.property[].name "rankingExpression(mymul).rankingScript"
rankprofile[].fef.property[].value "attribute(t1) * query(fromq)"
+rankprofile[].fef.property[].name "rankingExpression(mymul).type"
+rankprofile[].fef.property[].value "tensor(m{},v[3])"
rankprofile[].fef.property[].name "rankingExpression(myplus).rankingScript"
rankprofile[].fef.property[].value "attribute(foo1) + attribute(foo2)"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 8c520e87001..76869932a3e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -129,11 +129,15 @@ public class RankProfilesConfigImporter {
referencedFunctions.put(argReference, function);
}
else if (returnType.isPresent()) { // Return type always follows the function in properties
- ExpressionFunction function = referencedFunctions.get(returnType.get());
- function = function.withReturnType(TensorType.fromSpec(property.value()));
- if (returnType.get().isFree())
- functions.put(returnType.get(), function);
- referencedFunctions.put(returnType.get(), function);
+ FunctionReference functionRef = returnType.get();
+ ExpressionFunction function = referencedFunctions.get(functionRef);
+ TensorType type = TensorType.fromSpec(property.value());
+ function = function.withReturnType(type);
+ if (functionRef.isFree())
+ functions.put(functionRef, function);
+ referencedFunctions.put(functionRef, function);
+ declaredTypes.put(function.getName(), type); // "foo"
+ declaredTypes.put(functionRef.serialForm(), type); // "rankingExpression(foo)"
}
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index e2fffd824b9..c157f44be31 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -104,7 +104,11 @@ public class SerializationContext extends FunctionReferenceContext {
/** Adds the serialization of the return type of a function */
public void addFunctionTypeSerialization(String functionName, TensorType type) {
if (type.rank() == 0) return; // no explicit type implies scalar (aka rank 0 tensor)
- serializedFunctions.put(wrapInRankingExpression(functionName) + ".type", type.toString());
+ String key = wrapInRankingExpression(functionName) + ".type";
+ var old = serializedFunctions.put(key, type.toString());
+ if (old != null && !old.equals(type.toString())) {
+ throw new IllegalArgumentException("conflicting values for " + key + ": " + old + " != " + type.toString());
+ }
}
@Override