summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 05:46:22 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 05:46:22 +0200
commit9c80048457caab3881f3319aadd0990f65c04937 (patch)
treed180b1a6a866b53e0c23657a31ebe836d641911f /model-evaluation/src/main
parent8d80010a385f40d4bb852e6b11810692a67e90ed (diff)
Include argument type information in functions
Diffstat (limited to 'model-evaluation/src/main')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java23
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java6
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java44
3 files changed, 55 insertions, 18 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 00fcad94ce8..5bb22b23345 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -1,6 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.collections.Pair;
+import com.yahoo.tensor.TensorType;
+
import java.util.Objects;
import java.util.Optional;
import java.util.regex.Matcher;
@@ -23,6 +26,8 @@ class FunctionReference {
private static final Pattern referencePattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
+ private static final Pattern argumentTypePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?");
/** The name of the function referenced */
private final String name;
@@ -73,4 +78,22 @@ class FunctionReference {
return Optional.of(new FunctionReference(name, instance));
}
+ /**
+ * Returns a function reference and argument name string from the given serial form,
+ * or empty if the string is not a valid function argument serial form
+ */
+ static Optional<Pair<FunctionReference, String>> fromTypeArgumentSerial(String serialForm) {
+ Matcher expressionMatcher = argumentTypePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ String argument = expressionMatcher.group(3);
+ return Optional.of(new Pair<>(new FunctionReference(name, instance), argument));
+ }
+
+ public static FunctionReference fromName(String name) {
+ return new FunctionReference(name, null);
+ }
+
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 3fb43d73187..ac8f28677a4 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -45,7 +45,6 @@ public class Model {
Collection<ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants) {
- // TODO: Optimize functions
this.name = name;
this.functions = ImmutableList.copyOf(functions);
@@ -79,7 +78,10 @@ public class Model {
public String name() { return name; }
- /** Returns an immutable list of the free functions of this */
+ /**
+ * Returns an immutable list of the free functions of this.
+ * The functions returned always specifies types of all arguments and the return value
+ */
public List<ExpressionFunction> functions() { return functions; }
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 7bea2d0825a..f48d76e86f3 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
@@ -18,7 +19,9 @@ import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -60,26 +63,34 @@ public class RankProfilesConfigImporter {
private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
throws ParseException {
- List<ExpressionFunction> functions = new ArrayList<>();
- Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
- SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
- ExpressionFunction firstPhase = null;
- ExpressionFunction secondPhase = null;
List<Constant> constants = readLargeConstants(constantsConfig);
+ Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>();
+ Map<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<>();
+ SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
+ ExpressionFunction firstPhase = null;
+ ExpressionFunction secondPhase = null;
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
+ Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
if ( reference.isPresent()) {
List<String> arguments = new ArrayList<>(); // TODO: Arguments?
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
+ ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), arguments, expression);
if (reference.get().isFree()) // make available in model under configured name
- functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); //
-
- // Make all functions, bound or not available under the name they are referenced by in expressions
- referencedFunctions.put(reference.get(),
- new ExpressionFunction(reference.get().serialForm(), arguments, expression));
+ functions.put(reference.get(), function);
+ // Make all functions, bound or not, available under the name they are referenced by in expressions
+ referencedFunctions.put(reference.get(), function);
+ }
+ else if (argumentType.isPresent()) { // Arguments always follows the function in properties
+ FunctionReference argReference = argumentType.get().getFirst();
+ ExpressionFunction function = referencedFunctions.get(argReference);
+ function = function.withArgument(argumentType.get().getSecond(), TensorType.fromSpec(property.value()));
+ if (argReference.isFree())
+ functions.put(argReference, function);
+ referencedFunctions.put(argReference, function);
}
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
@@ -93,22 +104,23 @@ public class RankProfilesConfigImporter {
smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
}
}
- if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body
- functions.add(firstPhase);
- if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body
- functions.add(secondPhase);
+ if (functionByName("firstphase", functions.values()) == null && firstPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("firstphase"), firstPhase);
+ if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("secondphase"), secondPhase);
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions, referencedFunctions, constants);
+ return new Model(profile.name(), functions.values(), referencedFunctions, constants);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
}
}
- private ExpressionFunction functionByName(String name, List<ExpressionFunction> functions) {
+ // TODO: Replace by lookup in map
+ private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) {
for (ExpressionFunction function : functions)
if (function.getName().equals(name))
return function;