aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-09-28 19:00:08 +0200
committerJon Bratseth <bratseth@gmail.com>2022-09-28 19:00:08 +0200
commita1912b44d0b800f96b334a24ddefd0026f3af356 (patch)
tree352e2b4d026ae9373d73dc4fd7e9892c81943f7f /searchlib
parentbcbb2009c44380055b2670e7cdefcad232f9ece4 (diff)
Use tensor vocabulary
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java16
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java4
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java4
12 files changed, 62 insertions, 67 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index e1db6378fcf..186208e036f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -42,27 +42,27 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public Value greaterEqual(Value value) {
+ public Value largerOrEqual(Value value) {
return new BooleanValue(this.asDouble() >= value.asDouble());
}
@Override
- public Value greater(Value value) {
+ public Value larger(Value value) {
return new BooleanValue(this.asDouble() > value.asDouble());
}
@Override
- public Value lessEqual(Value value) {
+ public Value smallerOrEqual(Value value) {
return new BooleanValue(this.asDouble() <= value.asDouble());
}
@Override
- public Value less(Value value) {
+ public Value smaller(Value value) {
return new BooleanValue(this.asDouble() < value.asDouble());
}
@Override
- public Value approx(Value value) {
+ public Value approxEqual(Value value) {
return new BooleanValue(approxEqual(this.asDouble(), value.asDouble()));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 3c09c644147..a585c989954 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -71,27 +71,27 @@ public class StringValue extends Value {
}
@Override
- public Value greaterEqual(Value argument) {
+ public Value largerOrEqual(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual");
}
@Override
- public Value greater(Value argument) {
+ public Value larger(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support greater");
}
@Override
- public Value lessEqual(Value argument) {
+ public Value smallerOrEqual(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual");
}
@Override
- public Value less(Value argument) {
+ public Value smaller(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support less");
}
@Override
- public Value approx(Value argument) {
+ public Value approxEqual(Value argument) {
return new BooleanValue(this.asDouble() == argument.asDouble());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 73ea0b23986..25e03c75376 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -72,7 +72,7 @@ public class TensorValue extends Value {
}
@Override
- public Value greaterEqual(Value argument) {
+ public Value largerOrEqual(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.largerOrEqual(((TensorValue)argument).value));
else
@@ -80,7 +80,7 @@ public class TensorValue extends Value {
}
@Override
- public Value greater(Value argument) {
+ public Value larger(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.larger(((TensorValue)argument).value));
else
@@ -88,7 +88,7 @@ public class TensorValue extends Value {
}
@Override
- public Value lessEqual(Value argument) {
+ public Value smallerOrEqual(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value));
else
@@ -96,7 +96,7 @@ public class TensorValue extends Value {
}
@Override
- public Value less(Value argument) {
+ public Value smaller(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.smaller(((TensorValue)argument).value));
else
@@ -104,7 +104,7 @@ public class TensorValue extends Value {
}
@Override
- public Value approx(Value argument) {
+ public Value approxEqual(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.approxEqual(((TensorValue)argument).value));
else
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 99663fe8d0d..5de2138147e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -53,11 +53,11 @@ public abstract class Value {
public abstract Value or(Value value);
public abstract Value and(Value value);
- public abstract Value greaterEqual(Value value);
- public abstract Value greater(Value value);
- public abstract Value lessEqual(Value value);
- public abstract Value less(Value value);
- public abstract Value approx(Value value);
+ public abstract Value largerOrEqual(Value value);
+ public abstract Value larger(Value value);
+ public abstract Value smallerOrEqual(Value value);
+ public abstract Value smaller(Value value);
+ public abstract Value approxEqual(Value value);
public abstract Value notEqual(Value value);
public abstract Value equal(Value value);
public abstract Value add(Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
index 78acd2e5af1..a3fc6aae9ac 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
@@ -89,7 +89,7 @@ public class GBDTForestOptimizer extends Optimizer {
}
OperationNode aNode = (OperationNode)node;
for (Operator op : aNode.operators()) {
- if (op != Operator.PLUS) {
+ if (op != Operator.plus) {
return false;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index b3cbe252dfc..7ba671e62eb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -116,9 +116,9 @@ public class GBDTOptimizer extends Optimizer {
private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
if (isBinaryComparison(condition)) {
OperationNode comparison = (OperationNode)condition;
- if (comparison.operators().get(0) == Operator.LESS)
+ if (comparison.operators().get(0) == Operator.smaller)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context));
- else if (comparison.operators().get(0) == Operator.EQUAL)
+ else if (comparison.operators().get(0) == Operator.equal)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context));
else
throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0));
@@ -135,7 +135,7 @@ public class GBDTOptimizer extends Optimizer {
if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) {
if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) {
OperationNode comparison = (OperationNode)embracedNode.children().get(0);
- if (comparison.operators().get(0) == Operator.GREATEREQUAL)
+ if (comparison.operators().get(0) == Operator.largerOrEqual)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context));
else
throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0));
@@ -153,13 +153,13 @@ public class GBDTOptimizer extends Optimizer {
private boolean isBinaryComparison(ExpressionNode condition) {
if ( ! (condition instanceof OperationNode binaryNode)) return false;
if (binaryNode.operators().size() != 1) return false;
- if (binaryNode.operators().get(0) == Operator.GREATEREQUAL) return true;
- if (binaryNode.operators().get(0) == Operator.GREATER) return true;
- if (binaryNode.operators().get(0) == Operator.LESSEQUAL) return true;
- if (binaryNode.operators().get(0) == Operator.LESS) return true;
- if (binaryNode.operators().get(0) == Operator.APPROX) return true;
- if (binaryNode.operators().get(0) == Operator.NOTEQUAL) return true;
- if (binaryNode.operators().get(0) == Operator.EQUAL) return true;
+ if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.larger) return true;
+ if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.smaller) return true;
+ if (binaryNode.operators().get(0) == Operator.approxEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.notEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.equal) return true;
return false;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 3c7f48aa38c..97e9a74f9c8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -123,14 +123,14 @@ public class LambdaFunctionNode extends CompositeNode {
}
Operator operator = node.operators().get(0);
switch (operator) {
- case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
- case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
- case PLUS: return asFunctionExpression((left, right) -> left + right);
- case MINUS: return asFunctionExpression((left, right) -> left - right);
- case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
- case DIVIDE: return asFunctionExpression((left, right) -> left / right);
- case MODULO: return asFunctionExpression((left, right) -> left % right);
- case POWER: return asFunctionExpression(Math::pow);
+ case or: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case and: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case plus: return asFunctionExpression((left, right) -> left + right);
+ case minus: return asFunctionExpression((left, right) -> left - right);
+ case multiply: return asFunctionExpression((left, right) -> left * right);
+ case divide: return asFunctionExpression((left, right) -> left / right);
+ case modulo: return asFunctionExpression((left, right) -> left % right);
+ case power: return asFunctionExpression(Math::pow);
}
return Optional.empty();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
index 392f42f6cbe..d08e2270935 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
@@ -127,14 +127,6 @@ public final class OperationNode extends CompositeNode {
@Override
public int hashCode() { return Objects.hash(children, operators); }
- @Override
- public boolean equals(Object o) {
- if ( ! (o instanceof OperationNode other)) return false;
- if ( ! this.children().equals(other.children())) return false;
- if ( ! this.operators().equals(other.operators())) return false;
- return true;
- }
-
public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) {
if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
index 4ddbfa4ea9f..63144f0ef4a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
@@ -15,21 +15,21 @@ import java.util.function.BiFunction;
public enum Operator {
// In order from lowest to highest precedence
- OR("||", (x, y) -> x.or(y)),
- AND("&&", (x, y) -> x.and(y)),
- GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)),
- GREATER(">", (x, y) -> x.greater(y)),
- LESSEQUAL("<=", (x, y) -> x.lessEqual(y)),
- LESS("<", (x, y) -> x.less(y)),
- APPROX("~=", (x, y) -> x.approx(y)),
- NOTEQUAL("!=", (x, y) -> x.notEqual(y)),
- EQUAL("==", (x, y) -> x.equal(y)),
- PLUS("+", (x, y) -> x.add(y)),
- MINUS("-", (x, y) -> x.subtract(y)),
- MULTIPLY("*", (x, y) -> x.multiply(y)),
- DIVIDE("/", (x, y) -> x.divide(y)),
- MODULO("%", (x, y) -> x.modulo(y)),
- POWER("^", true, (x, y) -> x.power(y));
+ or("||", (x, y) -> x.or(y)),
+ and("&&", (x, y) -> x.and(y)),
+ largerOrEqual(">=", (x, y) -> x.largerOrEqual(y)),
+ larger(">", (x, y) -> x.larger(y)),
+ smallerOrEqual("<=", (x, y) -> x.smallerOrEqual(y)),
+ smaller("<", (x, y) -> x.smaller(y)),
+ approxEqual("~=", (x, y) -> x.approxEqual(y)),
+ notEqual("!=", (x, y) -> x.notEqual(y)),
+ equal("==", (x, y) -> x.equal(y)),
+ plus("+", (x, y) -> x.add(y)),
+ minus("-", (x, y) -> x.subtract(y)),
+ multiply("*", (x, y) -> x.multiply(y)),
+ divide("/", (x, y) -> x.divide(y)),
+ modulo("%", (x, y) -> x.modulo(y)),
+ power("^", true, (x, y) -> x.power(y));
/** A list of all the operators in this in order of increasing precedence */
public static final List<Operator> operatorsByPrecedence = Arrays.stream(Operator.values()).toList();
@@ -53,6 +53,9 @@ public enum Operator {
return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(op);
}
+ /** Returns true if a sequence of these operations should be evaluated from right to left rather than left to right. */
+ public boolean bindsRight() { return bindsRight; }
+
public final Value evaluate(Value x, Value y) {
return function.apply(x, y);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 04728966bc1..1c1f7509ce8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -116,7 +116,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean allMultiplicationOrDivision(OperationNode node) {
for (Operator o : node.operators())
- if (o != Operator.MULTIPLY && o != Operator.DIVIDE)
+ if (o != Operator.multiply && o != Operator.divide)
return false;
return true;
}
@@ -131,7 +131,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean hasDivisionByZero(OperationNode node) {
for (int i = 1; i < node.children().size(); i++) {
- if (node.operators().get(i - 1) == Operator.DIVIDE && isZero(node.children().get(i)))
+ if (node.operators().get(i - 1) == Operator.divide && isZero(node.children().get(i)))
return true;
}
return false;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 83de8e04a7d..f18240c3222 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -66,7 +66,7 @@ public class RankingExpressionTestCase {
public void testProgrammaticBuilding() throws ParseException {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
- OperationNode product = new OperationNode(input, Operator.MULTIPLY, constant);
+ OperationNode product = new OperationNode(input, Operator.multiply, constant);
Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 019f76521e9..dac7393a168 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -779,8 +779,8 @@ public class EvaluationTestCase {
@Test
public void testProgrammaticBuildingAndPrecedence() {
- RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.PLUS, new OperationNode(constant(3), Operator.MULTIPLY, constant(4))));
- RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.PLUS, constant(3)), Operator.MULTIPLY, constant(4)));
+ RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.plus, new OperationNode(constant(3), Operator.multiply, constant(4))));
+ RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.plus, constant(3)), Operator.multiply, constant(4)));
assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance);
assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance);
assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString());