diff options
56 files changed, 1091 insertions, 1905 deletions
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 64bb538eab5..206ab8e30f0 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1049,7 +1049,7 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithDimensions() { - assertTensorField("tensor(x{},y{}):{}", + assertTensorField("( {{x:-,y:-}:1.0} * {} )", createPutWithTensor("{ " + " \"dimensions\": [\"x\",\"y\"] " + "}")); @@ -1101,7 +1101,7 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithDimensionsAndCells() { - assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}", + assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )", createPutWithTensor("{ " + " \"dimensions\": [\"x\",\"y\",\"z\"], " + " \"cells\": [ " @@ -1115,7 +1115,7 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithDimensionsAndCellsInDifferentJsonOrder() { - assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}", + assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )", createPutWithTensor("{ " + " \"cells\": [ " + " { \"address\": { \"x\": \"a\", \"y\": \"b\" }, " diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java index 252d40b7291..ba06843f178 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java @@ -69,9 +69,9 @@ public class RestApiTest { // POST new nodes assertResponse(new Request("http://localhost:8080/nodes/v2/node", ("[" + asNodeJson("host8.yahoo.com", "default") + "," + - asNodeJson("host9.yahoo.com", "large-variant") + "," + - asHostJson("parent2.yahoo.com", "large-variant") + "," + - asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]"). + asNodeJson("host9.yahoo.com", "large-variant") + "," + + asHostJson("parent2.yahoo.com", "large-variant") + "," + + asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]"). getBytes(StandardCharsets.UTF_8), Request.Method.POST), "{\"message\":\"Added 4 nodes to the provisioned state\"}"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 620c6fad0b4..0dff0414ac2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -2,7 +2,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.tensor.functions.EvaluationContext; import java.util.Set; @@ -11,7 +10,7 @@ import java.util.Set; * * @author bratseth */ -public abstract class Context implements EvaluationContext { +public abstract class Context { /** * <p>Returns the value of a simple variable name.</p> @@ -42,7 +41,7 @@ public abstract class Context implements EvaluationContext { * "main" (or only) value. */ public Value get(String name, Arguments arguments,String output) { - if (arguments!=null && arguments.expressions().size() > 0) + if (arguments!=null && arguments.expressions().size()>0) throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" + name + arguments + "'"); if (output==null) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index f8dcd8a6127..2bae382d5bd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public Value compare(TruthOperator operator, Value value) { - return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); + public boolean compare(TruthOperator operator, Value value) { + return operator.evaluate(asDouble(), value.asDouble()); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 0e0d793bfd1..028dad16d21 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -98,6 +98,16 @@ public final class DoubleValue extends DoubleCompatibleValue { } @Override + public boolean compare(TruthOperator operator, Value value) { + try { + return operator.evaluate(this.value, value.asDouble()); + } + catch (UnsupportedOperationException e) { + throw unsupported("comparison",value); + } + } + + @Override public Value function(Function function, Value value) { // use the tensor implementation of max and min if the argument is a tensor if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 2dffe2a1100..9ee9a1f7a71 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -34,9 +34,11 @@ public class MapContext extends Context { * Creates a map context from a map. * The ownership of the map is transferred to this - it cannot be further modified by the caller. * All the Values of the map will be frozen. + * + * @since 5.1.5 */ public MapContext(Map<String,Value> bindings) { - this.bindings = bindings; + this.bindings=bindings; for (Value boundValue : bindings.values()) boundValue.freeze(); } @@ -65,9 +67,6 @@ public class MapContext extends Context { if (frozen) return bindings; return Collections.unmodifiableMap(bindings); } - - /** Returns a new, modifiable context containing all the bindings of this */ - public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); } /** Returns an unmodifiable map of the names of this */ public @Override Set<String> names() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index eb997ab818a..379b5755c7b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -68,10 +68,10 @@ public class StringValue extends Value { } @Override - public Value compare(TruthOperator operator, Value value) { + public boolean compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) - return new BooleanValue(this.equals(value)); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); + return this.equals(value); + throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index b1f4a7b20ca..12bede95aae 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -8,7 +8,6 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.TensorType; -import java.util.Collections; import java.util.Optional; /** @@ -18,7 +17,7 @@ import java.util.Optional; * * @author bratseth */ -@Beta + @Beta public class TensorValue extends Value { /** The tensor value of this */ @@ -54,7 +53,7 @@ public class TensorValue extends Value { @Override public Value negate() { - return new TensorValue(value.map((value) -> -value)); + return new TensorValue(value.apply((Double value) -> -value)); } @Override @@ -62,7 +61,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.add(((TensorValue)argument).value)); else - return new TensorValue(value.map((value) -> value + argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value + argument.asDouble())); } @Override @@ -70,7 +69,7 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.subtract(((TensorValue) argument).value)); else - return new TensorValue(value.map((value) -> value - argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value - argument.asDouble())); } @Override @@ -78,15 +77,35 @@ public class TensorValue extends Value { if (argument instanceof TensorValue) return new TensorValue(value.multiply(((TensorValue) argument).value)); else - return new TensorValue(value.map((value) -> value * argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value * argument.asDouble())); } @Override public Value divide(Value argument) { if (argument instanceof TensorValue) - return new TensorValue(value.divide(((TensorValue) argument).value)); + throw new UnsupportedOperationException("Two tensors cannot be divided"); else - return new TensorValue(value.map((value) -> value / argument.asDouble())); + return new TensorValue(value.apply((Double value) -> value / argument.asDouble())); + } + + public Value match(Value argument) { + return new TensorValue(value.match(asTensor(argument, "match"))); + } + + public Value min(Value argument) { + return new TensorValue(value.min(asTensor(argument, "min"))); + } + + public Value max(Value argument) { + return new TensorValue(value.max(asTensor(argument, "max"))); + } + + public Value sum(String dimension) { + return new TensorValue(value.sum(dimension)); + } + + public Value sum() { + return new DoubleValue(value.sum()); } private Tensor asTensor(Value value, String operationName) { @@ -103,37 +122,18 @@ public class TensorValue extends Value { } @Override - public Value compare(TruthOperator operator, Value argument) { - return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); - } - - private Tensor compareTensor(TruthOperator operator, Tensor argument) { - switch (operator) { - case LARGER: return value.larger(argument); - case LARGEREQUAL: return value.largerOrEqual(argument); - case SMALLER: return value.smaller(argument); - case SMALLEREQUAL: return value.smallerOrEqual(argument); - case EQUAL: return value.equal(argument); - case NOTEQUAL: return value.notEqual(argument); - default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); - } + public boolean compare(TruthOperator operator, Value value) { + throw new UnsupportedOperationException("A tensor cannot be compared with any value"); } @Override - public Value function(Function function, Value arg) { - if (arg instanceof TensorValue) - return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString()))); + public Value function(Function function, Value argument) { + if (function.equals(Function.min) && argument instanceof TensorValue) + return min(argument); + else if (function.equals(Function.max) && argument instanceof TensorValue) + return max(argument); else - return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); - } - - private Tensor functionOnTensor(Function function, Tensor argument) { - switch (function) { - case min: return value.min(argument); - case max: return value.max(argument); - case atan2: return value.atan2(argument); - default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); - } + return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble()))); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 8ce18265231..e5680edc68a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -42,7 +42,7 @@ public abstract class Value { public abstract Value divide(Value value); /** Perform the comparison specified by the operator between this value and the given value */ - public abstract Value compare(TruthOperator operator, Value value); + public abstract boolean compare(TruthOperator operator,Value value); /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function,Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index af05acb365a..882d16ebc1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -8,9 +8,10 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import java.util.*; /** - * A node which returns the outcome of a comparison. + * A node which returns true or false depending on the outcome of a comparison. * * @author bratseth + * @since 5.1.21 */ public class ComparisonNode extends BooleanNode { @@ -47,9 +48,9 @@ public class ComparisonNode extends BooleanNode { @Override public Value evaluate(Context context) { - Value leftValue = leftCondition.evaluate(context); - Value rightValue = rightCondition.evaluate(context); - return leftValue.compare(operator,rightValue); + Value leftValue=leftCondition.evaluate(context); + Value rightValue=rightCondition.evaluate(context); + return new BooleanValue(leftValue.compare(operator,rightValue)); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index 19b1a83ed99..675ce758faa 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -12,38 +12,31 @@ import static java.lang.Math.*; */ public enum Function implements Serializable { - abs { public double evaluate(double x, double y) { return abs(x); } }, + cosh { public double evaluate(double x, double y) { return cosh(x); } }, + sinh { public double evaluate(double x, double y) { return sinh(x); } }, + tanh { public double evaluate(double x, double y) { return tanh(x); } }, + cos { public double evaluate(double x, double y) { return cos(x); } }, + sin { public double evaluate(double x, double y) { return sin(x); } }, + tan { public double evaluate(double x, double y) { return tan(x); } }, acos { public double evaluate(double x, double y) { return acos(x); } }, asin { public double evaluate(double x, double y) { return asin(x); } }, atan { public double evaluate(double x, double y) { return atan(x); } }, - ceil { public double evaluate(double x, double y) { return ceil(x); } }, - cos { public double evaluate(double x, double y) { return cos(x); } }, - cosh { public double evaluate(double x, double y) { return cosh(x); } }, - elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } }, exp { public double evaluate(double x, double y) { return exp(x); } }, + log10 { public double evaluate(double x, double y) { return log10(x); } }, + log { public double evaluate(double x, double y) { return log(x); } }, + sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, + ceil { public double evaluate(double x, double y) { return ceil(x); } }, fabs { public double evaluate(double x, double y) { return abs(x); } }, floor { public double evaluate(double x, double y) { return floor(x); } }, isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } }, - log { public double evaluate(double x, double y) { return log(x); } }, - log10 { public double evaluate(double x, double y) { return log10(x); } }, relu { public double evaluate(double x, double y) { return max(x,0); } }, - round { public double evaluate(double x, double y) { return round(x); } }, sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } }, - sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } }, - sin { public double evaluate(double x, double y) { return sin(x); } }, - sinh { public double evaluate(double x, double y) { return sinh(x); } }, - square { public double evaluate(double x, double y) { return x*x; } }, - sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, - tan { public double evaluate(double x, double y) { return tan(x); } }, - tanh { public double evaluate(double x, double y) { return tanh(x); } }, - atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } }, - fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } }, + pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }, ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } }, - max(2) { public double evaluate(double x, double y) { return max(x,y); } }, + fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, - mod(2) { public double evaluate(double x, double y) { return x % y; } }, - pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; + max(2) { public double evaluate(double x, double y) { return max(x,y); } }; private final int arity; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java deleted file mode 100644 index 7b48288598d..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ /dev/null @@ -1,122 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; -import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleUnaryOperator; - -/** - * A free, parametrized function - * - * @author bratseth - */ -public class LambdaFunctionNode extends CompositeNode { - - private final ImmutableList<String> arguments; - private final ExpressionNode functionExpression; - - public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { - // TODO: Verify that the function only accesses the arguments in mapperVariables - this.arguments = ImmutableList.copyOf(arguments); - this.functionExpression = functionExpression; - } - - @Override - public List<ExpressionNode> children() { - return Collections.singletonList(functionExpression); - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - if ( children.size() != 1) - throw new IllegalArgumentException("A lambda function must have a single child expression"); - return new LambdaFunctionNode(arguments, children.get(0)); - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")"; - } - - private String commaSeparated(List<String> list) { - StringBuilder b = new StringBuilder(); - for (String element : list) - b.append(element).append(","); - if (b.length() > 0) - b.setLength(b.length()-1); - return b.toString(); - } - - /** Evaluate this in a context which must have the arguments bound */ - @Override - public Value evaluate(Context context) { - return functionExpression.evaluate(context); - } - - /** - * Returns this as a double unary operator - * - * @throws IllegalStateException if this has more than one argument - */ - public DoubleUnaryOperator asDoubleUnaryOperator() { - if (arguments.size() > 1) - throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + - "Must have at most one argument " + " but has " + arguments); - return new DoubleUnaryLambda(); - } - - /** - * Returns this as a double binary operator - * - * @throws IllegalStateException if this has more than two arguments - */ - public DoubleBinaryOperator asDoubleBinaryOperator() { - if (arguments.size() > 2) - throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " + - "Must have at most two argument " + " but has " + arguments); - return new DoubleBinaryLambda(); - } - - private class DoubleUnaryLambda implements DoubleUnaryOperator { - - @Override - public double applyAsDouble(double operand) { - MapContext context = new MapContext(); - if (arguments.size() > 0) - context.put(arguments.get(0), operand); - return evaluate(context).asDouble(); - } - - @Override - public String toString() { - return LambdaFunctionNode.this.toString(); - } - - } - - private class DoubleBinaryLambda implements DoubleBinaryOperator { - - @Override - public double applyAsDouble(double left, double right) { - MapContext context = new MapContext(); - if (arguments.size() > 0) - context.put(arguments.get(0), left); - if (arguments.size() > 1) - context.put(arguments.get(1), right); - return evaluate(context).asDouble(); - } - - @Override - public String toString() { - return LambdaFunctionNode.this.toString(); - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java deleted file mode 100644 index 26d3f1dcc0e..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.EvaluationContext; -import com.yahoo.tensor.functions.PrimitiveTensorFunction; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.functions.ToStringContext; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; -import java.util.stream.Collectors; - -/** - * A node which performs a tensor function - * - * @author bratseth - */ - @Beta -public class TensorFunctionNode extends CompositeNode { - - private final TensorFunction function; - - public TensorFunctionNode(TensorFunction function) { - this.function = function; - } - - @Override - public List<ExpressionNode> children() { - return function.functionArguments().stream() - .map(f -> ((TensorFunctionExpressionNode)f).expression) - .collect(Collectors.toList()); - } - - @Override - public CompositeNode setChildren(List<ExpressionNode> children) { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - // Serialize as primitive - return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); - } - - @Override - public Value evaluate(Context context) { - return new TensorValue(function.evaluate(context)); - } - - public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { - return new TensorFunctionExpressionNode(node); - } - - /** - * A tensor function implemented by an expression. - * This allows us to pass expressions as tensor function arguments. - */ - public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { - - /** An expression which produces a tensor */ - private final ExpressionNode expression; - - public TensorFunctionExpressionNode(ExpressionNode expression) { - this.expression = expression; - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } - - @Override - public PrimitiveTensorFunction toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext context) { - Value result = expression.evaluate((Context)context); - if ( ! ( result instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + - "but this returns " + result + ", not a tensor"); - return ((TensorValue)result).asTensor(); - } - - @Override - public String toString(ToStringContext c) { - ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c; - return expression.toString(context.context, context.path, context.parent); - } - - } - - /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionNodeToStringContext implements ToStringContext { - - final SerializationContext context; - final Deque<String> path; - final CompositeNode parent; - - public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { - this.context = context; - this.path = path; - this.parent = parent; - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java new file mode 100644 index 00000000000..af309b3e8d8 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java @@ -0,0 +1,59 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; + +/** + * @author bratseth + */ + @Beta +public class TensorMatchNode extends CompositeNode { + + private final ExpressionNode left, right; + + public TensorMatchNode(ExpressionNode left, ExpressionNode right) { + this.left = left; + this.right = right; + } + + @Override + public List<ExpressionNode> children() { + List<ExpressionNode> children = new ArrayList<>(2); + children.add(left); + children.add(right); + return children; + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if ( children.size() != 2) + throw new IllegalArgumentException("A match product must have two children"); + return new TensorMatchNode(children.get(0), children.get(1)); + + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")"; + } + + @Override + public Value evaluate(Context context) { + return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context))); + } + + private TensorValue asTensor(Value value) { + if ( ! (value instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " + + "not a tensor: " + value); + return (TensorValue)value; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java new file mode 100644 index 00000000000..a1f83157e20 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java @@ -0,0 +1,65 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Optional; + +/** + * A node which sums over all cells in the argument tensor + * + * @author bratseth + */ + @Beta +public class TensorSumNode extends CompositeNode { + + /** The tensor to sum */ + private final ExpressionNode argument; + + /** The dimension to sum over, or empty to sum all cells to a scalar */ + private final Optional<String> dimension; + + public TensorSumNode(ExpressionNode argument, Optional<String> dimension) { + this.argument = argument; + this.dimension = dimension; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(argument); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument"); + return new TensorSumNode(children.get(0), dimension); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "sum(" + + argument.toString(context, path, parent) + + ( dimension.isPresent() ? ", " + dimension.get() : "" ) + + ")"; + } + + @Override + public Value evaluate(Context context) { + Value argumentValue = argument.evaluate(context); + if ( ! ( argumentValue instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " + + "but this returns " + argumentValue + ", not a tensor"); + TensorValue tensorArgument = (TensorValue)argumentValue; + if (dimension.isPresent()) + return tensorArgument.sum(dimension.get()); + else + return tensorArgument.sum(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java index 932975f3b63..60fe19f909f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java @@ -15,8 +15,7 @@ public enum TruthOperator implements Serializable { EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, - NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; + LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }; private final String operatorString; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 0fcfdb5d40c..78ad665c414 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -21,9 +21,10 @@ import com.yahoo.searchlib.rankingexpression.rule.*; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.*; -import com.yahoo.tensor.functions.*; +import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.TensorAddress; import java.util.Collections; +import java.util.Map; import java.util.LinkedHashMap; import java.util.Arrays; import java.util.ArrayList; @@ -59,83 +60,51 @@ TOKEN : <RSQUARE: "]"> | <LCURLY: "{"> | <RCURLY: "}"> | - <ADD: "+"> | <SUB: "-"> | <DIV: "/"> | <MUL: "*"> | <DOT: "."> | - <DOLLAR: "$"> | <COMMA: ","> | <COLON: ":"> | - <LE: "<="> | <LT: "<"> | <EQ: "=="> | - <NQ: "!="> | <AQ: "~="> | <GE: ">="> | <GT: ">"> | - <STRING: ("\"" (~["\""] | "\\\"")* "\"") | ("'" (~["'"] | "\\'")* "'")> | - <IF: "if"> | - <IN: "in"> | - <F: "f"> | - - <ABS: "abs"> | + <COSH: "cosh"> | + <SINH: "sinh"> | + <TANH: "tanh"> | + <COS: "cos"> | + <SIN: "sin"> | + <TAN: "tan"> | <ACOS: "acos"> | <ASIN: "asin"> | + <ATAN2: "atan2"> | <ATAN: "atan"> | - <CEIL: "ceil"> | - <COS: "cos"> | - <COSH: "cosh"> | - <ELU: "elu"> | <EXP: "exp"> | + <LDEXP: "ldexp"> | + <LOG10: "log10"> | + <LOG: "log"> | + <POW: "pow"> | + <SQRT: "sqrt"> | + <CEIL: "ceil"> | <FABS: "fabs"> | <FLOOR: "floor"> | + <FMOD: "fmod"> | + <MIN: "min"> | + <MAX: "max"> | <ISNAN: "isNan"> | - <LOG: "log"> | - <LOG10: "log10"> | + <IN: "in"> | + <SUM: "sum"> | + <MATCH: "match"> | <RELU: "relu"> | - <ROUND: "round"> | <SIGMOID: "sigmoid"> | - <SIGN: "sign"> | - <SIN: "sin"> | - <SINH: "sinh"> | - <SQUARE: "square"> | - <SQRT: "sqrt"> | - <TAN: "tan"> | - <TANH: "tanh"> | - - <ATAN2: "atan2"> | - <FMOD: "fmod"> | - <LDEXP: "ldexp"> | - // MAX - // MIN - <MOD: "mod"> | - <POW: "pow"> | - - <MAP: "map"> | - <REDUCE: "reduce"> | - <JOIN: "join"> | - <RENAME: "rename"> | - <TENSOR: "tensor"> | - <L1_NORMALIZE: "l1_normalize"> | - <L2_NORMALIZE: "l2_normalize"> | - <MATMUL: "matmul"> | - <SOFTMAX: "softmax"> | - <XW_PLUS_B: "xw_plus_b"> | - - <AVG: "avg" > | - <COUNT: "count"> | - <PROD: "prod"> | - <SUM: "sum"> | - <MAX: "max"> | - <MIN: "min"> | - <IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)> } @@ -206,7 +175,6 @@ TruthOperator comparator() : { } ( <LE> { return TruthOperator.SMALLEREQUAL; } | <LT> { return TruthOperator.SMALLER; } | <EQ> { return TruthOperator.EQUAL; } | - <NQ> { return TruthOperator.NOTEQUAL; } | <AQ> { return TruthOperator.APPROX_EQUAL; } | <GE> { return TruthOperator.LARGEREQUAL; } | <GT> { return TruthOperator.LARGER; } ) @@ -221,6 +189,7 @@ ExpressionNode value() : { ( [ LOOKAHEAD(2) <SUB> { neg = true; } ] ( ret = constantPrimitive() | + ret = constantTensor() | LOOKAHEAD(2) ret = ifExpression() | LOOKAHEAD(2) ret = function() | ret = feature() | @@ -310,6 +279,7 @@ ExpressionNode arg() : } { ( ret = constantPrimitive() | + ret = constantTensor() | LOOKAHEAD(2) ret = feature() | name = identifier() { ret = new NameNode(name); } ) { return ret; } @@ -320,11 +290,11 @@ ExpressionNode function() : ExpressionNode function; } { - ( function = scalarOrTensorFunction() | function = tensorFunction() ) + ( function = scalarFunction() | function = tensorFunction() ) { return function; } } -FunctionNode scalarOrTensorFunction() : +FunctionNode scalarFunction() : { Function function; ExpressionNode arg1, arg2; @@ -342,223 +312,61 @@ FunctionNode scalarOrTensorFunction() : ExpressionNode tensorFunction() : { - ExpressionNode tensorExpression; -} -{ - ( - tensorExpression = tensorMap() | - tensorExpression = tensorReduce() | - tensorExpression = tensorReduceComposites() | - tensorExpression = tensorJoin() | - tensorExpression = tensorRename() | - tensorExpression = tensorGenerate() | - tensorExpression = tensorL1Normalize() | - tensorExpression = tensorL2Normalize() | - tensorExpression = tensorMatmul() | - tensorExpression = tensorSoftmax() | - tensorExpression = tensorXwPlusB() - ) - { return tensorExpression; } -} - -ExpressionNode tensorMap() : -{ - ExpressionNode tensor; - LambdaFunctionNode doubleMapper; -} -{ - <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), - doubleMapper.asDoubleUnaryOperator())); } -} - -ExpressionNode tensorReduce() : -{ - ExpressionNode tensor; - Reduce.Aggregator aggregator; - List<String> dimensions = null; -} -{ - <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } -} - -ExpressionNode tensorReduceComposites() : -{ - ExpressionNode tensor; - Reduce.Aggregator aggregator; - List<String> dimensions = null; -} -{ - aggregator = tensorReduceAggregator() - <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } -} - -ExpressionNode tensorJoin() : -{ ExpressionNode tensor1, tensor2; - LambdaFunctionNode doubleJoiner; + String dimension = null; + TensorAddress address = null; } { - <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - doubleJoiner.asDoubleBinaryOperator())); } -} - -ExpressionNode tensorRename() : -{ - ExpressionNode tensor; - List<String> fromDimensions, toDimensions; -} -{ - <RENAME> <LBRACE> tensor = expression() <COMMA> - fromDimensions = bracedIdentifierList() <COMMA> - toDimensions = bracedIdentifierList() - <RBRACE> - { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } -} - -// TODO: Notice that null is parsed below -ExpressionNode tensorGenerate() : -{ - TensorType type; - LambdaFunctionNode generator; -} -{ - <TENSOR> <LBRACE> <RBRACE> <LBRACE> - { return new TensorFunctionNode(new Generate(null, null)); } -} - -ExpressionNode tensorL1Normalize() : -{ - ExpressionNode tensor; - String dimension; -} -{ - <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } -} - -ExpressionNode tensorL2Normalize() : -{ - ExpressionNode tensor; - String dimension; -} -{ - <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } -} - -ExpressionNode tensorMatmul() : -{ - ExpressionNode tensor1, tensor2; - String dimension; -} -{ - <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - dimension)); } -} - -ExpressionNode tensorSoftmax() : -{ - ExpressionNode tensor; - String dimension; -} -{ - <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } -} - -ExpressionNode tensorXwPlusB() : -{ - ExpressionNode tensor1, tensor2, tensor3; - String dimension; -} -{ - <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA> - tensor2 = expression() <COMMA> - tensor3 = expression() <COMMA> - dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - TensorFunctionNode.wrapArgument(tensor3), - dimension)); } -} - -LambdaFunctionNode lambdaFunction() : -{ - List<String> variables; - ExpressionNode functionExpression; -} -{ - ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) - { return new LambdaFunctionNode(variables, functionExpression); } -} - -Reduce.Aggregator tensorReduceAggregator() : -{ -} -{ - ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) - { return Reduce.Aggregator.valueOf(token.image); } + ( + <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE> + { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); } + ) | + ( + <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE> + { return new TensorMatchNode(tensor1, tensor2); } + ) } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { - Reduce.Aggregator aggregator; } { - ( <F> { return token.image; } ) | - ( <MAP> { return token.image; } ) | - ( <REDUCE> { return token.image; } ) | - ( <JOIN> { return token.image; } ) | - ( <RENAME> { return token.image; } ) | - ( <TENSOR> { return token.image; } ) | - ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) + ( <SUM> | <MATCH> ) + { return token.image; } } Function unaryFunctionName() : { } { - <ABS> { return Function.abs; } | + <COS> { return Function.cos; } | + <SIN> { return Function.sin; } | + <TAN> { return Function.tan; } | + <COSH> { return Function.cosh; } | + <SINH> { return Function.sinh; } | + <TANH> { return Function.tanh; } | <ACOS> { return Function.acos; } | <ASIN> { return Function.asin; } | <ATAN> { return Function.atan; } | - <CEIL> { return Function.ceil; } | - <COS> { return Function.cos; } | - <COSH> { return Function.cosh; } | - <ELU> { return Function.elu; } | <EXP> { return Function.exp; } | + <LOG10> { return Function.log10; } | + <LOG> { return Function.log; } | + <SQRT> { return Function.sqrt; } | + <CEIL> { return Function.ceil; } | <FABS> { return Function.fabs; } | <FLOOR> { return Function.floor; } | <ISNAN> { return Function.isNan; } | - <LOG> { return Function.log; } | - <LOG10> { return Function.log10; } | <RELU> { return Function.relu; } | - <ROUND> { return Function.round; } | - <SIGMOID> { return Function.sigmoid; } | - <SIGN> { return Function.sign; } | - <SIN> { return Function.sin; } | - <SINH> { return Function.sinh; } | - <SQUARE> { return Function.square; } | - <SQRT> { return Function.sqrt; } | - <TAN> { return Function.tan; } | - <TANH> { return Function.tanh; } + <SIGMOID> { return Function.sigmoid; } } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | - <FMOD> { return Function.fmod; } | <LDEXP> { return Function.ldexp; } | - <MAX> { return Function.max; } | + <POW> { return Function.pow; } | + <FMOD> { return Function.fmod; } | <MIN> { return Function.min; } | - <MOD> { return Function.mod; } | - <POW> { return Function.pow; } + <MAX> { return Function.max; } } List<ExpressionNode> expressionList() : @@ -597,28 +405,6 @@ String identifier() : <IDENTIFIER> { return token.image; } } -List<String> identifierList() : -{ - List<String> list = new ArrayList<String>(); - String element; -} -{ - ( element = identifier() { list.add(element); } )? - ( <COMMA> element = identifier() { list.add(element); } ) * - { return list; } -} - -List<String> bracedIdentifierList() : -{ - List<String> list = new ArrayList<String>(); - String element; -} -{ - ( element = identifier() { return Collections.singletonList(element); } ) - | - ( <LBRACE> list = identifierList() <RBRACE> { return list; } ) -} - // An identifier or integer String tag() : { @@ -629,16 +415,6 @@ String tag() : <INTEGER> { return token.image; } } -List<String> tagCommaLeadingList() : -{ - List<String> list = new ArrayList<String>(); - String element; -} -{ - ( <COMMA> element = tag() { list.add(element); } ) * - { return list; } -} - ConstantNode constantPrimitive() : { String sign = ""; @@ -658,3 +434,50 @@ Value primitiveValue() : ( <INTEGER> | <FLOAT> | <STRING> ) { return Value.parse(sign + token.image); } } + +ConstantNode constantTensor() : +{ + Value constantValue; +} +{ + <LCURLY> constantValue = tensorContent() <RCURLY> + { return new ConstantNode(constantValue); } +} + +TensorValue tensorContent() : +{ + Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>(); + TensorAddress address; + Double value; +} +{ + ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ? + ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) * + { return new TensorValue(new MapTensor(cells)); } +} + +TensorAddress tensorAddress() : +{ + List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>(); + String dimension; + String label; +} +{ + <LCURLY> + ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ? + ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) * + <RCURLY> + { return TensorAddress.fromUnsorted(elements); } +} + +String label() : +{ + String label; + +} +{ + ( label = tag() | + ( "-" { label = "-"; } ) ) + { return label; } +} + diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index f28ff739b4c..24d7c82235c 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -6,10 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; -import org.junit.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertFalse; +import junit.framework.TestCase; import java.io.BufferedReader; import java.io.File; @@ -17,18 +14,15 @@ import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.*; /** - * @author Simon Thoresen - * @author bratseth + * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> */ -public class RankingExpressionTestCase { +public class RankingExpressionTestCase extends TestCase { - @Test public void testParamInFeature() throws ParseException { assertParse("if (1 > 2, dotProduct(allparentid,query(cate1_parentid)), 2)", "if ( 1 > 2,\n" + @@ -37,7 +31,6 @@ public class RankingExpressionTestCase { ")"); } - @Test public void testDollarShorthand() throws ParseException { assertParse("query(var1)", " $var1"); assertParse("query(var1)", " $var1 "); @@ -51,7 +44,6 @@ public class RankingExpressionTestCase { assertParse("if (if (f1.out < query(p1), 0, 1) < if (f2.out < query(p2), 0, 1), f3.out, query(p3))", "if(if(f1.out<$p1,0,1)<if(f2.out<$p2,0,1),f3.out,$p3)"); } - @Test public void testLookaheadIndefinitely() throws Exception { ExecutorService exec = Executors.newSingleThreadExecutor(); Future<Boolean> future = exec.submit(new Callable<Boolean>() { @@ -68,8 +60,7 @@ public class RankingExpressionTestCase { assertTrue(future.get(60, TimeUnit.SECONDS)); } - @Test - public void testSelfRecursionSerialization() throws ParseException { + public void testSelfRecursionScript() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo"))); @@ -81,8 +72,7 @@ public class RankingExpressionTestCase { } } - @Test - public void testMacroCycleSerialization() throws ParseException { + public void testMacroCycleScript() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar"))); macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo"))); @@ -95,48 +85,42 @@ public class RankingExpressionTestCase { } } - @Test - public void testSerialization() throws ParseException { + public void testScript() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))"))); macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2"))); macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)"))); macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977"))); - assertSerialization(Arrays.asList( - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", - "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))", - "min(6,pow(7,2))", - "min(1,pow(2,2))", - "min(3,pow(4,2))", - "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros); - assertSerialization(Arrays.asList( - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", - "min(1,pow(2,2))", - "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros); - assertSerialization(Arrays.asList( - "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", - "min(1,pow(2,2))", - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", - "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros); - assertSerialization(Arrays.asList( - "rankingExpression(cox)", - "10 + 08 * 1977"), "cox", macros - ); - } - - @Test - public void testTensorSerialization() { - assertSerialization("map(constant(tensor0), f(a)(cos(a)))", - "map(constant(tensor0), f(a)(cos(a)))"); - assertSerialization("map(constant(tensor0), f(a)(cos(a))) + join(attribute(tensor1), map(reduce(map(attribute(tensor1), f(a)(a * a)), sum, x), f(a)(sqrt(a))), f(a,b)(a / b))", - "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)"); - assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))", - "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)"); - + assertScript("foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros, + Arrays.asList( + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", + "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))", + "min(6,pow(7,2))", + "min(1,pow(2,2))", + "min(3,pow(4,2))", + "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))" + )); + assertScript("foo(1, 2) + bar(3, 4)", macros, + Arrays.asList( + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", + "min(1,pow(2,2))", + "3 * 3 + 2 * 3 * 4 + 4 * 4" + )); + assertScript("baz(1, 2)", macros, + Arrays.asList( + "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", + "min(1,pow(2,2))", + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", + "1 * 1 + 2 * 1 * 2 + 2 * 2" + )); + assertScript("cox", macros, + Arrays.asList( + "rankingExpression(cox)", + "10 + 08 * 1977" + )); } - @Test public void testBug3464208() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); @@ -151,11 +135,18 @@ public class RankingExpressionTestCase { String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " + "rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)"; - assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros); - assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros); + assertScript(lhs + " + " + rhs, macros, + Arrays.asList( + expLhs + " + " + expRhs, + "69" + )); + assertScript(lhs + " - " + rhs, macros, + Arrays.asList( + expLhs + " - " + expRhs, + "69" + )); } - @Test public void testParse() throws ParseException, IOException { BufferedReader reader = new BufferedReader(new FileReader("src/tests/rankingexpression/rankingexpressionlist")); String line; @@ -190,43 +181,36 @@ public class RankingExpressionTestCase { } } - @Test public void testIssue() throws ParseException { assertEquals("feature.0", new RankingExpression("feature.0").toString()); assertEquals("if (1 > 2, 3, 4) + feature(arg1).out.out", new RankingExpression("if ( 1 > 2 , 3 , 4 ) + feature ( arg1 ) . out.out").toString()); } - @Test public void testNegativeConstantArgument() throws ParseException { assertEquals("foo(-1.2)", new RankingExpression("foo(-1.2)").toString()); } - @Test public void testNaming() throws ParseException { RankingExpression test = new RankingExpression("a+b"); test.setName("test"); assertEquals("test: a + b", test.toString()); } - @Test public void testCondition() throws ParseException { RankingExpression expression = new RankingExpression("if(1<2,3,4)"); assertTrue(expression.getRoot() instanceof IfNode); } - @Test public void testFileImporting() throws ParseException { RankingExpression expression = new RankingExpression(new File("src/test/files/simple.expression")); assertEquals("simple: a + b", expression.toString()); } - @Test public void testNonCanonicalLegalStrings() throws ParseException { assertParse("a * b + c * d", "a* (b) + \nc*d"); } - @Test public void testEquality() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo)==\"BAR\",log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"), new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); @@ -235,7 +219,6 @@ public class RankingExpressionTestCase { new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).earliness) * fieldMatch(title).completeness)"))); } - @Test public void testSetMembershipConditions() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"), new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); @@ -248,7 +231,6 @@ public class RankingExpressionTestCase { assertEquals(new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"), new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)")); } - @Test public void testComments() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],\n" + "# a comment\n" + @@ -259,7 +241,6 @@ public class RankingExpressionTestCase { new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); } - @Test public void testIsNan() throws ParseException { String strExpr = "if (isNan(attribute(foo)) == 1.0, 1.0, attribute(foo))"; RankingExpression expr = new RankingExpression(strExpr); @@ -274,59 +255,27 @@ public class RankingExpressionTestCase { assertEquals(expected, new RankingExpression(expression).toString()); } - /** Test serialization with no macros */ - private void assertSerialization(String expectedSerialization, String expressionString) { - String serializedExpression; - try { - RankingExpression expression = new RankingExpression(expressionString); - // No macros -> expect one rank property - serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next(); - assertEquals(expectedSerialization, serializedExpression); - } - catch (ParseException e) { - throw new IllegalArgumentException(e); - } - - try { - // No macros -> output should be parseable to a ranking expression - // (but not the same one due to primitivization) - RankingExpression reparsedExpression = new RankingExpression(serializedExpression); - // Serializing the primitivized expression should yield the same expression again - String reserializedExpression = - reparsedExpression.getRankProperties(Collections.emptyList()).values().iterator().next(); - assertEquals(expectedSerialization, reserializedExpression); - } - catch (ParseException e) { - throw new IllegalArgumentException("Could not parse the serialized expression", e); - } - } - - private void assertSerialization(List<String> expectedSerialization, String expressionString, - List<ExpressionFunction> macros) { - assertSerialization(expectedSerialization, expressionString, macros, false); - } - private void assertSerialization(List<String> expectedSerialization, String expressionString, - List<ExpressionFunction> macros, boolean print) { - try { - if (print) - System.out.println("Parsing expression '" + expressionString + "'."); - - RankingExpression expression = new RankingExpression(expressionString); - Map<String, String> rankProperties = expression.getRankProperties(macros); - if (print) { - for (String key : rankProperties.keySet()) - System.out.println("Property '" + key + "': " + rankProperties.get(key)); + private void assertScript(String expression, List<ExpressionFunction> macros, List<String> expectedScripts) + throws ParseException { + boolean print = false; + if (print) + System.out.println("Parsing expression '" + expression + "'."); + + RankingExpression exp = new RankingExpression(expression); + Map<String, String> scripts = exp.getRankProperties(macros); + if (print) { + for (String key : scripts.keySet()) { + System.out.println("Script '" + key + "': " + scripts.get(key)); } - for (int i = 0; i < expectedSerialization.size();) { - String val = expectedSerialization.get(i++); - assertTrue("Properties contains " + val, rankProperties.containsValue(val)); - } - if (print) - System.out.println(""); } - catch (ParseException e) { - throw new IllegalArgumentException(e); + + for (Map.Entry<String, String> m : scripts.entrySet()) + System.out.println(m); + for (int i = 0; i < expectedScripts.size();) { + String val = expectedScripts.get(i++); + assertTrue("Script contains " + val, scripts.containsValue(val)); } + if (print) + System.out.println(""); } - } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 93800e2c246..b67a423181d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -20,7 +20,7 @@ import java.util.Set; */ public class EvaluationTestCase extends junit.framework.TestCase { - private MapContext defaultContext; + private Context defaultContext; @Override protected void setUp() { @@ -100,180 +100,201 @@ public class EvaluationTestCase extends junit.framework.TestCase { @Test public void testTensorEvaluation() { - assertEvaluates("{}", "tensor0", "{}"); + assertEvaluates("{}", "{}"); // empty + assertEvaluates("( {{x:-}:1} * {} )", "( {{x:-}:1} * {} )"); // empty with dimensions - // tensor map + // sum(tensor) + assertEvaluates(5.0, "sum({{}:5.0})"); + assertEvaluates(-5.0, "sum({{}:-5.0})"); + assertEvaluates(12.5, "sum({ {d1:l1}:5.5, {d2:l2}:7.0 })"); + assertEvaluates(0.0, "sum({ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0})"); + + // scalar functions on tensors assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", - "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:4, {d1:l1}:9, {d1:l1,d2:l1 }:16 }", - "map(tensor0, f(x) (x * x))", "{ {}:2, {d1:l1}:3, {d1:l1,d2:l1}:4 }"); - // -- tensor map composites - assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", - "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + "log10({ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 })"); + assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", + "5 * { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", + "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } + 3"); + assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", + "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } / 10"); assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }", - "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + "- { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }", - "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); + "min({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)"); assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }", - "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); - // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) - assertEvaluates("{ {x:1}:1, {x:2}:2 }", "abs(tensor0)", "{ {x:1}:1, {x:2}:-2 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "acos(tensor0)", "{ {x:1}:1, {x:2}:1 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "asin(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "atan(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:1, {x:2}:2 }", "ceil(tensor0)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cos(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cosh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:1, {x:2}:2 }", "elu(tensor0)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:1, {x:2}:1 }", "exp(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:1, {x:2}:2 }", "fabs(tensor0)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:1, {x:2}:2 }", "floor(tensor0)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "isNan(tensor0)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "log(tensor0)", "{ {x:1}:1, {x:2}:1 }"); - assertEvaluates("{ {x:1}:0, {x:2}:1 }", "log10(tensor0)", "{ {x:1}:1, {x:2}:10 }"); - assertEvaluates("{ {x:1}:0, {x:2}:2 }", "mod(tensor0, 3)", "{ {x:1}:3, {x:2}:8 }"); - assertEvaluates("{ {x:1}:1, {x:2}:8 }", "pow(tensor0, 3)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:1, {x:2}:2 }", "relu(tensor0)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:1, {x:2}:2 }", "round(tensor0)", "{ {x:1}:1, {x:2}:1.8 }"); - assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "sigmoid(tensor0)","{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:1, {x:2}:-1 }", "sign(tensor0)", "{ {x:1}:3, {x:2}:-5 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sin(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sinh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:1, {x:2}:4 }", "square(tensor0)", "{ {x:1}:1, {x:2}:2 }"); - assertEvaluates("{ {x:1}:1, {x:2}:3 }", "sqrt(tensor0)", "{ {x:1}:1, {x:2}:9 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tan(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tanh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); - - // tensor reduce - // -- reduce 2 dimensions - assertEvaluates("{ {}:4 }", - "reduce(tensor0, avg, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {}:4 }", - "reduce(tensor0, count, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {}:105 }", - "reduce(tensor0, prod, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {}:16 }", - "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {}:7 }", - "reduce(tensor0, max, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {}:1 }", - "reduce(tensor0, min, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - // -- reduce 2 by specifying no arguments - assertEvaluates("{ {}:4 }", - "reduce(tensor0, avg)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - // -- reduce 1 dimension - assertEvaluates("{ {y:1}:2, {y:2}:6 }", - "reduce(tensor0, avg, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {y:1}:2, {y:2}:2 }", - "reduce(tensor0, count, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {y:1}:3, {y:2}:35 }", - "reduce(tensor0, prod, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {y:1}:4, {y:2}:12 }", - "reduce(tensor0, sum, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {y:1}:3, {y:2}:7 }", - "reduce(tensor0, max, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {y:1}:1, {y:2}:5 }", - "reduce(tensor0, min, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - // -- reduce composites - assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0"); - assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0"); - assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }"); - assertEvaluates("{ {}: 0 }", "sum(tensor0)", "{ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0}"); - assertEvaluates("{ {y:1}:4, {y:2}:12.0 }", - "sum(tensor0, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {x:1}:6, {x:2}:10.0 }", - "sum(tensor0, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - assertEvaluates("{ {}:16 }", - "sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - - // tensor join - assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", "join(tensor0, tensor1, f(x,y) (x*y))", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - // -- join composites - assertEvaluates("{ }", "tensor0 * tensor0", "{}"); - assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )", - "{{x:-}:1}", "{}", "{{y:-,z:-}:1}"); // empty dimensions are preserved - assertEvaluates("tensor(x{}):{}", - "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:2}:5 }"); + "max({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)"); + assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + {{h:1}:1.0,{h:2}:1.0}"); + + // sum(tensor, dimension) + assertEvaluates("{ {y:1}:4.0, {y:2}:12.0 }", + "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, x)"); + assertEvaluates("{ {x:1}:6.0, {x:2}:10.0 }", + "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, y)"); + + // tensor sum + assertEvaluates("{ }", "{} + {}"); + assertEvaluates("{ {x:1}:3, {x:2}:5 }", + "{ {x:1}:3 } + { {x:2}:5 }"); + assertEvaluates("{ {x:1}:8 }", + "{ {x:1}:3 } + { {x:1}:5 }"); + assertEvaluates("{ {x:1}:3, {y:1}:5 }", + "{ {x:1}:3 } + { {y:1}:5 }"); + assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", + "{ {x:1}:3, {x:2}:7 } + { {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", + "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } + { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", + "{ {x:1}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", + "{ {x:1}:5, {x:1,y:1}:1 } + { {z:1}:11, {y:1,z:1}:7 }"); + assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", + "{ {}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }"); + assertEvaluates("{ {}:16, {x:1,y:1}:1, {y:1,z:1}:7 }", + "{ {}:5, {x:1,y:1}:1 } + { {}:11, {y:1,z:1}:7 }"); + + // tensor difference + assertEvaluates("{ }", "{} - {}"); + assertEvaluates("{ {x:1}:3, {x:2}:-5 }", + "{ {x:1}:3 } - { {x:2}:5 }"); + assertEvaluates("{ {x:1}:-2 }", + "{ {x:1}:3 } - { {x:1}:5 }"); + assertEvaluates("{ {x:1}:3, {y:1}:-5 }", + "{ {x:1}:3 } - { {y:1}:5 }"); + assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:-5 }", + "{ {x:1}:3, {x:2}:7 } - { {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:-7, {y:2,z:1}:-11, {y:1,z:2}:-13 }", + "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } - { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }", + "{ {x:1}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:-11, {y:1,z:1}:-7 }", + "{ {x:1}:5, {x:1,y:1}:1 } - { {z:1}:11, {y:1,z:1}:7 }"); + assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }", + "{ {}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }"); + assertEvaluates("{ {}:-6, {x:1,y:1}:1, {y:1,z:1}:-7 }", + "{ {}:5, {x:1,y:1}:1 } - { {}:11, {y:1,z:1}:7 }"); + assertEvaluates("{ {x:1}:0 }", + "{ {x:1}:3 } - { {x:1}:3 }"); + assertEvaluates("{ {x:1}:0, {x:2}:1 }", + "{ {x:1}:3, {x:2}:1 } - { {x:1}:3 }"); + + // tensor product + assertEvaluates("{ }", "{} * {}"); + assertEvaluates("( {{x:-,y:-,z:-}:1}*{} )", "( {{x:-}:1} * {} ) * ( {{y:-,z:-}:1} * {} )"); // empty dimensions are preserved + assertEvaluates("( {{x:-}:1} * {} )", + "{ {x:1}:3 } * { {x:2}:5 }"); assertEvaluates("{ {x:1}:15 }", - "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:1}:5 }"); + "{ {x:1}:3 } * { {x:1}:5 }"); assertEvaluates("{ {x:1,y:1}:15 }", - "tensor0 * tensor1", "{ {x:1}:3 }", "{ {y:1}:5 }"); + "{ {x:1}:3 } * { {y:1}:5 }"); assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", - "tensor0 * tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:8, {x:2,y:1}:12 }", - "tensor0 + tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:-2, {x:2,y:1}:2 }", - "tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }", - "tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }"); - assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }", - "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }", - "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + "{ {x:1}:3, {x:2}:7 } * { {y:1}:5 }"); assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }", - "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1,y:2,z:1}:35, {x:1,y:2,z:2}:65 }", - "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:2,z:1}:7, {y:3,z:1}:11, {y:2,z:2}:13 }"); - assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }"); - assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", - "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", - "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", - "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }"); - assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:0 }", - "atan2(tensor0, tensor1)", "{ {x:1}:0, {x:2}:0 }", "{ {y:1}:1 }"); - assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", - "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", - "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", - "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", - "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", - "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", - "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); - // TODO - // argmax - // argmin - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", - "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); - - // tensor rename - assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }"); - assertEvaluates("{ {x:2,y:1}:3 }", "rename(tensor0, (x, y), (y, x))", "{ {x:1,y:2}:3.0 }"); - - // tensor generate - TODO - // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)"); - // range - // diag - // fill - // random - - // composite functions - assertEvaluates("{ {x:1}:0.25, {x:2}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }"); - assertEvaluates("{ {x:1}:0.31622776601683794, {x:2}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }"); - assertEvaluates("{ {y:1}:81.0 }", "matmul(tensor0, tensor1, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }"); - assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "softmax(tensor0, x)", "{ {x:1}:1, {x:2}:1 }", "{ {y:1}:1 }"); - assertEvaluates("{ {x:1,y:1}:88.0 }", "xw_plus_b(tensor0, tensor1, tensor2, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }", "{ {x:1}:7 }"); - - // expressions combining functions - assertEvaluates(String.valueOf(7.5 + 45 + 1.7), - "sum( " + // model computation: - " tensor0 * tensor1 * tensor2 " + // - feature combinations - " * tensor3" + // - model weights application - ") + 1.7", - "{ {x:1}:1, {x:2}:2 }", "{ {y:1}:3, {y:2}:4 }", "{ {z:1}:5 }", - "{ {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }"); - assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{ {x:1}:0, {x:2}:0 }", "{ {x:1}:1, {x:2}:1 }"); - assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:1}:1, {x:2}:1 }"); + "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } * { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); + assertEvaluates("{ {x:1,y:1,z:1}:7 }", + "{ {x:1}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }"); + assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,z:1}:55 }", + "{ {x:1}:5, {x:1,y:1}:1 } * { {z:1}:11, {y:1,z:1}:7 }"); + assertEvaluates("{ {x:1,y:1,z:1}:7 }", + "{ {}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }"); + assertEvaluates("{ {x:1,y:1,z:1}:7, {}:55 }", + "{ {}:5, {x:1,y:1}:1 } * { {}:11, {y:1,z:1}:7 }"); + + // match product + assertEvaluates("{ }", "match({}, {})"); + assertEvaluates("( {{x:-}:1} * {} )", + "match({ {x:1}:3 }, { {x:2}:5 })"); + assertEvaluates("{ {x:1}:15 }", + "match({ {x:1}:3 }, { {x:1}:5 })"); + assertEvaluates("( {{x:-,y:-}:1} * {} )", + "match({ {x:1}:3 }, { {y:1}:5 })"); + assertEvaluates("( {{x:-,y:-}:1} * {} )", + "match({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); + assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", + "match({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); + assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", + "match({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); + assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", + "match({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); + assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", + "match({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); + assertEvaluates("( {{x:-,y:-,z:-}:1} * { {}:55 } )", + "match({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); + assertEvaluates("( {{z:-}:1} * { {x:1}:15, {x:1,y:1}:7 } )", + "match({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); + + // min + assertEvaluates("{ {x:1}:3, {x:2}:5 }", + "min({ {x:1}:3 }, { {x:2}:5 })"); + assertEvaluates("{ {x:1}:3 }", + "min({ {x:1}:3 }, { {x:1}:5 })"); + assertEvaluates("{ {x:1}:3, {y:1}:5 }", + "min({ {x:1}:3 }, { {y:1}:5 })"); + assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", + "min({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", + "min({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", + "min({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", + "min({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); + assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", + "min({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); + assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", + "min({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); + assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", + "min({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); + + // max + assertEvaluates("{ {x:1}:3, {x:2}:5 }", + "max({ {x:1}:3 }, { {x:2}:5 })"); + assertEvaluates("{ {x:1}:5 }", + "max({ {x:1}:3 }, { {x:1}:5 })"); + assertEvaluates("{ {x:1}:3, {y:1}:5 }", + "max({ {x:1}:3 }, { {y:1}:5 })"); + assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", + "max({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", + "max({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", + "max({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); + assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", + "max({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); + assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", + "max({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); + assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }", + "max({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); + assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", + "max({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); + + // Combined + assertEvaluates(7.5 + 45 + 1.7, + "sum( " + // model computation + " match( " + // model weight application + " { {x:1}:1, {x:2}:2 } * { {y:1}:3, {y:2}:4 } * { {z:1}:5 }, " + // feature combinations + " { {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }" + // model weights + "))+1.7"); + + // undefined is not the same as 0 + assertEvaluates(1.0, "sum({ {x:1}:0, {x:2}:0 } * { {x:1}:1, {x:2}:1 } + 0.5)"); + assertEvaluates(0.0, "sum({ } * { {x:1}:1, {x:2}:1 } + 0.5)"); // tensor result dimensions are given from argument dimensions, not the resulting values - assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2}:1 }"); - assertEvaluates("tensor(x{},y{}):{{x:1}:1.0}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1}:1 }"); + assertEvaluates("x", "( {{x:-}:1.0} * {} )", "{ {x:1}:1 } * { {x:2}:1 }"); + assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"); + + // demonstration of where this produces different results: { {x:1}:1 } with 2 dimensions ... + assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )","{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 } * { {x:1,y:1}:1 }"); + // ... vs { {x:1}:1 } with only one dimension + assertEvaluates("x, y", "{{x:1,y:1}:1.0}", "{ {x:1}:1 } * { {x:1,y:1}:1 }"); + + // check that dimensions are preserved through other operations + String d2 = "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"; // creates a 2d tensor with only an 1d value + assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )", "match(" + d2 + ", {})"); + assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " - {}"); + assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " + {}"); + assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "min(1.5, " + d2 +")"); + assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "max({{x:1}:0}, " + d2 +")"); } public void testProgrammaticBuildingAndPrecedence() { @@ -295,16 +316,12 @@ public class EvaluationTestCase extends junit.framework.TestCase { assertEvaluates(77, "average(\"2*3\",\"pow(2,3)\")+average(\"2*3\",\"pow(2,3)\").timesten", context); } - private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) { - MapContext context = defaultContext.thawedCopy(); - int argumentIndex = 0; - for (String tensorArgument : tensorArguments) - context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument))); - return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context); + private RankingExpression assertEvaluates(String tensorValue, String expressionString) { + return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, defaultContext); } /** Validate also that the dimension of the resulting tensors are as expected */ - private RankingExpression assertEvaluates_old(String tensorDimensions, String resultTensor, String expressionString) { + private RankingExpression assertEvaluates(String tensorDimensions, String resultTensor, String expressionString) { RankingExpression expression = assertEvaluates(new TensorValue(MapTensor.from(resultTensor)), expressionString, defaultContext); TensorValue value = (TensorValue)expression.evaluate(defaultContext); assertEquals(toSet(tensorDimensions), value.asTensor().dimensions()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java index 08fdc9917a4..95c4402a612 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java @@ -17,25 +17,22 @@ public class NeuralNetEvaluationTestCase { /** "XOR" neural network, separate expression per layer */ @Test public void testPerLayerExpression() { - String input = "{ {x:1}:0, {x:2}:1 }"; // tensor0 - String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; // tensor1 - String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; // tensor2 - String firstLayerInput = "sum(tensor0 * tensor1, x) + tensor2"; + String input = "{ {x:1}:0, {x:2}:1 }"; + + String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; + String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; + String firstLayerInput = "sum(" + input + "*" + firstLayerWeights + ", x) + " + firstLayerBias; String firstLayerOutput = "min(1.0, max(0.0, 0.5 + " + firstLayerInput + "))"; // non-linearity, "poor man's sigmoid" - assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput, input, firstLayerWeights, firstLayerBias); - String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; // tensor3 - String secondLayerBias = "{ {y:1}:-0.5 }"; // tensor4 - String secondLayerInput = "sum(" + firstLayerOutput + "* tensor3, h) + tensor4"; + assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput); + String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; + String secondLayerBias = "{ {y:1}:-0.5 }"; + String secondLayerInput = "sum(" + firstLayerOutput + "*" + secondLayerWeights + ", h) + " + secondLayerBias; String secondLayerOutput = "min(1.0, max(0.0, 0.5 + " + secondLayerInput + "))"; // non-linearity, "poor man's sigmoid" - assertEvaluates("{ {y:1}:1 }", secondLayerOutput, input, firstLayerWeights, firstLayerBias, secondLayerWeights, secondLayerBias); + assertEvaluates("{ {y:1}:1 }", secondLayerOutput); } - private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) { - MapContext context = new MapContext(); - int argumentIndex = 0; - for (String tensorArgument : tensorArguments) - context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument))); - return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context); + private RankingExpression assertEvaluates(String tensorValue, String expressionString) { + return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, new MapContext()); } private RankingExpression assertEvaluates(Value value, String expressionString, Context context) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java index 61b230ab390..9d94ec0bc99 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java @@ -69,4 +69,12 @@ public class SimplifierTestCase { assertEquals("a + (b + c) / 100000000.0", transformed.toString()); } + @Test + public void testSimplificationWithTensorConstants() throws ParseException { + new Simplifier().transform(new RankingExpression( + "sum(sum((tensorFromWeightedSet(query(wset_query),x)+" + + " tensorFromWeightedSet(attribute(wset),x)) * " + + " {{x:0,y:0}:54, {x:0,y:1} :69, {x:1,y:0} :72, {x:1,y:1} :93},x))")); + } + } diff --git a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java index a54f1971d21..d70b55c66a2 100644 --- a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java +++ b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java @@ -20,7 +20,7 @@ public class ClusterState implements Cloneable { private Map<Node, NodeState> nodeStates = new TreeMap<>(); // TODO: Change to one count for distributor and one for storage, rather than an array - // TODO: RenameFunction, this is not the highest node count but the highest index + // TODO: Rename, this is not the highest node count but the highest index private ArrayList<Integer> nodeCount = new ArrayList<>(2); private String description = ""; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java index 4fd743e4724..3bda4159ca6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java @@ -21,8 +21,6 @@ import java.util.function.UnaryOperator; @Beta public class MapTensor implements Tensor { - // TODO: Enforce that all addresses are dense (and then avoid storing keys in TensorAddress) - private final ImmutableSet<String> dimensions; private final ImmutableMap<TensorAddress, Double> cells; @@ -33,7 +31,7 @@ public class MapTensor implements Tensor { } /** Creates a sparse tensor */ - public MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) { + MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) { ensureValidDimensions(cells, dimensions); this.dimensions = ImmutableSet.copyOf(dimensions); this.cells = ImmutableMap.copyOf(cells); @@ -54,41 +52,24 @@ public class MapTensor implements Tensor { */ public static MapTensor from(String s) { s = s.trim(); - try { - if (s.startsWith("tensor(")) - return fromTypedTensor(s); - else if (s.startsWith("{")) - return fromUntypedTensor(s, Collections.emptySet()); - else - return fromNumber(Double.parseDouble(s)); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + s + "'"); - } + if ( s.startsWith("(")) + return fromTensorWithEmptyDimensions(s); + else if ( s.startsWith("{")) + return fromTensor(s, Collections.emptySet()); + else + throw new IllegalArgumentException("Excepted a string starting by { or (, got '" + s + "'"); } - private static MapTensor fromTypedTensor(String s) { - if ( ! s.startsWith("tensor(")) throw tensorFormatException(s); - s = s.substring("tensor(".length()); - int typeSpecEnd = s.indexOf(")"); - if (typeSpecEnd < 0 ) throw tensorFormatException(s); - String typeSpec = s.substring(0, typeSpecEnd); - - Set<String> dimensions = new HashSet<>(); - for (String dimensionSpec : typeSpec.split(",")) { - dimensionSpec = dimensionSpec.trim(); - if ( ! dimensionSpec.endsWith("{}")) - throw new IllegalArgumentException("Only mapped dimensions ({}) are supported, got '" + dimensionSpec + "'"); - dimensions.add(dimensionSpec.substring(0, dimensionSpec.length() - 2)); - } - - s = s.substring(typeSpec.length() + 1); - if ( ! s.startsWith(":")) throw tensorFormatException(s); + private static MapTensor fromTensorWithEmptyDimensions(String s) { s = s.substring(1).trim(); - return fromUntypedTensor(s, dimensions); + int multiplier = s.indexOf("*"); + if (multiplier < 0 || ! s.endsWith(")")) + throw new IllegalArgumentException("Expected a tensor on the form ({dimension:-,...}*{{cells}}), got '" + s + "'"); + MapTensor dimensionTensor = fromTensor(s.substring(0, multiplier).trim(), Collections.emptySet()); + return fromTensor(s.substring(multiplier + 1, s.length() - 1), dimensionTensor.dimensions()); } - private static MapTensor fromUntypedTensor(String s, Set<String> additionalDimensions) { + private static MapTensor fromTensor(String s, Set<String> additionalDimensions) { s = s.trim().substring(1).trim(); ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); while (s.length() > 1) { @@ -113,16 +94,6 @@ public class MapTensor implements Tensor { dimensions.addAll(additionalDimensions); return new MapTensor(dimensions, cellMap); } - - private static MapTensor fromNumber(double number) { - ImmutableMap.Builder<TensorAddress, Double> singleCell = new ImmutableMap.Builder<>(); - singleCell.put(TensorAddress.empty, number); - return new MapTensor(ImmutableSet.of(), singleCell.build()); - } - - private static IllegalArgumentException tensorFormatException(String s) { - return new IllegalArgumentException("Expected a tensor on the form tensor(dimensionspec):content, but got '" + s + "'"); - } private static Double asDouble(TensorAddress address, String s) { try { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java new file mode 100644 index 00000000000..074742acee1 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java @@ -0,0 +1,33 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import com.google.common.collect.ImmutableMap; + +import java.util.Map; +import java.util.Set; + +/** + * Computes a <i>match product</i>, see {@link Tensor#match} + * + * @author bratseth + */ +class MatchProduct { + + private final Set<String> dimensions; + private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); + + public MatchProduct(Tensor a, Tensor b) { + this.dimensions = TensorOperations.combineDimensions(a, b); + for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { + Double sameValueInB = b.cells().get(aCell.getKey()); + if (sameValueInB != null) + cells.put(aCell.getKey(), aCell.getValue() * sameValueInB); + } + } + + /** Returns the result of taking this product */ + public MapTensor result() { + return new MapTensor(dimensions, cells.build()); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 4b17f65ea21..41882738e89 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -2,25 +2,18 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; -import com.yahoo.tensor.functions.ConstantTensor; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.L1Normalize; -import com.yahoo.tensor.functions.L2Normalize; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.Softmax; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleFunction; import java.util.function.DoubleUnaryOperator; -import java.util.function.Function; +import java.util.function.UnaryOperator; /** * A multidimensional array which can be used in computations. @@ -56,74 +49,128 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); - // ----------------- Primitive tensor functions + // ----------------- Level 0 functions - default Tensor map(DoubleUnaryOperator mapper) { - return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); + default Tensor map(Tensor tensor, DoubleUnaryOperator mapper) { + throw new UnsupportedOperationException("Not implemented"); } - /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ - default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) { - return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate(); + default Tensor reduce(Tensor tensor, String dimension, + DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) { + throw new UnsupportedOperationException("Not implemented"); } - default Tensor join(Tensor argument, DoubleBinaryOperator combinator) { - return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate(); + default Tensor join(Tensor tensorA, Tensor tensorB, DoubleBinaryOperator combinator) { + throw new UnsupportedOperationException("Not implemented"); } - default Tensor rename(String fromDimension, String toDimension) { - return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), - Collections.singletonList(toDimension)).evaluate(); + // ----------------- Old stuff + /** + * Returns the <i>sparse tensor product</i> of this tensor and the argument tensor. + * This is the all-to-all combinations of cells in the argument tenors, except the combinations + * which have conflicting labels for the same dimension. The value of each combination is the product + * of the values of the two input cells. The dimensions of the tensor product is the set union of the + * dimensions of the argument tensors. + * <p> + * If there are no overlapping dimensions this is the regular tensor product. + * If the two tensors have exactly the same dimensions this is the Hadamard product. + * <p> + * The sparse tensor product is associative and commutative. + * + * @param argument the tensor to multiply by this + * @return the resulting tensor. + */ + default Tensor multiply(Tensor argument) { + return new TensorProduct(this, argument).result(); } - default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { - return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); - } - - static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) { - return new Generate(type, valueSupplier).evaluate(); + /** + * Returns the <i>match product</i> of two tensors. + * This returns a tensor which contains the <i>matching</i> cells in the two tensors, with their + * values multiplied. + * <p> + * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, + * and have the value undefined for any non-shared dimension. + * <p> + * The dimensions of the resulting tensor is the set intersection of the two argument tensors. + * <p> + * If the two tensors have exactly the same dimensions, this is the Hadamard product. + */ + default Tensor match(Tensor argument) { + return new MatchProduct(this, argument).result(); } - - // ----------------- Composite tensor functions which have a defined primitive mapping - - default Tensor l1Normalize(String dimension) { - return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); + + /** + * Returns a tensor which contains the cells of both argument tensors, where the value for + * any <i>matching</i> cell is the min of the two possible values. + * <p> + * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, + * and have the value undefined for any non-shared dimension. + */ + default Tensor min(Tensor argument) { + return new TensorMin(this, argument).result(); } - default Tensor l2Normalize(String dimension) { - return new L2Normalize(new ConstantTensor(this), dimension).evaluate(); + /** + * Returns a tensor which contains the cells of both argument tensors, where the value for + * any <i>matching</i> cell is the max of the two possible values. + * <p> + * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, + * and have the value undefined for any non-shared dimension. + */ + default Tensor max(Tensor argument) { + return new TensorMax(this, argument).result(); } - default Tensor matmul(Tensor argument, String dimension) { - return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); + /** + * Returns a tensor which contains the cells of both argument tensors, where the value for + * any <i>matching</i> cell is the sum of the two possible values. + * <p> + * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, + * and have the value undefined for any non-shared dimension. + */ + default Tensor add(Tensor argument) { + return new TensorSum(this, argument).result(); } - default Tensor softmax(String dimension) { - return new Softmax(new ConstantTensor(this), dimension).evaluate(); + /** + * Returns a tensor which contains the cells of both argument tensors, where the value for + * any <i>matching</i> cell is the difference of the two possible values. + * <p> + * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, + * and have the value undefined for any non-shared dimension. + */ + default Tensor subtract(Tensor argument) { + return new TensorDifference(this, argument).result(); } - // ----------------- Composite tensor functions mapped to primitives here on the fly + /** + * Returns a tensor with the same cells as this and the given function is applied to all its cell values. + * + * @param function the function to apply to all cells + * @return the tensor with the function applied to all the cells of this + */ + default Tensor apply(UnaryOperator<Double> function) { + return new TensorFunction(this, function).result(); + } - default Tensor multiply(Tensor argument) { return join(argument, (a, b) -> (a * b )); } - default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); } - default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); } - default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); } - default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); } - default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } - default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } - default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); } - default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); } - default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); } - default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); } - default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } - default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } + /** + * Returns a tensor with the given dimension removed and cells which contains the sum of the values + * in the removed dimension. + */ + default Tensor sum(String dimension) { + return new TensorDimensionSum(dimension, this).result(); + } - default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } - default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); } - default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); } - default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); } - default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); } - default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); } + /** + * Returns the sum of all the cells of this tensor. + */ + default double sum() { + double sum = 0; + for (Map.Entry<TensorAddress, Double> cell : cells().entrySet()) + sum += cell.getValue(); + return sum; + } /** * Returns true if the given tensor is mathematically equal to this: @@ -179,28 +226,19 @@ public interface Tensor { * @return the tensor on the standard string format */ static String toStandardString(Tensor tensor) { - if ( emptyDimensions(tensor).size() > 0) // explicitly output type TODO: Always do that - return typeToString(tensor) + ":" + contentToString(tensor); + Set<String> emptyDimensions = emptyDimensions(tensor); + if (emptyDimensions.size() > 0) // explicitly list empty dimensions + return "( " + unitTensorWithDimensions(emptyDimensions) + " * " + contentToString(tensor) + " )"; else return contentToString(tensor); } - static String typeToString(Tensor tensor) { - if (tensor.dimensions().isEmpty()) return "tensor()"; - StringBuilder b = new StringBuilder("tensor("); - for (String dimension : tensor.dimensions()) - b.append(dimension).append("{},"); - b.setLength(b.length() -1); - b.append(")"); - return b.toString(); - } - static String contentToString(Tensor tensor) { - List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); - Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); + List<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); + Collections.sort(cellEntries, Map.Entry.<TensorAddress, Double>comparingByKey()); StringBuilder b = new StringBuilder("{"); - for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) { + for (Map.Entry<TensorAddress, Double> cell : cellEntries) { b.append(cell.getKey()).append(":").append(cell.getValue()); b.append(","); } @@ -221,4 +259,8 @@ public interface Tensor { return emptyDimensions; } + static String unitTensorWithDimensions(Set<String> dimensions) { + return new MapTensor(Collections.singletonMap(TensorAddress.emptyWithDimensions(dimensions), 1.0)).toString(); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index e3c089de071..11c6a5f6685 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -8,11 +8,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; -import java.util.Optional; import java.util.Set; /** * An immutable address to a tensor cell. + * This is sparse: Only dimensions which have a different label than "undefined" are + * explicitly included. * <p> * Tensor addresses are ordered by increasing size primarily, and by the natural order of the elements in sorted * order secondarily. @@ -65,6 +66,14 @@ public final class TensorAddress implements Comparable<TensorAddress> { return TensorAddress.fromSorted(elements); } + /** Creates an empty address with a set of dimensions */ + public static TensorAddress emptyWithDimensions(Set<String> dimensions) { + List<Element> elements = new ArrayList<>(dimensions.size()); + for (String dimension : dimensions) + elements.add(new Element(dimension, Element.undefinedLabel)); + return TensorAddress.fromUnsorted(elements); + } + /** Returns an immutable list of the elements of this address in sorted order */ public List<Element> elements() { return elements; } @@ -84,14 +93,6 @@ public final class TensorAddress implements Comparable<TensorAddress> { return dimensions; } - /** Returns the label at the given dimension, or empty if this dimension is not present */ - public Optional<String> labelOfDimension(String dimension) { - for (TensorAddress.Element element : elements) - if (element.dimension().equals(dimension)) - return Optional.of(element.label()); - return Optional.empty(); - } - @Override public int compareTo(TensorAddress other) { int sizeComparison = Integer.compare(this.elements.size(), other.elements.size()); @@ -122,6 +123,7 @@ public final class TensorAddress implements Comparable<TensorAddress> { public String toString() { StringBuilder b = new StringBuilder("{"); for (TensorAddress.Element element : elements) { + //if (element.label() == Element.undefinedLabel) continue; b.append(element.toString()); b.append(","); } @@ -134,13 +136,18 @@ public final class TensorAddress implements Comparable<TensorAddress> { /** A tensor address element. Elements have the lexical order of the dimensions as natural order. */ public static class Element implements Comparable<Element> { + static final String undefinedLabel = "-"; + private final String dimension; private final String label; private final int hashCode; public Element(String dimension, String label) { this.dimension = dimension; - this.label = label; + if (label.equals(undefinedLabel)) + this.label = undefinedLabel; + else + this.label = label; this.hashCode = dimension.hashCode() + label.hashCode(); } @@ -168,7 +175,9 @@ public final class TensorAddress implements Comparable<TensorAddress> { @Override public String toString() { - return dimension + ":" + label; + StringBuilder b = new StringBuilder(); + b.append(dimension).append(":").append(label); + return b.toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java new file mode 100644 index 00000000000..ceb003b1615 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java @@ -0,0 +1,30 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Takes the difference between two tensors, see {@link Tensor#subtract} + * + * @author bratseth + */ +class TensorDifference { + + private final Set<String> dimensions; + private final Map<TensorAddress, Double> cells = new HashMap<>(); + + public TensorDifference(Tensor a, Tensor b) { + this.dimensions = TensorOperations.combineDimensions(a, b); + cells.putAll(a.cells()); + for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) + cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) - bCell.getValue()); + } + + /** Returns the result of taking this sum */ + public Tensor result() { + return new MapTensor(dimensions, cells); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java new file mode 100644 index 00000000000..d15e5092476 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java @@ -0,0 +1,35 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Takes the max of each cell of two tensors, see {@link Tensor#max} + * + * @author bratseth + */ +class TensorMax { + + private final Set<String> dimensions; + private final Map<TensorAddress, Double> cells = new HashMap<>(); + + public TensorMax(Tensor a, Tensor b) { + dimensions = TensorOperations.combineDimensions(a, b); + cells.putAll(a.cells()); + for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { + Double aValue = a.cells().get(bCell.getKey()); + if (aValue == null) + cells.put(bCell.getKey(), bCell.getValue()); + else + cells.put(bCell.getKey(), Math.max(aValue, bCell.getValue())); + } + } + + /** Returns the result of taking this sum */ + public Tensor result() { + return new MapTensor(dimensions, cells); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java new file mode 100644 index 00000000000..e389dea3883 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java @@ -0,0 +1,33 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Takes the min of each cell of two tensors, see {@link Tensor#min} + * + * @author bratseth + */ +class TensorMin { + + private final Set<String> dimensions; + private final Map<TensorAddress, Double> cells = new HashMap<>(); + + public TensorMin(Tensor a, Tensor b) { + dimensions = TensorOperations.combineDimensions(a, b); + cells.putAll(a.cells()); + for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { + Double aValue = a.cells().get(bCell.getKey()); + if (aValue == null) + cells.put(bCell.getKey(), bCell.getValue()); + else + cells.put(bCell.getKey(), Math.min(aValue, bCell.getValue())); + } + } + + /** Returns the result of taking this sum */ + public Tensor result() { return new MapTensor(dimensions, cells); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java new file mode 100644 index 00000000000..aca306b914c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java @@ -0,0 +1,28 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +/** + * Functions on tensors + * + * @author bratseth + */ +class TensorOperations { + + /** + * A utility method which returns an ummutable set of the union of the dimensions + * of the two argument tensors. + * + * @return the combined dimensions as an unmodifiable set + */ + static Set<String> combineDimensions(Tensor a, Tensor b) { + ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>(); + setBuilder.addAll(a.dimensions()); + setBuilder.addAll(b.dimensions()); + return setBuilder.build(); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java new file mode 100644 index 00000000000..221bd985380 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java @@ -0,0 +1,93 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import com.google.common.collect.ImmutableMap; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.Set; + +/** + * Computes a <i>sparse tensor product</i>, see {@link Tensor#multiply} + * + * @author bratseth + */ +class TensorProduct { + + private final Set<String> dimensionsA, dimensionsB; + + private final Set<String> dimensions; + private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); + + public TensorProduct(Tensor a, Tensor b) { + dimensionsA = a.dimensions(); + dimensionsB = b.dimensions(); + + // Dimension product + dimensions = TensorOperations.combineDimensions(a, b); + + // Cell product (slow baseline implementation) + for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { + for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { + TensorAddress combinedAddress = combine(aCell.getKey(), bCell.getKey()); + if (combinedAddress == null) continue; // not combinable + cells.put(combinedAddress, aCell.getValue() * bCell.getValue()); + } + } + } + + private TensorAddress combine(TensorAddress a, TensorAddress b) { + List<TensorAddress.Element> combined = new ArrayList<>(); + combined.addAll(dense(a, dimensionsA)); + combined.addAll(dense(b, dimensionsB)); + Collections.sort(combined); + TensorAddress.Element previous = null; + for (ListIterator<TensorAddress.Element> i = combined.listIterator(); i.hasNext(); ) { + TensorAddress.Element current = i.next(); + if (previous != null && previous.dimension().equals(current.dimension())) { // an overlapping dimension + if (previous.label().equals(current.label())) + i.remove(); // a match: remove the duplicate + else + return null; // no match: a combination isn't viable + } + previous = current; + } + return TensorAddress.fromSorted(sparse(combined)); + } + + /** + * Returns a set of tensor elements which contains an entry for each dimension including "undefined" values + * (which are not present in the sparse elements list). + */ + private List<TensorAddress.Element> dense(TensorAddress sparse, Set<String> dimensions) { + if (sparse.elements().size() == dimensions.size()) return sparse.elements(); + + List<TensorAddress.Element> dense = new ArrayList<>(sparse.elements()); + for (String dimension : dimensions) { + if ( ! sparse.hasDimension(dimension)) + dense.add(new TensorAddress.Element(dimension, TensorAddress.Element.undefinedLabel)); + } + return dense; + } + + /** + * Removes any "undefined" entries from the given elements. + */ + private List<TensorAddress.Element> sparse(List<TensorAddress.Element> dense) { + List<TensorAddress.Element> sparse = new ArrayList<>(); + for (TensorAddress.Element element : dense) { + if ( ! element.label().equals(TensorAddress.Element.undefinedLabel)) + sparse.add(element); + } + return sparse; + } + + /** Returns the result of taking this product */ + public Tensor result() { + return new MapTensor(dimensions, cells.build()); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java new file mode 100644 index 00000000000..85dfa289bd3 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java @@ -0,0 +1,29 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Takes the sum of two tensors, see {@link Tensor#add} + * + * @author bratseth + */ +class TensorSum { + + private final Set<String> dimensions; + private final Map<TensorAddress, Double> cells = new HashMap<>(); + + public TensorSum(Tensor a, Tensor b) { + dimensions = TensorOperations.combineDimensions(a, b); + cells.putAll(a.cells()); + for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { + cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) + bCell.getValue()); + } + } + + /** Returns the result of taking this sum */ + public Tensor result() { return new MapTensor(dimensions, cells); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 31454e28baf..23cdc0e6051 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -1,7 +1,5 @@ package com.yahoo.tensor.functions; -import com.yahoo.tensor.Tensor; - /** * A composite tensor function is a tensor function which can be expressed (less tersely) * as a tree of primitive tensor functions. @@ -10,8 +8,4 @@ import com.yahoo.tensor.Tensor; */ public abstract class CompositeTensorFunction extends TensorFunction { - /** Evaluates this by first converting it to a primitive function */ - @Override - public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java new file mode 100644 index 00000000000..113247be3bb --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java @@ -0,0 +1,24 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.MapTensor; + +/** + * A function which returns a constant tensor. + * + * @author bratseth + */ +public class Constant extends PrimitiveTensorFunction { + + private final MapTensor constant; + + public Constant(String tensorString) { + this.constant = MapTensor.from(tensorString); + } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public String toString() { return constant.toString(); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java deleted file mode 100644 index 0727579a331..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ /dev/null @@ -1,38 +0,0 @@ -package com.yahoo.tensor.functions; - -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.Tensor; - -import java.util.Collections; -import java.util.List; - -/** - * A function which returns a constant tensor. - * - * @author bratseth - */ -public class ConstantTensor extends PrimitiveTensorFunction { - - private final Tensor constant; - - public ConstantTensor(String tensorString) { - this.constant = MapTensor.from(tensorString); - } - - public ConstantTensor(Tensor tensor) { - this.constant = tensor; - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } - - @Override - public PrimitiveTensorFunction toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext context) { return constant; } - - @Override - public String toString(ToStringContext context) { return constant.toString(); } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java deleted file mode 100644 index 24a4c61a58c..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.yahoo.tensor.functions; - -/** - * An evaluation context which is passed down to all nested functions during evaluation. - * The default implementation is empty as this library does not in itself have any need for a - * context. - * - * @author bratseth - */ -public interface EvaluationContext { - - static EvaluationContext empty() { return new EvaluationContext() {}; } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java deleted file mode 100644 index c0e5776bf48..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ /dev/null @@ -1,57 +0,0 @@ -package com.yahoo.tensor.functions; - -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.util.Collections; -import java.util.List; -import java.util.Objects; -import java.util.function.Function; - -/** - * An indexed tensor whose values are generated by a function - * - * @author bratseth - */ -public class Generate extends PrimitiveTensorFunction { - - private final TensorType type; - private final Function<List<Integer>, Double> generator; - - /** - * Creates a generated tensor - * - * @param type the type of the tensor - * @param generator the function generating values from a list of ints specifying the indexes of the - * tensor cell which will receive the value - * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound - */ - public Generate(TensorType type, Function<List<Integer>, Double> generator) { - Objects.requireNonNull(type, "The argument tensor type cannot be null"); - Objects.requireNonNull(generator, "The argument function cannot be null"); - validateType(type); - this.type = type; - this.generator = generator; - } - - private void validateType(TensorType type) { - for (TensorType.Dimension dimension : type.dimensions()) - if (dimension.type() != TensorType.Dimension.Type.indexedBound) - throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions"); - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } - - @Override - public PrimitiveTensorFunction toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext context) { - throw new UnsupportedOperationException("Not implemented"); // TODO - } - - @Override - public String toString(ToStringContext context) { return type + "(" + generator + ")"; } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 323da5906c3..4d945963fdf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -1,24 +1,9 @@ package com.yahoo.tensor.functions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; import java.util.function.DoubleBinaryOperator; /** - * The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells - * given by the cross product of the cells of the given tensors, having as values the value produced by - * applying the given combinator function on the values from the two source cells. + * The join tensor function. * * @author bratseth */ @@ -28,9 +13,6 @@ public class Join extends PrimitiveTensorFunction { private final DoubleBinaryOperator combinator; public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { - Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); - Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); - Objects.requireNonNull(combinator, "The combinator function cannot be null"); this.argumentA = argumentA; this.argumentB = argumentB; this.combinator = combinator; @@ -39,60 +21,15 @@ public class Join extends PrimitiveTensorFunction { public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } - - @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } - + @Override public PrimitiveTensorFunction toPrimitive() { return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } - - @Override - public String toString(ToStringContext context) { - return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")"; - } - - private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); - + @Override - public Tensor evaluate(EvaluationContext context) { - Tensor a = argumentA.evaluate(context); - Tensor b = argumentB.evaluate(context); - - // Dimension product - Set<String> dimensions = combineDimensions(a, b); - - // Cell product (slow baseline implementation) - ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); - for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - TensorAddress combinedAddress = combineAddresses(aCell.getKey(), bCell.getKey()); - if (combinedAddress == null) continue; // not combinable - cells.put(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue())); - } - } - - return new MapTensor(dimensions, cells.build()); + public String toString() { + return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", lambda(a, b) (...))"; } - private Set<String> combineDimensions(Tensor a, Tensor b) { - ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>(); - setBuilder.addAll(a.dimensions()); - setBuilder.addAll(b.dimensions()); - return setBuilder.build(); - } - - private TensorAddress combineAddresses(TensorAddress a, TensorAddress b) { - List<TensorAddress.Element> combined = new ArrayList<>(a.elements()); - for (TensorAddress.Element bElement : b.elements()) { - Optional<String> aLabel = a.labelOfDimension(bElement.dimension()); - if ( ! aLabel.isPresent()) - combined.add(bElement); - else if ( ! aLabel.get().equals(bElement.label())) - return null; // incompatible - } - return TensorAddress.fromUnsorted(combined); - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java deleted file mode 100644 index 4467b378b3f..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ /dev/null @@ -1,36 +0,0 @@ -package com.yahoo.tensor.functions; - -import java.util.Collections; -import java.util.List; - -/** - * @author bratseth - */ -public class L1Normalize extends CompositeTensorFunction { - - private final TensorFunction argument; - private final String dimension; - - public L1Normalize(TensorFunction argument, String dimension) { - this.argument = argument; - this.dimension = dimension; - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } - - @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - // join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y)) - return new Join(primitiveArgument, - new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension), - ScalarFunctions.divide()); - } - - @Override - public String toString(ToStringContext context) { - return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java deleted file mode 100644 index 0e96b43bd22..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ /dev/null @@ -1,38 +0,0 @@ -package com.yahoo.tensor.functions; - -import java.util.Collections; -import java.util.List; - -/** - * @author bratseth - */ -public class L2Normalize extends CompositeTensorFunction { - - private final TensorFunction argument; - private final String dimension; - - public L2Normalize(TensorFunction argument, String dimension) { - this.argument = argument; - this.dimension = dimension; - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } - - @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(primitiveArgument, - new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()), - Reduce.Aggregator.sum, - dimension), - ScalarFunctions.sqrt()), - ScalarFunctions.divide()); - } - - @Override - public String toString(ToStringContext context) { - return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 5db88953c64..22dd08504d7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -1,17 +1,10 @@ package com.yahoo.tensor.functions; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; - -import java.util.Collections; -import java.util.List; -import java.util.Objects; +import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; /** - * The <i>map</i> tensor function produces a tensor where the given function is applied on each cell value. + * The join tensor function. * * @author bratseth */ @@ -21,8 +14,6 @@ public class Map extends PrimitiveTensorFunction { private final DoubleUnaryOperator mapper; public Map(TensorFunction argument, DoubleUnaryOperator mapper) { - Objects.requireNonNull(argument, "The argument tensor cannot be null"); - Objects.requireNonNull(mapper, "The argument function cannot be null"); this.argument = argument; this.mapper = mapper; } @@ -31,25 +22,13 @@ public class Map extends PrimitiveTensorFunction { public DoubleUnaryOperator mapper() { return mapper; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } - - @Override public PrimitiveTensorFunction toPrimitive() { return new Map(argument.toPrimitive(), mapper); } @Override - public Tensor evaluate(EvaluationContext context) { - Tensor argument = argument().evaluate(context); - ImmutableMap.Builder<TensorAddress, Double> mappedCells = new ImmutableMap.Builder<>(); - for (java.util.Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) - mappedCells.put(cell.getKey(), mapper.applyAsDouble(cell.getValue())); - return new MapTensor(argument.dimensions(), mappedCells.build()); - } - - @Override - public String toString(ToStringContext context) { - return "map(" + argument.toString(context) + ", " + mapper + ")"; + public String toString() { + return "map(" + argument.toString() + ", lambda(a) (...))"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java deleted file mode 100644 index 4492ab083d4..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ /dev/null @@ -1,38 +0,0 @@ -package com.yahoo.tensor.functions; - -import com.google.common.collect.ImmutableList; - -import java.util.List; - -/** - * @author bratseth - */ -public class Matmul extends CompositeTensorFunction { - - private final TensorFunction argument1, argument2; - private final String dimension; - - public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { - this.argument1 = argument1; - this.argument2 = argument2; - this.dimension = dimension; - } - - @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } - - @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument1 = argument1.toPrimitive(); - TensorFunction primitiveArgument2 = argument2.toPrimitive(); - return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - dimension); - } - - @Override - public String toString(ToStringContext context) { - return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index 91e58f4bf3b..9c0c9abaeb7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -1,7 +1,5 @@ package com.yahoo.tensor.functions; -import com.yahoo.tensor.Tensor; - /** * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. @@ -10,5 +8,4 @@ import com.yahoo.tensor.Tensor; * @author bratseth */ public abstract class PrimitiveTensorFunction extends TensorFunction { - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java new file mode 100644 index 00000000000..09038a294ce --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java @@ -0,0 +1,27 @@ +package com.yahoo.tensor.functions; + +/** + * The product tensor function + * + * @author bratseth + */ +public class Product extends CompositeTensorFunction { + + private final TensorFunction argumentA, argumentB; + + public Product(TensorFunction argumentA, TensorFunction argumentB) { + this.argumentA = argumentA; + this.argumentB = argumentB; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), (a, b) -> a * b); + } + + @Override + public String toString() { + return "product(" + argumentA.toString() + ", " + argumentB.toString() + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index ef18cb61b17..4b306d376a6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -1,246 +1,38 @@ package com.yahoo.tensor.functions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; - -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; +import java.util.Optional; +import java.util.function.DoubleBinaryOperator; /** - * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions - * are collapsed to a single value using an aggregator function. + * The reduce tensor function. * * @author bratseth */ public class Reduce extends PrimitiveTensorFunction { - public enum Aggregator { avg, count, prod, sum, max, min; } - private final TensorFunction argument; - private final List<String> dimensions; - private final Aggregator aggregator; - - /** Creates a reduce function reducing aLL dimensions */ - public Reduce(TensorFunction argument, Aggregator aggregator) { - this(argument, aggregator, Collections.emptyList()); - } + private final String dimension; + private final DoubleBinaryOperator reductor; + private final Optional<DoubleBinaryOperator> postTransformation; - /** Creates a reduce function reducing a single dimension */ - public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { - this(argument, aggregator, Collections.singletonList(dimension)); - } - - /** - * Creates a reduce function. - * - * @param argument the tensor to reduce - * @param aggregator the aggregator function to use - * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, - * producing a dimensionless tensor (a scalar). - * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor - */ - public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) { - Objects.requireNonNull(argument, "The argument tensor cannot be null"); - Objects.requireNonNull(aggregator, "The aggregator cannot be null"); - Objects.requireNonNull(dimensions, "The dimensions cannot be null"); + public Reduce(TensorFunction argument, String dimension, + DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) { this.argument = argument; - this.aggregator = aggregator; - this.dimensions = ImmutableList.copyOf(dimensions); + this.dimension = dimension; + this.reductor = reductor; + this.postTransformation = postTransformation; } public TensorFunction argument() { return argument; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } - - @Override public PrimitiveTensorFunction toPrimitive() { - return new Reduce(argument.toPrimitive(), aggregator, dimensions); - } - - @Override - public String toString(ToStringContext context) { - return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; - } - - private String commaSeparated(List<String> list) { - StringBuilder b = new StringBuilder(); - for (String element : list) - b.append(", ").append(element); - return b.toString(); + return new Reduce(argument.toPrimitive(), dimension, reductor, postTransformation); } @Override - public Tensor evaluate(EvaluationContext context) { - Tensor argument = this.argument.evaluate(context); - - if ( ! dimensions.isEmpty() && ! argument.dimensions().containsAll(dimensions)) - throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + - dimensions + ": Not all those dimensions are present in this tensor"); - - if (dimensions.isEmpty() || dimensions.size() == argument.dimensions().size()) - return reduceAll(argument); - - // Reduce dimensions - Set<String> reducedDimensions = new HashSet<>(argument.dimensions()); - reducedDimensions.removeAll(dimensions); - - // Reduce cells - Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); - for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) { - TensorAddress reducedAddress = reduceDimensions(cell.getKey(), reducedDimensions); - aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); - } - ImmutableMap.Builder<TensorAddress, Double> reducedCells = new ImmutableMap.Builder<>(); - for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) - reducedCells.put(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - return new MapTensor(reducedDimensions, reducedCells.build()); - } - - private TensorAddress reduceDimensions(TensorAddress address, Set<String> reducedDimensions) { - return TensorAddress.fromSorted(address.elements().stream() - .filter(e -> reducedDimensions.contains(e.dimension())) - .collect(Collectors.toList())); - } - - private Tensor reduceAll(Tensor argument) { - ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (Double cellValue : argument.cells().values()) - valueAggregator.aggregate(cellValue); - return new MapTensor(ImmutableMap.of(TensorAddress.empty, valueAggregator.aggregatedValue())); - } - - private static abstract class ValueAggregator { - - public static ValueAggregator ofType(Aggregator aggregator) { - switch (aggregator) { - case avg : return new AvgAggregator(); - case count : return new CountAggregator(); - case prod : return new ProdAggregator(); - case sum : return new SumAggregator(); - case max : return new MaxAggregator(); - case min : return new MinAggregator(); - default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); - } - - } - - /** Add a new value to those aggregated by this */ - public abstract void aggregate(double value); - - /** Returns the value aggregated by this */ - public abstract double aggregatedValue(); - - } - - private static class AvgAggregator extends ValueAggregator { - - private int valueCount = 0; - private double valueSum = 0.0; - - @Override - public void aggregate(double value) { - valueCount++; - valueSum+= value; - } - - @Override - public double aggregatedValue() { - return valueSum / valueCount; - } - - } - - private static class CountAggregator extends ValueAggregator { - - private int valueCount = 0; - - @Override - public void aggregate(double value) { - valueCount++; - } - - @Override - public double aggregatedValue() { - return valueCount; - } - - } - - private static class ProdAggregator extends ValueAggregator { - - private double valueProd = 1.0; - - @Override - public void aggregate(double value) { - valueProd *= value; - } - - @Override - public double aggregatedValue() { - return valueProd; - } - - } - - private static class SumAggregator extends ValueAggregator { - - private double valueSum = 0.0; - - @Override - public void aggregate(double value) { - valueSum += value; - } - - @Override - public double aggregatedValue() { - return valueSum; - } - - } - - private static class MaxAggregator extends ValueAggregator { - - private double maxValue = Double.MIN_VALUE; - - @Override - public void aggregate(double value) { - if (value > maxValue) - maxValue = value; - } - - @Override - public double aggregatedValue() { - return maxValue; - } - - } - - private static class MinAggregator extends ValueAggregator { - - private double minValue = Double.MAX_VALUE; - - @Override - public void aggregate(double value) { - if (value < minValue) - minValue = value; - } - - @Override - public double aggregatedValue() { - return minValue; - } - + public String toString() { + return "reduce(" + argument.toString() + ", " + dimension + ", lambda(a, b) (...), lambda(a, b) (...))"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java deleted file mode 100644 index 05af86c33e8..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ /dev/null @@ -1,100 +0,0 @@ -package com.yahoo.tensor.functions; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names. - * - * @author bratseth - */ -public class Rename extends PrimitiveTensorFunction { - - private final TensorFunction argument; - private final List<String> fromDimensions; - private final List<String> toDimensions; - - public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { - Objects.requireNonNull(argument, "The argument tensor cannot be null"); - Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); - Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null"); - if (fromDimensions.size() < 1) - throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension"); - if (fromDimensions.size() != toDimensions.size()) - throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " + - fromDimensions.size() + " and " + toDimensions.size()); - this.argument = argument; - this.fromDimensions = ImmutableList.copyOf(fromDimensions); - this.toDimensions = ImmutableList.copyOf(toDimensions); - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } - - @Override - public PrimitiveTensorFunction toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext context) { - Tensor tensor = argument.evaluate(context); - Map<String, String> fromToMap = fromToMap(); - Set<String> renamedDimensions = tensor.dimensions().stream() - .map((d) -> fromToMap.getOrDefault(d, d)) - .collect(Collectors.toSet()); - - ImmutableMap.Builder<TensorAddress, Double> renamedCells = new ImmutableMap.Builder<>(); - for (Map.Entry<TensorAddress, Double> cell : tensor.cells().entrySet()) { - TensorAddress renamedAddress = rename(cell.getKey(), fromToMap); - renamedCells.put(renamedAddress, cell.getValue()); - } - return new MapTensor(renamedDimensions, renamedCells.build()); - } - - private TensorAddress rename(TensorAddress address, Map<String, String> fromToMap) { - List<TensorAddress.Element> renamedElements = new ArrayList<>(); - for (TensorAddress.Element element : address.elements()) { - String toDimension = fromToMap.get(element.dimension()); - if (toDimension == null) - renamedElements.add(element); - else - renamedElements.add(new TensorAddress.Element(toDimension, element.label())); - } - return TensorAddress.fromUnsorted(renamedElements); - } - - @Override - public String toString(ToStringContext context) { - return "rename(" + argument.toString(context) + ", " + - toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; - } - - private Map<String, String> fromToMap() { - Map<String, String> map = new HashMap<>(); - for (int i = 0; i < fromDimensions.size(); i++) - map.put(fromDimensions.get(i), toDimensions.get(i)); - return map; - } - - private String toVectorString(List<String> elements) { - if (elements.size() == 1) - return elements.get(0); - StringBuilder b = new StringBuilder("["); - for (String element : elements) - b.append(element).append(", "); - b.setLength(b.length() - 2); - return b.toString(); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java deleted file mode 100644 index 9438c6c533a..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ /dev/null @@ -1,81 +0,0 @@ -package com.yahoo.tensor.functions; - -import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleUnaryOperator; - -/** - * Factory of scalar Java functions. - * The purpose of this is to embellish anonymous functions with a runtime type - * such that they can be inspected and will return a parseable toString. - * - * @author bratseth - */ -public class ScalarFunctions { - - public static DoubleBinaryOperator add() { return new Addition(); } - public static DoubleBinaryOperator multiply() { return new Multiplication(); } - public static DoubleBinaryOperator divide() { return new Division(); } - public static DoubleUnaryOperator square() { return new Square(); } - public static DoubleUnaryOperator sqrt() { return new Sqrt(); } - public static DoubleUnaryOperator exp() { return new Exponent(); } - - public static class Addition implements DoubleBinaryOperator { - - @Override - public double applyAsDouble(double left, double right) { return left + right; } - - @Override - public String toString() { return "f(a,b)(a + b)"; } - - } - - public static class Multiplication implements DoubleBinaryOperator { - - @Override - public double applyAsDouble(double left, double right) { return left * right; } - - @Override - public String toString() { return "f(a,b)(a * b)"; } - - } - - public static class Division implements DoubleBinaryOperator { - - @Override - public double applyAsDouble(double left, double right) { return left / right; } - - @Override - public String toString() { return "f(a,b)(a / b)"; } - } - - public static class Square implements DoubleUnaryOperator { - - @Override - public double applyAsDouble(double operand) { return operand * operand; } - - @Override - public String toString() { return "f(a)(a * a)"; } - - } - - public static class Sqrt implements DoubleUnaryOperator { - - @Override - public double applyAsDouble(double operand) { return Math.sqrt(operand); } - - @Override - public String toString() { return "f(a)(sqrt(a))"; } - - } - - public static class Exponent implements DoubleUnaryOperator { - - @Override - public double applyAsDouble(double operand) { return Math.exp(operand); } - - @Override - public String toString() { return "f(a)(exp(a))"; } - - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java deleted file mode 100644 index b05b8172b42..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ /dev/null @@ -1,37 +0,0 @@ -package com.yahoo.tensor.functions; - -import java.util.Collections; -import java.util.List; - -/** - * @author bratseth - */ -public class Softmax extends CompositeTensorFunction { - - private final TensorFunction argument; - private final String dimension; - - public Softmax(TensorFunction argument, String dimension) { - this.argument = argument; - this.dimension = dimension; - } - - @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } - - @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), - new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()), - Reduce.Aggregator.sum, - dimension), - ScalarFunctions.divide()); - } - - @Override - public String toString(ToStringContext context) { - return "softmax(" + argument.toString(context) + ", " + dimension + ")"; - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index a717292632e..95fca95a042 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -1,9 +1,5 @@ package com.yahoo.tensor.functions; -import com.yahoo.tensor.Tensor; - -import java.util.List; - /** * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. @@ -13,9 +9,6 @@ import java.util.List; */ public abstract class TensorFunction { - /** Returns the function arguments of this node in the order they are applied */ - public abstract List<TensorFunction> functionArguments(); - /** * Translate this function - and all of its arguments recursively - * to a tree of primitive functions only. @@ -24,24 +17,4 @@ public abstract class TensorFunction { */ public abstract PrimitiveTensorFunction toPrimitive(); - /** - * Evaluates this tensor. - * - * @param context a context which must be passed to all nexted functions when evaluating - */ - public abstract Tensor evaluate(EvaluationContext context); - - /** Evaluate with no context */ - public final Tensor evaluate() { return evaluate(EvaluationContext.empty()); } - - /** - * Return a string representation of this context. - * - * @param context a context which must be passed to all nexted functions when requesting the string value - */ - public abstract String toString(ToStringContext context); - - @Override - public final String toString() { return toString(ToStringContext.empty()); } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java deleted file mode 100644 index b71229703d2..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.yahoo.tensor.functions; - -/** - * A context which is passed down to all nested functions when returning a string representation. - * The default implementation is empty as this library does not in itself have any need for a - * context. - * - * @author bratseth - */ -public interface ToStringContext { - - static ToStringContext empty() { return new ToStringContext() {}; } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java deleted file mode 100644 index 1988c1d2390..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ /dev/null @@ -1,45 +0,0 @@ -package com.yahoo.tensor.functions; - -import com.google.common.collect.ImmutableList; - -import java.util.List; - -/** - * @author bratseth - */ -public class XwPlusB extends CompositeTensorFunction { - - private final TensorFunction x, w, b; - private final String dimension; - - public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { - this.x = x; - this.w = w; - this.b = b; - this.dimension = dimension; - } - - @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); } - - @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveX = x.toPrimitive(); - TensorFunction primitiveW = w.toPrimitive(); - TensorFunction primitiveB = b.toPrimitive(); - return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - dimension), - primitiveB, - ScalarFunctions.add()); - } - - @Override - public String toString(ToStringContext context) { - return "xw_plus_b(" + x.toString(context) + ", " + - w.toString(context) + ", " + - b.toString(context) + ", " + - dimension + ")"; - } - -} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java index af2260e2f20..889b2851a08 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java @@ -42,7 +42,7 @@ public class MapTensorBuilderTestCase { Tensor tensor = new MapTensorBuilder().dimension("y").dimension("z"). cell().label("x", "0").value(1).build(); assertEquals(Sets.newHashSet("x", "y", "z"), tensor.dimensions()); - assertEquals("tensor(x{},y{},z{}):{{x:0}:1.0}", tensor.toString()); + assertEquals("( {{y:-,z:-}:1.0} * {{x:0}:1.0} )", tensor.toString()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java index 0372f328811..13ea0e95dc8 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java @@ -33,7 +33,7 @@ public class MapTensorTestCase { fail("Expected parse error"); } catch (IllegalArgumentException expected) { - assertEquals("Excepted a number or a string starting by { or tensor(, got '--'", expected.getMessage()); + assertEquals("Excepted a string starting by { or (, got '--'", expected.getMessage()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java deleted file mode 100644 index e403bb56d56..00000000000 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ /dev/null @@ -1,28 +0,0 @@ -package com.yahoo.tensor; - -import com.google.common.collect.ImmutableList; -import org.junit.Test; -import static org.junit.Assert.assertEquals; - -/** - * Tests functionality on Tensor - * - * @author bratseth - */ -public class TensorTestCase { - - /** This is mostly tested in searchlib - spot checking here */ - @Test - public void testTensorComputation() { - MapTensor tensor1 = MapTensor.from("{ {x:1}:3, {x:2}:7 }"); - MapTensor tensor2 = MapTensor.from("{ {y:1}:5 }"); - assertEquals(MapTensor.from("{ {x:1,y:1}:15, {x:2,y:1}:35 }"), tensor1.multiply(tensor2)); - assertEquals(MapTensor.from("{ {x:1,y:1}:12, {x:2,y:1}:28 }"), tensor1.join(tensor2, (a, b) -> a * b - a )); - assertEquals(MapTensor.from("{ {x:1,y:1}:0, {x:2,y:1}:1 }"), tensor1.larger(tensor2)); - assertEquals(MapTensor.from("{ {y:1}:50.0 }"), tensor1.matmul(tensor2, "x")); - assertEquals(MapTensor.from("{ {z:1}:3, {z:2}:7 }"), tensor1.rename("x", "z")); - assertEquals(MapTensor.from("{ {y:1,x:1}:8, {y:2,x:1}:12 }"), tensor1.add(tensor2).rename(ImmutableList.of("x", "y"), - ImmutableList.of("y", "x"))); - } - -} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index cc9328f7274..501397e89bc 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -12,8 +12,8 @@ public class TensorFunctionTestCase { @Test public void testTranslation() { - assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))", - new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x")); + assertTranslated("join({{x:1}:1.0}, {{x:2}:1.0}, lambda(a, b) (...))", + new Product(new Constant("{{x:1}:1.0}"), new Constant("{{x:2}:1.0}"))); } private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index c3a5e24afc2..8580868dfdf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -52,14 +52,14 @@ public class SparseBinaryFormatTestCase { @Test public void testSerializationOfTensorsWithSparseTensorAddresses() { assertSerialization("{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x")); - assertSerialization("tensor(x{},y{}):{{x:0}:2.0}", Sets.newHashSet("x", "y")); - assertSerialization("tensor(x{},y{}):{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x", "y")); - assertSerialization("tensor(x{},y{}):{{x:0}:2.0,{x:1}:3.0}", Sets.newHashSet("x", "y")); - assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0}", Sets.newHashSet("x", "y", "z")); - assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0,{x:0,y:1}:3.0}", Sets.newHashSet("x", "y", "z")); - assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0}", Sets.newHashSet("x", "y", "z")); - assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0,{y:1,x:0}:3.0}", Sets.newHashSet("x", "y", "z")); - assertSerialization("tensor(x{},y{},z{}):{{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0}", Sets.newHashSet("x", "y", "z")); + assertSerialization("({{y:-}:1} * {{x:0}:2.0})", Sets.newHashSet("x", "y")); + assertSerialization("({{y:-}:1} * {{x:0}:2.0, {}:3.0})", Sets.newHashSet("x", "y")); + assertSerialization("({{y:-}:1} * {{x:0}:2.0,{x:1}:3.0})", Sets.newHashSet("x", "y")); + assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0})", Sets.newHashSet("x", "y", "z")); + assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0,{x:0,y:1}:3.0})", Sets.newHashSet("x", "y", "z")); + assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0})", Sets.newHashSet("x", "y", "z")); + assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0,{y:1,x:0}:3.0})", Sets.newHashSet("x", "y", "z")); + assertSerialization("({{z:-}:1} * {{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0})", Sets.newHashSet("x", "y", "z")); } @Test |