aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java25
1 files changed, 24 insertions, 1 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 098e6e7a1f6..6148287a536 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -90,11 +90,14 @@ public class RankProfilesConfigImporter {
SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
ExpressionFunction firstPhase = null;
ExpressionFunction secondPhase = null;
+ ExpressionFunction globalPhase = null;
+ Map<String, TensorType> declaredTypes = new LinkedHashMap<>();
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
Optional<FunctionReference> externalReference = FunctionReference.fromExternalSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
+ Optional<String> typeDeclaredFeature = fromTypeDeclarationSerial(property.name());
if (externalReference.isPresent()) {
RankingExpression expression = largeExpressions.get(property.value());
ExpressionFunction function = new ExpressionFunction(externalReference.get().functionName(),
@@ -140,6 +143,13 @@ public class RankProfilesConfigImporter {
secondPhase = new ExpressionFunction("secondphase", new ArrayList<>(),
new RankingExpression("second-phase", property.value()));
}
+ else if (property.name().equals("vespa.rank.globalphase")) { // Include in addition to functions
+ globalPhase = new ExpressionFunction("globalphase", new ArrayList<>(),
+ new RankingExpression("global-phase", property.value()));
+ }
+ else if (typeDeclaredFeature.isPresent()) {
+ declaredTypes.put(typeDeclaredFeature.get(), TensorType.fromSpec(property.value()));
+ }
else {
smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
}
@@ -148,11 +158,13 @@ public class RankProfilesConfigImporter {
functions.put(FunctionReference.fromName("firstphase"), firstPhase);
if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body
functions.put(FunctionReference.fromName("secondphase"), secondPhase);
+ if (functionByName("globalphase", functions.values()) == null && globalPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("globalphase"), globalPhase);
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels);
+ return new Model(profile.name(), functions, referencedFunctions, declaredTypes, constants, onnxModels);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
@@ -299,4 +311,15 @@ public class RankProfilesConfigImporter {
}
+ private static final Pattern typeDeclarationPattern =
+ Pattern.compile("vespa[.]type[.]([a-zA-Z0-9]+)[.](.+)");
+
+ static Optional<String> fromTypeDeclarationSerial(String serialForm) {
+ Matcher expressionMatcher = typeDeclarationPattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+ String name = expressionMatcher.group(1);
+ String argument = expressionMatcher.group(2);
+ return Optional.of(name + "(" + argument + ")");
+ }
+
}