aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2021-08-26 14:37:07 +0200
committerHenning Baldersheim <balder@yahoo-inc.com>2021-08-26 16:27:04 +0200
commit3bb0f95cbcad3d3a28168a3fe49de9118fb71ef2 (patch)
treed9578f31b3d92f3d9e8173f6153c08a10909b794 /model-evaluation
parent2c571f88d53efab97b70c67ae4f659bd5e4a1a26 (diff)
Handle external expressions in model evaluation too.
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java13
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java38
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java11
3 files changed, 60 insertions, 2 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index fa45920f3c8..f7c47e83df9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -2,7 +2,6 @@
package ai.vespa.models.evaluation;
import com.yahoo.collections.Pair;
-import com.yahoo.tensor.TensorType;
import java.util.Objects;
import java.util.Optional;
@@ -26,6 +25,8 @@ class FunctionReference {
private static final Pattern referencePattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
+ private static final Pattern externalReferencePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.expressionName)?");
private static final Pattern argumentTypePattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?");
private static final Pattern returnTypePattern =
@@ -80,6 +81,16 @@ class FunctionReference {
return Optional.of(new FunctionReference(name, instance));
}
+ /** Returns a function reference from the given serial form, or empty if the string is not a valid reference */
+ static Optional<FunctionReference> fromExternalSerial(String serialForm) {
+ Matcher expressionMatcher = externalReferencePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ return Optional.of(new FunctionReference(name, instance));
+ }
+
/**
* Returns a function reference and argument name string from the given serial form,
* or empty if the string is not a valid function argument serial form
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 06ca7a60f4c..eccc236f0ca 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -12,6 +12,7 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.text.Utf8;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
@@ -82,6 +83,7 @@ public class RankProfilesConfigImporter {
List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig);
List<Constant> constants = readLargeConstants(constantsConfig);
+ Map<String, RankingExpression> largeExpressions = readLargeExpressions(expressionsConfig);
Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>();
Map<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<>();
@@ -90,9 +92,21 @@ public class RankProfilesConfigImporter {
ExpressionFunction secondPhase = null;
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
+ Optional<FunctionReference> externalReference = FunctionReference.fromExternalSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
- if (reference.isPresent()) {
+ if (externalReference.isPresent()) {
+ RankingExpression expression = largeExpressions.get(property.value());
+ ExpressionFunction function = new ExpressionFunction(externalReference.get().functionName(),
+ Collections.emptyList(),
+ expression);
+
+ if (externalReference.get().isFree()) // make available in model under configured name
+ functions.put(externalReference.get(), function);
+ // Make all functions, bound or not, available under the name they are referenced by in expressions
+ referencedFunctions.put(externalReference.get(), function);
+ }
+ else if (reference.isPresent()) {
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
ExpressionFunction function = new ExpressionFunction(reference.get().functionName(),
Collections.emptyList(),
@@ -184,6 +198,28 @@ public class RankProfilesConfigImporter {
return constants;
}
+ private Map<String, RankingExpression> readLargeExpressions(RankingExpressionsConfig expressionsConfig) throws ParseException {
+ Map<String, RankingExpression> expressions = new HashMap<>();
+
+ for (RankingExpressionsConfig.Expression expression : expressionsConfig.expression()) {
+ expressions.put(expression.name(), readExpressionFromFile(expression.name(), expression.fileref()));
+ }
+ return expressions;
+ }
+
+ protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException {
+ try {
+ File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS);
+ return new RankingExpression(name, Utf8.toString(IOUtils.readFileBytes(file)));
+ }
+ catch (InterruptedException e) {
+ throw new IllegalStateException("Gave up waiting for expression " + name);
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) {
try {
File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS);
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index b6878f4ea1a..5fe5abef645 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -10,9 +10,12 @@ import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.text.Utf8;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
@@ -95,6 +98,14 @@ public class ModelTester {
}
}
+ @Override
+ protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException {
+ try {
+ return new RankingExpression(name, Utf8.toString(IOUtils.readFileBytes(constantsPath.append(name).toFile())));
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Missing expression file '" + name + "'", e);
+ }
+ }
}
}