diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-07 15:23:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-07 15:23:45 +0100 |
commit | 56057c40b54fc5b031ab9c2bfea9450e0ee77993 (patch) | |
tree | 956656920d38c4f508a19bfa5cd222c263ae9123 | |
parent | 0d569d1ba3661484b3d30c94dadae9d0444ffe98 (diff) | |
parent | 1197f63fe1a32b3e17493c9387527ac1c4e40cff (diff) |
Merge pull request #26338 from vespa-engine/arnej/declare-more-function-return-types
Arnej/declare more function return types
5 files changed, 26 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index acb125197d2..6272563f833 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -248,15 +248,15 @@ public class RawRankProfile implements RankProfilesConfig.Producer { SerializationContext context) { for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { String propertyName = RankingExpression.propertyName(e.getKey()); - if (context.serializedFunctions().containsKey(propertyName)) continue; + if (! context.serializedFunctions().containsKey(propertyName)) { - String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); + String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); + context.addFunctionSerialization(propertyName, expressionString); + e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) + .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue())); + } + e.getValue().function().returnType().ifPresent(t -> context.addFunctionTypeSerialization(e.getKey(), t)); - context.addFunctionSerialization(propertyName, expressionString); - e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) - .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue())); - if (e.getValue().function().returnType().isPresent()) - context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get()); // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types // throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved"); } @@ -274,6 +274,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { String propertyName = RankingExpression.propertyName(referenceNode.getName()); String expressionString = function.getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); + function.returnType().ifPresent(t -> context.addFunctionTypeSerialization(referenceNode.getName(), t)); var backendReferenceNode = new ReferenceNode(wrapInRankingExpression(referenceNode.getName()), referenceNode.getArguments().expressions(), referenceNode.getOutput()); diff --git a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg index e3947e9e46f..202669ae049 100644 --- a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg +++ b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg @@ -351,6 +351,8 @@ rankprofile[].fef.property[].name "rankingExpression(myplus).rankingScript" rankprofile[].fef.property[].value "attribute(foo1) + attribute(foo2)" rankprofile[].fef.property[].name "rankingExpression(mymul).rankingScript" rankprofile[].fef.property[].value "attribute(t1) * query(fromq)" +rankprofile[].fef.property[].name "rankingExpression(mymul).type" +rankprofile[].fef.property[].value "tensor(m{},v[3])" rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "attribute(foo1)" rankprofile[].fef.property[].name "vespa.rank.secondphase" diff --git a/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg b/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg index d084401d920..ea2d051484b 100644 --- a/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg +++ b/config-model/src/test/derived/renamedfeatures/rank-profiles.cfg @@ -56,6 +56,8 @@ rankprofile[].fef.property[].value "tensor(m{},v[3])" rankprofile[].name "withmf" rankprofile[].fef.property[].name "rankingExpression(mymul).rankingScript" rankprofile[].fef.property[].value "attribute(t1) * query(fromq)" +rankprofile[].fef.property[].name "rankingExpression(mymul).type" +rankprofile[].fef.property[].value "tensor(m{},v[3])" rankprofile[].fef.property[].name "rankingExpression(myplus).rankingScript" rankprofile[].fef.property[].value "attribute(foo1) + attribute(foo2)" rankprofile[].fef.property[].name "vespa.rank.firstphase" diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 8c520e87001..76869932a3e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -129,11 +129,15 @@ public class RankProfilesConfigImporter { referencedFunctions.put(argReference, function); } else if (returnType.isPresent()) { // Return type always follows the function in properties - ExpressionFunction function = referencedFunctions.get(returnType.get()); - function = function.withReturnType(TensorType.fromSpec(property.value())); - if (returnType.get().isFree()) - functions.put(returnType.get(), function); - referencedFunctions.put(returnType.get(), function); + FunctionReference functionRef = returnType.get(); + ExpressionFunction function = referencedFunctions.get(functionRef); + TensorType type = TensorType.fromSpec(property.value()); + function = function.withReturnType(type); + if (functionRef.isFree()) + functions.put(functionRef, function); + referencedFunctions.put(functionRef, function); + declaredTypes.put(function.getName(), type); // "foo" + declaredTypes.put(functionRef.serialForm(), type); // "rankingExpression(foo)" } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index e2fffd824b9..c157f44be31 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -104,7 +104,11 @@ public class SerializationContext extends FunctionReferenceContext { /** Adds the serialization of the return type of a function */ public void addFunctionTypeSerialization(String functionName, TensorType type) { if (type.rank() == 0) return; // no explicit type implies scalar (aka rank 0 tensor) - serializedFunctions.put(wrapInRankingExpression(functionName) + ".type", type.toString()); + String key = wrapInRankingExpression(functionName) + ".type"; + var old = serializedFunctions.put(key, type.toString()); + if (old != null && !old.equals(type.toString())) { + throw new IllegalArgumentException("conflicting values for " + key + ": " + old + " != " + type.toString()); + } } @Override |