aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-20 11:29:12 +0200
committerLester Solbakken <lesters@oath.com>2021-05-20 11:29:12 +0200
commit4a126bdd16323226411561b969e581af90260692 (patch)
tree50bd5318dd8e0f174ff26a41a786042b787c9001 /model-evaluation
parentfc0711f7870b55ea77d18d87ec3e70b75e0de2e0 (diff)
Evaluate ONNX models in model-evaluation with ONNX RT
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/abi-spec.json5
-rw-r--r--model-evaluation/pom.xml6
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java36
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java102
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java18
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java57
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java36
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java69
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java76
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java93
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java137
-rw-r--r--model-evaluation/src/test/resources/config/models/onnx-models.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx24
-rwxr-xr-xmodel-evaluation/src/test/resources/config/onnx/models/add_mul.py30
-rw-r--r--model-evaluation/src/test/resources/config/onnx/models/one_layer.onnxbin0 -> 299 bytes
-rwxr-xr-xmodel-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py38
-rw-r--r--model-evaluation/src/test/resources/config/onnx/onnx-models.cfg16
-rw-r--r--model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg17
-rw-r--r--model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg0
24 files changed, 678 insertions, 96 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index d465464de7f..63882525808 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -39,6 +39,7 @@
"public int size()",
"public java.util.Set names()",
"public java.util.Set arguments()",
+ "public java.util.Map onnxModels()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value defaultValue()",
"public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)"
],
@@ -66,7 +67,7 @@
"public"
],
"methods": [
- "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
+ "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
"public void <init>(java.util.Map)",
"public java.util.Map models()",
"public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])",
@@ -82,7 +83,7 @@
],
"methods": [
"public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
- "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig)",
+ "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"protected com.yahoo.tensor.Tensor readTensorFromFile(java.lang.String, com.yahoo.tensor.TensorType, com.yahoo.config.FileReference)"
],
"fields": []
diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml
index 00560a22bc7..8cdff451b42 100644
--- a/model-evaluation/pom.xml
+++ b/model-evaluation/pom.xml
@@ -68,6 +68,12 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>model-integration</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>provided</scope>
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index e373a54bcd1..910aca8aa98 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
@@ -100,19 +101,40 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
+ evaluateOnnxModels();
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- if (context.isMissing(argument.getKey()))
- throw new IllegalStateException("Missing argument '" + argument.getKey() +
- "': Must be bound to a value of type " + argument.getValue());
- if (! context.get(argument.getKey()).type().isAssignableTo(argument.getValue()))
- throw new IllegalStateException("Argument '" + argument.getKey() +
- "' must be bound to a value of type " + argument.getValue());
-
+ checkArgument(argument.getKey(), argument.getValue());
}
evaluated = true;
return function.getBody().evaluate(context).asTensor();
}
+ private void checkArgument(String name, TensorType type) {
+ if (context.isMissing(name))
+ throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type);
+ if (! context.get(name).type().isAssignableTo(type))
+ throw new IllegalStateException("Argument '" + name + "' must be bound to a value of type " + type);
+ }
+
+ /**
+ * Evaluate ONNX models (if not already evaluated) and add the result back to the context.
+ */
+ private void evaluateOnnxModels() {
+ for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) {
+ String onnxFeature = entry.getKey();
+ OnnxModel onnxModel = entry.getValue();
+ if (context.get(onnxFeature).equals(context.defaultValue())) {
+ Map<String, Tensor> inputs = new HashMap<>();
+ for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) {
+ checkArgument(input.getKey(), input.getValue());
+ inputs.put(input.getKey(), context.get(input.getKey()).asTensor());
+ }
+ Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model
+ context.put(onnxFeature, new TensorValue(result));
+ }
+ }
+ }
+
/** Returns the function evaluated by this */
public ExpressionFunction function() { return function; }
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index d66315ef457..a5dcd2719c9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -11,15 +11,18 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
/**
@@ -41,9 +44,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
LazyArrayContext(ExpressionFunction function,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
+ List<OnnxModel> onnxModels,
Model model) {
this.function = function;
- this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model);
+ this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model);
}
/**
@@ -117,6 +121,9 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** Returns the (immutable) subset of names in this which must be bound when invoking */
public Set<String> arguments() { return indexedBindings.arguments(); }
+ /** Returns the set of ONNX models that need to be evaluated on this context */
+ public Map<String, OnnxModel> onnxModels() { return indexedBindings.onnxModels(); }
+
private Integer requireIndexOf(String name) {
Integer index = indexedBindings.indexOf(name);
if (index == null)
@@ -152,18 +159,24 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The current values set */
private final Value[] values;
+ /** ONNX models indexed by rank feature that calls them */
+ private final ImmutableMap<String, OnnxModel> onnxModels;
+
/** The object instance which encodes "no value is set". The actual value of this is never used. */
private static final Value missing = new DoubleValue(Double.NaN).freeze();
/** The value to return for lookups where no value is set (default: NaN) */
private Value missingValue = new DoubleValue(Double.NaN).freeze();
+
private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
Value[] values,
- ImmutableSet<String> arguments) {
+ ImmutableSet<String> arguments,
+ ImmutableMap<String, OnnxModel> onnxModels) {
this.nameToIndex = nameToIndex;
this.values = values;
this.arguments = arguments;
+ this.onnxModels = onnxModels;
}
/**
@@ -173,13 +186,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
IndexedBindings(ExpressionFunction function,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
+ List<OnnxModel> onnxModels,
LazyArrayContext owner,
Model model) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
- extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments);
+ Map<String, OnnxModel> onnxModelsInUse = new HashMap<>();
+ extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);
+ this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse);
this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, missing);
@@ -214,12 +230,18 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private void extractBindTargets(ExpressionNode node,
Map<FunctionReference, ExpressionFunction> functions,
Set<String> bindTargets,
- Set<String> arguments) {
+ Set<String> arguments,
+ List<OnnxModel> onnxModels,
+ Map<String, OnnxModel> onnxModelsInUse) {
if (isFunctionReference(node)) {
FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
bindTargets.add(reference.serialForm());
- extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments);
+ ExpressionNode functionNode = functions.get(reference).getBody().getRoot();
+ extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
+ }
+ else if (isOnnx(node)) {
+ extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse);
}
else if (isConstant(node)) {
bindTargets.add(node.toString());
@@ -231,20 +253,81 @@ public final class LazyArrayContext extends Context implements ContextIndex {
else if (node instanceof CompositeNode) {
CompositeNode cNode = (CompositeNode)node;
for (ExpressionNode child : cNode.children())
- extractBindTargets(child, functions, bindTargets, arguments);
+ extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
+ }
+ }
+
+ /**
+ * Extract the feature used to evaluate the onnx model. e.g. onnxModel(name) and add
+ * that as a bind target and argument. During evaluation, this will be evaluated before
+ * the rest of the expression and the result is added to the context. Also extract the
+ * inputs to the model and add them as bind targets and arguments.
+ */
+ private void extractOnnxTargets(ExpressionNode node,
+ Set<String> bindTargets,
+ Set<String> arguments,
+ List<OnnxModel> onnxModels,
+ Map<String, OnnxModel> onnxModelsInUse) {
+ Optional<String> modelName = getArgument(node);
+ if (modelName.isPresent()) {
+ for (OnnxModel onnxModel : onnxModels) {
+ if (onnxModel.name().equals(modelName.get())) {
+ String onnxFeature = node.toString();
+ bindTargets.add(onnxFeature);
+ arguments.add(onnxFeature);
+
+ // Load the model (if not already loaded) to extract inputs
+ onnxModel.load();
+
+ for(String input : onnxModel.inputs().keySet()) {
+ bindTargets.add(input);
+ arguments.add(input);
+ }
+ onnxModelsInUse.put(onnxFeature, onnxModel);
+ }
+ }
}
}
+ private Optional<String> getArgument(ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode reference = (ReferenceNode) node;
+ if (reference.getArguments().size() > 0) {
+ if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
+ ConstantNode constantNode = (ConstantNode) reference.getArguments().expressions().get(0);
+ return Optional.of(stripQuotes(constantNode.sourceString()));
+ }
+ if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0);
+ return Optional.of(referenceNode.getName());
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ public static String stripQuotes(String s) {
+ if (s.codePointAt(0) == '"' && s.codePointAt(s.length()-1) == '"')
+ return s.substring(1, s.length()-1);
+ if (s.codePointAt(0) == '\'' && s.codePointAt(s.length()-1) == '\'')
+ return s.substring(1, s.length()-1);
+ return s;
+ }
+
private boolean isFunctionReference(ExpressionNode node) {
if ( ! (node instanceof ReferenceNode)) return false;
-
ReferenceNode reference = (ReferenceNode)node;
return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
}
- private boolean isConstant(ExpressionNode node) {
+ private boolean isOnnx(ExpressionNode node) {
if ( ! (node instanceof ReferenceNode)) return false;
+ ReferenceNode reference = (ReferenceNode) node;
+ return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
+ }
+ private boolean isConstant(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode)) return false;
ReferenceNode reference = (ReferenceNode)node;
return reference.getName().equals("constant") && reference.getArguments().size() == 1;
}
@@ -261,12 +344,13 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Set<String> names() { return nameToIndex.keySet(); }
Set<String> arguments() { return arguments; }
Integer indexOf(String name) { return nameToIndex.get(name); }
+ Map<String, OnnxModel> onnxModels() { return onnxModels; }
IndexedBindings copy(Context context) {
Value[] valueCopy = new Value[values.length];
for (int i = 0; i < values.length; i++)
valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i];
- return new IndexedBindings(nameToIndex, valueCopy, arguments);
+ return new IndexedBindings(nameToIndex, valueCopy, arguments, onnxModels);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 03bbb436026..40a84a701ec 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -26,7 +26,7 @@ import java.util.stream.Collectors;
@Beta
public class Model {
- /** The prefix generated by mode-integration/../IntermediateOperation */
+ /** The prefix generated by model-integration/../IntermediateOperation */
private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
private final String name;
@@ -50,25 +50,37 @@ public class Model {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
Collections.emptyMap(),
+ Collections.emptyList(),
Collections.emptyList());
}
Model(String name,
Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
- List<Constant> constants) {
+ List<Constant> constants,
+ List<OnnxModel> onnxModels) {
this.name = name;
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this);
+ LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
contextBuilder.put(function.getValue().getName(), context);
if ( ! function.getValue().returnType().isPresent()) {
functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
+ for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) {
+ String onnxFeature = entry.getKey();
+ OnnxModel onnxModel = entry.getValue();
+ for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) {
+ functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue()));
+ }
+ TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName());
+ functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType));
+ }
+
for (String argument : context.arguments()) {
if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) {
// Internal (generated) functions do not have type info - add arguments
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index a0b859bf930..88766da67fc 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -7,6 +7,7 @@ import com.google.inject.Inject;
import com.yahoo.component.AbstractComponent;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.util.Map;
@@ -27,8 +28,9 @@ public class ModelsEvaluator extends AbstractComponent {
@Inject
public ModelsEvaluator(RankProfilesConfig config,
RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer) {
- this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig));
+ this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig, onnxModelsConfig));
}
public ModelsEvaluator(Map<String, Model> models) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
new file mode 100644
index 00000000000..dc27c43ef70
--- /dev/null
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -0,0 +1,57 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.io.File;
+import java.util.Map;
+
+/**
+ * A named ONNX model that should be evaluated with OnnxEvaluator.
+ *
+ * @author lesters
+ */
+class OnnxModel {
+
+ private final String name;
+ private final File modelFile;
+
+ private OnnxEvaluator evaluator;
+
+ OnnxModel(String name, File modelFile) {
+ this.name = name;
+ this.modelFile = modelFile;
+ }
+
+ public String name() {
+ return name;
+ }
+
+ public void load() {
+ if (evaluator == null) {
+ evaluator = new OnnxEvaluator(modelFile.getPath());
+ }
+ }
+
+ public Map<String, TensorType> inputs() {
+ return evaluator().getInputInfo();
+ }
+
+ public Map<String, TensorType> outputs() {
+ return evaluator().getOutputInfo();
+ }
+
+ public Tensor evaluate(Map<String, Tensor> inputs, String output) {
+ return evaluator().evaluate(inputs, output);
+ }
+
+ private OnnxEvaluator evaluator() {
+ if (evaluator == null) {
+ throw new IllegalStateException("ONNX model has not been loaded.");
+ }
+ return evaluator;
+ }
+
+}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index fb424439592..1bdb2810ddf 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.io.File;
@@ -48,11 +49,13 @@ public class RankProfilesConfigImporter {
* Returns a map of the models contained in this config, indexed on name.
* The map is modifiable and owned by the caller.
*/
- public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) {
+ public Map<String, Model> importFrom(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig) {
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
- Model model = importProfile(profile, constantsConfig);
+ Model model = importProfile(profile, constantsConfig, onnxModelsConfig);
models.put(model.name(), model);
}
return models;
@@ -62,9 +65,12 @@ public class RankProfilesConfigImporter {
}
}
- private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
+ private Model importProfile(RankProfilesConfig.Rankprofile profile,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig)
throws ParseException {
+ List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig);
List<Constant> constants = readLargeConstants(constantsConfig);
Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>();
@@ -76,7 +82,7 @@ public class RankProfilesConfigImporter {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
- if ( reference.isPresent()) {
+ if (reference.isPresent()) {
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
ExpressionFunction function = new ExpressionFunction(reference.get().functionName(),
Collections.emptyList(),
@@ -122,7 +128,7 @@ public class RankProfilesConfigImporter {
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions, referencedFunctions, constants);
+ return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
@@ -136,6 +142,26 @@ public class RankProfilesConfigImporter {
return null;
}
+ private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) {
+ List<OnnxModel> onnxModels = new ArrayList<>();
+ if (onnxModelsConfig != null) {
+ for (OnnxModelsConfig.Model onnxModelConfig : onnxModelsConfig.model()) {
+ onnxModels.add(readOnnxModelConfig(onnxModelConfig));
+ }
+ }
+ return onnxModels;
+ }
+
+ private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model onnxModelConfig) {
+ try {
+ String name = onnxModelConfig.name();
+ File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS);
+ return new OnnxModel(name, file);
+ } catch (InterruptedException e) {
+ throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
+ }
+ }
+
private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) {
List<Constant> constants = new ArrayList<>();
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index bacdb52a201..d252594e729 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -14,6 +14,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.io.IOException;
@@ -45,8 +46,10 @@ public class ModelTester {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
return new RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), MockFileAcquirer.returnFile(null))
- .importFrom(config, constantsConfig);
+ .importFrom(config, constantsConfig, onnxModelsConfig);
}
public ExpressionFunction assertFunction(String name, String expression, Model model) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 6fcf76d2815..dce033c79b0 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -10,6 +10,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.yolean.Exceptions;
import org.junit.Test;
@@ -131,7 +132,9 @@ public class ModelsEvaluatorTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
- return new ModelsEvaluator(config, constantsConfig, MockFileAcquirer.returnFile(null));
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
+ return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, MockFileAcquirer.returnFile(null));
}
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
new file mode 100644
index 00000000000..1d55fdf9e6a
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
@@ -0,0 +1,69 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class OnnxEvaluatorTest {
+
+ private static final double delta = 0.00000000001;
+
+ @Test
+ public void testOnnxEvaluation() {
+ ModelsEvaluator models = createModels("src/test/resources/config/onnx/");
+
+ assertTrue(models.models().containsKey("add_mul"));
+ assertTrue(models.models().containsKey("one_layer"));
+
+ FunctionEvaluator function = models.evaluatorOf("add_mul", "output1");
+ function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]"));
+ function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]"));
+ assertEquals(6.0, function.evaluate().sum().asDouble(), delta);
+
+ function = models.evaluatorOf("add_mul", "output2");
+ function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]"));
+ function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]"));
+ assertEquals(5.0, function.evaluate().sum().asDouble(), delta);
+
+ function = models.evaluatorOf("one_layer");
+ function.bind("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]"));
+ assertEquals(function.evaluate(), Tensor.from("tensor<float>(d0[2],d1[1]):[0.63931,0.67574]"));
+ }
+
+ private ModelsEvaluator createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
+
+ Map<String, File> fileMap = new HashMap<>();
+ for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) {
+ fileMap.put(onnxModel.fileref().value(), new File(path + onnxModel.fileref().value()));
+ }
+ FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
+
+ return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
new file mode 100644
index 00000000000..0da7f2ed096
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -0,0 +1,76 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.handler;
+
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.container.jdisc.HttpRequest;
+import com.yahoo.container.jdisc.HttpResponse;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.JsonFormat;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.Executors;
+
+import static org.junit.Assert.assertEquals;
+
+class HandlerTester {
+
+ private final ModelsEvaluationHandler handler;
+
+ HandlerTester(ModelsEvaluator models) {
+ this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor());
+ }
+
+ void assertResponse(String url, int expectedCode) {
+ assertResponse(url, Collections.emptyMap(), expectedCode, (String)null);
+ }
+
+ void assertResponse(String url, int expectedCode, String expectedResult) {
+ assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
+ }
+
+ void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
+ HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
+ HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
+ assertResponse(getRequest, expectedCode, expectedResult);
+ assertResponse(postRequest, expectedCode, expectedResult);
+ }
+
+ void assertResponse(String url, Map<String, String> properties, int expectedCode, Tensor expectedResult) {
+ HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
+ assertResponse(getRequest, expectedCode, expectedResult);
+ }
+
+ void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
+ HttpResponse response = handler.handle(request);
+ assertEquals("application/json", response.getContentType());
+ assertEquals(expectedCode, response.getStatus());
+ if (expectedResult != null) {
+ assertEquals(expectedResult, getContents(response));
+ }
+ }
+
+ void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) {
+ HttpResponse response = handler.handle(request);
+ assertEquals("application/json", response.getContentType());
+ assertEquals(expectedCode, response.getStatus());
+ if (expectedResult != null) {
+ String contents = getContents(response);
+ Tensor result = JsonFormat.decode(expectedResult.type(), contents.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expectedResult, result);
+ }
+ }
+
+ private String getContents(HttpResponse response) {
+ try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) {
+ response.render(stream);
+ return stream.toString();
+ } catch (IOException e) {
+ throw new Error(e);
+ }
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index c9e49d3be02..a69a220e532 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -5,51 +5,41 @@ import ai.vespa.models.evaluation.ModelTester;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
-import com.yahoo.container.jdisc.HttpRequest;
-import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.path.Path;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import org.junit.BeforeClass;
import org.junit.Test;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
-
-import static org.junit.Assert.assertEquals;
public class ModelsEvaluationHandlerTest {
- private static ModelsEvaluationHandler handler;
+ private static HandlerTester handler;
@BeforeClass
static public void setUp() {
- Executor executor = Executors.newSingleThreadExecutor();
- ModelsEvaluator models = createModels("src/test/resources/config/models/");
- handler = new ModelsEvaluationHandler(models, executor);
+ handler = new HandlerTester(createModels("src/test/resources/config/models/"));
}
@Test
public void testUnknownAPI() {
- assertResponse("http://localhost/wrong-api-binding", 404);
+ handler.assertResponse("http://localhost/wrong-api-binding", 404);
}
@Test
public void testUnknownVersion() {
- assertResponse("http://localhost/model-evaluation/v0", 404);
+ handler.assertResponse("http://localhost/model-evaluation/v0", 404);
}
@Test
public void testNonExistingModel() {
- assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 404);
+ handler.assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 404);
}
@Test
@@ -57,14 +47,14 @@ public class ModelsEvaluationHandlerTest {
String url = "http://localhost/model-evaluation/v1";
String expected =
"{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testXgBoostEvaluationWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function
String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
@@ -77,7 +67,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
@@ -90,14 +80,14 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
public void testLightGBMEvaluationWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
@@ -110,7 +100,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":2.054697758469921}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
@@ -123,35 +113,35 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":2.0745534018208094}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
public void testMnistSoftmaxDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax";
String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSoftmaxTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/";
String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}";
- assertResponse(url, 400, expected);
+ handler.assertResponse(url, 400, expected);
}
@Test
public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}";
- assertResponse(url, 400, expected);
+ handler.assertResponse(url, 400, expected);
}
@Test
@@ -160,7 +150,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("Placeholder", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
@@ -169,28 +159,28 @@ public class ModelsEvaluationHandlerTest {
properties.put("Placeholder", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSavedTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/";
String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSavedEvaluateDefaultFunctionShouldFail() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/eval";
String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
- assertResponse(url, 404, expected);
+ handler.assertResponse(url, 404, expected);
}
@Test
@@ -199,40 +189,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("input", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval";
String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}";
- assertResponse(url, properties, 200, expected);
- }
-
- static private void assertResponse(String url, int expectedCode) {
- assertResponse(url, Collections.emptyMap(), expectedCode, null);
- }
-
- static private void assertResponse(String url, int expectedCode, String expectedResult) {
- assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
- }
-
- static private void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
- HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
- HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
- assertResponse(getRequest, expectedCode, expectedResult);
- assertResponse(postRequest, expectedCode, expectedResult);
- }
-
- static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
- HttpResponse response = handler.handle(request);
- assertEquals("application/json", response.getContentType());
- if (expectedResult != null) {
- assertEquals(expectedResult, getContents(response));
- }
- assertEquals(expectedCode, response.getStatus());
- }
-
- static private String getContents(HttpResponse response) {
- try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) {
- response.render(stream);
- return stream.toString();
- } catch (IOException e) {
- throw new Error(e);
- }
+ handler.assertResponse(url, properties, 200, expected);
}
static private ModelsEvaluator createModels(String path) {
@@ -241,10 +198,12 @@ public class ModelsEvaluationHandlerTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
ModelTester.RankProfilesConfigImporterWithMockedConstants importer =
new ModelTester.RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"),
MockFileAcquirer.returnFile(null));
- return new ModelsEvaluator(importer.importFrom(config, constantsConfig));
+ return new ModelsEvaluator(importer.importFrom(config, constantsConfig, onnxModelsConfig));
}
private String inputTensor() {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
new file mode 100644
index 00000000000..6cfda4d8ce8
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -0,0 +1,137 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.handler;
+
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+
+public class OnnxEvaluationHandlerTest {
+
+ private static HandlerTester handler;
+
+ @BeforeClass
+ static public void setUp() {
+ handler = new HandlerTester(createModels("src/test/resources/config/onnx/"));
+ }
+
+ @Test
+ public void testListModels() {
+ String url = "http://localhost/model-evaluation/v1";
+ String expected = "{\"one_layer\":\"http://localhost/model-evaluation/v1/one_layer\"," +
+ "\"add_mul\":\"http://localhost/model-evaluation/v1/add_mul\"," +
+ "\"no_model\":\"http://localhost/model-evaluation/v1/no_model\"}";
+ handler.assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testModelInfo() {
+ String url = "http://localhost/model-evaluation/v1/add_mul";
+ String expected = "{\"model\":\"add_mul\",\"functions\":[" +
+ "{\"function\":\"output1\"," +
+ "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output1\"," +
+ "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output1/eval\"," +
+ "\"arguments\":[" +
+ "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"onnxModel(add_mul).output1\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
+ "]}," +
+ "{\"function\":\"output2\"," +
+ "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output2\"," +
+ "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output2/eval\"," +
+ "\"arguments\":[" +
+ "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"onnxModel(add_mul).output2\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
+ "]}]}";
+ handler.assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testEvaluationWithoutSpecifyingOutput() {
+ String url = "http://localhost/model-evaluation/v1/add_mul/eval";
+ String expected = "{\"error\":\"More than one function is available in model 'add_mul', but no name is given. Available functions: output1, output2\"}";
+ handler.assertResponse(url, 404, expected);
+ }
+
+ @Test
+ public void testEvaluationWithoutBindings() {
+ String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
+ String expected = "{\"error\":\"Argument 'input2' must be bound to a value of type tensor<float>(d0[1])\"}";
+ handler.assertResponse(url, 400, expected);
+ }
+
+ @Test
+ public void testEvaluationOutput1() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input1", "tensor<float>(d0[1]):[2]");
+ properties.put("input2", "tensor<float>(d0[1]):[3]");
+ String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}"; // output1 is a mul
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testEvaluationOutput2() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input1", "tensor<float>(d0[1]):[2]");
+ properties.put("input2", "tensor<float>(d0[1]):[3]");
+ String url = "http://localhost/model-evaluation/v1/add_mul/output2/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}"; // output2 is an add
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testBatchDimensionModelInfo() {
+ String url = "http://localhost/model-evaluation/v1/one_layer";
+ String expected = "{\"model\":\"one_layer\",\"functions\":[" +
+ "{\"function\":\"output\"," +
+ "\"info\":\"http://localhost/model-evaluation/v1/one_layer/output\"," +
+ "\"eval\":\"http://localhost/model-evaluation/v1/one_layer/output/eval\"," +
+ "\"arguments\":[" +
+ "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}," +
+ "{\"name\":\"onnxModel(one_layer)\",\"type\":\"tensor<float>(d0[],d1[1])\"}" +
+ "]}]}";
+ handler.assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testBatchDimensionEvaluation() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input", "tensor<float>(d0[],d1[3]):{{d0:0,d1:0}:0.1,{d0:0,d1:1}:0.2,{d0:0,d1:2}:0.3,{d0:1,d1:0}:0.4,{d0:1,d1:1}:0.5,{d0:1,d1:2}:0.6}");
+ String url = "http://localhost/model-evaluation/v1/one_layer/eval"; // output not specified
+ Tensor expected = Tensor.from("tensor<float>(d0[2],d1[1]):[0.6393113,0.67574286]");
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ static private ModelsEvaluator createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
+
+ Map<String, File> fileMap = new HashMap<>();
+ for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) {
+ fileMap.put(onnxModel.fileref().value(), new File(path + onnxModel.fileref().value()));
+ }
+ FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
+
+ return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
+ }
+
+}
diff --git a/model-evaluation/src/test/resources/config/models/onnx-models.cfg b/model-evaluation/src/test/resources/config/models/onnx-models.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/models/onnx-models.cfg
diff --git a/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx b/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx
new file mode 100644
index 00000000000..ab054d112e9
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx
@@ -0,0 +1,24 @@
+
+add_mul.py:£
+
+input1
+input2output1"Mul
+
+input1
+input2output2"Addadd_mulZ
+input1
+
+
+Z
+input2
+
+
+b
+output1
+
+
+b
+output2
+
+
+B \ No newline at end of file
diff --git a/model-evaluation/src/test/resources/config/onnx/models/add_mul.py b/model-evaluation/src/test/resources/config/onnx/models/add_mul.py
new file mode 100755
index 00000000000..3a4522042e8
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/add_mul.py
@@ -0,0 +1,30 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT_1 = helper.make_tensor_value_info('output1', TensorProto.FLOAT, [1])
+OUTPUT_2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Mul',
+ ['input1', 'input2'],
+ ['output1'],
+ ),
+ helper.make_node(
+ 'Add',
+ ['input1', 'input2'],
+ ['output2'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'add_mul',
+ [INPUT_1, INPUT_2],
+ [OUTPUT_1, OUTPUT_2],
+)
+model_def = helper.make_model(graph_def, producer_name='add_mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'add_mul.onnx')
diff --git a/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx b/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx
new file mode 100644
index 00000000000..dc9f664b943
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx
Binary files differ
diff --git a/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py b/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py
new file mode 100755
index 00000000000..1296d84e180
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py
@@ -0,0 +1,38 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import torch
+import torch.onnx
+
+
+class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.linear = torch.nn.Linear(in_features=3, out_features=1)
+ self.logistic = torch.nn.Sigmoid()
+
+ def forward(self, vec):
+ return self.logistic(self.linear(vec))
+
+
+def main():
+ model = MyModel()
+
+ # Omit training - just export randomly initialized network
+
+ data = torch.FloatTensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]])
+ torch.onnx.export(model,
+ data,
+ "one_layer.onnx",
+ input_names = ["input"],
+ output_names = ["output"],
+ dynamic_axes = {
+ "input": {0: "batch"},
+ "output": {0: "batch"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg b/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg
new file mode 100644
index 00000000000..9ad9c7f6a07
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg
@@ -0,0 +1,16 @@
+model[0].name "add_mul"
+model[0].fileref "models/add_mul.onnx"
+model[0].input[0].name "input1"
+model[0].input[0].source "input1"
+model[0].input[1].name "input2"
+model[0].input[1].source "input2"
+model[0].output[0].name "output1"
+model[0].output[0].as "output1"
+model[0].output[1].name "output2"
+model[0].output[1].as "output2"
+model[1].name "one_layer"
+model[1].fileref "models/one_layer.onnx"
+model[1].input[0].name "input"
+model[1].input[0].source "input"
+model[1].output[0].name "output"
+model[1].output[0].as "output"
diff --git a/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg b/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg
new file mode 100644
index 00000000000..047b7c3c77b
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg
@@ -0,0 +1,17 @@
+rankprofile[0].name "add_mul"
+rankprofile[0].fef.property[0].name "rankingExpression(output1).rankingScript"
+rankprofile[0].fef.property[0].value "onnxModel(add_mul).output1"
+rankprofile[0].fef.property[1].name "rankingExpression(output1).type"
+rankprofile[0].fef.property[1].value "tensor<float>(d0[1])"
+rankprofile[0].fef.property[2].name "rankingExpression(output2).rankingScript"
+rankprofile[0].fef.property[2].value "onnxModel(add_mul).output2"
+rankprofile[0].fef.property[3].name "rankingExpression(output2).type"
+rankprofile[0].fef.property[3].value "tensor<float>(d0[1])"
+rankprofile[1].name "one_layer"
+rankprofile[1].fef.property[0].name "rankingExpression(output).rankingScript"
+rankprofile[1].fef.property[0].value "onnxModel(one_layer)"
+rankprofile[1].fef.property[1].name "rankingExpression(output).type"
+rankprofile[1].fef.property[1].value "tensor<float>(d0[],d1[1])"
+rankprofile[2].name "no_model"
+rankprofile[2].fef.property[0].name "rankingExpression(output).rankingScript"
+rankprofile[2].fef.property[0].value "onnxModel(no_model)"
diff --git a/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg b/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg
diff --git a/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg b/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg
diff --git a/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg b/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg