summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-08 10:31:33 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-08 10:35:17 +0000
commit5831e7054430d261b48446934aca771b962b269e (patch)
tree2998223b2f667220a9bbca45a9bd1cdc140cf34d /model-evaluation
parente35952510a7841de872cde216f81d813402fff0f (diff)
refactor binding extraction, with caching
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java183
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java130
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java4
3 files changed, 199 insertions, 118 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
new file mode 100644
index 00000000000..6b1f60df6f4
--- /dev/null
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
@@ -0,0 +1,183 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER;
+
+/**
+ * extract information about needed bindings, arguments, and onnx models from expression functions
+ */
+class BindingExtractor {
+
+ private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
+ private final List<OnnxModel> onnxModels;
+
+ public BindingExtractor(Map<FunctionReference, ExpressionFunction> referencedFunctions, List<OnnxModel> onnxModels) {
+ this.referencedFunctions = referencedFunctions;
+ this.onnxModels = onnxModels;
+ }
+
+ static class FunctionInfo {
+ /** The names which may be bound externally */
+ final Set<String> bindTargets = new LinkedHashSet<>();
+
+ /** The names which needs to be bound externally, subset of the above */
+ final Set<String> arguments = new LinkedHashSet<>();
+
+ /** ONNX models in use */
+ final Map<String, OnnxModel> onnxModelsInUse = new LinkedHashMap<>();
+
+ void merge(FunctionInfo other) {
+ bindTargets.addAll(other.bindTargets);
+ arguments.addAll(other.arguments);
+ onnxModelsInUse.putAll(other.onnxModelsInUse);
+ }
+ }
+
+ private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<>();
+
+ FunctionInfo extractFrom(FunctionReference ref) {
+ if (functionsInfo.containsKey(ref))
+ return functionsInfo.get(ref);
+ ExpressionFunction function = referencedFunctions.get(ref);
+ FunctionInfo result = extractFrom(function);
+ functionsInfo.put(ref, result);
+ return result;
+ }
+
+ FunctionInfo extractFrom(ExpressionFunction function) {
+ if (function == null)
+ return null;
+ ExpressionNode functionNode = function.getBody().getRoot();
+ return extractBindTargets(functionNode);
+ }
+
+ private FunctionInfo extractBindTargets(ExpressionNode node) {
+ var result = new FunctionInfo();
+ if (isFunctionReference(node)) {
+ var opt = FunctionReference.fromSerial(node.toString());
+ if (opt.isEmpty()) {
+ throw new IllegalArgumentException("Could not extract function " + node + " from serialized form '" + node.toString() +"'");
+ }
+ FunctionReference reference = opt.get();
+ result.bindTargets.add(reference.serialForm());
+ FunctionInfo subInfo = extractFrom(reference);
+ if (subInfo == null) {
+ // not available, must be supplied as input
+ result.arguments.add(reference.serialForm());
+ } else {
+ result.merge(subInfo);
+ }
+ return result;
+ }
+ else if (isOnnx(node)) {
+ return extractOnnxTargets(node);
+ }
+ else if (isConstant(node)) {
+ result.bindTargets.add(node.toString());
+ return result;
+ }
+ else if (node instanceof ReferenceNode) {
+ result.bindTargets.add(node.toString());
+ result.arguments.add(node.toString());
+ return result;
+ }
+ else if (node instanceof CompositeNode cNode) {
+ for (ExpressionNode child : cNode.children()) {
+ result.merge(extractBindTargets(child));
+ }
+ return result;
+ }
+ if (node instanceof ConstantNode) {
+ return result;
+ }
+ // TODO check if more node types need consideration here
+ return result;
+ }
+
+ /**
+ * Extract the feature used to evaluate the onnx model. e.g. onnx(name) and add
+ * that as a bind target and argument. During evaluation, this will be evaluated before
+ * the rest of the expression and the result is added to the context. Also extract the
+ * inputs to the model and add them as bind targets and arguments.
+ */
+ private FunctionInfo extractOnnxTargets(ExpressionNode node) {
+ var result = new FunctionInfo();
+ String onnxFeature = node.toString();
+ result.bindTargets.add(onnxFeature);
+ Optional<String> modelName = getArgument(node);
+ if (modelName.isPresent()) {
+ for (OnnxModel onnxModel : onnxModels) {
+ if (onnxModel.name().equals(modelName.get())) {
+ // Load the model (if not already loaded) to extract inputs
+ onnxModel.load();
+ for(String input : onnxModel.inputs().keySet()) {
+ result.bindTargets.add(input);
+ result.arguments.add(input);
+ }
+ result.onnxModelsInUse.put(onnxFeature, onnxModel);
+ return result;
+ }
+ }
+ }
+ // not found, must be supplied as argument
+ result.arguments.add(onnxFeature);
+ return result;
+ }
+
+ private Optional<String> getArgument(ExpressionNode node) {
+ if (node instanceof ReferenceNode reference) {
+ if (reference.getArguments().size() > 0) {
+ var arg = reference.getArguments().expressions().get(0);
+ if (arg instanceof ConstantNode) {
+ return Optional.of(stripQuotes(arg.toString()));
+ }
+ if (arg instanceof ReferenceNode refNode) {
+ return Optional.of(refNode.getName());
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ public static String stripQuotes(String s) {
+ if (s.length() < 3) {
+ return s;
+ }
+ int lastIdx = s.length() - 1;
+ char first = s.charAt(0);
+ char last = s.charAt(lastIdx);
+ if (first == '"' && last == '"') return s.substring(1, lastIdx);
+ if (first == '\'' && last == '\'') return s.substring(1, lastIdx);
+ return s;
+ }
+
+ private boolean isFunctionReference(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode reference)) return false;
+ return reference.getName().equals(RANKING_EXPRESSION_WRAPPER) && reference.getArguments().size() == 1;
+ }
+
+ private boolean isOnnx(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode reference)) return false;
+ return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
+ }
+
+ private boolean isConstant(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode reference)) return false;
+ return reference.getName().equals("constant") && reference.getArguments().size() == 1;
+ }
+
+}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 47c246c008e..898f5a3a73e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -44,12 +44,12 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** Create a fast lookup, lazy context for a function */
LazyArrayContext(ExpressionFunction function,
+ BindingExtractor bindingExtractor,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
- List<OnnxModel> onnxModels,
Model model) {
this.function = function;
- this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model);
+ this.indexedBindings = new IndexedBindings(function, bindingExtractor, referencedFunctions, constants, this, model);
}
/**
@@ -186,19 +186,19 @@ public final class LazyArrayContext extends Context implements ContextIndex {
* The given expression and functions may be inspected but cannot be stored.
*/
IndexedBindings(ExpressionFunction function,
+ BindingExtractor bindingExtractor,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
- List<OnnxModel> onnxModels,
LazyArrayContext owner,
- Model model) {
+ Model model)
+ {
// 1. Determine and prepare bind targets
- Set<String> bindTargets = new LinkedHashSet<>();
- Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
- Map<String, OnnxModel> onnxModelsInUse = new HashMap<>();
- extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);
+ var functionInfo = bindingExtractor.extractFrom(function);
+ Set<String> bindTargets = functionInfo.bindTargets;
+
+ this.onnxModels = Map.copyOf(functionInfo.onnxModelsInUse);
+ this.arguments = Set.copyOf(functionInfo.arguments); // Arguments: Bind targets which need to be bound before invocation
- this.onnxModels = Map.copyOf(onnxModelsInUse);
- this.arguments = Set.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, missing);
@@ -215,10 +215,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
}
- for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) {
- Integer index = nameToIndex.get(referencedFunction.getKey().serialForm());
+ for (FunctionReference referencedFunction : referencedFunctions.keySet()) {
+ Integer index = nameToIndex.get(referencedFunction.serialForm());
if (index != null) { // Referenced in this, so bind it
- values[index] = new LazyValue(referencedFunction.getKey(), owner, model);
+ values[index] = new LazyValue(referencedFunction, owner, model);
}
}
}
@@ -227,110 +227,6 @@ public final class LazyArrayContext extends Context implements ContextIndex {
missingValue = new TensorValue(value).freeze();
}
- private void extractBindTargets(ExpressionNode node,
- Map<FunctionReference, ExpressionFunction> functions,
- Set<String> bindTargets,
- Set<String> arguments,
- List<OnnxModel> onnxModels,
- Map<String, OnnxModel> onnxModelsInUse) {
- if (isFunctionReference(node)) {
- var opt = FunctionReference.fromSerial(node.toString());
- if (opt.isEmpty()) {
- throw new IllegalArgumentException("Could not extract function " + node + " from serialized form '" + node.toString() +"'");
- }
- FunctionReference reference = opt.get();
- bindTargets.add(reference.serialForm());
-
- ExpressionFunction function = functions.get(reference);
- if (function == null) return; // Function not included in this model: Not all models are for standalone use
- ExpressionNode functionNode = function.getBody().getRoot();
- extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
- }
- else if (isOnnx(node)) {
- extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse);
- }
- else if (isConstant(node)) {
- bindTargets.add(node.toString());
- }
- else if (node instanceof ReferenceNode) {
- bindTargets.add(node.toString());
- arguments.add(node.toString());
- }
- else if (node instanceof CompositeNode cNode) {
- for (ExpressionNode child : cNode.children())
- extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
- }
- }
-
- /**
- * Extract the feature used to evaluate the onnx model. e.g. onnx(name) and add
- * that as a bind target and argument. During evaluation, this will be evaluated before
- * the rest of the expression and the result is added to the context. Also extract the
- * inputs to the model and add them as bind targets and arguments.
- */
- private void extractOnnxTargets(ExpressionNode node,
- Set<String> bindTargets,
- Set<String> arguments,
- List<OnnxModel> onnxModels,
- Map<String, OnnxModel> onnxModelsInUse) {
- Optional<String> modelName = getArgument(node);
- if (modelName.isPresent()) {
- for (OnnxModel onnxModel : onnxModels) {
- if (onnxModel.name().equals(modelName.get())) {
- String onnxFeature = node.toString();
- bindTargets.add(onnxFeature);
-
- // Load the model (if not already loaded) to extract inputs
- onnxModel.load();
-
- for(String input : onnxModel.inputs().keySet()) {
- bindTargets.add(input);
- arguments.add(input);
- }
- onnxModelsInUse.put(onnxFeature, onnxModel);
- }
- }
- }
- }
-
- private Optional<String> getArgument(ExpressionNode node) {
- if (node instanceof ReferenceNode reference) {
- if (reference.getArguments().size() > 0) {
- if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
- ExpressionNode constantNode = reference.getArguments().expressions().get(0);
- return Optional.of(stripQuotes(constantNode.toString()));
- }
- if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) {
- return Optional.of(refNode.getName());
- }
- }
- }
- return Optional.empty();
- }
-
- public static String stripQuotes(String s) {
- if (s.codePointAt(0) == '"' && s.codePointAt(s.length()-1) == '"')
- return s.substring(1, s.length()-1);
- if (s.codePointAt(0) == '\'' && s.codePointAt(s.length()-1) == '\'')
- return s.substring(1, s.length()-1);
- return s;
- }
-
- private boolean isFunctionReference(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode reference)) return false;
- return reference.getName().equals(RANKING_EXPRESSION_WRAPPER) && reference.getArguments().size() == 1;
- }
-
- private boolean isOnnx(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode reference)) return false;
- return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
- }
-
- private boolean isConstant(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode reference)) return false;
- return reference.getName().equals("constant") && reference.getArguments().size() == 1;
- }
-
Value get(int index) {
Value value = values[index];
return value == missing ? missingValue : value;
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 1da8121ba8e..f173a6b453f 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -62,11 +62,13 @@ public class Model implements AutoCloseable {
List<OnnxModel> onnxModels) {
this.name = name;
+ var bindingExtractor = new BindingExtractor(referencedFunctions, onnxModels);
+
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
+ LazyArrayContext context = new LazyArrayContext(function.getValue(), bindingExtractor, referencedFunctions, constants, this);
contextBuilder.put(function.getValue().getName(), context);
if (function.getValue().returnType().isEmpty()) {
functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));