summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2019-10-11 14:17:25 +0200
committerGitHub <noreply@github.com>2019-10-11 14:17:25 +0200
commita879e07cbf4885da58ac51844f42176db5604695 (patch)
treea40b0fb7e34ff176a3b2e8ea157b6b10df4fc6d8
parentdc2f7ed94ce95ff983700d3e44d5006cbee5bf34 (diff)
parente97518beb88268146895913ed202b0df5746aef5 (diff)
Merge pull request #10952 from vespa-engine/lesters/add-xgboost-missing-feature-support-for-java
Add XGBoost missing feature support for Java
-rw-r--r--application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java2
-rw-r--r--model-evaluation/abi-spec.json4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java62
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java9
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java2
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java86
-rw-r--r--model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json26
-rw-r--r--searchlib/abi-spec.json9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java40
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java16
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java5
23 files changed, 304 insertions, 71 deletions
diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
index 6a2b7945d73..b7697d30447 100644
--- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
+++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
@@ -50,7 +50,7 @@ public class ContainerModelEvaluationTest {
}
{
- String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}";
assertResponse("http://localhost/model-evaluation/v1/xgboost_xgboost_2_2/eval", expected, jdisc);
}
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 2657779c5cf..c79883450e6 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -37,9 +37,7 @@
"public java.util.Set arguments()",
"public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)"
],
- "fields": [
- "public static final com.yahoo.searchlib.rankingexpression.evaluation.Value defaultContextValue"
- ]
+ "fields": []
},
"ai.vespa.models.evaluation.Model": {
"superClass": "java.lang.Object",
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 8c728867f45..9db26d7ecd8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -65,8 +65,8 @@ public class FunctionEvaluator {
public Tensor evaluate() {
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0)
- if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue)
+ if (argument.getValue().rank() == 0) continue; // Scalar arguments can be skipped (defaults to 0)
+ if (context.isMissing(argument.getKey()))
throw new IllegalStateException("Missing argument '" + argument.getKey() +
"': Must be bound to a value of type " + argument.getValue());
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 4d1b5a97583..9045e335167 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
+import java.util.BitSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@@ -28,10 +29,7 @@ import java.util.Set;
*/
public final class LazyArrayContext extends Context implements ContextIndex {
- public final static Value defaultContextValue = DoubleValue.zero;
-
private final ExpressionFunction function;
-
private final IndexedBindings indexedBindings;
private LazyArrayContext(ExpressionFunction function, IndexedBindings indexedBindings) {
@@ -43,9 +41,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
LazyArrayContext(ExpressionFunction function,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
- Model model) {
+ Model model,
+ Value missingValue) {
this.function = function;
- this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model);
+ this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model, missingValue);
}
/**
@@ -121,6 +120,11 @@ public final class LazyArrayContext extends Context implements ContextIndex {
return index;
}
+ boolean isMissing(String name) {
+ Integer index = indexedBindings.indexOf(name);
+ return index == null || indexedBindings.isMissing(index);
+ }
+
/**
* Creates a copy of this context suitable for evaluating against the same ranking expression
* in a different thread or for re-binding free variables.
@@ -134,18 +138,28 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The mapping from variable name to index */
private final ImmutableMap<String, Integer> nameToIndex;
- /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */
+ /** The names which needs to be bound externally when envoking this (i.e not constant or invocation */
private final ImmutableSet<String> arguments;
/** The current values set, pre-converted to doubles */
private final Value[] values;
+ /** The values that actually have been set */
+ private final BitSet setValues;
+
+ /** The value to return if not set */
+ private final Value missingValue;
+
private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
Value[] values,
- ImmutableSet<String> arguments) {
+ ImmutableSet<String> arguments,
+ BitSet setValues,
+ Value missingValue) {
this.nameToIndex = nameToIndex;
this.values = values;
this.arguments = arguments;
+ this.setValues = setValues;
+ this.missingValue = missingValue.freeze();
}
/**
@@ -156,15 +170,18 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
LazyArrayContext owner,
- Model model) {
+ Model model,
+ Value missingValue) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments);
this.arguments = ImmutableSet.copyOf(arguments);
+ this.missingValue = missingValue.freeze();
values = new Value[bindTargets.size()];
- Arrays.fill(values, defaultContextValue);
+ Arrays.fill(values, this.missingValue);
+ setValues = new BitSet(bindTargets.size());
int i = 0;
ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
@@ -172,19 +189,22 @@ public final class LazyArrayContext extends Context implements ContextIndex {
nameToIndexBuilder.put(variable, i++);
nameToIndex = nameToIndexBuilder.build();
-
// 2. Bind the bind targets
for (Constant constant : constants) {
String constantReference = "constant(" + constant.name() + ")";
Integer index = nameToIndex.get(constantReference);
- if (index != null)
+ if (index != null) {
values[index] = new TensorValue(constant.value());
+ setValues.set(index);
+ }
}
for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) {
Integer index = nameToIndex.get(referencedFunction.getKey().serialForm());
- if (index != null) // Referenced in this, so bind it
+ if (index != null) { // Referenced in this, so bind it
values[index] = new LazyValue(referencedFunction.getKey(), owner, model);
+ setValues.set(index);
+ }
}
}
@@ -227,16 +247,26 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
Value get(int index) { return values[index]; }
- void set(int index, Value value) { values[index] = value; }
+ void set(int index, Value value) {
+ values[index] = value;
+ setValues.set(index);
+ }
+
Set<String> names() { return nameToIndex.keySet(); }
Set<String> arguments() { return arguments; }
Integer indexOf(String name) { return nameToIndex.get(name); }
+ boolean isMissing(int index) { return ! setValues.get(index); }
IndexedBindings copy(Context context) {
Value[] valueCopy = new Value[values.length];
- for (int i = 0; i < values.length; i++)
- valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i];
- return new IndexedBindings(nameToIndex, valueCopy, arguments);
+ BitSet setValuesCopy = new BitSet(values.length);
+ for (int i = 0; i < values.length; i++) {
+ valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i];
+ if (setValues.get(i)) {
+ setValuesCopy.set(i);
+ }
+ }
+ return new IndexedBindings(nameToIndex, valueCopy, arguments, setValuesCopy, missingValue);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 55da2e78894..bc80989f030 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -6,7 +6,9 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -27,6 +29,9 @@ public class Model {
/** The prefix generated by mode-integration/../IntermediateOperation */
private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
+ /** Default value to return if value is not supplied */
+ private final static Value missingValue = DoubleValue.frozen(Double.NaN);
+
private final String name;
/** Free functions */
@@ -61,7 +66,7 @@ public class Model {
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this);
+ LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this, missingValue);
contextBuilder.put(function.getValue().getName(), context);
if ( ! function.getValue().returnType().isPresent()) {
functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
@@ -135,7 +140,7 @@ public class Model {
return context;
}
- /** Returns the function withe the given name, or null if none */ // TODO: Parameter overloading?
+ /** Returns the function with the given name, or null if none */ // TODO: Parameter overloading?
ExpressionFunction function(String name) {
for (ExpressionFunction function : functions)
if (function.getName().equals(name))
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index db892dce593..9320ac3fad8 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -43,7 +43,7 @@ public class MlModelsImportingTest {
// Evaluator
FunctionEvaluator evaluator = xgboost.evaluatorOf();
assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta);
+ assertEquals(-4.37659, evaluator.evaluate().sum().asDouble(), delta);
}
{
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 23f0fa7a571..95f9888024a 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -63,7 +63,7 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testXgBoostEvaluationWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function
- String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}";
assertResponse(url, 200, expected);
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java
new file mode 100644
index 00000000000..ec2498b3923
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java
@@ -0,0 +1,86 @@
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestNode;
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class XGBoostImportEvaluationTestCase {
+
+ @Test
+ public void testXGBoostEvaluation() {
+ RankingExpression expression = new XGBoostImporter()
+ .importModel("xgb", "src/test/models/xgboost/xgboost.test.if_inversion.json")
+ .expressions().get("xgb");
+
+ ArrayContext context = new ArrayContext(expression, DoubleValue.NaN);
+
+ assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0));
+ assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0));
+ assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0));
+ assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0));
+ assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0));
+ assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0));
+ assertXGBoostEvaluation(11.0, expression, features(context));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0)));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0)));
+
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ assertTrue(expression.getRoot() instanceof GBDTForestNode);
+
+ assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0));
+ assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0));
+ assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0));
+ assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0));
+ assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0));
+ assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0));
+ assertXGBoostEvaluation(11.0, expression, features(context));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0)));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0)));
+ }
+
+ private ArrayContext features(ArrayContext context) {
+ return context.clone();
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, double v1) {
+ context = context.clone();
+ context.put(f1, v1);
+ return context;
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, Tensor v1) {
+ context = context.clone();
+ context.put(f1, new TensorValue(v1));
+ return context;
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, double v1, String f2, double v2) {
+ context = context.clone();
+ context.put(f1, v1);
+ context.put(f2, v2);
+ return context;
+ }
+
+ private void assertXGBoostEvaluation(double expected, RankingExpression expr, Context context) {
+ assertEquals(expected, expr.evaluate(context).asDouble(), 1e-9);
+ }
+
+}
diff --git a/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json
new file mode 100644
index 00000000000..8994d89787e
--- /dev/null
+++ b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json
@@ -0,0 +1,26 @@
+[
+ { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 0.5, "yes": 1, "no": 2, "missing": 2, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 1.0 },
+ { "nodeid": 4, "leaf": 2.0 }
+ ]},
+ { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 5, "no": 6, "missing": 5, "children": [
+ { "nodeid": 5, "leaf": 3.0 },
+ { "nodeid": 6, "leaf": 4.0 }
+ ]}
+ ]},
+ { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [
+ { "nodeid": 1, "leaf": 0.0 },
+ { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [
+ { "nodeid": 3, "leaf": 4.0 },
+ { "nodeid": 4, "leaf": 5.0 }
+ ]}
+ ]},
+ { "nodeid": 0, "depth": 0, "split": "f2", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [
+ { "nodeid": 1, "leaf": 0.0 },
+ { "nodeid": 2, "depth": 1, "split": "f1", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [
+ { "nodeid": 3, "leaf": 4.0 },
+ { "nodeid": 4, "leaf": 3.0 }
+ ]}
+ ]}
+] \ No newline at end of file
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 0b9cb06d2a5..c72be5b23e9 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -378,6 +378,7 @@
"methods": [
"protected void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression)",
"protected void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean)",
+ "protected void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"protected final java.util.Map nameToIndex()",
"protected final double[] doubleValues()",
"protected final boolean ignoreUnknownValues()",
@@ -402,6 +403,8 @@
"methods": [
"public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression)",
"public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean)",
+ "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public final void put(java.lang.String, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public final void put(int, double)",
"public final void put(int, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
@@ -509,6 +512,7 @@
"methods": [
"public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression)",
"public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean)",
+ "public void <init>(com.yahoo.searchlib.rankingexpression.RankingExpression, boolean, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public final void put(java.lang.String, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public final void put(int, double)",
"public final void put(int, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
@@ -550,7 +554,8 @@
"public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()"
],
"fields": [
- "public static final com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue zero"
+ "public static final com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue zero",
+ "public static final com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue NaN"
]
},
"com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer": {
@@ -575,7 +580,9 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public void <init>(java.util.Map)",
+ "public void <init>(java.util.Map, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.MapContext freeze()",
"public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java
index 41bf827748a..16549b3ee1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import java.util.BitSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
@@ -33,7 +34,11 @@ public abstract class AbstractArrayContext extends Context implements Cloneable,
* This will fail if unknown values are attempted added.
*/
protected AbstractArrayContext(RankingExpression expression) {
- this(expression, false);
+ this(expression, false, defaultMissingValue);
+ }
+
+ protected AbstractArrayContext(RankingExpression expression, boolean ignoreUnknownValues) {
+ this(expression, ignoreUnknownValues, defaultMissingValue);
}
/**
@@ -44,10 +49,11 @@ public abstract class AbstractArrayContext extends Context implements Cloneable,
* @param ignoreUnknownValues whether attempts to put values not present in this expression
* should fail (false - the default), or be ignored (true)
*/
- protected AbstractArrayContext(RankingExpression expression, boolean ignoreUnknownValues) {
+ protected AbstractArrayContext(RankingExpression expression, boolean ignoreUnknownValues, Value missingValue) {
+ this.missingValue = missingValue.freeze();
this.ignoreUnknownValues = ignoreUnknownValues;
this.rankingExpressionName = expression.getName();
- this.indexedBindings = new IndexedBindings(expression);
+ this.indexedBindings = new IndexedBindings(expression, this.missingValue);
}
protected final Map<String, Integer> nameToIndex() { return indexedBindings.nameToIndex(); }
@@ -77,6 +83,14 @@ public abstract class AbstractArrayContext extends Context implements Cloneable,
return indexedBindings.getDouble(index);
}
+ final boolean isMissing(int index) {
+ return indexedBindings.isMissing(index);
+ }
+
+ final void clearMissing(int index) {
+ indexedBindings.clearMissing(index);
+ }
+
@Override
public String toString() {
return "fast lookup context for ranking expression '" + rankingExpressionName +
@@ -107,11 +121,22 @@ public abstract class AbstractArrayContext extends Context implements Cloneable,
/** The current values set, pre-converted to doubles */
private double[] doubleValues;
- public IndexedBindings(RankingExpression expression) {
+ /** Which values actually are set */
+ private BitSet setValues;
+
+ /** Value to return if value is missing. */
+ private double missingValue;
+
+ public IndexedBindings(RankingExpression expression, Value missingValue) {
Set<String> bindTargets = new LinkedHashSet<>();
extractBindTargets(expression.getRoot(), bindTargets);
+ this.missingValue = missingValue.asDouble();
+ setValues = new BitSet(bindTargets.size());
doubleValues = new double[bindTargets.size()];
+ for (int i = 0; i < bindTargets.size(); ++i) {
+ doubleValues[i] = this.missingValue;
+ }
int i = 0;
ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
@@ -136,10 +161,13 @@ public abstract class AbstractArrayContext extends Context implements Cloneable,
public Map<String, Integer> nameToIndex() { return nameToIndex; }
public double[] doubleValues() { return doubleValues; }
+
public Set<String> names() { return nameToIndex.keySet(); }
public int getIndex(String name) { return nameToIndex.get(name); }
public int size() { return doubleValues.length; }
public double getDouble(int index) { return doubleValues[index]; }
+ public boolean isMissing(int index) { return ! setValues.get(index); }
+ public void clearMissing(int index) { setValues.set(index); }
/**
* Creates a clone of this context suitable for evaluating against the same ranking expression
@@ -149,7 +177,11 @@ public abstract class AbstractArrayContext extends Context implements Cloneable,
public IndexedBindings clone() {
try {
IndexedBindings clone = (IndexedBindings)super.clone();
+ clone.setValues = new BitSet(nameToIndex.size());
clone.doubleValues = new double[nameToIndex.size()];
+ for (int i = 0; i < nameToIndex.size(); ++i) {
+ clone.doubleValues[i] = missingValue;
+ }
return clone;
}
catch (CloneNotSupportedException e) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
index 047d9d761ce..82243fc493d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
@@ -19,8 +19,6 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
/** The current values set */
private Value[] values;
- private static DoubleValue constantZero = DoubleValue.frozen(0);
-
/**
* Create a fast lookup context for an expression.
* This instance should be reused indefinitely by a single thread.
@@ -30,6 +28,14 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
this(expression, false);
}
+ public ArrayContext(RankingExpression expression, boolean ignoreUnknownValues) {
+ this(expression, ignoreUnknownValues, defaultMissingValue);
+ }
+
+ public ArrayContext(RankingExpression expression, Value defaultValue) {
+ this(expression, false, defaultValue);
+ }
+
/**
* Create a fast lookup context for an expression.
* This instance should be reused indefinitely by a single thread.
@@ -37,11 +43,12 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
* @param expression the expression to create a context for
* @param ignoreUnknownValues whether attempts to put values not present in this expression
* should fail (false - the default), or be ignored (true)
+ * @param missingValue the value to return if not set.
*/
- public ArrayContext(RankingExpression expression, boolean ignoreUnknownValues) {
- super(expression, ignoreUnknownValues);
+ public ArrayContext(RankingExpression expression, boolean ignoreUnknownValues, Value missingValue) {
+ super(expression, ignoreUnknownValues, missingValue);
values = new Value[doubleValues().length];
- Arrays.fill(values, DoubleValue.zero);
+ Arrays.fill(values, this.missingValue);
}
/**
@@ -74,6 +81,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
*/
public final void put(int index, Value value) {
values[index] = value.freeze();
+ clearMissing(index);
try {
doubleValues()[index] = value.asDouble();
}
@@ -93,7 +101,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
@Override
public Value get(String name) {
Integer index = nameToIndex().get(name);
- if (index == null) return DoubleValue.zero;
+ if (index == null) return missingValue;
return values[index];
}
@@ -107,7 +115,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
@Override
public final double getDouble(int index) {
double value = doubleValues()[index];
- if (Double.isNaN(value))
+ if (Double.isNaN(value) && ! isMissing(index)) // NaN is valid as a missing value
throw new UnsupportedOperationException("Value at " + index + " has no double representation");
return value;
}
@@ -119,7 +127,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
public ArrayContext clone() {
ArrayContext clone = (ArrayContext)super.clone();
clone.values = new Value[nameToIndex().size()];
- Arrays.fill(clone.values, constantZero);
+ Arrays.fill(clone.values, missingValue);
return clone;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java
index 8ac9a6787da..0e187dfc87c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/BooleanValue.java
@@ -49,8 +49,9 @@ public class BooleanValue extends DoubleCompatibleValue {
@Override
public boolean equals(Object other) {
if (this==other) return true;
- if ( ! (other instanceof BooleanValue)) return false;
- return ((BooleanValue)other).value==this.value;
+ if ( ! (other instanceof Value)) return false;
+ if ( ! ((Value) other).hasDouble()) return false;
+ return this.value == ((Value) other).asBoolean();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 4e046df11ca..d68f8c85ad1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -18,6 +18,12 @@ import java.util.stream.Collectors;
*/
public abstract class Context implements EvaluationContext<Reference> {
+ /** The default value to return if the value has not been set */
+ static Value defaultMissingValue = DoubleValue.zero;
+
+ /** The value to return if the value has not been set */
+ Value missingValue;
+
/**
* Returns the value of a simple variable name.
*
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
index 0004036da4b..257b344f025 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
@@ -19,7 +19,11 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
* This will fail if unknown values are attempted added.
*/
public DoubleOnlyArrayContext(RankingExpression expression) {
- this(expression, false);
+ this(expression, false, defaultMissingValue);
+ }
+
+ public DoubleOnlyArrayContext(RankingExpression expression, boolean ignoreUnknownValues) {
+ this(expression, ignoreUnknownValues, defaultMissingValue);
}
/**
@@ -29,9 +33,10 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
* @param expression the expression to create a context for
* @param ignoreUnknownValues whether attempts to put values not present in this expression
* should fail (false - the default), or be ignored (true)
+ * @param missingValue the value to return if not set.
*/
- public DoubleOnlyArrayContext(RankingExpression expression, boolean ignoreUnknownValues) {
- super(expression, ignoreUnknownValues);
+ public DoubleOnlyArrayContext(RankingExpression expression, boolean ignoreUnknownValues, Value missingValue) {
+ super(expression, ignoreUnknownValues, missingValue);
}
/**
@@ -56,6 +61,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
/** Same as put(index,DoubleValue.frozen(value)) */
public final void put(int index, double value) {
doubleValues()[index] = value;
+ clearMissing(index);
}
/** Puts a value by index. */
@@ -77,7 +83,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
@Override
public Value get(String name) {
Integer index = nameToIndex().get(name);
- if (index==null) return DoubleValue.zero;
+ if (index==null) return missingValue;
return new DoubleValue(getDouble(index));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 8aa7446cae7..06ab4cba98f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -20,6 +20,9 @@ public final class DoubleValue extends DoubleCompatibleValue {
/** The double value instance for 0 */
public final static DoubleValue zero = DoubleValue.frozen(0);
+ /** The double value instance for NaN */
+ public final static DoubleValue NaN = DoubleValue.frozen(Double.NaN);
+
public DoubleValue(double value) {
this.value = value;
}
@@ -146,8 +149,9 @@ public final class DoubleValue extends DoubleCompatibleValue {
@Override
public boolean equals(Object other) {
if (this==other) return true;
- if ( ! (other instanceof DoubleValue)) return false;
- return ((DoubleValue)other).value==this.value;
+ if ( ! (other instanceof Value)) return false;
+ if ( ! ((Value) other).hasDouble()) return false;
+ return this.asDouble() == ((Value) other).asDouble();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 4ef24d60bba..f531d77762d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -21,13 +21,23 @@ public class MapContext extends Context {
private boolean frozen = false;
public MapContext() {
+ this(defaultMissingValue);
+ }
+
+ public MapContext(Value missingValue) {
+ this.missingValue = missingValue.freeze();
+ }
+
+ public MapContext(Map<String,Value> bindings) {
+ this(bindings, defaultMissingValue);
}
/**
* Creates a map context from a map.
* All the Values of the map will be frozen.
*/
- public MapContext(Map<String,Value> bindings) {
+ public MapContext(Map<String,Value> bindings, Value missingValue) {
+ this.missingValue = missingValue.freeze();
bindings.forEach((k, v) -> this.bindings.put(k, v.freeze()));
}
@@ -52,7 +62,7 @@ public class MapContext extends Context {
/** Returns the value of a key. 0 is returned if the given key is not bound in this. */
@Override
public Value get(String key) {
- return bindings.getOrDefault(key, DoubleValue.zero);
+ return bindings.getOrDefault(key, missingValue);
}
/**
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index ee66dcc5a03..b109e6503e3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -41,7 +41,9 @@ public class TensorValue extends Value {
@Override
public boolean asBoolean() {
- throw new UnsupportedOperationException("A tensor does not have a boolean value");
+ if (hasDouble())
+ return asDouble() != 0.0;
+ throw new UnsupportedOperationException("Tensor does not have a value that can be converted to a boolean");
}
@Override
@@ -118,18 +120,11 @@ public class TensorValue extends Value {
return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble())));
}
- private Tensor asTensor(Value value, String operationName) {
- if ( ! (value instanceof TensorValue))
- throw new UnsupportedOperationException("Could not perform " + operationName +
- ": The second argument must be a tensor but was " + value);
- return ((TensorValue)value).value;
- }
-
public Tensor asTensor() { return value; }
@Override
public Value compare(TruthOperator operator, Value argument) {
- return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
+ return new TensorValue(compareTensor(operator, argument.asTensor()));
}
private Tensor compareTensor(TruthOperator operator, Tensor argument) {
@@ -148,7 +143,7 @@ public class TensorValue extends Value {
@Override
public Value function(Function function, Value arg) {
if (arg instanceof TensorValue)
- return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString())));
+ return new TensorValue(functionOnTensor(function, arg.asTensor()));
else
return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 7809cdd4e1b..39e408d27ca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -68,7 +68,7 @@ public abstract class Value {
public abstract Value compare(TruthOperator operator, Value value);
/** Perform the given binary function on this value and the given value */
- public abstract Value function(Function function,Value value);
+ public abstract Value function(Function function, Value value);
/**
* Irreversibly makes this immutable. Overriders must always call super.freeze() and return this
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index 3c0898b5d4f..c1ec72ba0fc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -26,15 +26,16 @@ public final class GBDTNode extends ExpressionNode {
// n=[0,MAX_LEAF_VALUE> : n is data (tree leaf constant value)
// n=[MAX_LEAF_VALUE+MAX_VARIABLES*0,MAX_LEAF_VALUE+MAX_VARIABLES*1>: < than var at index n
// n=[MAX_LEAF_VALUE+MAX_VARIABLES*1,MAX_LEAF_VALUE+MAX_VARIABLES*2>: = to var at index n-MAX_VARIABLES
- // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3]: n-MAX_VARIABLES*2 is IN the following set
+ // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3>: n-MAX_VARIABLES*2 is IN the following set
+ // n=[MAX_LEAF_VALUE+MAX_VARIABLES*3,MAX_LEAF_VALUE+MAX_VARIABLES*4]: !( >= ) than var at index n-MAX_VARIABLES*3 (if-inversion)
// The full layout of an IF instruction is
// COMPARISON,TRUE_BRANCH_LENGTH,TRUE_BRANCH,FALSE_BRANCH
- // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or =,
+ // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or = or !( >= ),
// and VARIABLE_AND_OPCODE,COMPARE_CONSTANTS_LENGTH,COMPARE_CONSTANTS if the opcode is IN
- // If any change is made to this encoding, this change must also be reflected in GBDTNodeOptimizer
+ // If any change is made to this encoding, this change must also be reflected in GBDTOptimizer
/** The max (absolute) supported value an optimized leaf may have */
public final static int MAX_LEAF_VALUE=2*1000*1000*1000;
@@ -72,7 +73,7 @@ public final class GBDTNode extends ExpressionNode {
else if (offset < MAX_VARIABLES*2) {
comparisonIsTrue = context.getDouble(offset-MAX_VARIABLES)==values[pc++];
}
- else { // offset<MAX_VARIABLES*3
+ else if (offset<MAX_VARIABLES*3) {
double testValue = context.getDouble(offset-MAX_VARIABLES*2);
int setValuesLeft = (int)values[pc++];
while (setValuesLeft > 0) { // test each value in the set
@@ -84,6 +85,9 @@ public final class GBDTNode extends ExpressionNode {
}
pc += setValuesLeft; // jump to after the set
}
+ else { // offset<MAX_VARIABLES*4
+ comparisonIsTrue = ! (context.getDouble(offset-MAX_VARIABLES*3)>=values[pc++]);
+ }
if (comparisonIsTrue)
pc++; // true branch - skip the jump value
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index 787818b0f42..a6df6b435d6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -53,7 +53,7 @@ public class GBDTOptimizer extends Optimizer {
* anything else.</p>
*
* <p>Each condition node is converted to the double sequence [(OperatorIsEquals ? GBDTNode.MAX_VARIABLES : 0) +
- * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAFT_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true
+ * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAF_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true
* branch values,false branch values]</p>
*
* <p>Each value node is converted to the double value of the value node itself.</p>
@@ -131,6 +131,20 @@ public class GBDTOptimizer extends Optimizer {
for (ExpressionNode setElementNode : setMembership.getSetValues())
values.add(toValue(setElementNode));
}
+ else if (condition instanceof NotNode) { // handle if inversion: !(a >= b)
+ NotNode notNode = (NotNode)condition;
+ if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) {
+ EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0);
+ if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) {
+ ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0);
+ if (comparison.getOperator() == TruthOperator.LARGEREQUAL)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context));
+ else
+ throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator());
+ values.add(toValue(comparison.getRightCondition()));
+ }
+ }
+ }
else {
throw new IllegalArgumentException("Node condition could not be optimized: " + condition);
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java
index ce78703f842..08f1a872759 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java
@@ -24,7 +24,7 @@ public class GBDTForestOptimizerTestCase {
RankingExpression gbdt = new RankingExpression(gbdtString);
// Regular evaluation
- MapContext arguments = new MapContext();
+ MapContext arguments = new MapContext(DoubleValue.NaN);
arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
arguments.put("SUGG_OVERLAP", 17d);
double result1 = gbdt.evaluate(arguments).asDouble();
@@ -36,7 +36,7 @@ public class GBDTForestOptimizerTestCase {
double result3 = gbdt.evaluate(arguments).asDouble();
// Optimized evaluation
- ArrayContext fArguments = new ArrayContext(gbdt);
+ ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN);
ExpressionOptimizer optimizer = new ExpressionOptimizer();
OptimizationReport report = optimizer.optimize(gbdt, fArguments);
assertEquals(4, report.getMetric("Optimized GDBT trees"));
@@ -70,7 +70,7 @@ public class GBDTForestOptimizerTestCase {
RankingExpression gbdt = new RankingExpression(gbdtString);
// Regular evaluation
- MapContext arguments = new MapContext();
+ MapContext arguments = new MapContext(DoubleValue.NaN);
arguments.put("MYSTRING", new StringValue("string 1"));
arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
arguments.put("SUGG_OVERLAP", 17d);
@@ -83,7 +83,7 @@ public class GBDTForestOptimizerTestCase {
double result3 = gbdt.evaluate(arguments).asDouble();
// Optimized evaluation
- ArrayContext fArguments = new ArrayContext(gbdt);
+ ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN);
ExpressionOptimizer optimizer = new ExpressionOptimizer();
OptimizationReport report = optimizer.optimize(gbdt, fArguments);
assertEquals(4, report.getMetric("Optimized GDBT trees"));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java
index 4b7462505fc..82ad034e306 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
@@ -43,7 +44,7 @@ public class GBDTOptimizerTestCase {
RankingExpression gbdt = new RankingExpression(gbdtString);
// Regular evaluation
- MapContext arguments = new MapContext();
+ MapContext arguments = new MapContext(DoubleValue.NaN);
arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
arguments.put("SUGG_OVERLAP", 17d);
double result1 = gbdt.evaluate(arguments).asDouble();
@@ -55,7 +56,7 @@ public class GBDTOptimizerTestCase {
double result3 = gbdt.evaluate(arguments).asDouble();
// Optimized evaluation
- ArrayContext fArguments = new ArrayContext(gbdt);
+ ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN);
ExpressionOptimizer optimizer = new ExpressionOptimizer();
optimizer.getOptimizer(GBDTForestOptimizer.class).setEnabled(false);
OptimizationReport report = optimizer.optimize(gbdt,fArguments);