summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java6
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java70
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java122
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java111
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java59
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java65
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java3
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj375
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java175
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java359
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java27
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java8
-rw-r--r--vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java188
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java31
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java93
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java73
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java36
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java236
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java100
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java81
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java45
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java28
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java16
56 files changed, 1905 insertions, 1091 deletions
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 206ab8e30f0..64bb538eab5 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1049,7 +1049,7 @@ public class JsonReaderTestCase {
@Test
public void testParsingOfTensorWithDimensions() {
- assertTensorField("( {{x:-,y:-}:1.0} * {} )",
+ assertTensorField("tensor(x{},y{}):{}",
createPutWithTensor("{ "
+ " \"dimensions\": [\"x\",\"y\"] "
+ "}"));
@@ -1101,7 +1101,7 @@ public class JsonReaderTestCase {
@Test
public void testParsingOfTensorWithDimensionsAndCells() {
- assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )",
+ assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}",
createPutWithTensor("{ "
+ " \"dimensions\": [\"x\",\"y\",\"z\"], "
+ " \"cells\": [ "
@@ -1115,7 +1115,7 @@ public class JsonReaderTestCase {
@Test
public void testParsingOfTensorWithDimensionsAndCellsInDifferentJsonOrder() {
- assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )",
+ assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}",
createPutWithTensor("{ "
+ " \"cells\": [ "
+ " { \"address\": { \"x\": \"a\", \"y\": \"b\" }, "
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java
index ba06843f178..252d40b7291 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java
@@ -69,9 +69,9 @@ public class RestApiTest {
// POST new nodes
assertResponse(new Request("http://localhost:8080/nodes/v2/node",
("[" + asNodeJson("host8.yahoo.com", "default") + "," +
- asNodeJson("host9.yahoo.com", "large-variant") + "," +
- asHostJson("parent2.yahoo.com", "large-variant") + "," +
- asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]").
+ asNodeJson("host9.yahoo.com", "large-variant") + "," +
+ asHostJson("parent2.yahoo.com", "large-variant") + "," +
+ asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]").
getBytes(StandardCharsets.UTF_8),
Request.Method.POST),
"{\"message\":\"Added 4 nodes to the provisioned state\"}");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 0dff0414ac2..620c6fad0b4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.tensor.functions.EvaluationContext;
import java.util.Set;
@@ -10,7 +11,7 @@ import java.util.Set;
*
* @author bratseth
*/
-public abstract class Context {
+public abstract class Context implements EvaluationContext {
/**
* <p>Returns the value of a simple variable name.</p>
@@ -41,7 +42,7 @@ public abstract class Context {
* "main" (or only) value.
*/
public Value get(String name, Arguments arguments,String output) {
- if (arguments!=null && arguments.expressions().size()>0)
+ if (arguments!=null && arguments.expressions().size() > 0)
throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" +
name + arguments + "'");
if (output==null)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index 2bae382d5bd..f8dcd8a6127 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- return operator.evaluate(asDouble(), value.asDouble());
+ public Value compare(TruthOperator operator, Value value) {
+ return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 028dad16d21..0e0d793bfd1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- try {
- return operator.evaluate(this.value, value.asDouble());
- }
- catch (UnsupportedOperationException e) {
- throw unsupported("comparison",value);
- }
- }
-
- @Override
public Value function(Function function, Value value) {
// use the tensor implementation of max and min if the argument is a tensor
if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 9ee9a1f7a71..2dffe2a1100 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -34,11 +34,9 @@ public class MapContext extends Context {
* Creates a map context from a map.
* The ownership of the map is transferred to this - it cannot be further modified by the caller.
* All the Values of the map will be frozen.
- *
- * @since 5.1.5
*/
public MapContext(Map<String,Value> bindings) {
- this.bindings=bindings;
+ this.bindings = bindings;
for (Value boundValue : bindings.values())
boundValue.freeze();
}
@@ -67,6 +65,9 @@ public class MapContext extends Context {
if (frozen) return bindings;
return Collections.unmodifiableMap(bindings);
}
+
+ /** Returns a new, modifiable context containing all the bindings of this */
+ public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); }
/** Returns an unmodifiable map of the names of this */
public @Override Set<String> names() {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 379b5755c7b..eb997ab818a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -68,10 +68,10 @@ public class StringValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
+ public Value compare(TruthOperator operator, Value value) {
if (operator.equals(TruthOperator.EQUAL))
- return this.equals(value);
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='");
+ return new BooleanValue(this.equals(value));
+ throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 12bede95aae..b1f4a7b20ca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.TensorType;
+import java.util.Collections;
import java.util.Optional;
/**
@@ -17,7 +18,7 @@ import java.util.Optional;
*
* @author bratseth
*/
- @Beta
+@Beta
public class TensorValue extends Value {
/** The tensor value of this */
@@ -53,7 +54,7 @@ public class TensorValue extends Value {
@Override
public Value negate() {
- return new TensorValue(value.apply((Double value) -> -value));
+ return new TensorValue(value.map((value) -> -value));
}
@Override
@@ -61,7 +62,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
else
- return new TensorValue(value.apply((Double value) -> value + argument.asDouble()));
+ return new TensorValue(value.map((value) -> value + argument.asDouble()));
}
@Override
@@ -69,7 +70,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.subtract(((TensorValue) argument).value));
else
- return new TensorValue(value.apply((Double value) -> value - argument.asDouble()));
+ return new TensorValue(value.map((value) -> value - argument.asDouble()));
}
@Override
@@ -77,35 +78,15 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.multiply(((TensorValue) argument).value));
else
- return new TensorValue(value.apply((Double value) -> value * argument.asDouble()));
+ return new TensorValue(value.map((value) -> value * argument.asDouble()));
}
@Override
public Value divide(Value argument) {
if (argument instanceof TensorValue)
- throw new UnsupportedOperationException("Two tensors cannot be divided");
+ return new TensorValue(value.divide(((TensorValue) argument).value));
else
- return new TensorValue(value.apply((Double value) -> value / argument.asDouble()));
- }
-
- public Value match(Value argument) {
- return new TensorValue(value.match(asTensor(argument, "match")));
- }
-
- public Value min(Value argument) {
- return new TensorValue(value.min(asTensor(argument, "min")));
- }
-
- public Value max(Value argument) {
- return new TensorValue(value.max(asTensor(argument, "max")));
- }
-
- public Value sum(String dimension) {
- return new TensorValue(value.sum(dimension));
- }
-
- public Value sum() {
- return new DoubleValue(value.sum());
+ return new TensorValue(value.map((value) -> value / argument.asDouble()));
}
private Tensor asTensor(Value value, String operationName) {
@@ -122,18 +103,37 @@ public class TensorValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- throw new UnsupportedOperationException("A tensor cannot be compared with any value");
+ public Value compare(TruthOperator operator, Value argument) {
+ return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
+ }
+
+ private Tensor compareTensor(TruthOperator operator, Tensor argument) {
+ switch (operator) {
+ case LARGER: return value.larger(argument);
+ case LARGEREQUAL: return value.largerOrEqual(argument);
+ case SMALLER: return value.smaller(argument);
+ case SMALLEREQUAL: return value.smallerOrEqual(argument);
+ case EQUAL: return value.equal(argument);
+ case NOTEQUAL: return value.notEqual(argument);
+ default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
+ }
}
@Override
- public Value function(Function function, Value argument) {
- if (function.equals(Function.min) && argument instanceof TensorValue)
- return min(argument);
- else if (function.equals(Function.max) && argument instanceof TensorValue)
- return max(argument);
+ public Value function(Function function, Value arg) {
+ if (arg instanceof TensorValue)
+ return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString())));
else
- return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble())));
+ return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
+ }
+
+ private Tensor functionOnTensor(Function function, Tensor argument) {
+ switch (function) {
+ case min: return value.min(argument);
+ case max: return value.max(argument);
+ case atan2: return value.atan2(argument);
+ default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ }
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index e5680edc68a..8ce18265231 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -42,7 +42,7 @@ public abstract class Value {
public abstract Value divide(Value value);
/** Perform the comparison specified by the operator between this value and the given value */
- public abstract boolean compare(TruthOperator operator,Value value);
+ public abstract Value compare(TruthOperator operator, Value value);
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function,Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 882d16ebc1c..af05acb365a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import java.util.*;
/**
- * A node which returns true or false depending on the outcome of a comparison.
+ * A node which returns the outcome of a comparison.
*
* @author bratseth
- * @since 5.1.21
*/
public class ComparisonNode extends BooleanNode {
@@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
- Value leftValue=leftCondition.evaluate(context);
- Value rightValue=rightCondition.evaluate(context);
- return new BooleanValue(leftValue.compare(operator,rightValue));
+ Value leftValue = leftCondition.evaluate(context);
+ Value rightValue = rightCondition.evaluate(context);
+ return leftValue.compare(operator,rightValue);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index 675ce758faa..19b1a83ed99 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -12,31 +12,38 @@ import static java.lang.Math.*;
*/
public enum Function implements Serializable {
- cosh { public double evaluate(double x, double y) { return cosh(x); } },
- sinh { public double evaluate(double x, double y) { return sinh(x); } },
- tanh { public double evaluate(double x, double y) { return tanh(x); } },
- cos { public double evaluate(double x, double y) { return cos(x); } },
- sin { public double evaluate(double x, double y) { return sin(x); } },
- tan { public double evaluate(double x, double y) { return tan(x); } },
+ abs { public double evaluate(double x, double y) { return abs(x); } },
acos { public double evaluate(double x, double y) { return acos(x); } },
asin { public double evaluate(double x, double y) { return asin(x); } },
atan { public double evaluate(double x, double y) { return atan(x); } },
- exp { public double evaluate(double x, double y) { return exp(x); } },
- log10 { public double evaluate(double x, double y) { return log10(x); } },
- log { public double evaluate(double x, double y) { return log(x); } },
- sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
ceil { public double evaluate(double x, double y) { return ceil(x); } },
+ cos { public double evaluate(double x, double y) { return cos(x); } },
+ cosh { public double evaluate(double x, double y) { return cosh(x); } },
+ elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } },
+ exp { public double evaluate(double x, double y) { return exp(x); } },
fabs { public double evaluate(double x, double y) { return abs(x); } },
floor { public double evaluate(double x, double y) { return floor(x); } },
isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } },
+ log { public double evaluate(double x, double y) { return log(x); } },
+ log10 { public double evaluate(double x, double y) { return log10(x); } },
relu { public double evaluate(double x, double y) { return max(x,0); } },
+ round { public double evaluate(double x, double y) { return round(x); } },
sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } },
+ sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } },
+ sin { public double evaluate(double x, double y) { return sin(x); } },
+ sinh { public double evaluate(double x, double y) { return sinh(x); } },
+ square { public double evaluate(double x, double y) { return x*x; } },
+ sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
+ tan { public double evaluate(double x, double y) { return tan(x); } },
+ tanh { public double evaluate(double x, double y) { return tanh(x); } },
+
atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } },
- pow(2) { public double evaluate(double x, double y) { return pow(x,y); } },
- ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } },
fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } },
+ ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } },
+ max(2) { public double evaluate(double x, double y) { return max(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
- max(2) { public double evaluate(double x, double y) { return max(x,y); } };
+ mod(2) { public double evaluate(double x, double y) { return x % y; } },
+ pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
private final int arity;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
new file mode 100644
index 00000000000..7b48288598d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -0,0 +1,122 @@
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * A free, parametrized function
+ *
+ * @author bratseth
+ */
+public class LambdaFunctionNode extends CompositeNode {
+
+ private final ImmutableList<String> arguments;
+ private final ExpressionNode functionExpression;
+
+ public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
+ // TODO: Verify that the function only accesses the arguments in mapperVariables
+ this.arguments = ImmutableList.copyOf(arguments);
+ this.functionExpression = functionExpression;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(functionExpression);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if ( children.size() != 1)
+ throw new IllegalArgumentException("A lambda function must have a single child expression");
+ return new LambdaFunctionNode(arguments, children.get(0));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")";
+ }
+
+ private String commaSeparated(List<String> list) {
+ StringBuilder b = new StringBuilder();
+ for (String element : list)
+ b.append(element).append(",");
+ if (b.length() > 0)
+ b.setLength(b.length()-1);
+ return b.toString();
+ }
+
+ /** Evaluate this in a context which must have the arguments bound */
+ @Override
+ public Value evaluate(Context context) {
+ return functionExpression.evaluate(context);
+ }
+
+ /**
+ * Returns this as a double unary operator
+ *
+ * @throws IllegalStateException if this has more than one argument
+ */
+ public DoubleUnaryOperator asDoubleUnaryOperator() {
+ if (arguments.size() > 1)
+ throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
+ "Must have at most one argument " + " but has " + arguments);
+ return new DoubleUnaryLambda();
+ }
+
+ /**
+ * Returns this as a double binary operator
+ *
+ * @throws IllegalStateException if this has more than two arguments
+ */
+ public DoubleBinaryOperator asDoubleBinaryOperator() {
+ if (arguments.size() > 2)
+ throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
+ "Must have at most two argument " + " but has " + arguments);
+ return new DoubleBinaryLambda();
+ }
+
+ private class DoubleUnaryLambda implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) {
+ MapContext context = new MapContext();
+ if (arguments.size() > 0)
+ context.put(arguments.get(0), operand);
+ return evaluate(context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+
+ }
+
+ private class DoubleBinaryLambda implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) {
+ MapContext context = new MapContext();
+ if (arguments.size() > 0)
+ context.put(arguments.get(0), left);
+ if (arguments.size() > 1)
+ context.put(arguments.get(1), right);
+ return evaluate(context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
new file mode 100644
index 00000000000..26d3f1dcc0e
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -0,0 +1,111 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.EvaluationContext;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * A node which performs a tensor function
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorFunctionNode extends CompositeNode {
+
+ private final TensorFunction function;
+
+ public TensorFunctionNode(TensorFunction function) {
+ this.function = function;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return function.functionArguments().stream()
+ .map(f -> ((TensorFunctionExpressionNode)f).expression)
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ // Serialize as primitive
+ return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ return new TensorValue(function.evaluate(context));
+ }
+
+ public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
+ return new TensorFunctionExpressionNode(node);
+ }
+
+ /**
+ * A tensor function implemented by an expression.
+ * This allows us to pass expressions as tensor function arguments.
+ */
+ public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {
+
+ /** An expression which produces a tensor */
+ private final ExpressionNode expression;
+
+ public TensorFunctionExpressionNode(ExpressionNode expression) {
+ this.expression = expression;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext context) {
+ Value result = expression.evaluate((Context)context);
+ if ( ! ( result instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
+ "but this returns " + result + ", not a tensor");
+ return ((TensorValue)result).asTensor();
+ }
+
+ @Override
+ public String toString(ToStringContext c) {
+ ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c;
+ return expression.toString(context.context, context.path, context.parent);
+ }
+
+ }
+
+ /** Allows passing serialization context arguments through TensorFunctions */
+ private static class ExpressionNodeToStringContext implements ToStringContext {
+
+ final SerializationContext context;
+ final Deque<String> path;
+ final CompositeNode parent;
+
+ public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ this.context = context;
+ this.path = path;
+ this.parent = parent;
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
deleted file mode 100644
index af309b3e8d8..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.ArrayList;
-import java.util.Deque;
-import java.util.List;
-
-/**
- * @author bratseth
- */
- @Beta
-public class TensorMatchNode extends CompositeNode {
-
- private final ExpressionNode left, right;
-
- public TensorMatchNode(ExpressionNode left, ExpressionNode right) {
- this.left = left;
- this.right = right;
- }
-
- @Override
- public List<ExpressionNode> children() {
- List<ExpressionNode> children = new ArrayList<>(2);
- children.add(left);
- children.add(right);
- return children;
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if ( children.size() != 2)
- throw new IllegalArgumentException("A match product must have two children");
- return new TensorMatchNode(children.get(0), children.get(1));
-
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")";
- }
-
- @Override
- public Value evaluate(Context context) {
- return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context)));
- }
-
- private TensorValue asTensor(Value value) {
- if ( ! (value instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " +
- "not a tensor: " + value);
- return (TensorValue)value;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
deleted file mode 100644
index a1f83157e20..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-import java.util.Optional;
-
-/**
- * A node which sums over all cells in the argument tensor
- *
- * @author bratseth
- */
- @Beta
-public class TensorSumNode extends CompositeNode {
-
- /** The tensor to sum */
- private final ExpressionNode argument;
-
- /** The dimension to sum over, or empty to sum all cells to a scalar */
- private final Optional<String> dimension;
-
- public TensorSumNode(ExpressionNode argument, Optional<String> dimension) {
- this.argument = argument;
- this.dimension = dimension;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return Collections.singletonList(argument);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument");
- return new TensorSumNode(children.get(0), dimension);
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "sum(" +
- argument.toString(context, path, parent) +
- ( dimension.isPresent() ? ", " + dimension.get() : "" ) +
- ")";
- }
-
- @Override
- public Value evaluate(Context context) {
- Value argumentValue = argument.evaluate(context);
- if ( ! ( argumentValue instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " +
- "but this returns " + argumentValue + ", not a tensor");
- TensorValue tensorArgument = (TensorValue)argumentValue;
- if (dimension.isPresent())
- return tensorArgument.sum(dimension.get());
- else
- return tensorArgument.sum();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
index 60fe19f909f..932975f3b63 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
@@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable {
EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } };
+ LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
+ NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
private final String operatorString;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 78ad665c414..0fcfdb5d40c 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -21,10 +21,9 @@ import com.yahoo.searchlib.rankingexpression.rule.*;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.*;
+import com.yahoo.tensor.functions.*;
import java.util.Collections;
-import java.util.Map;
import java.util.LinkedHashMap;
import java.util.Arrays;
import java.util.ArrayList;
@@ -60,51 +59,83 @@ TOKEN :
<RSQUARE: "]"> |
<LCURLY: "{"> |
<RCURLY: "}"> |
+
<ADD: "+"> |
<SUB: "-"> |
<DIV: "/"> |
<MUL: "*"> |
<DOT: "."> |
+
<DOLLAR: "$"> |
<COMMA: ","> |
<COLON: ":"> |
+
<LE: "<="> |
<LT: "<"> |
<EQ: "=="> |
+ <NQ: "!="> |
<AQ: "~="> |
<GE: ">="> |
<GT: ">"> |
+
<STRING: ("\"" (~["\""] | "\\\"")* "\"") |
("'" (~["'"] | "\\'")* "'")> |
+
<IF: "if"> |
- <COSH: "cosh"> |
- <SINH: "sinh"> |
- <TANH: "tanh"> |
- <COS: "cos"> |
- <SIN: "sin"> |
- <TAN: "tan"> |
+ <IN: "in"> |
+ <F: "f"> |
+
+ <ABS: "abs"> |
<ACOS: "acos"> |
<ASIN: "asin"> |
- <ATAN2: "atan2"> |
<ATAN: "atan"> |
- <EXP: "exp"> |
- <LDEXP: "ldexp"> |
- <LOG10: "log10"> |
- <LOG: "log"> |
- <POW: "pow"> |
- <SQRT: "sqrt"> |
<CEIL: "ceil"> |
+ <COS: "cos"> |
+ <COSH: "cosh"> |
+ <ELU: "elu"> |
+ <EXP: "exp"> |
<FABS: "fabs"> |
<FLOOR: "floor"> |
- <FMOD: "fmod"> |
- <MIN: "min"> |
- <MAX: "max"> |
<ISNAN: "isNan"> |
- <IN: "in"> |
- <SUM: "sum"> |
- <MATCH: "match"> |
+ <LOG: "log"> |
+ <LOG10: "log10"> |
<RELU: "relu"> |
+ <ROUND: "round"> |
<SIGMOID: "sigmoid"> |
+ <SIGN: "sign"> |
+ <SIN: "sin"> |
+ <SINH: "sinh"> |
+ <SQUARE: "square"> |
+ <SQRT: "sqrt"> |
+ <TAN: "tan"> |
+ <TANH: "tanh"> |
+
+ <ATAN2: "atan2"> |
+ <FMOD: "fmod"> |
+ <LDEXP: "ldexp"> |
+ // MAX
+ // MIN
+ <MOD: "mod"> |
+ <POW: "pow"> |
+
+ <MAP: "map"> |
+ <REDUCE: "reduce"> |
+ <JOIN: "join"> |
+ <RENAME: "rename"> |
+ <TENSOR: "tensor"> |
+ <L1_NORMALIZE: "l1_normalize"> |
+ <L2_NORMALIZE: "l2_normalize"> |
+ <MATMUL: "matmul"> |
+ <SOFTMAX: "softmax"> |
+ <XW_PLUS_B: "xw_plus_b"> |
+
+ <AVG: "avg" > |
+ <COUNT: "count"> |
+ <PROD: "prod"> |
+ <SUM: "sum"> |
+ <MAX: "max"> |
+ <MIN: "min"> |
+
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
}
@@ -175,6 +206,7 @@ TruthOperator comparator() : { }
( <LE> { return TruthOperator.SMALLEREQUAL; } |
<LT> { return TruthOperator.SMALLER; } |
<EQ> { return TruthOperator.EQUAL; } |
+ <NQ> { return TruthOperator.NOTEQUAL; } |
<AQ> { return TruthOperator.APPROX_EQUAL; } |
<GE> { return TruthOperator.LARGEREQUAL; } |
<GT> { return TruthOperator.LARGER; } )
@@ -189,7 +221,6 @@ ExpressionNode value() :
{
( [ LOOKAHEAD(2) <SUB> { neg = true; } ]
( ret = constantPrimitive() |
- ret = constantTensor() |
LOOKAHEAD(2) ret = ifExpression() |
LOOKAHEAD(2) ret = function() |
ret = feature() |
@@ -279,7 +310,6 @@ ExpressionNode arg() :
}
{
( ret = constantPrimitive() |
- ret = constantTensor() |
LOOKAHEAD(2) ret = feature() |
name = identifier() { ret = new NameNode(name); } )
{ return ret; }
@@ -290,11 +320,11 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( function = scalarFunction() | function = tensorFunction() )
+ ( function = scalarOrTensorFunction() | function = tensorFunction() )
{ return function; }
}
-FunctionNode scalarFunction() :
+FunctionNode scalarOrTensorFunction() :
{
Function function;
ExpressionNode arg1, arg2;
@@ -312,61 +342,223 @@ FunctionNode scalarFunction() :
ExpressionNode tensorFunction() :
{
+ ExpressionNode tensorExpression;
+}
+{
+ (
+ tensorExpression = tensorMap() |
+ tensorExpression = tensorReduce() |
+ tensorExpression = tensorReduceComposites() |
+ tensorExpression = tensorJoin() |
+ tensorExpression = tensorRename() |
+ tensorExpression = tensorGenerate() |
+ tensorExpression = tensorL1Normalize() |
+ tensorExpression = tensorL2Normalize() |
+ tensorExpression = tensorMatmul() |
+ tensorExpression = tensorSoftmax() |
+ tensorExpression = tensorXwPlusB()
+ )
+ { return tensorExpression; }
+}
+
+ExpressionNode tensorMap() :
+{
+ ExpressionNode tensor;
+ LambdaFunctionNode doubleMapper;
+}
+{
+ <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
+ { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
+ doubleMapper.asDoubleUnaryOperator())); }
+}
+
+ExpressionNode tensorReduce() :
+{
+ ExpressionNode tensor;
+ Reduce.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+}
+
+ExpressionNode tensorReduceComposites() :
+{
+ ExpressionNode tensor;
+ Reduce.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ aggregator = tensorReduceAggregator()
+ <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+}
+
+ExpressionNode tensorJoin() :
+{
ExpressionNode tensor1, tensor2;
- String dimension = null;
- TensorAddress address = null;
+ LambdaFunctionNode doubleJoiner;
}
{
- (
- <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE>
- { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); }
- ) |
- (
- <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE>
- { return new TensorMatchNode(tensor1, tensor2); }
- )
+ <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
+ { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ doubleJoiner.asDoubleBinaryOperator())); }
+}
+
+ExpressionNode tensorRename() :
+{
+ ExpressionNode tensor;
+ List<String> fromDimensions, toDimensions;
+}
+{
+ <RENAME> <LBRACE> tensor = expression() <COMMA>
+ fromDimensions = bracedIdentifierList() <COMMA>
+ toDimensions = bracedIdentifierList()
+ <RBRACE>
+ { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
+}
+
+// TODO: Notice that null is parsed below
+ExpressionNode tensorGenerate() :
+{
+ TensorType type;
+ LambdaFunctionNode generator;
+}
+{
+ <TENSOR> <LBRACE> <RBRACE> <LBRACE>
+ { return new TensorFunctionNode(new Generate(null, null)); }
+}
+
+ExpressionNode tensorL1Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorL2Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorMatmul() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ dimension)); }
+}
+
+ExpressionNode tensorSoftmax() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorXwPlusB() :
+{
+ ExpressionNode tensor1, tensor2, tensor3;
+ String dimension;
+}
+{
+ <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA>
+ tensor2 = expression() <COMMA>
+ tensor3 = expression() <COMMA>
+ dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ TensorFunctionNode.wrapArgument(tensor3),
+ dimension)); }
+}
+
+LambdaFunctionNode lambdaFunction() :
+{
+ List<String> variables;
+ ExpressionNode functionExpression;
+}
+{
+ ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
+ { return new LambdaFunctionNode(variables, functionExpression); }
+}
+
+Reduce.Aggregator tensorReduceAggregator() :
+{
+}
+{
+ ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
+ { return Reduce.Aggregator.valueOf(token.image); }
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
+ Reduce.Aggregator aggregator;
}
{
- ( <SUM> | <MATCH> )
- { return token.image; }
+ ( <F> { return token.image; } ) |
+ ( <MAP> { return token.image; } ) |
+ ( <REDUCE> { return token.image; } ) |
+ ( <JOIN> { return token.image; } ) |
+ ( <RENAME> { return token.image; } ) |
+ ( <TENSOR> { return token.image; } ) |
+ ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
Function unaryFunctionName() : { }
{
- <COS> { return Function.cos; } |
- <SIN> { return Function.sin; } |
- <TAN> { return Function.tan; } |
- <COSH> { return Function.cosh; } |
- <SINH> { return Function.sinh; } |
- <TANH> { return Function.tanh; } |
+ <ABS> { return Function.abs; } |
<ACOS> { return Function.acos; } |
<ASIN> { return Function.asin; } |
<ATAN> { return Function.atan; } |
- <EXP> { return Function.exp; } |
- <LOG10> { return Function.log10; } |
- <LOG> { return Function.log; } |
- <SQRT> { return Function.sqrt; } |
<CEIL> { return Function.ceil; } |
+ <COS> { return Function.cos; } |
+ <COSH> { return Function.cosh; } |
+ <ELU> { return Function.elu; } |
+ <EXP> { return Function.exp; } |
<FABS> { return Function.fabs; } |
<FLOOR> { return Function.floor; } |
<ISNAN> { return Function.isNan; } |
+ <LOG> { return Function.log; } |
+ <LOG10> { return Function.log10; } |
<RELU> { return Function.relu; } |
- <SIGMOID> { return Function.sigmoid; }
+ <ROUND> { return Function.round; } |
+ <SIGMOID> { return Function.sigmoid; } |
+ <SIGN> { return Function.sign; } |
+ <SIN> { return Function.sin; } |
+ <SINH> { return Function.sinh; } |
+ <SQUARE> { return Function.square; } |
+ <SQRT> { return Function.sqrt; } |
+ <TAN> { return Function.tan; } |
+ <TANH> { return Function.tanh; }
}
Function binaryFunctionName() : { }
{
<ATAN2> { return Function.atan2; } |
- <LDEXP> { return Function.ldexp; } |
- <POW> { return Function.pow; } |
<FMOD> { return Function.fmod; } |
+ <LDEXP> { return Function.ldexp; } |
+ <MAX> { return Function.max; } |
<MIN> { return Function.min; } |
- <MAX> { return Function.max; }
+ <MOD> { return Function.mod; } |
+ <POW> { return Function.pow; }
}
List<ExpressionNode> expressionList() :
@@ -405,79 +597,64 @@ String identifier() :
<IDENTIFIER> { return token.image; }
}
-// An identifier or integer
-String tag() :
-{
- String name;
-}
-{
- name = identifier() { return name; } |
- <INTEGER> { return token.image; }
-}
-
-ConstantNode constantPrimitive() :
+List<String> identifierList() :
{
- String sign = "";
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
+ ( element = identifier() { list.add(element); } )?
+ ( <COMMA> element = identifier() { list.add(element); } ) *
+ { return list; }
}
-Value primitiveValue() :
+List<String> bracedIdentifierList() :
{
- String sign = "";
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return Value.parse(sign + token.image); }
+ ( element = identifier() { return Collections.singletonList(element); } )
+ |
+ ( <LBRACE> list = identifierList() <RBRACE> { return list; } )
}
-ConstantNode constantTensor() :
+// An identifier or integer
+String tag() :
{
- Value constantValue;
+ String name;
}
{
- <LCURLY> constantValue = tensorContent() <RCURLY>
- { return new ConstantNode(constantValue); }
+ name = identifier() { return name; } |
+ <INTEGER> { return token.image; }
}
-TensorValue tensorContent() :
+List<String> tagCommaLeadingList() :
{
- Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>();
- TensorAddress address;
- Double value;
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ?
- ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) *
- { return new TensorValue(new MapTensor(cells)); }
+ ( <COMMA> element = tag() { list.add(element); } ) *
+ { return list; }
}
-TensorAddress tensorAddress() :
+ConstantNode constantPrimitive() :
{
- List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>();
- String dimension;
- String label;
+ String sign = "";
}
{
- <LCURLY>
- ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ?
- ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) *
- <RCURLY>
- { return TensorAddress.fromUnsorted(elements); }
+ ( <SUB> { sign = "-";} ) ?
+ ( <INTEGER> | <FLOAT> | <STRING> )
+ { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
}
-String label() :
+Value primitiveValue() :
{
- String label;
-
+ String sign = "";
}
{
- ( label = tag() |
- ( "-" { label = "-"; } ) )
- { return label; }
+ ( <SUB> { sign = "-";} ) ?
+ ( <INTEGER> | <FLOAT> | <STRING> )
+ { return Value.parse(sign + token.image); }
}
-
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 24d7c82235c..f28ff739b4c 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -6,7 +6,10 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
-import junit.framework.TestCase;
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertFalse;
import java.io.BufferedReader;
import java.io.File;
@@ -14,15 +17,18 @@ import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
+ * @author bratseth
*/
-public class RankingExpressionTestCase extends TestCase {
+public class RankingExpressionTestCase {
+ @Test
public void testParamInFeature() throws ParseException {
assertParse("if (1 > 2, dotProduct(allparentid,query(cate1_parentid)), 2)",
"if ( 1 > 2,\n" +
@@ -31,6 +37,7 @@ public class RankingExpressionTestCase extends TestCase {
")");
}
+ @Test
public void testDollarShorthand() throws ParseException {
assertParse("query(var1)", " $var1");
assertParse("query(var1)", " $var1 ");
@@ -44,6 +51,7 @@ public class RankingExpressionTestCase extends TestCase {
assertParse("if (if (f1.out < query(p1), 0, 1) < if (f2.out < query(p2), 0, 1), f3.out, query(p3))", "if(if(f1.out<$p1,0,1)<if(f2.out<$p2,0,1),f3.out,$p3)");
}
+ @Test
public void testLookaheadIndefinitely() throws Exception {
ExecutorService exec = Executors.newSingleThreadExecutor();
Future<Boolean> future = exec.submit(new Callable<Boolean>() {
@@ -60,7 +68,8 @@ public class RankingExpressionTestCase extends TestCase {
assertTrue(future.get(60, TimeUnit.SECONDS));
}
- public void testSelfRecursionScript() throws ParseException {
+ @Test
+ public void testSelfRecursionSerialization() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo")));
@@ -72,7 +81,8 @@ public class RankingExpressionTestCase extends TestCase {
}
}
- public void testMacroCycleScript() throws ParseException {
+ @Test
+ public void testMacroCycleSerialization() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar")));
macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo")));
@@ -85,42 +95,48 @@ public class RankingExpressionTestCase extends TestCase {
}
}
- public void testScript() throws ParseException {
+ @Test
+ public void testSerialization() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))")));
macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2")));
macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)")));
macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977")));
- assertScript("foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros,
- Arrays.asList(
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)",
- "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))",
- "min(6,pow(7,2))",
- "min(1,pow(2,2))",
- "min(3,pow(4,2))",
- "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"
- ));
- assertScript("foo(1, 2) + bar(3, 4)", macros,
- Arrays.asList(
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)",
- "min(1,pow(2,2))",
- "3 * 3 + 2 * 3 * 4 + 4 * 4"
- ));
- assertScript("baz(1, 2)", macros,
- Arrays.asList(
- "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)",
- "min(1,pow(2,2))",
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)",
- "1 * 1 + 2 * 1 * 2 + 2 * 2"
- ));
- assertScript("cox", macros,
- Arrays.asList(
- "rankingExpression(cox)",
- "10 + 08 * 1977"
- ));
+ assertSerialization(Arrays.asList(
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)",
+ "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))",
+ "min(6,pow(7,2))",
+ "min(1,pow(2,2))",
+ "min(3,pow(4,2))",
+ "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros);
+ assertSerialization(Arrays.asList(
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)",
+ "min(1,pow(2,2))",
+ "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros);
+ assertSerialization(Arrays.asList(
+ "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)",
+ "min(1,pow(2,2))",
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)",
+ "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros);
+ assertSerialization(Arrays.asList(
+ "rankingExpression(cox)",
+ "10 + 08 * 1977"), "cox", macros
+ );
+ }
+
+ @Test
+ public void testTensorSerialization() {
+ assertSerialization("map(constant(tensor0), f(a)(cos(a)))",
+ "map(constant(tensor0), f(a)(cos(a)))");
+ assertSerialization("map(constant(tensor0), f(a)(cos(a))) + join(attribute(tensor1), map(reduce(map(attribute(tensor1), f(a)(a * a)), sum, x), f(a)(sqrt(a))), f(a,b)(a / b))",
+ "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)");
+ assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))",
+ "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)");
+
}
+ @Test
public void testBug3464208() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69")));
@@ -135,18 +151,11 @@ public class RankingExpressionTestCase extends TestCase {
String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " +
"rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)";
- assertScript(lhs + " + " + rhs, macros,
- Arrays.asList(
- expLhs + " + " + expRhs,
- "69"
- ));
- assertScript(lhs + " - " + rhs, macros,
- Arrays.asList(
- expLhs + " - " + expRhs,
- "69"
- ));
+ assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros);
+ assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros);
}
+ @Test
public void testParse() throws ParseException, IOException {
BufferedReader reader = new BufferedReader(new FileReader("src/tests/rankingexpression/rankingexpressionlist"));
String line;
@@ -181,36 +190,43 @@ public class RankingExpressionTestCase extends TestCase {
}
}
+ @Test
public void testIssue() throws ParseException {
assertEquals("feature.0", new RankingExpression("feature.0").toString());
assertEquals("if (1 > 2, 3, 4) + feature(arg1).out.out",
new RankingExpression("if ( 1 > 2 , 3 , 4 ) + feature ( arg1 ) . out.out").toString());
}
+ @Test
public void testNegativeConstantArgument() throws ParseException {
assertEquals("foo(-1.2)", new RankingExpression("foo(-1.2)").toString());
}
+ @Test
public void testNaming() throws ParseException {
RankingExpression test = new RankingExpression("a+b");
test.setName("test");
assertEquals("test: a + b", test.toString());
}
+ @Test
public void testCondition() throws ParseException {
RankingExpression expression = new RankingExpression("if(1<2,3,4)");
assertTrue(expression.getRoot() instanceof IfNode);
}
+ @Test
public void testFileImporting() throws ParseException {
RankingExpression expression = new RankingExpression(new File("src/test/files/simple.expression"));
assertEquals("simple: a + b", expression.toString());
}
+ @Test
public void testNonCanonicalLegalStrings() throws ParseException {
assertParse("a * b + c * d", "a* (b) + \nc*d");
}
+ @Test
public void testEquality() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo)==\"BAR\",log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"),
new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
@@ -219,6 +235,7 @@ public class RankingExpressionTestCase extends TestCase {
new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).earliness) * fieldMatch(title).completeness)")));
}
+ @Test
public void testSetMembershipConditions() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"),
new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
@@ -231,6 +248,7 @@ public class RankingExpressionTestCase extends TestCase {
assertEquals(new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"), new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"));
}
+ @Test
public void testComments() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],\n" +
"# a comment\n" +
@@ -241,6 +259,7 @@ public class RankingExpressionTestCase extends TestCase {
new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
}
+ @Test
public void testIsNan() throws ParseException {
String strExpr = "if (isNan(attribute(foo)) == 1.0, 1.0, attribute(foo))";
RankingExpression expr = new RankingExpression(strExpr);
@@ -255,27 +274,59 @@ public class RankingExpressionTestCase extends TestCase {
assertEquals(expected, new RankingExpression(expression).toString());
}
- private void assertScript(String expression, List<ExpressionFunction> macros, List<String> expectedScripts)
- throws ParseException {
- boolean print = false;
- if (print)
- System.out.println("Parsing expression '" + expression + "'.");
-
- RankingExpression exp = new RankingExpression(expression);
- Map<String, String> scripts = exp.getRankProperties(macros);
- if (print) {
- for (String key : scripts.keySet()) {
- System.out.println("Script '" + key + "': " + scripts.get(key));
- }
+ /** Test serialization with no macros */
+ private void assertSerialization(String expectedSerialization, String expressionString) {
+ String serializedExpression;
+ try {
+ RankingExpression expression = new RankingExpression(expressionString);
+ // No macros -> expect one rank property
+ serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next();
+ assertEquals(expectedSerialization, serializedExpression);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
}
- for (Map.Entry<String, String> m : scripts.entrySet())
- System.out.println(m);
- for (int i = 0; i < expectedScripts.size();) {
- String val = expectedScripts.get(i++);
- assertTrue("Script contains " + val, scripts.containsValue(val));
+ try {
+ // No macros -> output should be parseable to a ranking expression
+ // (but not the same one due to primitivization)
+ RankingExpression reparsedExpression = new RankingExpression(serializedExpression);
+ // Serializing the primitivized expression should yield the same expression again
+ String reserializedExpression =
+ reparsedExpression.getRankProperties(Collections.emptyList()).values().iterator().next();
+ assertEquals(expectedSerialization, reserializedExpression);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException("Could not parse the serialized expression", e);
}
- if (print)
- System.out.println("");
}
+
+ private void assertSerialization(List<String> expectedSerialization, String expressionString,
+ List<ExpressionFunction> macros) {
+ assertSerialization(expectedSerialization, expressionString, macros, false);
+ }
+ private void assertSerialization(List<String> expectedSerialization, String expressionString,
+ List<ExpressionFunction> macros, boolean print) {
+ try {
+ if (print)
+ System.out.println("Parsing expression '" + expressionString + "'.");
+
+ RankingExpression expression = new RankingExpression(expressionString);
+ Map<String, String> rankProperties = expression.getRankProperties(macros);
+ if (print) {
+ for (String key : rankProperties.keySet())
+ System.out.println("Property '" + key + "': " + rankProperties.get(key));
+ }
+ for (int i = 0; i < expectedSerialization.size();) {
+ String val = expectedSerialization.get(i++);
+ assertTrue("Properties contains " + val, rankProperties.containsValue(val));
+ }
+ if (print)
+ System.out.println("");
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index b67a423181d..93800e2c246 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -20,7 +20,7 @@ import java.util.Set;
*/
public class EvaluationTestCase extends junit.framework.TestCase {
- private Context defaultContext;
+ private MapContext defaultContext;
@Override
protected void setUp() {
@@ -100,201 +100,180 @@ public class EvaluationTestCase extends junit.framework.TestCase {
@Test
public void testTensorEvaluation() {
- assertEvaluates("{}", "{}"); // empty
- assertEvaluates("( {{x:-}:1} * {} )", "( {{x:-}:1} * {} )"); // empty with dimensions
+ assertEvaluates("{}", "tensor0", "{}");
- // sum(tensor)
- assertEvaluates(5.0, "sum({{}:5.0})");
- assertEvaluates(-5.0, "sum({{}:-5.0})");
- assertEvaluates(12.5, "sum({ {d1:l1}:5.5, {d2:l2}:7.0 })");
- assertEvaluates(0.0, "sum({ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0})");
-
- // scalar functions on tensors
+ // tensor map
assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
- "log10({ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 })");
- assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
- "5 * { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
- "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } + 3");
- assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
- "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } / 10");
+ "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:4, {d1:l1}:9, {d1:l1,d2:l1 }:16 }",
+ "map(tensor0, f(x) (x * x))", "{ {}:2, {d1:l1}:3, {d1:l1,d2:l1}:4 }");
+ // -- tensor map composites
+ assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
+ "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }",
- "- { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }",
- "min({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)");
+ "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }",
- "max({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)");
- assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + {{h:1}:1.0,{h:2}:1.0}");
-
- // sum(tensor, dimension)
- assertEvaluates("{ {y:1}:4.0, {y:2}:12.0 }",
- "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, x)");
- assertEvaluates("{ {x:1}:6.0, {x:2}:10.0 }",
- "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, y)");
-
- // tensor sum
- assertEvaluates("{ }", "{} + {}");
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "{ {x:1}:3 } + { {x:2}:5 }");
- assertEvaluates("{ {x:1}:8 }",
- "{ {x:1}:3 } + { {x:1}:5 }");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "{ {x:1}:3 } + { {y:1}:5 }");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "{ {x:1}:3, {x:2}:7 } + { {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } + { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } + { {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "{ {}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }");
- assertEvaluates("{ {}:16, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "{ {}:5, {x:1,y:1}:1 } + { {}:11, {y:1,z:1}:7 }");
-
- // tensor difference
- assertEvaluates("{ }", "{} - {}");
- assertEvaluates("{ {x:1}:3, {x:2}:-5 }",
- "{ {x:1}:3 } - { {x:2}:5 }");
- assertEvaluates("{ {x:1}:-2 }",
- "{ {x:1}:3 } - { {x:1}:5 }");
- assertEvaluates("{ {x:1}:3, {y:1}:-5 }",
- "{ {x:1}:3 } - { {y:1}:5 }");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:-5 }",
- "{ {x:1}:3, {x:2}:7 } - { {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:-7, {y:2,z:1}:-11, {y:1,z:2}:-13 }",
- "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } - { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:-11, {y:1,z:1}:-7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } - { {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }",
- "{ {}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }");
- assertEvaluates("{ {}:-6, {x:1,y:1}:1, {y:1,z:1}:-7 }",
- "{ {}:5, {x:1,y:1}:1 } - { {}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:0 }",
- "{ {x:1}:3 } - { {x:1}:3 }");
- assertEvaluates("{ {x:1}:0, {x:2}:1 }",
- "{ {x:1}:3, {x:2}:1 } - { {x:1}:3 }");
-
- // tensor product
- assertEvaluates("{ }", "{} * {}");
- assertEvaluates("( {{x:-,y:-,z:-}:1}*{} )", "( {{x:-}:1} * {} ) * ( {{y:-,z:-}:1} * {} )"); // empty dimensions are preserved
- assertEvaluates("( {{x:-}:1} * {} )",
- "{ {x:1}:3 } * { {x:2}:5 }");
+ "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
+ // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "abs(tensor0)", "{ {x:1}:1, {x:2}:-2 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "acos(tensor0)", "{ {x:1}:1, {x:2}:1 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "asin(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "atan(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "ceil(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cos(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cosh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "elu(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:1 }", "exp(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "fabs(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "floor(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "isNan(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "log(tensor0)", "{ {x:1}:1, {x:2}:1 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:1 }", "log10(tensor0)", "{ {x:1}:1, {x:2}:10 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:2 }", "mod(tensor0, 3)", "{ {x:1}:3, {x:2}:8 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:8 }", "pow(tensor0, 3)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "relu(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "round(tensor0)", "{ {x:1}:1, {x:2}:1.8 }");
+ assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "sigmoid(tensor0)","{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:-1 }", "sign(tensor0)", "{ {x:1}:3, {x:2}:-5 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sin(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sinh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:4 }", "square(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:3 }", "sqrt(tensor0)", "{ {x:1}:1, {x:2}:9 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tan(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tanh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+
+ // tensor reduce
+ // -- reduce 2 dimensions
+ assertEvaluates("{ {}:4 }",
+ "reduce(tensor0, avg, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:4 }",
+ "reduce(tensor0, count, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:105 }",
+ "reduce(tensor0, prod, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:16 }",
+ "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:7 }",
+ "reduce(tensor0, max, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:1 }",
+ "reduce(tensor0, min, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ // -- reduce 2 by specifying no arguments
+ assertEvaluates("{ {}:4 }",
+ "reduce(tensor0, avg)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ // -- reduce 1 dimension
+ assertEvaluates("{ {y:1}:2, {y:2}:6 }",
+ "reduce(tensor0, avg, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:2, {y:2}:2 }",
+ "reduce(tensor0, count, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:3, {y:2}:35 }",
+ "reduce(tensor0, prod, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:4, {y:2}:12 }",
+ "reduce(tensor0, sum, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:3, {y:2}:7 }",
+ "reduce(tensor0, max, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:1, {y:2}:5 }",
+ "reduce(tensor0, min, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ // -- reduce composites
+ assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0");
+ assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0");
+ assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }");
+ assertEvaluates("{ {}: 0 }", "sum(tensor0)", "{ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0}");
+ assertEvaluates("{ {y:1}:4, {y:2}:12.0 }",
+ "sum(tensor0, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {x:1}:6, {x:2}:10.0 }",
+ "sum(tensor0, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:16 }",
+ "sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+
+ // tensor join
+ assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", "join(tensor0, tensor1, f(x,y) (x*y))", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ // -- join composites
+ assertEvaluates("{ }", "tensor0 * tensor0", "{}");
+ assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )",
+ "{{x:-}:1}", "{}", "{{y:-,z:-}:1}"); // empty dimensions are preserved
+ assertEvaluates("tensor(x{}):{}",
+ "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:2}:5 }");
assertEvaluates("{ {x:1}:15 }",
- "{ {x:1}:3 } * { {x:1}:5 }");
+ "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:1}:5 }");
assertEvaluates("{ {x:1,y:1}:15 }",
- "{ {x:1}:3 } * { {y:1}:5 }");
+ "tensor0 * tensor1", "{ {x:1}:3 }", "{ {y:1}:5 }");
assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }",
- "{ {x:1}:3, {x:2}:7 } * { {y:1}:5 }");
+ "tensor0 * tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:8, {x:2,y:1}:12 }",
+ "tensor0 + tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:-2, {x:2,y:1}:2 }",
+ "tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }",
+ "tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }");
+ assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }",
+ "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }",
+ "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }",
- "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } * { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,z:1}:55 }",
- "{ {x:1}:5, {x:1,y:1}:1 } * { {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7 }",
- "{ {}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7, {}:55 }",
- "{ {}:5, {x:1,y:1}:1 } * { {}:11, {y:1,z:1}:7 }");
-
- // match product
- assertEvaluates("{ }", "match({}, {})");
- assertEvaluates("( {{x:-}:1} * {} )",
- "match({ {x:1}:3 }, { {x:2}:5 })");
- assertEvaluates("{ {x:1}:15 }",
- "match({ {x:1}:3 }, { {x:1}:5 })");
- assertEvaluates("( {{x:-,y:-}:1} * {} )",
- "match({ {x:1}:3 }, { {y:1}:5 })");
- assertEvaluates("( {{x:-,y:-}:1} * {} )",
- "match({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * { {}:55 } )",
- "match({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
- assertEvaluates("( {{z:-}:1} * { {x:1}:15, {x:1,y:1}:7 } )",
- "match({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
-
- // min
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "min({ {x:1}:3 }, { {x:2}:5 })");
- assertEvaluates("{ {x:1}:3 }",
- "min({ {x:1}:3 }, { {x:1}:5 })");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "min({ {x:1}:3 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "min({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "min({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "min({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
- "min({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
-
- // max
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "max({ {x:1}:3 }, { {x:2}:5 })");
- assertEvaluates("{ {x:1}:5 }",
- "max({ {x:1}:3 }, { {x:1}:5 })");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "max({ {x:1}:3 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "max({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "max({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "max({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
- "max({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
-
- // Combined
- assertEvaluates(7.5 + 45 + 1.7,
- "sum( " + // model computation
- " match( " + // model weight application
- " { {x:1}:1, {x:2}:2 } * { {y:1}:3, {y:2}:4 } * { {z:1}:5 }, " + // feature combinations
- " { {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }" + // model weights
- "))+1.7");
-
- // undefined is not the same as 0
- assertEvaluates(1.0, "sum({ {x:1}:0, {x:2}:0 } * { {x:1}:1, {x:2}:1 } + 0.5)");
- assertEvaluates(0.0, "sum({ } * { {x:1}:1, {x:2}:1 } + 0.5)");
+ "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
+ assertEvaluates("{ {x:1,y:2,z:1}:35, {x:1,y:2,z:2}:65 }",
+ "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:2,z:1}:7, {y:3,z:1}:11, {y:2,z:2}:13 }");
+ assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }");
+ assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
+ "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
+ "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
+ "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:0 }",
+ "atan2(tensor0, tensor1)", "{ {x:1}:0, {x:2}:0 }", "{ {y:1}:1 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
+ // TODO
+ // argmax
+ // argmin
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
+
+ // tensor rename
+ assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }");
+ assertEvaluates("{ {x:2,y:1}:3 }", "rename(tensor0, (x, y), (y, x))", "{ {x:1,y:2}:3.0 }");
+
+ // tensor generate - TODO
+ // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)");
+ // range
+ // diag
+ // fill
+ // random
+
+ // composite functions
+ assertEvaluates("{ {x:1}:0.25, {x:2}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }");
+ assertEvaluates("{ {x:1}:0.31622776601683794, {x:2}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }");
+ assertEvaluates("{ {y:1}:81.0 }", "matmul(tensor0, tensor1, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }");
+ assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "softmax(tensor0, x)", "{ {x:1}:1, {x:2}:1 }", "{ {y:1}:1 }");
+ assertEvaluates("{ {x:1,y:1}:88.0 }", "xw_plus_b(tensor0, tensor1, tensor2, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }", "{ {x:1}:7 }");
+
+ // expressions combining functions
+ assertEvaluates(String.valueOf(7.5 + 45 + 1.7),
+ "sum( " + // model computation:
+ " tensor0 * tensor1 * tensor2 " + // - feature combinations
+ " * tensor3" + // - model weights application
+ ") + 1.7",
+ "{ {x:1}:1, {x:2}:2 }", "{ {y:1}:3, {y:2}:4 }", "{ {z:1}:5 }",
+ "{ {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }");
+ assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{ {x:1}:0, {x:2}:0 }", "{ {x:1}:1, {x:2}:1 }");
+ assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:1}:1, {x:2}:1 }");
// tensor result dimensions are given from argument dimensions, not the resulting values
- assertEvaluates("x", "( {{x:-}:1.0} * {} )", "{ {x:1}:1 } * { {x:2}:1 }");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }");
-
- // demonstration of where this produces different results: { {x:1}:1 } with 2 dimensions ...
- assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )","{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 } * { {x:1,y:1}:1 }");
- // ... vs { {x:1}:1 } with only one dimension
- assertEvaluates("x, y", "{{x:1,y:1}:1.0}", "{ {x:1}:1 } * { {x:1,y:1}:1 }");
-
- // check that dimensions are preserved through other operations
- String d2 = "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"; // creates a 2d tensor with only an 1d value
- assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )", "match(" + d2 + ", {})");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " - {}");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " + {}");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "min(1.5, " + d2 +")");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "max({{x:1}:0}, " + d2 +")");
+ assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2}:1 }");
+ assertEvaluates("tensor(x{},y{}):{{x:1}:1.0}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1}:1 }");
}
public void testProgrammaticBuildingAndPrecedence() {
@@ -316,12 +295,16 @@ public class EvaluationTestCase extends junit.framework.TestCase {
assertEvaluates(77, "average(\"2*3\",\"pow(2,3)\")+average(\"2*3\",\"pow(2,3)\").timesten", context);
}
- private RankingExpression assertEvaluates(String tensorValue, String expressionString) {
- return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, defaultContext);
+ private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
+ MapContext context = defaultContext.thawedCopy();
+ int argumentIndex = 0;
+ for (String tensorArgument : tensorArguments)
+ context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument)));
+ return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context);
}
/** Validate also that the dimension of the resulting tensors are as expected */
- private RankingExpression assertEvaluates(String tensorDimensions, String resultTensor, String expressionString) {
+ private RankingExpression assertEvaluates_old(String tensorDimensions, String resultTensor, String expressionString) {
RankingExpression expression = assertEvaluates(new TensorValue(MapTensor.from(resultTensor)), expressionString, defaultContext);
TensorValue value = (TensorValue)expression.evaluate(defaultContext);
assertEquals(toSet(tensorDimensions), value.asTensor().dimensions());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
index 95c4402a612..08fdc9917a4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
@@ -17,22 +17,25 @@ public class NeuralNetEvaluationTestCase {
/** "XOR" neural network, separate expression per layer */
@Test
public void testPerLayerExpression() {
- String input = "{ {x:1}:0, {x:2}:1 }";
-
- String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }";
- String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }";
- String firstLayerInput = "sum(" + input + "*" + firstLayerWeights + ", x) + " + firstLayerBias;
+ String input = "{ {x:1}:0, {x:2}:1 }"; // tensor0
+ String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; // tensor1
+ String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; // tensor2
+ String firstLayerInput = "sum(tensor0 * tensor1, x) + tensor2";
String firstLayerOutput = "min(1.0, max(0.0, 0.5 + " + firstLayerInput + "))"; // non-linearity, "poor man's sigmoid"
- assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput);
- String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }";
- String secondLayerBias = "{ {y:1}:-0.5 }";
- String secondLayerInput = "sum(" + firstLayerOutput + "*" + secondLayerWeights + ", h) + " + secondLayerBias;
+ assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput, input, firstLayerWeights, firstLayerBias);
+ String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; // tensor3
+ String secondLayerBias = "{ {y:1}:-0.5 }"; // tensor4
+ String secondLayerInput = "sum(" + firstLayerOutput + "* tensor3, h) + tensor4";
String secondLayerOutput = "min(1.0, max(0.0, 0.5 + " + secondLayerInput + "))"; // non-linearity, "poor man's sigmoid"
- assertEvaluates("{ {y:1}:1 }", secondLayerOutput);
+ assertEvaluates("{ {y:1}:1 }", secondLayerOutput, input, firstLayerWeights, firstLayerBias, secondLayerWeights, secondLayerBias);
}
- private RankingExpression assertEvaluates(String tensorValue, String expressionString) {
- return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, new MapContext());
+ private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
+ MapContext context = new MapContext();
+ int argumentIndex = 0;
+ for (String tensorArgument : tensorArguments)
+ context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument)));
+ return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context);
}
private RankingExpression assertEvaluates(Value value, String expressionString, Context context) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
index 9d94ec0bc99..61b230ab390 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
@@ -69,12 +69,4 @@ public class SimplifierTestCase {
assertEquals("a + (b + c) / 100000000.0", transformed.toString());
}
- @Test
- public void testSimplificationWithTensorConstants() throws ParseException {
- new Simplifier().transform(new RankingExpression(
- "sum(sum((tensorFromWeightedSet(query(wset_query),x)+" +
- " tensorFromWeightedSet(attribute(wset),x)) * " +
- " {{x:0,y:0}:54, {x:0,y:1} :69, {x:1,y:0} :72, {x:1,y:1} :93},x))"));
- }
-
}
diff --git a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java
index d70b55c66a2..a54f1971d21 100644
--- a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java
+++ b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java
@@ -20,7 +20,7 @@ public class ClusterState implements Cloneable {
private Map<Node, NodeState> nodeStates = new TreeMap<>();
// TODO: Change to one count for distributor and one for storage, rather than an array
- // TODO: Rename, this is not the highest node count but the highest index
+ // TODO: RenameFunction, this is not the highest node count but the highest index
private ArrayList<Integer> nodeCount = new ArrayList<>(2);
private String description = "";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java
index 3bda4159ca6..4fd743e4724 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java
@@ -21,6 +21,8 @@ import java.util.function.UnaryOperator;
@Beta
public class MapTensor implements Tensor {
+ // TODO: Enforce that all addresses are dense (and then avoid storing keys in TensorAddress)
+
private final ImmutableSet<String> dimensions;
private final ImmutableMap<TensorAddress, Double> cells;
@@ -31,7 +33,7 @@ public class MapTensor implements Tensor {
}
/** Creates a sparse tensor */
- MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) {
+ public MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) {
ensureValidDimensions(cells, dimensions);
this.dimensions = ImmutableSet.copyOf(dimensions);
this.cells = ImmutableMap.copyOf(cells);
@@ -52,24 +54,41 @@ public class MapTensor implements Tensor {
*/
public static MapTensor from(String s) {
s = s.trim();
- if ( s.startsWith("("))
- return fromTensorWithEmptyDimensions(s);
- else if ( s.startsWith("{"))
- return fromTensor(s, Collections.emptySet());
- else
- throw new IllegalArgumentException("Excepted a string starting by { or (, got '" + s + "'");
+ try {
+ if (s.startsWith("tensor("))
+ return fromTypedTensor(s);
+ else if (s.startsWith("{"))
+ return fromUntypedTensor(s, Collections.emptySet());
+ else
+ return fromNumber(Double.parseDouble(s));
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + s + "'");
+ }
}
- private static MapTensor fromTensorWithEmptyDimensions(String s) {
+ private static MapTensor fromTypedTensor(String s) {
+ if ( ! s.startsWith("tensor(")) throw tensorFormatException(s);
+ s = s.substring("tensor(".length());
+ int typeSpecEnd = s.indexOf(")");
+ if (typeSpecEnd < 0 ) throw tensorFormatException(s);
+ String typeSpec = s.substring(0, typeSpecEnd);
+
+ Set<String> dimensions = new HashSet<>();
+ for (String dimensionSpec : typeSpec.split(",")) {
+ dimensionSpec = dimensionSpec.trim();
+ if ( ! dimensionSpec.endsWith("{}"))
+ throw new IllegalArgumentException("Only mapped dimensions ({}) are supported, got '" + dimensionSpec + "'");
+ dimensions.add(dimensionSpec.substring(0, dimensionSpec.length() - 2));
+ }
+
+ s = s.substring(typeSpec.length() + 1);
+ if ( ! s.startsWith(":")) throw tensorFormatException(s);
s = s.substring(1).trim();
- int multiplier = s.indexOf("*");
- if (multiplier < 0 || ! s.endsWith(")"))
- throw new IllegalArgumentException("Expected a tensor on the form ({dimension:-,...}*{{cells}}), got '" + s + "'");
- MapTensor dimensionTensor = fromTensor(s.substring(0, multiplier).trim(), Collections.emptySet());
- return fromTensor(s.substring(multiplier + 1, s.length() - 1), dimensionTensor.dimensions());
+ return fromUntypedTensor(s, dimensions);
}
- private static MapTensor fromTensor(String s, Set<String> additionalDimensions) {
+ private static MapTensor fromUntypedTensor(String s, Set<String> additionalDimensions) {
s = s.trim().substring(1).trim();
ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
while (s.length() > 1) {
@@ -94,6 +113,16 @@ public class MapTensor implements Tensor {
dimensions.addAll(additionalDimensions);
return new MapTensor(dimensions, cellMap);
}
+
+ private static MapTensor fromNumber(double number) {
+ ImmutableMap.Builder<TensorAddress, Double> singleCell = new ImmutableMap.Builder<>();
+ singleCell.put(TensorAddress.empty, number);
+ return new MapTensor(ImmutableSet.of(), singleCell.build());
+ }
+
+ private static IllegalArgumentException tensorFormatException(String s) {
+ return new IllegalArgumentException("Expected a tensor on the form tensor(dimensionspec):content, but got '" + s + "'");
+ }
private static Double asDouble(TensorAddress address, String s) {
try {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java
deleted file mode 100644
index 074742acee1..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import com.google.common.collect.ImmutableMap;
-
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Computes a <i>match product</i>, see {@link Tensor#match}
- *
- * @author bratseth
- */
-class MatchProduct {
-
- private final Set<String> dimensions;
- private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
-
- public MatchProduct(Tensor a, Tensor b) {
- this.dimensions = TensorOperations.combineDimensions(a, b);
- for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) {
- Double sameValueInB = b.cells().get(aCell.getKey());
- if (sameValueInB != null)
- cells.put(aCell.getKey(), aCell.getValue() * sameValueInB);
- }
- }
-
- /** Returns the result of taking this product */
- public MapTensor result() {
- return new MapTensor(dimensions, cells.build());
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 41882738e89..4b17f65ea21 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -2,18 +2,25 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
+import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.L1Normalize;
+import com.yahoo.tensor.functions.L2Normalize;
+import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.Softmax;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
-import java.util.function.DoubleFunction;
import java.util.function.DoubleUnaryOperator;
-import java.util.function.UnaryOperator;
+import java.util.function.Function;
/**
* A multidimensional array which can be used in computations.
@@ -49,128 +56,74 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
- // ----------------- Level 0 functions
+ // ----------------- Primitive tensor functions
- default Tensor map(Tensor tensor, DoubleUnaryOperator mapper) {
- throw new UnsupportedOperationException("Not implemented");
+ default Tensor map(DoubleUnaryOperator mapper) {
+ return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
}
- default Tensor reduce(Tensor tensor, String dimension,
- DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) {
- throw new UnsupportedOperationException("Not implemented");
+ /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
+ default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
+ return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate();
}
- default Tensor join(Tensor tensorA, Tensor tensorB, DoubleBinaryOperator combinator) {
- throw new UnsupportedOperationException("Not implemented");
+ default Tensor join(Tensor argument, DoubleBinaryOperator combinator) {
+ return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate();
}
- // ----------------- Old stuff
- /**
- * Returns the <i>sparse tensor product</i> of this tensor and the argument tensor.
- * This is the all-to-all combinations of cells in the argument tenors, except the combinations
- * which have conflicting labels for the same dimension. The value of each combination is the product
- * of the values of the two input cells. The dimensions of the tensor product is the set union of the
- * dimensions of the argument tensors.
- * <p>
- * If there are no overlapping dimensions this is the regular tensor product.
- * If the two tensors have exactly the same dimensions this is the Hadamard product.
- * <p>
- * The sparse tensor product is associative and commutative.
- *
- * @param argument the tensor to multiply by this
- * @return the resulting tensor.
- */
- default Tensor multiply(Tensor argument) {
- return new TensorProduct(this, argument).result();
+ default Tensor rename(String fromDimension, String toDimension) {
+ return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
+ Collections.singletonList(toDimension)).evaluate();
}
- /**
- * Returns the <i>match product</i> of two tensors.
- * This returns a tensor which contains the <i>matching</i> cells in the two tensors, with their
- * values multiplied.
- * <p>
- * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
- * and have the value undefined for any non-shared dimension.
- * <p>
- * The dimensions of the resulting tensor is the set intersection of the two argument tensors.
- * <p>
- * If the two tensors have exactly the same dimensions, this is the Hadamard product.
- */
- default Tensor match(Tensor argument) {
- return new MatchProduct(this, argument).result();
+ default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
+ return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
}
-
- /**
- * Returns a tensor which contains the cells of both argument tensors, where the value for
- * any <i>matching</i> cell is the min of the two possible values.
- * <p>
- * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
- * and have the value undefined for any non-shared dimension.
- */
- default Tensor min(Tensor argument) {
- return new TensorMin(this, argument).result();
+
+ static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) {
+ return new Generate(type, valueSupplier).evaluate();
}
-
- /**
- * Returns a tensor which contains the cells of both argument tensors, where the value for
- * any <i>matching</i> cell is the max of the two possible values.
- * <p>
- * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
- * and have the value undefined for any non-shared dimension.
- */
- default Tensor max(Tensor argument) {
- return new TensorMax(this, argument).result();
+
+ // ----------------- Composite tensor functions which have a defined primitive mapping
+
+ default Tensor l1Normalize(String dimension) {
+ return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
}
- /**
- * Returns a tensor which contains the cells of both argument tensors, where the value for
- * any <i>matching</i> cell is the sum of the two possible values.
- * <p>
- * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
- * and have the value undefined for any non-shared dimension.
- */
- default Tensor add(Tensor argument) {
- return new TensorSum(this, argument).result();
+ default Tensor l2Normalize(String dimension) {
+ return new L2Normalize(new ConstantTensor(this), dimension).evaluate();
}
- /**
- * Returns a tensor which contains the cells of both argument tensors, where the value for
- * any <i>matching</i> cell is the difference of the two possible values.
- * <p>
- * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
- * and have the value undefined for any non-shared dimension.
- */
- default Tensor subtract(Tensor argument) {
- return new TensorDifference(this, argument).result();
+ default Tensor matmul(Tensor argument, String dimension) {
+ return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
}
- /**
- * Returns a tensor with the same cells as this and the given function is applied to all its cell values.
- *
- * @param function the function to apply to all cells
- * @return the tensor with the function applied to all the cells of this
- */
- default Tensor apply(UnaryOperator<Double> function) {
- return new TensorFunction(this, function).result();
+ default Tensor softmax(String dimension) {
+ return new Softmax(new ConstantTensor(this), dimension).evaluate();
}
- /**
- * Returns a tensor with the given dimension removed and cells which contains the sum of the values
- * in the removed dimension.
- */
- default Tensor sum(String dimension) {
- return new TensorDimensionSum(dimension, this).result();
- }
+ // ----------------- Composite tensor functions mapped to primitives here on the fly
- /**
- * Returns the sum of all the cells of this tensor.
- */
- default double sum() {
- double sum = 0;
- for (Map.Entry<TensorAddress, Double> cell : cells().entrySet())
- sum += cell.getValue();
- return sum;
- }
+ default Tensor multiply(Tensor argument) { return join(argument, (a, b) -> (a * b )); }
+ default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); }
+ default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); }
+ default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); }
+ default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); }
+ default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); }
+ default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); }
+ default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); }
+ default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); }
+ default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); }
+ default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); }
+ default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
+ default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); }
+
+ default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
+ default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); }
+ default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); }
+ default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); }
+ default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); }
+ default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); }
/**
* Returns true if the given tensor is mathematically equal to this:
@@ -226,19 +179,28 @@ public interface Tensor {
* @return the tensor on the standard string format
*/
static String toStandardString(Tensor tensor) {
- Set<String> emptyDimensions = emptyDimensions(tensor);
- if (emptyDimensions.size() > 0) // explicitly list empty dimensions
- return "( " + unitTensorWithDimensions(emptyDimensions) + " * " + contentToString(tensor) + " )";
+ if ( emptyDimensions(tensor).size() > 0) // explicitly output type TODO: Always do that
+ return typeToString(tensor) + ":" + contentToString(tensor);
else
return contentToString(tensor);
}
+ static String typeToString(Tensor tensor) {
+ if (tensor.dimensions().isEmpty()) return "tensor()";
+ StringBuilder b = new StringBuilder("tensor(");
+ for (String dimension : tensor.dimensions())
+ b.append(dimension).append("{},");
+ b.setLength(b.length() -1);
+ b.append(")");
+ return b.toString();
+ }
+
static String contentToString(Tensor tensor) {
- List<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
- Collections.sort(cellEntries, Map.Entry.<TensorAddress, Double>comparingByKey());
+ List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
+ Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
StringBuilder b = new StringBuilder("{");
- for (Map.Entry<TensorAddress, Double> cell : cellEntries) {
+ for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) {
b.append(cell.getKey()).append(":").append(cell.getValue());
b.append(",");
}
@@ -259,8 +221,4 @@ public interface Tensor {
return emptyDimensions;
}
- static String unitTensorWithDimensions(Set<String> dimensions) {
- return new MapTensor(Collections.singletonMap(TensorAddress.emptyWithDimensions(dimensions), 1.0)).toString();
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 11c6a5f6685..e3c089de071 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -8,12 +8,11 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
+import java.util.Optional;
import java.util.Set;
/**
* An immutable address to a tensor cell.
- * This is sparse: Only dimensions which have a different label than "undefined" are
- * explicitly included.
* <p>
* Tensor addresses are ordered by increasing size primarily, and by the natural order of the elements in sorted
* order secondarily.
@@ -66,14 +65,6 @@ public final class TensorAddress implements Comparable<TensorAddress> {
return TensorAddress.fromSorted(elements);
}
- /** Creates an empty address with a set of dimensions */
- public static TensorAddress emptyWithDimensions(Set<String> dimensions) {
- List<Element> elements = new ArrayList<>(dimensions.size());
- for (String dimension : dimensions)
- elements.add(new Element(dimension, Element.undefinedLabel));
- return TensorAddress.fromUnsorted(elements);
- }
-
/** Returns an immutable list of the elements of this address in sorted order */
public List<Element> elements() { return elements; }
@@ -93,6 +84,14 @@ public final class TensorAddress implements Comparable<TensorAddress> {
return dimensions;
}
+ /** Returns the label at the given dimension, or empty if this dimension is not present */
+ public Optional<String> labelOfDimension(String dimension) {
+ for (TensorAddress.Element element : elements)
+ if (element.dimension().equals(dimension))
+ return Optional.of(element.label());
+ return Optional.empty();
+ }
+
@Override
public int compareTo(TensorAddress other) {
int sizeComparison = Integer.compare(this.elements.size(), other.elements.size());
@@ -123,7 +122,6 @@ public final class TensorAddress implements Comparable<TensorAddress> {
public String toString() {
StringBuilder b = new StringBuilder("{");
for (TensorAddress.Element element : elements) {
- //if (element.label() == Element.undefinedLabel) continue;
b.append(element.toString());
b.append(",");
}
@@ -136,18 +134,13 @@ public final class TensorAddress implements Comparable<TensorAddress> {
/** A tensor address element. Elements have the lexical order of the dimensions as natural order. */
public static class Element implements Comparable<Element> {
- static final String undefinedLabel = "-";
-
private final String dimension;
private final String label;
private final int hashCode;
public Element(String dimension, String label) {
this.dimension = dimension;
- if (label.equals(undefinedLabel))
- this.label = undefinedLabel;
- else
- this.label = label;
+ this.label = label;
this.hashCode = dimension.hashCode() + label.hashCode();
}
@@ -175,9 +168,7 @@ public final class TensorAddress implements Comparable<TensorAddress> {
@Override
public String toString() {
- StringBuilder b = new StringBuilder();
- b.append(dimension).append(":").append(label);
- return b.toString();
+ return dimension + ":" + label;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java
deleted file mode 100644
index ceb003b1615..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the difference between two tensors, see {@link Tensor#subtract}
- *
- * @author bratseth
- */
-class TensorDifference {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorDifference(Tensor a, Tensor b) {
- this.dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet())
- cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) - bCell.getValue());
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() {
- return new MapTensor(dimensions, cells);
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java
deleted file mode 100644
index d15e5092476..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the max of each cell of two tensors, see {@link Tensor#max}
- *
- * @author bratseth
- */
-class TensorMax {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorMax(Tensor a, Tensor b) {
- dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- Double aValue = a.cells().get(bCell.getKey());
- if (aValue == null)
- cells.put(bCell.getKey(), bCell.getValue());
- else
- cells.put(bCell.getKey(), Math.max(aValue, bCell.getValue()));
- }
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() {
- return new MapTensor(dimensions, cells);
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java
deleted file mode 100644
index e389dea3883..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the min of each cell of two tensors, see {@link Tensor#min}
- *
- * @author bratseth
- */
-class TensorMin {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorMin(Tensor a, Tensor b) {
- dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- Double aValue = a.cells().get(bCell.getKey());
- if (aValue == null)
- cells.put(bCell.getKey(), bCell.getValue());
- else
- cells.put(bCell.getKey(), Math.min(aValue, bCell.getValue()));
- }
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() { return new MapTensor(dimensions, cells); }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java
deleted file mode 100644
index aca306b914c..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import com.google.common.collect.ImmutableSet;
-
-import java.util.Set;
-
-/**
- * Functions on tensors
- *
- * @author bratseth
- */
-class TensorOperations {
-
- /**
- * A utility method which returns an ummutable set of the union of the dimensions
- * of the two argument tensors.
- *
- * @return the combined dimensions as an unmodifiable set
- */
- static Set<String> combineDimensions(Tensor a, Tensor b) {
- ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>();
- setBuilder.addAll(a.dimensions());
- setBuilder.addAll(b.dimensions());
- return setBuilder.build();
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java
deleted file mode 100644
index 221bd985380..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import com.google.common.collect.ImmutableMap;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.ListIterator;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Computes a <i>sparse tensor product</i>, see {@link Tensor#multiply}
- *
- * @author bratseth
- */
-class TensorProduct {
-
- private final Set<String> dimensionsA, dimensionsB;
-
- private final Set<String> dimensions;
- private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
-
- public TensorProduct(Tensor a, Tensor b) {
- dimensionsA = a.dimensions();
- dimensionsB = b.dimensions();
-
- // Dimension product
- dimensions = TensorOperations.combineDimensions(a, b);
-
- // Cell product (slow baseline implementation)
- for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) {
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- TensorAddress combinedAddress = combine(aCell.getKey(), bCell.getKey());
- if (combinedAddress == null) continue; // not combinable
- cells.put(combinedAddress, aCell.getValue() * bCell.getValue());
- }
- }
- }
-
- private TensorAddress combine(TensorAddress a, TensorAddress b) {
- List<TensorAddress.Element> combined = new ArrayList<>();
- combined.addAll(dense(a, dimensionsA));
- combined.addAll(dense(b, dimensionsB));
- Collections.sort(combined);
- TensorAddress.Element previous = null;
- for (ListIterator<TensorAddress.Element> i = combined.listIterator(); i.hasNext(); ) {
- TensorAddress.Element current = i.next();
- if (previous != null && previous.dimension().equals(current.dimension())) { // an overlapping dimension
- if (previous.label().equals(current.label()))
- i.remove(); // a match: remove the duplicate
- else
- return null; // no match: a combination isn't viable
- }
- previous = current;
- }
- return TensorAddress.fromSorted(sparse(combined));
- }
-
- /**
- * Returns a set of tensor elements which contains an entry for each dimension including "undefined" values
- * (which are not present in the sparse elements list).
- */
- private List<TensorAddress.Element> dense(TensorAddress sparse, Set<String> dimensions) {
- if (sparse.elements().size() == dimensions.size()) return sparse.elements();
-
- List<TensorAddress.Element> dense = new ArrayList<>(sparse.elements());
- for (String dimension : dimensions) {
- if ( ! sparse.hasDimension(dimension))
- dense.add(new TensorAddress.Element(dimension, TensorAddress.Element.undefinedLabel));
- }
- return dense;
- }
-
- /**
- * Removes any "undefined" entries from the given elements.
- */
- private List<TensorAddress.Element> sparse(List<TensorAddress.Element> dense) {
- List<TensorAddress.Element> sparse = new ArrayList<>();
- for (TensorAddress.Element element : dense) {
- if ( ! element.label().equals(TensorAddress.Element.undefinedLabel))
- sparse.add(element);
- }
- return sparse;
- }
-
- /** Returns the result of taking this product */
- public Tensor result() {
- return new MapTensor(dimensions, cells.build());
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java
deleted file mode 100644
index 85dfa289bd3..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the sum of two tensors, see {@link Tensor#add}
- *
- * @author bratseth
- */
-class TensorSum {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorSum(Tensor a, Tensor b) {
- dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) + bCell.getValue());
- }
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() { return new MapTensor(dimensions, cells); }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 23cdc0e6051..31454e28baf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -1,5 +1,7 @@
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.Tensor;
+
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
* as a tree of primitive tensor functions.
@@ -8,4 +10,8 @@ package com.yahoo.tensor.functions;
*/
public abstract class CompositeTensorFunction extends TensorFunction {
+ /** Evaluates this by first converting it to a primitive function */
+ @Override
+ public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java
deleted file mode 100644
index 113247be3bb..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import com.yahoo.tensor.MapTensor;
-
-/**
- * A function which returns a constant tensor.
- *
- * @author bratseth
- */
-public class Constant extends PrimitiveTensorFunction {
-
- private final MapTensor constant;
-
- public Constant(String tensorString) {
- this.constant = MapTensor.from(tensorString);
- }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
-
- @Override
- public String toString() { return constant.toString(); }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
new file mode 100644
index 00000000000..0727579a331
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.Tensor;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * A function which returns a constant tensor.
+ *
+ * @author bratseth
+ */
+public class ConstantTensor extends PrimitiveTensorFunction {
+
+ private final Tensor constant;
+
+ public ConstantTensor(String tensorString) {
+ this.constant = MapTensor.from(tensorString);
+ }
+
+ public ConstantTensor(Tensor tensor) {
+ this.constant = tensor;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext context) { return constant; }
+
+ @Override
+ public String toString(ToStringContext context) { return constant.toString(); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java
new file mode 100644
index 00000000000..24a4c61a58c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java
@@ -0,0 +1,14 @@
+package com.yahoo.tensor.functions;
+
+/**
+ * An evaluation context which is passed down to all nested functions during evaluation.
+ * The default implementation is empty as this library does not in itself have any need for a
+ * context.
+ *
+ * @author bratseth
+ */
+public interface EvaluationContext {
+
+ static EvaluationContext empty() { return new EvaluationContext() {}; }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
new file mode 100644
index 00000000000..c0e5776bf48
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -0,0 +1,57 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Function;
+
+/**
+ * An indexed tensor whose values are generated by a function
+ *
+ * @author bratseth
+ */
+public class Generate extends PrimitiveTensorFunction {
+
+ private final TensorType type;
+ private final Function<List<Integer>, Double> generator;
+
+ /**
+ * Creates a generated tensor
+ *
+ * @param type the type of the tensor
+ * @param generator the function generating values from a list of ints specifying the indexes of the
+ * tensor cell which will receive the value
+ * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
+ */
+ public Generate(TensorType type, Function<List<Integer>, Double> generator) {
+ Objects.requireNonNull(type, "The argument tensor type cannot be null");
+ Objects.requireNonNull(generator, "The argument function cannot be null");
+ validateType(type);
+ this.type = type;
+ this.generator = generator;
+ }
+
+ private void validateType(TensorType type) {
+ for (TensorType.Dimension dimension : type.dimensions())
+ if (dimension.type() != TensorType.Dimension.Type.indexedBound)
+ throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext context) {
+ throw new UnsupportedOperationException("Not implemented"); // TODO
+ }
+
+ @Override
+ public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 4d945963fdf..323da5906c3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -1,9 +1,24 @@
package com.yahoo.tensor.functions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
- * The join tensor function.
+ * The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
+ * given by the cross product of the cells of the given tensors, having as values the value produced by
+ * applying the given combinator function on the values from the two source cells.
*
* @author bratseth
*/
@@ -13,6 +28,9 @@ public class Join extends PrimitiveTensorFunction {
private final DoubleBinaryOperator combinator;
public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
+ Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
+ Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
+ Objects.requireNonNull(combinator, "The combinator function cannot be null");
this.argumentA = argumentA;
this.argumentB = argumentB;
this.combinator = combinator;
@@ -21,15 +39,60 @@ public class Join extends PrimitiveTensorFunction {
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
-
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
-
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
+ }
+
+ private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
+
@Override
- public String toString() {
- return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", lambda(a, b) (...))";
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+
+ // Dimension product
+ Set<String> dimensions = combineDimensions(a, b);
+
+ // Cell product (slow baseline implementation)
+ ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
+ for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) {
+ for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
+ TensorAddress combinedAddress = combineAddresses(aCell.getKey(), bCell.getKey());
+ if (combinedAddress == null) continue; // not combinable
+ cells.put(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
+ }
+ }
+
+ return new MapTensor(dimensions, cells.build());
}
+ private Set<String> combineDimensions(Tensor a, Tensor b) {
+ ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>();
+ setBuilder.addAll(a.dimensions());
+ setBuilder.addAll(b.dimensions());
+ return setBuilder.build();
+ }
+
+ private TensorAddress combineAddresses(TensorAddress a, TensorAddress b) {
+ List<TensorAddress.Element> combined = new ArrayList<>(a.elements());
+ for (TensorAddress.Element bElement : b.elements()) {
+ Optional<String> aLabel = a.labelOfDimension(bElement.dimension());
+ if ( ! aLabel.isPresent())
+ combined.add(bElement);
+ else if ( ! aLabel.get().equals(bElement.label()))
+ return null; // incompatible
+ }
+ return TensorAddress.fromUnsorted(combined);
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
new file mode 100644
index 00000000000..4467b378b3f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -0,0 +1,36 @@
+package com.yahoo.tensor.functions;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class L1Normalize extends CompositeTensorFunction {
+
+ private final TensorFunction argument;
+ private final String dimension;
+
+ public L1Normalize(TensorFunction argument, String dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument = argument.toPrimitive();
+ // join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y))
+ return new Join(primitiveArgument,
+ new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension),
+ ScalarFunctions.divide());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
new file mode 100644
index 00000000000..0e96b43bd22
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor.functions;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class L2Normalize extends CompositeTensorFunction {
+
+ private final TensorFunction argument;
+ private final String dimension;
+
+ public L2Normalize(TensorFunction argument, String dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument = argument.toPrimitive();
+ return new Join(primitiveArgument,
+ new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.sqrt()),
+ ScalarFunctions.divide());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 22dd08504d7..5db88953c64 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -1,10 +1,17 @@
package com.yahoo.tensor.functions;
-import java.util.function.DoubleBinaryOperator;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
import java.util.function.DoubleUnaryOperator;
/**
- * The join tensor function.
+ * The <i>map</i> tensor function produces a tensor where the given function is applied on each cell value.
*
* @author bratseth
*/
@@ -14,6 +21,8 @@ public class Map extends PrimitiveTensorFunction {
private final DoubleUnaryOperator mapper;
public Map(TensorFunction argument, DoubleUnaryOperator mapper) {
+ Objects.requireNonNull(argument, "The argument tensor cannot be null");
+ Objects.requireNonNull(mapper, "The argument function cannot be null");
this.argument = argument;
this.mapper = mapper;
}
@@ -22,13 +31,25 @@ public class Map extends PrimitiveTensorFunction {
public DoubleUnaryOperator mapper() { return mapper; }
@Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
return new Map(argument.toPrimitive(), mapper);
}
@Override
- public String toString() {
- return "map(" + argument.toString() + ", lambda(a) (...))";
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor argument = argument().evaluate(context);
+ ImmutableMap.Builder<TensorAddress, Double> mappedCells = new ImmutableMap.Builder<>();
+ for (java.util.Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet())
+ mappedCells.put(cell.getKey(), mapper.applyAsDouble(cell.getValue()));
+ return new MapTensor(argument.dimensions(), mappedCells.build());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "map(" + argument.toString(context) + ", " + mapper + ")";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
new file mode 100644
index 00000000000..4492ab083d4
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class Matmul extends CompositeTensorFunction {
+
+ private final TensorFunction argument1, argument2;
+ private final String dimension;
+
+ public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
+ this.argument1 = argument1;
+ this.argument2 = argument2;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument1 = argument1.toPrimitive();
+ TensorFunction primitiveArgument2 = argument2.toPrimitive();
+ return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index 9c0c9abaeb7..91e58f4bf3b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -1,5 +1,7 @@
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.Tensor;
+
/**
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
@@ -8,4 +10,5 @@ package com.yahoo.tensor.functions;
* @author bratseth
*/
public abstract class PrimitiveTensorFunction extends TensorFunction {
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java
deleted file mode 100644
index 09038a294ce..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java
+++ /dev/null
@@ -1,27 +0,0 @@
-package com.yahoo.tensor.functions;
-
-/**
- * The product tensor function
- *
- * @author bratseth
- */
-public class Product extends CompositeTensorFunction {
-
- private final TensorFunction argumentA, argumentB;
-
- public Product(TensorFunction argumentA, TensorFunction argumentB) {
- this.argumentA = argumentA;
- this.argumentB = argumentB;
- }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), (a, b) -> a * b);
- }
-
- @Override
- public String toString() {
- return "product(" + argumentA.toString() + ", " + argumentB.toString() + ")";
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 4b306d376a6..ef18cb61b17 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -1,38 +1,246 @@
package com.yahoo.tensor.functions;
-import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
- * The reduce tensor function.
+ * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
+ * are collapsed to a single value using an aggregator function.
*
* @author bratseth
*/
public class Reduce extends PrimitiveTensorFunction {
+ public enum Aggregator { avg, count, prod, sum, max, min; }
+
private final TensorFunction argument;
- private final String dimension;
- private final DoubleBinaryOperator reductor;
- private final Optional<DoubleBinaryOperator> postTransformation;
+ private final List<String> dimensions;
+ private final Aggregator aggregator;
+
+ /** Creates a reduce function reducing aLL dimensions */
+ public Reduce(TensorFunction argument, Aggregator aggregator) {
+ this(argument, aggregator, Collections.emptyList());
+ }
- public Reduce(TensorFunction argument, String dimension,
- DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) {
+ /** Creates a reduce function reducing a single dimension */
+ public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
+ this(argument, aggregator, Collections.singletonList(dimension));
+ }
+
+ /**
+ * Creates a reduce function.
+ *
+ * @param argument the tensor to reduce
+ * @param aggregator the aggregator function to use
+ * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
+ * producing a dimensionless tensor (a scalar).
+ * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
+ */
+ public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
+ Objects.requireNonNull(argument, "The argument tensor cannot be null");
+ Objects.requireNonNull(aggregator, "The aggregator cannot be null");
+ Objects.requireNonNull(dimensions, "The dimensions cannot be null");
this.argument = argument;
- this.dimension = dimension;
- this.reductor = reductor;
- this.postTransformation = postTransformation;
+ this.aggregator = aggregator;
+ this.dimensions = ImmutableList.copyOf(dimensions);
}
public TensorFunction argument() { return argument; }
@Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
- return new Reduce(argument.toPrimitive(), dimension, reductor, postTransformation);
+ return new Reduce(argument.toPrimitive(), aggregator, dimensions);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
+ }
+
+ private String commaSeparated(List<String> list) {
+ StringBuilder b = new StringBuilder();
+ for (String element : list)
+ b.append(", ").append(element);
+ return b.toString();
}
@Override
- public String toString() {
- return "reduce(" + argument.toString() + ", " + dimension + ", lambda(a, b) (...), lambda(a, b) (...))";
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor argument = this.argument.evaluate(context);
+
+ if ( ! dimensions.isEmpty() && ! argument.dimensions().containsAll(dimensions))
+ throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
+ dimensions + ": Not all those dimensions are present in this tensor");
+
+ if (dimensions.isEmpty() || dimensions.size() == argument.dimensions().size())
+ return reduceAll(argument);
+
+ // Reduce dimensions
+ Set<String> reducedDimensions = new HashSet<>(argument.dimensions());
+ reducedDimensions.removeAll(dimensions);
+
+ // Reduce cells
+ Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
+ for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) {
+ TensorAddress reducedAddress = reduceDimensions(cell.getKey(), reducedDimensions);
+ aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
+ aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
+ }
+ ImmutableMap.Builder<TensorAddress, Double> reducedCells = new ImmutableMap.Builder<>();
+ for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
+ reducedCells.put(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
+ return new MapTensor(reducedDimensions, reducedCells.build());
+ }
+
+ private TensorAddress reduceDimensions(TensorAddress address, Set<String> reducedDimensions) {
+ return TensorAddress.fromSorted(address.elements().stream()
+ .filter(e -> reducedDimensions.contains(e.dimension()))
+ .collect(Collectors.toList()));
+ }
+
+ private Tensor reduceAll(Tensor argument) {
+ ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
+ for (Double cellValue : argument.cells().values())
+ valueAggregator.aggregate(cellValue);
+ return new MapTensor(ImmutableMap.of(TensorAddress.empty, valueAggregator.aggregatedValue()));
+ }
+
+ private static abstract class ValueAggregator {
+
+ public static ValueAggregator ofType(Aggregator aggregator) {
+ switch (aggregator) {
+ case avg : return new AvgAggregator();
+ case count : return new CountAggregator();
+ case prod : return new ProdAggregator();
+ case sum : return new SumAggregator();
+ case max : return new MaxAggregator();
+ case min : return new MinAggregator();
+ default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
+ }
+
+ }
+
+ /** Add a new value to those aggregated by this */
+ public abstract void aggregate(double value);
+
+ /** Returns the value aggregated by this */
+ public abstract double aggregatedValue();
+
+ }
+
+ private static class AvgAggregator extends ValueAggregator {
+
+ private int valueCount = 0;
+ private double valueSum = 0.0;
+
+ @Override
+ public void aggregate(double value) {
+ valueCount++;
+ valueSum+= value;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return valueSum / valueCount;
+ }
+
+ }
+
+ private static class CountAggregator extends ValueAggregator {
+
+ private int valueCount = 0;
+
+ @Override
+ public void aggregate(double value) {
+ valueCount++;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return valueCount;
+ }
+
+ }
+
+ private static class ProdAggregator extends ValueAggregator {
+
+ private double valueProd = 1.0;
+
+ @Override
+ public void aggregate(double value) {
+ valueProd *= value;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return valueProd;
+ }
+
+ }
+
+ private static class SumAggregator extends ValueAggregator {
+
+ private double valueSum = 0.0;
+
+ @Override
+ public void aggregate(double value) {
+ valueSum += value;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return valueSum;
+ }
+
+ }
+
+ private static class MaxAggregator extends ValueAggregator {
+
+ private double maxValue = Double.MIN_VALUE;
+
+ @Override
+ public void aggregate(double value) {
+ if (value > maxValue)
+ maxValue = value;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return maxValue;
+ }
+
+ }
+
+ private static class MinAggregator extends ValueAggregator {
+
+ private double minValue = Double.MAX_VALUE;
+
+ @Override
+ public void aggregate(double value) {
+ if (value < minValue)
+ minValue = value;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return minValue;
+ }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
new file mode 100644
index 00000000000..05af86c33e8
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -0,0 +1,100 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
+ *
+ * @author bratseth
+ */
+public class Rename extends PrimitiveTensorFunction {
+
+ private final TensorFunction argument;
+ private final List<String> fromDimensions;
+ private final List<String> toDimensions;
+
+ public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
+ Objects.requireNonNull(argument, "The argument tensor cannot be null");
+ Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
+ Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
+ if (fromDimensions.size() < 1)
+ throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
+ if (fromDimensions.size() != toDimensions.size())
+ throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
+ fromDimensions.size() + " and " + toDimensions.size());
+ this.argument = argument;
+ this.fromDimensions = ImmutableList.copyOf(fromDimensions);
+ this.toDimensions = ImmutableList.copyOf(toDimensions);
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor tensor = argument.evaluate(context);
+ Map<String, String> fromToMap = fromToMap();
+ Set<String> renamedDimensions = tensor.dimensions().stream()
+ .map((d) -> fromToMap.getOrDefault(d, d))
+ .collect(Collectors.toSet());
+
+ ImmutableMap.Builder<TensorAddress, Double> renamedCells = new ImmutableMap.Builder<>();
+ for (Map.Entry<TensorAddress, Double> cell : tensor.cells().entrySet()) {
+ TensorAddress renamedAddress = rename(cell.getKey(), fromToMap);
+ renamedCells.put(renamedAddress, cell.getValue());
+ }
+ return new MapTensor(renamedDimensions, renamedCells.build());
+ }
+
+ private TensorAddress rename(TensorAddress address, Map<String, String> fromToMap) {
+ List<TensorAddress.Element> renamedElements = new ArrayList<>();
+ for (TensorAddress.Element element : address.elements()) {
+ String toDimension = fromToMap.get(element.dimension());
+ if (toDimension == null)
+ renamedElements.add(element);
+ else
+ renamedElements.add(new TensorAddress.Element(toDimension, element.label()));
+ }
+ return TensorAddress.fromUnsorted(renamedElements);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
+ toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
+ }
+
+ private Map<String, String> fromToMap() {
+ Map<String, String> map = new HashMap<>();
+ for (int i = 0; i < fromDimensions.size(); i++)
+ map.put(fromDimensions.get(i), toDimensions.get(i));
+ return map;
+ }
+
+ private String toVectorString(List<String> elements) {
+ if (elements.size() == 1)
+ return elements.get(0);
+ StringBuilder b = new StringBuilder("[");
+ for (String element : elements)
+ b.append(element).append(", ");
+ b.setLength(b.length() - 2);
+ return b.toString();
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
new file mode 100644
index 00000000000..9438c6c533a
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -0,0 +1,81 @@
+package com.yahoo.tensor.functions;
+
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Factory of scalar Java functions.
+ * The purpose of this is to embellish anonymous functions with a runtime type
+ * such that they can be inspected and will return a parseable toString.
+ *
+ * @author bratseth
+ */
+public class ScalarFunctions {
+
+ public static DoubleBinaryOperator add() { return new Addition(); }
+ public static DoubleBinaryOperator multiply() { return new Multiplication(); }
+ public static DoubleBinaryOperator divide() { return new Division(); }
+ public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
+ public static DoubleUnaryOperator exp() { return new Exponent(); }
+
+ public static class Addition implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) { return left + right; }
+
+ @Override
+ public String toString() { return "f(a,b)(a + b)"; }
+
+ }
+
+ public static class Multiplication implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) { return left * right; }
+
+ @Override
+ public String toString() { return "f(a,b)(a * b)"; }
+
+ }
+
+ public static class Division implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) { return left / right; }
+
+ @Override
+ public String toString() { return "f(a,b)(a / b)"; }
+ }
+
+ public static class Square implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) { return operand * operand; }
+
+ @Override
+ public String toString() { return "f(a)(a * a)"; }
+
+ }
+
+ public static class Sqrt implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) { return Math.sqrt(operand); }
+
+ @Override
+ public String toString() { return "f(a)(sqrt(a))"; }
+
+ }
+
+ public static class Exponent implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
+
+ @Override
+ public String toString() { return "f(a)(exp(a))"; }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
new file mode 100644
index 00000000000..b05b8172b42
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -0,0 +1,37 @@
+package com.yahoo.tensor.functions;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class Softmax extends CompositeTensorFunction {
+
+ private final TensorFunction argument;
+ private final String dimension;
+
+ public Softmax(TensorFunction argument, String dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument = argument.toPrimitive();
+ return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
+ new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.divide());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "softmax(" + argument.toString(context) + ", " + dimension + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 95fca95a042..a717292632e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -1,5 +1,9 @@
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.Tensor;
+
+import java.util.List;
+
/**
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
@@ -9,6 +13,9 @@ package com.yahoo.tensor.functions;
*/
public abstract class TensorFunction {
+ /** Returns the function arguments of this node in the order they are applied */
+ public abstract List<TensorFunction> functionArguments();
+
/**
* Translate this function - and all of its arguments recursively -
* to a tree of primitive functions only.
@@ -17,4 +24,24 @@ public abstract class TensorFunction {
*/
public abstract PrimitiveTensorFunction toPrimitive();
+ /**
+ * Evaluates this tensor.
+ *
+ * @param context a context which must be passed to all nexted functions when evaluating
+ */
+ public abstract Tensor evaluate(EvaluationContext context);
+
+ /** Evaluate with no context */
+ public final Tensor evaluate() { return evaluate(EvaluationContext.empty()); }
+
+ /**
+ * Return a string representation of this context.
+ *
+ * @param context a context which must be passed to all nexted functions when requesting the string value
+ */
+ public abstract String toString(ToStringContext context);
+
+ @Override
+ public final String toString() { return toString(ToStringContext.empty()); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
new file mode 100644
index 00000000000..b71229703d2
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
@@ -0,0 +1,14 @@
+package com.yahoo.tensor.functions;
+
+/**
+ * A context which is passed down to all nested functions when returning a string representation.
+ * The default implementation is empty as this library does not in itself have any need for a
+ * context.
+ *
+ * @author bratseth
+ */
+public interface ToStringContext {
+
+ static ToStringContext empty() { return new ToStringContext() {}; }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
new file mode 100644
index 00000000000..1988c1d2390
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -0,0 +1,45 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class XwPlusB extends CompositeTensorFunction {
+
+ private final TensorFunction x, w, b;
+ private final String dimension;
+
+ public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) {
+ this.x = x;
+ this.w = w;
+ this.b = b;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveX = x.toPrimitive();
+ TensorFunction primitiveW = w.toPrimitive();
+ TensorFunction primitiveB = b.toPrimitive();
+ return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension),
+ primitiveB,
+ ScalarFunctions.add());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "xw_plus_b(" + x.toString(context) + ", " +
+ w.toString(context) + ", " +
+ b.toString(context) + ", " +
+ dimension + ")";
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java
index 889b2851a08..af2260e2f20 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java
@@ -42,7 +42,7 @@ public class MapTensorBuilderTestCase {
Tensor tensor = new MapTensorBuilder().dimension("y").dimension("z").
cell().label("x", "0").value(1).build();
assertEquals(Sets.newHashSet("x", "y", "z"), tensor.dimensions());
- assertEquals("( {{y:-,z:-}:1.0} * {{x:0}:1.0} )", tensor.toString());
+ assertEquals("tensor(x{},y{},z{}):{{x:0}:1.0}", tensor.toString());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java
index 13ea0e95dc8..0372f328811 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java
@@ -33,7 +33,7 @@ public class MapTensorTestCase {
fail("Expected parse error");
}
catch (IllegalArgumentException expected) {
- assertEquals("Excepted a string starting by { or (, got '--'", expected.getMessage());
+ assertEquals("Excepted a number or a string starting by { or tensor(, got '--'", expected.getMessage());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
new file mode 100644
index 00000000000..e403bb56d56
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -0,0 +1,28 @@
+package com.yahoo.tensor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests functionality on Tensor
+ *
+ * @author bratseth
+ */
+public class TensorTestCase {
+
+ /** This is mostly tested in searchlib - spot checking here */
+ @Test
+ public void testTensorComputation() {
+ MapTensor tensor1 = MapTensor.from("{ {x:1}:3, {x:2}:7 }");
+ MapTensor tensor2 = MapTensor.from("{ {y:1}:5 }");
+ assertEquals(MapTensor.from("{ {x:1,y:1}:15, {x:2,y:1}:35 }"), tensor1.multiply(tensor2));
+ assertEquals(MapTensor.from("{ {x:1,y:1}:12, {x:2,y:1}:28 }"), tensor1.join(tensor2, (a, b) -> a * b - a ));
+ assertEquals(MapTensor.from("{ {x:1,y:1}:0, {x:2,y:1}:1 }"), tensor1.larger(tensor2));
+ assertEquals(MapTensor.from("{ {y:1}:50.0 }"), tensor1.matmul(tensor2, "x"));
+ assertEquals(MapTensor.from("{ {z:1}:3, {z:2}:7 }"), tensor1.rename("x", "z"));
+ assertEquals(MapTensor.from("{ {y:1,x:1}:8, {y:2,x:1}:12 }"), tensor1.add(tensor2).rename(ImmutableList.of("x", "y"),
+ ImmutableList.of("y", "x")));
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index 501397e89bc..cc9328f7274 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -12,8 +12,8 @@ public class TensorFunctionTestCase {
@Test
public void testTranslation() {
- assertTranslated("join({{x:1}:1.0}, {{x:2}:1.0}, lambda(a, b) (...))",
- new Product(new Constant("{{x:1}:1.0}"), new Constant("{{x:2}:1.0}")));
+ assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))",
+ new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x"));
}
private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 8580868dfdf..c3a5e24afc2 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -52,14 +52,14 @@ public class SparseBinaryFormatTestCase {
@Test
public void testSerializationOfTensorsWithSparseTensorAddresses() {
assertSerialization("{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x"));
- assertSerialization("({{y:-}:1} * {{x:0}:2.0})", Sets.newHashSet("x", "y"));
- assertSerialization("({{y:-}:1} * {{x:0}:2.0, {}:3.0})", Sets.newHashSet("x", "y"));
- assertSerialization("({{y:-}:1} * {{x:0}:2.0,{x:1}:3.0})", Sets.newHashSet("x", "y"));
- assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0})", Sets.newHashSet("x", "y", "z"));
- assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0,{x:0,y:1}:3.0})", Sets.newHashSet("x", "y", "z"));
- assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0})", Sets.newHashSet("x", "y", "z"));
- assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0,{y:1,x:0}:3.0})", Sets.newHashSet("x", "y", "z"));
- assertSerialization("({{z:-}:1} * {{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0})", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("tensor(x{},y{}):{{x:0}:2.0}", Sets.newHashSet("x", "y"));
+ assertSerialization("tensor(x{},y{}):{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x", "y"));
+ assertSerialization("tensor(x{},y{}):{{x:0}:2.0,{x:1}:3.0}", Sets.newHashSet("x", "y"));
+ assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0}", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0,{x:0,y:1}:3.0}", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0}", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0,{y:1,x:0}:3.0}", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("tensor(x{},y{},z{}):{{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0}", Sets.newHashSet("x", "y", "z"));
}
@Test