diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2021-08-26 14:37:07 +0200 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2021-08-26 16:27:04 +0200 |
commit | 3bb0f95cbcad3d3a28168a3fe49de9118fb71ef2 (patch) | |
tree | d9578f31b3d92f3d9e8173f6153c08a10909b794 /model-evaluation | |
parent | 2c571f88d53efab97b70c67ae4f659bd5e4a1a26 (diff) |
Handle external expressions in model evaluation too.
Diffstat (limited to 'model-evaluation')
3 files changed, 60 insertions, 2 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index fa45920f3c8..f7c47e83df9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -2,7 +2,6 @@ package ai.vespa.models.evaluation; import com.yahoo.collections.Pair; -import com.yahoo.tensor.TensorType; import java.util.Objects; import java.util.Optional; @@ -26,6 +25,8 @@ class FunctionReference { private static final Pattern referencePattern = Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); + private static final Pattern externalReferencePattern = + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.expressionName)?"); private static final Pattern argumentTypePattern = Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?"); private static final Pattern returnTypePattern = @@ -80,6 +81,16 @@ class FunctionReference { return Optional.of(new FunctionReference(name, instance)); } + /** Returns a function reference from the given serial form, or empty if the string is not a valid reference */ + static Optional<FunctionReference> fromExternalSerial(String serialForm) { + Matcher expressionMatcher = externalReferencePattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + + String name = expressionMatcher.group(1); + String instance = expressionMatcher.group(2); + return Optional.of(new FunctionReference(name, instance)); + } + /** * Returns a function reference and argument name string from the given serial form, * or empty if the string is not a valid function argument serial form diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 06ca7a60f4c..eccc236f0ca 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -12,6 +12,7 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.text.Utf8; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; @@ -82,6 +83,7 @@ public class RankProfilesConfigImporter { List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig); List<Constant> constants = readLargeConstants(constantsConfig); + Map<String, RankingExpression> largeExpressions = readLargeExpressions(expressionsConfig); Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>(); Map<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<>(); @@ -90,9 +92,21 @@ public class RankProfilesConfigImporter { ExpressionFunction secondPhase = null; for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); + Optional<FunctionReference> externalReference = FunctionReference.fromExternalSerial(property.name()); Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); - if (reference.isPresent()) { + if (externalReference.isPresent()) { + RankingExpression expression = largeExpressions.get(property.value()); + ExpressionFunction function = new ExpressionFunction(externalReference.get().functionName(), + Collections.emptyList(), + expression); + + if (externalReference.get().isFree()) // make available in model under configured name + functions.put(externalReference.get(), function); + // Make all functions, bound or not, available under the name they are referenced by in expressions + referencedFunctions.put(externalReference.get(), function); + } + else if (reference.isPresent()) { RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), @@ -184,6 +198,28 @@ public class RankProfilesConfigImporter { return constants; } + private Map<String, RankingExpression> readLargeExpressions(RankingExpressionsConfig expressionsConfig) throws ParseException { + Map<String, RankingExpression> expressions = new HashMap<>(); + + for (RankingExpressionsConfig.Expression expression : expressionsConfig.expression()) { + expressions.put(expression.name(), readExpressionFromFile(expression.name(), expression.fileref())); + } + return expressions; + } + + protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException { + try { + File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS); + return new RankingExpression(name, Utf8.toString(IOUtils.readFileBytes(file))); + } + catch (InterruptedException e) { + throw new IllegalStateException("Gave up waiting for expression " + name); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { try { File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java index b6878f4ea1a..5fe5abef645 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -10,9 +10,12 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.text.Utf8; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; @@ -95,6 +98,14 @@ public class ModelTester { } } + @Override + protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException { + try { + return new RankingExpression(name, Utf8.toString(IOUtils.readFileBytes(constantsPath.append(name).toFile()))); + } catch (IOException e) { + throw new IllegalArgumentException("Missing expression file '" + name + "'", e); + } + } } } |