diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index ac462cc39eb..686cf6cd2df 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -1,11 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.xgboost; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import java.io.BufferedReader; import java.io.File; import java.io.IOException; @@ -22,7 +24,27 @@ public class XGBoostImporter extends ModelImporter { File modelFile = new File(modelPath); if ( ! modelFile.isFile()) return false; - return modelFile.toString().endsWith(".json"); // No other models ends by json yet + return modelFile.toString().endsWith(".json") && probe(modelFile); + } + + /** + * Returns true if the give file looks like an XGBoost json file. + * Currently, we just check if the file has an array on the top level. + */ + private boolean probe(File modelFile) { + try { + BufferedReader reader = IOUtils.createReader(modelFile.getAbsolutePath()); + String line; + while ((line = reader.readLine()) != null) { + line = line.trim(); + if (line.startsWith("[")) return true; + if ( ! line.isEmpty()) return false; + } + return false; + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read '" + modelFile + "'", e); + } } @Override |