diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 10:04:27 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 10:04:27 +0100 |
commit | ba6a11e6e2674a2b5c1ef967319fb269f989a216 (patch) | |
tree | b15c1c046989cafeed19d193fdb59634140d3db6 /searchlib | |
parent | 3e1477f5fda4a3dcd436a6d41843adc66e19f370 (diff) |
Use a context for transform state
Diffstat (limited to 'searchlib')
6 files changed, 73 insertions, 51 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index 46f79dcc6bd..1b8239ba643 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -10,7 +10,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import java.util.ArrayList; import java.util.List; -import java.util.Map; /** * Replaces "features" which found in the given constants by their constant value @@ -19,40 +18,33 @@ import java.util.Map; */ public class ConstantDereferencer extends ExpressionTransformer { - /** The map of constants to dereference */ - private final Map<String, Value> constants; - - public ConstantDereferencer(Map<String, Value> constants) { - this.constants = constants; - } - @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof ReferenceNode) - return transformFeature((ReferenceNode) node); + return transformFeature((ReferenceNode) node, context); else if (node instanceof CompositeNode) - return transformChildren((CompositeNode)node); + return transformChildren((CompositeNode)node, context); else return node; } - private ExpressionNode transformFeature(ReferenceNode node) { + private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) { if (!node.getArguments().isEmpty()) - return transformArguments(node); + return transformArguments(node, context); else - return transformConstantReference(node); + return transformConstantReference(node, context); } - private ExpressionNode transformArguments(ReferenceNode node) { + private ExpressionNode transformArguments(ReferenceNode node, TransformContext context) { List<ExpressionNode> arguments = node.getArguments().expressions(); List<ExpressionNode> transformedArguments = new ArrayList<>(arguments.size()); for (ExpressionNode argument : arguments) - transformedArguments.add(transform(argument)); + transformedArguments.add(transform(argument, context)); return node.setArguments(transformedArguments); } - private ExpressionNode transformConstantReference(ReferenceNode node) { - Value value = constants.get(node.getName()); + private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) { + Value value = context.constants().get(node.getName()); if (value == null || (value instanceof TensorValue)) { return node; // not a value constant reference } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java index bcc8b817641..c585c0dea1f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java @@ -15,22 +15,22 @@ import java.util.List; */ public abstract class ExpressionTransformer { - public RankingExpression transform(RankingExpression expression) { - return new RankingExpression(expression.getName(), transform(expression.getRoot())); + public RankingExpression transform(RankingExpression expression, TransformContext context) { + return new RankingExpression(expression.getName(), transform(expression.getRoot(), context)); } /** Transforms an expression node and returns the transformed node */ - public abstract ExpressionNode transform(ExpressionNode node); + public abstract ExpressionNode transform(ExpressionNode node, TransformContext context); /** * Utility method which calls transform on each child of the given node and return the resulting transformed * composite */ - protected CompositeNode transformChildren(CompositeNode node) { + protected CompositeNode transformChildren(CompositeNode node, TransformContext context) { List<ExpressionNode> children = node.children(); List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); for (ExpressionNode child : children) - transformedChildren.add(transform(child)); + transformedChildren.add(transform(child, context)); return node.setChildren(transformedChildren); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index ebad0d5c21f..9e8491340b0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; @@ -10,8 +9,8 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import java.util.ArrayList; import java.util.List; @@ -24,9 +23,9 @@ import java.util.List; public class Simplifier extends ExpressionTransformer { @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof CompositeNode) - node = transformChildren((CompositeNode) node); // depth first + node = transformChildren((CompositeNode) node, context); // depth first if (node instanceof IfNode) node = transformIf((IfNode) node); if (node instanceof EmbracedNode && hasSingleUndividableChild((EmbracedNode)node)) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java new file mode 100644 index 00000000000..746ca3b3200 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java @@ -0,0 +1,22 @@ +package com.yahoo.searchlib.rankingexpression.transform; + +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Map; + +/** + * Provides a context in which transforms on ranking expressions take place. + * + * @author bratseth + */ +public class TransformContext { + + private final Map<String, Value> constants; + + public TransformContext(Map<String, Value> constants) { + this.constants = constants; + } + + public Map<String, Value> constants() { return constants; } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java index 4035e499a6a..84e51835458 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java @@ -5,11 +5,12 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import org.junit.Test; -import static org.junit.Assert.*; import java.util.HashMap; import java.util.Map; +import static org.junit.Assert.assertEquals; + /** * @author bratseth */ @@ -17,14 +18,16 @@ public class ConstantDereferencerTestCase { @Test public void testConstantDereferencer() throws ParseException { + ConstantDereferencer c = new ConstantDereferencer(); + Map<String, Value> constants = new HashMap<>(); constants.put("a", Value.parse("1.0")); constants.put("b", Value.parse("2")); constants.put("c", Value.parse("3.5")); - ConstantDereferencer c = new ConstantDereferencer(constants); + TransformContext context = new TransformContext(constants); - assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c")).toString()); - assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)")).toString()); + assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString()); + assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)"), context).toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java index f9d2472e306..8fac3395ac0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java @@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import org.junit.Test; + +import java.util.Collections; + import static org.junit.Assert.*; /** @@ -17,31 +20,33 @@ public class SimplifierTestCase { @Test public void testSimplify() throws ParseException { Simplifier s = new Simplifier(); - assertEquals("a + b", s.transform(new RankingExpression("a + b")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 ")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1")).toString()); - assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )")).toString()); - assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)")).toString()); - assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))")).toString()); - assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)")).toString()); - assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4")).toString()); - assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar")).toString()); - assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)")).toString()); - assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)")).toString()); + TransformContext c = new TransformContext(Collections.emptyMap()); + assertEquals("a + b", s.transform(new RankingExpression("a + b"), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5"), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 "), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1"), c).toString()); + assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )"), c).toString()); + assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)"), c).toString()); + assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))"), c).toString()); + assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)"), c).toString()); + assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4"), c).toString()); + assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar"), c).toString()); + assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)"), c).toString()); + assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)"), c).toString()); } // A black box test verifying we are not screwing up real expressions @Test public void testSimplifyComplexExpression() throws ParseException { RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)"); - RankingExpression simplified = new Simplifier().transform(initial); + TransformContext c = new TransformContext(Collections.emptyMap()); + RankingExpression simplified = new Simplifier().transform(initial, c); Context context = new MapContext(); context.put("INFERRED", 0.5); @@ -65,7 +70,8 @@ public class SimplifierTestCase { @Test public void testParenthesisPreservation() throws ParseException { Simplifier s = new Simplifier(); - CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0")).getRoot(); + TransformContext c = new TransformContext(Collections.emptyMap()); + CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0"), c).getRoot(); assertEquals("a + (b + c) / 100000000.0", transformed.toString()); } |