aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2021-05-27 10:41:31 +0200
committerGitHub <noreply@github.com>2021-05-27 10:41:31 +0200
commite3d4dbac364216f8d93493d4a5f34835a268fbcf (patch)
tree90bc2cf28e08123a55854c2db1217f556d349a2e /model-evaluation
parent92efe91ec3d7be1902e7ca9c0e290c7859d535af (diff)
parent6b6e59869ab5259a8cd2e382cd2b5164a963a293 (diff)
Merge branch 'master' into lesters/wire-in-stateless-onnx-rt
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/abi-spec.json2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java14
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java12
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java7
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java5
-rw-r--r--model-evaluation/src/test/resources/config/models/ranking-expressions.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/onnx/ranking-expressions.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/rankexpression/ranking-expressions.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/smallconstant/ranking-expressions.cfg0
12 files changed, 46 insertions, 9 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 63882525808..fac4baf4682 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -67,6 +67,7 @@
"public"
],
"methods": [
+ "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
"public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
"public void <init>(java.util.Map)",
"public java.util.Map models()",
@@ -83,6 +84,7 @@
],
"methods": [
"public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
+ "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"protected com.yahoo.tensor.Tensor readTensorFromFile(java.lang.String, com.yahoo.tensor.TensorType, com.yahoo.config.FileReference)"
],
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 88766da67fc..2d1a0d069a8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -9,9 +9,9 @@ import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import java.util.Map;
-import java.util.stream.Collectors;
/**
* Evaluates machine-learned models added to Vespa applications and available as config form.
@@ -28,9 +28,19 @@ public class ModelsEvaluator extends AbstractComponent {
@Inject
public ModelsEvaluator(RankProfilesConfig config,
RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer) {
- this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig, onnxModelsConfig));
+ this(new RankProfilesConfigImporter(fileAcquirer)
+ .importFrom(config, constantsConfig, expressionsConfig, onnxModelsConfig));
+ }
+
+ public ModelsEvaluator(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig,
+ FileAcquirer fileAcquirer) {
+ this(new RankProfilesConfigImporter(fileAcquirer)
+ .importFrom(config, constantsConfig, new RankingExpressionsConfig.Builder().build(), onnxModelsConfig));
}
public ModelsEvaluator(Map<String, Model> models) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 1bdb2810ddf..06ca7a60f4c 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -15,6 +15,7 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import java.io.File;
import java.io.IOException;
@@ -51,11 +52,12 @@ public class RankProfilesConfigImporter {
*/
public Map<String, Model> importFrom(RankProfilesConfig config,
RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig) {
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
- Model model = importProfile(profile, constantsConfig, onnxModelsConfig);
+ Model model = importProfile(profile, constantsConfig, expressionsConfig, onnxModelsConfig);
models.put(model.name(), model);
}
return models;
@@ -65,8 +67,16 @@ public class RankProfilesConfigImporter {
}
}
+ @Deprecated
+ public Map<String, Model> importFrom(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ OnnxModelsConfig onnxModelsConfig) {
+ return importFrom(config, constantsConfig, new RankingExpressionsConfig.Builder().build(), onnxModelsConfig);
+ }
+
private Model importProfile(RankProfilesConfig.Rankprofile profile,
RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig)
throws ParseException {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index d252594e729..b6878f4ea1a 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -16,6 +16,7 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import java.io.IOException;
import java.util.Map;
@@ -46,10 +47,12 @@ public class ModelTester {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ RankingExpressionsConfig expresionsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-expressions.cfg").toFile()),
+ RankingExpressionsConfig.class).getConfig("");
OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
- OnnxModelsConfig.class).getConfig("");
+ OnnxModelsConfig.class).getConfig("");
return new RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), MockFileAcquirer.returnFile(null))
- .importFrom(config, constantsConfig, onnxModelsConfig);
+ .importFrom(config, constantsConfig, expresionsConfig, onnxModelsConfig);
}
public ExpressionFunction assertFunction(String name, String expression, Model model) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index dce033c79b0..2a8b7479b00 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import com.yahoo.yolean.Exceptions;
import org.junit.Test;
@@ -132,9 +133,11 @@ public class ModelsEvaluatorTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ RankingExpressionsConfig expressionsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-expressions.cfg").toFile()),
+ RankingExpressionsConfig.class).getConfig("");
OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
OnnxModelsConfig.class).getConfig("");
- return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, MockFileAcquirer.returnFile(null));
+ return new ModelsEvaluator(config, constantsConfig, expressionsConfig, onnxModelsConfig, MockFileAcquirer.returnFile(null));
}
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
index 1d55fdf9e6a..c137fb9a522 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import org.junit.Test;
import java.io.File;
@@ -54,6 +55,8 @@ public class OnnxEvaluatorTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ RankingExpressionsConfig expressionsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-expressions.cfg").toFile()),
+ RankingExpressionsConfig.class).getConfig("");
OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
OnnxModelsConfig.class).getConfig("");
@@ -63,7 +66,7 @@ public class OnnxEvaluatorTest {
}
FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
- return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
+ return new ModelsEvaluator(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer);
}
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index a69a220e532..d5282bd52f7 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -198,12 +199,14 @@ public class ModelsEvaluationHandlerTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ RankingExpressionsConfig expressionsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-expressions.cfg").toFile()),
+ RankingExpressionsConfig.class).getConfig("");
OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
OnnxModelsConfig.class).getConfig("");
ModelTester.RankProfilesConfigImporterWithMockedConstants importer =
new ModelTester.RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"),
MockFileAcquirer.returnFile(null));
- return new ModelsEvaluator(importer.importFrom(config, constantsConfig, onnxModelsConfig));
+ return new ModelsEvaluator(importer.importFrom(config, constantsConfig, expressionsConfig, onnxModelsConfig));
}
private String inputTensor() {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
index 6cfda4d8ce8..97692de56ef 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -122,6 +123,8 @@ public class OnnxEvaluationHandlerTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ RankingExpressionsConfig expressionsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-expressions.cfg").toFile()),
+ RankingExpressionsConfig.class).getConfig("");
OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
OnnxModelsConfig.class).getConfig("");
@@ -131,7 +134,7 @@ public class OnnxEvaluationHandlerTest {
}
FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
- return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
+ return new ModelsEvaluator(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer);
}
}
diff --git a/model-evaluation/src/test/resources/config/models/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/models/ranking-expressions.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/models/ranking-expressions.cfg
diff --git a/model-evaluation/src/test/resources/config/onnx/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/onnx/ranking-expressions.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/ranking-expressions.cfg
diff --git a/model-evaluation/src/test/resources/config/rankexpression/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/rankexpression/ranking-expressions.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/rankexpression/ranking-expressions.cfg
diff --git a/model-evaluation/src/test/resources/config/smallconstant/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/smallconstant/ranking-expressions.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/smallconstant/ranking-expressions.cfg