diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2016-11-26 22:45:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-26 22:45:20 +0100 |
commit | 2f55986b4de9420e5728c5abbaafb69fb2f10a34 (patch) | |
tree | 9a6a77f76d25620771dfe7ab5de49910c4321fc5 /searchlib/src/main | |
parent | 2bc82ba9d9698214e703f19039387609d82b12f8 (diff) |
Revert "Revert "Bratseth/tensor functions 3""
Diffstat (limited to 'searchlib/src/main')
15 files changed, 583 insertions, 298 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 0dff0414ac2..620c6fad0b4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.tensor.functions.EvaluationContext; import java.util.Set; @@ -10,7 +11,7 @@ import java.util.Set; * * @author bratseth */ -public abstract class Context { +public abstract class Context implements EvaluationContext { /** * <p>Returns the value of a simple variable name.</p> @@ -41,7 +42,7 @@ public abstract class Context { * "main" (or only) value. */ public Value get(String name, Arguments arguments,String output) { - if (arguments!=null && arguments.expressions().size()>0) + if (arguments!=null && arguments.expressions().size() > 0) throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" + name + arguments + "'"); if (output==null) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2bae382d5bd..f8dcd8a6127 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - return operator.evaluate(asDouble(), value.asDouble()); + public Value compare(TruthOperator operator, Value value) { + return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 028dad16d21..0e0d793bfd1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue { } @Override - public boolean compare(TruthOperator operator, Value value) { - try { - return operator.evaluate(this.value, value.asDouble()); - } - catch (UnsupportedOperationException e) { - throw unsupported("comparison",value); - } - } - - @Override public Value function(Function function, Value value) { // use the tensor implementation of max and min if the argument is a tensor if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 9ee9a1f7a71..2dffe2a1100 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -34,11 +34,9 @@ public class MapContext extends Context { * Creates a map context from a map. * The ownership of the map is transferred to this - it cannot be further modified by the caller. * All the Values of the map will be frozen. - * - * @since 5.1.5 */ public MapContext(Map<String,Value> bindings) { - this.bindings=bindings; + this.bindings = bindings; for (Value boundValue : bindings.values()) boundValue.freeze(); } @@ -67,6 +65,9 @@ public class MapContext extends Context { if (frozen) return bindings; return Collections.unmodifiableMap(bindings); } + + /** Returns a new, modifiable context containing all the bindings of this */ + public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); } /** Returns an unmodifiable map of the names of this */ public @Override Set<String> names() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 379b5755c7b..eb997ab818a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -68,10 +68,10 @@ public class StringValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { + public Value compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) - return this.equals(value); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='"); + return new BooleanValue(this.equals(value)); + throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 12bede95aae..b1f4a7b20ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.TensorType; +import java.util.Collections; import java.util.Optional; /** @@ -17,7 +18,7 @@ import java.util.Optional; * * @author bratseth */ - @Beta +@Beta public class TensorValue extends Value { /** The tensor value of this */ @@ -53,7 +54,7 @@ public class TensorValue extends Value { @Override public Value negate() { - return new TensorValue(value.apply((Double value) -> -value)); + return new TensorValue(value.map((value) -> -value)); } @Override @@ -61,7 +62,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.add(((TensorValue)argument).value)); else - return new TensorValue(value.apply((Double value) -> value + argument.asDouble())); + return new TensorValue(value.map((value) -> value + argument.asDouble())); } @Override @@ -69,7 +70,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.subtract(((TensorValue) argument).value)); else - return new TensorValue(value.apply((Double value) -> value - argument.asDouble())); + return new TensorValue(value.map((value) -> value - argument.asDouble())); } @Override @@ -77,35 +78,15 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.multiply(((TensorValue) argument).value)); else - return new TensorValue(value.apply((Double value) -> value * argument.asDouble())); + return new TensorValue(value.map((value) -> value * argument.asDouble())); } @Override public Value divide(Value argument) { if (argument instanceof TensorValue) - throw new UnsupportedOperationException("Two tensors cannot be divided"); + return new TensorValue(value.divide(((TensorValue) argument).value)); else - return new TensorValue(value.apply((Double value) -> value / argument.asDouble())); - } - - public Value match(Value argument) { - return new TensorValue(value.match(asTensor(argument, "match"))); - } - - public Value min(Value argument) { - return new TensorValue(value.min(asTensor(argument, "min"))); - } - - public Value max(Value argument) { - return new TensorValue(value.max(asTensor(argument, "max"))); - } - - public Value sum(String dimension) { - return new TensorValue(value.sum(dimension)); - } - - public Value sum() { - return new DoubleValue(value.sum()); + return new TensorValue(value.map((value) -> value / argument.asDouble())); } private Tensor asTensor(Value value, String operationName) { @@ -122,18 +103,37 @@ public class TensorValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - throw new UnsupportedOperationException("A tensor cannot be compared with any value"); + public Value compare(TruthOperator operator, Value argument) { + return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); + } + + private Tensor compareTensor(TruthOperator operator, Tensor argument) { + switch (operator) { + case LARGER: return value.larger(argument); + case LARGEREQUAL: return value.largerOrEqual(argument); + case SMALLER: return value.smaller(argument); + case SMALLEREQUAL: return value.smallerOrEqual(argument); + case EQUAL: return value.equal(argument); + case NOTEQUAL: return value.notEqual(argument); + default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); + } } @Override - public Value function(Function function, Value argument) { - if (function.equals(Function.min) && argument instanceof TensorValue) - return min(argument); - else if (function.equals(Function.max) && argument instanceof TensorValue) - return max(argument); + public Value function(Function function, Value arg) { + if (arg instanceof TensorValue) + return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString()))); else - return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble()))); + return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); + } + + private Tensor functionOnTensor(Function function, Tensor argument) { + switch (function) { + case min: return value.min(argument); + case max: return value.max(argument); + case atan2: return value.atan2(argument); + default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + } } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index e5680edc68a..8ce18265231 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -42,7 +42,7 @@ public abstract class Value { public abstract Value divide(Value value); /** Perform the comparison specified by the operator between this value and the given value */ - public abstract boolean compare(TruthOperator operator,Value value); + public abstract Value compare(TruthOperator operator, Value value); /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function,Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 882d16ebc1c..af05acb365a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import java.util.*; /** - * A node which returns true or false depending on the outcome of a comparison. + * A node which returns the outcome of a comparison. * * @author bratseth - * @since 5.1.21 */ public class ComparisonNode extends BooleanNode { @@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode { @Override public Value evaluate(Context context) { - Value leftValue=leftCondition.evaluate(context); - Value rightValue=rightCondition.evaluate(context); - return new BooleanValue(leftValue.compare(operator,rightValue)); + Value leftValue = leftCondition.evaluate(context); + Value rightValue = rightCondition.evaluate(context); + return leftValue.compare(operator,rightValue); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index 675ce758faa..19b1a83ed99 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -12,31 +12,38 @@ import static java.lang.Math.*; */ public enum Function implements Serializable { - cosh { public double evaluate(double x, double y) { return cosh(x); } }, - sinh { public double evaluate(double x, double y) { return sinh(x); } }, - tanh { public double evaluate(double x, double y) { return tanh(x); } }, - cos { public double evaluate(double x, double y) { return cos(x); } }, - sin { public double evaluate(double x, double y) { return sin(x); } }, - tan { public double evaluate(double x, double y) { return tan(x); } }, + abs { public double evaluate(double x, double y) { return abs(x); } }, acos { public double evaluate(double x, double y) { return acos(x); } }, asin { public double evaluate(double x, double y) { return asin(x); } }, atan { public double evaluate(double x, double y) { return atan(x); } }, - exp { public double evaluate(double x, double y) { return exp(x); } }, - log10 { public double evaluate(double x, double y) { return log10(x); } }, - log { public double evaluate(double x, double y) { return log(x); } }, - sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, ceil { public double evaluate(double x, double y) { return ceil(x); } }, + cos { public double evaluate(double x, double y) { return cos(x); } }, + cosh { public double evaluate(double x, double y) { return cosh(x); } }, + elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } }, + exp { public double evaluate(double x, double y) { return exp(x); } }, fabs { public double evaluate(double x, double y) { return abs(x); } }, floor { public double evaluate(double x, double y) { return floor(x); } }, isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } }, + log { public double evaluate(double x, double y) { return log(x); } }, + log10 { public double evaluate(double x, double y) { return log10(x); } }, relu { public double evaluate(double x, double y) { return max(x,0); } }, + round { public double evaluate(double x, double y) { return round(x); } }, sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } }, + sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } }, + sin { public double evaluate(double x, double y) { return sin(x); } }, + sinh { public double evaluate(double x, double y) { return sinh(x); } }, + square { public double evaluate(double x, double y) { return x*x; } }, + sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, + tan { public double evaluate(double x, double y) { return tan(x); } }, + tanh { public double evaluate(double x, double y) { return tanh(x); } }, + atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, - ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } }, fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } }, + ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } }, + max(2) { public double evaluate(double x, double y) { return max(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, - max(2) { public double evaluate(double x, double y) { return max(x,y); } }; + mod(2) { public double evaluate(double x, double y) { return x % y; } }, + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java new file mode 100644 index 00000000000..7b48288598d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -0,0 +1,122 @@ +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; + +/** + * A free, parametrized function + * + * @author bratseth + */ +public class LambdaFunctionNode extends CompositeNode { + + private final ImmutableList<String> arguments; + private final ExpressionNode functionExpression; + + public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { + // TODO: Verify that the function only accesses the arguments in mapperVariables + this.arguments = ImmutableList.copyOf(arguments); + this.functionExpression = functionExpression; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(functionExpression); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if ( children.size() != 1) + throw new IllegalArgumentException("A lambda function must have a single child expression"); + return new LambdaFunctionNode(arguments, children.get(0)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")"; + } + + private String commaSeparated(List<String> list) { + StringBuilder b = new StringBuilder(); + for (String element : list) + b.append(element).append(","); + if (b.length() > 0) + b.setLength(b.length()-1); + return b.toString(); + } + + /** Evaluate this in a context which must have the arguments bound */ + @Override + public Value evaluate(Context context) { + return functionExpression.evaluate(context); + } + + /** + * Returns this as a double unary operator + * + * @throws IllegalStateException if this has more than one argument + */ + public DoubleUnaryOperator asDoubleUnaryOperator() { + if (arguments.size() > 1) + throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + + "Must have at most one argument " + " but has " + arguments); + return new DoubleUnaryLambda(); + } + + /** + * Returns this as a double binary operator + * + * @throws IllegalStateException if this has more than two arguments + */ + public DoubleBinaryOperator asDoubleBinaryOperator() { + if (arguments.size() > 2) + throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " + + "Must have at most two argument " + " but has " + arguments); + return new DoubleBinaryLambda(); + } + + private class DoubleUnaryLambda implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { + MapContext context = new MapContext(); + if (arguments.size() > 0) + context.put(arguments.get(0), operand); + return evaluate(context).asDouble(); + } + + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } + + } + + private class DoubleBinaryLambda implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { + MapContext context = new MapContext(); + if (arguments.size() > 0) + context.put(arguments.get(0), left); + if (arguments.size() > 1) + context.put(arguments.get(1), right); + return evaluate(context).asDouble(); + } + + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java new file mode 100644 index 00000000000..26d3f1dcc0e --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -0,0 +1,111 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.EvaluationContext; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.stream.Collectors; + +/** + * A node which performs a tensor function + * + * @author bratseth + */ + @Beta +public class TensorFunctionNode extends CompositeNode { + + private final TensorFunction function; + + public TensorFunctionNode(TensorFunction function) { + this.function = function; + } + + @Override + public List<ExpressionNode> children() { + return function.functionArguments().stream() + .map(f -> ((TensorFunctionExpressionNode)f).expression) + .collect(Collectors.toList()); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + // Serialize as primitive + return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); + } + + @Override + public Value evaluate(Context context) { + return new TensorValue(function.evaluate(context)); + } + + public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { + return new TensorFunctionExpressionNode(node); + } + + /** + * A tensor function implemented by an expression. + * This allows us to pass expressions as tensor function arguments. + */ + public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { + + /** An expression which produces a tensor */ + private final ExpressionNode expression; + + public TensorFunctionExpressionNode(ExpressionNode expression) { + this.expression = expression; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext context) { + Value result = expression.evaluate((Context)context); + if ( ! ( result instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + + "but this returns " + result + ", not a tensor"); + return ((TensorValue)result).asTensor(); + } + + @Override + public String toString(ToStringContext c) { + ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c; + return expression.toString(context.context, context.path, context.parent); + } + + } + + /** Allows passing serialization context arguments through TensorFunctions */ + private static class ExpressionNodeToStringContext implements ToStringContext { + + final SerializationContext context; + final Deque<String> path; + final CompositeNode parent; + + public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { + this.context = context; + this.path = path; + this.parent = parent; + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java deleted file mode 100644 index af309b3e8d8..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.ArrayList; -import java.util.Deque; -import java.util.List; - -/** - * @author bratseth - */ - @Beta -public class TensorMatchNode extends CompositeNode { - - private final ExpressionNode left, right; - - public TensorMatchNode(ExpressionNode left, ExpressionNode right) { - this.left = left; - this.right = right; - } - - @Override - public List<ExpressionNode> children() { - List<ExpressionNode> children = new ArrayList<>(2); - children.add(left); - children.add(right); - return children; - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - if ( children.size() != 2) - throw new IllegalArgumentException("A match product must have two children"); - return new TensorMatchNode(children.get(0), children.get(1)); - - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")"; - } - - @Override - public Value evaluate(Context context) { - return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context))); - } - - private TensorValue asTensor(Value value) { - if ( ! (value instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " + - "not a tensor: " + value); - return (TensorValue)value; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java deleted file mode 100644 index a1f83157e20..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; -import java.util.Optional; - -/** - * A node which sums over all cells in the argument tensor - * - * @author bratseth - */ - @Beta -public class TensorSumNode extends CompositeNode { - - /** The tensor to sum */ - private final ExpressionNode argument; - - /** The dimension to sum over, or empty to sum all cells to a scalar */ - private final Optional<String> dimension; - - public TensorSumNode(ExpressionNode argument, Optional<String> dimension) { - this.argument = argument; - this.dimension = dimension; - } - - @Override - public List<ExpressionNode> children() { - return Collections.singletonList(argument); - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument"); - return new TensorSumNode(children.get(0), dimension); - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return "sum(" + - argument.toString(context, path, parent) + - ( dimension.isPresent() ? ", " + dimension.get() : "" ) + - ")"; - } - - @Override - public Value evaluate(Context context) { - Value argumentValue = argument.evaluate(context); - if ( ! ( argumentValue instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " + - "but this returns " + argumentValue + ", not a tensor"); - TensorValue tensorArgument = (TensorValue)argumentValue; - if (dimension.isPresent()) - return tensorArgument.sum(dimension.get()); - else - return tensorArgument.sum(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java index 60fe19f909f..932975f3b63 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java @@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable { EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }; + LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, + NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; private final String operatorString; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 78ad665c414..0fcfdb5d40c 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -21,10 +21,9 @@ import com.yahoo.searchlib.rankingexpression.rule.*; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.*; +import com.yahoo.tensor.functions.*; import java.util.Collections; -import java.util.Map; import java.util.LinkedHashMap; import java.util.Arrays; import java.util.ArrayList; @@ -60,51 +59,83 @@ TOKEN : <RSQUARE: "]"> | <LCURLY: "{"> | <RCURLY: "}"> | + <ADD: "+"> | <SUB: "-"> | <DIV: "/"> | <MUL: "*"> | <DOT: "."> | + <DOLLAR: "$"> | <COMMA: ","> | <COLON: ":"> | + <LE: "<="> | <LT: "<"> | <EQ: "=="> | + <NQ: "!="> | <AQ: "~="> | <GE: ">="> | <GT: ">"> | + <STRING: ("\"" (~["\""] | "\\\"")* "\"") | ("'" (~["'"] | "\\'")* "'")> | + <IF: "if"> | - <COSH: "cosh"> | - <SINH: "sinh"> | - <TANH: "tanh"> | - <COS: "cos"> | - <SIN: "sin"> | - <TAN: "tan"> | + <IN: "in"> | + <F: "f"> | + + <ABS: "abs"> | <ACOS: "acos"> | <ASIN: "asin"> | - <ATAN2: "atan2"> | <ATAN: "atan"> | - <EXP: "exp"> | - <LDEXP: "ldexp"> | - <LOG10: "log10"> | - <LOG: "log"> | - <POW: "pow"> | - <SQRT: "sqrt"> | <CEIL: "ceil"> | + <COS: "cos"> | + <COSH: "cosh"> | + <ELU: "elu"> | + <EXP: "exp"> | <FABS: "fabs"> | <FLOOR: "floor"> | - <FMOD: "fmod"> | - <MIN: "min"> | - <MAX: "max"> | <ISNAN: "isNan"> | - <IN: "in"> | - <SUM: "sum"> | - <MATCH: "match"> | + <LOG: "log"> | + <LOG10: "log10"> | <RELU: "relu"> | + <ROUND: "round"> | <SIGMOID: "sigmoid"> | + <SIGN: "sign"> | + <SIN: "sin"> | + <SINH: "sinh"> | + <SQUARE: "square"> | + <SQRT: "sqrt"> | + <TAN: "tan"> | + <TANH: "tanh"> | + + <ATAN2: "atan2"> | + <FMOD: "fmod"> | + <LDEXP: "ldexp"> | + // MAX + // MIN + <MOD: "mod"> | + <POW: "pow"> | + + <MAP: "map"> | + <REDUCE: "reduce"> | + <JOIN: "join"> | + <RENAME: "rename"> | + <TENSOR: "tensor"> | + <L1_NORMALIZE: "l1_normalize"> | + <L2_NORMALIZE: "l2_normalize"> | + <MATMUL: "matmul"> | + <SOFTMAX: "softmax"> | + <XW_PLUS_B: "xw_plus_b"> | + + <AVG: "avg" > | + <COUNT: "count"> | + <PROD: "prod"> | + <SUM: "sum"> | + <MAX: "max"> | + <MIN: "min"> | + <IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)> } @@ -175,6 +206,7 @@ TruthOperator comparator() : { } ( <LE> { return TruthOperator.SMALLEREQUAL; } | <LT> { return TruthOperator.SMALLER; } | <EQ> { return TruthOperator.EQUAL; } | + <NQ> { return TruthOperator.NOTEQUAL; } | <AQ> { return TruthOperator.APPROX_EQUAL; } | <GE> { return TruthOperator.LARGEREQUAL; } | <GT> { return TruthOperator.LARGER; } ) @@ -189,7 +221,6 @@ ExpressionNode value() : { ( [ LOOKAHEAD(2) <SUB> { neg = true; } ] ( ret = constantPrimitive() | - ret = constantTensor() | LOOKAHEAD(2) ret = ifExpression() | LOOKAHEAD(2) ret = function() | ret = feature() | @@ -279,7 +310,6 @@ ExpressionNode arg() : } { ( ret = constantPrimitive() | - ret = constantTensor() | LOOKAHEAD(2) ret = feature() | name = identifier() { ret = new NameNode(name); } ) { return ret; } @@ -290,11 +320,11 @@ ExpressionNode function() : ExpressionNode function; } { - ( function = scalarFunction() | function = tensorFunction() ) + ( function = scalarOrTensorFunction() | function = tensorFunction() ) { return function; } } -FunctionNode scalarFunction() : +FunctionNode scalarOrTensorFunction() : { Function function; ExpressionNode arg1, arg2; @@ -312,61 +342,223 @@ FunctionNode scalarFunction() : ExpressionNode tensorFunction() : { + ExpressionNode tensorExpression; +} +{ + ( + tensorExpression = tensorMap() | + tensorExpression = tensorReduce() | + tensorExpression = tensorReduceComposites() | + tensorExpression = tensorJoin() | + tensorExpression = tensorRename() | + tensorExpression = tensorGenerate() | + tensorExpression = tensorL1Normalize() | + tensorExpression = tensorL2Normalize() | + tensorExpression = tensorMatmul() | + tensorExpression = tensorSoftmax() | + tensorExpression = tensorXwPlusB() + ) + { return tensorExpression; } +} + +ExpressionNode tensorMap() : +{ + ExpressionNode tensor; + LambdaFunctionNode doubleMapper; +} +{ + <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> + { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), + doubleMapper.asDoubleUnaryOperator())); } +} + +ExpressionNode tensorReduce() : +{ + ExpressionNode tensor; + Reduce.Aggregator aggregator; + List<String> dimensions = null; +} +{ + <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } +} + +ExpressionNode tensorReduceComposites() : +{ + ExpressionNode tensor; + Reduce.Aggregator aggregator; + List<String> dimensions = null; +} +{ + aggregator = tensorReduceAggregator() + <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } +} + +ExpressionNode tensorJoin() : +{ ExpressionNode tensor1, tensor2; - String dimension = null; - TensorAddress address = null; + LambdaFunctionNode doubleJoiner; } { - ( - <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE> - { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); } - ) | - ( - <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE> - { return new TensorMatchNode(tensor1, tensor2); } - ) + <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> + { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + doubleJoiner.asDoubleBinaryOperator())); } +} + +ExpressionNode tensorRename() : +{ + ExpressionNode tensor; + List<String> fromDimensions, toDimensions; +} +{ + <RENAME> <LBRACE> tensor = expression() <COMMA> + fromDimensions = bracedIdentifierList() <COMMA> + toDimensions = bracedIdentifierList() + <RBRACE> + { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } +} + +// TODO: Notice that null is parsed below +ExpressionNode tensorGenerate() : +{ + TensorType type; + LambdaFunctionNode generator; +} +{ + <TENSOR> <LBRACE> <RBRACE> <LBRACE> + { return new TensorFunctionNode(new Generate(null, null)); } +} + +ExpressionNode tensorL1Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorL2Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorMatmul() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + dimension)); } +} + +ExpressionNode tensorSoftmax() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorXwPlusB() : +{ + ExpressionNode tensor1, tensor2, tensor3; + String dimension; +} +{ + <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA> + tensor2 = expression() <COMMA> + tensor3 = expression() <COMMA> + dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + TensorFunctionNode.wrapArgument(tensor3), + dimension)); } +} + +LambdaFunctionNode lambdaFunction() : +{ + List<String> variables; + ExpressionNode functionExpression; +} +{ + ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) + { return new LambdaFunctionNode(variables, functionExpression); } +} + +Reduce.Aggregator tensorReduceAggregator() : +{ +} +{ + ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) + { return Reduce.Aggregator.valueOf(token.image); } } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { + Reduce.Aggregator aggregator; } { - ( <SUM> | <MATCH> ) - { return token.image; } + ( <F> { return token.image; } ) | + ( <MAP> { return token.image; } ) | + ( <REDUCE> { return token.image; } ) | + ( <JOIN> { return token.image; } ) | + ( <RENAME> { return token.image; } ) | + ( <TENSOR> { return token.image; } ) | + ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } Function unaryFunctionName() : { } { - <COS> { return Function.cos; } | - <SIN> { return Function.sin; } | - <TAN> { return Function.tan; } | - <COSH> { return Function.cosh; } | - <SINH> { return Function.sinh; } | - <TANH> { return Function.tanh; } | + <ABS> { return Function.abs; } | <ACOS> { return Function.acos; } | <ASIN> { return Function.asin; } | <ATAN> { return Function.atan; } | - <EXP> { return Function.exp; } | - <LOG10> { return Function.log10; } | - <LOG> { return Function.log; } | - <SQRT> { return Function.sqrt; } | <CEIL> { return Function.ceil; } | + <COS> { return Function.cos; } | + <COSH> { return Function.cosh; } | + <ELU> { return Function.elu; } | + <EXP> { return Function.exp; } | <FABS> { return Function.fabs; } | <FLOOR> { return Function.floor; } | <ISNAN> { return Function.isNan; } | + <LOG> { return Function.log; } | + <LOG10> { return Function.log10; } | <RELU> { return Function.relu; } | - <SIGMOID> { return Function.sigmoid; } + <ROUND> { return Function.round; } | + <SIGMOID> { return Function.sigmoid; } | + <SIGN> { return Function.sign; } | + <SIN> { return Function.sin; } | + <SINH> { return Function.sinh; } | + <SQUARE> { return Function.square; } | + <SQRT> { return Function.sqrt; } | + <TAN> { return Function.tan; } | + <TANH> { return Function.tanh; } } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | - <LDEXP> { return Function.ldexp; } | - <POW> { return Function.pow; } | <FMOD> { return Function.fmod; } | + <LDEXP> { return Function.ldexp; } | + <MAX> { return Function.max; } | <MIN> { return Function.min; } | - <MAX> { return Function.max; } + <MOD> { return Function.mod; } | + <POW> { return Function.pow; } } List<ExpressionNode> expressionList() : @@ -405,79 +597,64 @@ String identifier() : <IDENTIFIER> { return token.image; } } -// An identifier or integer -String tag() : -{ - String name; -} -{ - name = identifier() { return name; } | - <INTEGER> { return token.image; } -} - -ConstantNode constantPrimitive() : +List<String> identifierList() : { - String sign = ""; + List<String> list = new ArrayList<String>(); + String element; } { - ( <SUB> { sign = "-";} ) ? - ( <INTEGER> | <FLOAT> | <STRING> ) - { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); } + ( element = identifier() { list.add(element); } )? + ( <COMMA> element = identifier() { list.add(element); } ) * + { return list; } } -Value primitiveValue() : +List<String> bracedIdentifierList() : { - String sign = ""; + List<String> list = new ArrayList<String>(); + String element; } { - ( <SUB> { sign = "-";} ) ? - ( <INTEGER> | <FLOAT> | <STRING> ) - { return Value.parse(sign + token.image); } + ( element = identifier() { return Collections.singletonList(element); } ) + | + ( <LBRACE> list = identifierList() <RBRACE> { return list; } ) } -ConstantNode constantTensor() : +// An identifier or integer +String tag() : { - Value constantValue; + String name; } { - <LCURLY> constantValue = tensorContent() <RCURLY> - { return new ConstantNode(constantValue); } + name = identifier() { return name; } | + <INTEGER> { return token.image; } } -TensorValue tensorContent() : +List<String> tagCommaLeadingList() : { - Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>(); - TensorAddress address; - Double value; + List<String> list = new ArrayList<String>(); + String element; } { - ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ? - ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) * - { return new TensorValue(new MapTensor(cells)); } + ( <COMMA> element = tag() { list.add(element); } ) * + { return list; } } -TensorAddress tensorAddress() : +ConstantNode constantPrimitive() : { - List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>(); - String dimension; - String label; + String sign = ""; } { - <LCURLY> - ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ? - ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) * - <RCURLY> - { return TensorAddress.fromUnsorted(elements); } + ( <SUB> { sign = "-";} ) ? + ( <INTEGER> | <FLOAT> | <STRING> ) + { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); } } -String label() : +Value primitiveValue() : { - String label; - + String sign = ""; } { - ( label = tag() | - ( "-" { label = "-"; } ) ) - { return label; } + ( <SUB> { sign = "-";} ) ? + ( <INTEGER> | <FLOAT> | <STRING> ) + { return Value.parse(sign + token.image); } } - |