diff options
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java index ef731730c84..76fcf8890c2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java @@ -4,6 +4,8 @@ package ai.vespa.rankingexpression.importer.lightgbm; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.databind.ObjectMapper; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -18,6 +20,7 @@ import java.io.IOException; */ public class LightGBMImporter extends ModelImporter { + private final ObjectMapper objectMapper = new ObjectMapper(); @Override public boolean canImport(String modelPath) { File modelFile = new File(modelPath); @@ -31,7 +34,14 @@ public class LightGBMImporter extends ModelImporter { */ private boolean probe(File modelFile) { try { - return new ObjectMapper().readTree(modelFile).has("tree_info"); + JsonParser parser = objectMapper.createParser(modelFile); + while (parser.nextToken() != null) { + JsonToken token = parser.getCurrentToken(); + if (token == JsonToken.FIELD_NAME) { + if ("tree_info".equals(parser.getCurrentName())) return true; + } + } + return false; } catch (IOException e) { throw new IllegalArgumentException("Could not read '" + modelFile + "'", e); } |