From f656ff5c15d95905f48d5829278ec241f1941577 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Sun, 2 Feb 2020 17:39:44 +0100 Subject: Add support for importing LightGBM models --- .../expressiontransforms/ExpressionTransforms.java | 1 + .../LightGBMFeatureConverter.java | 59 ++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/LightGBMFeatureConverter.java (limited to 'config-model/src/main/java/com/yahoo/searchdefinition') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index 6fdf448a39b..a6707ec7ac0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -27,6 +27,7 @@ public class ExpressionTransforms { ImmutableList.of(new TensorFlowFeatureConverter(), new OnnxFeatureConverter(), new XgboostFeatureConverter(), + new LightGBMFeatureConverter(), new ConstantDereferencer(), new ConstantTensorTransformer(), new FunctionInliner(), diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/LightGBMFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/LightGBMFeatureConverter.java new file mode 100644 index 00000000000..5bde627dc0a --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/LightGBMFeatureConverter.java @@ -0,0 +1,59 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; + +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Replaces instances of the lightgbm(model-path) pseudofeature with the + * native Vespa ranking expression implementing the same computation. + * + * @author lesters + */ +public class LightGBMFeatureConverter extends ExpressionTransformer { + + /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ + private final Map convertedLightGBMModels = new HashMap<>(); + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if ( ! feature.getName().equals("lightgbm")) return feature; + + try { + FeatureArguments arguments = asFeatureArguments(feature.getArguments()); + ConvertedModel convertedModel = + convertedLightGBMModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, true, context)); + return convertedModel.expression(arguments, context); + } catch (IllegalArgumentException | UncheckedIOException e) { + throw new IllegalArgumentException("Could not use LightGBM model from " + feature, e); + } + } + + private FeatureArguments asFeatureArguments(Arguments arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("A lightgbm node must take a single argument pointing to " + + "the LightGBM model file under [application]/models"); + return new FeatureArguments(arguments); + } + +} -- cgit v1.2.3