aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java36
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java102
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java18
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java57
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java36
6 files changed, 228 insertions, 25 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index e373a54bcd1..910aca8aa98 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
@@ -100,19 +101,40 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
+ evaluateOnnxModels();
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- if (context.isMissing(argument.getKey()))
- throw new IllegalStateException("Missing argument '" + argument.getKey() +
- "': Must be bound to a value of type " + argument.getValue());
- if (! context.get(argument.getKey()).type().isAssignableTo(argument.getValue()))
- throw new IllegalStateException("Argument '" + argument.getKey() +
- "' must be bound to a value of type " + argument.getValue());
-
+ checkArgument(argument.getKey(), argument.getValue());
}
evaluated = true;
return function.getBody().evaluate(context).asTensor();
}
+ private void checkArgument(String name, TensorType type) {
+ if (context.isMissing(name))
+ throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type);
+ if (! context.get(name).type().isAssignableTo(type))
+ throw new IllegalStateException("Argument '" + name + "' must be bound to a value of type " + type);
+ }
+
+ /**
+ * Evaluate ONNX models (if not already evaluated) and add the result back to the context.
+ */
+ private void evaluateOnnxModels() {
+ for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) {
+ String onnxFeature = entry.getKey();
+ OnnxModel onnxModel = entry.getValue();
+ if (context.get(onnxFeature).equals(context.defaultValue())) {
+ Map<String, Tensor> inputs = new HashMap<>();
+ for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) {
+ checkArgument(input.getKey(), input.getValue());
+ inputs.put(input.getKey(), context.get(input.getKey()).asTensor());
+ }
+ Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model
+ context.put(onnxFeature, new TensorValue(result));
+ }
+ }
+ }
+
/** Returns the function evaluated by this */
public ExpressionFunction function() { return function; }
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index d66315ef457..a5dcd2719c9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -11,15 +11,18 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
/**
@@ -41,9 +44,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
LazyArrayContext(ExpressionFunction function,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
+ List<OnnxModel> onnxModels,
Model model) {
this.function = function;
- this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model);
+ this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model);
}
/**
@@ -117,6 +121,9 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** Returns the (immutable) subset of names in this which must be bound when invoking */
public Set<String> arguments() { return indexedBindings.arguments(); }
+ /** Returns the set of ONNX models that need to be evaluated on this context */
+ public Map<String, OnnxModel> onnxModels() { return indexedBindings.onnxModels(); }
+
private Integer requireIndexOf(String name) {
Integer index = indexedBindings.indexOf(name);
if (index == null)
@@ -152,18 +159,24 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The current values set */
private final Value[] values;
+ /** ONNX models indexed by rank feature that calls them */
+ private final ImmutableMap<String, OnnxModel> onnxModels;
+
/** The object instance which encodes "no value is set". The actual value of this is never used. */
private static final Value missing = new DoubleValue(Double.NaN).freeze();
/** The value to return for lookups where no value is set (default: NaN) */
private Value missingValue = new DoubleValue(Double.NaN).freeze();
+
private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
Value[] values,
- ImmutableSet<String> arguments) {
+ ImmutableSet<String> arguments,
+ ImmutableMap<String, OnnxModel> onnxModels) {
this.nameToIndex = nameToIndex;
this.values = values;
this.arguments = arguments;
+ this.onnxModels = onnxModels;
}
/**
@@ -173,13 +186,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
IndexedBindings(ExpressionFunction function,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
+ List<OnnxModel> onnxModels,
LazyArrayContext owner,
Model model) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
- extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments);
+ Map<String, OnnxModel> onnxModelsInUse = new HashMap<>();
+ extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);
+ this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse);
this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, missing);
@@ -214,12 +230,18 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private void extractBindTargets(ExpressionNode node,
Map<FunctionReference, ExpressionFunction> functions,
Set<String> bindTargets,
- Set<String> arguments) {
+ Set<String> arguments,
+ List<OnnxModel> onnxModels,
+ Map<String, OnnxModel> onnxModelsInUse) {
if (isFunctionReference(node)) {
FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
bindTargets.add(reference.serialForm());
- extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments);
+ ExpressionNode functionNode = functions.get(reference).getBody().getRoot();
+ extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
+ }
+ else if (isOnnx(node)) {
+ extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse);
}
else if (isConstant(node)) {
bindTargets.add(node.toString());
@@ -231,20 +253,81 @@ public final class LazyArrayContext extends Context implements ContextIndex {
else if (node instanceof CompositeNode) {
CompositeNode cNode = (CompositeNode)node;
for (ExpressionNode child : cNode.children())
- extractBindTargets(child, functions, bindTargets, arguments);
+ extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
+ }
+ }
+
+ /**
+ * Extract the feature used to evaluate the onnx model. e.g. onnxModel(name) and add
+ * that as a bind target and argument. During evaluation, this will be evaluated before
+ * the rest of the expression and the result is added to the context. Also extract the
+ * inputs to the model and add them as bind targets and arguments.
+ */
+ private void extractOnnxTargets(ExpressionNode node,
+ Set<String> bindTargets,
+ Set<String> arguments,
+ List<OnnxModel> onnxModels,
+ Map<String, OnnxModel> onnxModelsInUse) {
+ Optional<String> modelName = getArgument(node);
+ if (modelName.isPresent()) {
+ for (OnnxModel onnxModel : onnxModels) {
+ if (onnxModel.name().equals(modelName.get())) {
+ String onnxFeature = node.toString();
+ bindTargets.add(onnxFeature);
+ arguments.add(onnxFeature);
+
+ // Load the model (if not already loaded) to extract inputs
+ onnxModel.load();
+
+ for(String input : onnxModel.inputs().keySet()) {
+ bindTargets.add(input);
+ arguments.add(input);
+ }
+ onnxModelsInUse.put(onnxFeature, onnxModel);
+ }
+ }
}
}
+ private Optional<String> getArgument(ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode reference = (ReferenceNode) node;
+ if (reference.getArguments().size() > 0) {
+ if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
+ ConstantNode constantNode = (ConstantNode) reference.getArguments().expressions().get(0);
+ return Optional.of(stripQuotes(constantNode.sourceString()));
+ }
+ if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0);
+ return Optional.of(referenceNode.getName());
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ public static String stripQuotes(String s) {
+ if (s.codePointAt(0) == '"' && s.codePointAt(s.length()-1) == '"')
+ return s.substring(1, s.length()-1);
+ if (s.codePointAt(0) == '\'' && s.codePointAt(s.length()-1) == '\'')
+ return s.substring(1, s.length()-1);
+ return s;
+ }
+
private boolean isFunctionReference(ExpressionNode node) {
if ( ! (node instanceof ReferenceNode)) return false;
-
ReferenceNode reference = (ReferenceNode)node;
return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
}
- private boolean isConstant(ExpressionNode node) {
+ private boolean isOnnx(ExpressionNode node) {
if ( ! (node instanceof ReferenceNode)) return false;
+ ReferenceNode reference = (ReferenceNode) node;
+ return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
+ }
+ private boolean isConstant(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode)) return false;
ReferenceNode reference = (ReferenceNode)node;
return reference.getName().equals("constant") && reference.getArguments().size() == 1;
}
@@ -261,12 +344,13 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Set<String> names() { return nameToIndex.keySet(); }
Set<String> arguments() { return arguments; }
Integer indexOf(String name) { return nameToIndex.get(name); }
+ Map<String, OnnxModel> onnxModels() { return onnxModels; }
IndexedBindings copy(Context context) {
Value[] valueCopy = new Value[values.length];
for (int i = 0; i < values.length; i++)
valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i];
- return new IndexedBindings(nameToIndex, valueCopy, arguments);
+ return new IndexedBindings(nameToIndex, valueCopy, arguments, onnxModels);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 03bbb436026..40a84a701ec 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -26,7 +26,7 @@ import java.util.stream.Collectors;
@Beta
public class Model {
- /** The prefix generated by mode-integration/../IntermediateOperation */
+ /** The prefix generated by model-integration/../IntermediateOperation */
private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
private final String name;
@@ -50,25 +50,37 @@ public class Model {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
Collections.emptyMap(),
+ Collections.emptyList(),
Collections.emptyList());
}
Model(String name,
Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
- List<Constant> constants) {
+ List<Constant> constants,
+ List<OnnxModel> onnxModels) {
this.name = name;
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this);
+ LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
contextBuilder.put(function.getValue().getName(), context);
if ( ! function.getValue().returnType().isPresent()) {
functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
+ for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) {
+ String onnxFeature = entry.getKey();
+ OnnxModel onnxModel = entry.getValue();
+ for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) {
+ functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue()));
+ }
+ TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName());
+ functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType));
+ }
+
for (String argument : context.arguments()) {
if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) {
// Internal (generated) functions do not have type info - add arguments
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index a0b859bf930..88766da67fc 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -7,6 +7,7 @@ import com.google.inject.Inject;
import com.yahoo.component.AbstractComponent;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.util.Map;
@@ -27,8 +28,9 @@ public class ModelsEvaluator extends AbstractComponent {
@Inject
public ModelsEvaluator(RankProfilesConfig config,
RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer) {
- this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig));
+ this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig, onnxModelsConfig));
}
public ModelsEvaluator(Map<String, Model> models) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
new file mode 100644
index 00000000000..dc27c43ef70
--- /dev/null
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -0,0 +1,57 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.io.File;
+import java.util.Map;
+
+/**
+ * A named ONNX model that should be evaluated with OnnxEvaluator.
+ *
+ * @author lesters
+ */
+class OnnxModel {
+
+ private final String name;
+ private final File modelFile;
+
+ private OnnxEvaluator evaluator;
+
+ OnnxModel(String name, File modelFile) {
+ this.name = name;
+ this.modelFile = modelFile;
+ }
+
+ public String name() {
+ return name;
+ }
+
+ public void load() {
+ if (evaluator == null) {
+ evaluator = new OnnxEvaluator(modelFile.getPath());
+ }
+ }
+
+ public Map<String, TensorType> inputs() {
+ return evaluator().getInputInfo();
+ }
+
+ public Map<String, TensorType> outputs() {
+ return evaluator().getOutputInfo();
+ }
+
+ public Tensor evaluate(Map<String, Tensor> inputs, String output) {
+ return evaluator().evaluate(inputs, output);
+ }
+
+ private OnnxEvaluator evaluator() {
+ if (evaluator == null) {
+ throw new IllegalStateException("ONNX model has not been loaded.");
+ }
+ return evaluator;
+ }
+
+}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index fb424439592..1bdb2810ddf 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.io.File;
@@ -48,11 +49,13 @@ public class RankProfilesConfigImporter {
* Returns a map of the models contained in this config, indexed on name.
* The map is modifiable and owned by the caller.
*/
- public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) {
+ public Map<String, Model> importFrom(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig) {
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
- Model model = importProfile(profile, constantsConfig);
+ Model model = importProfile(profile, constantsConfig, onnxModelsConfig);
models.put(model.name(), model);
}
return models;
@@ -62,9 +65,12 @@ public class RankProfilesConfigImporter {
}
}
- private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
+ private Model importProfile(RankProfilesConfig.Rankprofile profile,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig)
throws ParseException {
+ List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig);
List<Constant> constants = readLargeConstants(constantsConfig);
Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>();
@@ -76,7 +82,7 @@ public class RankProfilesConfigImporter {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
- if ( reference.isPresent()) {
+ if (reference.isPresent()) {
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
ExpressionFunction function = new ExpressionFunction(reference.get().functionName(),
Collections.emptyList(),
@@ -122,7 +128,7 @@ public class RankProfilesConfigImporter {
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions, referencedFunctions, constants);
+ return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
@@ -136,6 +142,26 @@ public class RankProfilesConfigImporter {
return null;
}
+ private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) {
+ List<OnnxModel> onnxModels = new ArrayList<>();
+ if (onnxModelsConfig != null) {
+ for (OnnxModelsConfig.Model onnxModelConfig : onnxModelsConfig.model()) {
+ onnxModels.add(readOnnxModelConfig(onnxModelConfig));
+ }
+ }
+ return onnxModels;
+ }
+
+ private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model onnxModelConfig) {
+ try {
+ String name = onnxModelConfig.name();
+ File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS);
+ return new OnnxModel(name, file);
+ } catch (InterruptedException e) {
+ throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
+ }
+ }
+
private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) {
List<Constant> constants = new ArrayList<>();