summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 10:04:27 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 10:04:27 +0100
commitba6a11e6e2674a2b5c1ef967319fb269f989a216 (patch)
treeb15c1c046989cafeed19d193fdb59634140d3db6 /searchlib
parent3e1477f5fda4a3dcd436a6d41843adc66e19f370 (diff)
Use a context for transform state
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java11
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java46
6 files changed, 73 insertions, 51 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
index 46f79dcc6bd..1b8239ba643 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
@@ -10,7 +10,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
/**
* Replaces "features" which found in the given constants by their constant value
@@ -19,40 +18,33 @@ import java.util.Map;
*/
public class ConstantDereferencer extends ExpressionTransformer {
- /** The map of constants to dereference */
- private final Map<String, Value> constants;
-
- public ConstantDereferencer(Map<String, Value> constants) {
- this.constants = constants;
- }
-
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof ReferenceNode)
- return transformFeature((ReferenceNode) node);
+ return transformFeature((ReferenceNode) node, context);
else if (node instanceof CompositeNode)
- return transformChildren((CompositeNode)node);
+ return transformChildren((CompositeNode)node, context);
else
return node;
}
- private ExpressionNode transformFeature(ReferenceNode node) {
+ private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) {
if (!node.getArguments().isEmpty())
- return transformArguments(node);
+ return transformArguments(node, context);
else
- return transformConstantReference(node);
+ return transformConstantReference(node, context);
}
- private ExpressionNode transformArguments(ReferenceNode node) {
+ private ExpressionNode transformArguments(ReferenceNode node, TransformContext context) {
List<ExpressionNode> arguments = node.getArguments().expressions();
List<ExpressionNode> transformedArguments = new ArrayList<>(arguments.size());
for (ExpressionNode argument : arguments)
- transformedArguments.add(transform(argument));
+ transformedArguments.add(transform(argument, context));
return node.setArguments(transformedArguments);
}
- private ExpressionNode transformConstantReference(ReferenceNode node) {
- Value value = constants.get(node.getName());
+ private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) {
+ Value value = context.constants().get(node.getName());
if (value == null || (value instanceof TensorValue)) {
return node; // not a value constant reference
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
index bcc8b817641..c585c0dea1f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
@@ -15,22 +15,22 @@ import java.util.List;
*/
public abstract class ExpressionTransformer {
- public RankingExpression transform(RankingExpression expression) {
- return new RankingExpression(expression.getName(), transform(expression.getRoot()));
+ public RankingExpression transform(RankingExpression expression, TransformContext context) {
+ return new RankingExpression(expression.getName(), transform(expression.getRoot(), context));
}
/** Transforms an expression node and returns the transformed node */
- public abstract ExpressionNode transform(ExpressionNode node);
+ public abstract ExpressionNode transform(ExpressionNode node, TransformContext context);
/**
* Utility method which calls transform on each child of the given node and return the resulting transformed
* composite
*/
- protected CompositeNode transformChildren(CompositeNode node) {
+ protected CompositeNode transformChildren(CompositeNode node, TransformContext context) {
List<ExpressionNode> children = node.children();
List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
for (ExpressionNode child : children)
- transformedChildren.add(transform(child));
+ transformedChildren.add(transform(child, context));
return node.setChildren(transformedChildren);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index ebad0d5c21f..9e8491340b0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.transform;
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
@@ -10,8 +9,8 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import java.util.ArrayList;
import java.util.List;
@@ -24,9 +23,9 @@ import java.util.List;
public class Simplifier extends ExpressionTransformer {
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof CompositeNode)
- node = transformChildren((CompositeNode) node); // depth first
+ node = transformChildren((CompositeNode) node, context); // depth first
if (node instanceof IfNode)
node = transformIf((IfNode) node);
if (node instanceof EmbracedNode && hasSingleUndividableChild((EmbracedNode)node))
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java
new file mode 100644
index 00000000000..746ca3b3200
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java
@@ -0,0 +1,22 @@
+package com.yahoo.searchlib.rankingexpression.transform;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Map;
+
+/**
+ * Provides a context in which transforms on ranking expressions take place.
+ *
+ * @author bratseth
+ */
+public class TransformContext {
+
+ private final Map<String, Value> constants;
+
+ public TransformContext(Map<String, Value> constants) {
+ this.constants = constants;
+ }
+
+ public Map<String, Value> constants() { return constants; }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
index 4035e499a6a..84e51835458 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
@@ -5,11 +5,12 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import org.junit.Test;
-import static org.junit.Assert.*;
import java.util.HashMap;
import java.util.Map;
+import static org.junit.Assert.assertEquals;
+
/**
* @author bratseth
*/
@@ -17,14 +18,16 @@ public class ConstantDereferencerTestCase {
@Test
public void testConstantDereferencer() throws ParseException {
+ ConstantDereferencer c = new ConstantDereferencer();
+
Map<String, Value> constants = new HashMap<>();
constants.put("a", Value.parse("1.0"));
constants.put("b", Value.parse("2"));
constants.put("c", Value.parse("3.5"));
- ConstantDereferencer c = new ConstantDereferencer(constants);
+ TransformContext context = new TransformContext(constants);
- assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c")).toString());
- assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)")).toString());
+ assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString());
+ assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)"), context).toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
index f9d2472e306..8fac3395ac0 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
@@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import org.junit.Test;
+
+import java.util.Collections;
+
import static org.junit.Assert.*;
/**
@@ -17,31 +20,33 @@ public class SimplifierTestCase {
@Test
public void testSimplify() throws ParseException {
Simplifier s = new Simplifier();
- assertEquals("a + b", s.transform(new RankingExpression("a + b")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 ")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1")).toString());
- assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )")).toString());
- assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)")).toString());
- assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))")).toString());
- assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)")).toString());
- assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4")).toString());
- assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar")).toString());
- assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)")).toString());
- assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)")).toString());
+ TransformContext c = new TransformContext(Collections.emptyMap());
+ assertEquals("a + b", s.transform(new RankingExpression("a + b"), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5"), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 "), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1"), c).toString());
+ assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )"), c).toString());
+ assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)"), c).toString());
+ assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))"), c).toString());
+ assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)"), c).toString());
+ assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4"), c).toString());
+ assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar"), c).toString());
+ assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)"), c).toString());
+ assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)"), c).toString());
}
// A black box test verifying we are not screwing up real expressions
@Test
public void testSimplifyComplexExpression() throws ParseException {
RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)");
- RankingExpression simplified = new Simplifier().transform(initial);
+ TransformContext c = new TransformContext(Collections.emptyMap());
+ RankingExpression simplified = new Simplifier().transform(initial, c);
Context context = new MapContext();
context.put("INFERRED", 0.5);
@@ -65,7 +70,8 @@ public class SimplifierTestCase {
@Test
public void testParenthesisPreservation() throws ParseException {
Simplifier s = new Simplifier();
- CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0")).getRoot();
+ TransformContext c = new TransformContext(Collections.emptyMap());
+ CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0"), c).getRoot();
assertEquals("a + (b + c) / 100000000.0", transformed.toString());
}