diff options
author | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-15 19:43:59 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-15 19:43:59 +0100 |
commit | ed9640e21c4b918b26db24a5b2fb3ee877bd0ce8 (patch) | |
tree | c8a45194593ed9d4704666e0f68217cd275270a4 /searchlib/src | |
parent | 498912776725422835259964684a1baf60800cdb (diff) |
Add Java ranking set membership for tensors
Diffstat (limited to 'searchlib/src')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index f8e44f1087c..f6b1a1a8979 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; -import java.util.*; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.function.Predicate; /** * A node which returns true or false depending on a set membership test @@ -55,11 +60,30 @@ public class SetMembershipNode extends BooleanNode { @Override public Value evaluate(Context context) { Value value = testValue.evaluate(context); + if (value instanceof TensorValue) { + return evaluateTensor(((TensorValue) value).asTensor(), context); + } + return evaluateValue(value, context); + } + + private Value evaluateValue(Value value, Context context) { + return new BooleanValue(testMembership(value::equals, context)); + } + + private Value evaluateTensor(Tensor tensor, Context context) { + return new TensorValue(tensor.map((value) -> contains(value, context) ? 1.0 : 0.0)); + } + + private boolean contains(double value, Context context) { + return testMembership((setValue) -> setValue.asDouble() == value, context); + } + + private boolean testMembership(Predicate<Value> test, Context context) { for (ExpressionNode setValue : setValues) { - if (setValue.evaluate(context).equals(value)) - return new BooleanValue(true); + if (test.test(setValue.evaluate(context))) + return true; } - return new BooleanValue(false); + return false; } @Override |