summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 15:40:06 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 15:40:06 +0100
commit0114377d53c1e70bba679582abb88cc7af038bc1 (patch)
treeef7b7c7dc66b3810d8d6b1a0641bdd98fa238c1e /searchlib/src/main/java/com
parentbd6d16ba66a7b6745fc15a8b25dc7120fb5580ab (diff)
Comparison functions on tensors
Diffstat (limited to 'searchlib/src/main/java/com')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java63
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java3
8 files changed, 40 insertions, 60 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index 2bae382d5bd..f8dcd8a6127 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- return operator.evaluate(asDouble(), value.asDouble());
+ public Value compare(TruthOperator operator, Value value) {
+ return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 1cd65c3133a..35c210da0c0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- try {
- return operator.evaluate(this.value, value.asDouble());
- }
- catch (UnsupportedOperationException e) {
- throw unsupported("comparison",value);
- }
- }
-
- @Override
public Value function(Function function, Value value) {
// use the tensor implementation of max and min if the argument is a tensor
if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index b2dc5e27b91..9d0f0b692c4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -68,10 +68,10 @@ public class StringValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
+ public Value compare(TruthOperator operator, Value value) {
if (operator.equals(TruthOperator.EQUAL))
- return this.equals(value);
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='");
+ return new BooleanValue(this.equals(value));
+ throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index dc422f2c8da..b1f4a7b20ca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -89,30 +89,6 @@ public class TensorValue extends Value {
return new TensorValue(value.map((value) -> value / argument.asDouble()));
}
- public Value min(Value argument) {
- return new TensorValue(value.min(asTensor(argument, "min")));
- }
-
- public Value max(Value argument) {
- return new TensorValue(value.max(asTensor(argument, "max")));
- }
-
- public Value atan2(Value argument) {
- return new TensorValue(value.atan2(asTensor(argument, "atan2")));
- }
-
- public Value equal(Value argument) {
- return new TensorValue(value.equal(asTensor(argument, "equal")));
- }
-
- public Value sum(String dimension) {
- return new TensorValue(value.sum(Collections.singletonList(dimension)));
- }
-
- public Value sum() {
- return new TensorValue(value.sum(Collections.emptyList()));
- }
-
private Tensor asTensor(Value value, String operationName) {
if ( ! (value instanceof TensorValue))
throw new UnsupportedOperationException("Could not perform " + operationName +
@@ -127,22 +103,37 @@ public class TensorValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- throw new UnsupportedOperationException("A tensor cannot be compared with any value");
+ public Value compare(TruthOperator operator, Value argument) {
+ return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
+ }
+
+ private Tensor compareTensor(TruthOperator operator, Tensor argument) {
+ switch (operator) {
+ case LARGER: return value.larger(argument);
+ case LARGEREQUAL: return value.largerOrEqual(argument);
+ case SMALLER: return value.smaller(argument);
+ case SMALLEREQUAL: return value.smallerOrEqual(argument);
+ case EQUAL: return value.equal(argument);
+ case NOTEQUAL: return value.notEqual(argument);
+ default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
+ }
}
@Override
- public Value function(Function function, Value argument) {
- if (function.equals(Function.min) && argument instanceof TensorValue)
- return min(argument);
- else if (function.equals(Function.max) && argument instanceof TensorValue)
- return max(argument);
- else if (function.equals(Function.atan2) && argument instanceof TensorValue)
- return atan2(argument);
- else if (function.equals(Function.equal) && argument instanceof TensorValue)
- return equal(argument);
+ public Value function(Function function, Value arg) {
+ if (arg instanceof TensorValue)
+ return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString())));
else
- return new TensorValue(value.map((value) -> function.evaluate(value, argument.asDouble())));
+ return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
+ }
+
+ private Tensor functionOnTensor(Function function, Tensor argument) {
+ switch (function) {
+ case min: return value.min(argument);
+ case max: return value.max(argument);
+ case atan2: return value.atan2(argument);
+ default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ }
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 142cd650ee8..76daeccc5e0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -43,7 +43,7 @@ public abstract class Value {
public abstract Value divide(Value value);
/** Perform the comparison specified by the operator between this value and the given value */
- public abstract boolean compare(TruthOperator operator,Value value);
+ public abstract Value compare(TruthOperator operator, Value value);
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function,Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 882d16ebc1c..af05acb365a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import java.util.*;
/**
- * A node which returns true or false depending on the outcome of a comparison.
+ * A node which returns the outcome of a comparison.
*
* @author bratseth
- * @since 5.1.21
*/
public class ComparisonNode extends BooleanNode {
@@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
- Value leftValue=leftCondition.evaluate(context);
- Value rightValue=rightCondition.evaluate(context);
- return new BooleanValue(leftValue.compare(operator,rightValue));
+ Value leftValue = leftCondition.evaluate(context);
+ Value rightValue = rightCondition.evaluate(context);
+ return leftValue.compare(operator,rightValue);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index b8e48dc2f05..cbea2ad627e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -43,8 +43,7 @@ public enum Function implements Serializable {
max(2) { public double evaluate(double x, double y) { return max(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
mod(2) { public double evaluate(double x, double y) { return x % y; } },
- pow(2) { public double evaluate(double x, double y) { return pow(x,y); } },
- equal(2) { public double evaluate(double x, double y) { return x==y ? 1.0 : 0.0; } };
+ pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
private final int arity;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
index 60fe19f909f..932975f3b63 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
@@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable {
EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } };
+ LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
+ NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
private final String operatorString;