diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-09-28 22:54:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-28 22:54:13 +0200 |
commit | 12992ecdc0e77968eb5c5544f2ae7d855e443162 (patch) | |
tree | ac8cec3ae02f27ae638876940399f490b4ac4ab1 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java | |
parent | d50f7bd9c99ed9d8edeabb71825f3966f9cd6bd9 (diff) | |
parent | fb0074925e9e8358d38145dc5753de1c935f737d (diff) |
Merge pull request #24251 from vespa-engine/bratseth/operatorsv8.61.17
Bratseth/operators
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java | 139 |
1 files changed, 88 insertions, 51 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index b37bbb543eb..25e03c75376 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.api.annotations.Beta; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -52,6 +51,83 @@ public class TensorValue extends Value { } @Override + public Value not() { + return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); + } + + @Override + public Value or(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value and(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value largerOrEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.largerOrEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value larger(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.larger(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value smallerOrEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value smaller(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.smaller(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value approxEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.approxEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0)); + } + + @Override + public Value notEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.notEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value equal(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.equal(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0)); + } + + @Override public Value add(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.add(((TensorValue)argument).value)); @@ -92,27 +168,6 @@ public class TensorValue extends Value { } @Override - public Value and(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); - } - - @Override - public Value or(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); - } - - @Override - public Value not() { - return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); - } - - @Override public Value power(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.pow(((TensorValue)argument).value)); @@ -123,24 +178,6 @@ public class TensorValue extends Value { public Tensor asTensor() { return value; } @Override - public Value compare(TruthOperator operator, Value argument) { - return new TensorValue(compareTensor(operator, argument.asTensor())); - } - - private Tensor compareTensor(TruthOperator operator, Tensor argument) { - switch (operator) { - case LARGER: return value.larger(argument); - case LARGEREQUAL: return value.largerOrEqual(argument); - case SMALLER: return value.smaller(argument); - case SMALLEREQUAL: return value.smallerOrEqual(argument); - case EQUAL: return value.equal(argument); - case NOTEQUAL: return value.notEqual(argument); - case APPROX_EQUAL: return value.approxEqual(argument); - default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); - } - } - - @Override public Value function(Function function, Value arg) { if (arg instanceof TensorValue) return new TensorValue(functionOnTensor(function, arg.asTensor())); @@ -149,17 +186,17 @@ public class TensorValue extends Value { } private Tensor functionOnTensor(Function function, Tensor argument) { - switch (function) { - case min: return value.min(argument); - case max: return value.max(argument); - case atan2: return value.atan2(argument); - case pow: return value.pow(argument); - case fmod: return value.fmod(argument); - case ldexp: return value.ldexp(argument); - case bit: return value.bit(argument); - case hamming: return value.hamming(argument); - default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); - } + return switch (function) { + case min -> value.min(argument); + case max -> value.max(argument); + case atan2 -> value.atan2(argument); + case pow -> value.pow(argument); + case fmod -> value.fmod(argument); + case ldexp -> value.ldexp(argument); + case bit -> value.bit(argument); + case hamming -> value.hamming(argument); + default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + }; } @Override |