diff options
author | Lester Solbakken <lesters@oath.com> | 2019-10-11 09:42:58 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-10-11 09:42:58 +0200 |
commit | e5ba4ddccfa45b440b6803eb1665c3b2e6f19be9 (patch) | |
tree | 7e823c36866de9a29ca80af5cd56785339b319d6 /searchlib/src | |
parent | 3acec4a95bc2f75f8384bde14d35f3a5c073460b (diff) |
Recognize if-inverted decision trees in GBDT optimizer
Diffstat (limited to 'searchlib/src')
4 files changed, 30 insertions, 11 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index 3c0898b5d4f..c1ec72ba0fc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -26,15 +26,16 @@ public final class GBDTNode extends ExpressionNode { // n=[0,MAX_LEAF_VALUE> : n is data (tree leaf constant value) // n=[MAX_LEAF_VALUE+MAX_VARIABLES*0,MAX_LEAF_VALUE+MAX_VARIABLES*1>: < than var at index n // n=[MAX_LEAF_VALUE+MAX_VARIABLES*1,MAX_LEAF_VALUE+MAX_VARIABLES*2>: = to var at index n-MAX_VARIABLES - // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3]: n-MAX_VARIABLES*2 is IN the following set + // n=[MAX_LEAF_VALUE+MAX_VARIABLES*2,MAX_LEAF_VALUE+MAX_VARIABLES*3>: n-MAX_VARIABLES*2 is IN the following set + // n=[MAX_LEAF_VALUE+MAX_VARIABLES*3,MAX_LEAF_VALUE+MAX_VARIABLES*4]: !( >= ) than var at index n-MAX_VARIABLES*3 (if-inversion) // The full layout of an IF instruction is // COMPARISON,TRUE_BRANCH_LENGTH,TRUE_BRANCH,FALSE_BRANCH - // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or =, + // where COMPARISON is VARIABLE_AND_OPCODE,COMPARE_CONSTANT if the opcode is < or = or !( >= ), // and VARIABLE_AND_OPCODE,COMPARE_CONSTANTS_LENGTH,COMPARE_CONSTANTS if the opcode is IN - // If any change is made to this encoding, this change must also be reflected in GBDTNodeOptimizer + // If any change is made to this encoding, this change must also be reflected in GBDTOptimizer /** The max (absolute) supported value an optimized leaf may have */ public final static int MAX_LEAF_VALUE=2*1000*1000*1000; @@ -72,7 +73,7 @@ public final class GBDTNode extends ExpressionNode { else if (offset < MAX_VARIABLES*2) { comparisonIsTrue = context.getDouble(offset-MAX_VARIABLES)==values[pc++]; } - else { // offset<MAX_VARIABLES*3 + else if (offset<MAX_VARIABLES*3) { double testValue = context.getDouble(offset-MAX_VARIABLES*2); int setValuesLeft = (int)values[pc++]; while (setValuesLeft > 0) { // test each value in the set @@ -84,6 +85,9 @@ public final class GBDTNode extends ExpressionNode { } pc += setValuesLeft; // jump to after the set } + else { // offset<MAX_VARIABLES*4 + comparisonIsTrue = ! (context.getDouble(offset-MAX_VARIABLES*3)>=values[pc++]); + } if (comparisonIsTrue) pc++; // true branch - skip the jump value diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index 787818b0f42..a6df6b435d6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -53,7 +53,7 @@ public class GBDTOptimizer extends Optimizer { * anything else.</p> * * <p>Each condition node is converted to the double sequence [(OperatorIsEquals ? GBDTNode.MAX_VARIABLES : 0) + - * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAFT_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true + * IndexOfLeftComparisonFeature+GBDTNode.MAX_LEAF_VALUE, ValueOfRightComparisonValue,#OfValuesInTrueBranch,true * branch values,false branch values]</p> * * <p>Each value node is converted to the double value of the value node itself.</p> @@ -131,6 +131,20 @@ public class GBDTOptimizer extends Optimizer { for (ExpressionNode setElementNode : setMembership.getSetValues()) values.add(toValue(setElementNode)); } + else if (condition instanceof NotNode) { // handle if inversion: !(a >= b) + NotNode notNode = (NotNode)condition; + if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) { + EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0); + if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) { + ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0); + if (comparison.getOperator() == TruthOperator.LARGEREQUAL) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context)); + else + throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator()); + values.add(toValue(comparison.getRightCondition())); + } + } + } else { throw new IllegalArgumentException("Node condition could not be optimized: " + condition); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java index ce78703f842..08f1a872759 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizerTestCase.java @@ -24,7 +24,7 @@ public class GBDTForestOptimizerTestCase { RankingExpression gbdt = new RankingExpression(gbdtString); // Regular evaluation - MapContext arguments = new MapContext(); + MapContext arguments = new MapContext(DoubleValue.NaN); arguments.put("LW_NEWS_SEARCHES_RATIO", 1d); arguments.put("SUGG_OVERLAP", 17d); double result1 = gbdt.evaluate(arguments).asDouble(); @@ -36,7 +36,7 @@ public class GBDTForestOptimizerTestCase { double result3 = gbdt.evaluate(arguments).asDouble(); // Optimized evaluation - ArrayContext fArguments = new ArrayContext(gbdt); + ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN); ExpressionOptimizer optimizer = new ExpressionOptimizer(); OptimizationReport report = optimizer.optimize(gbdt, fArguments); assertEquals(4, report.getMetric("Optimized GDBT trees")); @@ -70,7 +70,7 @@ public class GBDTForestOptimizerTestCase { RankingExpression gbdt = new RankingExpression(gbdtString); // Regular evaluation - MapContext arguments = new MapContext(); + MapContext arguments = new MapContext(DoubleValue.NaN); arguments.put("MYSTRING", new StringValue("string 1")); arguments.put("LW_NEWS_SEARCHES_RATIO", 1d); arguments.put("SUGG_OVERLAP", 17d); @@ -83,7 +83,7 @@ public class GBDTForestOptimizerTestCase { double result3 = gbdt.evaluate(arguments).asDouble(); // Optimized evaluation - ArrayContext fArguments = new ArrayContext(gbdt); + ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN); ExpressionOptimizer optimizer = new ExpressionOptimizer(); OptimizationReport report = optimizer.optimize(gbdt, fArguments); assertEquals(4, report.getMetric("Optimized GDBT trees")); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java index 4b7462505fc..82ad034e306 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizerTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; @@ -43,7 +44,7 @@ public class GBDTOptimizerTestCase { RankingExpression gbdt = new RankingExpression(gbdtString); // Regular evaluation - MapContext arguments = new MapContext(); + MapContext arguments = new MapContext(DoubleValue.NaN); arguments.put("LW_NEWS_SEARCHES_RATIO", 1d); arguments.put("SUGG_OVERLAP", 17d); double result1 = gbdt.evaluate(arguments).asDouble(); @@ -55,7 +56,7 @@ public class GBDTOptimizerTestCase { double result3 = gbdt.evaluate(arguments).asDouble(); // Optimized evaluation - ArrayContext fArguments = new ArrayContext(gbdt); + ArrayContext fArguments = new ArrayContext(gbdt, DoubleValue.NaN); ExpressionOptimizer optimizer = new ExpressionOptimizer(); optimizer.getOptimizer(GBDTForestOptimizer.class).setEnabled(false); OptimizationReport report = optimizer.optimize(gbdt,fArguments); |