diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 098e6e7a1f6..6148287a536 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -90,11 +90,14 @@ public class RankProfilesConfigImporter { SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo(); ExpressionFunction firstPhase = null; ExpressionFunction secondPhase = null; + ExpressionFunction globalPhase = null; + Map<String, TensorType> declaredTypes = new LinkedHashMap<>(); for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); Optional<FunctionReference> externalReference = FunctionReference.fromExternalSerial(property.name()); Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); + Optional<String> typeDeclaredFeature = fromTypeDeclarationSerial(property.name()); if (externalReference.isPresent()) { RankingExpression expression = largeExpressions.get(property.value()); ExpressionFunction function = new ExpressionFunction(externalReference.get().functionName(), @@ -140,6 +143,13 @@ public class RankProfilesConfigImporter { secondPhase = new ExpressionFunction("secondphase", new ArrayList<>(), new RankingExpression("second-phase", property.value())); } + else if (property.name().equals("vespa.rank.globalphase")) { // Include in addition to functions + globalPhase = new ExpressionFunction("globalphase", new ArrayList<>(), + new RankingExpression("global-phase", property.value())); + } + else if (typeDeclaredFeature.isPresent()) { + declaredTypes.put(typeDeclaredFeature.get(), TensorType.fromSpec(property.value())); + } else { smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value()); } @@ -148,11 +158,13 @@ public class RankProfilesConfigImporter { functions.put(FunctionReference.fromName("firstphase"), firstPhase); if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body functions.put(FunctionReference.fromName("secondphase"), secondPhase); + if (functionByName("globalphase", functions.values()) == null && globalPhase != null) // may be already included, depending on body + functions.put(FunctionReference.fromName("globalphase"), globalPhase); constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels); + return new Model(profile.name(), functions, referencedFunctions, declaredTypes, constants, onnxModels); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -299,4 +311,15 @@ public class RankProfilesConfigImporter { } + private static final Pattern typeDeclarationPattern = + Pattern.compile("vespa[.]type[.]([a-zA-Z0-9]+)[.](.+)"); + + static Optional<String> fromTypeDeclarationSerial(String serialForm) { + Matcher expressionMatcher = typeDeclarationPattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + String name = expressionMatcher.group(1); + String argument = expressionMatcher.group(2); + return Optional.of(name + "(" + argument + ")"); + } + } |