summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-20 11:29:12 +0200
committerLester Solbakken <lesters@oath.com>2021-05-20 11:29:12 +0200
commit4a126bdd16323226411561b969e581af90260692 (patch)
tree50bd5318dd8e0f174ff26a41a786042b787c9001 /config-model/src/test/java/com
parentfc0711f7870b55ea77d18d87ec3e70b75e0de2e0 (diff)
Evaluate ONNX models in model-evaluation with ONNX RT
Diffstat (limited to 'config-model/src/test/java/com')
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java7
1 files changed, 6 insertions, 1 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index fc6a4ee2783..d0196ace766 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -15,6 +15,7 @@ import com.yahoo.path.Path;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
@@ -95,6 +96,10 @@ public class ModelEvaluationTest {
cluster.getConfig(cb);
RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb);
+ OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder();
+ cluster.getConfig(ob);
+ OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob);
+
assertEquals(4, config.rankprofile().size());
Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
assertTrue(modelNames.contains("xgboost_2_2"));
@@ -109,7 +114,7 @@ public class ModelEvaluationTest {
assertEquals(profile, sb.toString());
ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null))
- .importFrom(config, constantsConfig));
+ .importFrom(config, constantsConfig, onnxModelsConfig));
assertEquals(4, evaluator.models().size());