summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java139
1 files changed, 88 insertions, 51 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index b37bbb543eb..25e03c75376 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -52,6 +51,83 @@ public class TensorValue extends Value {
}
@Override
+ public Value not() {
+ return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value or(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value and(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value largerOrEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.largerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value larger(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.larger(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value smallerOrEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value smaller(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smaller(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value approxEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.approxEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value notEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.notEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value equal(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.equal(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
public Value add(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
@@ -92,27 +168,6 @@ public class TensorValue extends Value {
}
@Override
- public Value and(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value or(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value not() {
- return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
- }
-
- @Override
public Value power(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.pow(((TensorValue)argument).value));
@@ -123,24 +178,6 @@ public class TensorValue extends Value {
public Tensor asTensor() { return value; }
@Override
- public Value compare(TruthOperator operator, Value argument) {
- return new TensorValue(compareTensor(operator, argument.asTensor()));
- }
-
- private Tensor compareTensor(TruthOperator operator, Tensor argument) {
- switch (operator) {
- case LARGER: return value.larger(argument);
- case LARGEREQUAL: return value.largerOrEqual(argument);
- case SMALLER: return value.smaller(argument);
- case SMALLEREQUAL: return value.smallerOrEqual(argument);
- case EQUAL: return value.equal(argument);
- case NOTEQUAL: return value.notEqual(argument);
- case APPROX_EQUAL: return value.approxEqual(argument);
- default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
- }
- }
-
- @Override
public Value function(Function function, Value arg) {
if (arg instanceof TensorValue)
return new TensorValue(functionOnTensor(function, arg.asTensor()));
@@ -149,17 +186,17 @@ public class TensorValue extends Value {
}
private Tensor functionOnTensor(Function function, Tensor argument) {
- switch (function) {
- case min: return value.min(argument);
- case max: return value.max(argument);
- case atan2: return value.atan2(argument);
- case pow: return value.pow(argument);
- case fmod: return value.fmod(argument);
- case ldexp: return value.ldexp(argument);
- case bit: return value.bit(argument);
- case hamming: return value.hamming(argument);
- default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
- }
+ return switch (function) {
+ case min -> value.min(argument);
+ case max -> value.max(argument);
+ case atan2 -> value.atan2(argument);
+ case pow -> value.pow(argument);
+ case fmod -> value.fmod(argument);
+ case ldexp -> value.ldexp(argument);
+ case bit -> value.bit(argument);
+ case hamming -> value.hamming(argument);
+ default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ };
}
@Override