summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorLester Solbakken <lesters@yahoo-inc.com>2017-11-15 19:43:59 +0100
committerLester Solbakken <lesters@yahoo-inc.com>2017-11-15 19:43:59 +0100
commited9640e21c4b918b26db24a5b2fb3ee877bd0ce8 (patch)
treec8a45194593ed9d4704666e0f68217cd275270a4 /searchlib/src
parent498912776725422835259964684a1baf60800cdb (diff)
Add Java ranking set membership for tensors
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java32
1 files changed, 28 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index f8e44f1087c..f6b1a1a8979 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.Predicate;
/**
* A node which returns true or false depending on a set membership test
@@ -55,11 +60,30 @@ public class SetMembershipNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
Value value = testValue.evaluate(context);
+ if (value instanceof TensorValue) {
+ return evaluateTensor(((TensorValue) value).asTensor(), context);
+ }
+ return evaluateValue(value, context);
+ }
+
+ private Value evaluateValue(Value value, Context context) {
+ return new BooleanValue(testMembership(value::equals, context));
+ }
+
+ private Value evaluateTensor(Tensor tensor, Context context) {
+ return new TensorValue(tensor.map((value) -> contains(value, context) ? 1.0 : 0.0));
+ }
+
+ private boolean contains(double value, Context context) {
+ return testMembership((setValue) -> setValue.asDouble() == value, context);
+ }
+
+ private boolean testMembership(Predicate<Value> test, Context context) {
for (ExpressionNode setValue : setValues) {
- if (setValue.evaluate(context).equals(value))
- return new BooleanValue(true);
+ if (test.test(setValue.evaluate(context)))
+ return true;
}
- return new BooleanValue(false);
+ return false;
}
@Override