From 3bb0f95cbcad3d3a28168a3fe49de9118fb71ef2 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Thu, 26 Aug 2021 14:37:07 +0200 Subject: Handle external expressions in model evaluation too. --- .../vespa/models/evaluation/FunctionReference.java | 13 +++++++- .../evaluation/RankProfilesConfigImporter.java | 38 +++++++++++++++++++++- .../ai/vespa/models/evaluation/ModelTester.java | 11 +++++++ 3 files changed, 60 insertions(+), 2 deletions(-) (limited to 'model-evaluation/src') diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index fa45920f3c8..f7c47e83df9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -2,7 +2,6 @@ package ai.vespa.models.evaluation; import com.yahoo.collections.Pair; -import com.yahoo.tensor.TensorType; import java.util.Objects; import java.util.Optional; @@ -26,6 +25,8 @@ class FunctionReference { private static final Pattern referencePattern = Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); + private static final Pattern externalReferencePattern = + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.expressionName)?"); private static final Pattern argumentTypePattern = Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?"); private static final Pattern returnTypePattern = @@ -80,6 +81,16 @@ class FunctionReference { return Optional.of(new FunctionReference(name, instance)); } + /** Returns a function reference from the given serial form, or empty if the string is not a valid reference */ + static Optional fromExternalSerial(String serialForm) { + Matcher expressionMatcher = externalReferencePattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + + String name = expressionMatcher.group(1); + String instance = expressionMatcher.group(2); + return Optional.of(new FunctionReference(name, instance)); + } + /** * Returns a function reference and argument name string from the given serial form, * or empty if the string is not a valid function argument serial form diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 06ca7a60f4c..eccc236f0ca 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -12,6 +12,7 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.text.Utf8; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; @@ -82,6 +83,7 @@ public class RankProfilesConfigImporter { List onnxModels = readOnnxModelsConfig(onnxModelsConfig); List constants = readLargeConstants(constantsConfig); + Map largeExpressions = readLargeExpressions(expressionsConfig); Map functions = new LinkedHashMap<>(); Map referencedFunctions = new LinkedHashMap<>(); @@ -90,9 +92,21 @@ public class RankProfilesConfigImporter { ExpressionFunction secondPhase = null; for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional reference = FunctionReference.fromSerial(property.name()); + Optional externalReference = FunctionReference.fromExternalSerial(property.name()); Optional> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); Optional returnType = FunctionReference.fromReturnTypeSerial(property.name()); - if (reference.isPresent()) { + if (externalReference.isPresent()) { + RankingExpression expression = largeExpressions.get(property.value()); + ExpressionFunction function = new ExpressionFunction(externalReference.get().functionName(), + Collections.emptyList(), + expression); + + if (externalReference.get().isFree()) // make available in model under configured name + functions.put(externalReference.get(), function); + // Make all functions, bound or not, available under the name they are referenced by in expressions + referencedFunctions.put(externalReference.get(), function); + } + else if (reference.isPresent()) { RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), @@ -184,6 +198,28 @@ public class RankProfilesConfigImporter { return constants; } + private Map readLargeExpressions(RankingExpressionsConfig expressionsConfig) throws ParseException { + Map expressions = new HashMap<>(); + + for (RankingExpressionsConfig.Expression expression : expressionsConfig.expression()) { + expressions.put(expression.name(), readExpressionFromFile(expression.name(), expression.fileref())); + } + return expressions; + } + + protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException { + try { + File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS); + return new RankingExpression(name, Utf8.toString(IOUtils.readFileBytes(file))); + } + catch (InterruptedException e) { + throw new IllegalStateException("Gave up waiting for expression " + name); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { try { File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java index b6878f4ea1a..5fe5abef645 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -10,9 +10,12 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.text.Utf8; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; @@ -95,6 +98,14 @@ public class ModelTester { } } + @Override + protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException { + try { + return new RankingExpression(name, Utf8.toString(IOUtils.readFileBytes(constantsPath.append(name).toFile()))); + } catch (IOException e) { + throw new IllegalArgumentException("Missing expression file '" + name + "'", e); + } + } } } -- cgit v1.2.3