diff options
Diffstat (limited to 'model-evaluation/src/main')
6 files changed, 228 insertions, 25 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index e373a54bcd1..910aca8aa98 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -100,19 +101,40 @@ public class FunctionEvaluator { } public Tensor evaluate() { + evaluateOnnxModels(); for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - if (context.isMissing(argument.getKey())) - throw new IllegalStateException("Missing argument '" + argument.getKey() + - "': Must be bound to a value of type " + argument.getValue()); - if (! context.get(argument.getKey()).type().isAssignableTo(argument.getValue())) - throw new IllegalStateException("Argument '" + argument.getKey() + - "' must be bound to a value of type " + argument.getValue()); - + checkArgument(argument.getKey(), argument.getValue()); } evaluated = true; return function.getBody().evaluate(context).asTensor(); } + private void checkArgument(String name, TensorType type) { + if (context.isMissing(name)) + throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type); + if (! context.get(name).type().isAssignableTo(type)) + throw new IllegalStateException("Argument '" + name + "' must be bound to a value of type " + type); + } + + /** + * Evaluate ONNX models (if not already evaluated) and add the result back to the context. + */ + private void evaluateOnnxModels() { + for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel onnxModel = entry.getValue(); + if (context.get(onnxFeature).equals(context.defaultValue())) { + Map<String, Tensor> inputs = new HashMap<>(); + for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) { + checkArgument(input.getKey(), input.getValue()); + inputs.put(input.getKey(), context.get(input.getKey()).asTensor()); + } + Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model + context.put(onnxFeature, new TensorValue(result)); + } + } + } + /** Returns the function evaluated by this */ public ExpressionFunction function() { return function; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d66315ef457..a5dcd2719c9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -11,15 +11,18 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; /** @@ -41,9 +44,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { LazyArrayContext(ExpressionFunction function, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, + List<OnnxModel> onnxModels, Model model) { this.function = function; - this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model); + this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model); } /** @@ -117,6 +121,9 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** Returns the (immutable) subset of names in this which must be bound when invoking */ public Set<String> arguments() { return indexedBindings.arguments(); } + /** Returns the set of ONNX models that need to be evaluated on this context */ + public Map<String, OnnxModel> onnxModels() { return indexedBindings.onnxModels(); } + private Integer requireIndexOf(String name) { Integer index = indexedBindings.indexOf(name); if (index == null) @@ -152,18 +159,24 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The current values set */ private final Value[] values; + /** ONNX models indexed by rank feature that calls them */ + private final ImmutableMap<String, OnnxModel> onnxModels; + /** The object instance which encodes "no value is set". The actual value of this is never used. */ private static final Value missing = new DoubleValue(Double.NaN).freeze(); /** The value to return for lookups where no value is set (default: NaN) */ private Value missingValue = new DoubleValue(Double.NaN).freeze(); + private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values, - ImmutableSet<String> arguments) { + ImmutableSet<String> arguments, + ImmutableMap<String, OnnxModel> onnxModels) { this.nameToIndex = nameToIndex; this.values = values; this.arguments = arguments; + this.onnxModels = onnxModels; } /** @@ -173,13 +186,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { IndexedBindings(ExpressionFunction function, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, + List<OnnxModel> onnxModels, LazyArrayContext owner, Model model) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation - extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments); + Map<String, OnnxModel> onnxModelsInUse = new HashMap<>(); + extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse); + this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse); this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, missing); @@ -214,12 +230,18 @@ public final class LazyArrayContext extends Context implements ContextIndex { private void extractBindTargets(ExpressionNode node, Map<FunctionReference, ExpressionFunction> functions, Set<String> bindTargets, - Set<String> arguments) { + Set<String> arguments, + List<OnnxModel> onnxModels, + Map<String, OnnxModel> onnxModelsInUse) { if (isFunctionReference(node)) { FunctionReference reference = FunctionReference.fromSerial(node.toString()).get(); bindTargets.add(reference.serialForm()); - extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments); + ExpressionNode functionNode = functions.get(reference).getBody().getRoot(); + extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); + } + else if (isOnnx(node)) { + extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse); } else if (isConstant(node)) { bindTargets.add(node.toString()); @@ -231,20 +253,81 @@ public final class LazyArrayContext extends Context implements ContextIndex { else if (node instanceof CompositeNode) { CompositeNode cNode = (CompositeNode)node; for (ExpressionNode child : cNode.children()) - extractBindTargets(child, functions, bindTargets, arguments); + extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); + } + } + + /** + * Extract the feature used to evaluate the onnx model. e.g. onnxModel(name) and add + * that as a bind target and argument. During evaluation, this will be evaluated before + * the rest of the expression and the result is added to the context. Also extract the + * inputs to the model and add them as bind targets and arguments. + */ + private void extractOnnxTargets(ExpressionNode node, + Set<String> bindTargets, + Set<String> arguments, + List<OnnxModel> onnxModels, + Map<String, OnnxModel> onnxModelsInUse) { + Optional<String> modelName = getArgument(node); + if (modelName.isPresent()) { + for (OnnxModel onnxModel : onnxModels) { + if (onnxModel.name().equals(modelName.get())) { + String onnxFeature = node.toString(); + bindTargets.add(onnxFeature); + arguments.add(onnxFeature); + + // Load the model (if not already loaded) to extract inputs + onnxModel.load(); + + for(String input : onnxModel.inputs().keySet()) { + bindTargets.add(input); + arguments.add(input); + } + onnxModelsInUse.put(onnxFeature, onnxModel); + } + } } } + private Optional<String> getArgument(ExpressionNode node) { + if (node instanceof ReferenceNode) { + ReferenceNode reference = (ReferenceNode) node; + if (reference.getArguments().size() > 0) { + if (reference.getArguments().expressions().get(0) instanceof ConstantNode) { + ConstantNode constantNode = (ConstantNode) reference.getArguments().expressions().get(0); + return Optional.of(stripQuotes(constantNode.sourceString())); + } + if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0); + return Optional.of(referenceNode.getName()); + } + } + } + return Optional.empty(); + } + + public static String stripQuotes(String s) { + if (s.codePointAt(0) == '"' && s.codePointAt(s.length()-1) == '"') + return s.substring(1, s.length()-1); + if (s.codePointAt(0) == '\'' && s.codePointAt(s.length()-1) == '\'') + return s.substring(1, s.length()-1); + return s; + } + private boolean isFunctionReference(ExpressionNode node) { if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode)node; return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1; } - private boolean isConstant(ExpressionNode node) { + private boolean isOnnx(ExpressionNode node) { if ( ! (node instanceof ReferenceNode)) return false; + ReferenceNode reference = (ReferenceNode) node; + return reference.getName().equals("onnx") || reference.getName().equals("onnxModel"); + } + private boolean isConstant(ExpressionNode node) { + if ( ! (node instanceof ReferenceNode)) return false; ReferenceNode reference = (ReferenceNode)node; return reference.getName().equals("constant") && reference.getArguments().size() == 1; } @@ -261,12 +344,13 @@ public final class LazyArrayContext extends Context implements ContextIndex { Set<String> names() { return nameToIndex.keySet(); } Set<String> arguments() { return arguments; } Integer indexOf(String name) { return nameToIndex.get(name); } + Map<String, OnnxModel> onnxModels() { return onnxModels; } IndexedBindings copy(Context context) { Value[] valueCopy = new Value[values.length]; for (int i = 0; i < values.length; i++) valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i]; - return new IndexedBindings(nameToIndex, valueCopy, arguments); + return new IndexedBindings(nameToIndex, valueCopy, arguments, onnxModels); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 03bbb436026..40a84a701ec 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -26,7 +26,7 @@ import java.util.stream.Collectors; @Beta public class Model { - /** The prefix generated by mode-integration/../IntermediateOperation */ + /** The prefix generated by model-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; private final String name; @@ -50,25 +50,37 @@ public class Model { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), Collections.emptyMap(), + Collections.emptyList(), Collections.emptyList()); } Model(String name, Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, - List<Constant> constants) { + List<Constant> constants, + List<OnnxModel> onnxModels) { this.name = name; // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this); + LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); contextBuilder.put(function.getValue().getName(), context); if ( ! function.getValue().returnType().isPresent()) { functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } + for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel onnxModel = entry.getValue(); + for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) { + functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue())); + } + TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName()); + functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType)); + } + for (String argument : context.arguments()) { if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) { // Internal (generated) functions do not have type info - add arguments diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index a0b859bf930..88766da67fc 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -7,6 +7,7 @@ import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.util.Map; @@ -27,8 +28,9 @@ public class ModelsEvaluator extends AbstractComponent { @Inject public ModelsEvaluator(RankProfilesConfig config, RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig)); + this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig, onnxModelsConfig)); } public ModelsEvaluator(Map<String, Model> models) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java new file mode 100644 index 00000000000..dc27c43ef70 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -0,0 +1,57 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.io.File; +import java.util.Map; + +/** + * A named ONNX model that should be evaluated with OnnxEvaluator. + * + * @author lesters + */ +class OnnxModel { + + private final String name; + private final File modelFile; + + private OnnxEvaluator evaluator; + + OnnxModel(String name, File modelFile) { + this.name = name; + this.modelFile = modelFile; + } + + public String name() { + return name; + } + + public void load() { + if (evaluator == null) { + evaluator = new OnnxEvaluator(modelFile.getPath()); + } + } + + public Map<String, TensorType> inputs() { + return evaluator().getInputInfo(); + } + + public Map<String, TensorType> outputs() { + return evaluator().getOutputInfo(); + } + + public Tensor evaluate(Map<String, Tensor> inputs, String output) { + return evaluator().evaluate(inputs, output); + } + + private OnnxEvaluator evaluator() { + if (evaluator == null) { + throw new IllegalStateException("ONNX model has not been loaded."); + } + return evaluator; + } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index fb424439592..1bdb2810ddf 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.io.File; @@ -48,11 +49,13 @@ public class RankProfilesConfigImporter { * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { + public Map<String, Model> importFrom(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig) { try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile, constantsConfig); + Model model = importProfile(profile, constantsConfig, onnxModelsConfig); models.put(model.name(), model); } return models; @@ -62,9 +65,12 @@ public class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) + private Model importProfile(RankProfilesConfig.Rankprofile profile, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig) throws ParseException { + List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig); List<Constant> constants = readLargeConstants(constantsConfig); Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>(); @@ -76,7 +82,7 @@ public class RankProfilesConfigImporter { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); - if ( reference.isPresent()) { + if (reference.isPresent()) { RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), @@ -122,7 +128,7 @@ public class RankProfilesConfigImporter { constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions, referencedFunctions, constants); + return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -136,6 +142,26 @@ public class RankProfilesConfigImporter { return null; } + private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) { + List<OnnxModel> onnxModels = new ArrayList<>(); + if (onnxModelsConfig != null) { + for (OnnxModelsConfig.Model onnxModelConfig : onnxModelsConfig.model()) { + onnxModels.add(readOnnxModelConfig(onnxModelConfig)); + } + } + return onnxModels; + } + + private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model onnxModelConfig) { + try { + String name = onnxModelConfig.name(); + File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS); + return new OnnxModel(name, file); + } catch (InterruptedException e) { + throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); + } + } + private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) { List<Constant> constants = new ArrayList<>(); |