aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 06:08:08 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 06:08:08 +0200
commit0ff988ecf9704faac33f6201cb59349e48846457 (patch)
tree0cceb9c6961836a7b6149798d041e341bedcf903 /model-evaluation/src/main
parent9c80048457caab3881f3319aadd0990f65c04937 (diff)
Resoløve return types whenever possible
Diffstat (limited to 'model-evaluation/src/main')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java15
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java12
2 files changed, 25 insertions, 2 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 5bb22b23345..fa45920f3c8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -28,6 +28,8 @@ class FunctionReference {
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
private static final Pattern argumentTypePattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?");
+ private static final Pattern returnTypePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.type?");
/** The name of the function referenced */
private final String name;
@@ -92,6 +94,19 @@ class FunctionReference {
return Optional.of(new Pair<>(new FunctionReference(name, instance), argument));
}
+ /**
+ * Returns a function reference from the given return type serial form,
+ * or empty if the string is not a valid function return typoe serial form
+ */
+ static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) {
+ Matcher expressionMatcher = returnTypePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ return Optional.of(new FunctionReference(name, instance));
+ }
+
public static FunctionReference fromName(String name) {
return new FunctionReference(name, null);
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index f48d76e86f3..648c6d931a9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -20,6 +20,7 @@ import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
@@ -74,10 +75,10 @@ public class RankProfilesConfigImporter {
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
+ Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
if ( reference.isPresent()) {
- List<String> arguments = new ArrayList<>(); // TODO: Arguments?
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
- ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), arguments, expression);
+ ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression);
if (reference.get().isFree()) // make available in model under configured name
functions.put(reference.get(), function);
@@ -92,6 +93,13 @@ public class RankProfilesConfigImporter {
functions.put(argReference, function);
referencedFunctions.put(argReference, function);
}
+ else if (returnType.isPresent()) { // Return type always follows the function in properties
+ ExpressionFunction function = referencedFunctions.get(returnType.get());
+ function = function.withReturnType(TensorType.fromSpec(property.value()));
+ if (returnType.get().isFree())
+ functions.put(returnType.get(), function);
+ referencedFunctions.put(returnType.get(), function);
+ }
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
new RankingExpression("first-phase", property.value()));