summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2023-02-25 09:11:42 +0100
committerGitHub <noreply@github.com>2023-02-25 09:11:42 +0100
commitf434753332b67499e4f9e71fd04c0e524135b6f1 (patch)
treeaab4172796d6dde0022d46e43d5e1e499fafba58
parenta8659be0f80b80e875222baa259f938bde151023 (diff)
parent4c17e82dff2073938d9c44d035bfa659c555f72e (diff)
Merge pull request #26183 from vespa-engine/arnej/handle-lazy-value
handle other Value subtypes holding tensors
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java81
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java4
2 files changed, 20 insertions, 65 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 25e03c75376..90556bf4b00 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -57,132 +57,87 @@ public class TensorValue extends Value {
@Override
public Value or(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
+ return new TensorValue(value.join(argument.asTensor(), (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
}
@Override
public Value and(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
+ return new TensorValue(value.join(argument.asTensor(), (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
}
@Override
public Value largerOrEqual(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.largerOrEqual(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0));
+ return new TensorValue(value.largerOrEqual(argument.asTensor()));
}
@Override
public Value larger(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.larger(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0));
+ return new TensorValue(value.larger(argument.asTensor()));
}
@Override
public Value smallerOrEqual(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0));
+ return new TensorValue(value.smallerOrEqual(argument.asTensor()));
}
@Override
public Value smaller(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.smaller(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0));
+ return new TensorValue(value.smaller(argument.asTensor()));
}
@Override
public Value approxEqual(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.approxEqual(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0));
+ return new TensorValue(value.approxEqual(argument.asTensor()));
}
@Override
public Value notEqual(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.notEqual(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0));
+ return new TensorValue(value.notEqual(argument.asTensor()));
}
@Override
public Value equal(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.equal(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0));
+ return new TensorValue(value.equal(argument.asTensor()));
}
@Override
public Value add(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.add(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> value + argument.asDouble()));
+ return new TensorValue(value.add(argument.asTensor()));
}
@Override
public Value subtract(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.subtract(((TensorValue) argument).value));
- else
- return new TensorValue(value.map((value) -> value - argument.asDouble()));
+ return new TensorValue(value.subtract(argument.asTensor()));
}
@Override
public Value multiply(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.multiply(((TensorValue) argument).value));
- else
- return new TensorValue(value.map((value) -> value * argument.asDouble()));
+ return new TensorValue(value.multiply(argument.asTensor()));
}
@Override
public Value divide(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.divide(((TensorValue) argument).value));
- else
- return new TensorValue(value.map((value) -> value / argument.asDouble()));
+ return new TensorValue(value.divide(argument.asTensor()));
}
@Override
public Value modulo(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.fmod(((TensorValue) argument).value));
- else
- return new TensorValue(value.map((value) -> value % argument.asDouble()));
+ return new TensorValue(value.fmod(argument.asTensor()));
}
@Override
public Value power(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.pow(((TensorValue)argument).value));
- else
- return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble())));
+ return new TensorValue(value.pow(argument.asTensor()));
}
public Tensor asTensor() { return value; }
@Override
public Value function(Function function, Value arg) {
- if (arg instanceof TensorValue)
+ if (function.arity() != 1)
return new TensorValue(functionOnTensor(function, arg.asTensor()));
else
- return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
+ return new TensorValue(value.map((value) -> function.evaluate(value, 0.0)));
}
private Tensor functionOnTensor(Function function, Tensor argument) {
@@ -195,7 +150,7 @@ public class TensorValue extends Value {
case ldexp -> value.ldexp(argument);
case bit -> value.bit(argument);
case hamming -> value.hamming(argument);
- default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ default -> value.join(argument, (a, b) -> function.evaluate(a, b));
};
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 5de2138147e..ed53b82f1d5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -114,8 +114,8 @@ public abstract class Value {
return new TensorValue(Tensor.from(value));
else if ((value.indexOf('.') == -1) && (value.indexOf('e') == -1) && (value.indexOf('E') == -1))
return new LongValue(Long.parseLong(value));
- else
- return new DoubleValue(Double.parseDouble(value));
+ else
+ return new DoubleValue(Double.parseDouble(value));
}
public static Value of(Tensor tensor) {