diff options
author | Lester Solbakken <lesters@oath.com> | 2021-05-20 11:29:12 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-05-20 11:29:12 +0200 |
commit | 4a126bdd16323226411561b969e581af90260692 (patch) | |
tree | 50bd5318dd8e0f174ff26a41a786042b787c9001 /model-evaluation | |
parent | fc0711f7870b55ea77d18d87ec3e70b75e0de2e0 (diff) |
Evaluate ONNX models in model-evaluation with ONNX RT
Diffstat (limited to 'model-evaluation')
24 files changed, 678 insertions, 96 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index d465464de7f..63882525808 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -39,6 +39,7 @@ "public int size()", "public java.util.Set names()", "public java.util.Set arguments()", + "public java.util.Map onnxModels()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value defaultValue()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)" ], @@ -66,7 +67,7 @@ "public" ], "methods": [ - "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)", + "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)", "public void <init>(java.util.Map)", "public java.util.Map models()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])", @@ -82,7 +83,7 @@ ], "methods": [ "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer)", - "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig)", + "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "protected com.yahoo.tensor.Tensor readTensorFromFile(java.lang.String, com.yahoo.tensor.TensorType, com.yahoo.config.FileReference)" ], "fields": [] diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml index 00560a22bc7..8cdff451b42 100644 --- a/model-evaluation/pom.xml +++ b/model-evaluation/pom.xml @@ -68,6 +68,12 @@ <scope>provided</scope> </dependency> <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>model-integration</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <scope>provided</scope> diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index e373a54bcd1..910aca8aa98 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -100,19 +101,40 @@ public class FunctionEvaluator { } public Tensor evaluate() { + evaluateOnnxModels(); for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - if (context.isMissing(argument.getKey())) - throw new IllegalStateException("Missing argument '" + argument.getKey() + - "': Must be bound to a value of type " + argument.getValue()); - if (! context.get(argument.getKey()).type().isAssignableTo(argument.getValue())) - throw new IllegalStateException("Argument '" + argument.getKey() + - "' must be bound to a value of type " + argument.getValue()); - + checkArgument(argument.getKey(), argument.getValue()); } evaluated = true; return function.getBody().evaluate(context).asTensor(); } + private void checkArgument(String name, TensorType type) { + if (context.isMissing(name)) + throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type); + if (! context.get(name).type().isAssignableTo(type)) + throw new IllegalStateException("Argument '" + name + "' must be bound to a value of type " + type); + } + + /** + * Evaluate ONNX models (if not already evaluated) and add the result back to the context. + */ + private void evaluateOnnxModels() { + for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel onnxModel = entry.getValue(); + if (context.get(onnxFeature).equals(context.defaultValue())) { + Map<String, Tensor> inputs = new HashMap<>(); + for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) { + checkArgument(input.getKey(), input.getValue()); + inputs.put(input.getKey(), context.get(input.getKey()).asTensor()); + } + Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model + context.put(onnxFeature, new TensorValue(result)); + } + } + } + /** Returns the function evaluated by this */ public ExpressionFunction function() { return function; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d66315ef457..a5dcd2719c9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -11,15 +11,18 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; /** @@ -41,9 +44,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { LazyArrayContext(ExpressionFunction function, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, + List<OnnxModel> onnxModels, Model model) { this.function = function; - this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model); + this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model); } /** @@ -117,6 +121,9 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** Returns the (immutable) subset of names in this which must be bound when invoking */ public Set<String> arguments() { return indexedBindings.arguments(); } + /** Returns the set of ONNX models that need to be evaluated on this context */ + public Map<String, OnnxModel> onnxModels() { return indexedBindings.onnxModels(); } + private Integer requireIndexOf(String name) { Integer index = indexedBindings.indexOf(name); if (index == null) @@ -152,18 +159,24 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The current values set */ private final Value[] values; + /** ONNX models indexed by rank feature that calls them */ + private final ImmutableMap<String, OnnxModel> onnxModels; + /** The object instance which encodes "no value is set". The actual value of this is never used. */ private static final Value missing = new DoubleValue(Double.NaN).freeze(); /** The value to return for lookups where no value is set (default: NaN) */ private Value missingValue = new DoubleValue(Double.NaN).freeze(); + private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values, - ImmutableSet<String> arguments) { + ImmutableSet<String> arguments, + ImmutableMap<String, OnnxModel> onnxModels) { this.nameToIndex = nameToIndex; this.values = values; this.arguments = arguments; + this.onnxModels = onnxModels; } /** @@ -173,13 +186,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { IndexedBindings(ExpressionFunction function, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, + List<OnnxModel> onnxModels, LazyArrayContext owner, Model model) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation - extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments); + Map<String, OnnxModel> onnxModelsInUse = new HashMap<>(); + extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse); + this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse); this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, missing); @@ -214,12 +230,18 @@ public final class LazyArrayContext extends Context implements ContextIndex { private void extractBindTargets(ExpressionNode node, Map<FunctionReference, ExpressionFunction> functions, Set<String> bindTargets, - Set<String> arguments) { + Set<String> arguments, + List<OnnxModel> onnxModels, + Map<String, OnnxModel> onnxModelsInUse) { if (isFunctionReference(node)) { FunctionReference reference = FunctionReference.fromSerial(node.toString()).get(); bindTargets.add(reference.serialForm()); - extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments); + ExpressionNode functionNode = functions.get(reference).getBody().getRoot(); + extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); + } + else if (isOnnx(node)) { + extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse); } else if (isConstant(node)) { bindTargets.add(node.toString()); @@ -231,20 +253,81 @@ public final class LazyArrayContext extends Context implements ContextIndex { else if (node instanceof CompositeNode) { CompositeNode cNode = (CompositeNode)node; for (ExpressionNode child : cNode.children()) - extractBindTargets(child, functions, bindTargets, arguments); + extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); + } + } + + /** + * Extract the feature used to evaluate the onnx model. e.g. onnxModel(name) and add + * that as a bind target and argument. During evaluation, this will be evaluated before + * the rest of the expression and the result is added to the context. Also extract the + * inputs to the model and add them as bind targets and arguments. + */ + private void extractOnnxTargets(ExpressionNode node, + Set<String> bindTargets, + Set<String> arguments, + List<OnnxModel> onnxModels, + Map<String, OnnxModel> onnxModelsInUse) { + Optional<String> modelName = getArgument(node); + if (modelName.isPresent()) { + for (OnnxModel onnxModel : onnxModels) { + if (onnxModel.name().equals(modelName.get())) { + String onnxFeature = node.toString(); + bindTargets.add(onnxFeature); + arguments.add(onnxFeature); + + // Load the model (if not already loaded) to extract inputs + onnxModel.load(); + + for(String input : onnxModel.inputs().keySet()) { + bindTargets.add(input); + arguments.add(input); + } + onnxModelsInUse.put(onnxFeature, onnxModel); + } + } } } + private Optional<String> getArgument(ExpressionNode node) { + if (node instanceof ReferenceNode) { + ReferenceNode reference = (ReferenceNode) node; + if (reference.getArguments().size() > 0) { + if (reference.getArguments().expressions().get(0) instanceof ConstantNode) { + ConstantNode constantNode = (ConstantNode) reference.getArguments().expressions().get(0); + return Optional.of(stripQuotes(constantNode.sourceString())); + } + if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0); + return Optional.of(referenceNode.getName()); + } + } + } + return Optional.empty(); + } + + public static String stripQuotes(String s) { + if (s.codePointAt(0) == '"' && s.codePointAt(s.length()-1) == '"') + return s.substring(1, s.length()-1); + if (s.codePointAt(0) == '\'' && s.codePointAt(s.length()-1) == '\'') + return s.substring(1, s.length()-1); + return s; + } + private boolean isFunctionReference(ExpressionNode node) { if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode)node; return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1; } - private boolean isConstant(ExpressionNode node) { + private boolean isOnnx(ExpressionNode node) { if ( ! (node instanceof ReferenceNode)) return false; + ReferenceNode reference = (ReferenceNode) node; + return reference.getName().equals("onnx") || reference.getName().equals("onnxModel"); + } + private boolean isConstant(ExpressionNode node) { + if ( ! (node instanceof ReferenceNode)) return false; ReferenceNode reference = (ReferenceNode)node; return reference.getName().equals("constant") && reference.getArguments().size() == 1; } @@ -261,12 +344,13 @@ public final class LazyArrayContext extends Context implements ContextIndex { Set<String> names() { return nameToIndex.keySet(); } Set<String> arguments() { return arguments; } Integer indexOf(String name) { return nameToIndex.get(name); } + Map<String, OnnxModel> onnxModels() { return onnxModels; } IndexedBindings copy(Context context) { Value[] valueCopy = new Value[values.length]; for (int i = 0; i < values.length; i++) valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i]; - return new IndexedBindings(nameToIndex, valueCopy, arguments); + return new IndexedBindings(nameToIndex, valueCopy, arguments, onnxModels); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 03bbb436026..40a84a701ec 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -26,7 +26,7 @@ import java.util.stream.Collectors; @Beta public class Model { - /** The prefix generated by mode-integration/../IntermediateOperation */ + /** The prefix generated by model-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; private final String name; @@ -50,25 +50,37 @@ public class Model { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), Collections.emptyMap(), + Collections.emptyList(), Collections.emptyList()); } Model(String name, Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, - List<Constant> constants) { + List<Constant> constants, + List<OnnxModel> onnxModels) { this.name = name; // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this); + LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); contextBuilder.put(function.getValue().getName(), context); if ( ! function.getValue().returnType().isPresent()) { functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } + for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel onnxModel = entry.getValue(); + for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) { + functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue())); + } + TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName()); + functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType)); + } + for (String argument : context.arguments()) { if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) { // Internal (generated) functions do not have type info - add arguments diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index a0b859bf930..88766da67fc 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -7,6 +7,7 @@ import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.util.Map; @@ -27,8 +28,9 @@ public class ModelsEvaluator extends AbstractComponent { @Inject public ModelsEvaluator(RankProfilesConfig config, RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig)); + this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig, onnxModelsConfig)); } public ModelsEvaluator(Map<String, Model> models) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java new file mode 100644 index 00000000000..dc27c43ef70 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -0,0 +1,57 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.io.File; +import java.util.Map; + +/** + * A named ONNX model that should be evaluated with OnnxEvaluator. + * + * @author lesters + */ +class OnnxModel { + + private final String name; + private final File modelFile; + + private OnnxEvaluator evaluator; + + OnnxModel(String name, File modelFile) { + this.name = name; + this.modelFile = modelFile; + } + + public String name() { + return name; + } + + public void load() { + if (evaluator == null) { + evaluator = new OnnxEvaluator(modelFile.getPath()); + } + } + + public Map<String, TensorType> inputs() { + return evaluator().getInputInfo(); + } + + public Map<String, TensorType> outputs() { + return evaluator().getOutputInfo(); + } + + public Tensor evaluate(Map<String, Tensor> inputs, String output) { + return evaluator().evaluate(inputs, output); + } + + private OnnxEvaluator evaluator() { + if (evaluator == null) { + throw new IllegalStateException("ONNX model has not been loaded."); + } + return evaluator; + } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index fb424439592..1bdb2810ddf 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.io.File; @@ -48,11 +49,13 @@ public class RankProfilesConfigImporter { * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { + public Map<String, Model> importFrom(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig) { try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile, constantsConfig); + Model model = importProfile(profile, constantsConfig, onnxModelsConfig); models.put(model.name(), model); } return models; @@ -62,9 +65,12 @@ public class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) + private Model importProfile(RankProfilesConfig.Rankprofile profile, + RankingConstantsConfig constantsConfig, + OnnxModelsConfig onnxModelsConfig) throws ParseException { + List<OnnxModel> onnxModels = readOnnxModelsConfig(onnxModelsConfig); List<Constant> constants = readLargeConstants(constantsConfig); Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>(); @@ -76,7 +82,7 @@ public class RankProfilesConfigImporter { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); - if ( reference.isPresent()) { + if (reference.isPresent()) { RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), @@ -122,7 +128,7 @@ public class RankProfilesConfigImporter { constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions, referencedFunctions, constants); + return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -136,6 +142,26 @@ public class RankProfilesConfigImporter { return null; } + private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) { + List<OnnxModel> onnxModels = new ArrayList<>(); + if (onnxModelsConfig != null) { + for (OnnxModelsConfig.Model onnxModelConfig : onnxModelsConfig.model()) { + onnxModels.add(readOnnxModelConfig(onnxModelConfig)); + } + } + return onnxModels; + } + + private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model onnxModelConfig) { + try { + String name = onnxModelConfig.name(); + File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS); + return new OnnxModel(name, file); + } catch (InterruptedException e) { + throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); + } + } + private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) { List<Constant> constants = new ArrayList<>(); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java index bacdb52a201..d252594e729 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -14,6 +14,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.io.IOException; @@ -45,8 +46,10 @@ public class ModelTester { RankProfilesConfig.class).getConfig(""); RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), RankingConstantsConfig.class).getConfig(""); + OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()), + OnnxModelsConfig.class).getConfig(""); return new RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), MockFileAcquirer.returnFile(null)) - .importFrom(config, constantsConfig); + .importFrom(config, constantsConfig, onnxModelsConfig); } public ExpressionFunction assertFunction(String name, String expression, Model model) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 6fcf76d2815..dce033c79b0 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -10,6 +10,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.yolean.Exceptions; import org.junit.Test; @@ -131,7 +132,9 @@ public class ModelsEvaluatorTest { RankProfilesConfig.class).getConfig(""); RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), RankingConstantsConfig.class).getConfig(""); - return new ModelsEvaluator(config, constantsConfig, MockFileAcquirer.returnFile(null)); + OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()), + OnnxModelsConfig.class).getConfig(""); + return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, MockFileAcquirer.returnFile(null)); } } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java new file mode 100644 index 00000000000..1d55fdf9e6a --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java @@ -0,0 +1,69 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.path.Path; +import com.yahoo.tensor.Tensor; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import org.junit.Test; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author lesters + */ +public class OnnxEvaluatorTest { + + private static final double delta = 0.00000000001; + + @Test + public void testOnnxEvaluation() { + ModelsEvaluator models = createModels("src/test/resources/config/onnx/"); + + assertTrue(models.models().containsKey("add_mul")); + assertTrue(models.models().containsKey("one_layer")); + + FunctionEvaluator function = models.evaluatorOf("add_mul", "output1"); + function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]")); + function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]")); + assertEquals(6.0, function.evaluate().sum().asDouble(), delta); + + function = models.evaluatorOf("add_mul", "output2"); + function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]")); + function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]")); + assertEquals(5.0, function.evaluate().sum().asDouble(), delta); + + function = models.evaluatorOf("one_layer"); + function.bind("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]")); + assertEquals(function.evaluate(), Tensor.from("tensor<float>(d0[2],d1[1]):[0.63931,0.67574]")); + } + + private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()), + OnnxModelsConfig.class).getConfig(""); + + Map<String, File> fileMap = new HashMap<>(); + for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) { + fileMap.put(onnxModel.fileref().value(), new File(path + onnxModel.fileref().value())); + } + FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap); + + return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer); + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java new file mode 100644 index 00000000000..0da7f2ed096 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -0,0 +1,76 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.handler; + +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.container.jdisc.HttpResponse; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.JsonFormat; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.Executors; + +import static org.junit.Assert.assertEquals; + +class HandlerTester { + + private final ModelsEvaluationHandler handler; + + HandlerTester(ModelsEvaluator models) { + this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor()); + } + + void assertResponse(String url, int expectedCode) { + assertResponse(url, Collections.emptyMap(), expectedCode, (String)null); + } + + void assertResponse(String url, int expectedCode, String expectedResult) { + assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult); + } + + void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) { + HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties); + HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties); + assertResponse(getRequest, expectedCode, expectedResult); + assertResponse(postRequest, expectedCode, expectedResult); + } + + void assertResponse(String url, Map<String, String> properties, int expectedCode, Tensor expectedResult) { + HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties); + assertResponse(getRequest, expectedCode, expectedResult); + } + + void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { + HttpResponse response = handler.handle(request); + assertEquals("application/json", response.getContentType()); + assertEquals(expectedCode, response.getStatus()); + if (expectedResult != null) { + assertEquals(expectedResult, getContents(response)); + } + } + + void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) { + HttpResponse response = handler.handle(request); + assertEquals("application/json", response.getContentType()); + assertEquals(expectedCode, response.getStatus()); + if (expectedResult != null) { + String contents = getContents(response); + Tensor result = JsonFormat.decode(expectedResult.type(), contents.getBytes(StandardCharsets.UTF_8)); + assertEquals(expectedResult, result); + } + } + + private String getContents(HttpResponse response) { + try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) { + response.render(stream); + return stream.toString(); + } catch (IOException e) { + throw new Error(e); + } + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index c9e49d3be02..a69a220e532 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -5,51 +5,41 @@ import ai.vespa.models.evaluation.ModelTester; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; -import com.yahoo.container.jdisc.HttpRequest; -import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.BeforeClass; import org.junit.Test; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import static org.junit.Assert.assertEquals; public class ModelsEvaluationHandlerTest { - private static ModelsEvaluationHandler handler; + private static HandlerTester handler; @BeforeClass static public void setUp() { - Executor executor = Executors.newSingleThreadExecutor(); - ModelsEvaluator models = createModels("src/test/resources/config/models/"); - handler = new ModelsEvaluationHandler(models, executor); + handler = new HandlerTester(createModels("src/test/resources/config/models/")); } @Test public void testUnknownAPI() { - assertResponse("http://localhost/wrong-api-binding", 404); + handler.assertResponse("http://localhost/wrong-api-binding", 404); } @Test public void testUnknownVersion() { - assertResponse("http://localhost/model-evaluation/v0", 404); + handler.assertResponse("http://localhost/model-evaluation/v0", 404); } @Test public void testNonExistingModel() { - assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 404); + handler.assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 404); } @Test @@ -57,14 +47,14 @@ public class ModelsEvaluationHandlerTest { String url = "http://localhost/model-evaluation/v1"; String expected = "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}"; - assertResponse(url, 200, expected); + handler.assertResponse(url, 200, expected); } @Test public void testXgBoostEvaluationWithoutBindings() { String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}"; - assertResponse(url, 200, expected); + handler.assertResponse(url, 200, expected); } @Test @@ -77,7 +67,7 @@ public class ModelsEvaluationHandlerTest { properties.put("non-existing-binding", "-1"); String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}"; - assertResponse(url, properties, 200, expected); + handler.assertResponse(url, properties, 200, expected); } @Test @@ -90,14 +80,14 @@ public class ModelsEvaluationHandlerTest { properties.put("non-existing-binding", "-1"); String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}"; - assertResponse(url, properties, 200, expected); + handler.assertResponse(url, properties, 200, expected); } @Test public void testLightGBMEvaluationWithoutBindings() { String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval"; String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}"; - assertResponse(url, 200, expected); + handler.assertResponse(url, 200, expected); } @Test @@ -110,7 +100,7 @@ public class ModelsEvaluationHandlerTest { properties.put("non-existing-binding", "-1"); String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval"; String expected = "{\"cells\":[{\"address\":{},\"value\":2.054697758469921}]}"; - assertResponse(url, properties, 200, expected); + handler.assertResponse(url, properties, 200, expected); } @Test @@ -123,35 +113,35 @@ public class ModelsEvaluationHandlerTest { properties.put("non-existing-binding", "-1"); String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval"; String expected = "{\"cells\":[{\"address\":{},\"value\":2.0745534018208094}]}"; - assertResponse(url, properties, 200, expected); + handler.assertResponse(url, properties, 200, expected); } @Test public void testMnistSoftmaxDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax"; String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; - assertResponse(url, 200, expected); + handler.assertResponse(url, 200, expected); } @Test public void testMnistSoftmaxTypeDetails() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/"; String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}"; - assertResponse(url, 200, expected); + handler.assertResponse(url, 200, expected); } @Test public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}"; - assertResponse(url, 400, expected); + handler.assertResponse(url, 400, expected); } @Test public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}"; - assertResponse(url, 400, expected); + handler.assertResponse(url, 400, expected); } @Test @@ -160,7 +150,7 @@ public class ModelsEvaluationHandlerTest { properties.put("Placeholder", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; - assertResponse(url, properties, 200, expected); + handler.assertResponse(url, properties, 200, expected); } @Test @@ -169,28 +159,28 @@ public class ModelsEvaluationHandlerTest { properties.put("Placeholder", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; - assertResponse(url, properties, 200, expected); + handler.assertResponse(url, properties, 200, expected); } @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; - assertResponse(url, 200, expected); + handler.assertResponse(url, 200, expected); } @Test public void testMnistSavedTypeDetails() { String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/"; String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}"; - assertResponse(url, 200, expected); + handler.assertResponse(url, 200, expected); } @Test public void testMnistSavedEvaluateDefaultFunctionShouldFail() { String url = "http://localhost/model-evaluation/v1/mnist_saved/eval"; String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}"; - assertResponse(url, 404, expected); + handler.assertResponse(url, 404, expected); } @Test @@ -199,40 +189,7 @@ public class ModelsEvaluationHandlerTest { properties.put("input", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval"; String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}"; - assertResponse(url, properties, 200, expected); - } - - static private void assertResponse(String url, int expectedCode) { - assertResponse(url, Collections.emptyMap(), expectedCode, null); - } - - static private void assertResponse(String url, int expectedCode, String expectedResult) { - assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult); - } - - static private void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) { - HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties); - HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties); - assertResponse(getRequest, expectedCode, expectedResult); - assertResponse(postRequest, expectedCode, expectedResult); - } - - static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { - HttpResponse response = handler.handle(request); - assertEquals("application/json", response.getContentType()); - if (expectedResult != null) { - assertEquals(expectedResult, getContents(response)); - } - assertEquals(expectedCode, response.getStatus()); - } - - static private String getContents(HttpResponse response) { - try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) { - response.render(stream); - return stream.toString(); - } catch (IOException e) { - throw new Error(e); - } + handler.assertResponse(url, properties, 200, expected); } static private ModelsEvaluator createModels(String path) { @@ -241,10 +198,12 @@ public class ModelsEvaluationHandlerTest { RankProfilesConfig.class).getConfig(""); RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), RankingConstantsConfig.class).getConfig(""); + OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()), + OnnxModelsConfig.class).getConfig(""); ModelTester.RankProfilesConfigImporterWithMockedConstants importer = new ModelTester.RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), MockFileAcquirer.returnFile(null)); - return new ModelsEvaluator(importer.importFrom(config, constantsConfig)); + return new ModelsEvaluator(importer.importFrom(config, constantsConfig, onnxModelsConfig)); } private String inputTensor() { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java new file mode 100644 index 00000000000..6cfda4d8ce8 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -0,0 +1,137 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.handler; + +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.path.Path; +import com.yahoo.tensor.Tensor; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; + +public class OnnxEvaluationHandlerTest { + + private static HandlerTester handler; + + @BeforeClass + static public void setUp() { + handler = new HandlerTester(createModels("src/test/resources/config/onnx/")); + } + + @Test + public void testListModels() { + String url = "http://localhost/model-evaluation/v1"; + String expected = "{\"one_layer\":\"http://localhost/model-evaluation/v1/one_layer\"," + + "\"add_mul\":\"http://localhost/model-evaluation/v1/add_mul\"," + + "\"no_model\":\"http://localhost/model-evaluation/v1/no_model\"}"; + handler.assertResponse(url, 200, expected); + } + + @Test + public void testModelInfo() { + String url = "http://localhost/model-evaluation/v1/add_mul"; + String expected = "{\"model\":\"add_mul\",\"functions\":[" + + "{\"function\":\"output1\"," + + "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output1\"," + + "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output1/eval\"," + + "\"arguments\":[" + + "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," + + "{\"name\":\"onnxModel(add_mul).output1\",\"type\":\"tensor<float>(d0[1])\"}," + + "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" + + "]}," + + "{\"function\":\"output2\"," + + "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output2\"," + + "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output2/eval\"," + + "\"arguments\":[" + + "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," + + "{\"name\":\"onnxModel(add_mul).output2\",\"type\":\"tensor<float>(d0[1])\"}," + + "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" + + "]}]}"; + handler.assertResponse(url, 200, expected); + } + + @Test + public void testEvaluationWithoutSpecifyingOutput() { + String url = "http://localhost/model-evaluation/v1/add_mul/eval"; + String expected = "{\"error\":\"More than one function is available in model 'add_mul', but no name is given. Available functions: output1, output2\"}"; + handler.assertResponse(url, 404, expected); + } + + @Test + public void testEvaluationWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval"; + String expected = "{\"error\":\"Argument 'input2' must be bound to a value of type tensor<float>(d0[1])\"}"; + handler.assertResponse(url, 400, expected); + } + + @Test + public void testEvaluationOutput1() { + Map<String, String> properties = new HashMap<>(); + properties.put("input1", "tensor<float>(d0[1]):[2]"); + properties.put("input2", "tensor<float>(d0[1]):[3]"); + String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}"; // output1 is a mul + handler.assertResponse(url, properties, 200, expected); + } + + @Test + public void testEvaluationOutput2() { + Map<String, String> properties = new HashMap<>(); + properties.put("input1", "tensor<float>(d0[1]):[2]"); + properties.put("input2", "tensor<float>(d0[1]):[3]"); + String url = "http://localhost/model-evaluation/v1/add_mul/output2/eval"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}"; // output2 is an add + handler.assertResponse(url, properties, 200, expected); + } + + @Test + public void testBatchDimensionModelInfo() { + String url = "http://localhost/model-evaluation/v1/one_layer"; + String expected = "{\"model\":\"one_layer\",\"functions\":[" + + "{\"function\":\"output\"," + + "\"info\":\"http://localhost/model-evaluation/v1/one_layer/output\"," + + "\"eval\":\"http://localhost/model-evaluation/v1/one_layer/output/eval\"," + + "\"arguments\":[" + + "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}," + + "{\"name\":\"onnxModel(one_layer)\",\"type\":\"tensor<float>(d0[],d1[1])\"}" + + "]}]}"; + handler.assertResponse(url, 200, expected); + } + + @Test + public void testBatchDimensionEvaluation() { + Map<String, String> properties = new HashMap<>(); + properties.put("input", "tensor<float>(d0[],d1[3]):{{d0:0,d1:0}:0.1,{d0:0,d1:1}:0.2,{d0:0,d1:2}:0.3,{d0:1,d1:0}:0.4,{d0:1,d1:1}:0.5,{d0:1,d1:2}:0.6}"); + String url = "http://localhost/model-evaluation/v1/one_layer/eval"; // output not specified + Tensor expected = Tensor.from("tensor<float>(d0[2],d1[1]):[0.6393113,0.67574286]"); + handler.assertResponse(url, properties, 200, expected); + } + + static private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()), + OnnxModelsConfig.class).getConfig(""); + + Map<String, File> fileMap = new HashMap<>(); + for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) { + fileMap.put(onnxModel.fileref().value(), new File(path + onnxModel.fileref().value())); + } + FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap); + + return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer); + } + +} diff --git a/model-evaluation/src/test/resources/config/models/onnx-models.cfg b/model-evaluation/src/test/resources/config/models/onnx-models.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/onnx-models.cfg diff --git a/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx b/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx new file mode 100644 index 00000000000..ab054d112e9 --- /dev/null +++ b/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx @@ -0,0 +1,24 @@ + +add_mul.py:£ + +input1 +input2output1"Mul + +input1 +input2output2"Addadd_mulZ +input1 + + +Z +input2 + + +b +output1 + + +b +output2 + + +B
\ No newline at end of file diff --git a/model-evaluation/src/test/resources/config/onnx/models/add_mul.py b/model-evaluation/src/test/resources/config/onnx/models/add_mul.py new file mode 100755 index 00000000000..3a4522042e8 --- /dev/null +++ b/model-evaluation/src/test/resources/config/onnx/models/add_mul.py @@ -0,0 +1,30 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) +INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) +OUTPUT_1 = helper.make_tensor_value_info('output1', TensorProto.FLOAT, [1]) +OUTPUT_2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, [1]) + +nodes = [ + helper.make_node( + 'Mul', + ['input1', 'input2'], + ['output1'], + ), + helper.make_node( + 'Add', + ['input1', 'input2'], + ['output2'], + ), +] +graph_def = helper.make_graph( + nodes, + 'add_mul', + [INPUT_1, INPUT_2], + [OUTPUT_1, OUTPUT_2], +) +model_def = helper.make_model(graph_def, producer_name='add_mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'add_mul.onnx') diff --git a/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx b/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx Binary files differnew file mode 100644 index 00000000000..dc9f664b943 --- /dev/null +++ b/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx diff --git a/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py b/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py new file mode 100755 index 00000000000..1296d84e180 --- /dev/null +++ b/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py @@ -0,0 +1,38 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import torch +import torch.onnx + + +class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.linear = torch.nn.Linear(in_features=3, out_features=1) + self.logistic = torch.nn.Sigmoid() + + def forward(self, vec): + return self.logistic(self.linear(vec)) + + +def main(): + model = MyModel() + + # Omit training - just export randomly initialized network + + data = torch.FloatTensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]]) + torch.onnx.export(model, + data, + "one_layer.onnx", + input_names = ["input"], + output_names = ["output"], + dynamic_axes = { + "input": {0: "batch"}, + "output": {0: "batch"}, + }, + opset_version=12) + + +if __name__ == "__main__": + main() + + diff --git a/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg b/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg new file mode 100644 index 00000000000..9ad9c7f6a07 --- /dev/null +++ b/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg @@ -0,0 +1,16 @@ +model[0].name "add_mul" +model[0].fileref "models/add_mul.onnx" +model[0].input[0].name "input1" +model[0].input[0].source "input1" +model[0].input[1].name "input2" +model[0].input[1].source "input2" +model[0].output[0].name "output1" +model[0].output[0].as "output1" +model[0].output[1].name "output2" +model[0].output[1].as "output2" +model[1].name "one_layer" +model[1].fileref "models/one_layer.onnx" +model[1].input[0].name "input" +model[1].input[0].source "input" +model[1].output[0].name "output" +model[1].output[0].as "output" diff --git a/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg b/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg new file mode 100644 index 00000000000..047b7c3c77b --- /dev/null +++ b/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg @@ -0,0 +1,17 @@ +rankprofile[0].name "add_mul" +rankprofile[0].fef.property[0].name "rankingExpression(output1).rankingScript" +rankprofile[0].fef.property[0].value "onnxModel(add_mul).output1" +rankprofile[0].fef.property[1].name "rankingExpression(output1).type" +rankprofile[0].fef.property[1].value "tensor<float>(d0[1])" +rankprofile[0].fef.property[2].name "rankingExpression(output2).rankingScript" +rankprofile[0].fef.property[2].value "onnxModel(add_mul).output2" +rankprofile[0].fef.property[3].name "rankingExpression(output2).type" +rankprofile[0].fef.property[3].value "tensor<float>(d0[1])" +rankprofile[1].name "one_layer" +rankprofile[1].fef.property[0].name "rankingExpression(output).rankingScript" +rankprofile[1].fef.property[0].value "onnxModel(one_layer)" +rankprofile[1].fef.property[1].name "rankingExpression(output).type" +rankprofile[1].fef.property[1].value "tensor<float>(d0[],d1[1])" +rankprofile[2].name "no_model" +rankprofile[2].fef.property[0].name "rankingExpression(output).rankingScript" +rankprofile[2].fef.property[0].value "onnxModel(no_model)" diff --git a/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg b/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg diff --git a/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg b/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg diff --git a/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg b/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg |