diff options
Diffstat (limited to 'searchlib')
19 files changed, 141 insertions, 37 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index ff8758bd1e7..a1e79df95e3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -30,8 +30,7 @@ public abstract class Context implements EvaluationContext { public TensorType getTensorType(String name) { ValueType type = getType(name); if (type == null) return null; - if (type.isTensor()) return type.tensorType().get(); - return TensorType.empty; // double as tensor + return type.tensorType(); } /** Returns a variable as a tensor */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java index 06301372dcc..046ad7861ef 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java @@ -2,8 +2,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo import com.yahoo.tensor.TensorType; -import java.util.Optional; - /** * The type of a ranking expression value - either a double or a tensor. * @@ -11,27 +9,24 @@ import java.util.Optional; */ public class ValueType { - private static final ValueType doubleValueType = new ValueType(Optional.empty()); + private static final ValueType doubleValueType = new ValueType(TensorType.empty); - private final Optional<TensorType> tensorType; + private final TensorType tensorType; - private ValueType(Optional<TensorType> type) { - this.tensorType = type; + private ValueType(TensorType tensorType) { + this.tensorType = tensorType; } - /** Returns true if this is a double type */ - public boolean isDouble() { return ! tensorType.isPresent(); } - - /** Returns true if this is a tensor type */ - public boolean isTensor() { return tensorType.isPresent(); } + /** Returns true if this is the double type */ + public boolean isDouble() { return tensorType.rank() == 0; } - /** The specific tensor type of this, or empty if this is not a tensor type */ - public Optional<TensorType> tensorType() { return tensorType; } + /** The type of this as a tensor type. The double type is the empty tensor type (rank 0) */ + public TensorType tensorType() { return tensorType; } /** Returns the type representing a double */ public static ValueType doubleType() { return doubleValueType; } /** Returns a type representing the given tensor type */ - public static ValueType tensorType(TensorType type) { return new ValueType(Optional.of(type)); } + public static ValueType of(TensorType type) { return new ValueType(type); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java index 372fb00431b..b4e126f69e0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; @@ -24,6 +25,9 @@ public class GBDTForestNode extends ExpressionNode { } @Override + public final ValueType type(Context context) { return ValueType.doubleType(); } + + @Override public final Value evaluate(Context context) { int pc = 0; double treeSum = 0; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index 4d7b4835892..f085194a7df 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; @@ -49,6 +50,9 @@ public final class GBDTNode extends ExpressionNode { public final double[] values() { return values; } @Override + public final ValueType type(Context context) { return ValueType.doubleType(); } + + @Override public final Value evaluate(Context context) { return new DoubleValue(evaluate(values,0,context)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index 518a15bcc87..d45037b6044 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -4,8 +4,15 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Join; -import java.util.*; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; /** * A binary mathematical operation @@ -73,14 +80,26 @@ public final class ArithmeticNode extends CompositeNode { } @Override + public ValueType type(Context context) { + // Compute type using tensor types as arithmetic operators are supported on tensors + // and is correct also in the special case of doubles. + // As all our functions are type-commutative, we don't need to take operator precedence into account + TensorType type = children.get(0).type(context).tensorType(); + for (int i = 1; i < children.size(); i++) + type = Join.outputType(type, children.get(i).type(context).tensorType()); + return ValueType.of(type); + } + + @Override public Value evaluate(Context context) { Iterator<ExpressionNode> child = children.iterator(); + // Apply in precedence order: Deque<ValueItem> stack = new ArrayDeque<>(); stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context))); for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) { ArithmeticOperator op = it.next(); - if (!stack.isEmpty()) { + if ( ! stack.isEmpty()) { while (stack.peek().op.hasPrecedenceOver(op)) { popStack(stack); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 9484f789169..fdbb22093ea 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -1,11 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; -import java.util.*; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; /** * A node which returns the outcome of a comparison. @@ -46,6 +48,11 @@ public class ComparisonNode extends BooleanNode { } @Override + public ValueType type(Context context) { + return ValueType.doubleType(); // by definition + } + + @Override public Value evaluate(Context context) { Value leftValue = leftCondition.evaluate(context); Value rightValue = rightCondition.evaluate(context); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java index cd473ae6a6f..e6074a5f745 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.Deque; @@ -47,6 +48,9 @@ public final class ConstantNode extends ExpressionNode { } @Override + public ValueType type(Context context) { return value.type(); } + + @Override public Value evaluate(Context context) { return value; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index b5d7c41d698..8404226c33b 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.Collections; import java.util.Deque; @@ -48,6 +49,11 @@ public final class EmbracedNode extends CompositeNode { } @Override + public ValueType type(Context context) { + return value.type(context); + } + + @Override public CompositeNode setChildren(List<ExpressionNode> newChildren) { if (newChildren.size() != 1) throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index 31984dca54d..5d06a562b5d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -47,7 +47,7 @@ public abstract class ExpressionNode implements Serializable { * @param context the variable type bindings to use for this evaluation * @throws IllegalArgumentException if there are variables which are not bound in the given map */ - public ValueType type(Context context) { return ValueType.doubleType(); } // double is default + public abstract ValueType type(Context context); /** * Returns the value of evaluating this expression over the given context. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index 142e282e5c6..b187b8f029c 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -4,6 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.functions.Join; import java.util.ArrayList; import java.util.Collections; @@ -64,16 +66,29 @@ public final class FunctionNode extends CompositeNode { } @Override + public ValueType type(Context context) { + if (arguments.expressions().size() == 0) + return ValueType.doubleType(); + + ValueType argument1Type = arguments.expressions().get(0).type(context); + if (arguments.expressions().size() == 1) + return argument1Type; + + ValueType argument2Type = arguments.expressions().get(1).type(context); + return ValueType.of(Join.outputType(argument1Type.tensorType(), argument2Type.tensorType())); + } + + @Override public Value evaluate(Context context) { if (arguments.expressions().size() == 0) - return DoubleValue.zero.function(function,DoubleValue.zero); + return DoubleValue.zero.function(function ,DoubleValue.zero); Value argument1 = arguments.expressions().get(0).evaluate(context); if (arguments.expressions().size() == 1) return argument1.function(function, DoubleValue.zero); Value argument2 = arguments.expressions().get(1).evaluate(context); - return argument1.function(function,argument2); + return argument1.function(function, argument2); } /** Returns a new function node with the children replaced by the given children */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index 9da1ba40144..fcd40bed4d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.TensorType; import java.util.Collections; @@ -46,6 +47,9 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { return generator.toString(context, path, this); } + @Override + public ValueType type(Context context) { return ValueType.of(type); } + /** Evaluate this in a context which must have the arguments bound */ @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 1b429de0be5..b9866bec027 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -3,13 +3,14 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.*; /** * A conditional branch of a ranking expression. * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen * @author bratseth */ public final class IfNode extends CompositeNode { @@ -70,6 +71,17 @@ public final class IfNode extends CompositeNode { } @Override + public ValueType type(Context context) { + ValueType trueType = trueExpression.type(context); + ValueType falseType = falseExpression.type(context); + if ( ! trueType.equals(falseType)) + throw new IllegalArgumentException("An if expression must produce a value of the same type in both " + + "alternatives, but the 'true' type is " + trueType + " while the " + + "'false' type is " + falseType); + return trueType; + } + + @Override public Value evaluate(Context context) { if (condition.evaluate(context).asBoolean()) return trueExpression.evaluate(context); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 78206d75d0d..b898529c4b9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.Collections; import java.util.Deque; @@ -14,20 +15,20 @@ import java.util.function.DoubleUnaryOperator; /** * A free, parametrized function - * + * * @author bratseth */ public class LambdaFunctionNode extends CompositeNode { private final ImmutableList<String> arguments; private final ExpressionNode functionExpression; - + public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { // TODO: Verify that the function only accesses the given arguments this.arguments = ImmutableList.copyOf(arguments); this.functionExpression = functionExpression; } - + @Override public List<ExpressionNode> children() { return Collections.singletonList(functionExpression); @@ -54,19 +55,24 @@ public class LambdaFunctionNode extends CompositeNode { return b.toString(); } + @Override + public ValueType type(Context context) { + return ValueType.doubleType(); // by definition - no nested lambdas + } + /** Evaluate this in a context which must have the arguments bound */ @Override public Value evaluate(Context context) { return functionExpression.evaluate(context); } - - /** + + /** * Returns this as a double unary operator - * - * @throws IllegalStateException if this has more than one argument + * + * @throws IllegalStateException if this has more than one argument */ public DoubleUnaryOperator asDoubleUnaryOperator() { - if (arguments.size() > 1) + if (arguments.size() > 1) throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + "Must have at most one argument " + " but has " + arguments); return new DoubleUnaryLambda(); @@ -93,7 +99,7 @@ public class LambdaFunctionNode extends CompositeNode { context.put(arguments.get(0), operand); return evaluate(context).asDouble(); } - + @Override public String toString() { return LambdaFunctionNode.this.toString(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index 69df572272a..cf6475238c4 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.Deque; @@ -30,6 +31,9 @@ public final class NameNode extends ExpressionNode { } @Override + public ValueType type(Context context) { throw new RuntimeException("Named nodes can not have a type"); } + + @Override public Value evaluate(Context context) { throw new RuntimeException("Name nodes should never be evaluated"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java index 61c20a97b64..2e685a6c8ab 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.Collections; import java.util.Deque; @@ -36,6 +37,11 @@ public class NegativeNode extends CompositeNode { } @Override + public ValueType type(Context context) { + return value.type(context); + } + + @Override public Value evaluate(Context context) { return value.evaluate(context).negate(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java index 8c459a032bd..c4b940f1bd6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.Collections; import java.util.Deque; @@ -36,6 +37,11 @@ public class NotNode extends BooleanNode { } @Override + public ValueType type(Context context) { + return value.type(context); + } + + @Override public Value evaluate(Context context) { return value.evaluate(context).not(); } @@ -45,6 +51,6 @@ public class NotNode extends BooleanNode { if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size()); return new NotNode(children.get(0)); } - + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 139709998b4..e5176f9966d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.util.ArrayDeque; import java.util.Deque; @@ -105,8 +106,14 @@ public final class ReferenceNode extends CompositeNode { } @Override + public ValueType type(Context context) { + // Don't support outputs of different type, for simplicity + return context.getType(name); + } + + @Override public Value evaluate(Context context) { - if (arguments.expressions().size()==0 && output==null) + if (arguments.expressions().isEmpty() && output == null) return context.get(name); return context.get(name, arguments, output); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index f6b1a1a8979..a8b82c560f7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.Tensor; import java.util.ArrayList; @@ -58,6 +59,11 @@ public class SetMembershipNode extends BooleanNode { } @Override + public ValueType type(Context context) { + return ValueType.doubleType(); + } + + @Override public Value evaluate(Context context) { Value value = testValue.evaluate(context); if (value instanceof TensorValue) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index c85a15ada64..97cfa2a5350 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -64,7 +64,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { return ValueType.tensorType(function.type(context)); } + public ValueType type(Context context) { return ValueType.of(function.type(context)); } @Override public Value evaluate(Context context) { @@ -112,7 +112,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public TensorType type(EvaluationContext context) { - return expression.type((Context)context).tensorType().orElse(TensorType.empty); + return expression.type((Context)context).tensorType(); } @Override |