diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2019-10-11 14:17:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-10-11 14:17:25 +0200 |
commit | a879e07cbf4885da58ac51844f42176db5604695 (patch) | |
tree | a40b0fb7e34ff176a3b2e8ea157b6b10df4fc6d8 | |
parent | dc2f7ed94ce95ff983700d3e44d5006cbee5bf34 (diff) | |
parent | e97518beb88268146895913ed202b0df5746aef5 (diff) |
Merge pull request #10952 from vespa-engine/lesters/add-xgboost-missing-feature-support-for-java
Add XGBoost missing feature support for Java
23 files changed, 304 insertions, 71 deletions
diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java index 6a2b7945d73..b7697d30447 100644 --- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java +++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java @@ -50,7 +50,7 @@ public class ContainerModelEvaluationTest { } { - String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}"; + String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}"; assertResponse("http://localhost/model-evaluation/v1/xgboost_xgboost_2_2/eval", expected, jdisc); } diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 2657779c5cf..c79883450e6 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -37,9 +37,7 @@ "public java.util.Set arguments()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)" ], - "fields": [ - "public static final com.yahoo.searchlib.rankingexpression.evaluation.Value defaultContextValue" - ] + "fields": [] }, "ai.vespa.models.evaluation.Model": { "superClass": "java.lang.Object", diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 8c728867f45..9db26d7ecd8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -65,8 +65,8 @@ public class FunctionEvaluator { public Tensor evaluate() { for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0) - if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue) + if (argument.getValue().rank() == 0) continue; // Scalar arguments can be skipped (defaults to 0) + if (context.isMissing(argument.getKey())) throw new IllegalStateException("Missing argument '" + argument.getKey() + "': Must be bound to a value of type " + argument.getValue()); } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 4d1b5a97583..9045e335167 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import java.util.Arrays; +import java.util.BitSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -28,10 +29,7 @@ import java.util.Set; */ public final class LazyArrayContext extends Context implements ContextIndex { - public final static Value defaultContextValue = DoubleValue.zero; - private final ExpressionFunction function; - private final IndexedBindings indexedBindings; private LazyArrayContext(ExpressionFunction function, IndexedBindings indexedBindings) { @@ -43,9 +41,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { LazyArrayContext(ExpressionFunction function, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, - Model model) { + Model model, + Value missingValue) { this.function = function; - this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model); + this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model, missingValue); } /** @@ -121,6 +120,11 @@ public final class LazyArrayContext extends Context implements ContextIndex { return index; } + boolean isMissing(String name) { + Integer index = indexedBindings.indexOf(name); + return index == null || indexedBindings.isMissing(index); + } + /** * Creates a copy of this context suitable for evaluating against the same ranking expression * in a different thread or for re-binding free variables. @@ -134,18 +138,28 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The mapping from variable name to index */ private final ImmutableMap<String, Integer> nameToIndex; - /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */ + /** The names which needs to be bound externally when envoking this (i.e not constant or invocation */ private final ImmutableSet<String> arguments; /** The current values set, pre-converted to doubles */ private final Value[] values; + /** The values that actually have been set */ + private final BitSet setValues; + + /** The value to return if not set */ + private final Value missingValue; + private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values, - ImmutableSet<String> arguments) { + ImmutableSet<String> arguments, + BitSet setValues, + Value missingValue) { this.nameToIndex = nameToIndex; this.values = values; this.arguments = arguments; + this.setValues = setValues; + this.missingValue = missingValue.freeze(); } /** @@ -156,15 +170,18 @@ public final class LazyArrayContext extends Context implements ContextIndex { Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, LazyArrayContext owner, - Model model) { + Model model, + Value missingValue) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments); this.arguments = ImmutableSet.copyOf(arguments); + this.missingValue = missingValue.freeze(); values = new Value[bindTargets.size()]; - Arrays.fill(values, defaultContextValue); + Arrays.fill(values, this.missingValue); + setValues = new BitSet(bindTargets.size()); int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); @@ -172,19 +189,22 @@ public final class LazyArrayContext extends Context implements ContextIndex { nameToIndexBuilder.put(variable, i++); nameToIndex = nameToIndexBuilder.build(); - // 2. Bind the bind targets for (Constant constant : constants) { String constantReference = "constant(" + constant.name() + ")"; Integer index = nameToIndex.get(constantReference); - if (index != null) + if (index != null) { values[index] = new TensorValue(constant.value()); + setValues.set(index); + } } for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) { Integer index = nameToIndex.get(referencedFunction.getKey().serialForm()); - if (index != null) // Referenced in this, so bind it + if (index != null) { // Referenced in this, so bind it values[index] = new LazyValue(referencedFunction.getKey(), owner, model); + setValues.set(index); + } } } @@ -227,16 +247,26 @@ public final class LazyArrayContext extends Context implements ContextIndex { } Value get(int index) { return values[index]; } - void set(int index, Value value) { values[index] = value; } + void set(int index, Value value) { + values[index] = value; + setValues.set(index); + } + Set<String> names() { return nameToIndex.keySet(); } Set<String> arguments() { return arguments; } Integer indexOf(String name) { return nameToIndex.get(name); } + boolean isMissing(int index) { return ! setValues.get(index); } IndexedBindings copy(Context context) { Value[] valueCopy = new Value[values.length]; - for (int i = 0; i < values.length; i++) - valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i]; - return new IndexedBindings(nameToIndex, valueCopy, arguments); + BitSet setValuesCopy = new BitSet(values.length); + for (int i = 0; i < values.length; i++) { + valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i]; + if (setValues.get(i)) { + setValuesCopy.set(i); + } + } + return new IndexedBindings(nameToIndex, valueCopy, arguments, setValuesCopy, missingValue); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 55da2e78894..bc80989f030 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -6,7 +6,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -27,6 +29,9 @@ public class Model { /** The prefix generated by mode-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; + /** Default value to return if value is not supplied */ + private final static Value missingValue = DoubleValue.frozen(Double.NaN); + private final String name; /** Free functions */ @@ -61,7 +66,7 @@ public class Model { ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this); + LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this, missingValue); contextBuilder.put(function.getValue().getName(), context); if ( ! function.getValue().returnType().isPresent()) { functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); @@ -135,7 +140,7 @@ public class Model { return context; } - /** Returns the function withe the given name, or null if none */ // TODO: Parameter overloading? + /** Returns the function with the given name, or null if none */ // TODO: Parameter overloading? ExpressionFunction function(String name) { for (ExpressionFunction function : functions) if (function.getName().equals(name)) diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index db892dce593..9320ac3fad8 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -43,7 +43,7 @@ public class MlModelsImportingTest { // Evaluator FunctionEvaluator evaluator = xgboost.evaluatorOf(); assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta); + assertEquals(-4.37659, evaluator.evaluate().sum().asDouble(), delta); } { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 23f0fa7a571..95f9888024a 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -63,7 +63,7 @@ public class ModelsEvaluationHandlerTest { @Test public void testXgBoostEvaluationWithoutBindings() { String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function - String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}"; + String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}"; assertResponse(url, 200, expected); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java new file mode 100644 index 00000000000..ec2498b3923 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java @@ -0,0 +1,86 @@ +package ai.vespa.rankingexpression.importer.xgboost; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestNode; +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author lesters + */ +public class XGBoostImportEvaluationTestCase { + + @Test + public void testXGBoostEvaluation() { + RankingExpression expression = new XGBoostImporter() + .importModel("xgb", "src/test/models/xgboost/xgboost.test.if_inversion.json") + .expressions().get("xgb"); + + ArrayContext context = new ArrayContext(expression, DoubleValue.NaN); + + assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0)); + assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0)); + assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0)); + assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0)); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0)); + assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0)); + assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0)); + assertXGBoostEvaluation(11.0, expression, features(context)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0))); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0))); + + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + assertTrue(expression.getRoot() instanceof GBDTForestNode); + + assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0)); + assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0)); + assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0)); + assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0)); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0)); + assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0)); + assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0)); + assertXGBoostEvaluation(11.0, expression, features(context)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0))); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0))); + } + + private ArrayContext features(ArrayContext context) { + return context.clone(); + } + + private ArrayContext features(ArrayContext context, String f1, double v1) { + context = context.clone(); + context.put(f1, v1); + return context; + } + + private ArrayContext features(ArrayContext context, String f1, Tensor v1) { + context = context.clone(); + context.put(f1, new TensorValue(v1)); + return context; + } + + private ArrayContext features(ArrayContext context, String f1, double v1, String f2, double v2) { + context = context.clone(); + context.put(f1, v1); + context.put(f2, v2); + return context; + } + + private void assertXGBoostEvaluation(double expected, RankingExpression expr, Context context) { + assertEquals(expected, expr.evaluate(context).asDouble(), 1e-9); + } + +} diff --git a/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json new file mode 100644 index 00000000000..8994d89787e --- /dev/null +++ b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json @@ -0,0 +1,26 @@ +[ + { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 0.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 1.0 }, + { "nodeid": 4, "leaf": 2.0 } + ]}, + { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 5, "no": 6, "missing": 5, "children": [ + { "nodeid": 5, "leaf": 3.0 }, + { "nodeid": 6, "leaf": 4.0 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "leaf": 0.0 }, + { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "leaf": 4.0 }, + { "nodeid": 4, "leaf": 5.0 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f2", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "leaf": 0.0 }, + { "nodeid": 2, "depth": 1, "split": "f1", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "leaf": 4.0 }, + { "nodeid": 4, "leaf": 3.0 } + ]} + ]} +]
\ No newline at end of file diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 0b9cb06d2a5..c72be5b23e9 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -378,6 +378,7 @@ "methods": [ "protected void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression)", "protected void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean)", + "protected void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "protected final java.util.Map nameToIndex()", "protected final double[] doubleValues()", "protected final boolean ignoreUnknownValues()", @@ -402,6 +403,8 @@ "methods": [ "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression)", "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean)", + "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public final void put(java.lang.String, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public final void put(int, double)", "public final void put(int, com.yahoo.searchlib.rankingexpression.evaluation.Value)", @@ -509,6 +512,7 @@ "methods": [ "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression)", "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean)", + "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public final void put(java.lang.String, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public final void put(int, double)", "public final void put(int, com.yahoo.searchlib.rankingexpression.evaluation.Value)", @@ -550,7 +554,8 @@ "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()" ], "fields": [ - "public static final com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue zero" + "public static final com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue zero", + "public static final com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue NaN" ] }, "com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer": { @@ -575,7 +580,9 @@ ], "methods": [ "public void <init>()", + "public void <init>(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public void <init>(java.util.Map)", + "public void <init>(java.util.Map, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.MapContext freeze()", "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java index 41bf827748a..16549b3ee1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import java.util.BitSet; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -33,7 +34,11 @@ public abstract class AbstractArrayContext extends Context implements Cloneable, * This will fail if unknown values are attempted added. */ protected AbstractArrayContext(RankingExpression expression) { - this(expression, false); + this(expression, false, defaultMissingValue); + } + + protected AbstractArrayContext(RankingExpression expression, boolean ignoreUnknownValues) { + this(expression, ignoreUnknownValues, defaultMissingValue); } /** @@ -44,10 +49,11 @@ public abstract class AbstractArrayContext extends Context implements Cloneable, * @param ignoreUnknownValues whether attempts to put values not present in this expression * should fail (false - the default), or be ignored (true) */ - protected AbstractArrayContext(RankingExpression expression, boolean ignoreUnknownValues) { + protected AbstractArrayContext(RankingExpression expression, boolean ignoreUnknownValues, Value missingValue) { + this.missingValue = missingValue.freeze(); this.ignoreUnknownValues = ignoreUnknownValues; this.rankingExpressionName = expression.getName(); - this.indexedBindings = new IndexedBindings(expression); + this.indexedBindings = new IndexedBindings(expression, this.missingValue); } protected final Map<String, Integer> nameToIndex() { return indexedBindings.nameToIndex(); } @@ -77,6 +83,14 @@ public abstract class AbstractArrayContext extends Context implements Cloneable, return indexedBindings.getDouble(index); } + final boolean isMissing(int index) { + return indexedBindings.isMissing(index); + } + + final void clearMissing(int index) { + indexedBindings.clearMissing(index); + } + @Override public String toString() { return "fast lookup context for ranking expression '" + rankingExpressionName + @@ -107,11 +121,22 @@ public abstract class AbstractArrayContext extends Context implements Cloneable, /** The current values set, pre-converted to doubles */ private double[] doubleValues; - public IndexedBindings(RankingExpression expression) { + /** Which values actually are set */ + private BitSet setValues; + + /** Value to return if value is missing. */ + private double missingValue; + + public IndexedBindings(RankingExpression expression, Value missingValue) { Set<String> bindTargets = new LinkedHashSet<>(); extractBindTargets(expression.getRoot(), bindTargets); + this.missingValue = missingValue.asDouble(); + setValues = new BitSet(bindTargets.size()); doubleValues = new double[bindTargets.size()]; + for (int i = 0; i < bindTargets.size(); ++i) { + doubleValues[i] = this.missingValue; + } int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); @@ -136,10 +161,13 @@ public abstract class AbstractArrayContext extends Context implements Cloneable, public Map<String, Integer> nameToIndex() { return nameToIndex; } public double[] doubleValues() { return doubleValues; } + public Set<String> names() { return nameToIndex.keySet(); } public int getIndex(String name) { return nameToIndex.get(name); } public int size() { return doubleValues.length; } public double getDouble(int index) { return doubleValues[index]; } + public boolean isMissing(int index) { return ! setValues.get(index); } + public void clearMissing(int index) { setValues.set(index); } /** * Creates a clone of this context suitable for evaluating against the same ranking expression @@ -149,7 +177,11 @@ public abstract class AbstractArrayContext extends Context implements Cloneable, public IndexedBindings clone() { try { IndexedBindings clone = (IndexedBindings)super.clone(); + clone.setValues = new BitSet(nameToIndex.size()); clone.doubleValues = new double[nameToIndex.size()]; + for (int i = 0; i < nameToIndex.size(); ++i) { + clone.doubleValues[i] = missingValue; + } return clone; } catch (CloneNotSupportedException e) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index 047d9d761ce..82243fc493d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -19,8 +19,6 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { /** The current values set */ private Value[] values; - private static DoubleValue constantZero = DoubleValue.frozen(0); - /** * Create a fast lookup context for an expression. * This instance should be reused indefinitely by a single thread. @@ -30,6 +28,14 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { this(expression, false); } + public ArrayContext(RankingExpression expression, boolean ignoreUnknownValues) { + this(expression, ignoreUnknownValues, defaultMissingValue); + } + + public ArrayContext(RankingExpression expression, Value defaultValue) { + this(expression, false, defaultValue); + } + /** * Create a fast lookup context for an expression. * This instance should be reused indefinitely by a single thread. @@ -37,11 +43,12 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { * @param expression the expression to create a context for * @param ignoreUnknownValues whether attempts to put values not present in this expression * should fail (false - the default), or be ignored (true) + * @param missingValue the value to return if not set. */ - public ArrayContext(RankingExpression expression, boolean ignoreUnknownValues) { - super(expression, ignoreUnknownValues); + public ArrayContext(RankingExpression expression, boolean ignoreUnknownValues, Value missingValue) { + super(expression, ignoreUnknownValues, missingValue); values = new Value[doubleValues().length]; - Arrays.fill(values, DoubleValue.zero); + Arrays.fill(values, this.missingValue); } /** @@ -74,6 +81,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { */ public final void put(int index, Value value) { values[index] = value.freeze(); + clearMissing(index); try { doubleValues()[index] = value.asDouble(); } @@ -93,7 +101,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { @Override public Value get(String name) { Integer index = nameToIndex().get(name); - if (index == null) return DoubleValue.zero; + if (index == null) return missingValue; return values[index]; } @@ -107,7 +115,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { @Override public final double getDouble(int index) { double value = doubleValues()[index]; - if (Double.isNaN(value)) + if (Double.isNaN(value) && ! isMissing(index)) // NaN is valid as a missing value throw new UnsupportedOperationException("Value at " + index + " has no double representation"); return value; } @@ -119,7 +127,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { public ArrayContext clone() { ArrayContext clone = (ArrayContext)super.clone(); clone.values = new Value[nameToIndex().size()]; - Arrays.fill(clone.values, constantZero); + Arrays.fill(clone.values, missingValue); return clone; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java index 8ac9a6787da..0e187dfc87c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java @@ -49,8 +49,9 @@ public class BooleanValue extends DoubleCompatibleValue { @Override public boolean equals(Object other) { if (this==other) return true; - if ( ! (other instanceof BooleanValue)) return false; - return ((BooleanValue)other).value==this.value; + if ( ! (other instanceof Value)) return false; + if ( ! ((Value) other).hasDouble()) return false; + return this.value == ((Value) other).asBoolean(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 4e046df11ca..d68f8c85ad1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -18,6 +18,12 @@ import java.util.stream.Collectors; */ public abstract class Context implements EvaluationContext<Reference> { + /** The default value to return if the value has not been set */ + static Value defaultMissingValue = DoubleValue.zero; + + /** The value to return if the value has not been set */ + Value missingValue; + /** * Returns the value of a simple variable name. * diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index 0004036da4b..257b344f025 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -19,7 +19,11 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { * This will fail if unknown values are attempted added. */ public DoubleOnlyArrayContext(RankingExpression expression) { - this(expression, false); + this(expression, false, defaultMissingValue); + } + + public DoubleOnlyArrayContext(RankingExpression expression, boolean ignoreUnknownValues) { + this(expression, ignoreUnknownValues, defaultMissingValue); } /** @@ -29,9 +33,10 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { * @param expression the expression to create a context for * @param ignoreUnknownValues whether attempts to put values not present in this expression * should fail (false - the default), or be ignored (true) + * @param missingValue the value to return if not set. */ - public DoubleOnlyArrayContext(RankingExpression expression, boolean ignoreUnknownValues) { - super(expression, ignoreUnknownValues); + public DoubleOnlyArrayContext(RankingExpression expression, boolean ignoreUnknownValues, Value missingValue) { + super(expression, ignoreUnknownValues, missingValue); } /** @@ -56,6 +61,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { /** Same as put(index,DoubleValue.frozen(value)) */ public final void put(int index, double value) { doubleValues()[index] = value; + clearMissing(index); } /** Puts a value by index. */ @@ -77,7 +83,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { @Override public Value get(String name) { Integer index = nameToIndex().get(name); - if (index==null) return DoubleValue.zero; + if (index==null) return missingValue; return new DoubleValue(getDouble(index)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 8aa7446cae7..06ab4cba98f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -20,6 +20,9 @@ public final class DoubleValue extends DoubleCompatibleValue { /** The double value instance for 0 */ public final static DoubleValue zero = DoubleValue.frozen(0); + /** The double value instance for NaN */ + public final static DoubleValue NaN = DoubleValue.frozen(Double.NaN); + public DoubleValue(double value) { this.value = value; } @@ -146,8 +149,9 @@ public final class DoubleValue extends DoubleCompatibleValue { @Override public boolean equals(Object other) { if (this==other) return true; - if ( ! (other instanceof DoubleValue)) return false; - return ((DoubleValue)other).value==this.value; + if ( ! (other instanceof Value)) return false; + if ( ! ((Value) other).hasDouble()) return false; + return this.asDouble() == ((Value) other).asDouble(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 4ef24d60bba..f531d77762d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -21,13 +21,23 @@ public class MapContext extends Context { private boolean frozen = false; public MapContext() { + this(defaultMissingValue); + } + + public MapContext(Value missingValue) { + this.missingValue = missingValue.freeze(); + } + + public MapContext(Map<String,Value> bindings) { + this(bindings, defaultMissingValue); } /** * Creates a map context from a map. * All the Values of the map will be frozen. */ - public MapContext(Map<String,Value> bindings) { + public MapContext(Map<String,Value> bindings, Value missingValue) { + this.missingValue = missingValue.freeze(); bindings.forEach((k, v) -> this.bindings.put(k, v.freeze())); } @@ -52,7 +62,7 @@ public class MapContext extends Context { /** Returns the value of a key. 0 is returned if the given key is not bound in this. */ @Override public Value get(String key) { - return bindings.getOrDefault(key, DoubleValue.zero); + return bindings.getOrDefault(key, missingValue); } /** diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index ee66dcc5a03..b109e6503e3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -41,7 +41,9 @@ public class TensorValue extends Value { @Override public boolean asBoolean() { - throw new UnsupportedOperationException("A tensor does not have a boolean value"); + if (hasDouble()) + return asDouble() != 0.0; + throw new UnsupportedOperationException("Tensor does not have a value that can be converted to a boolean"); } @Override @@ -118,18 +120,11 @@ public class TensorValue extends Value { return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble()))); } - private Tensor asTensor(Value value, String operationName) { - if ( ! (value instanceof TensorValue)) - throw new UnsupportedOperationException("Could not perform " + operationName + - ": The second argument must be a tensor but was " + value); - return ((TensorValue)value).value; - } - public Tensor asTensor() { return value; } @Override public Value compare(TruthOperator operator, Value argument) { - return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); + return new TensorValue(compareTensor(operator, argument.asTensor())); } private Tensor compareTensor(TruthOperator operator, Tensor argument) { @@ -148,7 +143,7 @@ public class TensorValue extends Value { @Override public Value function(Function function, Value arg) { if (arg instanceof TensorValue) - return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString()))); + return new TensorValue(functionOnTensor(function, arg.asTensor())); else return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 7809cdd4e1b..39e408d27ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -68,7 +68,7 @@ public abstract class Value { public abstract Value compare(TruthOperator operator, Value value); /** Perform the given binary function on this value and the given value */ - public abstract Value function(Function function,Value value); + public abstract Value function(Function function, Value value); /** * Irreversibly makes this immutable. Overriders must always call super.freeze() and return this diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index 3c0898b5d4f..c1ec72ba0fc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -26,15 +26,16 @@ public final class GBDTNode extends ExpressionNode { // n=[0,MAX_LEAF_VALUE> : n is data (tree leaf constant value) // n=[MAX_LEAF_VALUE+MAX_VARIABLES*0,MAX_LEAF_VALUE+MAX_VARIABLES*1>: < than var at index n // n=[MAX_LEAF_VALUE+MAX_VARIABLES*1,MAX_LEAF_VALUE+MAX_VARIABLES*2>: = to var at index n-MAX_VARIABLES - // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3]: n-MAX_VARIABLES*2 is IN the following set + // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3>: n-MAX_VARIABLES*2 is IN the following set + // n=[MAX_LEAF_VALUE+MAX_VARIABLES*3,MAX_LEAF_VALUE+MAX_VARIABLES*4]: !( >= ) than var at index n-MAX_VARIABLES*3 (if-inversion) // The full layout of an IF instruction is // COMPARISON,TRUE_BRANCH_LENGTH,TRUE_BRANCH,FALSE_BRANCH - // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or =, + // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or = or !( >= ), // and VARIABLE_AND_OPCODE,COMPARE_CONSTANTS_LENGTH,COMPARE_CONSTANTS if the opcode is IN - // If any change is made to this encoding, this change must also be reflected in GBDTNodeOptimizer + // If any change is made to this encoding, this change must also be reflected in GBDTOptimizer /** The max (absolute) supported value an optimized leaf may have */ public final static int MAX_LEAF_VALUE=2*1000*1000*1000; @@ -72,7 +73,7 @@ public final class GBDTNode extends ExpressionNode { else if (offset < MAX_VARIABLES*2) { comparisonIsTrue = context.getDouble(offset-MAX_VARIABLES)==values[pc++]; } - else { // offset<MAX_VARIABLES*3 + else if (offset<MAX_VARIABLES*3) { double testValue = context.getDouble(offset-MAX_VARIABLES*2); int setValuesLeft = (int)values[pc++]; while (setValuesLeft > 0) { // test each value in the set @@ -84,6 +85,9 @@ public final class GBDTNode extends ExpressionNode { } pc += setValuesLeft; // jump to after the set } + else { // offset<MAX_VARIABLES*4 + comparisonIsTrue = ! (context.getDouble(offset-MAX_VARIABLES*3)>=values[pc++]); + } if (comparisonIsTrue) pc++; // true branch - skip the jump value diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index 787818b0f42..a6df6b435d6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -53,7 +53,7 @@ public class GBDTOptimizer extends Optimizer { * anything else.</p> * * <p>Each condition node is converted to the double sequence [(OperatorIsEquals ? GBDTNode.MAX_VARIABLES : 0) + - * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAFT_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true + * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAF_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true * branch values,false branch values]</p> * * <p>Each value node is converted to the double value of the value node itself.</p> @@ -131,6 +131,20 @@ public class GBDTOptimizer extends Optimizer { for (ExpressionNode setElementNode : setMembership.getSetValues()) values.add(toValue(setElementNode)); } + else if (condition instanceof NotNode) { // handle if inversion: !(a >= b) + NotNode notNode = (NotNode)condition; + if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) { + EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0); + if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) { + ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0); + if (comparison.getOperator() == TruthOperator.LARGEREQUAL) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context)); + else + throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator()); + values.add(toValue(comparison.getRightCondition())); + } + } + } else { throw new IllegalArgumentException("Node condition could not be optimized: " + condition); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java index ce78703f842..08f1a872759 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java @@ -24,7 +24,7 @@ public class GBDTForestOptimizerTestCase { RankingExpression gbdt = new RankingExpression(gbdtString); // Regular evaluation - MapContext arguments = new MapContext(); + MapContext arguments = new MapContext(DoubleValue.NaN); arguments.put("LW_NEWS_SEARCHES_RATIO", 1d); arguments.put("SUGG_OVERLAP", 17d); double result1 = gbdt.evaluate(arguments).asDouble(); @@ -36,7 +36,7 @@ public class GBDTForestOptimizerTestCase { double result3 = gbdt.evaluate(arguments).asDouble(); // Optimized evaluation - ArrayContext fArguments = new ArrayContext(gbdt); + ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN); ExpressionOptimizer optimizer = new ExpressionOptimizer(); OptimizationReport report = optimizer.optimize(gbdt, fArguments); assertEquals(4, report.getMetric("Optimized GDBT trees")); @@ -70,7 +70,7 @@ public class GBDTForestOptimizerTestCase { RankingExpression gbdt = new RankingExpression(gbdtString); // Regular evaluation - MapContext arguments = new MapContext(); + MapContext arguments = new MapContext(DoubleValue.NaN); arguments.put("MYSTRING", new StringValue("string 1")); arguments.put("LW_NEWS_SEARCHES_RATIO", 1d); arguments.put("SUGG_OVERLAP", 17d); @@ -83,7 +83,7 @@ public class GBDTForestOptimizerTestCase { double result3 = gbdt.evaluate(arguments).asDouble(); // Optimized evaluation - ArrayContext fArguments = new ArrayContext(gbdt); + ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN); ExpressionOptimizer optimizer = new ExpressionOptimizer(); OptimizationReport report = optimizer.optimize(gbdt, fArguments); assertEquals(4, report.getMetric("Optimized GDBT trees")); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java index 4b7462505fc..82ad034e306 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; @@ -43,7 +44,7 @@ public class GBDTOptimizerTestCase { RankingExpression gbdt = new RankingExpression(gbdtString); // Regular evaluation - MapContext arguments = new MapContext(); + MapContext arguments = new MapContext(DoubleValue.NaN); arguments.put("LW_NEWS_SEARCHES_RATIO", 1d); arguments.put("SUGG_OVERLAP", 17d); double result1 = gbdt.evaluate(arguments).asDouble(); @@ -55,7 +56,7 @@ public class GBDTOptimizerTestCase { double result3 = gbdt.evaluate(arguments).asDouble(); // Optimized evaluation - ArrayContext fArguments = new ArrayContext(gbdt); + ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN); ExpressionOptimizer optimizer = new ExpressionOptimizer(); optimizer.getOptimizer(GBDTForestOptimizer.class).setEnabled(false); OptimizationReport report = optimizer.optimize(gbdt,fArguments); |