diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2016-11-25 18:21:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-25 18:21:25 +0100 |
commit | 11b208db7d2422828c90aafa638f059306acbc24 (patch) | |
tree | 63d3f766b7a046b13b2b4fdc8e633fe71134847c /searchlib/src/main | |
parent | 5400980ea6bbac6ef385d089b5e9f9b100ecae71 (diff) |
Revert "Bratseth/tensor functions 3"
Diffstat (limited to 'searchlib/src/main')
15 files changed, 298 insertions, 583 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 620c6fad0b4..0dff0414ac2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -2,7 +2,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.tensor.functions.EvaluationContext; import java.util.Set; @@ -11,7 +10,7 @@ import java.util.Set; * * @author bratseth */ -public abstract class Context implements EvaluationContext { +public abstract class Context { /** * <p>Returns the value of a simple variable name.</p> @@ -42,7 +41,7 @@ public abstract class Context implements EvaluationContext { * "main" (or only) value. */ public Value get(String name, Arguments arguments,String output) { - if (arguments!=null && arguments.expressions().size() > 0) + if (arguments!=null && arguments.expressions().size()>0) throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" + name + arguments + "'"); if (output==null) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index f8dcd8a6127..2bae382d5bd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public Value compare(TruthOperator operator, Value value) { - return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); + public boolean compare(TruthOperator operator, Value value) { + return operator.evaluate(asDouble(), value.asDouble()); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 0e0d793bfd1..028dad16d21 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -98,6 +98,16 @@ public final class DoubleValue extends DoubleCompatibleValue { } @Override + public boolean compare(TruthOperator operator, Value value) { + try { + return operator.evaluate(this.value, value.asDouble()); + } + catch (UnsupportedOperationException e) { + throw unsupported("comparison",value); + } + } + + @Override public Value function(Function function, Value value) { // use the tensor implementation of max and min if the argument is a tensor if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 2dffe2a1100..9ee9a1f7a71 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -34,9 +34,11 @@ public class MapContext extends Context { * Creates a map context from a map. * The ownership of the map is transferred to this - it cannot be further modified by the caller. * All the Values of the map will be frozen. + * + * @since 5.1.5 */ public MapContext(Map<String,Value> bindings) { - this.bindings = bindings; + this.bindings=bindings; for (Value boundValue : bindings.values()) boundValue.freeze(); } @@ -65,9 +67,6 @@ public class MapContext extends Context { if (frozen) return bindings; return Collections.unmodifiableMap(bindings); } - - /** Returns a new, modifiable context containing all the bindings of this */ - public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); } /** Returns an unmodifiable map of the names of this */ public @Override Set<String> names() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index eb997ab818a..379b5755c7b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -68,10 +68,10 @@ public class StringValue extends Value { } @Override - public Value compare(TruthOperator operator, Value value) { + public boolean compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) - return new BooleanValue(this.equals(value)); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); + return this.equals(value); + throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index b1f4a7b20ca..12bede95aae 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -8,7 +8,6 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.TensorType; -import java.util.Collections; import java.util.Optional; /** @@ -18,7 +17,7 @@ import java.util.Optional; * * @author bratseth */ -@Beta + @Beta public class TensorValue extends Value { /** The tensor value of this */ @@ -54,7 +53,7 @@ public class TensorValue extends Value { @Override public Value negate() { - return new TensorValue(value.map((value) -> -value)); + return new TensorValue(value.apply((Double value) -> -value)); } @Override @@ -62,7 +61,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.add(((TensorValue)argument).value)); else - return new TensorValue(value.map((value) -> value + argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value + argument.asDouble())); } @Override @@ -70,7 +69,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.subtract(((TensorValue) argument).value)); else - return new TensorValue(value.map((value) -> value - argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value - argument.asDouble())); } @Override @@ -78,15 +77,35 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.multiply(((TensorValue) argument).value)); else - return new TensorValue(value.map((value) -> value * argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value * argument.asDouble())); } @Override public Value divide(Value argument) { if (argument instanceof TensorValue) - return new TensorValue(value.divide(((TensorValue) argument).value)); + throw new UnsupportedOperationException("Two tensors cannot be divided"); else - return new TensorValue(value.map((value) -> value / argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value / argument.asDouble())); + } + + public Value match(Value argument) { + return new TensorValue(value.match(asTensor(argument, "match"))); + } + + public Value min(Value argument) { + return new TensorValue(value.min(asTensor(argument, "min"))); + } + + public Value max(Value argument) { + return new TensorValue(value.max(asTensor(argument, "max"))); + } + + public Value sum(String dimension) { + return new TensorValue(value.sum(dimension)); + } + + public Value sum() { + return new DoubleValue(value.sum()); } private Tensor asTensor(Value value, String operationName) { @@ -103,37 +122,18 @@ public class TensorValue extends Value { } @Override - public Value compare(TruthOperator operator, Value argument) { - return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); - } - - private Tensor compareTensor(TruthOperator operator, Tensor argument) { - switch (operator) { - case LARGER: return value.larger(argument); - case LARGEREQUAL: return value.largerOrEqual(argument); - case SMALLER: return value.smaller(argument); - case SMALLEREQUAL: return value.smallerOrEqual(argument); - case EQUAL: return value.equal(argument); - case NOTEQUAL: return value.notEqual(argument); - default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); - } + public boolean compare(TruthOperator operator, Value value) { + throw new UnsupportedOperationException("A tensor cannot be compared with any value"); } @Override - public Value function(Function function, Value arg) { - if (arg instanceof TensorValue) - return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString()))); + public Value function(Function function, Value argument) { + if (function.equals(Function.min) && argument instanceof TensorValue) + return min(argument); + else if (function.equals(Function.max) && argument instanceof TensorValue) + return max(argument); else - return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); - } - - private Tensor functionOnTensor(Function function, Tensor argument) { - switch (function) { - case min: return value.min(argument); - case max: return value.max(argument); - case atan2: return value.atan2(argument); - default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); - } + return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble()))); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 8ce18265231..e5680edc68a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -42,7 +42,7 @@ public abstract class Value { public abstract Value divide(Value value); /** Perform the comparison specified by the operator between this value and the given value */ - public abstract Value compare(TruthOperator operator, Value value); + public abstract boolean compare(TruthOperator operator,Value value); /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function,Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index af05acb365a..882d16ebc1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -8,9 +8,10 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import java.util.*; /** - * A node which returns the outcome of a comparison. + * A node which returns true or false depending on the outcome of a comparison. * * @author bratseth + * @since 5.1.21 */ public class ComparisonNode extends BooleanNode { @@ -47,9 +48,9 @@ public class ComparisonNode extends BooleanNode { @Override public Value evaluate(Context context) { - Value leftValue = leftCondition.evaluate(context); - Value rightValue = rightCondition.evaluate(context); - return leftValue.compare(operator,rightValue); + Value leftValue=leftCondition.evaluate(context); + Value rightValue=rightCondition.evaluate(context); + return new BooleanValue(leftValue.compare(operator,rightValue)); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index 19b1a83ed99..675ce758faa 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -12,38 +12,31 @@ import static java.lang.Math.*; */ public enum Function implements Serializable { - abs { public double evaluate(double x, double y) { return abs(x); } }, + cosh { public double evaluate(double x, double y) { return cosh(x); } }, + sinh { public double evaluate(double x, double y) { return sinh(x); } }, + tanh { public double evaluate(double x, double y) { return tanh(x); } }, + cos { public double evaluate(double x, double y) { return cos(x); } }, + sin { public double evaluate(double x, double y) { return sin(x); } }, + tan { public double evaluate(double x, double y) { return tan(x); } }, acos { public double evaluate(double x, double y) { return acos(x); } }, asin { public double evaluate(double x, double y) { return asin(x); } }, atan { public double evaluate(double x, double y) { return atan(x); } }, - ceil { public double evaluate(double x, double y) { return ceil(x); } }, - cos { public double evaluate(double x, double y) { return cos(x); } }, - cosh { public double evaluate(double x, double y) { return cosh(x); } }, - elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } }, exp { public double evaluate(double x, double y) { return exp(x); } }, + log10 { public double evaluate(double x, double y) { return log10(x); } }, + log { public double evaluate(double x, double y) { return log(x); } }, + sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, + ceil { public double evaluate(double x, double y) { return ceil(x); } }, fabs { public double evaluate(double x, double y) { return abs(x); } }, floor { public double evaluate(double x, double y) { return floor(x); } }, isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } }, - log { public double evaluate(double x, double y) { return log(x); } }, - log10 { public double evaluate(double x, double y) { return log10(x); } }, relu { public double evaluate(double x, double y) { return max(x,0); } }, - round { public double evaluate(double x, double y) { return round(x); } }, sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } }, - sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } }, - sin { public double evaluate(double x, double y) { return sin(x); } }, - sinh { public double evaluate(double x, double y) { return sinh(x); } }, - square { public double evaluate(double x, double y) { return x*x; } }, - sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, - tan { public double evaluate(double x, double y) { return tan(x); } }, - tanh { public double evaluate(double x, double y) { return tanh(x); } }, - atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } }, - fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } }, + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } }, - max(2) { public double evaluate(double x, double y) { return max(x,y); } }, + fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, - mod(2) { public double evaluate(double x, double y) { return x % y; } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; + max(2) { public double evaluate(double x, double y) { return max(x,y); } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java deleted file mode 100644 index 7b48288598d..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ /dev/null @@ -1,122 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; -import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleUnaryOperator; - -/** - * A free, parametrized function - * - * @author bratseth - */ -public class LambdaFunctionNode extends CompositeNode { - - private final ImmutableList<String> arguments; - private final ExpressionNode functionExpression; - - public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { - // TODO: Verify that the function only accesses the arguments in mapperVariables - this.arguments = ImmutableList.copyOf(arguments); - this.functionExpression = functionExpression; - } - - @Override - public List<ExpressionNode> children() { - return Collections.singletonList(functionExpression); - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - if ( children.size() != 1) - throw new IllegalArgumentException("A lambda function must have a single child expression"); - return new LambdaFunctionNode(arguments, children.get(0)); - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")"; - } - - private String commaSeparated(List<String> list) { - StringBuilder b = new StringBuilder(); - for (String element : list) - b.append(element).append(","); - if (b.length() > 0) - b.setLength(b.length()-1); - return b.toString(); - } - - /** Evaluate this in a context which must have the arguments bound */ - @Override - public Value evaluate(Context context) { - return functionExpression.evaluate(context); - } - - /** - * Returns this as a double unary operator - * - * @throws IllegalStateException if this has more than one argument - */ - public DoubleUnaryOperator asDoubleUnaryOperator() { - if (arguments.size() > 1) - throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + - "Must have at most one argument " + " but has " + arguments); - return new DoubleUnaryLambda(); - } - - /** - * Returns this as a double binary operator - * - * @throws IllegalStateException if this has more than two arguments - */ - public DoubleBinaryOperator asDoubleBinaryOperator() { - if (arguments.size() > 2) - throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " + - "Must have at most two argument " + " but has " + arguments); - return new DoubleBinaryLambda(); - } - - private class DoubleUnaryLambda implements DoubleUnaryOperator { - - @Override - public double applyAsDouble(double operand) { - MapContext context = new MapContext(); - if (arguments.size() > 0) - context.put(arguments.get(0), operand); - return evaluate(context).asDouble(); - } - - @Override - public String toString() { - return LambdaFunctionNode.this.toString(); - } - - } - - private class DoubleBinaryLambda implements DoubleBinaryOperator { - - @Override - public double applyAsDouble(double left, double right) { - MapContext context = new MapContext(); - if (arguments.size() > 0) - context.put(arguments.get(0), left); - if (arguments.size() > 1) - context.put(arguments.get(1), right); - return evaluate(context).asDouble(); - } - - @Override - public String toString() { - return LambdaFunctionNode.this.toString(); - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java deleted file mode 100644 index 26d3f1dcc0e..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.EvaluationContext; -import com.yahoo.tensor.functions.PrimitiveTensorFunction; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.functions.ToStringContext; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; -import java.util.stream.Collectors; - -/** - * A node which performs a tensor function - * - * @author bratseth - */ - @Beta -public class TensorFunctionNode extends CompositeNode { - - private final TensorFunction function; - - public TensorFunctionNode(TensorFunction function) { - this.function = function; - } - - @Override - public List<ExpressionNode> children() { - return function.functionArguments().stream() - .map(f -> ((TensorFunctionExpressionNode)f).expression) - .collect(Collectors.toList()); - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - // Serialize as primitive - return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); - } - - @Override - public Value evaluate(Context context) { - return new TensorValue(function.evaluate(context)); - } - - public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { - return new TensorFunctionExpressionNode(node); - } - - /** - * A tensor function implemented by an expression. - * This allows us to pass expressions as tensor function arguments. - */ - public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { - - /** An expression which produces a tensor */ - private final ExpressionNode expression; - - public TensorFunctionExpressionNode(ExpressionNode expression) { - this.expression = expression; - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } - - @Override - public PrimitiveTensorFunction toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext context) { - Value result = expression.evaluate((Context)context); - if ( ! ( result instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + - "but this returns " + result + ", not a tensor"); - return ((TensorValue)result).asTensor(); - } - - @Override - public String toString(ToStringContext c) { - ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c; - return expression.toString(context.context, context.path, context.parent); - } - - } - - /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionNodeToStringContext implements ToStringContext { - - final SerializationContext context; - final Deque<String> path; - final CompositeNode parent; - - public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { - this.context = context; - this.path = path; - this.parent = parent; - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java new file mode 100644 index 00000000000..af309b3e8d8 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java @@ -0,0 +1,59 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; + +/** + * @author bratseth + */ + @Beta +public class TensorMatchNode extends CompositeNode { + + private final ExpressionNode left, right; + + public TensorMatchNode(ExpressionNode left, ExpressionNode right) { + this.left = left; + this.right = right; + } + + @Override + public List<ExpressionNode> children() { + List<ExpressionNode> children = new ArrayList<>(2); + children.add(left); + children.add(right); + return children; + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if ( children.size() != 2) + throw new IllegalArgumentException("A match product must have two children"); + return new TensorMatchNode(children.get(0), children.get(1)); + + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")"; + } + + @Override + public Value evaluate(Context context) { + return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context))); + } + + private TensorValue asTensor(Value value) { + if ( ! (value instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " + + "not a tensor: " + value); + return (TensorValue)value; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java new file mode 100644 index 00000000000..a1f83157e20 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java @@ -0,0 +1,65 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Optional; + +/** + * A node which sums over all cells in the argument tensor + * + * @author bratseth + */ + @Beta +public class TensorSumNode extends CompositeNode { + + /** The tensor to sum */ + private final ExpressionNode argument; + + /** The dimension to sum over, or empty to sum all cells to a scalar */ + private final Optional<String> dimension; + + public TensorSumNode(ExpressionNode argument, Optional<String> dimension) { + this.argument = argument; + this.dimension = dimension; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(argument); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument"); + return new TensorSumNode(children.get(0), dimension); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "sum(" + + argument.toString(context, path, parent) + + ( dimension.isPresent() ? ", " + dimension.get() : "" ) + + ")"; + } + + @Override + public Value evaluate(Context context) { + Value argumentValue = argument.evaluate(context); + if ( ! ( argumentValue instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " + + "but this returns " + argumentValue + ", not a tensor"); + TensorValue tensorArgument = (TensorValue)argumentValue; + if (dimension.isPresent()) + return tensorArgument.sum(dimension.get()); + else + return tensorArgument.sum(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java index 932975f3b63..60fe19f909f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java @@ -15,8 +15,7 @@ public enum TruthOperator implements Serializable { EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, - NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; + LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }; private final String operatorString; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 0fcfdb5d40c..78ad665c414 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -21,9 +21,10 @@ import com.yahoo.searchlib.rankingexpression.rule.*; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.*; -import com.yahoo.tensor.functions.*; +import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.TensorAddress; import java.util.Collections; +import java.util.Map; import java.util.LinkedHashMap; import java.util.Arrays; import java.util.ArrayList; @@ -59,83 +60,51 @@ TOKEN : <RSQUARE: "]"> | <LCURLY: "{"> | <RCURLY: "}"> | - <ADD: "+"> | <SUB: "-"> | <DIV: "/"> | <MUL: "*"> | <DOT: "."> | - <DOLLAR: "$"> | <COMMA: ","> | <COLON: ":"> | - <LE: "<="> | <LT: "<"> | <EQ: "=="> | - <NQ: "!="> | <AQ: "~="> | <GE: ">="> | <GT: ">"> | - <STRING: ("\"" (~["\""] | "\\\"")* "\"") | ("'" (~["'"] | "\\'")* "'")> | - <IF: "if"> | - <IN: "in"> | - <F: "f"> | - - <ABS: "abs"> | + <COSH: "cosh"> | + <SINH: "sinh"> | + <TANH: "tanh"> | + <COS: "cos"> | + <SIN: "sin"> | + <TAN: "tan"> | <ACOS: "acos"> | <ASIN: "asin"> | + <ATAN2: "atan2"> | <ATAN: "atan"> | - <CEIL: "ceil"> | - <COS: "cos"> | - <COSH: "cosh"> | - <ELU: "elu"> | <EXP: "exp"> | + <LDEXP: "ldexp"> | + <LOG10: "log10"> | + <LOG: "log"> | + <POW: "pow"> | + <SQRT: "sqrt"> | + <CEIL: "ceil"> | <FABS: "fabs"> | <FLOOR: "floor"> | + <FMOD: "fmod"> | + <MIN: "min"> | + <MAX: "max"> | <ISNAN: "isNan"> | - <LOG: "log"> | - <LOG10: "log10"> | + <IN: "in"> | + <SUM: "sum"> | + <MATCH: "match"> | <RELU: "relu"> | - <ROUND: "round"> | <SIGMOID: "sigmoid"> | - <SIGN: "sign"> | - <SIN: "sin"> | - <SINH: "sinh"> | - <SQUARE: "square"> | - <SQRT: "sqrt"> | - <TAN: "tan"> | - <TANH: "tanh"> | - - <ATAN2: "atan2"> | - <FMOD: "fmod"> | - <LDEXP: "ldexp"> | - // MAX - // MIN - <MOD: "mod"> | - <POW: "pow"> | - - <MAP: "map"> | - <REDUCE: "reduce"> | - <JOIN: "join"> | - <RENAME: "rename"> | - <TENSOR: "tensor"> | - <L1_NORMALIZE: "l1_normalize"> | - <L2_NORMALIZE: "l2_normalize"> | - <MATMUL: "matmul"> | - <SOFTMAX: "softmax"> | - <XW_PLUS_B: "xw_plus_b"> | - - <AVG: "avg" > | - <COUNT: "count"> | - <PROD: "prod"> | - <SUM: "sum"> | - <MAX: "max"> | - <MIN: "min"> | - <IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)> } @@ -206,7 +175,6 @@ TruthOperator comparator() : { } ( <LE> { return TruthOperator.SMALLEREQUAL; } | <LT> { return TruthOperator.SMALLER; } | <EQ> { return TruthOperator.EQUAL; } | - <NQ> { return TruthOperator.NOTEQUAL; } | <AQ> { return TruthOperator.APPROX_EQUAL; } | <GE> { return TruthOperator.LARGEREQUAL; } | <GT> { return TruthOperator.LARGER; } ) @@ -221,6 +189,7 @@ ExpressionNode value() : { ( [ LOOKAHEAD(2) <SUB> { neg = true; } ] ( ret = constantPrimitive() | + ret = constantTensor() | LOOKAHEAD(2) ret = ifExpression() | LOOKAHEAD(2) ret = function() | ret = feature() | @@ -310,6 +279,7 @@ ExpressionNode arg() : } { ( ret = constantPrimitive() | + ret = constantTensor() | LOOKAHEAD(2) ret = feature() | name = identifier() { ret = new NameNode(name); } ) { return ret; } @@ -320,11 +290,11 @@ ExpressionNode function() : ExpressionNode function; } { - ( function = scalarOrTensorFunction() | function = tensorFunction() ) + ( function = scalarFunction() | function = tensorFunction() ) { return function; } } -FunctionNode scalarOrTensorFunction() : +FunctionNode scalarFunction() : { Function function; ExpressionNode arg1, arg2; @@ -342,223 +312,61 @@ FunctionNode scalarOrTensorFunction() : ExpressionNode tensorFunction() : { - ExpressionNode tensorExpression; -} -{ - ( - tensorExpression = tensorMap() | - tensorExpression = tensorReduce() | - tensorExpression = tensorReduceComposites() | - tensorExpression = tensorJoin() | - tensorExpression = tensorRename() | - tensorExpression = tensorGenerate() | - tensorExpression = tensorL1Normalize() | - tensorExpression = tensorL2Normalize() | - tensorExpression = tensorMatmul() | - tensorExpression = tensorSoftmax() | - tensorExpression = tensorXwPlusB() - ) - { return tensorExpression; } -} - -ExpressionNode tensorMap() : -{ - ExpressionNode tensor; - LambdaFunctionNode doubleMapper; -} -{ - <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), - doubleMapper.asDoubleUnaryOperator())); } -} - -ExpressionNode tensorReduce() : -{ - ExpressionNode tensor; - Reduce.Aggregator aggregator; - List<String> dimensions = null; -} -{ - <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } -} - -ExpressionNode tensorReduceComposites() : -{ - ExpressionNode tensor; - Reduce.Aggregator aggregator; - List<String> dimensions = null; -} -{ - aggregator = tensorReduceAggregator() - <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } -} - -ExpressionNode tensorJoin() : -{ ExpressionNode tensor1, tensor2; - LambdaFunctionNode doubleJoiner; + String dimension = null; + TensorAddress address = null; } { - <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - doubleJoiner.asDoubleBinaryOperator())); } -} - -ExpressionNode tensorRename() : -{ - ExpressionNode tensor; - List<String> fromDimensions, toDimensions; -} -{ - <RENAME> <LBRACE> tensor = expression() <COMMA> - fromDimensions = bracedIdentifierList() <COMMA> - toDimensions = bracedIdentifierList() - <RBRACE> - { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } -} - -// TODO: Notice that null is parsed below -ExpressionNode tensorGenerate() : -{ - TensorType type; - LambdaFunctionNode generator; -} -{ - <TENSOR> <LBRACE> <RBRACE> <LBRACE> - { return new TensorFunctionNode(new Generate(null, null)); } -} - -ExpressionNode tensorL1Normalize() : -{ - ExpressionNode tensor; - String dimension; -} -{ - <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } -} - -ExpressionNode tensorL2Normalize() : -{ - ExpressionNode tensor; - String dimension; -} -{ - <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } -} - -ExpressionNode tensorMatmul() : -{ - ExpressionNode tensor1, tensor2; - String dimension; -} -{ - <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - dimension)); } -} - -ExpressionNode tensorSoftmax() : -{ - ExpressionNode tensor; - String dimension; -} -{ - <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } -} - -ExpressionNode tensorXwPlusB() : -{ - ExpressionNode tensor1, tensor2, tensor3; - String dimension; -} -{ - <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA> - tensor2 = expression() <COMMA> - tensor3 = expression() <COMMA> - dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - TensorFunctionNode.wrapArgument(tensor3), - dimension)); } -} - -LambdaFunctionNode lambdaFunction() : -{ - List<String> variables; - ExpressionNode functionExpression; -} -{ - ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) - { return new LambdaFunctionNode(variables, functionExpression); } -} - -Reduce.Aggregator tensorReduceAggregator() : -{ -} -{ - ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) - { return Reduce.Aggregator.valueOf(token.image); } + ( + <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE> + { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); } + ) | + ( + <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE> + { return new TensorMatchNode(tensor1, tensor2); } + ) } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { - Reduce.Aggregator aggregator; } { - ( <F> { return token.image; } ) | - ( <MAP> { return token.image; } ) | - ( <REDUCE> { return token.image; } ) | - ( <JOIN> { return token.image; } ) | - ( <RENAME> { return token.image; } ) | - ( <TENSOR> { return token.image; } ) | - ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) + ( <SUM> | <MATCH> ) + { return token.image; } } Function unaryFunctionName() : { } { - <ABS> { return Function.abs; } | + <COS> { return Function.cos; } | + <SIN> { return Function.sin; } | + <TAN> { return Function.tan; } | + <COSH> { return Function.cosh; } | + <SINH> { return Function.sinh; } | + <TANH> { return Function.tanh; } | <ACOS> { return Function.acos; } | <ASIN> { return Function.asin; } | <ATAN> { return Function.atan; } | - <CEIL> { return Function.ceil; } | - <COS> { return Function.cos; } | - <COSH> { return Function.cosh; } | - <ELU> { return Function.elu; } | <EXP> { return Function.exp; } | + <LOG10> { return Function.log10; } | + <LOG> { return Function.log; } | + <SQRT> { return Function.sqrt; } | + <CEIL> { return Function.ceil; } | <FABS> { return Function.fabs; } | <FLOOR> { return Function.floor; } | <ISNAN> { return Function.isNan; } | - <LOG> { return Function.log; } | - <LOG10> { return Function.log10; } | <RELU> { return Function.relu; } | - <ROUND> { return Function.round; } | - <SIGMOID> { return Function.sigmoid; } | - <SIGN> { return Function.sign; } | - <SIN> { return Function.sin; } | - <SINH> { return Function.sinh; } | - <SQUARE> { return Function.square; } | - <SQRT> { return Function.sqrt; } | - <TAN> { return Function.tan; } | - <TANH> { return Function.tanh; } + <SIGMOID> { return Function.sigmoid; } } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | - <FMOD> { return Function.fmod; } | <LDEXP> { return Function.ldexp; } | - <MAX> { return Function.max; } | + <POW> { return Function.pow; } | + <FMOD> { return Function.fmod; } | <MIN> { return Function.min; } | - <MOD> { return Function.mod; } | - <POW> { return Function.pow; } + <MAX> { return Function.max; } } List<ExpressionNode> expressionList() : @@ -597,28 +405,6 @@ String identifier() : <IDENTIFIER> { return token.image; } } -List<String> identifierList() : -{ - List<String> list = new ArrayList<String>(); - String element; -} -{ - ( element = identifier() { list.add(element); } )? - ( <COMMA> element = identifier() { list.add(element); } ) * - { return list; } -} - -List<String> bracedIdentifierList() : -{ - List<String> list = new ArrayList<String>(); - String element; -} -{ - ( element = identifier() { return Collections.singletonList(element); } ) - | - ( <LBRACE> list = identifierList() <RBRACE> { return list; } ) -} - // An identifier or integer String tag() : { @@ -629,16 +415,6 @@ String tag() : <INTEGER> { return token.image; } } -List<String> tagCommaLeadingList() : -{ - List<String> list = new ArrayList<String>(); - String element; -} -{ - ( <COMMA> element = tag() { list.add(element); } ) * - { return list; } -} - ConstantNode constantPrimitive() : { String sign = ""; @@ -658,3 +434,50 @@ Value primitiveValue() : ( <INTEGER> | <FLOAT> | <STRING> ) { return Value.parse(sign + token.image); } } + +ConstantNode constantTensor() : +{ + Value constantValue; +} +{ + <LCURLY> constantValue = tensorContent() <RCURLY> + { return new ConstantNode(constantValue); } +} + +TensorValue tensorContent() : +{ + Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>(); + TensorAddress address; + Double value; +} +{ + ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ? + ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) * + { return new TensorValue(new MapTensor(cells)); } +} + +TensorAddress tensorAddress() : +{ + List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>(); + String dimension; + String label; +} +{ + <LCURLY> + ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ? + ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) * + <RCURLY> + { return TensorAddress.fromUnsorted(elements); } +} + +String label() : +{ + String label; + +} +{ + ( label = tag() | + ( "-" { label = "-"; } ) ) + { return label; } +} + |