diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
commit | 1b4fde01d98bf724a54b6c1cfe3ffa4b29aec90e (patch) | |
tree | 20a127542b004eceb94e4d1344b3446df8092bd2 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression | |
parent | 28e3545728977a0be82159b8f278be8e772cb59b (diff) |
Propagate type information
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression')
11 files changed, 117 insertions, 33 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index 23dd841b0ef..a4d3c111356 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -48,12 +48,11 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { * * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and * ignoredUnknownValues is false - * @since 5.1.5 */ @Override public final void put(String name, Value value) { Integer index = nameToIndex().get(name); - if (index==null) { + if (index == null) { if (ignoreUnknownValues()) return; else @@ -70,24 +69,29 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { /** * Puts a value by index. * The value will be frozen if it isn't already. - * - * @since 5.1.5 */ public final void put(int index, Value value) { - values[index]=value.freeze(); + values[index] = value.freeze(); try { - doubleValues()[index]=value.asDouble(); + doubleValues()[index] = value.asDouble(); } catch (UnsupportedOperationException e) { - doubleValues()[index]=Double.NaN; // see getDouble below + doubleValues()[index] = Double.NaN; // see getDouble below } } + @Override + public ValueType getType(String name) { + Integer index = nameToIndex().get(name); + if (index == null) return null; + return values[index].type(); + } + /** Perform a slow lookup by name */ @Override public Value get(String name) { - Integer index=nameToIndex().get(name); - if (index==null) return DoubleValue.zero; + Integer index = nameToIndex().get(name); + if (index == null) return DoubleValue.zero; return values[index]; } @@ -100,8 +104,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { /** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */ @Override public final double getDouble(int index) { - double value=doubleValues()[index]; - if (value==Double.NaN) + double value = doubleValues()[index]; + if (value == Double.NaN) throw new UnsupportedOperationException("Value at " + index + " has no double representation"); return value; } @@ -111,7 +115,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { * in a different thread (i.e, name name to index map, different value set. */ public ArrayContext clone() { - ArrayContext clone=(ArrayContext)super.clone(); + ArrayContext clone = (ArrayContext)super.clone(); clone.values = new Value[nameToIndex().size()]; Arrays.fill(values,constantZero); return clone; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 0eeb0a9e630..ff8758bd1e7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -17,17 +18,29 @@ import java.util.stream.Collectors; public abstract class Context implements EvaluationContext { /** - * <p>Returns the value of a simple variable name.</p> + * Returns the value of a simple variable name. * * @param name the name of the variable whose value to return. * @return the value of the named variable. */ public abstract Value get(String name); + /** Returns the type of the value of the given variable as a tensor type, or null if there is no such variable */ + @Override + public TensorType getTensorType(String name) { + ValueType type = getType(name); + if (type == null) return null; + if (type.isTensor()) return type.tensorType().get(); + return TensorType.empty; // double as tensor + } + /** Returns a variable as a tensor */ @Override public Tensor getTensor(String name) { return get(name).asTensor(); } + /** Returns the type of the value of the given variable, or null if there is no such variable */ + public abstract ValueType getType(String name); + /** * <p>Returns the value of a <i>structured variable</i> on the form * <code>name(argument*)(.output)?</code>, where <i>argument</i> is any diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2ef4a2ede2f..c85a8f1c7e1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -4,18 +4,18 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. * * @author bratseth - * @since 5.1.21 */ public abstract class DoubleCompatibleValue extends Value { @Override + public ValueType type() { return ValueType.doubleType(); } + + @Override public boolean hasDouble() { return true; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index ceec9358b3c..34cd75df9cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -38,7 +38,6 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { * * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and * ignoredUnknownValues is false - * @since 5.1.5 */ @Override public final void put(String name, Value value) { @@ -57,11 +56,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { doubleValues()[index] = value; } - /** - * Puts a value by index. - * - * @since 5.1.5 - */ + /** Puts a value by index. */ public final void put(int index, Value value) { try { put(index, value.asDouble()); @@ -71,6 +66,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { } } + @Override + public ValueType getType(String name) { return ValueType.doubleType(); } + /** Perform a slow lookup by name */ @Override public Value get(String name) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 09895a0c2f6..2672fe6cd8e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -41,14 +41,20 @@ public class MapContext extends Context { boundValue.freeze(); } + /** Returns the type of the given value key, or null if it is not bound. */ + @Override + public ValueType getType(String key) { + Value value = bindings.get(key); + if (value == null) return null; + return value.type(); + } + /** * Returns the value of a key. 0 is returned if the given key is not bound in this. */ @Override public Value get(String key) { - Value value = bindings.get(key); - if (value == null) return DoubleValue.zero; - return value; + return bindings.getOrDefault(key, DoubleValue.zero); } /** @@ -67,7 +73,7 @@ public class MapContext extends Context { if (frozen) return bindings; return Collections.unmodifiableMap(bindings); } - + /** Returns a new, modifiable context containing all the bindings of this */ public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index dad69b31181..874b41ec3e1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -5,8 +5,6 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; /** * A string value. @@ -30,6 +28,9 @@ public class StringValue extends Value { this.value = value; } + @Override + public ValueType type() { return ValueType.doubleType(); } + /** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */ @Override public double asDouble() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 26c30fe5ed2..f1c65dc79d3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -25,6 +25,9 @@ public class TensorValue extends Value { } @Override + public ValueType type() { return ValueType.doubleType(); } + + @Override public double asDouble() { if (hasDouble()) return value.get(TensorAddress.of()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 40d70e0022c..856bfb3638d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -19,6 +19,9 @@ public abstract class Value { private boolean frozen=false; + /** Returns the type of this value */ + public abstract ValueType type(); + /** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */ public abstract double asDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java new file mode 100644 index 00000000000..06301372dcc --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java @@ -0,0 +1,37 @@ +package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.tensor.TensorType; + +import java.util.Optional; + +/** + * The type of a ranking expression value - either a double or a tensor. + * + * @author bratseth + */ +public class ValueType { + + private static final ValueType doubleValueType = new ValueType(Optional.empty()); + + private final Optional<TensorType> tensorType; + + private ValueType(Optional<TensorType> type) { + this.tensorType = type; + } + + /** Returns true if this is a double type */ + public boolean isDouble() { return ! tensorType.isPresent(); } + + /** Returns true if this is a tensor type */ + public boolean isTensor() { return tensorType.isPresent(); } + + /** The specific tensor type of this, or empty if this is not a tensor type */ + public Optional<TensorType> tensorType() { return tensorType; } + + /** Returns the type representing a double */ + public static ValueType doubleType() { return doubleValueType; } + + /** Returns a type representing the given tensor type */ + public static ValueType tensorType(TensorType type) { return new ValueType(Optional.of(type)); } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index eb303fc6446..31984dca54d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.io.Serializable; import java.util.Deque; @@ -41,6 +42,14 @@ public abstract class ExpressionNode implements Serializable { public abstract String toString(SerializationContext context, Deque<String> path, CompositeNode parent); /** + * Returns the type this will return if evaluated with the given context. + * + * @param context the variable type bindings to use for this evaluation + * @throws IllegalArgumentException if there are variables which are not bound in the given map + */ + public ValueType type(Context context) { return ValueType.doubleType(); } // double is default + + /** * Returns the value of evaluating this expression over the given context. * * @param context the variable bindings to use for this evaluation diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index b42570d3aea..c85a15ada64 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -5,7 +5,9 @@ import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; @@ -35,7 +37,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public List<ExpressionNode> children() { - return function.functionArguments().stream() + return function.arguments().stream() .map(this::toExpressionNode) .collect(Collectors.toList()); } @@ -52,7 +54,7 @@ public class TensorFunctionNode extends CompositeNode { List<TensorFunction> wrappedChildren = children.stream() .map(TensorFunctionExpressionNode::new) .collect(Collectors.toList()); - return new TensorFunctionNode(function.replaceArguments(wrappedChildren)); + return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @Override @@ -62,6 +64,9 @@ public class TensorFunctionNode extends CompositeNode { } @Override + public ValueType type(Context context) { return ValueType.tensorType(function.type(context)); } + + @Override public Value evaluate(Context context) { return new TensorValue(function.evaluate(context)); } @@ -84,7 +89,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> functionArguments() { + public List<TensorFunction> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(TensorFunctionExpressionNode::new) @@ -94,7 +99,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() .map(arg -> ((TensorFunctionExpressionNode)arg).expression) @@ -106,12 +111,17 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { + return expression.type((Context)context).tensorType().orElse(TensorType.empty); + } + + @Override public Tensor evaluate(EvaluationContext context) { Value result = expression.evaluate((Context)context); if ( ! ( result instanceof TensorValue)) throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + "but this returns " + result + ", not a tensor"); - return ((TensorValue)result).asTensor(); + return result.asTensor(); } @Override |