diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-08 10:31:33 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-08 10:35:17 +0000 |
commit | 5831e7054430d261b48446934aca771b962b269e (patch) | |
tree | 2998223b2f667220a9bbca45a9bd1cdc140cf34d | |
parent | e35952510a7841de872cde216f81d813402fff0f (diff) |
refactor binding extraction, with caching
3 files changed, 199 insertions, 118 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java new file mode 100644 index 00000000000..6b1f60df6f4 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java @@ -0,0 +1,183 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; + +/** + * extract information about needed bindings, arguments, and onnx models from expression functions + */ +class BindingExtractor { + + private final Map<FunctionReference, ExpressionFunction> referencedFunctions; + private final List<OnnxModel> onnxModels; + + public BindingExtractor(Map<FunctionReference, ExpressionFunction> referencedFunctions, List<OnnxModel> onnxModels) { + this.referencedFunctions = referencedFunctions; + this.onnxModels = onnxModels; + } + + static class FunctionInfo { + /** The names which may be bound externally */ + final Set<String> bindTargets = new LinkedHashSet<>(); + + /** The names which needs to be bound externally, subset of the above */ + final Set<String> arguments = new LinkedHashSet<>(); + + /** ONNX models in use */ + final Map<String, OnnxModel> onnxModelsInUse = new LinkedHashMap<>(); + + void merge(FunctionInfo other) { + bindTargets.addAll(other.bindTargets); + arguments.addAll(other.arguments); + onnxModelsInUse.putAll(other.onnxModelsInUse); + } + } + + private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<>(); + + FunctionInfo extractFrom(FunctionReference ref) { + if (functionsInfo.containsKey(ref)) + return functionsInfo.get(ref); + ExpressionFunction function = referencedFunctions.get(ref); + FunctionInfo result = extractFrom(function); + functionsInfo.put(ref, result); + return result; + } + + FunctionInfo extractFrom(ExpressionFunction function) { + if (function == null) + return null; + ExpressionNode functionNode = function.getBody().getRoot(); + return extractBindTargets(functionNode); + } + + private FunctionInfo extractBindTargets(ExpressionNode node) { + var result = new FunctionInfo(); + if (isFunctionReference(node)) { + var opt = FunctionReference.fromSerial(node.toString()); + if (opt.isEmpty()) { + throw new IllegalArgumentException("Could not extract function " + node + " from serialized form '" + node.toString() +"'"); + } + FunctionReference reference = opt.get(); + result.bindTargets.add(reference.serialForm()); + FunctionInfo subInfo = extractFrom(reference); + if (subInfo == null) { + // not available, must be supplied as input + result.arguments.add(reference.serialForm()); + } else { + result.merge(subInfo); + } + return result; + } + else if (isOnnx(node)) { + return extractOnnxTargets(node); + } + else if (isConstant(node)) { + result.bindTargets.add(node.toString()); + return result; + } + else if (node instanceof ReferenceNode) { + result.bindTargets.add(node.toString()); + result.arguments.add(node.toString()); + return result; + } + else if (node instanceof CompositeNode cNode) { + for (ExpressionNode child : cNode.children()) { + result.merge(extractBindTargets(child)); + } + return result; + } + if (node instanceof ConstantNode) { + return result; + } + // TODO check if more node types need consideration here + return result; + } + + /** + * Extract the feature used to evaluate the onnx model. e.g. onnx(name) and add + * that as a bind target and argument. During evaluation, this will be evaluated before + * the rest of the expression and the result is added to the context. Also extract the + * inputs to the model and add them as bind targets and arguments. + */ + private FunctionInfo extractOnnxTargets(ExpressionNode node) { + var result = new FunctionInfo(); + String onnxFeature = node.toString(); + result.bindTargets.add(onnxFeature); + Optional<String> modelName = getArgument(node); + if (modelName.isPresent()) { + for (OnnxModel onnxModel : onnxModels) { + if (onnxModel.name().equals(modelName.get())) { + // Load the model (if not already loaded) to extract inputs + onnxModel.load(); + for(String input : onnxModel.inputs().keySet()) { + result.bindTargets.add(input); + result.arguments.add(input); + } + result.onnxModelsInUse.put(onnxFeature, onnxModel); + return result; + } + } + } + // not found, must be supplied as argument + result.arguments.add(onnxFeature); + return result; + } + + private Optional<String> getArgument(ExpressionNode node) { + if (node instanceof ReferenceNode reference) { + if (reference.getArguments().size() > 0) { + var arg = reference.getArguments().expressions().get(0); + if (arg instanceof ConstantNode) { + return Optional.of(stripQuotes(arg.toString())); + } + if (arg instanceof ReferenceNode refNode) { + return Optional.of(refNode.getName()); + } + } + } + return Optional.empty(); + } + + public static String stripQuotes(String s) { + if (s.length() < 3) { + return s; + } + int lastIdx = s.length() - 1; + char first = s.charAt(0); + char last = s.charAt(lastIdx); + if (first == '"' && last == '"') return s.substring(1, lastIdx); + if (first == '\'' && last == '\'') return s.substring(1, lastIdx); + return s; + } + + private boolean isFunctionReference(ExpressionNode node) { + if ( ! (node instanceof ReferenceNode reference)) return false; + return reference.getName().equals(RANKING_EXPRESSION_WRAPPER) && reference.getArguments().size() == 1; + } + + private boolean isOnnx(ExpressionNode node) { + if ( ! (node instanceof ReferenceNode reference)) return false; + return reference.getName().equals("onnx") || reference.getName().equals("onnxModel"); + } + + private boolean isConstant(ExpressionNode node) { + if ( ! (node instanceof ReferenceNode reference)) return false; + return reference.getName().equals("constant") && reference.getArguments().size() == 1; + } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 47c246c008e..898f5a3a73e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -44,12 +44,12 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** Create a fast lookup, lazy context for a function */ LazyArrayContext(ExpressionFunction function, + BindingExtractor bindingExtractor, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, - List<OnnxModel> onnxModels, Model model) { this.function = function; - this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model); + this.indexedBindings = new IndexedBindings(function, bindingExtractor, referencedFunctions, constants, this, model); } /** @@ -186,19 +186,19 @@ public final class LazyArrayContext extends Context implements ContextIndex { * The given expression and functions may be inspected but cannot be stored. */ IndexedBindings(ExpressionFunction function, + BindingExtractor bindingExtractor, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, - List<OnnxModel> onnxModels, LazyArrayContext owner, - Model model) { + Model model) + { // 1. Determine and prepare bind targets - Set<String> bindTargets = new LinkedHashSet<>(); - Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation - Map<String, OnnxModel> onnxModelsInUse = new HashMap<>(); - extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse); + var functionInfo = bindingExtractor.extractFrom(function); + Set<String> bindTargets = functionInfo.bindTargets; + + this.onnxModels = Map.copyOf(functionInfo.onnxModelsInUse); + this.arguments = Set.copyOf(functionInfo.arguments); // Arguments: Bind targets which need to be bound before invocation - this.onnxModels = Map.copyOf(onnxModelsInUse); - this.arguments = Set.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, missing); @@ -215,10 +215,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { } } - for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) { - Integer index = nameToIndex.get(referencedFunction.getKey().serialForm()); + for (FunctionReference referencedFunction : referencedFunctions.keySet()) { + Integer index = nameToIndex.get(referencedFunction.serialForm()); if (index != null) { // Referenced in this, so bind it - values[index] = new LazyValue(referencedFunction.getKey(), owner, model); + values[index] = new LazyValue(referencedFunction, owner, model); } } } @@ -227,110 +227,6 @@ public final class LazyArrayContext extends Context implements ContextIndex { missingValue = new TensorValue(value).freeze(); } - private void extractBindTargets(ExpressionNode node, - Map<FunctionReference, ExpressionFunction> functions, - Set<String> bindTargets, - Set<String> arguments, - List<OnnxModel> onnxModels, - Map<String, OnnxModel> onnxModelsInUse) { - if (isFunctionReference(node)) { - var opt = FunctionReference.fromSerial(node.toString()); - if (opt.isEmpty()) { - throw new IllegalArgumentException("Could not extract function " + node + " from serialized form '" + node.toString() +"'"); - } - FunctionReference reference = opt.get(); - bindTargets.add(reference.serialForm()); - - ExpressionFunction function = functions.get(reference); - if (function == null) return; // Function not included in this model: Not all models are for standalone use - ExpressionNode functionNode = function.getBody().getRoot(); - extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); - } - else if (isOnnx(node)) { - extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse); - } - else if (isConstant(node)) { - bindTargets.add(node.toString()); - } - else if (node instanceof ReferenceNode) { - bindTargets.add(node.toString()); - arguments.add(node.toString()); - } - else if (node instanceof CompositeNode cNode) { - for (ExpressionNode child : cNode.children()) - extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); - } - } - - /** - * Extract the feature used to evaluate the onnx model. e.g. onnx(name) and add - * that as a bind target and argument. During evaluation, this will be evaluated before - * the rest of the expression and the result is added to the context. Also extract the - * inputs to the model and add them as bind targets and arguments. - */ - private void extractOnnxTargets(ExpressionNode node, - Set<String> bindTargets, - Set<String> arguments, - List<OnnxModel> onnxModels, - Map<String, OnnxModel> onnxModelsInUse) { - Optional<String> modelName = getArgument(node); - if (modelName.isPresent()) { - for (OnnxModel onnxModel : onnxModels) { - if (onnxModel.name().equals(modelName.get())) { - String onnxFeature = node.toString(); - bindTargets.add(onnxFeature); - - // Load the model (if not already loaded) to extract inputs - onnxModel.load(); - - for(String input : onnxModel.inputs().keySet()) { - bindTargets.add(input); - arguments.add(input); - } - onnxModelsInUse.put(onnxFeature, onnxModel); - } - } - } - } - - private Optional<String> getArgument(ExpressionNode node) { - if (node instanceof ReferenceNode reference) { - if (reference.getArguments().size() > 0) { - if (reference.getArguments().expressions().get(0) instanceof ConstantNode) { - ExpressionNode constantNode = reference.getArguments().expressions().get(0); - return Optional.of(stripQuotes(constantNode.toString())); - } - if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) { - return Optional.of(refNode.getName()); - } - } - } - return Optional.empty(); - } - - public static String stripQuotes(String s) { - if (s.codePointAt(0) == '"' && s.codePointAt(s.length()-1) == '"') - return s.substring(1, s.length()-1); - if (s.codePointAt(0) == '\'' && s.codePointAt(s.length()-1) == '\'') - return s.substring(1, s.length()-1); - return s; - } - - private boolean isFunctionReference(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode reference)) return false; - return reference.getName().equals(RANKING_EXPRESSION_WRAPPER) && reference.getArguments().size() == 1; - } - - private boolean isOnnx(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode reference)) return false; - return reference.getName().equals("onnx") || reference.getName().equals("onnxModel"); - } - - private boolean isConstant(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode reference)) return false; - return reference.getName().equals("constant") && reference.getArguments().size() == 1; - } - Value get(int index) { Value value = values[index]; return value == missing ? missingValue : value; diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 1da8121ba8e..f173a6b453f 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -62,11 +62,13 @@ public class Model implements AutoCloseable { List<OnnxModel> onnxModels) { this.name = name; + var bindingExtractor = new BindingExtractor(referencedFunctions, onnxModels); + // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); + LazyArrayContext context = new LazyArrayContext(function.getValue(), bindingExtractor, referencedFunctions, constants, this); contextBuilder.put(function.getValue().getName(), context); if (function.getValue().returnType().isEmpty()) { functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); |