diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2016-11-26 22:45:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-26 22:45:20 +0100 |
commit | 2f55986b4de9420e5728c5abbaafb69fb2f10a34 (patch) | |
tree | 9a6a77f76d25620771dfe7ab5de49910c4321fc5 | |
parent | 2bc82ba9d9698214e703f19039387609d82b12f8 (diff) |
Revert "Revert "Bratseth/tensor functions 3""
56 files changed, 1905 insertions, 1091 deletions
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 206ab8e30f0..64bb538eab5 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1049,7 +1049,7 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithDimensions() { - assertTensorField("( {{x:-,y:-}:1.0} * {} )", + assertTensorField("tensor(x{},y{}):{}", createPutWithTensor("{ " + " \"dimensions\": [\"x\",\"y\"] " + "}")); @@ -1101,7 +1101,7 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithDimensionsAndCells() { - assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )", + assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}", createPutWithTensor("{ " + " \"dimensions\": [\"x\",\"y\",\"z\"], " + " \"cells\": [ " @@ -1115,7 +1115,7 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithDimensionsAndCellsInDifferentJsonOrder() { - assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )", + assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}", createPutWithTensor("{ " + " \"cells\": [ " + " { \"address\": { \"x\": \"a\", \"y\": \"b\" }, " diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java index ba06843f178..252d40b7291 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java @@ -69,9 +69,9 @@ public class RestApiTest { // POST new nodes assertResponse(new Request("http://localhost:8080/nodes/v2/node", ("[" + asNodeJson("host8.yahoo.com", "default") + "," + - asNodeJson("host9.yahoo.com", "large-variant") + "," + - asHostJson("parent2.yahoo.com", "large-variant") + "," + - asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]"). + asNodeJson("host9.yahoo.com", "large-variant") + "," + + asHostJson("parent2.yahoo.com", "large-variant") + "," + + asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]"). getBytes(StandardCharsets.UTF_8), Request.Method.POST), "{\"message\":\"Added 4 nodes to the provisioned state\"}"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 0dff0414ac2..620c6fad0b4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.tensor.functions.EvaluationContext; import java.util.Set; @@ -10,7 +11,7 @@ import java.util.Set; * * @author bratseth */ -public abstract class Context { +public abstract class Context implements EvaluationContext { /** * <p>Returns the value of a simple variable name.</p> @@ -41,7 +42,7 @@ public abstract class Context { * "main" (or only) value. */ public Value get(String name, Arguments arguments,String output) { - if (arguments!=null && arguments.expressions().size()>0) + if (arguments!=null && arguments.expressions().size() > 0) throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" + name + arguments + "'"); if (output==null) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2bae382d5bd..f8dcd8a6127 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - return operator.evaluate(asDouble(), value.asDouble()); + public Value compare(TruthOperator operator, Value value) { + return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 028dad16d21..0e0d793bfd1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue { } @Override - public boolean compare(TruthOperator operator, Value value) { - try { - return operator.evaluate(this.value, value.asDouble()); - } - catch (UnsupportedOperationException e) { - throw unsupported("comparison",value); - } - } - - @Override public Value function(Function function, Value value) { // use the tensor implementation of max and min if the argument is a tensor if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 9ee9a1f7a71..2dffe2a1100 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -34,11 +34,9 @@ public class MapContext extends Context { * Creates a map context from a map. * The ownership of the map is transferred to this - it cannot be further modified by the caller. * All the Values of the map will be frozen. - * - * @since 5.1.5 */ public MapContext(Map<String,Value> bindings) { - this.bindings=bindings; + this.bindings = bindings; for (Value boundValue : bindings.values()) boundValue.freeze(); } @@ -67,6 +65,9 @@ public class MapContext extends Context { if (frozen) return bindings; return Collections.unmodifiableMap(bindings); } + + /** Returns a new, modifiable context containing all the bindings of this */ + public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); } /** Returns an unmodifiable map of the names of this */ public @Override Set<String> names() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 379b5755c7b..eb997ab818a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -68,10 +68,10 @@ public class StringValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { + public Value compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) - return this.equals(value); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='"); + return new BooleanValue(this.equals(value)); + throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 12bede95aae..b1f4a7b20ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.TensorType; +import java.util.Collections; import java.util.Optional; /** @@ -17,7 +18,7 @@ import java.util.Optional; * * @author bratseth */ - @Beta +@Beta public class TensorValue extends Value { /** The tensor value of this */ @@ -53,7 +54,7 @@ public class TensorValue extends Value { @Override public Value negate() { - return new TensorValue(value.apply((Double value) -> -value)); + return new TensorValue(value.map((value) -> -value)); } @Override @@ -61,7 +62,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.add(((TensorValue)argument).value)); else - return new TensorValue(value.apply((Double value) -> value + argument.asDouble())); + return new TensorValue(value.map((value) -> value + argument.asDouble())); } @Override @@ -69,7 +70,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.subtract(((TensorValue) argument).value)); else - return new TensorValue(value.apply((Double value) -> value - argument.asDouble())); + return new TensorValue(value.map((value) -> value - argument.asDouble())); } @Override @@ -77,35 +78,15 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.multiply(((TensorValue) argument).value)); else - return new TensorValue(value.apply((Double value) -> value * argument.asDouble())); + return new TensorValue(value.map((value) -> value * argument.asDouble())); } @Override public Value divide(Value argument) { if (argument instanceof TensorValue) - throw new UnsupportedOperationException("Two tensors cannot be divided"); + return new TensorValue(value.divide(((TensorValue) argument).value)); else - return new TensorValue(value.apply((Double value) -> value / argument.asDouble())); - } - - public Value match(Value argument) { - return new TensorValue(value.match(asTensor(argument, "match"))); - } - - public Value min(Value argument) { - return new TensorValue(value.min(asTensor(argument, "min"))); - } - - public Value max(Value argument) { - return new TensorValue(value.max(asTensor(argument, "max"))); - } - - public Value sum(String dimension) { - return new TensorValue(value.sum(dimension)); - } - - public Value sum() { - return new DoubleValue(value.sum()); + return new TensorValue(value.map((value) -> value / argument.asDouble())); } private Tensor asTensor(Value value, String operationName) { @@ -122,18 +103,37 @@ public class TensorValue extends Value { } @Override - public boolean compare(TruthOperator operator, Value value) { - throw new UnsupportedOperationException("A tensor cannot be compared with any value"); + public Value compare(TruthOperator operator, Value argument) { + return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); + } + + private Tensor compareTensor(TruthOperator operator, Tensor argument) { + switch (operator) { + case LARGER: return value.larger(argument); + case LARGEREQUAL: return value.largerOrEqual(argument); + case SMALLER: return value.smaller(argument); + case SMALLEREQUAL: return value.smallerOrEqual(argument); + case EQUAL: return value.equal(argument); + case NOTEQUAL: return value.notEqual(argument); + default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); + } } @Override - public Value function(Function function, Value argument) { - if (function.equals(Function.min) && argument instanceof TensorValue) - return min(argument); - else if (function.equals(Function.max) && argument instanceof TensorValue) - return max(argument); + public Value function(Function function, Value arg) { + if (arg instanceof TensorValue) + return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString()))); else - return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble()))); + return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); + } + + private Tensor functionOnTensor(Function function, Tensor argument) { + switch (function) { + case min: return value.min(argument); + case max: return value.max(argument); + case atan2: return value.atan2(argument); + default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + } } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index e5680edc68a..8ce18265231 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -42,7 +42,7 @@ public abstract class Value { public abstract Value divide(Value value); /** Perform the comparison specified by the operator between this value and the given value */ - public abstract boolean compare(TruthOperator operator,Value value); + public abstract Value compare(TruthOperator operator, Value value); /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function,Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 882d16ebc1c..af05acb365a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import java.util.*; /** - * A node which returns true or false depending on the outcome of a comparison. + * A node which returns the outcome of a comparison. * * @author bratseth - * @since 5.1.21 */ public class ComparisonNode extends BooleanNode { @@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode { @Override public Value evaluate(Context context) { - Value leftValue=leftCondition.evaluate(context); - Value rightValue=rightCondition.evaluate(context); - return new BooleanValue(leftValue.compare(operator,rightValue)); + Value leftValue = leftCondition.evaluate(context); + Value rightValue = rightCondition.evaluate(context); + return leftValue.compare(operator,rightValue); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index 675ce758faa..19b1a83ed99 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -12,31 +12,38 @@ import static java.lang.Math.*; */ public enum Function implements Serializable { - cosh { public double evaluate(double x, double y) { return cosh(x); } }, - sinh { public double evaluate(double x, double y) { return sinh(x); } }, - tanh { public double evaluate(double x, double y) { return tanh(x); } }, - cos { public double evaluate(double x, double y) { return cos(x); } }, - sin { public double evaluate(double x, double y) { return sin(x); } }, - tan { public double evaluate(double x, double y) { return tan(x); } }, + abs { public double evaluate(double x, double y) { return abs(x); } }, acos { public double evaluate(double x, double y) { return acos(x); } }, asin { public double evaluate(double x, double y) { return asin(x); } }, atan { public double evaluate(double x, double y) { return atan(x); } }, - exp { public double evaluate(double x, double y) { return exp(x); } }, - log10 { public double evaluate(double x, double y) { return log10(x); } }, - log { public double evaluate(double x, double y) { return log(x); } }, - sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, ceil { public double evaluate(double x, double y) { return ceil(x); } }, + cos { public double evaluate(double x, double y) { return cos(x); } }, + cosh { public double evaluate(double x, double y) { return cosh(x); } }, + elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } }, + exp { public double evaluate(double x, double y) { return exp(x); } }, fabs { public double evaluate(double x, double y) { return abs(x); } }, floor { public double evaluate(double x, double y) { return floor(x); } }, isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } }, + log { public double evaluate(double x, double y) { return log(x); } }, + log10 { public double evaluate(double x, double y) { return log10(x); } }, relu { public double evaluate(double x, double y) { return max(x,0); } }, + round { public double evaluate(double x, double y) { return round(x); } }, sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } }, + sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } }, + sin { public double evaluate(double x, double y) { return sin(x); } }, + sinh { public double evaluate(double x, double y) { return sinh(x); } }, + square { public double evaluate(double x, double y) { return x*x; } }, + sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, + tan { public double evaluate(double x, double y) { return tan(x); } }, + tanh { public double evaluate(double x, double y) { return tanh(x); } }, + atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, - ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } }, fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } }, + ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } }, + max(2) { public double evaluate(double x, double y) { return max(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, - max(2) { public double evaluate(double x, double y) { return max(x,y); } }; + mod(2) { public double evaluate(double x, double y) { return x % y; } }, + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java new file mode 100644 index 00000000000..7b48288598d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -0,0 +1,122 @@ +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; + +/** + * A free, parametrized function + * + * @author bratseth + */ +public class LambdaFunctionNode extends CompositeNode { + + private final ImmutableList<String> arguments; + private final ExpressionNode functionExpression; + + public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { + // TODO: Verify that the function only accesses the arguments in mapperVariables + this.arguments = ImmutableList.copyOf(arguments); + this.functionExpression = functionExpression; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(functionExpression); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if ( children.size() != 1) + throw new IllegalArgumentException("A lambda function must have a single child expression"); + return new LambdaFunctionNode(arguments, children.get(0)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")"; + } + + private String commaSeparated(List<String> list) { + StringBuilder b = new StringBuilder(); + for (String element : list) + b.append(element).append(","); + if (b.length() > 0) + b.setLength(b.length()-1); + return b.toString(); + } + + /** Evaluate this in a context which must have the arguments bound */ + @Override + public Value evaluate(Context context) { + return functionExpression.evaluate(context); + } + + /** + * Returns this as a double unary operator + * + * @throws IllegalStateException if this has more than one argument + */ + public DoubleUnaryOperator asDoubleUnaryOperator() { + if (arguments.size() > 1) + throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + + "Must have at most one argument " + " but has " + arguments); + return new DoubleUnaryLambda(); + } + + /** + * Returns this as a double binary operator + * + * @throws IllegalStateException if this has more than two arguments + */ + public DoubleBinaryOperator asDoubleBinaryOperator() { + if (arguments.size() > 2) + throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " + + "Must have at most two argument " + " but has " + arguments); + return new DoubleBinaryLambda(); + } + + private class DoubleUnaryLambda implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { + MapContext context = new MapContext(); + if (arguments.size() > 0) + context.put(arguments.get(0), operand); + return evaluate(context).asDouble(); + } + + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } + + } + + private class DoubleBinaryLambda implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { + MapContext context = new MapContext(); + if (arguments.size() > 0) + context.put(arguments.get(0), left); + if (arguments.size() > 1) + context.put(arguments.get(1), right); + return evaluate(context).asDouble(); + } + + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java new file mode 100644 index 00000000000..26d3f1dcc0e --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -0,0 +1,111 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.EvaluationContext; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.stream.Collectors; + +/** + * A node which performs a tensor function + * + * @author bratseth + */ + @Beta +public class TensorFunctionNode extends CompositeNode { + + private final TensorFunction function; + + public TensorFunctionNode(TensorFunction function) { + this.function = function; + } + + @Override + public List<ExpressionNode> children() { + return function.functionArguments().stream() + .map(f -> ((TensorFunctionExpressionNode)f).expression) + .collect(Collectors.toList()); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + // Serialize as primitive + return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); + } + + @Override + public Value evaluate(Context context) { + return new TensorValue(function.evaluate(context)); + } + + public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { + return new TensorFunctionExpressionNode(node); + } + + /** + * A tensor function implemented by an expression. + * This allows us to pass expressions as tensor function arguments. + */ + public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { + + /** An expression which produces a tensor */ + private final ExpressionNode expression; + + public TensorFunctionExpressionNode(ExpressionNode expression) { + this.expression = expression; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext context) { + Value result = expression.evaluate((Context)context); + if ( ! ( result instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + + "but this returns " + result + ", not a tensor"); + return ((TensorValue)result).asTensor(); + } + + @Override + public String toString(ToStringContext c) { + ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c; + return expression.toString(context.context, context.path, context.parent); + } + + } + + /** Allows passing serialization context arguments through TensorFunctions */ + private static class ExpressionNodeToStringContext implements ToStringContext { + + final SerializationContext context; + final Deque<String> path; + final CompositeNode parent; + + public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { + this.context = context; + this.path = path; + this.parent = parent; + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java deleted file mode 100644 index af309b3e8d8..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.ArrayList; -import java.util.Deque; -import java.util.List; - -/** - * @author bratseth - */ - @Beta -public class TensorMatchNode extends CompositeNode { - - private final ExpressionNode left, right; - - public TensorMatchNode(ExpressionNode left, ExpressionNode right) { - this.left = left; - this.right = right; - } - - @Override - public List<ExpressionNode> children() { - List<ExpressionNode> children = new ArrayList<>(2); - children.add(left); - children.add(right); - return children; - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - if ( children.size() != 2) - throw new IllegalArgumentException("A match product must have two children"); - return new TensorMatchNode(children.get(0), children.get(1)); - - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")"; - } - - @Override - public Value evaluate(Context context) { - return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context))); - } - - private TensorValue asTensor(Value value) { - if ( ! (value instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " + - "not a tensor: " + value); - return (TensorValue)value; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java deleted file mode 100644 index a1f83157e20..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; -import java.util.Optional; - -/** - * A node which sums over all cells in the argument tensor - * - * @author bratseth - */ - @Beta -public class TensorSumNode extends CompositeNode { - - /** The tensor to sum */ - private final ExpressionNode argument; - - /** The dimension to sum over, or empty to sum all cells to a scalar */ - private final Optional<String> dimension; - - public TensorSumNode(ExpressionNode argument, Optional<String> dimension) { - this.argument = argument; - this.dimension = dimension; - } - - @Override - public List<ExpressionNode> children() { - return Collections.singletonList(argument); - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument"); - return new TensorSumNode(children.get(0), dimension); - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return "sum(" + - argument.toString(context, path, parent) + - ( dimension.isPresent() ? ", " + dimension.get() : "" ) + - ")"; - } - - @Override - public Value evaluate(Context context) { - Value argumentValue = argument.evaluate(context); - if ( ! ( argumentValue instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " + - "but this returns " + argumentValue + ", not a tensor"); - TensorValue tensorArgument = (TensorValue)argumentValue; - if (dimension.isPresent()) - return tensorArgument.sum(dimension.get()); - else - return tensorArgument.sum(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java index 60fe19f909f..932975f3b63 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java @@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable { EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }; + LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, + NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; private final String operatorString; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 78ad665c414..0fcfdb5d40c 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -21,10 +21,9 @@ import com.yahoo.searchlib.rankingexpression.rule.*; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.*; +import com.yahoo.tensor.functions.*; import java.util.Collections; -import java.util.Map; import java.util.LinkedHashMap; import java.util.Arrays; import java.util.ArrayList; @@ -60,51 +59,83 @@ TOKEN : <RSQUARE: "]"> | <LCURLY: "{"> | <RCURLY: "}"> | + <ADD: "+"> | <SUB: "-"> | <DIV: "/"> | <MUL: "*"> | <DOT: "."> | + <DOLLAR: "$"> | <COMMA: ","> | <COLON: ":"> | + <LE: "<="> | <LT: "<"> | <EQ: "=="> | + <NQ: "!="> | <AQ: "~="> | <GE: ">="> | <GT: ">"> | + <STRING: ("\"" (~["\""] | "\\\"")* "\"") | ("'" (~["'"] | "\\'")* "'")> | + <IF: "if"> | - <COSH: "cosh"> | - <SINH: "sinh"> | - <TANH: "tanh"> | - <COS: "cos"> | - <SIN: "sin"> | - <TAN: "tan"> | + <IN: "in"> | + <F: "f"> | + + <ABS: "abs"> | <ACOS: "acos"> | <ASIN: "asin"> | - <ATAN2: "atan2"> | <ATAN: "atan"> | - <EXP: "exp"> | - <LDEXP: "ldexp"> | - <LOG10: "log10"> | - <LOG: "log"> | - <POW: "pow"> | - <SQRT: "sqrt"> | <CEIL: "ceil"> | + <COS: "cos"> | + <COSH: "cosh"> | + <ELU: "elu"> | + <EXP: "exp"> | <FABS: "fabs"> | <FLOOR: "floor"> | - <FMOD: "fmod"> | - <MIN: "min"> | - <MAX: "max"> | <ISNAN: "isNan"> | - <IN: "in"> | - <SUM: "sum"> | - <MATCH: "match"> | + <LOG: "log"> | + <LOG10: "log10"> | <RELU: "relu"> | + <ROUND: "round"> | <SIGMOID: "sigmoid"> | + <SIGN: "sign"> | + <SIN: "sin"> | + <SINH: "sinh"> | + <SQUARE: "square"> | + <SQRT: "sqrt"> | + <TAN: "tan"> | + <TANH: "tanh"> | + + <ATAN2: "atan2"> | + <FMOD: "fmod"> | + <LDEXP: "ldexp"> | + // MAX + // MIN + <MOD: "mod"> | + <POW: "pow"> | + + <MAP: "map"> | + <REDUCE: "reduce"> | + <JOIN: "join"> | + <RENAME: "rename"> | + <TENSOR: "tensor"> | + <L1_NORMALIZE: "l1_normalize"> | + <L2_NORMALIZE: "l2_normalize"> | + <MATMUL: "matmul"> | + <SOFTMAX: "softmax"> | + <XW_PLUS_B: "xw_plus_b"> | + + <AVG: "avg" > | + <COUNT: "count"> | + <PROD: "prod"> | + <SUM: "sum"> | + <MAX: "max"> | + <MIN: "min"> | + <IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)> } @@ -175,6 +206,7 @@ TruthOperator comparator() : { } ( <LE> { return TruthOperator.SMALLEREQUAL; } | <LT> { return TruthOperator.SMALLER; } | <EQ> { return TruthOperator.EQUAL; } | + <NQ> { return TruthOperator.NOTEQUAL; } | <AQ> { return TruthOperator.APPROX_EQUAL; } | <GE> { return TruthOperator.LARGEREQUAL; } | <GT> { return TruthOperator.LARGER; } ) @@ -189,7 +221,6 @@ ExpressionNode value() : { ( [ LOOKAHEAD(2) <SUB> { neg = true; } ] ( ret = constantPrimitive() | - ret = constantTensor() | LOOKAHEAD(2) ret = ifExpression() | LOOKAHEAD(2) ret = function() | ret = feature() | @@ -279,7 +310,6 @@ ExpressionNode arg() : } { ( ret = constantPrimitive() | - ret = constantTensor() | LOOKAHEAD(2) ret = feature() | name = identifier() { ret = new NameNode(name); } ) { return ret; } @@ -290,11 +320,11 @@ ExpressionNode function() : ExpressionNode function; } { - ( function = scalarFunction() | function = tensorFunction() ) + ( function = scalarOrTensorFunction() | function = tensorFunction() ) { return function; } } -FunctionNode scalarFunction() : +FunctionNode scalarOrTensorFunction() : { Function function; ExpressionNode arg1, arg2; @@ -312,61 +342,223 @@ FunctionNode scalarFunction() : ExpressionNode tensorFunction() : { + ExpressionNode tensorExpression; +} +{ + ( + tensorExpression = tensorMap() | + tensorExpression = tensorReduce() | + tensorExpression = tensorReduceComposites() | + tensorExpression = tensorJoin() | + tensorExpression = tensorRename() | + tensorExpression = tensorGenerate() | + tensorExpression = tensorL1Normalize() | + tensorExpression = tensorL2Normalize() | + tensorExpression = tensorMatmul() | + tensorExpression = tensorSoftmax() | + tensorExpression = tensorXwPlusB() + ) + { return tensorExpression; } +} + +ExpressionNode tensorMap() : +{ + ExpressionNode tensor; + LambdaFunctionNode doubleMapper; +} +{ + <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> + { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), + doubleMapper.asDoubleUnaryOperator())); } +} + +ExpressionNode tensorReduce() : +{ + ExpressionNode tensor; + Reduce.Aggregator aggregator; + List<String> dimensions = null; +} +{ + <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } +} + +ExpressionNode tensorReduceComposites() : +{ + ExpressionNode tensor; + Reduce.Aggregator aggregator; + List<String> dimensions = null; +} +{ + aggregator = tensorReduceAggregator() + <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } +} + +ExpressionNode tensorJoin() : +{ ExpressionNode tensor1, tensor2; - String dimension = null; - TensorAddress address = null; + LambdaFunctionNode doubleJoiner; } { - ( - <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE> - { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); } - ) | - ( - <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE> - { return new TensorMatchNode(tensor1, tensor2); } - ) + <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> + { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + doubleJoiner.asDoubleBinaryOperator())); } +} + +ExpressionNode tensorRename() : +{ + ExpressionNode tensor; + List<String> fromDimensions, toDimensions; +} +{ + <RENAME> <LBRACE> tensor = expression() <COMMA> + fromDimensions = bracedIdentifierList() <COMMA> + toDimensions = bracedIdentifierList() + <RBRACE> + { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } +} + +// TODO: Notice that null is parsed below +ExpressionNode tensorGenerate() : +{ + TensorType type; + LambdaFunctionNode generator; +} +{ + <TENSOR> <LBRACE> <RBRACE> <LBRACE> + { return new TensorFunctionNode(new Generate(null, null)); } +} + +ExpressionNode tensorL1Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorL2Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorMatmul() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + dimension)); } +} + +ExpressionNode tensorSoftmax() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorXwPlusB() : +{ + ExpressionNode tensor1, tensor2, tensor3; + String dimension; +} +{ + <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA> + tensor2 = expression() <COMMA> + tensor3 = expression() <COMMA> + dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + TensorFunctionNode.wrapArgument(tensor3), + dimension)); } +} + +LambdaFunctionNode lambdaFunction() : +{ + List<String> variables; + ExpressionNode functionExpression; +} +{ + ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) + { return new LambdaFunctionNode(variables, functionExpression); } +} + +Reduce.Aggregator tensorReduceAggregator() : +{ +} +{ + ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) + { return Reduce.Aggregator.valueOf(token.image); } } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { + Reduce.Aggregator aggregator; } { - ( <SUM> | <MATCH> ) - { return token.image; } + ( <F> { return token.image; } ) | + ( <MAP> { return token.image; } ) | + ( <REDUCE> { return token.image; } ) | + ( <JOIN> { return token.image; } ) | + ( <RENAME> { return token.image; } ) | + ( <TENSOR> { return token.image; } ) | + ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } Function unaryFunctionName() : { } { - <COS> { return Function.cos; } | - <SIN> { return Function.sin; } | - <TAN> { return Function.tan; } | - <COSH> { return Function.cosh; } | - <SINH> { return Function.sinh; } | - <TANH> { return Function.tanh; } | + <ABS> { return Function.abs; } | <ACOS> { return Function.acos; } | <ASIN> { return Function.asin; } | <ATAN> { return Function.atan; } | - <EXP> { return Function.exp; } | - <LOG10> { return Function.log10; } | - <LOG> { return Function.log; } | - <SQRT> { return Function.sqrt; } | <CEIL> { return Function.ceil; } | + <COS> { return Function.cos; } | + <COSH> { return Function.cosh; } | + <ELU> { return Function.elu; } | + <EXP> { return Function.exp; } | <FABS> { return Function.fabs; } | <FLOOR> { return Function.floor; } | <ISNAN> { return Function.isNan; } | + <LOG> { return Function.log; } | + <LOG10> { return Function.log10; } | <RELU> { return Function.relu; } | - <SIGMOID> { return Function.sigmoid; } + <ROUND> { return Function.round; } | + <SIGMOID> { return Function.sigmoid; } | + <SIGN> { return Function.sign; } | + <SIN> { return Function.sin; } | + <SINH> { return Function.sinh; } | + <SQUARE> { return Function.square; } | + <SQRT> { return Function.sqrt; } | + <TAN> { return Function.tan; } | + <TANH> { return Function.tanh; } } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | - <LDEXP> { return Function.ldexp; } | - <POW> { return Function.pow; } | <FMOD> { return Function.fmod; } | + <LDEXP> { return Function.ldexp; } | + <MAX> { return Function.max; } | <MIN> { return Function.min; } | - <MAX> { return Function.max; } + <MOD> { return Function.mod; } | + <POW> { return Function.pow; } } List<ExpressionNode> expressionList() : @@ -405,79 +597,64 @@ String identifier() : <IDENTIFIER> { return token.image; } } -// An identifier or integer -String tag() : -{ - String name; -} -{ - name = identifier() { return name; } | - <INTEGER> { return token.image; } -} - -ConstantNode constantPrimitive() : +List<String> identifierList() : { - String sign = ""; + List<String> list = new ArrayList<String>(); + String element; } { - ( <SUB> { sign = "-";} ) ? - ( <INTEGER> | <FLOAT> | <STRING> ) - { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); } + ( element = identifier() { list.add(element); } )? + ( <COMMA> element = identifier() { list.add(element); } ) * + { return list; } } -Value primitiveValue() : +List<String> bracedIdentifierList() : { - String sign = ""; + List<String> list = new ArrayList<String>(); + String element; } { - ( <SUB> { sign = "-";} ) ? - ( <INTEGER> | <FLOAT> | <STRING> ) - { return Value.parse(sign + token.image); } + ( element = identifier() { return Collections.singletonList(element); } ) + | + ( <LBRACE> list = identifierList() <RBRACE> { return list; } ) } -ConstantNode constantTensor() : +// An identifier or integer +String tag() : { - Value constantValue; + String name; } { - <LCURLY> constantValue = tensorContent() <RCURLY> - { return new ConstantNode(constantValue); } + name = identifier() { return name; } | + <INTEGER> { return token.image; } } -TensorValue tensorContent() : +List<String> tagCommaLeadingList() : { - Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>(); - TensorAddress address; - Double value; + List<String> list = new ArrayList<String>(); + String element; } { - ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ? - ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) * - { return new TensorValue(new MapTensor(cells)); } + ( <COMMA> element = tag() { list.add(element); } ) * + { return list; } } -TensorAddress tensorAddress() : +ConstantNode constantPrimitive() : { - List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>(); - String dimension; - String label; + String sign = ""; } { - <LCURLY> - ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ? - ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) * - <RCURLY> - { return TensorAddress.fromUnsorted(elements); } + ( <SUB> { sign = "-";} ) ? + ( <INTEGER> | <FLOAT> | <STRING> ) + { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); } } -String label() : +Value primitiveValue() : { - String label; - + String sign = ""; } { - ( label = tag() | - ( "-" { label = "-"; } ) ) - { return label; } + ( <SUB> { sign = "-";} ) ? + ( <INTEGER> | <FLOAT> | <STRING> ) + { return Value.parse(sign + token.image); } } - diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 24d7c82235c..f28ff739b4c 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -6,7 +6,10 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; -import junit.framework.TestCase; +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; import java.io.BufferedReader; import java.io.File; @@ -14,15 +17,18 @@ import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.*; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen + * @author bratseth */ -public class RankingExpressionTestCase extends TestCase { +public class RankingExpressionTestCase { + @Test public void testParamInFeature() throws ParseException { assertParse("if (1 > 2, dotProduct(allparentid,query(cate1_parentid)), 2)", "if ( 1 > 2,\n" + @@ -31,6 +37,7 @@ public class RankingExpressionTestCase extends TestCase { ")"); } + @Test public void testDollarShorthand() throws ParseException { assertParse("query(var1)", " $var1"); assertParse("query(var1)", " $var1 "); @@ -44,6 +51,7 @@ public class RankingExpressionTestCase extends TestCase { assertParse("if (if (f1.out < query(p1), 0, 1) < if (f2.out < query(p2), 0, 1), f3.out, query(p3))", "if(if(f1.out<$p1,0,1)<if(f2.out<$p2,0,1),f3.out,$p3)"); } + @Test public void testLookaheadIndefinitely() throws Exception { ExecutorService exec = Executors.newSingleThreadExecutor(); Future<Boolean> future = exec.submit(new Callable<Boolean>() { @@ -60,7 +68,8 @@ public class RankingExpressionTestCase extends TestCase { assertTrue(future.get(60, TimeUnit.SECONDS)); } - public void testSelfRecursionScript() throws ParseException { + @Test + public void testSelfRecursionSerialization() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo"))); @@ -72,7 +81,8 @@ public class RankingExpressionTestCase extends TestCase { } } - public void testMacroCycleScript() throws ParseException { + @Test + public void testMacroCycleSerialization() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar"))); macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo"))); @@ -85,42 +95,48 @@ public class RankingExpressionTestCase extends TestCase { } } - public void testScript() throws ParseException { + @Test + public void testSerialization() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))"))); macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2"))); macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)"))); macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977"))); - assertScript("foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros, - Arrays.asList( - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", - "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))", - "min(6,pow(7,2))", - "min(1,pow(2,2))", - "min(3,pow(4,2))", - "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))" - )); - assertScript("foo(1, 2) + bar(3, 4)", macros, - Arrays.asList( - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", - "min(1,pow(2,2))", - "3 * 3 + 2 * 3 * 4 + 4 * 4" - )); - assertScript("baz(1, 2)", macros, - Arrays.asList( - "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", - "min(1,pow(2,2))", - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", - "1 * 1 + 2 * 1 * 2 + 2 * 2" - )); - assertScript("cox", macros, - Arrays.asList( - "rankingExpression(cox)", - "10 + 08 * 1977" - )); + assertSerialization(Arrays.asList( + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", + "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))", + "min(6,pow(7,2))", + "min(1,pow(2,2))", + "min(3,pow(4,2))", + "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros); + assertSerialization(Arrays.asList( + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", + "min(1,pow(2,2))", + "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros); + assertSerialization(Arrays.asList( + "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", + "min(1,pow(2,2))", + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", + "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros); + assertSerialization(Arrays.asList( + "rankingExpression(cox)", + "10 + 08 * 1977"), "cox", macros + ); + } + + @Test + public void testTensorSerialization() { + assertSerialization("map(constant(tensor0), f(a)(cos(a)))", + "map(constant(tensor0), f(a)(cos(a)))"); + assertSerialization("map(constant(tensor0), f(a)(cos(a))) + join(attribute(tensor1), map(reduce(map(attribute(tensor1), f(a)(a * a)), sum, x), f(a)(sqrt(a))), f(a,b)(a / b))", + "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)"); + assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))", + "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)"); + } + @Test public void testBug3464208() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); @@ -135,18 +151,11 @@ public class RankingExpressionTestCase extends TestCase { String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " + "rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)"; - assertScript(lhs + " + " + rhs, macros, - Arrays.asList( - expLhs + " + " + expRhs, - "69" - )); - assertScript(lhs + " - " + rhs, macros, - Arrays.asList( - expLhs + " - " + expRhs, - "69" - )); + assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros); + assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros); } + @Test public void testParse() throws ParseException, IOException { BufferedReader reader = new BufferedReader(new FileReader("src/tests/rankingexpression/rankingexpressionlist")); String line; @@ -181,36 +190,43 @@ public class RankingExpressionTestCase extends TestCase { } } + @Test public void testIssue() throws ParseException { assertEquals("feature.0", new RankingExpression("feature.0").toString()); assertEquals("if (1 > 2, 3, 4) + feature(arg1).out.out", new RankingExpression("if ( 1 > 2 , 3 , 4 ) + feature ( arg1 ) . out.out").toString()); } + @Test public void testNegativeConstantArgument() throws ParseException { assertEquals("foo(-1.2)", new RankingExpression("foo(-1.2)").toString()); } + @Test public void testNaming() throws ParseException { RankingExpression test = new RankingExpression("a+b"); test.setName("test"); assertEquals("test: a + b", test.toString()); } + @Test public void testCondition() throws ParseException { RankingExpression expression = new RankingExpression("if(1<2,3,4)"); assertTrue(expression.getRoot() instanceof IfNode); } + @Test public void testFileImporting() throws ParseException { RankingExpression expression = new RankingExpression(new File("src/test/files/simple.expression")); assertEquals("simple: a + b", expression.toString()); } + @Test public void testNonCanonicalLegalStrings() throws ParseException { assertParse("a * b + c * d", "a* (b) + \nc*d"); } + @Test public void testEquality() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo)==\"BAR\",log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"), new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); @@ -219,6 +235,7 @@ public class RankingExpressionTestCase extends TestCase { new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).earliness) * fieldMatch(title).completeness)"))); } + @Test public void testSetMembershipConditions() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"), new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); @@ -231,6 +248,7 @@ public class RankingExpressionTestCase extends TestCase { assertEquals(new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"), new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)")); } + @Test public void testComments() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],\n" + "# a comment\n" + @@ -241,6 +259,7 @@ public class RankingExpressionTestCase extends TestCase { new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); } + @Test public void testIsNan() throws ParseException { String strExpr = "if (isNan(attribute(foo)) == 1.0, 1.0, attribute(foo))"; RankingExpression expr = new RankingExpression(strExpr); @@ -255,27 +274,59 @@ public class RankingExpressionTestCase extends TestCase { assertEquals(expected, new RankingExpression(expression).toString()); } - private void assertScript(String expression, List<ExpressionFunction> macros, List<String> expectedScripts) - throws ParseException { - boolean print = false; - if (print) - System.out.println("Parsing expression '" + expression + "'."); - - RankingExpression exp = new RankingExpression(expression); - Map<String, String> scripts = exp.getRankProperties(macros); - if (print) { - for (String key : scripts.keySet()) { - System.out.println("Script '" + key + "': " + scripts.get(key)); - } + /** Test serialization with no macros */ + private void assertSerialization(String expectedSerialization, String expressionString) { + String serializedExpression; + try { + RankingExpression expression = new RankingExpression(expressionString); + // No macros -> expect one rank property + serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next(); + assertEquals(expectedSerialization, serializedExpression); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); } - for (Map.Entry<String, String> m : scripts.entrySet()) - System.out.println(m); - for (int i = 0; i < expectedScripts.size();) { - String val = expectedScripts.get(i++); - assertTrue("Script contains " + val, scripts.containsValue(val)); + try { + // No macros -> output should be parseable to a ranking expression + // (but not the same one due to primitivization) + RankingExpression reparsedExpression = new RankingExpression(serializedExpression); + // Serializing the primitivized expression should yield the same expression again + String reserializedExpression = + reparsedExpression.getRankProperties(Collections.emptyList()).values().iterator().next(); + assertEquals(expectedSerialization, reserializedExpression); + } + catch (ParseException e) { + throw new IllegalArgumentException("Could not parse the serialized expression", e); } - if (print) - System.out.println(""); } + + private void assertSerialization(List<String> expectedSerialization, String expressionString, + List<ExpressionFunction> macros) { + assertSerialization(expectedSerialization, expressionString, macros, false); + } + private void assertSerialization(List<String> expectedSerialization, String expressionString, + List<ExpressionFunction> macros, boolean print) { + try { + if (print) + System.out.println("Parsing expression '" + expressionString + "'."); + + RankingExpression expression = new RankingExpression(expressionString); + Map<String, String> rankProperties = expression.getRankProperties(macros); + if (print) { + for (String key : rankProperties.keySet()) + System.out.println("Property '" + key + "': " + rankProperties.get(key)); + } + for (int i = 0; i < expectedSerialization.size();) { + String val = expectedSerialization.get(i++); + assertTrue("Properties contains " + val, rankProperties.containsValue(val)); + } + if (print) + System.out.println(""); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index b67a423181d..93800e2c246 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -20,7 +20,7 @@ import java.util.Set; */ public class EvaluationTestCase extends junit.framework.TestCase { - private Context defaultContext; + private MapContext defaultContext; @Override protected void setUp() { @@ -100,201 +100,180 @@ public class EvaluationTestCase extends junit.framework.TestCase { @Test public void testTensorEvaluation() { - assertEvaluates("{}", "{}"); // empty - assertEvaluates("( {{x:-}:1} * {} )", "( {{x:-}:1} * {} )"); // empty with dimensions + assertEvaluates("{}", "tensor0", "{}"); - // sum(tensor) - assertEvaluates(5.0, "sum({{}:5.0})"); - assertEvaluates(-5.0, "sum({{}:-5.0})"); - assertEvaluates(12.5, "sum({ {d1:l1}:5.5, {d2:l2}:7.0 })"); - assertEvaluates(0.0, "sum({ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0})"); - - // scalar functions on tensors + // tensor map assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", - "log10({ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 })"); - assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", - "5 * { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", - "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } + 3"); - assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", - "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } / 10"); + "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:4, {d1:l1}:9, {d1:l1,d2:l1 }:16 }", + "map(tensor0, f(x) (x * x))", "{ {}:2, {d1:l1}:3, {d1:l1,d2:l1}:4 }"); + // -- tensor map composites + assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", + "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }", - "- { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }", - "min({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)"); + "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }", - "max({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)"); - assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + {{h:1}:1.0,{h:2}:1.0}"); - - // sum(tensor, dimension) - assertEvaluates("{ {y:1}:4.0, {y:2}:12.0 }", - "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, x)"); - assertEvaluates("{ {x:1}:6.0, {x:2}:10.0 }", - "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, y)"); - - // tensor sum - assertEvaluates("{ }", "{} + {}"); - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "{ {x:1}:3 } + { {x:2}:5 }"); - assertEvaluates("{ {x:1}:8 }", - "{ {x:1}:3 } + { {x:1}:5 }"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "{ {x:1}:3 } + { {y:1}:5 }"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "{ {x:1}:3, {x:2}:7 } + { {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } + { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "{ {x:1}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "{ {x:1}:5, {x:1,y:1}:1 } + { {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "{ {}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }"); - assertEvaluates("{ {}:16, {x:1,y:1}:1, {y:1,z:1}:7 }", - "{ {}:5, {x:1,y:1}:1 } + { {}:11, {y:1,z:1}:7 }"); - - // tensor difference - assertEvaluates("{ }", "{} - {}"); - assertEvaluates("{ {x:1}:3, {x:2}:-5 }", - "{ {x:1}:3 } - { {x:2}:5 }"); - assertEvaluates("{ {x:1}:-2 }", - "{ {x:1}:3 } - { {x:1}:5 }"); - assertEvaluates("{ {x:1}:3, {y:1}:-5 }", - "{ {x:1}:3 } - { {y:1}:5 }"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:-5 }", - "{ {x:1}:3, {x:2}:7 } - { {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:-7, {y:2,z:1}:-11, {y:1,z:2}:-13 }", - "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } - { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }", - "{ {x:1}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:-11, {y:1,z:1}:-7 }", - "{ {x:1}:5, {x:1,y:1}:1 } - { {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }", - "{ {}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }"); - assertEvaluates("{ {}:-6, {x:1,y:1}:1, {y:1,z:1}:-7 }", - "{ {}:5, {x:1,y:1}:1 } - { {}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:0 }", - "{ {x:1}:3 } - { {x:1}:3 }"); - assertEvaluates("{ {x:1}:0, {x:2}:1 }", - "{ {x:1}:3, {x:2}:1 } - { {x:1}:3 }"); - - // tensor product - assertEvaluates("{ }", "{} * {}"); - assertEvaluates("( {{x:-,y:-,z:-}:1}*{} )", "( {{x:-}:1} * {} ) * ( {{y:-,z:-}:1} * {} )"); // empty dimensions are preserved - assertEvaluates("( {{x:-}:1} * {} )", - "{ {x:1}:3 } * { {x:2}:5 }"); + "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "abs(tensor0)", "{ {x:1}:1, {x:2}:-2 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "acos(tensor0)", "{ {x:1}:1, {x:2}:1 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "asin(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "atan(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "ceil(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cos(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cosh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "elu(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:1 }", "exp(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "fabs(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "floor(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "isNan(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "log(tensor0)", "{ {x:1}:1, {x:2}:1 }"); + assertEvaluates("{ {x:1}:0, {x:2}:1 }", "log10(tensor0)", "{ {x:1}:1, {x:2}:10 }"); + assertEvaluates("{ {x:1}:0, {x:2}:2 }", "mod(tensor0, 3)", "{ {x:1}:3, {x:2}:8 }"); + assertEvaluates("{ {x:1}:1, {x:2}:8 }", "pow(tensor0, 3)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "relu(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "round(tensor0)", "{ {x:1}:1, {x:2}:1.8 }"); + assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "sigmoid(tensor0)","{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:-1 }", "sign(tensor0)", "{ {x:1}:3, {x:2}:-5 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sin(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sinh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:4 }", "square(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:3 }", "sqrt(tensor0)", "{ {x:1}:1, {x:2}:9 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tan(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tanh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + + // tensor reduce + // -- reduce 2 dimensions + assertEvaluates("{ {}:4 }", + "reduce(tensor0, avg, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:4 }", + "reduce(tensor0, count, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:105 }", + "reduce(tensor0, prod, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:16 }", + "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:7 }", + "reduce(tensor0, max, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:1 }", + "reduce(tensor0, min, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + // -- reduce 2 by specifying no arguments + assertEvaluates("{ {}:4 }", + "reduce(tensor0, avg)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + // -- reduce 1 dimension + assertEvaluates("{ {y:1}:2, {y:2}:6 }", + "reduce(tensor0, avg, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:2, {y:2}:2 }", + "reduce(tensor0, count, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:3, {y:2}:35 }", + "reduce(tensor0, prod, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:4, {y:2}:12 }", + "reduce(tensor0, sum, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:3, {y:2}:7 }", + "reduce(tensor0, max, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:1, {y:2}:5 }", + "reduce(tensor0, min, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + // -- reduce composites + assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0"); + assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0"); + assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }"); + assertEvaluates("{ {}: 0 }", "sum(tensor0)", "{ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0}"); + assertEvaluates("{ {y:1}:4, {y:2}:12.0 }", + "sum(tensor0, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {x:1}:6, {x:2}:10.0 }", + "sum(tensor0, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:16 }", + "sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + + // tensor join + assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", "join(tensor0, tensor1, f(x,y) (x*y))", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + // -- join composites + assertEvaluates("{ }", "tensor0 * tensor0", "{}"); + assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )", + "{{x:-}:1}", "{}", "{{y:-,z:-}:1}"); // empty dimensions are preserved + assertEvaluates("tensor(x{}):{}", + "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:2}:5 }"); assertEvaluates("{ {x:1}:15 }", - "{ {x:1}:3 } * { {x:1}:5 }"); + "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:1}:5 }"); assertEvaluates("{ {x:1,y:1}:15 }", - "{ {x:1}:3 } * { {y:1}:5 }"); + "tensor0 * tensor1", "{ {x:1}:3 }", "{ {y:1}:5 }"); assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", - "{ {x:1}:3, {x:2}:7 } * { {y:1}:5 }"); + "tensor0 * tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:8, {x:2,y:1}:12 }", + "tensor0 + tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:-2, {x:2,y:1}:2 }", + "tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }", + "tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }"); + assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }", + "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }", + "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }", - "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } * { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7 }", - "{ {x:1}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,z:1}:55 }", - "{ {x:1}:5, {x:1,y:1}:1 } * { {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7 }", - "{ {}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7, {}:55 }", - "{ {}:5, {x:1,y:1}:1 } * { {}:11, {y:1,z:1}:7 }"); - - // match product - assertEvaluates("{ }", "match({}, {})"); - assertEvaluates("( {{x:-}:1} * {} )", - "match({ {x:1}:3 }, { {x:2}:5 })"); - assertEvaluates("{ {x:1}:15 }", - "match({ {x:1}:3 }, { {x:1}:5 })"); - assertEvaluates("( {{x:-,y:-}:1} * {} )", - "match({ {x:1}:3 }, { {y:1}:5 })"); - assertEvaluates("( {{x:-,y:-}:1} * {} )", - "match({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * { {}:55 } )", - "match({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); - assertEvaluates("( {{z:-}:1} * { {x:1}:15, {x:1,y:1}:7 } )", - "match({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); - - // min - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "min({ {x:1}:3 }, { {x:2}:5 })"); - assertEvaluates("{ {x:1}:3 }", - "min({ {x:1}:3 }, { {x:1}:5 })"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "min({ {x:1}:3 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "min({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "min({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "min({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", - "min({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); - - // max - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "max({ {x:1}:3 }, { {x:2}:5 })"); - assertEvaluates("{ {x:1}:5 }", - "max({ {x:1}:3 }, { {x:1}:5 })"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "max({ {x:1}:3 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "max({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "max({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "max({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", - "max({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); - - // Combined - assertEvaluates(7.5 + 45 + 1.7, - "sum( " + // model computation - " match( " + // model weight application - " { {x:1}:1, {x:2}:2 } * { {y:1}:3, {y:2}:4 } * { {z:1}:5 }, " + // feature combinations - " { {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }" + // model weights - "))+1.7"); - - // undefined is not the same as 0 - assertEvaluates(1.0, "sum({ {x:1}:0, {x:2}:0 } * { {x:1}:1, {x:2}:1 } + 0.5)"); - assertEvaluates(0.0, "sum({ } * { {x:1}:1, {x:2}:1 } + 0.5)"); + "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); + assertEvaluates("{ {x:1,y:2,z:1}:35, {x:1,y:2,z:2}:65 }", + "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:2,z:1}:7, {y:3,z:1}:11, {y:2,z:2}:13 }"); + assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }"); + assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", + "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", + "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", + "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:0 }", + "atan2(tensor0, tensor1)", "{ {x:1}:0, {x:2}:0 }", "{ {y:1}:1 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); + // TODO + // argmax + // argmin + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); + + // tensor rename + assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }"); + assertEvaluates("{ {x:2,y:1}:3 }", "rename(tensor0, (x, y), (y, x))", "{ {x:1,y:2}:3.0 }"); + + // tensor generate - TODO + // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)"); + // range + // diag + // fill + // random + + // composite functions + assertEvaluates("{ {x:1}:0.25, {x:2}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }"); + assertEvaluates("{ {x:1}:0.31622776601683794, {x:2}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }"); + assertEvaluates("{ {y:1}:81.0 }", "matmul(tensor0, tensor1, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }"); + assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "softmax(tensor0, x)", "{ {x:1}:1, {x:2}:1 }", "{ {y:1}:1 }"); + assertEvaluates("{ {x:1,y:1}:88.0 }", "xw_plus_b(tensor0, tensor1, tensor2, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }", "{ {x:1}:7 }"); + + // expressions combining functions + assertEvaluates(String.valueOf(7.5 + 45 + 1.7), + "sum( " + // model computation: + " tensor0 * tensor1 * tensor2 " + // - feature combinations + " * tensor3" + // - model weights application + ") + 1.7", + "{ {x:1}:1, {x:2}:2 }", "{ {y:1}:3, {y:2}:4 }", "{ {z:1}:5 }", + "{ {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }"); + assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{ {x:1}:0, {x:2}:0 }", "{ {x:1}:1, {x:2}:1 }"); + assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:1}:1, {x:2}:1 }"); // tensor result dimensions are given from argument dimensions, not the resulting values - assertEvaluates("x", "( {{x:-}:1.0} * {} )", "{ {x:1}:1 } * { {x:2}:1 }"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"); - - // demonstration of where this produces different results: { {x:1}:1 } with 2 dimensions ... - assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )","{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 } * { {x:1,y:1}:1 }"); - // ... vs { {x:1}:1 } with only one dimension - assertEvaluates("x, y", "{{x:1,y:1}:1.0}", "{ {x:1}:1 } * { {x:1,y:1}:1 }"); - - // check that dimensions are preserved through other operations - String d2 = "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"; // creates a 2d tensor with only an 1d value - assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )", "match(" + d2 + ", {})"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " - {}"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " + {}"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "min(1.5, " + d2 +")"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "max({{x:1}:0}, " + d2 +")"); + assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2}:1 }"); + assertEvaluates("tensor(x{},y{}):{{x:1}:1.0}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1}:1 }"); } public void testProgrammaticBuildingAndPrecedence() { @@ -316,12 +295,16 @@ public class EvaluationTestCase extends junit.framework.TestCase { assertEvaluates(77, "average(\"2*3\",\"pow(2,3)\")+average(\"2*3\",\"pow(2,3)\").timesten", context); } - private RankingExpression assertEvaluates(String tensorValue, String expressionString) { - return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, defaultContext); + private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) { + MapContext context = defaultContext.thawedCopy(); + int argumentIndex = 0; + for (String tensorArgument : tensorArguments) + context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument))); + return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context); } /** Validate also that the dimension of the resulting tensors are as expected */ - private RankingExpression assertEvaluates(String tensorDimensions, String resultTensor, String expressionString) { + private RankingExpression assertEvaluates_old(String tensorDimensions, String resultTensor, String expressionString) { RankingExpression expression = assertEvaluates(new TensorValue(MapTensor.from(resultTensor)), expressionString, defaultContext); TensorValue value = (TensorValue)expression.evaluate(defaultContext); assertEquals(toSet(tensorDimensions), value.asTensor().dimensions()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java index 95c4402a612..08fdc9917a4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java @@ -17,22 +17,25 @@ public class NeuralNetEvaluationTestCase { /** "XOR" neural network, separate expression per layer */ @Test public void testPerLayerExpression() { - String input = "{ {x:1}:0, {x:2}:1 }"; - - String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; - String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; - String firstLayerInput = "sum(" + input + "*" + firstLayerWeights + ", x) + " + firstLayerBias; + String input = "{ {x:1}:0, {x:2}:1 }"; // tensor0 + String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; // tensor1 + String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; // tensor2 + String firstLayerInput = "sum(tensor0 * tensor1, x) + tensor2"; String firstLayerOutput = "min(1.0, max(0.0, 0.5 + " + firstLayerInput + "))"; // non-linearity, "poor man's sigmoid" - assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput); - String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; - String secondLayerBias = "{ {y:1}:-0.5 }"; - String secondLayerInput = "sum(" + firstLayerOutput + "*" + secondLayerWeights + ", h) + " + secondLayerBias; + assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput, input, firstLayerWeights, firstLayerBias); + String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; // tensor3 + String secondLayerBias = "{ {y:1}:-0.5 }"; // tensor4 + String secondLayerInput = "sum(" + firstLayerOutput + "* tensor3, h) + tensor4"; String secondLayerOutput = "min(1.0, max(0.0, 0.5 + " + secondLayerInput + "))"; // non-linearity, "poor man's sigmoid" - assertEvaluates("{ {y:1}:1 }", secondLayerOutput); + assertEvaluates("{ {y:1}:1 }", secondLayerOutput, input, firstLayerWeights, firstLayerBias, secondLayerWeights, secondLayerBias); } - private RankingExpression assertEvaluates(String tensorValue, String expressionString) { - return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, new MapContext()); + private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) { + MapContext context = new MapContext(); + int argumentIndex = 0; + for (String tensorArgument : tensorArguments) + context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument))); + return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context); } private RankingExpression assertEvaluates(Value value, String expressionString, Context context) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java index 9d94ec0bc99..61b230ab390 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java @@ -69,12 +69,4 @@ public class SimplifierTestCase { assertEquals("a + (b + c) / 100000000.0", transformed.toString()); } - @Test - public void testSimplificationWithTensorConstants() throws ParseException { - new Simplifier().transform(new RankingExpression( - "sum(sum((tensorFromWeightedSet(query(wset_query),x)+" + - " tensorFromWeightedSet(attribute(wset),x)) * " + - " {{x:0,y:0}:54, {x:0,y:1} :69, {x:1,y:0} :72, {x:1,y:1} :93},x))")); - } - } diff --git a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java index d70b55c66a2..a54f1971d21 100644 --- a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java +++ b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java @@ -20,7 +20,7 @@ public class ClusterState implements Cloneable { private Map<Node, NodeState> nodeStates = new TreeMap<>(); // TODO: Change to one count for distributor and one for storage, rather than an array - // TODO: Rename, this is not the highest node count but the highest index + // TODO: RenameFunction, this is not the highest node count but the highest index private ArrayList<Integer> nodeCount = new ArrayList<>(2); private String description = ""; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java index 3bda4159ca6..4fd743e4724 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java @@ -21,6 +21,8 @@ import java.util.function.UnaryOperator; @Beta public class MapTensor implements Tensor { + // TODO: Enforce that all addresses are dense (and then avoid storing keys in TensorAddress) + private final ImmutableSet<String> dimensions; private final ImmutableMap<TensorAddress, Double> cells; @@ -31,7 +33,7 @@ public class MapTensor implements Tensor { } /** Creates a sparse tensor */ - MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) { + public MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) { ensureValidDimensions(cells, dimensions); this.dimensions = ImmutableSet.copyOf(dimensions); this.cells = ImmutableMap.copyOf(cells); @@ -52,24 +54,41 @@ public class MapTensor implements Tensor { */ public static MapTensor from(String s) { s = s.trim(); - if ( s.startsWith("(")) - return fromTensorWithEmptyDimensions(s); - else if ( s.startsWith("{")) - return fromTensor(s, Collections.emptySet()); - else - throw new IllegalArgumentException("Excepted a string starting by { or (, got '" + s + "'"); + try { + if (s.startsWith("tensor(")) + return fromTypedTensor(s); + else if (s.startsWith("{")) + return fromUntypedTensor(s, Collections.emptySet()); + else + return fromNumber(Double.parseDouble(s)); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + s + "'"); + } } - private static MapTensor fromTensorWithEmptyDimensions(String s) { + private static MapTensor fromTypedTensor(String s) { + if ( ! s.startsWith("tensor(")) throw tensorFormatException(s); + s = s.substring("tensor(".length()); + int typeSpecEnd = s.indexOf(")"); + if (typeSpecEnd < 0 ) throw tensorFormatException(s); + String typeSpec = s.substring(0, typeSpecEnd); + + Set<String> dimensions = new HashSet<>(); + for (String dimensionSpec : typeSpec.split(",")) { + dimensionSpec = dimensionSpec.trim(); + if ( ! dimensionSpec.endsWith("{}")) + throw new IllegalArgumentException("Only mapped dimensions ({}) are supported, got '" + dimensionSpec + "'"); + dimensions.add(dimensionSpec.substring(0, dimensionSpec.length() - 2)); + } + + s = s.substring(typeSpec.length() + 1); + if ( ! s.startsWith(":")) throw tensorFormatException(s); s = s.substring(1).trim(); - int multiplier = s.indexOf("*"); - if (multiplier < 0 || ! s.endsWith(")")) - throw new IllegalArgumentException("Expected a tensor on the form ({dimension:-,...}*{{cells}}), got '" + s + "'"); - MapTensor dimensionTensor = fromTensor(s.substring(0, multiplier).trim(), Collections.emptySet()); - return fromTensor(s.substring(multiplier + 1, s.length() - 1), dimensionTensor.dimensions()); + return fromUntypedTensor(s, dimensions); } - private static MapTensor fromTensor(String s, Set<String> additionalDimensions) { + private static MapTensor fromUntypedTensor(String s, Set<String> additionalDimensions) { s = s.trim().substring(1).trim(); ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); while (s.length() > 1) { @@ -94,6 +113,16 @@ public class MapTensor implements Tensor { dimensions.addAll(additionalDimensions); return new MapTensor(dimensions, cellMap); } + + private static MapTensor fromNumber(double number) { + ImmutableMap.Builder<TensorAddress, Double> singleCell = new ImmutableMap.Builder<>(); + singleCell.put(TensorAddress.empty, number); + return new MapTensor(ImmutableSet.of(), singleCell.build()); + } + + private static IllegalArgumentException tensorFormatException(String s) { + return new IllegalArgumentException("Expected a tensor on the form tensor(dimensionspec):content, but got '" + s + "'"); + } private static Double asDouble(TensorAddress address, String s) { try { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java deleted file mode 100644 index 074742acee1..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import com.google.common.collect.ImmutableMap; - -import java.util.Map; -import java.util.Set; - -/** - * Computes a <i>match product</i>, see {@link Tensor#match} - * - * @author bratseth - */ -class MatchProduct { - - private final Set<String> dimensions; - private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); - - public MatchProduct(Tensor a, Tensor b) { - this.dimensions = TensorOperations.combineDimensions(a, b); - for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { - Double sameValueInB = b.cells().get(aCell.getKey()); - if (sameValueInB != null) - cells.put(aCell.getKey(), aCell.getValue() * sameValueInB); - } - } - - /** Returns the result of taking this product */ - public MapTensor result() { - return new MapTensor(dimensions, cells.build()); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 41882738e89..4b17f65ea21 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -2,18 +2,25 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.L1Normalize; +import com.yahoo.tensor.functions.L2Normalize; +import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.Softmax; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleFunction; import java.util.function.DoubleUnaryOperator; -import java.util.function.UnaryOperator; +import java.util.function.Function; /** * A multidimensional array which can be used in computations. @@ -49,128 +56,74 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); - // ----------------- Level 0 functions + // ----------------- Primitive tensor functions - default Tensor map(Tensor tensor, DoubleUnaryOperator mapper) { - throw new UnsupportedOperationException("Not implemented"); + default Tensor map(DoubleUnaryOperator mapper) { + return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); } - default Tensor reduce(Tensor tensor, String dimension, - DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) { - throw new UnsupportedOperationException("Not implemented"); + /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ + default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) { + return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate(); } - default Tensor join(Tensor tensorA, Tensor tensorB, DoubleBinaryOperator combinator) { - throw new UnsupportedOperationException("Not implemented"); + default Tensor join(Tensor argument, DoubleBinaryOperator combinator) { + return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate(); } - // ----------------- Old stuff - /** - * Returns the <i>sparse tensor product</i> of this tensor and the argument tensor. - * This is the all-to-all combinations of cells in the argument tenors, except the combinations - * which have conflicting labels for the same dimension. The value of each combination is the product - * of the values of the two input cells. The dimensions of the tensor product is the set union of the - * dimensions of the argument tensors. - * <p> - * If there are no overlapping dimensions this is the regular tensor product. - * If the two tensors have exactly the same dimensions this is the Hadamard product. - * <p> - * The sparse tensor product is associative and commutative. - * - * @param argument the tensor to multiply by this - * @return the resulting tensor. - */ - default Tensor multiply(Tensor argument) { - return new TensorProduct(this, argument).result(); + default Tensor rename(String fromDimension, String toDimension) { + return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), + Collections.singletonList(toDimension)).evaluate(); } - /** - * Returns the <i>match product</i> of two tensors. - * This returns a tensor which contains the <i>matching</i> cells in the two tensors, with their - * values multiplied. - * <p> - * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, - * and have the value undefined for any non-shared dimension. - * <p> - * The dimensions of the resulting tensor is the set intersection of the two argument tensors. - * <p> - * If the two tensors have exactly the same dimensions, this is the Hadamard product. - */ - default Tensor match(Tensor argument) { - return new MatchProduct(this, argument).result(); + default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { + return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } - - /** - * Returns a tensor which contains the cells of both argument tensors, where the value for - * any <i>matching</i> cell is the min of the two possible values. - * <p> - * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, - * and have the value undefined for any non-shared dimension. - */ - default Tensor min(Tensor argument) { - return new TensorMin(this, argument).result(); + + static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) { + return new Generate(type, valueSupplier).evaluate(); } - - /** - * Returns a tensor which contains the cells of both argument tensors, where the value for - * any <i>matching</i> cell is the max of the two possible values. - * <p> - * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, - * and have the value undefined for any non-shared dimension. - */ - default Tensor max(Tensor argument) { - return new TensorMax(this, argument).result(); + + // ----------------- Composite tensor functions which have a defined primitive mapping + + default Tensor l1Normalize(String dimension) { + return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); } - /** - * Returns a tensor which contains the cells of both argument tensors, where the value for - * any <i>matching</i> cell is the sum of the two possible values. - * <p> - * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, - * and have the value undefined for any non-shared dimension. - */ - default Tensor add(Tensor argument) { - return new TensorSum(this, argument).result(); + default Tensor l2Normalize(String dimension) { + return new L2Normalize(new ConstantTensor(this), dimension).evaluate(); } - /** - * Returns a tensor which contains the cells of both argument tensors, where the value for - * any <i>matching</i> cell is the difference of the two possible values. - * <p> - * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, - * and have the value undefined for any non-shared dimension. - */ - default Tensor subtract(Tensor argument) { - return new TensorDifference(this, argument).result(); + default Tensor matmul(Tensor argument, String dimension) { + return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); } - /** - * Returns a tensor with the same cells as this and the given function is applied to all its cell values. - * - * @param function the function to apply to all cells - * @return the tensor with the function applied to all the cells of this - */ - default Tensor apply(UnaryOperator<Double> function) { - return new TensorFunction(this, function).result(); + default Tensor softmax(String dimension) { + return new Softmax(new ConstantTensor(this), dimension).evaluate(); } - /** - * Returns a tensor with the given dimension removed and cells which contains the sum of the values - * in the removed dimension. - */ - default Tensor sum(String dimension) { - return new TensorDimensionSum(dimension, this).result(); - } + // ----------------- Composite tensor functions mapped to primitives here on the fly - /** - * Returns the sum of all the cells of this tensor. - */ - default double sum() { - double sum = 0; - for (Map.Entry<TensorAddress, Double> cell : cells().entrySet()) - sum += cell.getValue(); - return sum; - } + default Tensor multiply(Tensor argument) { return join(argument, (a, b) -> (a * b )); } + default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); } + default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); } + default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); } + default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); } + default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } + default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } + default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); } + default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); } + default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); } + default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); } + default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } + default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } + + default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } + default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); } + default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); } + default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); } + default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); } + default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); } /** * Returns true if the given tensor is mathematically equal to this: @@ -226,19 +179,28 @@ public interface Tensor { * @return the tensor on the standard string format */ static String toStandardString(Tensor tensor) { - Set<String> emptyDimensions = emptyDimensions(tensor); - if (emptyDimensions.size() > 0) // explicitly list empty dimensions - return "( " + unitTensorWithDimensions(emptyDimensions) + " * " + contentToString(tensor) + " )"; + if ( emptyDimensions(tensor).size() > 0) // explicitly output type TODO: Always do that + return typeToString(tensor) + ":" + contentToString(tensor); else return contentToString(tensor); } + static String typeToString(Tensor tensor) { + if (tensor.dimensions().isEmpty()) return "tensor()"; + StringBuilder b = new StringBuilder("tensor("); + for (String dimension : tensor.dimensions()) + b.append(dimension).append("{},"); + b.setLength(b.length() -1); + b.append(")"); + return b.toString(); + } + static String contentToString(Tensor tensor) { - List<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); - Collections.sort(cellEntries, Map.Entry.<TensorAddress, Double>comparingByKey()); + List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); + Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); StringBuilder b = new StringBuilder("{"); - for (Map.Entry<TensorAddress, Double> cell : cellEntries) { + for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) { b.append(cell.getKey()).append(":").append(cell.getValue()); b.append(","); } @@ -259,8 +221,4 @@ public interface Tensor { return emptyDimensions; } - static String unitTensorWithDimensions(Set<String> dimensions) { - return new MapTensor(Collections.singletonMap(TensorAddress.emptyWithDimensions(dimensions), 1.0)).toString(); - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 11c6a5f6685..e3c089de071 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -8,12 +8,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; /** * An immutable address to a tensor cell. - * This is sparse: Only dimensions which have a different label than "undefined" are - * explicitly included. * <p> * Tensor addresses are ordered by increasing size primarily, and by the natural order of the elements in sorted * order secondarily. @@ -66,14 +65,6 @@ public final class TensorAddress implements Comparable<TensorAddress> { return TensorAddress.fromSorted(elements); } - /** Creates an empty address with a set of dimensions */ - public static TensorAddress emptyWithDimensions(Set<String> dimensions) { - List<Element> elements = new ArrayList<>(dimensions.size()); - for (String dimension : dimensions) - elements.add(new Element(dimension, Element.undefinedLabel)); - return TensorAddress.fromUnsorted(elements); - } - /** Returns an immutable list of the elements of this address in sorted order */ public List<Element> elements() { return elements; } @@ -93,6 +84,14 @@ public final class TensorAddress implements Comparable<TensorAddress> { return dimensions; } + /** Returns the label at the given dimension, or empty if this dimension is not present */ + public Optional<String> labelOfDimension(String dimension) { + for (TensorAddress.Element element : elements) + if (element.dimension().equals(dimension)) + return Optional.of(element.label()); + return Optional.empty(); + } + @Override public int compareTo(TensorAddress other) { int sizeComparison = Integer.compare(this.elements.size(), other.elements.size()); @@ -123,7 +122,6 @@ public final class TensorAddress implements Comparable<TensorAddress> { public String toString() { StringBuilder b = new StringBuilder("{"); for (TensorAddress.Element element : elements) { - //if (element.label() == Element.undefinedLabel) continue; b.append(element.toString()); b.append(","); } @@ -136,18 +134,13 @@ public final class TensorAddress implements Comparable<TensorAddress> { /** A tensor address element. Elements have the lexical order of the dimensions as natural order. */ public static class Element implements Comparable<Element> { - static final String undefinedLabel = "-"; - private final String dimension; private final String label; private final int hashCode; public Element(String dimension, String label) { this.dimension = dimension; - if (label.equals(undefinedLabel)) - this.label = undefinedLabel; - else - this.label = label; + this.label = label; this.hashCode = dimension.hashCode() + label.hashCode(); } @@ -175,9 +168,7 @@ public final class TensorAddress implements Comparable<TensorAddress> { @Override public String toString() { - StringBuilder b = new StringBuilder(); - b.append(dimension).append(":").append(label); - return b.toString(); + return dimension + ":" + label; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java deleted file mode 100644 index ceb003b1615..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the difference between two tensors, see {@link Tensor#subtract} - * - * @author bratseth - */ -class TensorDifference { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorDifference(Tensor a, Tensor b) { - this.dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) - cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) - bCell.getValue()); - } - - /** Returns the result of taking this sum */ - public Tensor result() { - return new MapTensor(dimensions, cells); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java deleted file mode 100644 index d15e5092476..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the max of each cell of two tensors, see {@link Tensor#max} - * - * @author bratseth - */ -class TensorMax { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorMax(Tensor a, Tensor b) { - dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - Double aValue = a.cells().get(bCell.getKey()); - if (aValue == null) - cells.put(bCell.getKey(), bCell.getValue()); - else - cells.put(bCell.getKey(), Math.max(aValue, bCell.getValue())); - } - } - - /** Returns the result of taking this sum */ - public Tensor result() { - return new MapTensor(dimensions, cells); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java deleted file mode 100644 index e389dea3883..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the min of each cell of two tensors, see {@link Tensor#min} - * - * @author bratseth - */ -class TensorMin { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorMin(Tensor a, Tensor b) { - dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - Double aValue = a.cells().get(bCell.getKey()); - if (aValue == null) - cells.put(bCell.getKey(), bCell.getValue()); - else - cells.put(bCell.getKey(), Math.min(aValue, bCell.getValue())); - } - } - - /** Returns the result of taking this sum */ - public Tensor result() { return new MapTensor(dimensions, cells); } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java deleted file mode 100644 index aca306b914c..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import com.google.common.collect.ImmutableSet; - -import java.util.Set; - -/** - * Functions on tensors - * - * @author bratseth - */ -class TensorOperations { - - /** - * A utility method which returns an ummutable set of the union of the dimensions - * of the two argument tensors. - * - * @return the combined dimensions as an unmodifiable set - */ - static Set<String> combineDimensions(Tensor a, Tensor b) { - ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>(); - setBuilder.addAll(a.dimensions()); - setBuilder.addAll(b.dimensions()); - return setBuilder.build(); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java deleted file mode 100644 index 221bd985380..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import com.google.common.collect.ImmutableMap; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.ListIterator; -import java.util.Map; -import java.util.Set; - -/** - * Computes a <i>sparse tensor product</i>, see {@link Tensor#multiply} - * - * @author bratseth - */ -class TensorProduct { - - private final Set<String> dimensionsA, dimensionsB; - - private final Set<String> dimensions; - private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); - - public TensorProduct(Tensor a, Tensor b) { - dimensionsA = a.dimensions(); - dimensionsB = b.dimensions(); - - // Dimension product - dimensions = TensorOperations.combineDimensions(a, b); - - // Cell product (slow baseline implementation) - for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - TensorAddress combinedAddress = combine(aCell.getKey(), bCell.getKey()); - if (combinedAddress == null) continue; // not combinable - cells.put(combinedAddress, aCell.getValue() * bCell.getValue()); - } - } - } - - private TensorAddress combine(TensorAddress a, TensorAddress b) { - List<TensorAddress.Element> combined = new ArrayList<>(); - combined.addAll(dense(a, dimensionsA)); - combined.addAll(dense(b, dimensionsB)); - Collections.sort(combined); - TensorAddress.Element previous = null; - for (ListIterator<TensorAddress.Element> i = combined.listIterator(); i.hasNext(); ) { - TensorAddress.Element current = i.next(); - if (previous != null && previous.dimension().equals(current.dimension())) { // an overlapping dimension - if (previous.label().equals(current.label())) - i.remove(); // a match: remove the duplicate - else - return null; // no match: a combination isn't viable - } - previous = current; - } - return TensorAddress.fromSorted(sparse(combined)); - } - - /** - * Returns a set of tensor elements which contains an entry for each dimension including "undefined" values - * (which are not present in the sparse elements list). - */ - private List<TensorAddress.Element> dense(TensorAddress sparse, Set<String> dimensions) { - if (sparse.elements().size() == dimensions.size()) return sparse.elements(); - - List<TensorAddress.Element> dense = new ArrayList<>(sparse.elements()); - for (String dimension : dimensions) { - if ( ! sparse.hasDimension(dimension)) - dense.add(new TensorAddress.Element(dimension, TensorAddress.Element.undefinedLabel)); - } - return dense; - } - - /** - * Removes any "undefined" entries from the given elements. - */ - private List<TensorAddress.Element> sparse(List<TensorAddress.Element> dense) { - List<TensorAddress.Element> sparse = new ArrayList<>(); - for (TensorAddress.Element element : dense) { - if ( ! element.label().equals(TensorAddress.Element.undefinedLabel)) - sparse.add(element); - } - return sparse; - } - - /** Returns the result of taking this product */ - public Tensor result() { - return new MapTensor(dimensions, cells.build()); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java deleted file mode 100644 index 85dfa289bd3..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the sum of two tensors, see {@link Tensor#add} - * - * @author bratseth - */ -class TensorSum { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorSum(Tensor a, Tensor b) { - dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) + bCell.getValue()); - } - } - - /** Returns the result of taking this sum */ - public Tensor result() { return new MapTensor(dimensions, cells); } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 23cdc0e6051..31454e28baf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -1,5 +1,7 @@ package com.yahoo.tensor.functions; +import com.yahoo.tensor.Tensor; + /** * A composite tensor function is a tensor function which can be expressed (less tersely) * as a tree of primitive tensor functions. @@ -8,4 +10,8 @@ package com.yahoo.tensor.functions; */ public abstract class CompositeTensorFunction extends TensorFunction { + /** Evaluates this by first converting it to a primitive function */ + @Override + public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java deleted file mode 100644 index 113247be3bb..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.yahoo.tensor.functions; - -import com.yahoo.tensor.MapTensor; - -/** - * A function which returns a constant tensor. - * - * @author bratseth - */ -public class Constant extends PrimitiveTensorFunction { - - private final MapTensor constant; - - public Constant(String tensorString) { - this.constant = MapTensor.from(tensorString); - } - - @Override - public PrimitiveTensorFunction toPrimitive() { return this; } - - @Override - public String toString() { return constant.toString(); } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java new file mode 100644 index 00000000000..0727579a331 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.Tensor; + +import java.util.Collections; +import java.util.List; + +/** + * A function which returns a constant tensor. + * + * @author bratseth + */ +public class ConstantTensor extends PrimitiveTensorFunction { + + private final Tensor constant; + + public ConstantTensor(String tensorString) { + this.constant = MapTensor.from(tensorString); + } + + public ConstantTensor(Tensor tensor) { + this.constant = tensor; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext context) { return constant; } + + @Override + public String toString(ToStringContext context) { return constant.toString(); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java new file mode 100644 index 00000000000..24a4c61a58c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java @@ -0,0 +1,14 @@ +package com.yahoo.tensor.functions; + +/** + * An evaluation context which is passed down to all nested functions during evaluation. + * The default implementation is empty as this library does not in itself have any need for a + * context. + * + * @author bratseth + */ +public interface EvaluationContext { + + static EvaluationContext empty() { return new EvaluationContext() {}; } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java new file mode 100644 index 00000000000..c0e5776bf48 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -0,0 +1,57 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.function.Function; + +/** + * An indexed tensor whose values are generated by a function + * + * @author bratseth + */ +public class Generate extends PrimitiveTensorFunction { + + private final TensorType type; + private final Function<List<Integer>, Double> generator; + + /** + * Creates a generated tensor + * + * @param type the type of the tensor + * @param generator the function generating values from a list of ints specifying the indexes of the + * tensor cell which will receive the value + * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound + */ + public Generate(TensorType type, Function<List<Integer>, Double> generator) { + Objects.requireNonNull(type, "The argument tensor type cannot be null"); + Objects.requireNonNull(generator, "The argument function cannot be null"); + validateType(type); + this.type = type; + this.generator = generator; + } + + private void validateType(TensorType type) { + for (TensorType.Dimension dimension : type.dimensions()) + if (dimension.type() != TensorType.Dimension.Type.indexedBound) + throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions"); + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext context) { + throw new UnsupportedOperationException("Not implemented"); // TODO + } + + @Override + public String toString(ToStringContext context) { return type + "(" + generator + ")"; } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 4d945963fdf..323da5906c3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -1,9 +1,24 @@ package com.yahoo.tensor.functions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; /** - * The join tensor function. + * The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells + * given by the cross product of the cells of the given tensors, having as values the value produced by + * applying the given combinator function on the values from the two source cells. * * @author bratseth */ @@ -13,6 +28,9 @@ public class Join extends PrimitiveTensorFunction { private final DoubleBinaryOperator combinator; public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { + Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); + Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); + Objects.requireNonNull(combinator, "The combinator function cannot be null"); this.argumentA = argumentA; this.argumentB = argumentB; this.combinator = combinator; @@ -21,15 +39,60 @@ public class Join extends PrimitiveTensorFunction { public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } - + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + @Override public PrimitiveTensorFunction toPrimitive() { return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } - + + @Override + public String toString(ToStringContext context) { + return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")"; + } + + private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); + @Override - public String toString() { - return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", lambda(a, b) (...))"; + public Tensor evaluate(EvaluationContext context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + + // Dimension product + Set<String> dimensions = combineDimensions(a, b); + + // Cell product (slow baseline implementation) + ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); + for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { + for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { + TensorAddress combinedAddress = combineAddresses(aCell.getKey(), bCell.getKey()); + if (combinedAddress == null) continue; // not combinable + cells.put(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue())); + } + } + + return new MapTensor(dimensions, cells.build()); } + private Set<String> combineDimensions(Tensor a, Tensor b) { + ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>(); + setBuilder.addAll(a.dimensions()); + setBuilder.addAll(b.dimensions()); + return setBuilder.build(); + } + + private TensorAddress combineAddresses(TensorAddress a, TensorAddress b) { + List<TensorAddress.Element> combined = new ArrayList<>(a.elements()); + for (TensorAddress.Element bElement : b.elements()) { + Optional<String> aLabel = a.labelOfDimension(bElement.dimension()); + if ( ! aLabel.isPresent()) + combined.add(bElement); + else if ( ! aLabel.get().equals(bElement.label())) + return null; // incompatible + } + return TensorAddress.fromUnsorted(combined); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java new file mode 100644 index 00000000000..4467b378b3f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -0,0 +1,36 @@ +package com.yahoo.tensor.functions; + +import java.util.Collections; +import java.util.List; + +/** + * @author bratseth + */ +public class L1Normalize extends CompositeTensorFunction { + + private final TensorFunction argument; + private final String dimension; + + public L1Normalize(TensorFunction argument, String dimension) { + this.argument = argument; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument = argument.toPrimitive(); + // join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y)) + return new Join(primitiveArgument, + new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension), + ScalarFunctions.divide()); + } + + @Override + public String toString(ToStringContext context) { + return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java new file mode 100644 index 00000000000..0e96b43bd22 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor.functions; + +import java.util.Collections; +import java.util.List; + +/** + * @author bratseth + */ +public class L2Normalize extends CompositeTensorFunction { + + private final TensorFunction argument; + private final String dimension; + + public L2Normalize(TensorFunction argument, String dimension) { + this.argument = argument; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument = argument.toPrimitive(); + return new Join(primitiveArgument, + new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.sqrt()), + ScalarFunctions.divide()); + } + + @Override + public String toString(ToStringContext context) { + return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 22dd08504d7..5db88953c64 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -1,10 +1,17 @@ package com.yahoo.tensor.functions; -import java.util.function.DoubleBinaryOperator; +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; import java.util.function.DoubleUnaryOperator; /** - * The join tensor function. + * The <i>map</i> tensor function produces a tensor where the given function is applied on each cell value. * * @author bratseth */ @@ -14,6 +21,8 @@ public class Map extends PrimitiveTensorFunction { private final DoubleUnaryOperator mapper; public Map(TensorFunction argument, DoubleUnaryOperator mapper) { + Objects.requireNonNull(argument, "The argument tensor cannot be null"); + Objects.requireNonNull(mapper, "The argument function cannot be null"); this.argument = argument; this.mapper = mapper; } @@ -22,13 +31,25 @@ public class Map extends PrimitiveTensorFunction { public DoubleUnaryOperator mapper() { return mapper; } @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override public PrimitiveTensorFunction toPrimitive() { return new Map(argument.toPrimitive(), mapper); } @Override - public String toString() { - return "map(" + argument.toString() + ", lambda(a) (...))"; + public Tensor evaluate(EvaluationContext context) { + Tensor argument = argument().evaluate(context); + ImmutableMap.Builder<TensorAddress, Double> mappedCells = new ImmutableMap.Builder<>(); + for (java.util.Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) + mappedCells.put(cell.getKey(), mapper.applyAsDouble(cell.getValue())); + return new MapTensor(argument.dimensions(), mappedCells.build()); + } + + @Override + public String toString(ToStringContext context) { + return "map(" + argument.toString(context) + ", " + mapper + ")"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java new file mode 100644 index 00000000000..4492ab083d4 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * @author bratseth + */ +public class Matmul extends CompositeTensorFunction { + + private final TensorFunction argument1, argument2; + private final String dimension; + + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { + this.argument1 = argument1; + this.argument2 = argument2; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument1 = argument1.toPrimitive(); + TensorFunction primitiveArgument2 = argument2.toPrimitive(); + return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension); + } + + @Override + public String toString(ToStringContext context) { + return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index 9c0c9abaeb7..91e58f4bf3b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -1,5 +1,7 @@ package com.yahoo.tensor.functions; +import com.yahoo.tensor.Tensor; + /** * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. @@ -8,4 +10,5 @@ package com.yahoo.tensor.functions; * @author bratseth */ public abstract class PrimitiveTensorFunction extends TensorFunction { + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java deleted file mode 100644 index 09038a294ce..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.yahoo.tensor.functions; - -/** - * The product tensor function - * - * @author bratseth - */ -public class Product extends CompositeTensorFunction { - - private final TensorFunction argumentA, argumentB; - - public Product(TensorFunction argumentA, TensorFunction argumentB) { - this.argumentA = argumentA; - this.argumentB = argumentB; - } - - @Override - public PrimitiveTensorFunction toPrimitive() { - return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), (a, b) -> a * b); - } - - @Override - public String toString() { - return "product(" + argumentA.toString() + ", " + argumentB.toString() + ")"; - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 4b306d376a6..ef18cb61b17 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -1,38 +1,246 @@ package com.yahoo.tensor.functions; -import java.util.Optional; -import java.util.function.DoubleBinaryOperator; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; /** - * The reduce tensor function. + * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions + * are collapsed to a single value using an aggregator function. * * @author bratseth */ public class Reduce extends PrimitiveTensorFunction { + public enum Aggregator { avg, count, prod, sum, max, min; } + private final TensorFunction argument; - private final String dimension; - private final DoubleBinaryOperator reductor; - private final Optional<DoubleBinaryOperator> postTransformation; + private final List<String> dimensions; + private final Aggregator aggregator; + + /** Creates a reduce function reducing aLL dimensions */ + public Reduce(TensorFunction argument, Aggregator aggregator) { + this(argument, aggregator, Collections.emptyList()); + } - public Reduce(TensorFunction argument, String dimension, - DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) { + /** Creates a reduce function reducing a single dimension */ + public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { + this(argument, aggregator, Collections.singletonList(dimension)); + } + + /** + * Creates a reduce function. + * + * @param argument the tensor to reduce + * @param aggregator the aggregator function to use + * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, + * producing a dimensionless tensor (a scalar). + * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor + */ + public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) { + Objects.requireNonNull(argument, "The argument tensor cannot be null"); + Objects.requireNonNull(aggregator, "The aggregator cannot be null"); + Objects.requireNonNull(dimensions, "The dimensions cannot be null"); this.argument = argument; - this.dimension = dimension; - this.reductor = reductor; - this.postTransformation = postTransformation; + this.aggregator = aggregator; + this.dimensions = ImmutableList.copyOf(dimensions); } public TensorFunction argument() { return argument; } @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override public PrimitiveTensorFunction toPrimitive() { - return new Reduce(argument.toPrimitive(), dimension, reductor, postTransformation); + return new Reduce(argument.toPrimitive(), aggregator, dimensions); + } + + @Override + public String toString(ToStringContext context) { + return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; + } + + private String commaSeparated(List<String> list) { + StringBuilder b = new StringBuilder(); + for (String element : list) + b.append(", ").append(element); + return b.toString(); } @Override - public String toString() { - return "reduce(" + argument.toString() + ", " + dimension + ", lambda(a, b) (...), lambda(a, b) (...))"; + public Tensor evaluate(EvaluationContext context) { + Tensor argument = this.argument.evaluate(context); + + if ( ! dimensions.isEmpty() && ! argument.dimensions().containsAll(dimensions)) + throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + + dimensions + ": Not all those dimensions are present in this tensor"); + + if (dimensions.isEmpty() || dimensions.size() == argument.dimensions().size()) + return reduceAll(argument); + + // Reduce dimensions + Set<String> reducedDimensions = new HashSet<>(argument.dimensions()); + reducedDimensions.removeAll(dimensions); + + // Reduce cells + Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); + for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) { + TensorAddress reducedAddress = reduceDimensions(cell.getKey(), reducedDimensions); + aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); + aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); + } + ImmutableMap.Builder<TensorAddress, Double> reducedCells = new ImmutableMap.Builder<>(); + for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) + reducedCells.put(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); + return new MapTensor(reducedDimensions, reducedCells.build()); + } + + private TensorAddress reduceDimensions(TensorAddress address, Set<String> reducedDimensions) { + return TensorAddress.fromSorted(address.elements().stream() + .filter(e -> reducedDimensions.contains(e.dimension())) + .collect(Collectors.toList())); + } + + private Tensor reduceAll(Tensor argument) { + ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); + for (Double cellValue : argument.cells().values()) + valueAggregator.aggregate(cellValue); + return new MapTensor(ImmutableMap.of(TensorAddress.empty, valueAggregator.aggregatedValue())); + } + + private static abstract class ValueAggregator { + + public static ValueAggregator ofType(Aggregator aggregator) { + switch (aggregator) { + case avg : return new AvgAggregator(); + case count : return new CountAggregator(); + case prod : return new ProdAggregator(); + case sum : return new SumAggregator(); + case max : return new MaxAggregator(); + case min : return new MinAggregator(); + default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); + } + + } + + /** Add a new value to those aggregated by this */ + public abstract void aggregate(double value); + + /** Returns the value aggregated by this */ + public abstract double aggregatedValue(); + + } + + private static class AvgAggregator extends ValueAggregator { + + private int valueCount = 0; + private double valueSum = 0.0; + + @Override + public void aggregate(double value) { + valueCount++; + valueSum+= value; + } + + @Override + public double aggregatedValue() { + return valueSum / valueCount; + } + + } + + private static class CountAggregator extends ValueAggregator { + + private int valueCount = 0; + + @Override + public void aggregate(double value) { + valueCount++; + } + + @Override + public double aggregatedValue() { + return valueCount; + } + + } + + private static class ProdAggregator extends ValueAggregator { + + private double valueProd = 1.0; + + @Override + public void aggregate(double value) { + valueProd *= value; + } + + @Override + public double aggregatedValue() { + return valueProd; + } + + } + + private static class SumAggregator extends ValueAggregator { + + private double valueSum = 0.0; + + @Override + public void aggregate(double value) { + valueSum += value; + } + + @Override + public double aggregatedValue() { + return valueSum; + } + + } + + private static class MaxAggregator extends ValueAggregator { + + private double maxValue = Double.MIN_VALUE; + + @Override + public void aggregate(double value) { + if (value > maxValue) + maxValue = value; + } + + @Override + public double aggregatedValue() { + return maxValue; + } + + } + + private static class MinAggregator extends ValueAggregator { + + private double minValue = Double.MAX_VALUE; + + @Override + public void aggregate(double value) { + if (value < minValue) + minValue = value; + } + + @Override + public double aggregatedValue() { + return minValue; + } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java new file mode 100644 index 00000000000..05af86c33e8 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -0,0 +1,100 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names. + * + * @author bratseth + */ +public class Rename extends PrimitiveTensorFunction { + + private final TensorFunction argument; + private final List<String> fromDimensions; + private final List<String> toDimensions; + + public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { + Objects.requireNonNull(argument, "The argument tensor cannot be null"); + Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); + Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null"); + if (fromDimensions.size() < 1) + throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension"); + if (fromDimensions.size() != toDimensions.size()) + throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " + + fromDimensions.size() + " and " + toDimensions.size()); + this.argument = argument; + this.fromDimensions = ImmutableList.copyOf(fromDimensions); + this.toDimensions = ImmutableList.copyOf(toDimensions); + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext context) { + Tensor tensor = argument.evaluate(context); + Map<String, String> fromToMap = fromToMap(); + Set<String> renamedDimensions = tensor.dimensions().stream() + .map((d) -> fromToMap.getOrDefault(d, d)) + .collect(Collectors.toSet()); + + ImmutableMap.Builder<TensorAddress, Double> renamedCells = new ImmutableMap.Builder<>(); + for (Map.Entry<TensorAddress, Double> cell : tensor.cells().entrySet()) { + TensorAddress renamedAddress = rename(cell.getKey(), fromToMap); + renamedCells.put(renamedAddress, cell.getValue()); + } + return new MapTensor(renamedDimensions, renamedCells.build()); + } + + private TensorAddress rename(TensorAddress address, Map<String, String> fromToMap) { + List<TensorAddress.Element> renamedElements = new ArrayList<>(); + for (TensorAddress.Element element : address.elements()) { + String toDimension = fromToMap.get(element.dimension()); + if (toDimension == null) + renamedElements.add(element); + else + renamedElements.add(new TensorAddress.Element(toDimension, element.label())); + } + return TensorAddress.fromUnsorted(renamedElements); + } + + @Override + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; + } + + private Map<String, String> fromToMap() { + Map<String, String> map = new HashMap<>(); + for (int i = 0; i < fromDimensions.size(); i++) + map.put(fromDimensions.get(i), toDimensions.get(i)); + return map; + } + + private String toVectorString(List<String> elements) { + if (elements.size() == 1) + return elements.get(0); + StringBuilder b = new StringBuilder("["); + for (String element : elements) + b.append(element).append(", "); + b.setLength(b.length() - 2); + return b.toString(); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java new file mode 100644 index 00000000000..9438c6c533a --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -0,0 +1,81 @@ +package com.yahoo.tensor.functions; + +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; + +/** + * Factory of scalar Java functions. + * The purpose of this is to embellish anonymous functions with a runtime type + * such that they can be inspected and will return a parseable toString. + * + * @author bratseth + */ +public class ScalarFunctions { + + public static DoubleBinaryOperator add() { return new Addition(); } + public static DoubleBinaryOperator multiply() { return new Multiplication(); } + public static DoubleBinaryOperator divide() { return new Division(); } + public static DoubleUnaryOperator square() { return new Square(); } + public static DoubleUnaryOperator sqrt() { return new Sqrt(); } + public static DoubleUnaryOperator exp() { return new Exponent(); } + + public static class Addition implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left + right; } + + @Override + public String toString() { return "f(a,b)(a + b)"; } + + } + + public static class Multiplication implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left * right; } + + @Override + public String toString() { return "f(a,b)(a * b)"; } + + } + + public static class Division implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left / right; } + + @Override + public String toString() { return "f(a,b)(a / b)"; } + } + + public static class Square implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { return operand * operand; } + + @Override + public String toString() { return "f(a)(a * a)"; } + + } + + public static class Sqrt implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { return Math.sqrt(operand); } + + @Override + public String toString() { return "f(a)(sqrt(a))"; } + + } + + public static class Exponent implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { return Math.exp(operand); } + + @Override + public String toString() { return "f(a)(exp(a))"; } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java new file mode 100644 index 00000000000..b05b8172b42 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -0,0 +1,37 @@ +package com.yahoo.tensor.functions; + +import java.util.Collections; +import java.util.List; + +/** + * @author bratseth + */ +public class Softmax extends CompositeTensorFunction { + + private final TensorFunction argument; + private final String dimension; + + public Softmax(TensorFunction argument, String dimension) { + this.argument = argument; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument = argument.toPrimitive(); + return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), + new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.divide()); + } + + @Override + public String toString(ToStringContext context) { + return "softmax(" + argument.toString(context) + ", " + dimension + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 95fca95a042..a717292632e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -1,5 +1,9 @@ package com.yahoo.tensor.functions; +import com.yahoo.tensor.Tensor; + +import java.util.List; + /** * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. @@ -9,6 +13,9 @@ package com.yahoo.tensor.functions; */ public abstract class TensorFunction { + /** Returns the function arguments of this node in the order they are applied */ + public abstract List<TensorFunction> functionArguments(); + /** * Translate this function - and all of its arguments recursively - * to a tree of primitive functions only. @@ -17,4 +24,24 @@ public abstract class TensorFunction { */ public abstract PrimitiveTensorFunction toPrimitive(); + /** + * Evaluates this tensor. + * + * @param context a context which must be passed to all nexted functions when evaluating + */ + public abstract Tensor evaluate(EvaluationContext context); + + /** Evaluate with no context */ + public final Tensor evaluate() { return evaluate(EvaluationContext.empty()); } + + /** + * Return a string representation of this context. + * + * @param context a context which must be passed to all nexted functions when requesting the string value + */ + public abstract String toString(ToStringContext context); + + @Override + public final String toString() { return toString(ToStringContext.empty()); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java new file mode 100644 index 00000000000..b71229703d2 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -0,0 +1,14 @@ +package com.yahoo.tensor.functions; + +/** + * A context which is passed down to all nested functions when returning a string representation. + * The default implementation is empty as this library does not in itself have any need for a + * context. + * + * @author bratseth + */ +public interface ToStringContext { + + static ToStringContext empty() { return new ToStringContext() {}; } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java new file mode 100644 index 00000000000..1988c1d2390 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -0,0 +1,45 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * @author bratseth + */ +public class XwPlusB extends CompositeTensorFunction { + + private final TensorFunction x, w, b; + private final String dimension; + + public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { + this.x = x; + this.w = w; + this.b = b; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveX = x.toPrimitive(); + TensorFunction primitiveW = w.toPrimitive(); + TensorFunction primitiveB = b.toPrimitive(); + return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension), + primitiveB, + ScalarFunctions.add()); + } + + @Override + public String toString(ToStringContext context) { + return "xw_plus_b(" + x.toString(context) + ", " + + w.toString(context) + ", " + + b.toString(context) + ", " + + dimension + ")"; + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java index 889b2851a08..af2260e2f20 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java @@ -42,7 +42,7 @@ public class MapTensorBuilderTestCase { Tensor tensor = new MapTensorBuilder().dimension("y").dimension("z"). cell().label("x", "0").value(1).build(); assertEquals(Sets.newHashSet("x", "y", "z"), tensor.dimensions()); - assertEquals("( {{y:-,z:-}:1.0} * {{x:0}:1.0} )", tensor.toString()); + assertEquals("tensor(x{},y{},z{}):{{x:0}:1.0}", tensor.toString()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java index 13ea0e95dc8..0372f328811 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java @@ -33,7 +33,7 @@ public class MapTensorTestCase { fail("Expected parse error"); } catch (IllegalArgumentException expected) { - assertEquals("Excepted a string starting by { or (, got '--'", expected.getMessage()); + assertEquals("Excepted a number or a string starting by { or tensor(, got '--'", expected.getMessage()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java new file mode 100644 index 00000000000..e403bb56d56 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -0,0 +1,28 @@ +package com.yahoo.tensor; + +import com.google.common.collect.ImmutableList; +import org.junit.Test; +import static org.junit.Assert.assertEquals; + +/** + * Tests functionality on Tensor + * + * @author bratseth + */ +public class TensorTestCase { + + /** This is mostly tested in searchlib - spot checking here */ + @Test + public void testTensorComputation() { + MapTensor tensor1 = MapTensor.from("{ {x:1}:3, {x:2}:7 }"); + MapTensor tensor2 = MapTensor.from("{ {y:1}:5 }"); + assertEquals(MapTensor.from("{ {x:1,y:1}:15, {x:2,y:1}:35 }"), tensor1.multiply(tensor2)); + assertEquals(MapTensor.from("{ {x:1,y:1}:12, {x:2,y:1}:28 }"), tensor1.join(tensor2, (a, b) -> a * b - a )); + assertEquals(MapTensor.from("{ {x:1,y:1}:0, {x:2,y:1}:1 }"), tensor1.larger(tensor2)); + assertEquals(MapTensor.from("{ {y:1}:50.0 }"), tensor1.matmul(tensor2, "x")); + assertEquals(MapTensor.from("{ {z:1}:3, {z:2}:7 }"), tensor1.rename("x", "z")); + assertEquals(MapTensor.from("{ {y:1,x:1}:8, {y:2,x:1}:12 }"), tensor1.add(tensor2).rename(ImmutableList.of("x", "y"), + ImmutableList.of("y", "x"))); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index 501397e89bc..cc9328f7274 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -12,8 +12,8 @@ public class TensorFunctionTestCase { @Test public void testTranslation() { - assertTranslated("join({{x:1}:1.0}, {{x:2}:1.0}, lambda(a, b) (...))", - new Product(new Constant("{{x:1}:1.0}"), new Constant("{{x:2}:1.0}"))); + assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))", + new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x")); } private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 8580868dfdf..c3a5e24afc2 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -52,14 +52,14 @@ public class SparseBinaryFormatTestCase { @Test public void testSerializationOfTensorsWithSparseTensorAddresses() { assertSerialization("{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x")); - assertSerialization("({{y:-}:1} * {{x:0}:2.0})", Sets.newHashSet("x", "y")); - assertSerialization("({{y:-}:1} * {{x:0}:2.0, {}:3.0})", Sets.newHashSet("x", "y")); - assertSerialization("({{y:-}:1} * {{x:0}:2.0,{x:1}:3.0})", Sets.newHashSet("x", "y")); - assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0})", Sets.newHashSet("x", "y", "z")); - assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0,{x:0,y:1}:3.0})", Sets.newHashSet("x", "y", "z")); - assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0})", Sets.newHashSet("x", "y", "z")); - assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0,{y:1,x:0}:3.0})", Sets.newHashSet("x", "y", "z")); - assertSerialization("({{z:-}:1} * {{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0})", Sets.newHashSet("x", "y", "z")); + assertSerialization("tensor(x{},y{}):{{x:0}:2.0}", Sets.newHashSet("x", "y")); + assertSerialization("tensor(x{},y{}):{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x", "y")); + assertSerialization("tensor(x{},y{}):{{x:0}:2.0,{x:1}:3.0}", Sets.newHashSet("x", "y")); + assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0}", Sets.newHashSet("x", "y", "z")); + assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0,{x:0,y:1}:3.0}", Sets.newHashSet("x", "y", "z")); + assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0}", Sets.newHashSet("x", "y", "z")); + assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0,{y:1,x:0}:3.0}", Sets.newHashSet("x", "y", "z")); + assertSerialization("tensor(x{},y{},z{}):{{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0}", Sets.newHashSet("x", "y", "z")); } @Test |