summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-29 21:41:06 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-29 21:41:06 +0100
commit9c4ba9bf5b96b8c62a9b8c5a6c20a9175c698b70 (patch)
tree45c33c04ceb3b03a92e6d2e7fde4fd2cab18ced4 /searchlib
parent1b4fde01d98bf724a54b6c1cfe3ffa4b29aec90e (diff)
Propagate type information through ranking expressions
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java23
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java23
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java11
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java6
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java2
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java24
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java8
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java4
19 files changed, 141 insertions, 37 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index ff8758bd1e7..a1e79df95e3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -30,8 +30,7 @@ public abstract class Context implements EvaluationContext {
public TensorType getTensorType(String name) {
ValueType type = getType(name);
if (type == null) return null;
- if (type.isTensor()) return type.tensorType().get();
- return TensorType.empty; // double as tensor
+ return type.tensorType();
}
/** Returns a variable as a tensor */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java
index 06301372dcc..046ad7861ef 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java
@@ -2,8 +2,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo
import com.yahoo.tensor.TensorType;
-import java.util.Optional;
-
/**
* The type of a ranking expression value - either a double or a tensor.
*
@@ -11,27 +9,24 @@ import java.util.Optional;
*/
public class ValueType {
- private static final ValueType doubleValueType = new ValueType(Optional.empty());
+ private static final ValueType doubleValueType = new ValueType(TensorType.empty);
- private final Optional<TensorType> tensorType;
+ private final TensorType tensorType;
- private ValueType(Optional<TensorType> type) {
- this.tensorType = type;
+ private ValueType(TensorType tensorType) {
+ this.tensorType = tensorType;
}
- /** Returns true if this is a double type */
- public boolean isDouble() { return ! tensorType.isPresent(); }
-
- /** Returns true if this is a tensor type */
- public boolean isTensor() { return tensorType.isPresent(); }
+ /** Returns true if this is the double type */
+ public boolean isDouble() { return tensorType.rank() == 0; }
- /** The specific tensor type of this, or empty if this is not a tensor type */
- public Optional<TensorType> tensorType() { return tensorType; }
+ /** The type of this as a tensor type. The double type is the empty tensor type (rank 0) */
+ public TensorType tensorType() { return tensorType; }
/** Returns the type representing a double */
public static ValueType doubleType() { return doubleValueType; }
/** Returns a type representing the given tensor type */
- public static ValueType tensorType(TensorType type) { return new ValueType(Optional.of(type)); }
+ public static ValueType of(TensorType type) { return new ValueType(type); }
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
index 372fb00431b..b4e126f69e0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
@@ -24,6 +25,9 @@ public class GBDTForestNode extends ExpressionNode {
}
@Override
+ public final ValueType type(Context context) { return ValueType.doubleType(); }
+
+ @Override
public final Value evaluate(Context context) {
int pc = 0;
double treeSum = 0;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index 4d7b4835892..f085194a7df 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
@@ -49,6 +50,9 @@ public final class GBDTNode extends ExpressionNode {
public final double[] values() { return values; }
@Override
+ public final ValueType type(Context context) { return ValueType.doubleType(); }
+
+ @Override
public final Value evaluate(Context context) {
return new DoubleValue(evaluate(values,0,context));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index 518a15bcc87..d45037b6044 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -4,8 +4,15 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Join;
-import java.util.*;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
/**
* A binary mathematical operation
@@ -73,14 +80,26 @@ public final class ArithmeticNode extends CompositeNode {
}
@Override
+ public ValueType type(Context context) {
+ // Compute type using tensor types as arithmetic operators are supported on tensors
+ // and is correct also in the special case of doubles.
+ // As all our functions are type-commutative, we don't need to take operator precedence into account
+ TensorType type = children.get(0).type(context).tensorType();
+ for (int i = 1; i < children.size(); i++)
+ type = Join.outputType(type, children.get(i).type(context).tensorType());
+ return ValueType.of(type);
+ }
+
+ @Override
public Value evaluate(Context context) {
Iterator<ExpressionNode> child = children.iterator();
+ // Apply in precedence order:
Deque<ValueItem> stack = new ArrayDeque<>();
stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context)));
for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
ArithmeticOperator op = it.next();
- if (!stack.isEmpty()) {
+ if ( ! stack.isEmpty()) {
while (stack.peek().op.hasPrecedenceOver(op)) {
popStack(stack);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 9484f789169..fdbb22093ea 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -1,11 +1,13 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
/**
* A node which returns the outcome of a comparison.
@@ -46,6 +48,11 @@ public class ComparisonNode extends BooleanNode {
}
@Override
+ public ValueType type(Context context) {
+ return ValueType.doubleType(); // by definition
+ }
+
+ @Override
public Value evaluate(Context context) {
Value leftValue = leftCondition.evaluate(context);
Value rightValue = rightCondition.evaluate(context);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
index cd473ae6a6f..e6074a5f745 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.Deque;
@@ -47,6 +48,9 @@ public final class ConstantNode extends ExpressionNode {
}
@Override
+ public ValueType type(Context context) { return value.type(); }
+
+ @Override
public Value evaluate(Context context) {
return value;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
index b5d7c41d698..8404226c33b 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.Collections;
import java.util.Deque;
@@ -48,6 +49,11 @@ public final class EmbracedNode extends CompositeNode {
}
@Override
+ public ValueType type(Context context) {
+ return value.type(context);
+ }
+
+ @Override
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
if (newChildren.size() != 1)
throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size());
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
index 31984dca54d..5d06a562b5d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
@@ -47,7 +47,7 @@ public abstract class ExpressionNode implements Serializable {
* @param context the variable type bindings to use for this evaluation
* @throws IllegalArgumentException if there are variables which are not bound in the given map
*/
- public ValueType type(Context context) { return ValueType.doubleType(); } // double is default
+ public abstract ValueType type(Context context);
/**
* Returns the value of evaluating this expression over the given context.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index 142e282e5c6..b187b8f029c 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -4,6 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.functions.Join;
import java.util.ArrayList;
import java.util.Collections;
@@ -64,16 +66,29 @@ public final class FunctionNode extends CompositeNode {
}
@Override
+ public ValueType type(Context context) {
+ if (arguments.expressions().size() == 0)
+ return ValueType.doubleType();
+
+ ValueType argument1Type = arguments.expressions().get(0).type(context);
+ if (arguments.expressions().size() == 1)
+ return argument1Type;
+
+ ValueType argument2Type = arguments.expressions().get(1).type(context);
+ return ValueType.of(Join.outputType(argument1Type.tensorType(), argument2Type.tensorType()));
+ }
+
+ @Override
public Value evaluate(Context context) {
if (arguments.expressions().size() == 0)
- return DoubleValue.zero.function(function,DoubleValue.zero);
+ return DoubleValue.zero.function(function ,DoubleValue.zero);
Value argument1 = arguments.expressions().get(0).evaluate(context);
if (arguments.expressions().size() == 1)
return argument1.function(function, DoubleValue.zero);
Value argument2 = arguments.expressions().get(1).evaluate(context);
- return argument1.function(function,argument2);
+ return argument1.function(function, argument2);
}
/** Returns a new function node with the children replaced by the given children */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index 9da1ba40144..fcd40bed4d0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
@@ -46,6 +47,9 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
return generator.toString(context, path, this);
}
+ @Override
+ public ValueType type(Context context) { return ValueType.of(type); }
+
/** Evaluate this in a context which must have the arguments bound */
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index 1b429de0be5..b9866bec027 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -3,13 +3,14 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.*;
/**
* A conditional branch of a ranking expression.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
* @author bratseth
*/
public final class IfNode extends CompositeNode {
@@ -70,6 +71,17 @@ public final class IfNode extends CompositeNode {
}
@Override
+ public ValueType type(Context context) {
+ ValueType trueType = trueExpression.type(context);
+ ValueType falseType = falseExpression.type(context);
+ if ( ! trueType.equals(falseType))
+ throw new IllegalArgumentException("An if expression must produce a value of the same type in both " +
+ "alternatives, but the 'true' type is " + trueType + " while the " +
+ "'false' type is " + falseType);
+ return trueType;
+ }
+
+ @Override
public Value evaluate(Context context) {
if (condition.evaluate(context).asBoolean())
return trueExpression.evaluate(context);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 78206d75d0d..b898529c4b9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.Collections;
import java.util.Deque;
@@ -14,20 +15,20 @@ import java.util.function.DoubleUnaryOperator;
/**
* A free, parametrized function
- *
+ *
* @author bratseth
*/
public class LambdaFunctionNode extends CompositeNode {
private final ImmutableList<String> arguments;
private final ExpressionNode functionExpression;
-
+
public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
// TODO: Verify that the function only accesses the given arguments
this.arguments = ImmutableList.copyOf(arguments);
this.functionExpression = functionExpression;
}
-
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(functionExpression);
@@ -54,19 +55,24 @@ public class LambdaFunctionNode extends CompositeNode {
return b.toString();
}
+ @Override
+ public ValueType type(Context context) {
+ return ValueType.doubleType(); // by definition - no nested lambdas
+ }
+
/** Evaluate this in a context which must have the arguments bound */
@Override
public Value evaluate(Context context) {
return functionExpression.evaluate(context);
}
-
- /**
+
+ /**
* Returns this as a double unary operator
- *
- * @throws IllegalStateException if this has more than one argument
+ *
+ * @throws IllegalStateException if this has more than one argument
*/
public DoubleUnaryOperator asDoubleUnaryOperator() {
- if (arguments.size() > 1)
+ if (arguments.size() > 1)
throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
"Must have at most one argument " + " but has " + arguments);
return new DoubleUnaryLambda();
@@ -93,7 +99,7 @@ public class LambdaFunctionNode extends CompositeNode {
context.put(arguments.get(0), operand);
return evaluate(context).asDouble();
}
-
+
@Override
public String toString() {
return LambdaFunctionNode.this.toString();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index 69df572272a..cf6475238c4 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.Deque;
@@ -30,6 +31,9 @@ public final class NameNode extends ExpressionNode {
}
@Override
+ public ValueType type(Context context) { throw new RuntimeException("Named nodes can not have a type"); }
+
+ @Override
public Value evaluate(Context context) {
throw new RuntimeException("Name nodes should never be evaluated");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
index 61c20a97b64..2e685a6c8ab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.Collections;
import java.util.Deque;
@@ -36,6 +37,11 @@ public class NegativeNode extends CompositeNode {
}
@Override
+ public ValueType type(Context context) {
+ return value.type(context);
+ }
+
+ @Override
public Value evaluate(Context context) {
return value.evaluate(context).negate();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
index 8c459a032bd..c4b940f1bd6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.Collections;
import java.util.Deque;
@@ -36,6 +37,11 @@ public class NotNode extends BooleanNode {
}
@Override
+ public ValueType type(Context context) {
+ return value.type(context);
+ }
+
+ @Override
public Value evaluate(Context context) {
return value.evaluate(context).not();
}
@@ -45,6 +51,6 @@ public class NotNode extends BooleanNode {
if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size());
return new NotNode(children.get(0));
}
-
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 139709998b4..e5176f9966d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.util.ArrayDeque;
import java.util.Deque;
@@ -105,8 +106,14 @@ public final class ReferenceNode extends CompositeNode {
}
@Override
+ public ValueType type(Context context) {
+ // Don't support outputs of different type, for simplicity
+ return context.getType(name);
+ }
+
+ @Override
public Value evaluate(Context context) {
- if (arguments.expressions().size()==0 && output==null)
+ if (arguments.expressions().isEmpty() && output == null)
return context.get(name);
return context.get(name, arguments, output);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index f6b1a1a8979..a8b82c560f7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.tensor.Tensor;
import java.util.ArrayList;
@@ -58,6 +59,11 @@ public class SetMembershipNode extends BooleanNode {
}
@Override
+ public ValueType type(Context context) {
+ return ValueType.doubleType();
+ }
+
+ @Override
public Value evaluate(Context context) {
Value value = testValue.evaluate(context);
if (value instanceof TensorValue) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index c85a15ada64..97cfa2a5350 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -64,7 +64,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) { return ValueType.tensorType(function.type(context)); }
+ public ValueType type(Context context) { return ValueType.of(function.type(context)); }
@Override
public Value evaluate(Context context) {
@@ -112,7 +112,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public TensorType type(EvaluationContext context) {
- return expression.type((Context)context).tensorType().orElse(TensorType.empty);
+ return expression.type((Context)context).tensorType();
}
@Override