diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 15:40:06 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 15:40:06 +0100 |
commit | 0114377d53c1e70bba679582abb88cc7af038bc1 (patch) | |
tree | ef7b7c7dc66b3810d8d6b1a0641bdd98fa238c1e /searchlib/src/main/java/com | |
parent | bd6d16ba66a7b6745fc15a8b25dc7120fb5580ab (diff) |
Comparison functions on tensors
Diffstat (limited to 'searchlib/src/main/java/com')
8 files changed, 40 insertions, 60 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2bae382d5bd..f8dcd8a6127 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - return operator.evaluate(asDouble(), value.asDouble()); + public Value compare(TruthOperator operator, Value value) { + return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 1cd65c3133a..35c210da0c0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue { } @Override - public boolean compare(TruthOperator operator, Value value) { - try { - return operator.evaluate(this.value, value.asDouble()); - } - catch (UnsupportedOperationException e) { - throw unsupported("comparison",value); - } - } - - @Override public Value function(Function function, Value value) { // use the tensor implementation of max and min if the argument is a tensor if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index b2dc5e27b91..9d0f0b692c4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -68,10 +68,10 @@ public class StringValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { + public Value compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) - return this.equals(value); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='"); + return new BooleanValue(this.equals(value)); + throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index dc422f2c8da..b1f4a7b20ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -89,30 +89,6 @@ public class TensorValue extends Value { return new TensorValue(value.map((value) -> value / argument.asDouble())); } - public Value min(Value argument) { - return new TensorValue(value.min(asTensor(argument, "min"))); - } - - public Value max(Value argument) { - return new TensorValue(value.max(asTensor(argument, "max"))); - } - - public Value atan2(Value argument) { - return new TensorValue(value.atan2(asTensor(argument, "atan2"))); - } - - public Value equal(Value argument) { - return new TensorValue(value.equal(asTensor(argument, "equal"))); - } - - public Value sum(String dimension) { - return new TensorValue(value.sum(Collections.singletonList(dimension))); - } - - public Value sum() { - return new TensorValue(value.sum(Collections.emptyList())); - } - private Tensor asTensor(Value value, String operationName) { if ( ! (value instanceof TensorValue)) throw new UnsupportedOperationException("Could not perform " + operationName + @@ -127,22 +103,37 @@ public class TensorValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - throw new UnsupportedOperationException("A tensor cannot be compared with any value"); + public Value compare(TruthOperator operator, Value argument) { + return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); + } + + private Tensor compareTensor(TruthOperator operator, Tensor argument) { + switch (operator) { + case LARGER: return value.larger(argument); + case LARGEREQUAL: return value.largerOrEqual(argument); + case SMALLER: return value.smaller(argument); + case SMALLEREQUAL: return value.smallerOrEqual(argument); + case EQUAL: return value.equal(argument); + case NOTEQUAL: return value.notEqual(argument); + default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); + } } @Override - public Value function(Function function, Value argument) { - if (function.equals(Function.min) && argument instanceof TensorValue) - return min(argument); - else if (function.equals(Function.max) && argument instanceof TensorValue) - return max(argument); - else if (function.equals(Function.atan2) && argument instanceof TensorValue) - return atan2(argument); - else if (function.equals(Function.equal) && argument instanceof TensorValue) - return equal(argument); + public Value function(Function function, Value arg) { + if (arg instanceof TensorValue) + return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString()))); else - return new TensorValue(value.map((value) -> function.evaluate(value, argument.asDouble()))); + return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); + } + + private Tensor functionOnTensor(Function function, Tensor argument) { + switch (function) { + case min: return value.min(argument); + case max: return value.max(argument); + case atan2: return value.atan2(argument); + default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + } } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 142cd650ee8..76daeccc5e0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -43,7 +43,7 @@ public abstract class Value { public abstract Value divide(Value value); /** Perform the comparison specified by the operator between this value and the given value */ - public abstract boolean compare(TruthOperator operator,Value value); + public abstract Value compare(TruthOperator operator, Value value); /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function,Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 882d16ebc1c..af05acb365a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import java.util.*; /** - * A node which returns true or false depending on the outcome of a comparison. + * A node which returns the outcome of a comparison. * * @author bratseth - * @since 5.1.21 */ public class ComparisonNode extends BooleanNode { @@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode { @Override public Value evaluate(Context context) { - Value leftValue=leftCondition.evaluate(context); - Value rightValue=rightCondition.evaluate(context); - return new BooleanValue(leftValue.compare(operator,rightValue)); + Value leftValue = leftCondition.evaluate(context); + Value rightValue = rightCondition.evaluate(context); + return leftValue.compare(operator,rightValue); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index b8e48dc2f05..cbea2ad627e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -43,8 +43,7 @@ public enum Function implements Serializable { max(2) { public double evaluate(double x, double y) { return max(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, mod(2) { public double evaluate(double x, double y) { return x % y; } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, - equal(2) { public double evaluate(double x, double y) { return x==y ? 1.0 : 0.0; } }; + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java index 60fe19f909f..932975f3b63 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java @@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable { EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }; + LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, + NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; private final String operatorString; |