diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
commit | cb2dc3460fa31dffb51e54847283038e8a0ae93c (patch) | |
tree | e96497fe6b167f8867ad9cb225ea979a6e09dab8 /searchlib/src/main/java/com | |
parent | 437a2dc519cc991302c01acb8cd1df1e96b1283d (diff) |
Implement composite functions
Diffstat (limited to 'searchlib/src/main/java/com')
6 files changed, 131 insertions, 6 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 4cd1e1fc0ee..620c6fad0b4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.tensor.functions.EvaluationContext; import java.util.Set; @@ -10,7 +11,7 @@ import java.util.Set; * * @author bratseth */ -public abstract class Context { +public abstract class Context implements EvaluationContext { /** * <p>Returns the value of a simple variable name.</p> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index f559f8adaf5..dc422f2c8da 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -97,6 +97,14 @@ public class TensorValue extends Value { return new TensorValue(value.max(asTensor(argument, "max"))); } + public Value atan2(Value argument) { + return new TensorValue(value.atan2(asTensor(argument, "atan2"))); + } + + public Value equal(Value argument) { + return new TensorValue(value.equal(asTensor(argument, "equal"))); + } + public Value sum(String dimension) { return new TensorValue(value.sum(Collections.singletonList(dimension))); } @@ -129,6 +137,10 @@ public class TensorValue extends Value { return min(argument); else if (function.equals(Function.max) && argument instanceof TensorValue) return max(argument); + else if (function.equals(Function.atan2) && argument instanceof TensorValue) + return atan2(argument); + else if (function.equals(Function.equal) && argument instanceof TensorValue) + return equal(argument); else return new TensorValue(value.map((value) -> function.evaluate(value, argument.asDouble()))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index cbea2ad627e..b8e48dc2f05 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -43,7 +43,8 @@ public enum Function implements Serializable { max(2) { public double evaluate(double x, double y) { return max(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, mod(2) { public double evaluate(double x, double y) { return x % y; } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, + equal(2) { public double evaluate(double x, double y) { return x==y ? 1.0 : 0.0; } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java new file mode 100644 index 00000000000..ebd79f65578 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -0,0 +1,112 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.EvaluationContext; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.stream.Collectors; + +/** + * A node which performs a tensor function + * + * @author bratseth + */ + @Beta +public class TensorFunctionNode extends CompositeNode { + + private final TensorFunction function; + + public TensorFunctionNode(TensorFunction function) { + this.function = function; + } + + @Override + public List<ExpressionNode> children() { + return function.functionArguments().stream() + .map(f -> ((TensorFunctionExpressionNode)f).expression) + .collect(Collectors.toList()); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + // Serialize as primitive + return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); + } + + @Override + public Value evaluate(Context context) { + return new TensorValue(function.evaluate(context)); + } + + public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { + return new TensorFunctionExpressionNode(node); + } + + /** + * A tensor function implemented by an expression. + * This allows us to pass expressions as tensor function arguments. + */ + public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { + + /** An expression which produces a tensor */ + private final ExpressionNode expression; + + public TensorFunctionExpressionNode(ExpressionNode expression) { + this.expression = expression; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext context) { + Value result = expression.evaluate((Context)context); + if ( ! ( result instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + + "but this returns " + result + ", not a tensor"); + return ((TensorValue)result).asTensor(); + } + + @Override + public String toString(ToStringContext c) { + ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c; + return expression.toString(context.context, context.path, context.parent); + } + + } + + /** Allows passing serialization context arguments through TensorFunctions */ + private static class ExpressionNodeToStringContext implements ToStringContext { + + final SerializationContext context; + final Deque<String> path; + final CompositeNode parent; + + public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { + this.context = context; + this.path = path; + this.parent = parent; + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java index 4f73c632422..d4b95d12fdd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java @@ -6,7 +6,7 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.functions.ReduceFunction; +import com.yahoo.tensor.functions.Reduce; import java.util.Collections; import java.util.Deque; @@ -23,12 +23,12 @@ public class TensorReduceNode extends CompositeNode { /** The tensor to aggregate over */ private final ExpressionNode argument; - private final ReduceFunction.Aggregator aggregator; + private final Reduce.Aggregator aggregator; /** The dimensions to sum over, or empty to sum all cells */ private final ImmutableList<String> dimensions; - public TensorReduceNode(ExpressionNode argument, ReduceFunction.Aggregator aggregator, List<String> dimensions) { + public TensorReduceNode(ExpressionNode argument, Reduce.Aggregator aggregator, List<String> dimensions) { this.argument = argument; this.aggregator = aggregator; this.dimensions = ImmutableList.copyOf(dimensions); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java index 17a08beba8b..b7f21c215dc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.functions.ReduceFunction; import java.util.Collections; import java.util.Deque; |