diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
commit | cb2dc3460fa31dffb51e54847283038e8a0ae93c (patch) | |
tree | e96497fe6b167f8867ad9cb225ea979a6e09dab8 /searchlib/src/main | |
parent | 437a2dc519cc991302c01acb8cd1df1e96b1283d (diff) |
Implement composite functions
Diffstat (limited to 'searchlib/src/main')
7 files changed, 207 insertions, 13 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 4cd1e1fc0ee..620c6fad0b4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.tensor.functions.EvaluationContext; import java.util.Set; @@ -10,7 +11,7 @@ import java.util.Set; * * @author bratseth */ -public abstract class Context { +public abstract class Context implements EvaluationContext { /** * <p>Returns the value of a simple variable name.</p> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index f559f8adaf5..dc422f2c8da 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -97,6 +97,14 @@ public class TensorValue extends Value { return new TensorValue(value.max(asTensor(argument, "max"))); } + public Value atan2(Value argument) { + return new TensorValue(value.atan2(asTensor(argument, "atan2"))); + } + + public Value equal(Value argument) { + return new TensorValue(value.equal(asTensor(argument, "equal"))); + } + public Value sum(String dimension) { return new TensorValue(value.sum(Collections.singletonList(dimension))); } @@ -129,6 +137,10 @@ public class TensorValue extends Value { return min(argument); else if (function.equals(Function.max) && argument instanceof TensorValue) return max(argument); + else if (function.equals(Function.atan2) && argument instanceof TensorValue) + return atan2(argument); + else if (function.equals(Function.equal) && argument instanceof TensorValue) + return equal(argument); else return new TensorValue(value.map((value) -> function.evaluate(value, argument.asDouble()))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index cbea2ad627e..b8e48dc2f05 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -43,7 +43,8 @@ public enum Function implements Serializable { max(2) { public double evaluate(double x, double y) { return max(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, mod(2) { public double evaluate(double x, double y) { return x % y; } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, + equal(2) { public double evaluate(double x, double y) { return x==y ? 1.0 : 0.0; } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java new file mode 100644 index 00000000000..ebd79f65578 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -0,0 +1,112 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.EvaluationContext; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.stream.Collectors; + +/** + * A node which performs a tensor function + * + * @author bratseth + */ + @Beta +public class TensorFunctionNode extends CompositeNode { + + private final TensorFunction function; + + public TensorFunctionNode(TensorFunction function) { + this.function = function; + } + + @Override + public List<ExpressionNode> children() { + return function.functionArguments().stream() + .map(f -> ((TensorFunctionExpressionNode)f).expression) + .collect(Collectors.toList()); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + // Serialize as primitive + return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); + } + + @Override + public Value evaluate(Context context) { + return new TensorValue(function.evaluate(context)); + } + + public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { + return new TensorFunctionExpressionNode(node); + } + + /** + * A tensor function implemented by an expression. + * This allows us to pass expressions as tensor function arguments. + */ + public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { + + /** An expression which produces a tensor */ + private final ExpressionNode expression; + + public TensorFunctionExpressionNode(ExpressionNode expression) { + this.expression = expression; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext context) { + Value result = expression.evaluate((Context)context); + if ( ! ( result instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + + "but this returns " + result + ", not a tensor"); + return ((TensorValue)result).asTensor(); + } + + @Override + public String toString(ToStringContext c) { + ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c; + return expression.toString(context.context, context.path, context.parent); + } + + } + + /** Allows passing serialization context arguments through TensorFunctions */ + private static class ExpressionNodeToStringContext implements ToStringContext { + + final SerializationContext context; + final Deque<String> path; + final CompositeNode parent; + + public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { + this.context = context; + this.path = path; + this.parent = parent; + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java index 4f73c632422..d4b95d12fdd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java @@ -6,7 +6,7 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.functions.ReduceFunction; +import com.yahoo.tensor.functions.Reduce; import java.util.Collections; import java.util.Deque; @@ -23,12 +23,12 @@ public class TensorReduceNode extends CompositeNode { /** The tensor to aggregate over */ private final ExpressionNode argument; - private final ReduceFunction.Aggregator aggregator; + private final Reduce.Aggregator aggregator; /** The dimensions to sum over, or empty to sum all cells */ private final ImmutableList<String> dimensions; - public TensorReduceNode(ExpressionNode argument, ReduceFunction.Aggregator aggregator, List<String> dimensions) { + public TensorReduceNode(ExpressionNode argument, Reduce.Aggregator aggregator, List<String> dimensions) { this.argument = argument; this.aggregator = aggregator; this.dimensions = ImmutableList.copyOf(dimensions); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java index 17a08beba8b..b7f21c215dc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.functions.ReduceFunction; import java.util.Collections; import java.util.Deque; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 6a2ce356722..0d290bf7688 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -24,7 +24,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.*; import com.yahoo.tensor.functions.*; import java.util.Collections; -import java.util.Map; import java.util.LinkedHashMap; import java.util.Arrays; import java.util.ArrayList; @@ -111,6 +110,7 @@ TOKEN : <TANH: "tanh"> | <ATAN2: "atan2"> | + <EQUAL: "equal"> | <FMOD: "fmod"> | <LDEXP: "ldexp"> | // MAX @@ -123,6 +123,11 @@ TOKEN : <JOIN: "join"> | <RENAME: "rename"> | <TENSOR: "tensor"> | + <L1_NORMALIZE: "l1_normalize"> | + <L2_NORMALIZE: "l2_normalize"> | + <MATMUL: "matmul"> | + <SOFTMAX: "softmax"> | + <XW_PLUS_B: "xw_plus_b"> | <AVG: "avg" > | <COUNT: "count"> | @@ -345,7 +350,12 @@ ExpressionNode tensorFunction() : tensorExpression = tensorReduceComposites() | tensorExpression = tensorJoin() | tensorExpression = tensorRename() | - tensorExpression = tensorGenerate() + tensorExpression = tensorGenerate() | + tensorExpression = tensorL1Normalize() | + tensorExpression = tensorL2Normalize() | + tensorExpression = tensorMatmul() | + tensorExpression = tensorSoftmax() | + tensorExpression = tensorXwPlusB() ) { return tensorExpression; } } @@ -363,7 +373,7 @@ ExpressionNode tensorMap() : ExpressionNode tensorReduce() : { ExpressionNode tensor; - ReduceFunction.Aggregator aggregator; + Reduce.Aggregator aggregator; List<String> dimensions = null; } { @@ -374,7 +384,7 @@ ExpressionNode tensorReduce() : ExpressionNode tensorReduceComposites() : { ExpressionNode tensor; - ReduceFunction.Aggregator aggregator; + Reduce.Aggregator aggregator; List<String> dimensions = null; } { @@ -417,6 +427,64 @@ ExpressionNode tensorGenerate() : { return null; } } +ExpressionNode tensorL1Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorL2Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorMatmul() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + dimension)); } +} + +ExpressionNode tensorSoftmax() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorXwPlusB() : +{ + ExpressionNode tensor1, tensor2, tensor3; + String dimension; +} +{ + <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA> + tensor2 = expression() <COMMA> + tensor3 = expression() <COMMA> + dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + TensorFunctionNode.wrapArgument(tensor3), + dimension)); } +} + LambdaFunctionNode lambdaFunction() : { List<String> variables; @@ -427,18 +495,18 @@ LambdaFunctionNode lambdaFunction() : { return new LambdaFunctionNode(variables, functionExpression); } } -ReduceFunction.Aggregator tensorReduceAggregator() : +Reduce.Aggregator tensorReduceAggregator() : { } { ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) - { return ReduceFunction.Aggregator.valueOf(token.image); } + { return Reduce.Aggregator.valueOf(token.image); } } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { - ReduceFunction.Aggregator aggregator; + Reduce.Aggregator aggregator; } { ( <F> { return token.image; } ) | @@ -481,6 +549,7 @@ Function unaryFunctionName() : { } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | + <EQUAL> { return Function.equal; } | <FMOD> { return Function.fmod; } | <LDEXP> { return Function.ldexp; } | <MAX> { return Function.max; } | |