summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java24
1 files changed, 23 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index ac462cc39eb..686cf6cd2df 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,11 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;
+import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
@@ -22,7 +24,27 @@ public class XGBoostImporter extends ModelImporter {
File modelFile = new File(modelPath);
if ( ! modelFile.isFile()) return false;
- return modelFile.toString().endsWith(".json"); // No other models ends by json yet
+ return modelFile.toString().endsWith(".json") && probe(modelFile);
+ }
+
+ /**
+ * Returns true if the give file looks like an XGBoost json file.
+ * Currently, we just check if the file has an array on the top level.
+ */
+ private boolean probe(File modelFile) {
+ try {
+ BufferedReader reader = IOUtils.createReader(modelFile.getAbsolutePath());
+ String line;
+ while ((line = reader.readLine()) != null) {
+ line = line.trim();
+ if (line.startsWith("[")) return true;
+ if ( ! line.isEmpty()) return false;
+ }
+ return false;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read '" + modelFile + "'", e);
+ }
}
@Override