summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-10-11 09:42:58 +0200
committerLester Solbakken <lesters@oath.com>2019-10-11 09:42:58 +0200
commite5ba4ddccfa45b440b6803eb1665c3b2e6f19be9 (patch)
tree7e823c36866de9a29ca80af5cd56785339b319d6 /searchlib
parent3acec4a95bc2f75f8384bde14d35f3a5c073460b (diff)
Recognize if-inverted decision trees in GBDT optimizer
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java16
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java5
4 files changed, 30 insertions, 11 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index 3c0898b5d4f..c1ec72ba0fc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -26,15 +26,16 @@ public final class GBDTNode extends ExpressionNode {
// n=[0,MAX_LEAF_VALUE> : n is data (tree leaf constant value)
// n=[MAX_LEAF_VALUE+MAX_VARIABLES*0,MAX_LEAF_VALUE+MAX_VARIABLES*1>: < than var at index n
// n=[MAX_LEAF_VALUE+MAX_VARIABLES*1,MAX_LEAF_VALUE+MAX_VARIABLES*2>: = to var at index n-MAX_VARIABLES
- // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3]: n-MAX_VARIABLES*2 is IN the following set
+ // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3>: n-MAX_VARIABLES*2 is IN the following set
+ // n=[MAX_LEAF_VALUE+MAX_VARIABLES*3,MAX_LEAF_VALUE+MAX_VARIABLES*4]: !( >= ) than var at index n-MAX_VARIABLES*3 (if-inversion)
// The full layout of an IF instruction is
// COMPARISON,TRUE_BRANCH_LENGTH,TRUE_BRANCH,FALSE_BRANCH
- // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or =,
+ // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or = or !( >= ),
// and VARIABLE_AND_OPCODE,COMPARE_CONSTANTS_LENGTH,COMPARE_CONSTANTS if the opcode is IN
- // If any change is made to this encoding, this change must also be reflected in GBDTNodeOptimizer
+ // If any change is made to this encoding, this change must also be reflected in GBDTOptimizer
/** The max (absolute) supported value an optimized leaf may have */
public final static int MAX_LEAF_VALUE=2*1000*1000*1000;
@@ -72,7 +73,7 @@ public final class GBDTNode extends ExpressionNode {
else if (offset < MAX_VARIABLES*2) {
comparisonIsTrue = context.getDouble(offset-MAX_VARIABLES)==values[pc++];
}
- else { // offset<MAX_VARIABLES*3
+ else if (offset<MAX_VARIABLES*3) {
double testValue = context.getDouble(offset-MAX_VARIABLES*2);
int setValuesLeft = (int)values[pc++];
while (setValuesLeft > 0) { // test each value in the set
@@ -84,6 +85,9 @@ public final class GBDTNode extends ExpressionNode {
}
pc += setValuesLeft; // jump to after the set
}
+ else { // offset<MAX_VARIABLES*4
+ comparisonIsTrue = ! (context.getDouble(offset-MAX_VARIABLES*3)>=values[pc++]);
+ }
if (comparisonIsTrue)
pc++; // true branch - skip the jump value
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index 787818b0f42..a6df6b435d6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -53,7 +53,7 @@ public class GBDTOptimizer extends Optimizer {
* anything else.</p>
*
* <p>Each condition node is converted to the double sequence [(OperatorIsEquals ? GBDTNode.MAX_VARIABLES : 0) +
- * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAFT_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true
+ * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAF_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true
* branch values,false branch values]</p>
*
* <p>Each value node is converted to the double value of the value node itself.</p>
@@ -131,6 +131,20 @@ public class GBDTOptimizer extends Optimizer {
for (ExpressionNode setElementNode : setMembership.getSetValues())
values.add(toValue(setElementNode));
}
+ else if (condition instanceof NotNode) { // handle if inversion: !(a >= b)
+ NotNode notNode = (NotNode)condition;
+ if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) {
+ EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0);
+ if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) {
+ ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0);
+ if (comparison.getOperator() == TruthOperator.LARGEREQUAL)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context));
+ else
+ throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator());
+ values.add(toValue(comparison.getRightCondition()));
+ }
+ }
+ }
else {
throw new IllegalArgumentException("Node condition could not be optimized: " + condition);
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java
index ce78703f842..08f1a872759 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java
@@ -24,7 +24,7 @@ public class GBDTForestOptimizerTestCase {
RankingExpression gbdt = new RankingExpression(gbdtString);
// Regular evaluation
- MapContext arguments = new MapContext();
+ MapContext arguments = new MapContext(DoubleValue.NaN);
arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
arguments.put("SUGG_OVERLAP", 17d);
double result1 = gbdt.evaluate(arguments).asDouble();
@@ -36,7 +36,7 @@ public class GBDTForestOptimizerTestCase {
double result3 = gbdt.evaluate(arguments).asDouble();
// Optimized evaluation
- ArrayContext fArguments = new ArrayContext(gbdt);
+ ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN);
ExpressionOptimizer optimizer = new ExpressionOptimizer();
OptimizationReport report = optimizer.optimize(gbdt, fArguments);
assertEquals(4, report.getMetric("Optimized GDBT trees"));
@@ -70,7 +70,7 @@ public class GBDTForestOptimizerTestCase {
RankingExpression gbdt = new RankingExpression(gbdtString);
// Regular evaluation
- MapContext arguments = new MapContext();
+ MapContext arguments = new MapContext(DoubleValue.NaN);
arguments.put("MYSTRING", new StringValue("string 1"));
arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
arguments.put("SUGG_OVERLAP", 17d);
@@ -83,7 +83,7 @@ public class GBDTForestOptimizerTestCase {
double result3 = gbdt.evaluate(arguments).asDouble();
// Optimized evaluation
- ArrayContext fArguments = new ArrayContext(gbdt);
+ ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN);
ExpressionOptimizer optimizer = new ExpressionOptimizer();
OptimizationReport report = optimizer.optimize(gbdt, fArguments);
assertEquals(4, report.getMetric("Optimized GDBT trees"));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java
index 4b7462505fc..82ad034e306 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
@@ -43,7 +44,7 @@ public class GBDTOptimizerTestCase {
RankingExpression gbdt = new RankingExpression(gbdtString);
// Regular evaluation
- MapContext arguments = new MapContext();
+ MapContext arguments = new MapContext(DoubleValue.NaN);
arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
arguments.put("SUGG_OVERLAP", 17d);
double result1 = gbdt.evaluate(arguments).asDouble();
@@ -55,7 +56,7 @@ public class GBDTOptimizerTestCase {
double result3 = gbdt.evaluate(arguments).asDouble();
// Optimized evaluation
- ArrayContext fArguments = new ArrayContext(gbdt);
+ ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN);
ExpressionOptimizer optimizer = new ExpressionOptimizer();
optimizer.getOptimizer(GBDTForestOptimizer.class).setEnabled(false);
OptimizationReport report = optimizer.optimize(gbdt,fArguments);