aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 15:40:06 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 15:40:06 +0100
commit0114377d53c1e70bba679582abb88cc7af038bc1 (patch)
treeef7b7c7dc66b3810d8d6b1a0641bdd98fa238c1e
parentbd6d16ba66a7b6745fc15a8b25dc7120fb5580ab (diff)
Comparison functions on tensors
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java63
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java3
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
11 files changed, 59 insertions, 62 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index 2bae382d5bd..f8dcd8a6127 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- return operator.evaluate(asDouble(), value.asDouble());
+ public Value compare(TruthOperator operator, Value value) {
+ return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 1cd65c3133a..35c210da0c0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- try {
- return operator.evaluate(this.value, value.asDouble());
- }
- catch (UnsupportedOperationException e) {
- throw unsupported("comparison",value);
- }
- }
-
- @Override
public Value function(Function function, Value value) {
// use the tensor implementation of max and min if the argument is a tensor
if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index b2dc5e27b91..9d0f0b692c4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -68,10 +68,10 @@ public class StringValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
+ public Value compare(TruthOperator operator, Value value) {
if (operator.equals(TruthOperator.EQUAL))
- return this.equals(value);
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='");
+ return new BooleanValue(this.equals(value));
+ throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index dc422f2c8da..b1f4a7b20ca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -89,30 +89,6 @@ public class TensorValue extends Value {
return new TensorValue(value.map((value) -> value / argument.asDouble()));
}
- public Value min(Value argument) {
- return new TensorValue(value.min(asTensor(argument, "min")));
- }
-
- public Value max(Value argument) {
- return new TensorValue(value.max(asTensor(argument, "max")));
- }
-
- public Value atan2(Value argument) {
- return new TensorValue(value.atan2(asTensor(argument, "atan2")));
- }
-
- public Value equal(Value argument) {
- return new TensorValue(value.equal(asTensor(argument, "equal")));
- }
-
- public Value sum(String dimension) {
- return new TensorValue(value.sum(Collections.singletonList(dimension)));
- }
-
- public Value sum() {
- return new TensorValue(value.sum(Collections.emptyList()));
- }
-
private Tensor asTensor(Value value, String operationName) {
if ( ! (value instanceof TensorValue))
throw new UnsupportedOperationException("Could not perform " + operationName +
@@ -127,22 +103,37 @@ public class TensorValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- throw new UnsupportedOperationException("A tensor cannot be compared with any value");
+ public Value compare(TruthOperator operator, Value argument) {
+ return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
+ }
+
+ private Tensor compareTensor(TruthOperator operator, Tensor argument) {
+ switch (operator) {
+ case LARGER: return value.larger(argument);
+ case LARGEREQUAL: return value.largerOrEqual(argument);
+ case SMALLER: return value.smaller(argument);
+ case SMALLEREQUAL: return value.smallerOrEqual(argument);
+ case EQUAL: return value.equal(argument);
+ case NOTEQUAL: return value.notEqual(argument);
+ default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
+ }
}
@Override
- public Value function(Function function, Value argument) {
- if (function.equals(Function.min) && argument instanceof TensorValue)
- return min(argument);
- else if (function.equals(Function.max) && argument instanceof TensorValue)
- return max(argument);
- else if (function.equals(Function.atan2) && argument instanceof TensorValue)
- return atan2(argument);
- else if (function.equals(Function.equal) && argument instanceof TensorValue)
- return equal(argument);
+ public Value function(Function function, Value arg) {
+ if (arg instanceof TensorValue)
+ return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString())));
else
- return new TensorValue(value.map((value) -> function.evaluate(value, argument.asDouble())));
+ return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
+ }
+
+ private Tensor functionOnTensor(Function function, Tensor argument) {
+ switch (function) {
+ case min: return value.min(argument);
+ case max: return value.max(argument);
+ case atan2: return value.atan2(argument);
+ default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ }
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 142cd650ee8..76daeccc5e0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -43,7 +43,7 @@ public abstract class Value {
public abstract Value divide(Value value);
/** Perform the comparison specified by the operator between this value and the given value */
- public abstract boolean compare(TruthOperator operator,Value value);
+ public abstract Value compare(TruthOperator operator, Value value);
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function,Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 882d16ebc1c..af05acb365a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import java.util.*;
/**
- * A node which returns true or false depending on the outcome of a comparison.
+ * A node which returns the outcome of a comparison.
*
* @author bratseth
- * @since 5.1.21
*/
public class ComparisonNode extends BooleanNode {
@@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
- Value leftValue=leftCondition.evaluate(context);
- Value rightValue=rightCondition.evaluate(context);
- return new BooleanValue(leftValue.compare(operator,rightValue));
+ Value leftValue = leftCondition.evaluate(context);
+ Value rightValue = rightCondition.evaluate(context);
+ return leftValue.compare(operator,rightValue);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index b8e48dc2f05..cbea2ad627e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -43,8 +43,7 @@ public enum Function implements Serializable {
max(2) { public double evaluate(double x, double y) { return max(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
mod(2) { public double evaluate(double x, double y) { return x % y; } },
- pow(2) { public double evaluate(double x, double y) { return pow(x,y); } },
- equal(2) { public double evaluate(double x, double y) { return x==y ? 1.0 : 0.0; } };
+ pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
private final int arity;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
index 60fe19f909f..932975f3b63 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
@@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable {
EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } };
+ LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
+ NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
private final String operatorString;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index d5f4c67e62a..0fcfdb5d40c 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -73,6 +73,7 @@ TOKEN :
<LE: "<="> |
<LT: "<"> |
<EQ: "=="> |
+ <NQ: "!="> |
<AQ: "~="> |
<GE: ">="> |
<GT: ">"> |
@@ -110,7 +111,6 @@ TOKEN :
<TANH: "tanh"> |
<ATAN2: "atan2"> |
- <EQUAL: "equal"> |
<FMOD: "fmod"> |
<LDEXP: "ldexp"> |
// MAX
@@ -206,6 +206,7 @@ TruthOperator comparator() : { }
( <LE> { return TruthOperator.SMALLEREQUAL; } |
<LT> { return TruthOperator.SMALLER; } |
<EQ> { return TruthOperator.EQUAL; } |
+ <NQ> { return TruthOperator.NOTEQUAL; } |
<AQ> { return TruthOperator.APPROX_EQUAL; } |
<GE> { return TruthOperator.LARGEREQUAL; } |
<GT> { return TruthOperator.LARGER; } )
@@ -552,7 +553,6 @@ Function unaryFunctionName() : { }
Function binaryFunctionName() : { }
{
<ATAN2> { return Function.atan2; } |
- <EQUAL> { return Function.equal; } |
<FMOD> { return Function.fmod; } |
<LDEXP> { return Function.ldexp; } |
<MAX> { return Function.max; } |
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index dc0e2b8bd5e..f69ad7bb9ad 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -230,6 +230,18 @@ public class EvaluationTestCase extends junit.framework.TestCase {
// not_equal (!=)
// argmax
// argmin
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
// tensor rename
assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }");
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 9704103d81c..9863303caa2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -106,7 +106,12 @@ public interface Tensor {
default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); }
default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); }
default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); }
+ default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); }
+ default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); }
+ default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); }
+ default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); }
default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
+ default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); }
default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); }