diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-10-01 05:46:22 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-10-01 05:46:22 +0200 |
commit | 9c80048457caab3881f3319aadd0990f65c04937 (patch) | |
tree | d180b1a6a866b53e0c23657a31ebe836d641911f /model-evaluation/src/main | |
parent | 8d80010a385f40d4bb852e6b11810692a67e90ed (diff) |
Include argument type information in functions
Diffstat (limited to 'model-evaluation/src/main')
3 files changed, 55 insertions, 18 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 00fcad94ce8..5bb22b23345 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -1,6 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.collections.Pair; +import com.yahoo.tensor.TensorType; + import java.util.Objects; import java.util.Optional; import java.util.regex.Matcher; @@ -23,6 +26,8 @@ class FunctionReference { private static final Pattern referencePattern = Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); + private static final Pattern argumentTypePattern = + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?"); /** The name of the function referenced */ private final String name; @@ -73,4 +78,22 @@ class FunctionReference { return Optional.of(new FunctionReference(name, instance)); } + /** + * Returns a function reference and argument name string from the given serial form, + * or empty if the string is not a valid function argument serial form + */ + static Optional<Pair<FunctionReference, String>> fromTypeArgumentSerial(String serialForm) { + Matcher expressionMatcher = argumentTypePattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + + String name = expressionMatcher.group(1); + String instance = expressionMatcher.group(2); + String argument = expressionMatcher.group(3); + return Optional.of(new Pair<>(new FunctionReference(name, instance), argument)); + } + + public static FunctionReference fromName(String name) { + return new FunctionReference(name, null); + } + } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 3fb43d73187..ac8f28677a4 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -45,7 +45,6 @@ public class Model { Collection<ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants) { - // TODO: Optimize functions this.name = name; this.functions = ImmutableList.copyOf(functions); @@ -79,7 +78,10 @@ public class Model { public String name() { return name; } - /** Returns an immutable list of the free functions of this */ + /** + * Returns an immutable list of the free functions of this. + * The functions returned always specifies types of all arguments and the return value + */ public List<ExpressionFunction> functions() { return functions; } /** Returns the given function, or throws a IllegalArgumentException if it does not exist */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 7bea2d0825a..f48d76e86f3 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.io.GrowableByteBuffer; @@ -18,7 +19,9 @@ import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -60,26 +63,34 @@ public class RankProfilesConfigImporter { private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException { - List<ExpressionFunction> functions = new ArrayList<>(); - Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); - SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo(); - ExpressionFunction firstPhase = null; - ExpressionFunction secondPhase = null; List<Constant> constants = readLargeConstants(constantsConfig); + Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>(); + Map<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<>(); + SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo(); + ExpressionFunction firstPhase = null; + ExpressionFunction secondPhase = null; for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); + Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); if ( reference.isPresent()) { List<String> arguments = new ArrayList<>(); // TODO: Arguments? RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); + ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), arguments, expression); if (reference.get().isFree()) // make available in model under configured name - functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); // - - // Make all functions, bound or not available under the name they are referenced by in expressions - referencedFunctions.put(reference.get(), - new ExpressionFunction(reference.get().serialForm(), arguments, expression)); + functions.put(reference.get(), function); + // Make all functions, bound or not, available under the name they are referenced by in expressions + referencedFunctions.put(reference.get(), function); + } + else if (argumentType.isPresent()) { // Arguments always follows the function in properties + FunctionReference argReference = argumentType.get().getFirst(); + ExpressionFunction function = referencedFunctions.get(argReference); + function = function.withArgument(argumentType.get().getSecond(), TensorType.fromSpec(property.value())); + if (argReference.isFree()) + functions.put(argReference, function); + referencedFunctions.put(argReference, function); } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), @@ -93,22 +104,23 @@ public class RankProfilesConfigImporter { smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value()); } } - if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body - functions.add(firstPhase); - if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body - functions.add(secondPhase); + if (functionByName("firstphase", functions.values()) == null && firstPhase != null) // may be already included, depending on body + functions.put(FunctionReference.fromName("firstphase"), firstPhase); + if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body + functions.put(FunctionReference.fromName("secondphase"), secondPhase); constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions, referencedFunctions, constants); + return new Model(profile.name(), functions.values(), referencedFunctions, constants); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); } } - private ExpressionFunction functionByName(String name, List<ExpressionFunction> functions) { + // TODO: Replace by lookup in map + private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) { for (ExpressionFunction function : functions) if (function.getName().equals(name)) return function; |