diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 15:40:06 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 15:40:06 +0100 |
commit | 0114377d53c1e70bba679582abb88cc7af038bc1 (patch) | |
tree | ef7b7c7dc66b3810d8d6b1a0641bdd98fa238c1e | |
parent | bd6d16ba66a7b6745fc15a8b25dc7120fb5580ab (diff) |
Comparison functions on tensors
11 files changed, 59 insertions, 62 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2bae382d5bd..f8dcd8a6127 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - return operator.evaluate(asDouble(), value.asDouble()); + public Value compare(TruthOperator operator, Value value) { + return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 1cd65c3133a..35c210da0c0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue { } @Override - public boolean compare(TruthOperator operator, Value value) { - try { - return operator.evaluate(this.value, value.asDouble()); - } - catch (UnsupportedOperationException e) { - throw unsupported("comparison",value); - } - } - - @Override public Value function(Function function, Value value) { // use the tensor implementation of max and min if the argument is a tensor if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index b2dc5e27b91..9d0f0b692c4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -68,10 +68,10 @@ public class StringValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { + public Value compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) - return this.equals(value); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='"); + return new BooleanValue(this.equals(value)); + throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index dc422f2c8da..b1f4a7b20ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -89,30 +89,6 @@ public class TensorValue extends Value { return new TensorValue(value.map((value) -> value / argument.asDouble())); } - public Value min(Value argument) { - return new TensorValue(value.min(asTensor(argument, "min"))); - } - - public Value max(Value argument) { - return new TensorValue(value.max(asTensor(argument, "max"))); - } - - public Value atan2(Value argument) { - return new TensorValue(value.atan2(asTensor(argument, "atan2"))); - } - - public Value equal(Value argument) { - return new TensorValue(value.equal(asTensor(argument, "equal"))); - } - - public Value sum(String dimension) { - return new TensorValue(value.sum(Collections.singletonList(dimension))); - } - - public Value sum() { - return new TensorValue(value.sum(Collections.emptyList())); - } - private Tensor asTensor(Value value, String operationName) { if ( ! (value instanceof TensorValue)) throw new UnsupportedOperationException("Could not perform " + operationName + @@ -127,22 +103,37 @@ public class TensorValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - throw new UnsupportedOperationException("A tensor cannot be compared with any value"); + public Value compare(TruthOperator operator, Value argument) { + return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); + } + + private Tensor compareTensor(TruthOperator operator, Tensor argument) { + switch (operator) { + case LARGER: return value.larger(argument); + case LARGEREQUAL: return value.largerOrEqual(argument); + case SMALLER: return value.smaller(argument); + case SMALLEREQUAL: return value.smallerOrEqual(argument); + case EQUAL: return value.equal(argument); + case NOTEQUAL: return value.notEqual(argument); + default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); + } } @Override - public Value function(Function function, Value argument) { - if (function.equals(Function.min) && argument instanceof TensorValue) - return min(argument); - else if (function.equals(Function.max) && argument instanceof TensorValue) - return max(argument); - else if (function.equals(Function.atan2) && argument instanceof TensorValue) - return atan2(argument); - else if (function.equals(Function.equal) && argument instanceof TensorValue) - return equal(argument); + public Value function(Function function, Value arg) { + if (arg instanceof TensorValue) + return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString()))); else - return new TensorValue(value.map((value) -> function.evaluate(value, argument.asDouble()))); + return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); + } + + private Tensor functionOnTensor(Function function, Tensor argument) { + switch (function) { + case min: return value.min(argument); + case max: return value.max(argument); + case atan2: return value.atan2(argument); + default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + } } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 142cd650ee8..76daeccc5e0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -43,7 +43,7 @@ public abstract class Value { public abstract Value divide(Value value); /** Perform the comparison specified by the operator between this value and the given value */ - public abstract boolean compare(TruthOperator operator,Value value); + public abstract Value compare(TruthOperator operator, Value value); /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function,Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 882d16ebc1c..af05acb365a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import java.util.*; /** - * A node which returns true or false depending on the outcome of a comparison. + * A node which returns the outcome of a comparison. * * @author bratseth - * @since 5.1.21 */ public class ComparisonNode extends BooleanNode { @@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode { @Override public Value evaluate(Context context) { - Value leftValue=leftCondition.evaluate(context); - Value rightValue=rightCondition.evaluate(context); - return new BooleanValue(leftValue.compare(operator,rightValue)); + Value leftValue = leftCondition.evaluate(context); + Value rightValue = rightCondition.evaluate(context); + return leftValue.compare(operator,rightValue); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index b8e48dc2f05..cbea2ad627e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -43,8 +43,7 @@ public enum Function implements Serializable { max(2) { public double evaluate(double x, double y) { return max(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, mod(2) { public double evaluate(double x, double y) { return x % y; } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, - equal(2) { public double evaluate(double x, double y) { return x==y ? 1.0 : 0.0; } }; + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java index 60fe19f909f..932975f3b63 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java @@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable { EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }; + LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, + NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; private final String operatorString; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index d5f4c67e62a..0fcfdb5d40c 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -73,6 +73,7 @@ TOKEN : <LE: "<="> | <LT: "<"> | <EQ: "=="> | + <NQ: "!="> | <AQ: "~="> | <GE: ">="> | <GT: ">"> | @@ -110,7 +111,6 @@ TOKEN : <TANH: "tanh"> | <ATAN2: "atan2"> | - <EQUAL: "equal"> | <FMOD: "fmod"> | <LDEXP: "ldexp"> | // MAX @@ -206,6 +206,7 @@ TruthOperator comparator() : { } ( <LE> { return TruthOperator.SMALLEREQUAL; } | <LT> { return TruthOperator.SMALLER; } | <EQ> { return TruthOperator.EQUAL; } | + <NQ> { return TruthOperator.NOTEQUAL; } | <AQ> { return TruthOperator.APPROX_EQUAL; } | <GE> { return TruthOperator.LARGEREQUAL; } | <GT> { return TruthOperator.LARGER; } ) @@ -552,7 +553,6 @@ Function unaryFunctionName() : { } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | - <EQUAL> { return Function.equal; } | <FMOD> { return Function.fmod; } | <LDEXP> { return Function.ldexp; } | <MAX> { return Function.max; } | diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index dc0e2b8bd5e..f69ad7bb9ad 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -230,6 +230,18 @@ public class EvaluationTestCase extends junit.framework.TestCase { // not_equal (!=) // argmax // argmin + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); // tensor rename assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 9704103d81c..9863303caa2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -106,7 +106,12 @@ public interface Tensor { default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); } default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } + default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); } + default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); } + default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); } + default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); } default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } + default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); } |