summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-06-28 09:36:37 +0200
committerLester Solbakken <lesters@oath.com>2021-06-28 09:36:37 +0200
commit0b53632d1ce373aa18ed53b1be6dc362a003260e (patch)
tree2c820888d7778a34205cfa3c0e8bba2dc8b85fef /config-model
parentd713569989c88b541305e79ac531b0fc8a8bceaa (diff)
Add unit testing of ModelsEvaluator in applications
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java137
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ml/package-info.java5
-rw-r--r--config-model/src/test/cfg/application/stateless_eval/constant1asLarge.json7
-rw-r--r--config-model/src/test/cfg/application/stateless_eval/example.model34
-rw-r--r--config-model/src/test/cfg/application/stateless_eval/mul.onnx16
-rwxr-xr-xconfig-model/src/test/cfg/application/stateless_eval/mul.py26
-rw-r--r--config-model/src/test/cfg/application/stateless_eval/test.expression1
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java44
9 files changed, 271 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index d20247b79fc..93a9600e134 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -187,7 +187,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
deployState.rankProfileRegistry(), deployState.getQueryProfiles());
rankProfileList = new RankProfileList(null, // null search -> global
rankingConstants,
- largeRankExpressions,
+ largeRankExpressions,
AttributeFields.empty,
deployState.rankProfileRegistry(),
deployState.getQueryProfiles().getRegistry(),
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
new file mode 100644
index 00000000000..c06f5f4c441
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
@@ -0,0 +1,137 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.ml;
+
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
+import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.config.model.test.MockApplicationPackage;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
+import com.yahoo.io.IOUtils;
+import com.yahoo.searchdefinition.derived.RankProfileList;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.model.VespaModel;
+import org.xml.sax.SAXException;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A ModelsEvaluator object is usually injected automatically in a component if
+ * requested. This class is for creating a ModelsEvaluator so that the component
+ * can be properly unit tested. Pass a directory containing model files, such
+ * as the application's "models" directory, and it will return a ModelsEvaluator
+ * for the imported models.
+ *
+ * For use in testing only.
+ *
+ * @author lesters
+ */
+public class ModelsEvaluatorTester {
+
+ private static final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new LightGBMImporter(),
+ new XGBoostImporter(),
+ new VespaImporter());
+
+ private static final String modelEvaluationServices = "<services version=\"1.0\">" +
+ " <container version=\"1.0\">" +
+ " <model-evaluation/>" +
+ " </container>" +
+ "</services>";
+
+ /**
+ * Create a ModelsEvaluator from the models found in the modelsPath. Does
+ * not need to be in a application package.
+ *
+ * @param modelsPath Path to a directory containing models to import
+ * @return a ModelsEvaluator containing the imported models
+ */
+ public static ModelsEvaluator create(String modelsPath) {
+ File temporaryApplicationDir = null;
+ try {
+ temporaryApplicationDir = createTemporaryApplicationDir(modelsPath);
+ RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir);
+
+ RankProfilesConfig rankProfilesConfig = getRankProfilesConfig(rankProfileList);
+ RankingConstantsConfig rankingConstantsConfig = getRankingConstantConfig(rankProfileList);
+ OnnxModelsConfig onnxModelsConfig = getOnnxModelsConfig(rankProfileList);
+ FileAcquirer files = createFileAcquirer(rankingConstantsConfig, onnxModelsConfig, temporaryApplicationDir);
+
+ return new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, onnxModelsConfig, files);
+
+ } catch (IOException | SAXException e) {
+ throw new RuntimeException(e);
+ } finally {
+ if (temporaryApplicationDir != null) {
+ IOUtils.recursiveDeleteDir(temporaryApplicationDir);
+ }
+ }
+ }
+
+ private static File createTemporaryApplicationDir(String modelsPath) throws IOException {
+ String tmpDir = Files.exists(Path.of("target")) ? "target" : "";
+ File temporaryApplicationDir = Files.createTempDirectory(Path.of(tmpDir), "tmp_").toFile();
+ File modelsDir = relativePath(temporaryApplicationDir, ApplicationPackage.MODELS_DIR.toString());
+ IOUtils.copyDirectory(new File(modelsPath), modelsDir);
+ return temporaryApplicationDir;
+ }
+
+ private static RankProfileList createRankProfileList(File appDir) throws IOException, SAXException {
+ ApplicationPackage app = new MockApplicationPackage.Builder()
+ .withEmptyHosts()
+ .withServices(modelEvaluationServices)
+ .withRoot(appDir).build();
+ DeployState deployState = new DeployState.Builder().applicationPackage(app).modelImporters(importers).build();
+ VespaModel vespaModel = new VespaModel(deployState);
+ return vespaModel.rankProfileList();
+ }
+
+ private static RankProfilesConfig getRankProfilesConfig(RankProfileList rankProfileList) {
+ RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
+ rankProfileList.getConfig(builder);
+ return new RankProfilesConfig(builder);
+ }
+
+ private static RankingConstantsConfig getRankingConstantConfig(RankProfileList rankProfileList) {
+ RankingConstantsConfig.Builder builder = new RankingConstantsConfig.Builder();
+ rankProfileList.getConfig(builder);
+ return new RankingConstantsConfig(builder);
+ }
+
+ private static OnnxModelsConfig getOnnxModelsConfig(RankProfileList rankProfileList) {
+ OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
+ rankProfileList.getConfig(builder);
+ return new OnnxModelsConfig(builder);
+ }
+
+ private static FileAcquirer createFileAcquirer(RankingConstantsConfig constantsConfig, OnnxModelsConfig onnxModelsConfig, File appDir) {
+ Map<String, File> fileMap = new HashMap<>();
+ for (RankingConstantsConfig.Constant constant : constantsConfig.constant()) {
+ fileMap.put(constant.fileref().value(), relativePath(appDir, constant.fileref().value()));
+ }
+ for (OnnxModelsConfig.Model model : onnxModelsConfig.model()) {
+ fileMap.put(model.fileref().value(), relativePath(appDir, model.fileref().value()));
+ }
+ return MockFileAcquirer.returnFiles(fileMap);
+ }
+
+ private static File relativePath(File root, String subpath) {
+ return new File(root.getAbsolutePath() + File.separator + subpath);
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/package-info.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/package-info.java
new file mode 100644
index 00000000000..67556f256ed
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/package-info.java
@@ -0,0 +1,5 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+package com.yahoo.vespa.model.container.ml;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/config-model/src/test/cfg/application/stateless_eval/constant1asLarge.json b/config-model/src/test/cfg/application/stateless_eval/constant1asLarge.json
new file mode 100644
index 00000000000..d2944d255af
--- /dev/null
+++ b/config-model/src/test/cfg/application/stateless_eval/constant1asLarge.json
@@ -0,0 +1,7 @@
+{
+ "cells": [
+ { "address": { "x": "0" }, "value": 0.5 },
+ { "address": { "x": "1" }, "value": 1.5 },
+ { "address": { "x": "2" }, "value": 2.5 }
+ ]
+} \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/stateless_eval/example.model b/config-model/src/test/cfg/application/stateless_eval/example.model
new file mode 100644
index 00000000000..af1c85be4f0
--- /dev/null
+++ b/config-model/src/test/cfg/application/stateless_eval/example.model
@@ -0,0 +1,34 @@
+model example {
+
+ # All inputs that are not scalar (aka 0-dimensional tensor) must be declared
+ input1: tensor(name{}, x[3])
+ input2: tensor(x[3])
+
+ constants {
+ constant1: tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}
+ constant2: 3.0
+ }
+
+ constant constant1asLarge {
+ type: tensor(x[3])
+ file: constant1asLarge.json
+ }
+
+ function foo1() {
+ expression: file:test.expression
+ }
+
+ function foo2() {
+ expression: reduce(sum(input1 * input2, name) * constant(constant1asLarge), max, x) * constant2
+ # expression: input1 * input2
+ }
+
+ function my_input1() {
+ expression: tensor(d0[1]):[2]
+ }
+
+ function my_input2() {
+ expression: tensor(d0[1]):[3]
+ }
+
+} \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.onnx b/config-model/src/test/cfg/application/stateless_eval/mul.onnx
new file mode 100644
index 00000000000..087e2c3427f
--- /dev/null
+++ b/config-model/src/test/cfg/application/stateless_eval/mul.onnx
@@ -0,0 +1,16 @@
+mul.py:f
+
+input1
+input2output"MulmulZ
+input1
+
+
+Z
+input2
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.py b/config-model/src/test/cfg/application/stateless_eval/mul.py
new file mode 100755
index 00000000000..db01561c355
--- /dev/null
+++ b/config-model/src/test/cfg/application/stateless_eval/mul.py
@@ -0,0 +1,26 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Mul',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'mul',
+ [
+ INPUT_1,
+ INPUT_2
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'mul.onnx')
diff --git a/config-model/src/test/cfg/application/stateless_eval/test.expression b/config-model/src/test/cfg/application/stateless_eval/test.expression
new file mode 100644
index 00000000000..5db8a720498
--- /dev/null
+++ b/config-model/src/test/cfg/application/stateless_eval/test.expression
@@ -0,0 +1 @@
+reduce(sum(input1 * input2, name) * constant1, max, x) * constant2 \ No newline at end of file
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
new file mode 100644
index 00000000000..771cba673bc
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
@@ -0,0 +1,44 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.ml;
+
+import ai.vespa.models.evaluation.FunctionEvaluator;
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests the ModelsEvaluatorTester.
+ *
+ * @author lesters
+ */
+public class ModelsEvaluatorTest {
+
+ @Test
+ public void testModelsEvaluatorTester() {
+ ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval");
+ assertEquals(2, modelsEvaluator.models().size());
+
+ // ONNX model evaluation
+ FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul");
+ Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]");
+ Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]");
+ Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(6.0, output.sum().asDouble(), 1e-9);
+
+ // Vespa model evaluation
+ FunctionEvaluator foo1 = modelsEvaluator.evaluatorOf("example", "foo1");
+ input1 = Tensor.from("tensor(name{},x[3]):{{name:n,x:0}:1,{name:n,x:1}:2,{name:n,x:2}:3 }");
+ input2 = Tensor.from("tensor(x[3]):[2,3,4]");
+ output = foo1.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(90, output.asDouble(), 1e-9);
+
+ FunctionEvaluator foo2 = modelsEvaluator.evaluatorOf("example", "foo2");
+ input1 = Tensor.from("tensor(name{},x[3]):{{name:n,x:0}:1,{name:n,x:1}:2,{name:n,x:2}:3 }");
+ input2 = Tensor.from("tensor(x[3]):[2,3,4]");
+ output = foo2.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(90, output.asDouble(), 1e-9);
+ }
+
+}