From 4c17e82dff2073938d9c44d035bfa659c555f72e Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Fri, 24 Feb 2023 19:42:35 +0000 Subject: handle other Value subtypes holding tensors --- .../rankingexpression/evaluation/TensorValue.java | 81 +++++----------------- .../rankingexpression/evaluation/Value.java | 4 +- 2 files changed, 20 insertions(+), 65 deletions(-) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 25e03c75376..90556bf4b00 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -57,132 +57,87 @@ public class TensorValue extends Value { @Override public Value or(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); + return new TensorValue(value.join(argument.asTensor(), (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); } @Override public Value and(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); + return new TensorValue(value.join(argument.asTensor(), (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); } @Override public Value largerOrEqual(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.largerOrEqual(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0)); + return new TensorValue(value.largerOrEqual(argument.asTensor())); } @Override public Value larger(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.larger(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0)); + return new TensorValue(value.larger(argument.asTensor())); } @Override public Value smallerOrEqual(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0)); + return new TensorValue(value.smallerOrEqual(argument.asTensor())); } @Override public Value smaller(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.smaller(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0)); + return new TensorValue(value.smaller(argument.asTensor())); } @Override public Value approxEqual(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.approxEqual(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0)); + return new TensorValue(value.approxEqual(argument.asTensor())); } @Override public Value notEqual(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.notEqual(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0)); + return new TensorValue(value.notEqual(argument.asTensor())); } @Override public Value equal(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.equal(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0)); + return new TensorValue(value.equal(argument.asTensor())); } @Override public Value add(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.add(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> value + argument.asDouble())); + return new TensorValue(value.add(argument.asTensor())); } @Override public Value subtract(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.subtract(((TensorValue) argument).value)); - else - return new TensorValue(value.map((value) -> value - argument.asDouble())); + return new TensorValue(value.subtract(argument.asTensor())); } @Override public Value multiply(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.multiply(((TensorValue) argument).value)); - else - return new TensorValue(value.map((value) -> value * argument.asDouble())); + return new TensorValue(value.multiply(argument.asTensor())); } @Override public Value divide(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.divide(((TensorValue) argument).value)); - else - return new TensorValue(value.map((value) -> value / argument.asDouble())); + return new TensorValue(value.divide(argument.asTensor())); } @Override public Value modulo(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.fmod(((TensorValue) argument).value)); - else - return new TensorValue(value.map((value) -> value % argument.asDouble())); + return new TensorValue(value.fmod(argument.asTensor())); } @Override public Value power(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.pow(((TensorValue)argument).value)); - else - return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble()))); + return new TensorValue(value.pow(argument.asTensor())); } public Tensor asTensor() { return value; } @Override public Value function(Function function, Value arg) { - if (arg instanceof TensorValue) + if (function.arity() != 1) return new TensorValue(functionOnTensor(function, arg.asTensor())); else - return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); + return new TensorValue(value.map((value) -> function.evaluate(value, 0.0))); } private Tensor functionOnTensor(Function function, Tensor argument) { @@ -195,7 +150,7 @@ public class TensorValue extends Value { case ldexp -> value.ldexp(argument); case bit -> value.bit(argument); case hamming -> value.hamming(argument); - default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + default -> value.join(argument, (a, b) -> function.evaluate(a, b)); }; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 5de2138147e..ed53b82f1d5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -114,8 +114,8 @@ public abstract class Value { return new TensorValue(Tensor.from(value)); else if ((value.indexOf('.') == -1) && (value.indexOf('e') == -1) && (value.indexOf('E') == -1)) return new LongValue(Long.parseLong(value)); - else - return new DoubleValue(Double.parseDouble(value)); + else + return new DoubleValue(Double.parseDouble(value)); } public static Value of(Tensor tensor) { -- cgit v1.2.3