summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java6
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java70
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java122
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java111
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java59
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java65
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java3
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj375
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java175
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java359
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java27
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java8
-rw-r--r--vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java188
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java31
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java93
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java73
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java36
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java236
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java100
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java81
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java45
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java28
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java16
56 files changed, 1091 insertions, 1905 deletions
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 64bb538eab5..206ab8e30f0 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1049,7 +1049,7 @@ public class JsonReaderTestCase {
@Test
public void testParsingOfTensorWithDimensions() {
- assertTensorField("tensor(x{},y{}):{}",
+ assertTensorField("( {{x:-,y:-}:1.0} * {} )",
createPutWithTensor("{ "
+ " \"dimensions\": [\"x\",\"y\"] "
+ "}"));
@@ -1101,7 +1101,7 @@ public class JsonReaderTestCase {
@Test
public void testParsingOfTensorWithDimensionsAndCells() {
- assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}",
+ assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )",
createPutWithTensor("{ "
+ " \"dimensions\": [\"x\",\"y\",\"z\"], "
+ " \"cells\": [ "
@@ -1115,7 +1115,7 @@ public class JsonReaderTestCase {
@Test
public void testParsingOfTensorWithDimensionsAndCellsInDifferentJsonOrder() {
- assertTensorField("tensor(x{},y{},z{}):{{x:a,y:b}:2.0,{x:c}:3.0}",
+ assertTensorField("( {{z:-}:1.0} * {{x:a,y:b}:2.0,{x:c}:3.0} )",
createPutWithTensor("{ "
+ " \"cells\": [ "
+ " { \"address\": { \"x\": \"a\", \"y\": \"b\" }, "
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java
index 252d40b7291..ba06843f178 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java
@@ -69,9 +69,9 @@ public class RestApiTest {
// POST new nodes
assertResponse(new Request("http://localhost:8080/nodes/v2/node",
("[" + asNodeJson("host8.yahoo.com", "default") + "," +
- asNodeJson("host9.yahoo.com", "large-variant") + "," +
- asHostJson("parent2.yahoo.com", "large-variant") + "," +
- asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]").
+ asNodeJson("host9.yahoo.com", "large-variant") + "," +
+ asHostJson("parent2.yahoo.com", "large-variant") + "," +
+ asDockerNodeJson("host11.yahoo.com", "parent.host.yahoo.com") + "]").
getBytes(StandardCharsets.UTF_8),
Request.Method.POST),
"{\"message\":\"Added 4 nodes to the provisioned state\"}");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 620c6fad0b4..0dff0414ac2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -2,7 +2,6 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
-import com.yahoo.tensor.functions.EvaluationContext;
import java.util.Set;
@@ -11,7 +10,7 @@ import java.util.Set;
*
* @author bratseth
*/
-public abstract class Context implements EvaluationContext {
+public abstract class Context {
/**
* <p>Returns the value of a simple variable name.</p>
@@ -42,7 +41,7 @@ public abstract class Context implements EvaluationContext {
* "main" (or only) value.
*/
public Value get(String name, Arguments arguments,String output) {
- if (arguments!=null && arguments.expressions().size() > 0)
+ if (arguments!=null && arguments.expressions().size()>0)
throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" +
name + arguments + "'");
if (output==null)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index f8dcd8a6127..2bae382d5bd 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
+ public boolean compare(TruthOperator operator, Value value) {
+ return operator.evaluate(asDouble(), value.asDouble());
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 0e0d793bfd1..028dad16d21 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -98,6 +98,16 @@ public final class DoubleValue extends DoubleCompatibleValue {
}
@Override
+ public boolean compare(TruthOperator operator, Value value) {
+ try {
+ return operator.evaluate(this.value, value.asDouble());
+ }
+ catch (UnsupportedOperationException e) {
+ throw unsupported("comparison",value);
+ }
+ }
+
+ @Override
public Value function(Function function, Value value) {
// use the tensor implementation of max and min if the argument is a tensor
if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 2dffe2a1100..9ee9a1f7a71 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -34,9 +34,11 @@ public class MapContext extends Context {
* Creates a map context from a map.
* The ownership of the map is transferred to this - it cannot be further modified by the caller.
* All the Values of the map will be frozen.
+ *
+ * @since 5.1.5
*/
public MapContext(Map<String,Value> bindings) {
- this.bindings = bindings;
+ this.bindings=bindings;
for (Value boundValue : bindings.values())
boundValue.freeze();
}
@@ -65,9 +67,6 @@ public class MapContext extends Context {
if (frozen) return bindings;
return Collections.unmodifiableMap(bindings);
}
-
- /** Returns a new, modifiable context containing all the bindings of this */
- public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); }
/** Returns an unmodifiable map of the names of this */
public @Override Set<String> names() {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index eb997ab818a..379b5755c7b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -68,10 +68,10 @@ public class StringValue extends Value {
}
@Override
- public Value compare(TruthOperator operator, Value value) {
+ public boolean compare(TruthOperator operator, Value value) {
if (operator.equals(TruthOperator.EQUAL))
- return new BooleanValue(this.equals(value));
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
+ return this.equals(value);
+ throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index b1f4a7b20ca..12bede95aae 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -8,7 +8,6 @@ import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.TensorType;
-import java.util.Collections;
import java.util.Optional;
/**
@@ -18,7 +17,7 @@ import java.util.Optional;
*
* @author bratseth
*/
-@Beta
+ @Beta
public class TensorValue extends Value {
/** The tensor value of this */
@@ -54,7 +53,7 @@ public class TensorValue extends Value {
@Override
public Value negate() {
- return new TensorValue(value.map((value) -> -value));
+ return new TensorValue(value.apply((Double value) -> -value));
}
@Override
@@ -62,7 +61,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
else
- return new TensorValue(value.map((value) -> value + argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value + argument.asDouble()));
}
@Override
@@ -70,7 +69,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.subtract(((TensorValue) argument).value));
else
- return new TensorValue(value.map((value) -> value - argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value - argument.asDouble()));
}
@Override
@@ -78,15 +77,35 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.multiply(((TensorValue) argument).value));
else
- return new TensorValue(value.map((value) -> value * argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value * argument.asDouble()));
}
@Override
public Value divide(Value argument) {
if (argument instanceof TensorValue)
- return new TensorValue(value.divide(((TensorValue) argument).value));
+ throw new UnsupportedOperationException("Two tensors cannot be divided");
else
- return new TensorValue(value.map((value) -> value / argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value / argument.asDouble()));
+ }
+
+ public Value match(Value argument) {
+ return new TensorValue(value.match(asTensor(argument, "match")));
+ }
+
+ public Value min(Value argument) {
+ return new TensorValue(value.min(asTensor(argument, "min")));
+ }
+
+ public Value max(Value argument) {
+ return new TensorValue(value.max(asTensor(argument, "max")));
+ }
+
+ public Value sum(String dimension) {
+ return new TensorValue(value.sum(dimension));
+ }
+
+ public Value sum() {
+ return new DoubleValue(value.sum());
}
private Tensor asTensor(Value value, String operationName) {
@@ -103,37 +122,18 @@ public class TensorValue extends Value {
}
@Override
- public Value compare(TruthOperator operator, Value argument) {
- return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
- }
-
- private Tensor compareTensor(TruthOperator operator, Tensor argument) {
- switch (operator) {
- case LARGER: return value.larger(argument);
- case LARGEREQUAL: return value.largerOrEqual(argument);
- case SMALLER: return value.smaller(argument);
- case SMALLEREQUAL: return value.smallerOrEqual(argument);
- case EQUAL: return value.equal(argument);
- case NOTEQUAL: return value.notEqual(argument);
- default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
- }
+ public boolean compare(TruthOperator operator, Value value) {
+ throw new UnsupportedOperationException("A tensor cannot be compared with any value");
}
@Override
- public Value function(Function function, Value arg) {
- if (arg instanceof TensorValue)
- return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString())));
+ public Value function(Function function, Value argument) {
+ if (function.equals(Function.min) && argument instanceof TensorValue)
+ return min(argument);
+ else if (function.equals(Function.max) && argument instanceof TensorValue)
+ return max(argument);
else
- return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
- }
-
- private Tensor functionOnTensor(Function function, Tensor argument) {
- switch (function) {
- case min: return value.min(argument);
- case max: return value.max(argument);
- case atan2: return value.atan2(argument);
- default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
- }
+ return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble())));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 8ce18265231..e5680edc68a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -42,7 +42,7 @@ public abstract class Value {
public abstract Value divide(Value value);
/** Perform the comparison specified by the operator between this value and the given value */
- public abstract Value compare(TruthOperator operator, Value value);
+ public abstract boolean compare(TruthOperator operator,Value value);
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function,Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index af05acb365a..882d16ebc1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -8,9 +8,10 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import java.util.*;
/**
- * A node which returns the outcome of a comparison.
+ * A node which returns true or false depending on the outcome of a comparison.
*
* @author bratseth
+ * @since 5.1.21
*/
public class ComparisonNode extends BooleanNode {
@@ -47,9 +48,9 @@ public class ComparisonNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
- Value leftValue = leftCondition.evaluate(context);
- Value rightValue = rightCondition.evaluate(context);
- return leftValue.compare(operator,rightValue);
+ Value leftValue=leftCondition.evaluate(context);
+ Value rightValue=rightCondition.evaluate(context);
+ return new BooleanValue(leftValue.compare(operator,rightValue));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index 19b1a83ed99..675ce758faa 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -12,38 +12,31 @@ import static java.lang.Math.*;
*/
public enum Function implements Serializable {
- abs { public double evaluate(double x, double y) { return abs(x); } },
+ cosh { public double evaluate(double x, double y) { return cosh(x); } },
+ sinh { public double evaluate(double x, double y) { return sinh(x); } },
+ tanh { public double evaluate(double x, double y) { return tanh(x); } },
+ cos { public double evaluate(double x, double y) { return cos(x); } },
+ sin { public double evaluate(double x, double y) { return sin(x); } },
+ tan { public double evaluate(double x, double y) { return tan(x); } },
acos { public double evaluate(double x, double y) { return acos(x); } },
asin { public double evaluate(double x, double y) { return asin(x); } },
atan { public double evaluate(double x, double y) { return atan(x); } },
- ceil { public double evaluate(double x, double y) { return ceil(x); } },
- cos { public double evaluate(double x, double y) { return cos(x); } },
- cosh { public double evaluate(double x, double y) { return cosh(x); } },
- elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } },
exp { public double evaluate(double x, double y) { return exp(x); } },
+ log10 { public double evaluate(double x, double y) { return log10(x); } },
+ log { public double evaluate(double x, double y) { return log(x); } },
+ sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
+ ceil { public double evaluate(double x, double y) { return ceil(x); } },
fabs { public double evaluate(double x, double y) { return abs(x); } },
floor { public double evaluate(double x, double y) { return floor(x); } },
isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } },
- log { public double evaluate(double x, double y) { return log(x); } },
- log10 { public double evaluate(double x, double y) { return log10(x); } },
relu { public double evaluate(double x, double y) { return max(x,0); } },
- round { public double evaluate(double x, double y) { return round(x); } },
sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } },
- sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } },
- sin { public double evaluate(double x, double y) { return sin(x); } },
- sinh { public double evaluate(double x, double y) { return sinh(x); } },
- square { public double evaluate(double x, double y) { return x*x; } },
- sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
- tan { public double evaluate(double x, double y) { return tan(x); } },
- tanh { public double evaluate(double x, double y) { return tanh(x); } },
-
atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } },
- fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } },
+ pow(2) { public double evaluate(double x, double y) { return pow(x,y); } },
ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } },
- max(2) { public double evaluate(double x, double y) { return max(x,y); } },
+ fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
- mod(2) { public double evaluate(double x, double y) { return x % y; } },
- pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
+ max(2) { public double evaluate(double x, double y) { return max(x,y); } };
private final int arity;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
deleted file mode 100644
index 7b48288598d..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ /dev/null
@@ -1,122 +0,0 @@
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-import java.util.function.DoubleBinaryOperator;
-import java.util.function.DoubleUnaryOperator;
-
-/**
- * A free, parametrized function
- *
- * @author bratseth
- */
-public class LambdaFunctionNode extends CompositeNode {
-
- private final ImmutableList<String> arguments;
- private final ExpressionNode functionExpression;
-
- public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
- // TODO: Verify that the function only accesses the arguments in mapperVariables
- this.arguments = ImmutableList.copyOf(arguments);
- this.functionExpression = functionExpression;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return Collections.singletonList(functionExpression);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if ( children.size() != 1)
- throw new IllegalArgumentException("A lambda function must have a single child expression");
- return new LambdaFunctionNode(arguments, children.get(0));
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")";
- }
-
- private String commaSeparated(List<String> list) {
- StringBuilder b = new StringBuilder();
- for (String element : list)
- b.append(element).append(",");
- if (b.length() > 0)
- b.setLength(b.length()-1);
- return b.toString();
- }
-
- /** Evaluate this in a context which must have the arguments bound */
- @Override
- public Value evaluate(Context context) {
- return functionExpression.evaluate(context);
- }
-
- /**
- * Returns this as a double unary operator
- *
- * @throws IllegalStateException if this has more than one argument
- */
- public DoubleUnaryOperator asDoubleUnaryOperator() {
- if (arguments.size() > 1)
- throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
- "Must have at most one argument " + " but has " + arguments);
- return new DoubleUnaryLambda();
- }
-
- /**
- * Returns this as a double binary operator
- *
- * @throws IllegalStateException if this has more than two arguments
- */
- public DoubleBinaryOperator asDoubleBinaryOperator() {
- if (arguments.size() > 2)
- throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
- "Must have at most two argument " + " but has " + arguments);
- return new DoubleBinaryLambda();
- }
-
- private class DoubleUnaryLambda implements DoubleUnaryOperator {
-
- @Override
- public double applyAsDouble(double operand) {
- MapContext context = new MapContext();
- if (arguments.size() > 0)
- context.put(arguments.get(0), operand);
- return evaluate(context).asDouble();
- }
-
- @Override
- public String toString() {
- return LambdaFunctionNode.this.toString();
- }
-
- }
-
- private class DoubleBinaryLambda implements DoubleBinaryOperator {
-
- @Override
- public double applyAsDouble(double left, double right) {
- MapContext context = new MapContext();
- if (arguments.size() > 0)
- context.put(arguments.get(0), left);
- if (arguments.size() > 1)
- context.put(arguments.get(1), right);
- return evaluate(context).asDouble();
- }
-
- @Override
- public String toString() {
- return LambdaFunctionNode.this.toString();
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
deleted file mode 100644
index 26d3f1dcc0e..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ /dev/null
@@ -1,111 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.EvaluationContext;
-import com.yahoo.tensor.functions.PrimitiveTensorFunction;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.functions.ToStringContext;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * A node which performs a tensor function
- *
- * @author bratseth
- */
- @Beta
-public class TensorFunctionNode extends CompositeNode {
-
- private final TensorFunction function;
-
- public TensorFunctionNode(TensorFunction function) {
- this.function = function;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return function.functionArguments().stream()
- .map(f -> ((TensorFunctionExpressionNode)f).expression)
- .collect(Collectors.toList());
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- // Serialize as primitive
- return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
- }
-
- @Override
- public Value evaluate(Context context) {
- return new TensorValue(function.evaluate(context));
- }
-
- public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
- return new TensorFunctionExpressionNode(node);
- }
-
- /**
- * A tensor function implemented by an expression.
- * This allows us to pass expressions as tensor function arguments.
- */
- public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {
-
- /** An expression which produces a tensor */
- private final ExpressionNode expression;
-
- public TensorFunctionExpressionNode(ExpressionNode expression) {
- this.expression = expression;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
-
- @Override
- public Tensor evaluate(EvaluationContext context) {
- Value result = expression.evaluate((Context)context);
- if ( ! ( result instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
- "but this returns " + result + ", not a tensor");
- return ((TensorValue)result).asTensor();
- }
-
- @Override
- public String toString(ToStringContext c) {
- ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c;
- return expression.toString(context.context, context.path, context.parent);
- }
-
- }
-
- /** Allows passing serialization context arguments through TensorFunctions */
- private static class ExpressionNodeToStringContext implements ToStringContext {
-
- final SerializationContext context;
- final Deque<String> path;
- final CompositeNode parent;
-
- public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
- this.context = context;
- this.path = path;
- this.parent = parent;
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
new file mode 100644
index 00000000000..af309b3e8d8
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
@@ -0,0 +1,59 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+ @Beta
+public class TensorMatchNode extends CompositeNode {
+
+ private final ExpressionNode left, right;
+
+ public TensorMatchNode(ExpressionNode left, ExpressionNode right) {
+ this.left = left;
+ this.right = right;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ List<ExpressionNode> children = new ArrayList<>(2);
+ children.add(left);
+ children.add(right);
+ return children;
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if ( children.size() != 2)
+ throw new IllegalArgumentException("A match product must have two children");
+ return new TensorMatchNode(children.get(0), children.get(1));
+
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")";
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context)));
+ }
+
+ private TensorValue asTensor(Value value) {
+ if ( ! (value instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " +
+ "not a tensor: " + value);
+ return (TensorValue)value;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
new file mode 100644
index 00000000000..a1f83157e20
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
@@ -0,0 +1,65 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * A node which sums over all cells in the argument tensor
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorSumNode extends CompositeNode {
+
+ /** The tensor to sum */
+ private final ExpressionNode argument;
+
+ /** The dimension to sum over, or empty to sum all cells to a scalar */
+ private final Optional<String> dimension;
+
+ public TensorSumNode(ExpressionNode argument, Optional<String> dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(argument);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument");
+ return new TensorSumNode(children.get(0), dimension);
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "sum(" +
+ argument.toString(context, path, parent) +
+ ( dimension.isPresent() ? ", " + dimension.get() : "" ) +
+ ")";
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Value argumentValue = argument.evaluate(context);
+ if ( ! ( argumentValue instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " +
+ "but this returns " + argumentValue + ", not a tensor");
+ TensorValue tensorArgument = (TensorValue)argumentValue;
+ if (dimension.isPresent())
+ return tensorArgument.sum(dimension.get());
+ else
+ return tensorArgument.sum();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
index 932975f3b63..60fe19f909f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
@@ -15,8 +15,7 @@ public enum TruthOperator implements Serializable {
EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
- NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
+ LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } };
private final String operatorString;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 0fcfdb5d40c..78ad665c414 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -21,9 +21,10 @@ import com.yahoo.searchlib.rankingexpression.rule.*;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.*;
-import com.yahoo.tensor.functions.*;
+import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.TensorAddress;
import java.util.Collections;
+import java.util.Map;
import java.util.LinkedHashMap;
import java.util.Arrays;
import java.util.ArrayList;
@@ -59,83 +60,51 @@ TOKEN :
<RSQUARE: "]"> |
<LCURLY: "{"> |
<RCURLY: "}"> |
-
<ADD: "+"> |
<SUB: "-"> |
<DIV: "/"> |
<MUL: "*"> |
<DOT: "."> |
-
<DOLLAR: "$"> |
<COMMA: ","> |
<COLON: ":"> |
-
<LE: "<="> |
<LT: "<"> |
<EQ: "=="> |
- <NQ: "!="> |
<AQ: "~="> |
<GE: ">="> |
<GT: ">"> |
-
<STRING: ("\"" (~["\""] | "\\\"")* "\"") |
("'" (~["'"] | "\\'")* "'")> |
-
<IF: "if"> |
- <IN: "in"> |
- <F: "f"> |
-
- <ABS: "abs"> |
+ <COSH: "cosh"> |
+ <SINH: "sinh"> |
+ <TANH: "tanh"> |
+ <COS: "cos"> |
+ <SIN: "sin"> |
+ <TAN: "tan"> |
<ACOS: "acos"> |
<ASIN: "asin"> |
+ <ATAN2: "atan2"> |
<ATAN: "atan"> |
- <CEIL: "ceil"> |
- <COS: "cos"> |
- <COSH: "cosh"> |
- <ELU: "elu"> |
<EXP: "exp"> |
+ <LDEXP: "ldexp"> |
+ <LOG10: "log10"> |
+ <LOG: "log"> |
+ <POW: "pow"> |
+ <SQRT: "sqrt"> |
+ <CEIL: "ceil"> |
<FABS: "fabs"> |
<FLOOR: "floor"> |
+ <FMOD: "fmod"> |
+ <MIN: "min"> |
+ <MAX: "max"> |
<ISNAN: "isNan"> |
- <LOG: "log"> |
- <LOG10: "log10"> |
+ <IN: "in"> |
+ <SUM: "sum"> |
+ <MATCH: "match"> |
<RELU: "relu"> |
- <ROUND: "round"> |
<SIGMOID: "sigmoid"> |
- <SIGN: "sign"> |
- <SIN: "sin"> |
- <SINH: "sinh"> |
- <SQUARE: "square"> |
- <SQRT: "sqrt"> |
- <TAN: "tan"> |
- <TANH: "tanh"> |
-
- <ATAN2: "atan2"> |
- <FMOD: "fmod"> |
- <LDEXP: "ldexp"> |
- // MAX
- // MIN
- <MOD: "mod"> |
- <POW: "pow"> |
-
- <MAP: "map"> |
- <REDUCE: "reduce"> |
- <JOIN: "join"> |
- <RENAME: "rename"> |
- <TENSOR: "tensor"> |
- <L1_NORMALIZE: "l1_normalize"> |
- <L2_NORMALIZE: "l2_normalize"> |
- <MATMUL: "matmul"> |
- <SOFTMAX: "softmax"> |
- <XW_PLUS_B: "xw_plus_b"> |
-
- <AVG: "avg" > |
- <COUNT: "count"> |
- <PROD: "prod"> |
- <SUM: "sum"> |
- <MAX: "max"> |
- <MIN: "min"> |
-
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
}
@@ -206,7 +175,6 @@ TruthOperator comparator() : { }
( <LE> { return TruthOperator.SMALLEREQUAL; } |
<LT> { return TruthOperator.SMALLER; } |
<EQ> { return TruthOperator.EQUAL; } |
- <NQ> { return TruthOperator.NOTEQUAL; } |
<AQ> { return TruthOperator.APPROX_EQUAL; } |
<GE> { return TruthOperator.LARGEREQUAL; } |
<GT> { return TruthOperator.LARGER; } )
@@ -221,6 +189,7 @@ ExpressionNode value() :
{
( [ LOOKAHEAD(2) <SUB> { neg = true; } ]
( ret = constantPrimitive() |
+ ret = constantTensor() |
LOOKAHEAD(2) ret = ifExpression() |
LOOKAHEAD(2) ret = function() |
ret = feature() |
@@ -310,6 +279,7 @@ ExpressionNode arg() :
}
{
( ret = constantPrimitive() |
+ ret = constantTensor() |
LOOKAHEAD(2) ret = feature() |
name = identifier() { ret = new NameNode(name); } )
{ return ret; }
@@ -320,11 +290,11 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( function = scalarOrTensorFunction() | function = tensorFunction() )
+ ( function = scalarFunction() | function = tensorFunction() )
{ return function; }
}
-FunctionNode scalarOrTensorFunction() :
+FunctionNode scalarFunction() :
{
Function function;
ExpressionNode arg1, arg2;
@@ -342,223 +312,61 @@ FunctionNode scalarOrTensorFunction() :
ExpressionNode tensorFunction() :
{
- ExpressionNode tensorExpression;
-}
-{
- (
- tensorExpression = tensorMap() |
- tensorExpression = tensorReduce() |
- tensorExpression = tensorReduceComposites() |
- tensorExpression = tensorJoin() |
- tensorExpression = tensorRename() |
- tensorExpression = tensorGenerate() |
- tensorExpression = tensorL1Normalize() |
- tensorExpression = tensorL2Normalize() |
- tensorExpression = tensorMatmul() |
- tensorExpression = tensorSoftmax() |
- tensorExpression = tensorXwPlusB()
- )
- { return tensorExpression; }
-}
-
-ExpressionNode tensorMap() :
-{
- ExpressionNode tensor;
- LambdaFunctionNode doubleMapper;
-}
-{
- <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
- doubleMapper.asDoubleUnaryOperator())); }
-}
-
-ExpressionNode tensorReduce() :
-{
- ExpressionNode tensor;
- Reduce.Aggregator aggregator;
- List<String> dimensions = null;
-}
-{
- <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
-}
-
-ExpressionNode tensorReduceComposites() :
-{
- ExpressionNode tensor;
- Reduce.Aggregator aggregator;
- List<String> dimensions = null;
-}
-{
- aggregator = tensorReduceAggregator()
- <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
-}
-
-ExpressionNode tensorJoin() :
-{
ExpressionNode tensor1, tensor2;
- LambdaFunctionNode doubleJoiner;
+ String dimension = null;
+ TensorAddress address = null;
}
{
- <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- doubleJoiner.asDoubleBinaryOperator())); }
-}
-
-ExpressionNode tensorRename() :
-{
- ExpressionNode tensor;
- List<String> fromDimensions, toDimensions;
-}
-{
- <RENAME> <LBRACE> tensor = expression() <COMMA>
- fromDimensions = bracedIdentifierList() <COMMA>
- toDimensions = bracedIdentifierList()
- <RBRACE>
- { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
-}
-
-// TODO: Notice that null is parsed below
-ExpressionNode tensorGenerate() :
-{
- TensorType type;
- LambdaFunctionNode generator;
-}
-{
- <TENSOR> <LBRACE> <RBRACE> <LBRACE>
- { return new TensorFunctionNode(new Generate(null, null)); }
-}
-
-ExpressionNode tensorL1Normalize() :
-{
- ExpressionNode tensor;
- String dimension;
-}
-{
- <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
-}
-
-ExpressionNode tensorL2Normalize() :
-{
- ExpressionNode tensor;
- String dimension;
-}
-{
- <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
-}
-
-ExpressionNode tensorMatmul() :
-{
- ExpressionNode tensor1, tensor2;
- String dimension;
-}
-{
- <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- dimension)); }
-}
-
-ExpressionNode tensorSoftmax() :
-{
- ExpressionNode tensor;
- String dimension;
-}
-{
- <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
-}
-
-ExpressionNode tensorXwPlusB() :
-{
- ExpressionNode tensor1, tensor2, tensor3;
- String dimension;
-}
-{
- <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA>
- tensor2 = expression() <COMMA>
- tensor3 = expression() <COMMA>
- dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- TensorFunctionNode.wrapArgument(tensor3),
- dimension)); }
-}
-
-LambdaFunctionNode lambdaFunction() :
-{
- List<String> variables;
- ExpressionNode functionExpression;
-}
-{
- ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
- { return new LambdaFunctionNode(variables, functionExpression); }
-}
-
-Reduce.Aggregator tensorReduceAggregator() :
-{
-}
-{
- ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
- { return Reduce.Aggregator.valueOf(token.image); }
+ (
+ <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE>
+ { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); }
+ ) |
+ (
+ <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE>
+ { return new TensorMatchNode(tensor1, tensor2); }
+ )
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
- Reduce.Aggregator aggregator;
}
{
- ( <F> { return token.image; } ) |
- ( <MAP> { return token.image; } ) |
- ( <REDUCE> { return token.image; } ) |
- ( <JOIN> { return token.image; } ) |
- ( <RENAME> { return token.image; } ) |
- ( <TENSOR> { return token.image; } ) |
- ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
+ ( <SUM> | <MATCH> )
+ { return token.image; }
}
Function unaryFunctionName() : { }
{
- <ABS> { return Function.abs; } |
+ <COS> { return Function.cos; } |
+ <SIN> { return Function.sin; } |
+ <TAN> { return Function.tan; } |
+ <COSH> { return Function.cosh; } |
+ <SINH> { return Function.sinh; } |
+ <TANH> { return Function.tanh; } |
<ACOS> { return Function.acos; } |
<ASIN> { return Function.asin; } |
<ATAN> { return Function.atan; } |
- <CEIL> { return Function.ceil; } |
- <COS> { return Function.cos; } |
- <COSH> { return Function.cosh; } |
- <ELU> { return Function.elu; } |
<EXP> { return Function.exp; } |
+ <LOG10> { return Function.log10; } |
+ <LOG> { return Function.log; } |
+ <SQRT> { return Function.sqrt; } |
+ <CEIL> { return Function.ceil; } |
<FABS> { return Function.fabs; } |
<FLOOR> { return Function.floor; } |
<ISNAN> { return Function.isNan; } |
- <LOG> { return Function.log; } |
- <LOG10> { return Function.log10; } |
<RELU> { return Function.relu; } |
- <ROUND> { return Function.round; } |
- <SIGMOID> { return Function.sigmoid; } |
- <SIGN> { return Function.sign; } |
- <SIN> { return Function.sin; } |
- <SINH> { return Function.sinh; } |
- <SQUARE> { return Function.square; } |
- <SQRT> { return Function.sqrt; } |
- <TAN> { return Function.tan; } |
- <TANH> { return Function.tanh; }
+ <SIGMOID> { return Function.sigmoid; }
}
Function binaryFunctionName() : { }
{
<ATAN2> { return Function.atan2; } |
- <FMOD> { return Function.fmod; } |
<LDEXP> { return Function.ldexp; } |
- <MAX> { return Function.max; } |
+ <POW> { return Function.pow; } |
+ <FMOD> { return Function.fmod; } |
<MIN> { return Function.min; } |
- <MOD> { return Function.mod; } |
- <POW> { return Function.pow; }
+ <MAX> { return Function.max; }
}
List<ExpressionNode> expressionList() :
@@ -597,28 +405,6 @@ String identifier() :
<IDENTIFIER> { return token.image; }
}
-List<String> identifierList() :
-{
- List<String> list = new ArrayList<String>();
- String element;
-}
-{
- ( element = identifier() { list.add(element); } )?
- ( <COMMA> element = identifier() { list.add(element); } ) *
- { return list; }
-}
-
-List<String> bracedIdentifierList() :
-{
- List<String> list = new ArrayList<String>();
- String element;
-}
-{
- ( element = identifier() { return Collections.singletonList(element); } )
- |
- ( <LBRACE> list = identifierList() <RBRACE> { return list; } )
-}
-
// An identifier or integer
String tag() :
{
@@ -629,16 +415,6 @@ String tag() :
<INTEGER> { return token.image; }
}
-List<String> tagCommaLeadingList() :
-{
- List<String> list = new ArrayList<String>();
- String element;
-}
-{
- ( <COMMA> element = tag() { list.add(element); } ) *
- { return list; }
-}
-
ConstantNode constantPrimitive() :
{
String sign = "";
@@ -658,3 +434,50 @@ Value primitiveValue() :
( <INTEGER> | <FLOAT> | <STRING> )
{ return Value.parse(sign + token.image); }
}
+
+ConstantNode constantTensor() :
+{
+ Value constantValue;
+}
+{
+ <LCURLY> constantValue = tensorContent() <RCURLY>
+ { return new ConstantNode(constantValue); }
+}
+
+TensorValue tensorContent() :
+{
+ Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>();
+ TensorAddress address;
+ Double value;
+}
+{
+ ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ?
+ ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) *
+ { return new TensorValue(new MapTensor(cells)); }
+}
+
+TensorAddress tensorAddress() :
+{
+ List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>();
+ String dimension;
+ String label;
+}
+{
+ <LCURLY>
+ ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ?
+ ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) *
+ <RCURLY>
+ { return TensorAddress.fromUnsorted(elements); }
+}
+
+String label() :
+{
+ String label;
+
+}
+{
+ ( label = tag() |
+ ( "-" { label = "-"; } ) )
+ { return label; }
+}
+
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index f28ff739b4c..24d7c82235c 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -6,10 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
-import org.junit.Test;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.assertFalse;
+import junit.framework.TestCase;
import java.io.BufferedReader;
import java.io.File;
@@ -17,18 +14,15 @@ import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
/**
- * @author Simon Thoresen
- * @author bratseth
+ * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
*/
-public class RankingExpressionTestCase {
+public class RankingExpressionTestCase extends TestCase {
- @Test
public void testParamInFeature() throws ParseException {
assertParse("if (1 > 2, dotProduct(allparentid,query(cate1_parentid)), 2)",
"if ( 1 > 2,\n" +
@@ -37,7 +31,6 @@ public class RankingExpressionTestCase {
")");
}
- @Test
public void testDollarShorthand() throws ParseException {
assertParse("query(var1)", " $var1");
assertParse("query(var1)", " $var1 ");
@@ -51,7 +44,6 @@ public class RankingExpressionTestCase {
assertParse("if (if (f1.out < query(p1), 0, 1) < if (f2.out < query(p2), 0, 1), f3.out, query(p3))", "if(if(f1.out<$p1,0,1)<if(f2.out<$p2,0,1),f3.out,$p3)");
}
- @Test
public void testLookaheadIndefinitely() throws Exception {
ExecutorService exec = Executors.newSingleThreadExecutor();
Future<Boolean> future = exec.submit(new Callable<Boolean>() {
@@ -68,8 +60,7 @@ public class RankingExpressionTestCase {
assertTrue(future.get(60, TimeUnit.SECONDS));
}
- @Test
- public void testSelfRecursionSerialization() throws ParseException {
+ public void testSelfRecursionScript() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo")));
@@ -81,8 +72,7 @@ public class RankingExpressionTestCase {
}
}
- @Test
- public void testMacroCycleSerialization() throws ParseException {
+ public void testMacroCycleScript() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar")));
macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo")));
@@ -95,48 +85,42 @@ public class RankingExpressionTestCase {
}
}
- @Test
- public void testSerialization() throws ParseException {
+ public void testScript() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))")));
macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2")));
macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)")));
macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977")));
- assertSerialization(Arrays.asList(
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)",
- "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))",
- "min(6,pow(7,2))",
- "min(1,pow(2,2))",
- "min(3,pow(4,2))",
- "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros);
- assertSerialization(Arrays.asList(
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)",
- "min(1,pow(2,2))",
- "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros);
- assertSerialization(Arrays.asList(
- "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)",
- "min(1,pow(2,2))",
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)",
- "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros);
- assertSerialization(Arrays.asList(
- "rankingExpression(cox)",
- "10 + 08 * 1977"), "cox", macros
- );
- }
-
- @Test
- public void testTensorSerialization() {
- assertSerialization("map(constant(tensor0), f(a)(cos(a)))",
- "map(constant(tensor0), f(a)(cos(a)))");
- assertSerialization("map(constant(tensor0), f(a)(cos(a))) + join(attribute(tensor1), map(reduce(map(attribute(tensor1), f(a)(a * a)), sum, x), f(a)(sqrt(a))), f(a,b)(a / b))",
- "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)");
- assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))",
- "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)");
-
+ assertScript("foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros,
+ Arrays.asList(
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)",
+ "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))",
+ "min(6,pow(7,2))",
+ "min(1,pow(2,2))",
+ "min(3,pow(4,2))",
+ "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"
+ ));
+ assertScript("foo(1, 2) + bar(3, 4)", macros,
+ Arrays.asList(
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)",
+ "min(1,pow(2,2))",
+ "3 * 3 + 2 * 3 * 4 + 4 * 4"
+ ));
+ assertScript("baz(1, 2)", macros,
+ Arrays.asList(
+ "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)",
+ "min(1,pow(2,2))",
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)",
+ "1 * 1 + 2 * 1 * 2 + 2 * 2"
+ ));
+ assertScript("cox", macros,
+ Arrays.asList(
+ "rankingExpression(cox)",
+ "10 + 08 * 1977"
+ ));
}
- @Test
public void testBug3464208() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69")));
@@ -151,11 +135,18 @@ public class RankingExpressionTestCase {
String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " +
"rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)";
- assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros);
- assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros);
+ assertScript(lhs + " + " + rhs, macros,
+ Arrays.asList(
+ expLhs + " + " + expRhs,
+ "69"
+ ));
+ assertScript(lhs + " - " + rhs, macros,
+ Arrays.asList(
+ expLhs + " - " + expRhs,
+ "69"
+ ));
}
- @Test
public void testParse() throws ParseException, IOException {
BufferedReader reader = new BufferedReader(new FileReader("src/tests/rankingexpression/rankingexpressionlist"));
String line;
@@ -190,43 +181,36 @@ public class RankingExpressionTestCase {
}
}
- @Test
public void testIssue() throws ParseException {
assertEquals("feature.0", new RankingExpression("feature.0").toString());
assertEquals("if (1 > 2, 3, 4) + feature(arg1).out.out",
new RankingExpression("if ( 1 > 2 , 3 , 4 ) + feature ( arg1 ) . out.out").toString());
}
- @Test
public void testNegativeConstantArgument() throws ParseException {
assertEquals("foo(-1.2)", new RankingExpression("foo(-1.2)").toString());
}
- @Test
public void testNaming() throws ParseException {
RankingExpression test = new RankingExpression("a+b");
test.setName("test");
assertEquals("test: a + b", test.toString());
}
- @Test
public void testCondition() throws ParseException {
RankingExpression expression = new RankingExpression("if(1<2,3,4)");
assertTrue(expression.getRoot() instanceof IfNode);
}
- @Test
public void testFileImporting() throws ParseException {
RankingExpression expression = new RankingExpression(new File("src/test/files/simple.expression"));
assertEquals("simple: a + b", expression.toString());
}
- @Test
public void testNonCanonicalLegalStrings() throws ParseException {
assertParse("a * b + c * d", "a* (b) + \nc*d");
}
- @Test
public void testEquality() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo)==\"BAR\",log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"),
new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
@@ -235,7 +219,6 @@ public class RankingExpressionTestCase {
new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).earliness) * fieldMatch(title).completeness)")));
}
- @Test
public void testSetMembershipConditions() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"),
new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
@@ -248,7 +231,6 @@ public class RankingExpressionTestCase {
assertEquals(new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"), new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"));
}
- @Test
public void testComments() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],\n" +
"# a comment\n" +
@@ -259,7 +241,6 @@ public class RankingExpressionTestCase {
new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
}
- @Test
public void testIsNan() throws ParseException {
String strExpr = "if (isNan(attribute(foo)) == 1.0, 1.0, attribute(foo))";
RankingExpression expr = new RankingExpression(strExpr);
@@ -274,59 +255,27 @@ public class RankingExpressionTestCase {
assertEquals(expected, new RankingExpression(expression).toString());
}
- /** Test serialization with no macros */
- private void assertSerialization(String expectedSerialization, String expressionString) {
- String serializedExpression;
- try {
- RankingExpression expression = new RankingExpression(expressionString);
- // No macros -> expect one rank property
- serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next();
- assertEquals(expectedSerialization, serializedExpression);
- }
- catch (ParseException e) {
- throw new IllegalArgumentException(e);
- }
-
- try {
- // No macros -> output should be parseable to a ranking expression
- // (but not the same one due to primitivization)
- RankingExpression reparsedExpression = new RankingExpression(serializedExpression);
- // Serializing the primitivized expression should yield the same expression again
- String reserializedExpression =
- reparsedExpression.getRankProperties(Collections.emptyList()).values().iterator().next();
- assertEquals(expectedSerialization, reserializedExpression);
- }
- catch (ParseException e) {
- throw new IllegalArgumentException("Could not parse the serialized expression", e);
- }
- }
-
- private void assertSerialization(List<String> expectedSerialization, String expressionString,
- List<ExpressionFunction> macros) {
- assertSerialization(expectedSerialization, expressionString, macros, false);
- }
- private void assertSerialization(List<String> expectedSerialization, String expressionString,
- List<ExpressionFunction> macros, boolean print) {
- try {
- if (print)
- System.out.println("Parsing expression '" + expressionString + "'.");
-
- RankingExpression expression = new RankingExpression(expressionString);
- Map<String, String> rankProperties = expression.getRankProperties(macros);
- if (print) {
- for (String key : rankProperties.keySet())
- System.out.println("Property '" + key + "': " + rankProperties.get(key));
+ private void assertScript(String expression, List<ExpressionFunction> macros, List<String> expectedScripts)
+ throws ParseException {
+ boolean print = false;
+ if (print)
+ System.out.println("Parsing expression '" + expression + "'.");
+
+ RankingExpression exp = new RankingExpression(expression);
+ Map<String, String> scripts = exp.getRankProperties(macros);
+ if (print) {
+ for (String key : scripts.keySet()) {
+ System.out.println("Script '" + key + "': " + scripts.get(key));
}
- for (int i = 0; i < expectedSerialization.size();) {
- String val = expectedSerialization.get(i++);
- assertTrue("Properties contains " + val, rankProperties.containsValue(val));
- }
- if (print)
- System.out.println("");
}
- catch (ParseException e) {
- throw new IllegalArgumentException(e);
+
+ for (Map.Entry<String, String> m : scripts.entrySet())
+ System.out.println(m);
+ for (int i = 0; i < expectedScripts.size();) {
+ String val = expectedScripts.get(i++);
+ assertTrue("Script contains " + val, scripts.containsValue(val));
}
+ if (print)
+ System.out.println("");
}
-
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 93800e2c246..b67a423181d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -20,7 +20,7 @@ import java.util.Set;
*/
public class EvaluationTestCase extends junit.framework.TestCase {
- private MapContext defaultContext;
+ private Context defaultContext;
@Override
protected void setUp() {
@@ -100,180 +100,201 @@ public class EvaluationTestCase extends junit.framework.TestCase {
@Test
public void testTensorEvaluation() {
- assertEvaluates("{}", "tensor0", "{}");
+ assertEvaluates("{}", "{}"); // empty
+ assertEvaluates("( {{x:-}:1} * {} )", "( {{x:-}:1} * {} )"); // empty with dimensions
- // tensor map
+ // sum(tensor)
+ assertEvaluates(5.0, "sum({{}:5.0})");
+ assertEvaluates(-5.0, "sum({{}:-5.0})");
+ assertEvaluates(12.5, "sum({ {d1:l1}:5.5, {d2:l2}:7.0 })");
+ assertEvaluates(0.0, "sum({ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0})");
+
+ // scalar functions on tensors
assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
- "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:4, {d1:l1}:9, {d1:l1,d2:l1 }:16 }",
- "map(tensor0, f(x) (x * x))", "{ {}:2, {d1:l1}:3, {d1:l1,d2:l1}:4 }");
- // -- tensor map composites
- assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
- "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ "log10({ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 })");
+ assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
+ "5 * { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
+ "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } + 3");
+ assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
+ "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } / 10");
assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }",
- "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ "- { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }",
- "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
+ "min({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)");
assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }",
- "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
- // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
- assertEvaluates("{ {x:1}:1, {x:2}:2 }", "abs(tensor0)", "{ {x:1}:1, {x:2}:-2 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "acos(tensor0)", "{ {x:1}:1, {x:2}:1 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "asin(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "atan(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:1, {x:2}:2 }", "ceil(tensor0)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cos(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cosh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:1, {x:2}:2 }", "elu(tensor0)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:1, {x:2}:1 }", "exp(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:1, {x:2}:2 }", "fabs(tensor0)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:1, {x:2}:2 }", "floor(tensor0)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "isNan(tensor0)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "log(tensor0)", "{ {x:1}:1, {x:2}:1 }");
- assertEvaluates("{ {x:1}:0, {x:2}:1 }", "log10(tensor0)", "{ {x:1}:1, {x:2}:10 }");
- assertEvaluates("{ {x:1}:0, {x:2}:2 }", "mod(tensor0, 3)", "{ {x:1}:3, {x:2}:8 }");
- assertEvaluates("{ {x:1}:1, {x:2}:8 }", "pow(tensor0, 3)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:1, {x:2}:2 }", "relu(tensor0)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:1, {x:2}:2 }", "round(tensor0)", "{ {x:1}:1, {x:2}:1.8 }");
- assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "sigmoid(tensor0)","{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:1, {x:2}:-1 }", "sign(tensor0)", "{ {x:1}:3, {x:2}:-5 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sin(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sinh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:1, {x:2}:4 }", "square(tensor0)", "{ {x:1}:1, {x:2}:2 }");
- assertEvaluates("{ {x:1}:1, {x:2}:3 }", "sqrt(tensor0)", "{ {x:1}:1, {x:2}:9 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tan(tensor0)", "{ {x:1}:0, {x:2}:0 }");
- assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tanh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
-
- // tensor reduce
- // -- reduce 2 dimensions
- assertEvaluates("{ {}:4 }",
- "reduce(tensor0, avg, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {}:4 }",
- "reduce(tensor0, count, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {}:105 }",
- "reduce(tensor0, prod, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {}:16 }",
- "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {}:7 }",
- "reduce(tensor0, max, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {}:1 }",
- "reduce(tensor0, min, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- // -- reduce 2 by specifying no arguments
- assertEvaluates("{ {}:4 }",
- "reduce(tensor0, avg)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- // -- reduce 1 dimension
- assertEvaluates("{ {y:1}:2, {y:2}:6 }",
- "reduce(tensor0, avg, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {y:1}:2, {y:2}:2 }",
- "reduce(tensor0, count, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {y:1}:3, {y:2}:35 }",
- "reduce(tensor0, prod, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {y:1}:4, {y:2}:12 }",
- "reduce(tensor0, sum, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {y:1}:3, {y:2}:7 }",
- "reduce(tensor0, max, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {y:1}:1, {y:2}:5 }",
- "reduce(tensor0, min, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- // -- reduce composites
- assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0");
- assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0");
- assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }");
- assertEvaluates("{ {}: 0 }", "sum(tensor0)", "{ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0}");
- assertEvaluates("{ {y:1}:4, {y:2}:12.0 }",
- "sum(tensor0, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {x:1}:6, {x:2}:10.0 }",
- "sum(tensor0, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- assertEvaluates("{ {}:16 }",
- "sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
-
- // tensor join
- assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", "join(tensor0, tensor1, f(x,y) (x*y))", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- // -- join composites
- assertEvaluates("{ }", "tensor0 * tensor0", "{}");
- assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )",
- "{{x:-}:1}", "{}", "{{y:-,z:-}:1}"); // empty dimensions are preserved
- assertEvaluates("tensor(x{}):{}",
- "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:2}:5 }");
+ "max({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)");
+ assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + {{h:1}:1.0,{h:2}:1.0}");
+
+ // sum(tensor, dimension)
+ assertEvaluates("{ {y:1}:4.0, {y:2}:12.0 }",
+ "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, x)");
+ assertEvaluates("{ {x:1}:6.0, {x:2}:10.0 }",
+ "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, y)");
+
+ // tensor sum
+ assertEvaluates("{ }", "{} + {}");
+ assertEvaluates("{ {x:1}:3, {x:2}:5 }",
+ "{ {x:1}:3 } + { {x:2}:5 }");
+ assertEvaluates("{ {x:1}:8 }",
+ "{ {x:1}:3 } + { {x:1}:5 }");
+ assertEvaluates("{ {x:1}:3, {y:1}:5 }",
+ "{ {x:1}:3 } + { {y:1}:5 }");
+ assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
+ "{ {x:1}:3, {x:2}:7 } + { {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
+ "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } + { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "{ {x:1}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
+ "{ {x:1}:5, {x:1,y:1}:1 } + { {z:1}:11, {y:1,z:1}:7 }");
+ assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "{ {}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }");
+ assertEvaluates("{ {}:16, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "{ {}:5, {x:1,y:1}:1 } + { {}:11, {y:1,z:1}:7 }");
+
+ // tensor difference
+ assertEvaluates("{ }", "{} - {}");
+ assertEvaluates("{ {x:1}:3, {x:2}:-5 }",
+ "{ {x:1}:3 } - { {x:2}:5 }");
+ assertEvaluates("{ {x:1}:-2 }",
+ "{ {x:1}:3 } - { {x:1}:5 }");
+ assertEvaluates("{ {x:1}:3, {y:1}:-5 }",
+ "{ {x:1}:3 } - { {y:1}:5 }");
+ assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:-5 }",
+ "{ {x:1}:3, {x:2}:7 } - { {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:-7, {y:2,z:1}:-11, {y:1,z:2}:-13 }",
+ "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } - { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }",
+ "{ {x:1}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:-11, {y:1,z:1}:-7 }",
+ "{ {x:1}:5, {x:1,y:1}:1 } - { {z:1}:11, {y:1,z:1}:7 }");
+ assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }",
+ "{ {}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }");
+ assertEvaluates("{ {}:-6, {x:1,y:1}:1, {y:1,z:1}:-7 }",
+ "{ {}:5, {x:1,y:1}:1 } - { {}:11, {y:1,z:1}:7 }");
+ assertEvaluates("{ {x:1}:0 }",
+ "{ {x:1}:3 } - { {x:1}:3 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:1 }",
+ "{ {x:1}:3, {x:2}:1 } - { {x:1}:3 }");
+
+ // tensor product
+ assertEvaluates("{ }", "{} * {}");
+ assertEvaluates("( {{x:-,y:-,z:-}:1}*{} )", "( {{x:-}:1} * {} ) * ( {{y:-,z:-}:1} * {} )"); // empty dimensions are preserved
+ assertEvaluates("( {{x:-}:1} * {} )",
+ "{ {x:1}:3 } * { {x:2}:5 }");
assertEvaluates("{ {x:1}:15 }",
- "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:1}:5 }");
+ "{ {x:1}:3 } * { {x:1}:5 }");
assertEvaluates("{ {x:1,y:1}:15 }",
- "tensor0 * tensor1", "{ {x:1}:3 }", "{ {y:1}:5 }");
+ "{ {x:1}:3 } * { {y:1}:5 }");
assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }",
- "tensor0 * tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:8, {x:2,y:1}:12 }",
- "tensor0 + tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:-2, {x:2,y:1}:2 }",
- "tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }",
- "tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }");
- assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }",
- "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }",
- "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ "{ {x:1}:3, {x:2}:7 } * { {y:1}:5 }");
assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }",
- "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1,y:2,z:1}:35, {x:1,y:2,z:2}:65 }",
- "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:2,z:1}:7, {y:3,z:1}:11, {y:2,z:2}:13 }");
- assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }");
- assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
- "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
- "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
- "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }");
- assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:0 }",
- "atan2(tensor0, tensor1)", "{ {x:1}:0, {x:2}:0 }", "{ {y:1}:1 }");
- assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
- "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
- "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
- "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
- "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
- "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
- "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
- // TODO
- // argmax
- // argmin
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
- "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
-
- // tensor rename
- assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }");
- assertEvaluates("{ {x:2,y:1}:3 }", "rename(tensor0, (x, y), (y, x))", "{ {x:1,y:2}:3.0 }");
-
- // tensor generate - TODO
- // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)");
- // range
- // diag
- // fill
- // random
-
- // composite functions
- assertEvaluates("{ {x:1}:0.25, {x:2}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }");
- assertEvaluates("{ {x:1}:0.31622776601683794, {x:2}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }");
- assertEvaluates("{ {y:1}:81.0 }", "matmul(tensor0, tensor1, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }");
- assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "softmax(tensor0, x)", "{ {x:1}:1, {x:2}:1 }", "{ {y:1}:1 }");
- assertEvaluates("{ {x:1,y:1}:88.0 }", "xw_plus_b(tensor0, tensor1, tensor2, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }", "{ {x:1}:7 }");
-
- // expressions combining functions
- assertEvaluates(String.valueOf(7.5 + 45 + 1.7),
- "sum( " + // model computation:
- " tensor0 * tensor1 * tensor2 " + // - feature combinations
- " * tensor3" + // - model weights application
- ") + 1.7",
- "{ {x:1}:1, {x:2}:2 }", "{ {y:1}:3, {y:2}:4 }", "{ {z:1}:5 }",
- "{ {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }");
- assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{ {x:1}:0, {x:2}:0 }", "{ {x:1}:1, {x:2}:1 }");
- assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:1}:1, {x:2}:1 }");
+ "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } * { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
+ assertEvaluates("{ {x:1,y:1,z:1}:7 }",
+ "{ {x:1}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }");
+ assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,z:1}:55 }",
+ "{ {x:1}:5, {x:1,y:1}:1 } * { {z:1}:11, {y:1,z:1}:7 }");
+ assertEvaluates("{ {x:1,y:1,z:1}:7 }",
+ "{ {}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }");
+ assertEvaluates("{ {x:1,y:1,z:1}:7, {}:55 }",
+ "{ {}:5, {x:1,y:1}:1 } * { {}:11, {y:1,z:1}:7 }");
+
+ // match product
+ assertEvaluates("{ }", "match({}, {})");
+ assertEvaluates("( {{x:-}:1} * {} )",
+ "match({ {x:1}:3 }, { {x:2}:5 })");
+ assertEvaluates("{ {x:1}:15 }",
+ "match({ {x:1}:3 }, { {x:1}:5 })");
+ assertEvaluates("( {{x:-,y:-}:1} * {} )",
+ "match({ {x:1}:3 }, { {y:1}:5 })");
+ assertEvaluates("( {{x:-,y:-}:1} * {} )",
+ "match({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
+ assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
+ "match({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
+ assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
+ "match({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
+ assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
+ "match({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
+ assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
+ "match({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
+ assertEvaluates("( {{x:-,y:-,z:-}:1} * { {}:55 } )",
+ "match({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
+ assertEvaluates("( {{z:-}:1} * { {x:1}:15, {x:1,y:1}:7 } )",
+ "match({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
+
+ // min
+ assertEvaluates("{ {x:1}:3, {x:2}:5 }",
+ "min({ {x:1}:3 }, { {x:2}:5 })");
+ assertEvaluates("{ {x:1}:3 }",
+ "min({ {x:1}:3 }, { {x:1}:5 })");
+ assertEvaluates("{ {x:1}:3, {y:1}:5 }",
+ "min({ {x:1}:3 }, { {y:1}:5 })");
+ assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
+ "min({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
+ "min({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "min({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
+ "min({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
+ assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "min({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
+ assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "min({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
+ assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
+ "min({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
+
+ // max
+ assertEvaluates("{ {x:1}:3, {x:2}:5 }",
+ "max({ {x:1}:3 }, { {x:2}:5 })");
+ assertEvaluates("{ {x:1}:5 }",
+ "max({ {x:1}:3 }, { {x:1}:5 })");
+ assertEvaluates("{ {x:1}:3, {y:1}:5 }",
+ "max({ {x:1}:3 }, { {y:1}:5 })");
+ assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
+ "max({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
+ "max({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "max({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
+ assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
+ "max({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
+ assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "max({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
+ assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }",
+ "max({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
+ assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
+ "max({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
+
+ // Combined
+ assertEvaluates(7.5 + 45 + 1.7,
+ "sum( " + // model computation
+ " match( " + // model weight application
+ " { {x:1}:1, {x:2}:2 } * { {y:1}:3, {y:2}:4 } * { {z:1}:5 }, " + // feature combinations
+ " { {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }" + // model weights
+ "))+1.7");
+
+ // undefined is not the same as 0
+ assertEvaluates(1.0, "sum({ {x:1}:0, {x:2}:0 } * { {x:1}:1, {x:2}:1 } + 0.5)");
+ assertEvaluates(0.0, "sum({ } * { {x:1}:1, {x:2}:1 } + 0.5)");
// tensor result dimensions are given from argument dimensions, not the resulting values
- assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2}:1 }");
- assertEvaluates("tensor(x{},y{}):{{x:1}:1.0}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1}:1 }");
+ assertEvaluates("x", "( {{x:-}:1.0} * {} )", "{ {x:1}:1 } * { {x:2}:1 }");
+ assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }");
+
+ // demonstration of where this produces different results: { {x:1}:1 } with 2 dimensions ...
+ assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )","{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 } * { {x:1,y:1}:1 }");
+ // ... vs { {x:1}:1 } with only one dimension
+ assertEvaluates("x, y", "{{x:1,y:1}:1.0}", "{ {x:1}:1 } * { {x:1,y:1}:1 }");
+
+ // check that dimensions are preserved through other operations
+ String d2 = "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"; // creates a 2d tensor with only an 1d value
+ assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )", "match(" + d2 + ", {})");
+ assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " - {}");
+ assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " + {}");
+ assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "min(1.5, " + d2 +")");
+ assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "max({{x:1}:0}, " + d2 +")");
}
public void testProgrammaticBuildingAndPrecedence() {
@@ -295,16 +316,12 @@ public class EvaluationTestCase extends junit.framework.TestCase {
assertEvaluates(77, "average(\"2*3\",\"pow(2,3)\")+average(\"2*3\",\"pow(2,3)\").timesten", context);
}
- private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
- MapContext context = defaultContext.thawedCopy();
- int argumentIndex = 0;
- for (String tensorArgument : tensorArguments)
- context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument)));
- return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context);
+ private RankingExpression assertEvaluates(String tensorValue, String expressionString) {
+ return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, defaultContext);
}
/** Validate also that the dimension of the resulting tensors are as expected */
- private RankingExpression assertEvaluates_old(String tensorDimensions, String resultTensor, String expressionString) {
+ private RankingExpression assertEvaluates(String tensorDimensions, String resultTensor, String expressionString) {
RankingExpression expression = assertEvaluates(new TensorValue(MapTensor.from(resultTensor)), expressionString, defaultContext);
TensorValue value = (TensorValue)expression.evaluate(defaultContext);
assertEquals(toSet(tensorDimensions), value.asTensor().dimensions());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
index 08fdc9917a4..95c4402a612 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
@@ -17,25 +17,22 @@ public class NeuralNetEvaluationTestCase {
/** "XOR" neural network, separate expression per layer */
@Test
public void testPerLayerExpression() {
- String input = "{ {x:1}:0, {x:2}:1 }"; // tensor0
- String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; // tensor1
- String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; // tensor2
- String firstLayerInput = "sum(tensor0 * tensor1, x) + tensor2";
+ String input = "{ {x:1}:0, {x:2}:1 }";
+
+ String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }";
+ String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }";
+ String firstLayerInput = "sum(" + input + "*" + firstLayerWeights + ", x) + " + firstLayerBias;
String firstLayerOutput = "min(1.0, max(0.0, 0.5 + " + firstLayerInput + "))"; // non-linearity, "poor man's sigmoid"
- assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput, input, firstLayerWeights, firstLayerBias);
- String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; // tensor3
- String secondLayerBias = "{ {y:1}:-0.5 }"; // tensor4
- String secondLayerInput = "sum(" + firstLayerOutput + "* tensor3, h) + tensor4";
+ assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput);
+ String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }";
+ String secondLayerBias = "{ {y:1}:-0.5 }";
+ String secondLayerInput = "sum(" + firstLayerOutput + "*" + secondLayerWeights + ", h) + " + secondLayerBias;
String secondLayerOutput = "min(1.0, max(0.0, 0.5 + " + secondLayerInput + "))"; // non-linearity, "poor man's sigmoid"
- assertEvaluates("{ {y:1}:1 }", secondLayerOutput, input, firstLayerWeights, firstLayerBias, secondLayerWeights, secondLayerBias);
+ assertEvaluates("{ {y:1}:1 }", secondLayerOutput);
}
- private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
- MapContext context = new MapContext();
- int argumentIndex = 0;
- for (String tensorArgument : tensorArguments)
- context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument)));
- return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context);
+ private RankingExpression assertEvaluates(String tensorValue, String expressionString) {
+ return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, new MapContext());
}
private RankingExpression assertEvaluates(Value value, String expressionString, Context context) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
index 61b230ab390..9d94ec0bc99 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
@@ -69,4 +69,12 @@ public class SimplifierTestCase {
assertEquals("a + (b + c) / 100000000.0", transformed.toString());
}
+ @Test
+ public void testSimplificationWithTensorConstants() throws ParseException {
+ new Simplifier().transform(new RankingExpression(
+ "sum(sum((tensorFromWeightedSet(query(wset_query),x)+" +
+ " tensorFromWeightedSet(attribute(wset),x)) * " +
+ " {{x:0,y:0}:54, {x:0,y:1} :69, {x:1,y:0} :72, {x:1,y:1} :93},x))"));
+ }
+
}
diff --git a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java
index a54f1971d21..d70b55c66a2 100644
--- a/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java
+++ b/vdslib/src/main/java/com/yahoo/vdslib/state/ClusterState.java
@@ -20,7 +20,7 @@ public class ClusterState implements Cloneable {
private Map<Node, NodeState> nodeStates = new TreeMap<>();
// TODO: Change to one count for distributor and one for storage, rather than an array
- // TODO: RenameFunction, this is not the highest node count but the highest index
+ // TODO: Rename, this is not the highest node count but the highest index
private ArrayList<Integer> nodeCount = new ArrayList<>(2);
private String description = "";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java
index 4fd743e4724..3bda4159ca6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MapTensor.java
@@ -21,8 +21,6 @@ import java.util.function.UnaryOperator;
@Beta
public class MapTensor implements Tensor {
- // TODO: Enforce that all addresses are dense (and then avoid storing keys in TensorAddress)
-
private final ImmutableSet<String> dimensions;
private final ImmutableMap<TensorAddress, Double> cells;
@@ -33,7 +31,7 @@ public class MapTensor implements Tensor {
}
/** Creates a sparse tensor */
- public MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) {
+ MapTensor(Set<String> dimensions, Map<TensorAddress, Double> cells) {
ensureValidDimensions(cells, dimensions);
this.dimensions = ImmutableSet.copyOf(dimensions);
this.cells = ImmutableMap.copyOf(cells);
@@ -54,41 +52,24 @@ public class MapTensor implements Tensor {
*/
public static MapTensor from(String s) {
s = s.trim();
- try {
- if (s.startsWith("tensor("))
- return fromTypedTensor(s);
- else if (s.startsWith("{"))
- return fromUntypedTensor(s, Collections.emptySet());
- else
- return fromNumber(Double.parseDouble(s));
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + s + "'");
- }
+ if ( s.startsWith("("))
+ return fromTensorWithEmptyDimensions(s);
+ else if ( s.startsWith("{"))
+ return fromTensor(s, Collections.emptySet());
+ else
+ throw new IllegalArgumentException("Excepted a string starting by { or (, got '" + s + "'");
}
- private static MapTensor fromTypedTensor(String s) {
- if ( ! s.startsWith("tensor(")) throw tensorFormatException(s);
- s = s.substring("tensor(".length());
- int typeSpecEnd = s.indexOf(")");
- if (typeSpecEnd < 0 ) throw tensorFormatException(s);
- String typeSpec = s.substring(0, typeSpecEnd);
-
- Set<String> dimensions = new HashSet<>();
- for (String dimensionSpec : typeSpec.split(",")) {
- dimensionSpec = dimensionSpec.trim();
- if ( ! dimensionSpec.endsWith("{}"))
- throw new IllegalArgumentException("Only mapped dimensions ({}) are supported, got '" + dimensionSpec + "'");
- dimensions.add(dimensionSpec.substring(0, dimensionSpec.length() - 2));
- }
-
- s = s.substring(typeSpec.length() + 1);
- if ( ! s.startsWith(":")) throw tensorFormatException(s);
+ private static MapTensor fromTensorWithEmptyDimensions(String s) {
s = s.substring(1).trim();
- return fromUntypedTensor(s, dimensions);
+ int multiplier = s.indexOf("*");
+ if (multiplier < 0 || ! s.endsWith(")"))
+ throw new IllegalArgumentException("Expected a tensor on the form ({dimension:-,...}*{{cells}}), got '" + s + "'");
+ MapTensor dimensionTensor = fromTensor(s.substring(0, multiplier).trim(), Collections.emptySet());
+ return fromTensor(s.substring(multiplier + 1, s.length() - 1), dimensionTensor.dimensions());
}
- private static MapTensor fromUntypedTensor(String s, Set<String> additionalDimensions) {
+ private static MapTensor fromTensor(String s, Set<String> additionalDimensions) {
s = s.trim().substring(1).trim();
ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
while (s.length() > 1) {
@@ -113,16 +94,6 @@ public class MapTensor implements Tensor {
dimensions.addAll(additionalDimensions);
return new MapTensor(dimensions, cellMap);
}
-
- private static MapTensor fromNumber(double number) {
- ImmutableMap.Builder<TensorAddress, Double> singleCell = new ImmutableMap.Builder<>();
- singleCell.put(TensorAddress.empty, number);
- return new MapTensor(ImmutableSet.of(), singleCell.build());
- }
-
- private static IllegalArgumentException tensorFormatException(String s) {
- return new IllegalArgumentException("Expected a tensor on the form tensor(dimensionspec):content, but got '" + s + "'");
- }
private static Double asDouble(TensorAddress address, String s) {
try {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java
new file mode 100644
index 00000000000..074742acee1
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MatchProduct.java
@@ -0,0 +1,33 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import com.google.common.collect.ImmutableMap;
+
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Computes a <i>match product</i>, see {@link Tensor#match}
+ *
+ * @author bratseth
+ */
+class MatchProduct {
+
+ private final Set<String> dimensions;
+ private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
+
+ public MatchProduct(Tensor a, Tensor b) {
+ this.dimensions = TensorOperations.combineDimensions(a, b);
+ for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) {
+ Double sameValueInB = b.cells().get(aCell.getKey());
+ if (sameValueInB != null)
+ cells.put(aCell.getKey(), aCell.getValue() * sameValueInB);
+ }
+ }
+
+ /** Returns the result of taking this product */
+ public MapTensor result() {
+ return new MapTensor(dimensions, cells.build());
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 4b17f65ea21..41882738e89 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -2,25 +2,18 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
-import com.yahoo.tensor.functions.ConstantTensor;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.L1Normalize;
-import com.yahoo.tensor.functions.L2Normalize;
-import com.yahoo.tensor.functions.Matmul;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.Softmax;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleFunction;
import java.util.function.DoubleUnaryOperator;
-import java.util.function.Function;
+import java.util.function.UnaryOperator;
/**
* A multidimensional array which can be used in computations.
@@ -56,74 +49,128 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
- // ----------------- Primitive tensor functions
+ // ----------------- Level 0 functions
- default Tensor map(DoubleUnaryOperator mapper) {
- return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
+ default Tensor map(Tensor tensor, DoubleUnaryOperator mapper) {
+ throw new UnsupportedOperationException("Not implemented");
}
- /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
- default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
- return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate();
+ default Tensor reduce(Tensor tensor, String dimension,
+ DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) {
+ throw new UnsupportedOperationException("Not implemented");
}
- default Tensor join(Tensor argument, DoubleBinaryOperator combinator) {
- return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate();
+ default Tensor join(Tensor tensorA, Tensor tensorB, DoubleBinaryOperator combinator) {
+ throw new UnsupportedOperationException("Not implemented");
}
- default Tensor rename(String fromDimension, String toDimension) {
- return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
- Collections.singletonList(toDimension)).evaluate();
+ // ----------------- Old stuff
+ /**
+ * Returns the <i>sparse tensor product</i> of this tensor and the argument tensor.
+ * This is the all-to-all combinations of cells in the argument tenors, except the combinations
+ * which have conflicting labels for the same dimension. The value of each combination is the product
+ * of the values of the two input cells. The dimensions of the tensor product is the set union of the
+ * dimensions of the argument tensors.
+ * <p>
+ * If there are no overlapping dimensions this is the regular tensor product.
+ * If the two tensors have exactly the same dimensions this is the Hadamard product.
+ * <p>
+ * The sparse tensor product is associative and commutative.
+ *
+ * @param argument the tensor to multiply by this
+ * @return the resulting tensor.
+ */
+ default Tensor multiply(Tensor argument) {
+ return new TensorProduct(this, argument).result();
}
- default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
- return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
- }
-
- static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) {
- return new Generate(type, valueSupplier).evaluate();
+ /**
+ * Returns the <i>match product</i> of two tensors.
+ * This returns a tensor which contains the <i>matching</i> cells in the two tensors, with their
+ * values multiplied.
+ * <p>
+ * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
+ * and have the value undefined for any non-shared dimension.
+ * <p>
+ * The dimensions of the resulting tensor is the set intersection of the two argument tensors.
+ * <p>
+ * If the two tensors have exactly the same dimensions, this is the Hadamard product.
+ */
+ default Tensor match(Tensor argument) {
+ return new MatchProduct(this, argument).result();
}
-
- // ----------------- Composite tensor functions which have a defined primitive mapping
-
- default Tensor l1Normalize(String dimension) {
- return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
+
+ /**
+ * Returns a tensor which contains the cells of both argument tensors, where the value for
+ * any <i>matching</i> cell is the min of the two possible values.
+ * <p>
+ * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
+ * and have the value undefined for any non-shared dimension.
+ */
+ default Tensor min(Tensor argument) {
+ return new TensorMin(this, argument).result();
}
- default Tensor l2Normalize(String dimension) {
- return new L2Normalize(new ConstantTensor(this), dimension).evaluate();
+ /**
+ * Returns a tensor which contains the cells of both argument tensors, where the value for
+ * any <i>matching</i> cell is the max of the two possible values.
+ * <p>
+ * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
+ * and have the value undefined for any non-shared dimension.
+ */
+ default Tensor max(Tensor argument) {
+ return new TensorMax(this, argument).result();
}
- default Tensor matmul(Tensor argument, String dimension) {
- return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
+ /**
+ * Returns a tensor which contains the cells of both argument tensors, where the value for
+ * any <i>matching</i> cell is the sum of the two possible values.
+ * <p>
+ * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
+ * and have the value undefined for any non-shared dimension.
+ */
+ default Tensor add(Tensor argument) {
+ return new TensorSum(this, argument).result();
}
- default Tensor softmax(String dimension) {
- return new Softmax(new ConstantTensor(this), dimension).evaluate();
+ /**
+ * Returns a tensor which contains the cells of both argument tensors, where the value for
+ * any <i>matching</i> cell is the difference of the two possible values.
+ * <p>
+ * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
+ * and have the value undefined for any non-shared dimension.
+ */
+ default Tensor subtract(Tensor argument) {
+ return new TensorDifference(this, argument).result();
}
- // ----------------- Composite tensor functions mapped to primitives here on the fly
+ /**
+ * Returns a tensor with the same cells as this and the given function is applied to all its cell values.
+ *
+ * @param function the function to apply to all cells
+ * @return the tensor with the function applied to all the cells of this
+ */
+ default Tensor apply(UnaryOperator<Double> function) {
+ return new TensorFunction(this, function).result();
+ }
- default Tensor multiply(Tensor argument) { return join(argument, (a, b) -> (a * b )); }
- default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); }
- default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); }
- default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); }
- default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); }
- default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); }
- default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); }
- default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); }
- default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); }
- default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); }
- default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); }
- default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
- default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); }
+ /**
+ * Returns a tensor with the given dimension removed and cells which contains the sum of the values
+ * in the removed dimension.
+ */
+ default Tensor sum(String dimension) {
+ return new TensorDimensionSum(dimension, this).result();
+ }
- default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
- default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); }
- default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); }
- default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); }
- default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); }
- default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); }
+ /**
+ * Returns the sum of all the cells of this tensor.
+ */
+ default double sum() {
+ double sum = 0;
+ for (Map.Entry<TensorAddress, Double> cell : cells().entrySet())
+ sum += cell.getValue();
+ return sum;
+ }
/**
* Returns true if the given tensor is mathematically equal to this:
@@ -179,28 +226,19 @@ public interface Tensor {
* @return the tensor on the standard string format
*/
static String toStandardString(Tensor tensor) {
- if ( emptyDimensions(tensor).size() > 0) // explicitly output type TODO: Always do that
- return typeToString(tensor) + ":" + contentToString(tensor);
+ Set<String> emptyDimensions = emptyDimensions(tensor);
+ if (emptyDimensions.size() > 0) // explicitly list empty dimensions
+ return "( " + unitTensorWithDimensions(emptyDimensions) + " * " + contentToString(tensor) + " )";
else
return contentToString(tensor);
}
- static String typeToString(Tensor tensor) {
- if (tensor.dimensions().isEmpty()) return "tensor()";
- StringBuilder b = new StringBuilder("tensor(");
- for (String dimension : tensor.dimensions())
- b.append(dimension).append("{},");
- b.setLength(b.length() -1);
- b.append(")");
- return b.toString();
- }
-
static String contentToString(Tensor tensor) {
- List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
- Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
+ List<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
+ Collections.sort(cellEntries, Map.Entry.<TensorAddress, Double>comparingByKey());
StringBuilder b = new StringBuilder("{");
- for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) {
+ for (Map.Entry<TensorAddress, Double> cell : cellEntries) {
b.append(cell.getKey()).append(":").append(cell.getValue());
b.append(",");
}
@@ -221,4 +259,8 @@ public interface Tensor {
return emptyDimensions;
}
+ static String unitTensorWithDimensions(Set<String> dimensions) {
+ return new MapTensor(Collections.singletonMap(TensorAddress.emptyWithDimensions(dimensions), 1.0)).toString();
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index e3c089de071..11c6a5f6685 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -8,11 +8,12 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
-import java.util.Optional;
import java.util.Set;
/**
* An immutable address to a tensor cell.
+ * This is sparse: Only dimensions which have a different label than "undefined" are
+ * explicitly included.
* <p>
* Tensor addresses are ordered by increasing size primarily, and by the natural order of the elements in sorted
* order secondarily.
@@ -65,6 +66,14 @@ public final class TensorAddress implements Comparable<TensorAddress> {
return TensorAddress.fromSorted(elements);
}
+ /** Creates an empty address with a set of dimensions */
+ public static TensorAddress emptyWithDimensions(Set<String> dimensions) {
+ List<Element> elements = new ArrayList<>(dimensions.size());
+ for (String dimension : dimensions)
+ elements.add(new Element(dimension, Element.undefinedLabel));
+ return TensorAddress.fromUnsorted(elements);
+ }
+
/** Returns an immutable list of the elements of this address in sorted order */
public List<Element> elements() { return elements; }
@@ -84,14 +93,6 @@ public final class TensorAddress implements Comparable<TensorAddress> {
return dimensions;
}
- /** Returns the label at the given dimension, or empty if this dimension is not present */
- public Optional<String> labelOfDimension(String dimension) {
- for (TensorAddress.Element element : elements)
- if (element.dimension().equals(dimension))
- return Optional.of(element.label());
- return Optional.empty();
- }
-
@Override
public int compareTo(TensorAddress other) {
int sizeComparison = Integer.compare(this.elements.size(), other.elements.size());
@@ -122,6 +123,7 @@ public final class TensorAddress implements Comparable<TensorAddress> {
public String toString() {
StringBuilder b = new StringBuilder("{");
for (TensorAddress.Element element : elements) {
+ //if (element.label() == Element.undefinedLabel) continue;
b.append(element.toString());
b.append(",");
}
@@ -134,13 +136,18 @@ public final class TensorAddress implements Comparable<TensorAddress> {
/** A tensor address element. Elements have the lexical order of the dimensions as natural order. */
public static class Element implements Comparable<Element> {
+ static final String undefinedLabel = "-";
+
private final String dimension;
private final String label;
private final int hashCode;
public Element(String dimension, String label) {
this.dimension = dimension;
- this.label = label;
+ if (label.equals(undefinedLabel))
+ this.label = undefinedLabel;
+ else
+ this.label = label;
this.hashCode = dimension.hashCode() + label.hashCode();
}
@@ -168,7 +175,9 @@ public final class TensorAddress implements Comparable<TensorAddress> {
@Override
public String toString() {
- return dimension + ":" + label;
+ StringBuilder b = new StringBuilder();
+ b.append(dimension).append(":").append(label);
+ return b.toString();
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java
new file mode 100644
index 00000000000..ceb003b1615
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java
@@ -0,0 +1,30 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Takes the difference between two tensors, see {@link Tensor#subtract}
+ *
+ * @author bratseth
+ */
+class TensorDifference {
+
+ private final Set<String> dimensions;
+ private final Map<TensorAddress, Double> cells = new HashMap<>();
+
+ public TensorDifference(Tensor a, Tensor b) {
+ this.dimensions = TensorOperations.combineDimensions(a, b);
+ cells.putAll(a.cells());
+ for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet())
+ cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) - bCell.getValue());
+ }
+
+ /** Returns the result of taking this sum */
+ public Tensor result() {
+ return new MapTensor(dimensions, cells);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java
new file mode 100644
index 00000000000..d15e5092476
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java
@@ -0,0 +1,35 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Takes the max of each cell of two tensors, see {@link Tensor#max}
+ *
+ * @author bratseth
+ */
+class TensorMax {
+
+ private final Set<String> dimensions;
+ private final Map<TensorAddress, Double> cells = new HashMap<>();
+
+ public TensorMax(Tensor a, Tensor b) {
+ dimensions = TensorOperations.combineDimensions(a, b);
+ cells.putAll(a.cells());
+ for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
+ Double aValue = a.cells().get(bCell.getKey());
+ if (aValue == null)
+ cells.put(bCell.getKey(), bCell.getValue());
+ else
+ cells.put(bCell.getKey(), Math.max(aValue, bCell.getValue()));
+ }
+ }
+
+ /** Returns the result of taking this sum */
+ public Tensor result() {
+ return new MapTensor(dimensions, cells);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java
new file mode 100644
index 00000000000..e389dea3883
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java
@@ -0,0 +1,33 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Takes the min of each cell of two tensors, see {@link Tensor#min}
+ *
+ * @author bratseth
+ */
+class TensorMin {
+
+ private final Set<String> dimensions;
+ private final Map<TensorAddress, Double> cells = new HashMap<>();
+
+ public TensorMin(Tensor a, Tensor b) {
+ dimensions = TensorOperations.combineDimensions(a, b);
+ cells.putAll(a.cells());
+ for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
+ Double aValue = a.cells().get(bCell.getKey());
+ if (aValue == null)
+ cells.put(bCell.getKey(), bCell.getValue());
+ else
+ cells.put(bCell.getKey(), Math.min(aValue, bCell.getValue()));
+ }
+ }
+
+ /** Returns the result of taking this sum */
+ public Tensor result() { return new MapTensor(dimensions, cells); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java
new file mode 100644
index 00000000000..aca306b914c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorOperations.java
@@ -0,0 +1,28 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import com.google.common.collect.ImmutableSet;
+
+import java.util.Set;
+
+/**
+ * Functions on tensors
+ *
+ * @author bratseth
+ */
+class TensorOperations {
+
+ /**
+ * A utility method which returns an ummutable set of the union of the dimensions
+ * of the two argument tensors.
+ *
+ * @return the combined dimensions as an unmodifiable set
+ */
+ static Set<String> combineDimensions(Tensor a, Tensor b) {
+ ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>();
+ setBuilder.addAll(a.dimensions());
+ setBuilder.addAll(b.dimensions());
+ return setBuilder.build();
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java
new file mode 100644
index 00000000000..221bd985380
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorProduct.java
@@ -0,0 +1,93 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import com.google.common.collect.ImmutableMap;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Computes a <i>sparse tensor product</i>, see {@link Tensor#multiply}
+ *
+ * @author bratseth
+ */
+class TensorProduct {
+
+ private final Set<String> dimensionsA, dimensionsB;
+
+ private final Set<String> dimensions;
+ private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
+
+ public TensorProduct(Tensor a, Tensor b) {
+ dimensionsA = a.dimensions();
+ dimensionsB = b.dimensions();
+
+ // Dimension product
+ dimensions = TensorOperations.combineDimensions(a, b);
+
+ // Cell product (slow baseline implementation)
+ for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) {
+ for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
+ TensorAddress combinedAddress = combine(aCell.getKey(), bCell.getKey());
+ if (combinedAddress == null) continue; // not combinable
+ cells.put(combinedAddress, aCell.getValue() * bCell.getValue());
+ }
+ }
+ }
+
+ private TensorAddress combine(TensorAddress a, TensorAddress b) {
+ List<TensorAddress.Element> combined = new ArrayList<>();
+ combined.addAll(dense(a, dimensionsA));
+ combined.addAll(dense(b, dimensionsB));
+ Collections.sort(combined);
+ TensorAddress.Element previous = null;
+ for (ListIterator<TensorAddress.Element> i = combined.listIterator(); i.hasNext(); ) {
+ TensorAddress.Element current = i.next();
+ if (previous != null && previous.dimension().equals(current.dimension())) { // an overlapping dimension
+ if (previous.label().equals(current.label()))
+ i.remove(); // a match: remove the duplicate
+ else
+ return null; // no match: a combination isn't viable
+ }
+ previous = current;
+ }
+ return TensorAddress.fromSorted(sparse(combined));
+ }
+
+ /**
+ * Returns a set of tensor elements which contains an entry for each dimension including "undefined" values
+ * (which are not present in the sparse elements list).
+ */
+ private List<TensorAddress.Element> dense(TensorAddress sparse, Set<String> dimensions) {
+ if (sparse.elements().size() == dimensions.size()) return sparse.elements();
+
+ List<TensorAddress.Element> dense = new ArrayList<>(sparse.elements());
+ for (String dimension : dimensions) {
+ if ( ! sparse.hasDimension(dimension))
+ dense.add(new TensorAddress.Element(dimension, TensorAddress.Element.undefinedLabel));
+ }
+ return dense;
+ }
+
+ /**
+ * Removes any "undefined" entries from the given elements.
+ */
+ private List<TensorAddress.Element> sparse(List<TensorAddress.Element> dense) {
+ List<TensorAddress.Element> sparse = new ArrayList<>();
+ for (TensorAddress.Element element : dense) {
+ if ( ! element.label().equals(TensorAddress.Element.undefinedLabel))
+ sparse.add(element);
+ }
+ return sparse;
+ }
+
+ /** Returns the result of taking this product */
+ public Tensor result() {
+ return new MapTensor(dimensions, cells.build());
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java
new file mode 100644
index 00000000000..85dfa289bd3
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java
@@ -0,0 +1,29 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Takes the sum of two tensors, see {@link Tensor#add}
+ *
+ * @author bratseth
+ */
+class TensorSum {
+
+ private final Set<String> dimensions;
+ private final Map<TensorAddress, Double> cells = new HashMap<>();
+
+ public TensorSum(Tensor a, Tensor b) {
+ dimensions = TensorOperations.combineDimensions(a, b);
+ cells.putAll(a.cells());
+ for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
+ cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) + bCell.getValue());
+ }
+ }
+
+ /** Returns the result of taking this sum */
+ public Tensor result() { return new MapTensor(dimensions, cells); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 31454e28baf..23cdc0e6051 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -1,7 +1,5 @@
package com.yahoo.tensor.functions;
-import com.yahoo.tensor.Tensor;
-
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
* as a tree of primitive tensor functions.
@@ -10,8 +8,4 @@ import com.yahoo.tensor.Tensor;
*/
public abstract class CompositeTensorFunction extends TensorFunction {
- /** Evaluates this by first converting it to a primitive function */
- @Override
- public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java
new file mode 100644
index 00000000000..113247be3bb
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Constant.java
@@ -0,0 +1,24 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.MapTensor;
+
+/**
+ * A function which returns a constant tensor.
+ *
+ * @author bratseth
+ */
+public class Constant extends PrimitiveTensorFunction {
+
+ private final MapTensor constant;
+
+ public Constant(String tensorString) {
+ this.constant = MapTensor.from(tensorString);
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ @Override
+ public String toString() { return constant.toString(); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
deleted file mode 100644
index 0727579a331..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ /dev/null
@@ -1,38 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.Tensor;
-
-import java.util.Collections;
-import java.util.List;
-
-/**
- * A function which returns a constant tensor.
- *
- * @author bratseth
- */
-public class ConstantTensor extends PrimitiveTensorFunction {
-
- private final Tensor constant;
-
- public ConstantTensor(String tensorString) {
- this.constant = MapTensor.from(tensorString);
- }
-
- public ConstantTensor(Tensor tensor) {
- this.constant = tensor;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
-
- @Override
- public Tensor evaluate(EvaluationContext context) { return constant; }
-
- @Override
- public String toString(ToStringContext context) { return constant.toString(); }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java
deleted file mode 100644
index 24a4c61a58c..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java
+++ /dev/null
@@ -1,14 +0,0 @@
-package com.yahoo.tensor.functions;
-
-/**
- * An evaluation context which is passed down to all nested functions during evaluation.
- * The default implementation is empty as this library does not in itself have any need for a
- * context.
- *
- * @author bratseth
- */
-public interface EvaluationContext {
-
- static EvaluationContext empty() { return new EvaluationContext() {}; }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
deleted file mode 100644
index c0e5776bf48..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ /dev/null
@@ -1,57 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.util.Collections;
-import java.util.List;
-import java.util.Objects;
-import java.util.function.Function;
-
-/**
- * An indexed tensor whose values are generated by a function
- *
- * @author bratseth
- */
-public class Generate extends PrimitiveTensorFunction {
-
- private final TensorType type;
- private final Function<List<Integer>, Double> generator;
-
- /**
- * Creates a generated tensor
- *
- * @param type the type of the tensor
- * @param generator the function generating values from a list of ints specifying the indexes of the
- * tensor cell which will receive the value
- * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
- */
- public Generate(TensorType type, Function<List<Integer>, Double> generator) {
- Objects.requireNonNull(type, "The argument tensor type cannot be null");
- Objects.requireNonNull(generator, "The argument function cannot be null");
- validateType(type);
- this.type = type;
- this.generator = generator;
- }
-
- private void validateType(TensorType type) {
- for (TensorType.Dimension dimension : type.dimensions())
- if (dimension.type() != TensorType.Dimension.Type.indexedBound)
- throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
-
- @Override
- public Tensor evaluate(EvaluationContext context) {
- throw new UnsupportedOperationException("Not implemented"); // TODO
- }
-
- @Override
- public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 323da5906c3..4d945963fdf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -1,24 +1,9 @@
package com.yahoo.tensor.functions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
- * The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
- * given by the cross product of the cells of the given tensors, having as values the value produced by
- * applying the given combinator function on the values from the two source cells.
+ * The join tensor function.
*
* @author bratseth
*/
@@ -28,9 +13,6 @@ public class Join extends PrimitiveTensorFunction {
private final DoubleBinaryOperator combinator;
public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
- Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
- Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
- Objects.requireNonNull(combinator, "The combinator function cannot be null");
this.argumentA = argumentA;
this.argumentB = argumentB;
this.combinator = combinator;
@@ -39,60 +21,15 @@ public class Join extends PrimitiveTensorFunction {
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
-
- @Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
-
+
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
-
- @Override
- public String toString(ToStringContext context) {
- return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
- }
-
- private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
-
+
@Override
- public Tensor evaluate(EvaluationContext context) {
- Tensor a = argumentA.evaluate(context);
- Tensor b = argumentB.evaluate(context);
-
- // Dimension product
- Set<String> dimensions = combineDimensions(a, b);
-
- // Cell product (slow baseline implementation)
- ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
- for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) {
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- TensorAddress combinedAddress = combineAddresses(aCell.getKey(), bCell.getKey());
- if (combinedAddress == null) continue; // not combinable
- cells.put(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
- }
- }
-
- return new MapTensor(dimensions, cells.build());
+ public String toString() {
+ return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", lambda(a, b) (...))";
}
- private Set<String> combineDimensions(Tensor a, Tensor b) {
- ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<>();
- setBuilder.addAll(a.dimensions());
- setBuilder.addAll(b.dimensions());
- return setBuilder.build();
- }
-
- private TensorAddress combineAddresses(TensorAddress a, TensorAddress b) {
- List<TensorAddress.Element> combined = new ArrayList<>(a.elements());
- for (TensorAddress.Element bElement : b.elements()) {
- Optional<String> aLabel = a.labelOfDimension(bElement.dimension());
- if ( ! aLabel.isPresent())
- combined.add(bElement);
- else if ( ! aLabel.get().equals(bElement.label()))
- return null; // incompatible
- }
- return TensorAddress.fromUnsorted(combined);
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
deleted file mode 100644
index 4467b378b3f..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ /dev/null
@@ -1,36 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import java.util.Collections;
-import java.util.List;
-
-/**
- * @author bratseth
- */
-public class L1Normalize extends CompositeTensorFunction {
-
- private final TensorFunction argument;
- private final String dimension;
-
- public L1Normalize(TensorFunction argument, String dimension) {
- this.argument = argument;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- // join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y))
- return new Join(primitiveArgument,
- new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension),
- ScalarFunctions.divide());
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
deleted file mode 100644
index 0e96b43bd22..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ /dev/null
@@ -1,38 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import java.util.Collections;
-import java.util.List;
-
-/**
- * @author bratseth
- */
-public class L2Normalize extends CompositeTensorFunction {
-
- private final TensorFunction argument;
- private final String dimension;
-
- public L2Normalize(TensorFunction argument, String dimension) {
- this.argument = argument;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(primitiveArgument,
- new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
- Reduce.Aggregator.sum,
- dimension),
- ScalarFunctions.sqrt()),
- ScalarFunctions.divide());
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 5db88953c64..22dd08504d7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -1,17 +1,10 @@
package com.yahoo.tensor.functions;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.Collections;
-import java.util.List;
-import java.util.Objects;
+import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
/**
- * The <i>map</i> tensor function produces a tensor where the given function is applied on each cell value.
+ * The join tensor function.
*
* @author bratseth
*/
@@ -21,8 +14,6 @@ public class Map extends PrimitiveTensorFunction {
private final DoubleUnaryOperator mapper;
public Map(TensorFunction argument, DoubleUnaryOperator mapper) {
- Objects.requireNonNull(argument, "The argument tensor cannot be null");
- Objects.requireNonNull(mapper, "The argument function cannot be null");
this.argument = argument;
this.mapper = mapper;
}
@@ -31,25 +22,13 @@ public class Map extends PrimitiveTensorFunction {
public DoubleUnaryOperator mapper() { return mapper; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
-
- @Override
public PrimitiveTensorFunction toPrimitive() {
return new Map(argument.toPrimitive(), mapper);
}
@Override
- public Tensor evaluate(EvaluationContext context) {
- Tensor argument = argument().evaluate(context);
- ImmutableMap.Builder<TensorAddress, Double> mappedCells = new ImmutableMap.Builder<>();
- for (java.util.Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet())
- mappedCells.put(cell.getKey(), mapper.applyAsDouble(cell.getValue()));
- return new MapTensor(argument.dimensions(), mappedCells.build());
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "map(" + argument.toString(context) + ", " + mapper + ")";
+ public String toString() {
+ return "map(" + argument.toString() + ", lambda(a) (...))";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
deleted file mode 100644
index 4492ab083d4..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ /dev/null
@@ -1,38 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import com.google.common.collect.ImmutableList;
-
-import java.util.List;
-
-/**
- * @author bratseth
- */
-public class Matmul extends CompositeTensorFunction {
-
- private final TensorFunction argument1, argument2;
- private final String dimension;
-
- public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
- this.argument1 = argument1;
- this.argument2 = argument2;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument1 = argument1.toPrimitive();
- TensorFunction primitiveArgument2 = argument2.toPrimitive();
- return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- dimension);
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index 91e58f4bf3b..9c0c9abaeb7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -1,7 +1,5 @@
package com.yahoo.tensor.functions;
-import com.yahoo.tensor.Tensor;
-
/**
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
@@ -10,5 +8,4 @@ import com.yahoo.tensor.Tensor;
* @author bratseth
*/
public abstract class PrimitiveTensorFunction extends TensorFunction {
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java
new file mode 100644
index 00000000000..09038a294ce
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Product.java
@@ -0,0 +1,27 @@
+package com.yahoo.tensor.functions;
+
+/**
+ * The product tensor function
+ *
+ * @author bratseth
+ */
+public class Product extends CompositeTensorFunction {
+
+ private final TensorFunction argumentA, argumentB;
+
+ public Product(TensorFunction argumentA, TensorFunction argumentB) {
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), (a, b) -> a * b);
+ }
+
+ @Override
+ public String toString() {
+ return "product(" + argumentA.toString() + ", " + argumentB.toString() + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index ef18cb61b17..4b306d376a6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -1,246 +1,38 @@
package com.yahoo.tensor.functions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.stream.Collectors;
+import java.util.Optional;
+import java.util.function.DoubleBinaryOperator;
/**
- * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
- * are collapsed to a single value using an aggregator function.
+ * The reduce tensor function.
*
* @author bratseth
*/
public class Reduce extends PrimitiveTensorFunction {
- public enum Aggregator { avg, count, prod, sum, max, min; }
-
private final TensorFunction argument;
- private final List<String> dimensions;
- private final Aggregator aggregator;
-
- /** Creates a reduce function reducing aLL dimensions */
- public Reduce(TensorFunction argument, Aggregator aggregator) {
- this(argument, aggregator, Collections.emptyList());
- }
+ private final String dimension;
+ private final DoubleBinaryOperator reductor;
+ private final Optional<DoubleBinaryOperator> postTransformation;
- /** Creates a reduce function reducing a single dimension */
- public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
- this(argument, aggregator, Collections.singletonList(dimension));
- }
-
- /**
- * Creates a reduce function.
- *
- * @param argument the tensor to reduce
- * @param aggregator the aggregator function to use
- * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
- * producing a dimensionless tensor (a scalar).
- * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
- */
- public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
- Objects.requireNonNull(argument, "The argument tensor cannot be null");
- Objects.requireNonNull(aggregator, "The aggregator cannot be null");
- Objects.requireNonNull(dimensions, "The dimensions cannot be null");
+ public Reduce(TensorFunction argument, String dimension,
+ DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) {
this.argument = argument;
- this.aggregator = aggregator;
- this.dimensions = ImmutableList.copyOf(dimensions);
+ this.dimension = dimension;
+ this.reductor = reductor;
+ this.postTransformation = postTransformation;
}
public TensorFunction argument() { return argument; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
-
- @Override
public PrimitiveTensorFunction toPrimitive() {
- return new Reduce(argument.toPrimitive(), aggregator, dimensions);
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
- }
-
- private String commaSeparated(List<String> list) {
- StringBuilder b = new StringBuilder();
- for (String element : list)
- b.append(", ").append(element);
- return b.toString();
+ return new Reduce(argument.toPrimitive(), dimension, reductor, postTransformation);
}
@Override
- public Tensor evaluate(EvaluationContext context) {
- Tensor argument = this.argument.evaluate(context);
-
- if ( ! dimensions.isEmpty() && ! argument.dimensions().containsAll(dimensions))
- throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
- dimensions + ": Not all those dimensions are present in this tensor");
-
- if (dimensions.isEmpty() || dimensions.size() == argument.dimensions().size())
- return reduceAll(argument);
-
- // Reduce dimensions
- Set<String> reducedDimensions = new HashSet<>(argument.dimensions());
- reducedDimensions.removeAll(dimensions);
-
- // Reduce cells
- Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
- for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) {
- TensorAddress reducedAddress = reduceDimensions(cell.getKey(), reducedDimensions);
- aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
- aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
- }
- ImmutableMap.Builder<TensorAddress, Double> reducedCells = new ImmutableMap.Builder<>();
- for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
- reducedCells.put(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
- return new MapTensor(reducedDimensions, reducedCells.build());
- }
-
- private TensorAddress reduceDimensions(TensorAddress address, Set<String> reducedDimensions) {
- return TensorAddress.fromSorted(address.elements().stream()
- .filter(e -> reducedDimensions.contains(e.dimension()))
- .collect(Collectors.toList()));
- }
-
- private Tensor reduceAll(Tensor argument) {
- ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
- for (Double cellValue : argument.cells().values())
- valueAggregator.aggregate(cellValue);
- return new MapTensor(ImmutableMap.of(TensorAddress.empty, valueAggregator.aggregatedValue()));
- }
-
- private static abstract class ValueAggregator {
-
- public static ValueAggregator ofType(Aggregator aggregator) {
- switch (aggregator) {
- case avg : return new AvgAggregator();
- case count : return new CountAggregator();
- case prod : return new ProdAggregator();
- case sum : return new SumAggregator();
- case max : return new MaxAggregator();
- case min : return new MinAggregator();
- default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
- }
-
- }
-
- /** Add a new value to those aggregated by this */
- public abstract void aggregate(double value);
-
- /** Returns the value aggregated by this */
- public abstract double aggregatedValue();
-
- }
-
- private static class AvgAggregator extends ValueAggregator {
-
- private int valueCount = 0;
- private double valueSum = 0.0;
-
- @Override
- public void aggregate(double value) {
- valueCount++;
- valueSum+= value;
- }
-
- @Override
- public double aggregatedValue() {
- return valueSum / valueCount;
- }
-
- }
-
- private static class CountAggregator extends ValueAggregator {
-
- private int valueCount = 0;
-
- @Override
- public void aggregate(double value) {
- valueCount++;
- }
-
- @Override
- public double aggregatedValue() {
- return valueCount;
- }
-
- }
-
- private static class ProdAggregator extends ValueAggregator {
-
- private double valueProd = 1.0;
-
- @Override
- public void aggregate(double value) {
- valueProd *= value;
- }
-
- @Override
- public double aggregatedValue() {
- return valueProd;
- }
-
- }
-
- private static class SumAggregator extends ValueAggregator {
-
- private double valueSum = 0.0;
-
- @Override
- public void aggregate(double value) {
- valueSum += value;
- }
-
- @Override
- public double aggregatedValue() {
- return valueSum;
- }
-
- }
-
- private static class MaxAggregator extends ValueAggregator {
-
- private double maxValue = Double.MIN_VALUE;
-
- @Override
- public void aggregate(double value) {
- if (value > maxValue)
- maxValue = value;
- }
-
- @Override
- public double aggregatedValue() {
- return maxValue;
- }
-
- }
-
- private static class MinAggregator extends ValueAggregator {
-
- private double minValue = Double.MAX_VALUE;
-
- @Override
- public void aggregate(double value) {
- if (value < minValue)
- minValue = value;
- }
-
- @Override
- public double aggregatedValue() {
- return minValue;
- }
-
+ public String toString() {
+ return "reduce(" + argument.toString() + ", " + dimension + ", lambda(a, b) (...), lambda(a, b) (...))";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
deleted file mode 100644
index 05af86c33e8..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ /dev/null
@@ -1,100 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-/**
- * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
- *
- * @author bratseth
- */
-public class Rename extends PrimitiveTensorFunction {
-
- private final TensorFunction argument;
- private final List<String> fromDimensions;
- private final List<String> toDimensions;
-
- public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
- Objects.requireNonNull(argument, "The argument tensor cannot be null");
- Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
- Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
- if (fromDimensions.size() < 1)
- throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
- if (fromDimensions.size() != toDimensions.size())
- throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
- fromDimensions.size() + " and " + toDimensions.size());
- this.argument = argument;
- this.fromDimensions = ImmutableList.copyOf(fromDimensions);
- this.toDimensions = ImmutableList.copyOf(toDimensions);
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
-
- @Override
- public Tensor evaluate(EvaluationContext context) {
- Tensor tensor = argument.evaluate(context);
- Map<String, String> fromToMap = fromToMap();
- Set<String> renamedDimensions = tensor.dimensions().stream()
- .map((d) -> fromToMap.getOrDefault(d, d))
- .collect(Collectors.toSet());
-
- ImmutableMap.Builder<TensorAddress, Double> renamedCells = new ImmutableMap.Builder<>();
- for (Map.Entry<TensorAddress, Double> cell : tensor.cells().entrySet()) {
- TensorAddress renamedAddress = rename(cell.getKey(), fromToMap);
- renamedCells.put(renamedAddress, cell.getValue());
- }
- return new MapTensor(renamedDimensions, renamedCells.build());
- }
-
- private TensorAddress rename(TensorAddress address, Map<String, String> fromToMap) {
- List<TensorAddress.Element> renamedElements = new ArrayList<>();
- for (TensorAddress.Element element : address.elements()) {
- String toDimension = fromToMap.get(element.dimension());
- if (toDimension == null)
- renamedElements.add(element);
- else
- renamedElements.add(new TensorAddress.Element(toDimension, element.label()));
- }
- return TensorAddress.fromUnsorted(renamedElements);
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "rename(" + argument.toString(context) + ", " +
- toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
- }
-
- private Map<String, String> fromToMap() {
- Map<String, String> map = new HashMap<>();
- for (int i = 0; i < fromDimensions.size(); i++)
- map.put(fromDimensions.get(i), toDimensions.get(i));
- return map;
- }
-
- private String toVectorString(List<String> elements) {
- if (elements.size() == 1)
- return elements.get(0);
- StringBuilder b = new StringBuilder("[");
- for (String element : elements)
- b.append(element).append(", ");
- b.setLength(b.length() - 2);
- return b.toString();
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
deleted file mode 100644
index 9438c6c533a..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ /dev/null
@@ -1,81 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import java.util.function.DoubleBinaryOperator;
-import java.util.function.DoubleUnaryOperator;
-
-/**
- * Factory of scalar Java functions.
- * The purpose of this is to embellish anonymous functions with a runtime type
- * such that they can be inspected and will return a parseable toString.
- *
- * @author bratseth
- */
-public class ScalarFunctions {
-
- public static DoubleBinaryOperator add() { return new Addition(); }
- public static DoubleBinaryOperator multiply() { return new Multiplication(); }
- public static DoubleBinaryOperator divide() { return new Division(); }
- public static DoubleUnaryOperator square() { return new Square(); }
- public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
- public static DoubleUnaryOperator exp() { return new Exponent(); }
-
- public static class Addition implements DoubleBinaryOperator {
-
- @Override
- public double applyAsDouble(double left, double right) { return left + right; }
-
- @Override
- public String toString() { return "f(a,b)(a + b)"; }
-
- }
-
- public static class Multiplication implements DoubleBinaryOperator {
-
- @Override
- public double applyAsDouble(double left, double right) { return left * right; }
-
- @Override
- public String toString() { return "f(a,b)(a * b)"; }
-
- }
-
- public static class Division implements DoubleBinaryOperator {
-
- @Override
- public double applyAsDouble(double left, double right) { return left / right; }
-
- @Override
- public String toString() { return "f(a,b)(a / b)"; }
- }
-
- public static class Square implements DoubleUnaryOperator {
-
- @Override
- public double applyAsDouble(double operand) { return operand * operand; }
-
- @Override
- public String toString() { return "f(a)(a * a)"; }
-
- }
-
- public static class Sqrt implements DoubleUnaryOperator {
-
- @Override
- public double applyAsDouble(double operand) { return Math.sqrt(operand); }
-
- @Override
- public String toString() { return "f(a)(sqrt(a))"; }
-
- }
-
- public static class Exponent implements DoubleUnaryOperator {
-
- @Override
- public double applyAsDouble(double operand) { return Math.exp(operand); }
-
- @Override
- public String toString() { return "f(a)(exp(a))"; }
-
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
deleted file mode 100644
index b05b8172b42..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ /dev/null
@@ -1,37 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import java.util.Collections;
-import java.util.List;
-
-/**
- * @author bratseth
- */
-public class Softmax extends CompositeTensorFunction {
-
- private final TensorFunction argument;
- private final String dimension;
-
- public Softmax(TensorFunction argument, String dimension) {
- this.argument = argument;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
- new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()),
- Reduce.Aggregator.sum,
- dimension),
- ScalarFunctions.divide());
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "softmax(" + argument.toString(context) + ", " + dimension + ")";
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index a717292632e..95fca95a042 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -1,9 +1,5 @@
package com.yahoo.tensor.functions;
-import com.yahoo.tensor.Tensor;
-
-import java.util.List;
-
/**
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
@@ -13,9 +9,6 @@ import java.util.List;
*/
public abstract class TensorFunction {
- /** Returns the function arguments of this node in the order they are applied */
- public abstract List<TensorFunction> functionArguments();
-
/**
* Translate this function - and all of its arguments recursively -
* to a tree of primitive functions only.
@@ -24,24 +17,4 @@ public abstract class TensorFunction {
*/
public abstract PrimitiveTensorFunction toPrimitive();
- /**
- * Evaluates this tensor.
- *
- * @param context a context which must be passed to all nexted functions when evaluating
- */
- public abstract Tensor evaluate(EvaluationContext context);
-
- /** Evaluate with no context */
- public final Tensor evaluate() { return evaluate(EvaluationContext.empty()); }
-
- /**
- * Return a string representation of this context.
- *
- * @param context a context which must be passed to all nexted functions when requesting the string value
- */
- public abstract String toString(ToStringContext context);
-
- @Override
- public final String toString() { return toString(ToStringContext.empty()); }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
deleted file mode 100644
index b71229703d2..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
+++ /dev/null
@@ -1,14 +0,0 @@
-package com.yahoo.tensor.functions;
-
-/**
- * A context which is passed down to all nested functions when returning a string representation.
- * The default implementation is empty as this library does not in itself have any need for a
- * context.
- *
- * @author bratseth
- */
-public interface ToStringContext {
-
- static ToStringContext empty() { return new ToStringContext() {}; }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
deleted file mode 100644
index 1988c1d2390..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ /dev/null
@@ -1,45 +0,0 @@
-package com.yahoo.tensor.functions;
-
-import com.google.common.collect.ImmutableList;
-
-import java.util.List;
-
-/**
- * @author bratseth
- */
-public class XwPlusB extends CompositeTensorFunction {
-
- private final TensorFunction x, w, b;
- private final String dimension;
-
- public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) {
- this.x = x;
- this.w = w;
- this.b = b;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveX = x.toPrimitive();
- TensorFunction primitiveW = w.toPrimitive();
- TensorFunction primitiveB = b.toPrimitive();
- return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- dimension),
- primitiveB,
- ScalarFunctions.add());
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "xw_plus_b(" + x.toString(context) + ", " +
- w.toString(context) + ", " +
- b.toString(context) + ", " +
- dimension + ")";
- }
-
-}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java
index af2260e2f20..889b2851a08 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorBuilderTestCase.java
@@ -42,7 +42,7 @@ public class MapTensorBuilderTestCase {
Tensor tensor = new MapTensorBuilder().dimension("y").dimension("z").
cell().label("x", "0").value(1).build();
assertEquals(Sets.newHashSet("x", "y", "z"), tensor.dimensions());
- assertEquals("tensor(x{},y{},z{}):{{x:0}:1.0}", tensor.toString());
+ assertEquals("( {{y:-,z:-}:1.0} * {{x:0}:1.0} )", tensor.toString());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java
index 0372f328811..13ea0e95dc8 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MapTensorTestCase.java
@@ -33,7 +33,7 @@ public class MapTensorTestCase {
fail("Expected parse error");
}
catch (IllegalArgumentException expected) {
- assertEquals("Excepted a number or a string starting by { or tensor(, got '--'", expected.getMessage());
+ assertEquals("Excepted a string starting by { or (, got '--'", expected.getMessage());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
deleted file mode 100644
index e403bb56d56..00000000000
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ /dev/null
@@ -1,28 +0,0 @@
-package com.yahoo.tensor;
-
-import com.google.common.collect.ImmutableList;
-import org.junit.Test;
-import static org.junit.Assert.assertEquals;
-
-/**
- * Tests functionality on Tensor
- *
- * @author bratseth
- */
-public class TensorTestCase {
-
- /** This is mostly tested in searchlib - spot checking here */
- @Test
- public void testTensorComputation() {
- MapTensor tensor1 = MapTensor.from("{ {x:1}:3, {x:2}:7 }");
- MapTensor tensor2 = MapTensor.from("{ {y:1}:5 }");
- assertEquals(MapTensor.from("{ {x:1,y:1}:15, {x:2,y:1}:35 }"), tensor1.multiply(tensor2));
- assertEquals(MapTensor.from("{ {x:1,y:1}:12, {x:2,y:1}:28 }"), tensor1.join(tensor2, (a, b) -> a * b - a ));
- assertEquals(MapTensor.from("{ {x:1,y:1}:0, {x:2,y:1}:1 }"), tensor1.larger(tensor2));
- assertEquals(MapTensor.from("{ {y:1}:50.0 }"), tensor1.matmul(tensor2, "x"));
- assertEquals(MapTensor.from("{ {z:1}:3, {z:2}:7 }"), tensor1.rename("x", "z"));
- assertEquals(MapTensor.from("{ {y:1,x:1}:8, {y:2,x:1}:12 }"), tensor1.add(tensor2).rename(ImmutableList.of("x", "y"),
- ImmutableList.of("y", "x")));
- }
-
-}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index cc9328f7274..501397e89bc 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -12,8 +12,8 @@ public class TensorFunctionTestCase {
@Test
public void testTranslation() {
- assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))",
- new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x"));
+ assertTranslated("join({{x:1}:1.0}, {{x:2}:1.0}, lambda(a, b) (...))",
+ new Product(new Constant("{{x:1}:1.0}"), new Constant("{{x:2}:1.0}")));
}
private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index c3a5e24afc2..8580868dfdf 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -52,14 +52,14 @@ public class SparseBinaryFormatTestCase {
@Test
public void testSerializationOfTensorsWithSparseTensorAddresses() {
assertSerialization("{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x"));
- assertSerialization("tensor(x{},y{}):{{x:0}:2.0}", Sets.newHashSet("x", "y"));
- assertSerialization("tensor(x{},y{}):{{x:0}:2.0, {}:3.0}", Sets.newHashSet("x", "y"));
- assertSerialization("tensor(x{},y{}):{{x:0}:2.0,{x:1}:3.0}", Sets.newHashSet("x", "y"));
- assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0}", Sets.newHashSet("x", "y", "z"));
- assertSerialization("tensor(x{},y{},z{}):{{x:0,y:0}:2.0,{x:0,y:1}:3.0}", Sets.newHashSet("x", "y", "z"));
- assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0}", Sets.newHashSet("x", "y", "z"));
- assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0}:2.0,{y:1,x:0}:3.0}", Sets.newHashSet("x", "y", "z"));
- assertSerialization("tensor(x{},y{},z{}):{{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0}", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("({{y:-}:1} * {{x:0}:2.0})", Sets.newHashSet("x", "y"));
+ assertSerialization("({{y:-}:1} * {{x:0}:2.0, {}:3.0})", Sets.newHashSet("x", "y"));
+ assertSerialization("({{y:-}:1} * {{x:0}:2.0,{x:1}:3.0})", Sets.newHashSet("x", "y"));
+ assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0})", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("({{z:-}:1} * {{x:0,y:0}:2.0,{x:0,y:1}:3.0})", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0})", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("({{z:-}:1} * {{y:0,x:0}:2.0,{y:1,x:0}:3.0})", Sets.newHashSet("x", "y", "z"));
+ assertSerialization("({{z:-}:1} * {{}:2.0,{x:0}:3.0,{x:0,y:0}:5.0})", Sets.newHashSet("x", "y", "z"));
}
@Test