summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2016-11-25 18:21:25 +0100
committerGitHub <noreply@github.com>2016-11-25 18:21:25 +0100
commit11b208db7d2422828c90aafa638f059306acbc24 (patch)
tree63d3f766b7a046b13b2b4fdc8e633fe71134847c /searchlib/src/main
parent5400980ea6bbac6ef385d089b5e9f9b100ecae71 (diff)
Revert "Bratseth/tensor functions 3"
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java70
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java122
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java111
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java59
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java65
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java3
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj375
15 files changed, 298 insertions, 583 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 620c6fad0b4..0dff0414ac2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -2,7 +2,6 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
-import com.yahoo.tensor.functions.EvaluationContext;
import java.util.Set;
@@ -11,7 +10,7 @@ import java.util.Set;
*
* @author bratseth
*/
-public abstract class Context implements EvaluationContext {
+public abstract class Context {
/**
* <p>Returns the value of a simple variable name.</p>
@@ -42,7 +41,7 @@ public abstract class Context implements EvaluationContext {
* "main" (or only) value.
*/
public Value get(String name, Arguments arguments,String output) {
- if (arguments!=null && arguments.expressions().size() > 0)
+ if (arguments!=null && arguments.expressions().size()>0)
throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" +
name + arguments + "'");
if (output==null)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index f8dcd8a6127..2bae382d5bd 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
+ public boolean compare(TruthOperator operator, Value value) {
+ return operator.evaluate(asDouble(), value.asDouble());
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 0e0d793bfd1..028dad16d21 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -98,6 +98,16 @@ public final class DoubleValue extends DoubleCompatibleValue {
}
@Override
+ public boolean compare(TruthOperator operator, Value value) {
+ try {
+ return operator.evaluate(this.value, value.asDouble());
+ }
+ catch (UnsupportedOperationException e) {
+ throw unsupported("comparison",value);
+ }
+ }
+
+ @Override
public Value function(Function function, Value value) {
// use the tensor implementation of max and min if the argument is a tensor
if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 2dffe2a1100..9ee9a1f7a71 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -34,9 +34,11 @@ public class MapContext extends Context {
* Creates a map context from a map.
* The ownership of the map is transferred to this - it cannot be further modified by the caller.
* All the Values of the map will be frozen.
+ *
+ * @since 5.1.5
*/
public MapContext(Map<String,Value> bindings) {
- this.bindings = bindings;
+ this.bindings=bindings;
for (Value boundValue : bindings.values())
boundValue.freeze();
}
@@ -65,9 +67,6 @@ public class MapContext extends Context {
if (frozen) return bindings;
return Collections.unmodifiableMap(bindings);
}
-
- /** Returns a new, modifiable context containing all the bindings of this */
- public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); }
/** Returns an unmodifiable map of the names of this */
public @Override Set<String> names() {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index eb997ab818a..379b5755c7b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -68,10 +68,10 @@ public class StringValue extends Value {
}
@Override
- public Value compare(TruthOperator operator, Value value) {
+ public boolean compare(TruthOperator operator, Value value) {
if (operator.equals(TruthOperator.EQUAL))
- return new BooleanValue(this.equals(value));
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
+ return this.equals(value);
+ throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index b1f4a7b20ca..12bede95aae 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -8,7 +8,6 @@ import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.TensorType;
-import java.util.Collections;
import java.util.Optional;
/**
@@ -18,7 +17,7 @@ import java.util.Optional;
*
* @author bratseth
*/
-@Beta
+ @Beta
public class TensorValue extends Value {
/** The tensor value of this */
@@ -54,7 +53,7 @@ public class TensorValue extends Value {
@Override
public Value negate() {
- return new TensorValue(value.map((value) -> -value));
+ return new TensorValue(value.apply((Double value) -> -value));
}
@Override
@@ -62,7 +61,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
else
- return new TensorValue(value.map((value) -> value + argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value + argument.asDouble()));
}
@Override
@@ -70,7 +69,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.subtract(((TensorValue) argument).value));
else
- return new TensorValue(value.map((value) -> value - argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value - argument.asDouble()));
}
@Override
@@ -78,15 +77,35 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.multiply(((TensorValue) argument).value));
else
- return new TensorValue(value.map((value) -> value * argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value * argument.asDouble()));
}
@Override
public Value divide(Value argument) {
if (argument instanceof TensorValue)
- return new TensorValue(value.divide(((TensorValue) argument).value));
+ throw new UnsupportedOperationException("Two tensors cannot be divided");
else
- return new TensorValue(value.map((value) -> value / argument.asDouble()));
+ return new TensorValue(value.apply((Double value) -> value / argument.asDouble()));
+ }
+
+ public Value match(Value argument) {
+ return new TensorValue(value.match(asTensor(argument, "match")));
+ }
+
+ public Value min(Value argument) {
+ return new TensorValue(value.min(asTensor(argument, "min")));
+ }
+
+ public Value max(Value argument) {
+ return new TensorValue(value.max(asTensor(argument, "max")));
+ }
+
+ public Value sum(String dimension) {
+ return new TensorValue(value.sum(dimension));
+ }
+
+ public Value sum() {
+ return new DoubleValue(value.sum());
}
private Tensor asTensor(Value value, String operationName) {
@@ -103,37 +122,18 @@ public class TensorValue extends Value {
}
@Override
- public Value compare(TruthOperator operator, Value argument) {
- return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
- }
-
- private Tensor compareTensor(TruthOperator operator, Tensor argument) {
- switch (operator) {
- case LARGER: return value.larger(argument);
- case LARGEREQUAL: return value.largerOrEqual(argument);
- case SMALLER: return value.smaller(argument);
- case SMALLEREQUAL: return value.smallerOrEqual(argument);
- case EQUAL: return value.equal(argument);
- case NOTEQUAL: return value.notEqual(argument);
- default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
- }
+ public boolean compare(TruthOperator operator, Value value) {
+ throw new UnsupportedOperationException("A tensor cannot be compared with any value");
}
@Override
- public Value function(Function function, Value arg) {
- if (arg instanceof TensorValue)
- return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString())));
+ public Value function(Function function, Value argument) {
+ if (function.equals(Function.min) && argument instanceof TensorValue)
+ return min(argument);
+ else if (function.equals(Function.max) && argument instanceof TensorValue)
+ return max(argument);
else
- return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
- }
-
- private Tensor functionOnTensor(Function function, Tensor argument) {
- switch (function) {
- case min: return value.min(argument);
- case max: return value.max(argument);
- case atan2: return value.atan2(argument);
- default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
- }
+ return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble())));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 8ce18265231..e5680edc68a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -42,7 +42,7 @@ public abstract class Value {
public abstract Value divide(Value value);
/** Perform the comparison specified by the operator between this value and the given value */
- public abstract Value compare(TruthOperator operator, Value value);
+ public abstract boolean compare(TruthOperator operator,Value value);
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function,Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index af05acb365a..882d16ebc1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -8,9 +8,10 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import java.util.*;
/**
- * A node which returns the outcome of a comparison.
+ * A node which returns true or false depending on the outcome of a comparison.
*
* @author bratseth
+ * @since 5.1.21
*/
public class ComparisonNode extends BooleanNode {
@@ -47,9 +48,9 @@ public class ComparisonNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
- Value leftValue = leftCondition.evaluate(context);
- Value rightValue = rightCondition.evaluate(context);
- return leftValue.compare(operator,rightValue);
+ Value leftValue=leftCondition.evaluate(context);
+ Value rightValue=rightCondition.evaluate(context);
+ return new BooleanValue(leftValue.compare(operator,rightValue));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index 19b1a83ed99..675ce758faa 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -12,38 +12,31 @@ import static java.lang.Math.*;
*/
public enum Function implements Serializable {
- abs { public double evaluate(double x, double y) { return abs(x); } },
+ cosh { public double evaluate(double x, double y) { return cosh(x); } },
+ sinh { public double evaluate(double x, double y) { return sinh(x); } },
+ tanh { public double evaluate(double x, double y) { return tanh(x); } },
+ cos { public double evaluate(double x, double y) { return cos(x); } },
+ sin { public double evaluate(double x, double y) { return sin(x); } },
+ tan { public double evaluate(double x, double y) { return tan(x); } },
acos { public double evaluate(double x, double y) { return acos(x); } },
asin { public double evaluate(double x, double y) { return asin(x); } },
atan { public double evaluate(double x, double y) { return atan(x); } },
- ceil { public double evaluate(double x, double y) { return ceil(x); } },
- cos { public double evaluate(double x, double y) { return cos(x); } },
- cosh { public double evaluate(double x, double y) { return cosh(x); } },
- elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } },
exp { public double evaluate(double x, double y) { return exp(x); } },
+ log10 { public double evaluate(double x, double y) { return log10(x); } },
+ log { public double evaluate(double x, double y) { return log(x); } },
+ sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
+ ceil { public double evaluate(double x, double y) { return ceil(x); } },
fabs { public double evaluate(double x, double y) { return abs(x); } },
floor { public double evaluate(double x, double y) { return floor(x); } },
isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } },
- log { public double evaluate(double x, double y) { return log(x); } },
- log10 { public double evaluate(double x, double y) { return log10(x); } },
relu { public double evaluate(double x, double y) { return max(x,0); } },
- round { public double evaluate(double x, double y) { return round(x); } },
sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } },
- sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } },
- sin { public double evaluate(double x, double y) { return sin(x); } },
- sinh { public double evaluate(double x, double y) { return sinh(x); } },
- square { public double evaluate(double x, double y) { return x*x; } },
- sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
- tan { public double evaluate(double x, double y) { return tan(x); } },
- tanh { public double evaluate(double x, double y) { return tanh(x); } },
-
atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } },
- fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } },
+ pow(2) { public double evaluate(double x, double y) { return pow(x,y); } },
ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } },
- max(2) { public double evaluate(double x, double y) { return max(x,y); } },
+ fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
- mod(2) { public double evaluate(double x, double y) { return x % y; } },
- pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
+ max(2) { public double evaluate(double x, double y) { return max(x,y); } };
private final int arity;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
deleted file mode 100644
index 7b48288598d..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ /dev/null
@@ -1,122 +0,0 @@
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-import java.util.function.DoubleBinaryOperator;
-import java.util.function.DoubleUnaryOperator;
-
-/**
- * A free, parametrized function
- *
- * @author bratseth
- */
-public class LambdaFunctionNode extends CompositeNode {
-
- private final ImmutableList<String> arguments;
- private final ExpressionNode functionExpression;
-
- public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
- // TODO: Verify that the function only accesses the arguments in mapperVariables
- this.arguments = ImmutableList.copyOf(arguments);
- this.functionExpression = functionExpression;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return Collections.singletonList(functionExpression);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if ( children.size() != 1)
- throw new IllegalArgumentException("A lambda function must have a single child expression");
- return new LambdaFunctionNode(arguments, children.get(0));
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")";
- }
-
- private String commaSeparated(List<String> list) {
- StringBuilder b = new StringBuilder();
- for (String element : list)
- b.append(element).append(",");
- if (b.length() > 0)
- b.setLength(b.length()-1);
- return b.toString();
- }
-
- /** Evaluate this in a context which must have the arguments bound */
- @Override
- public Value evaluate(Context context) {
- return functionExpression.evaluate(context);
- }
-
- /**
- * Returns this as a double unary operator
- *
- * @throws IllegalStateException if this has more than one argument
- */
- public DoubleUnaryOperator asDoubleUnaryOperator() {
- if (arguments.size() > 1)
- throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
- "Must have at most one argument " + " but has " + arguments);
- return new DoubleUnaryLambda();
- }
-
- /**
- * Returns this as a double binary operator
- *
- * @throws IllegalStateException if this has more than two arguments
- */
- public DoubleBinaryOperator asDoubleBinaryOperator() {
- if (arguments.size() > 2)
- throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
- "Must have at most two argument " + " but has " + arguments);
- return new DoubleBinaryLambda();
- }
-
- private class DoubleUnaryLambda implements DoubleUnaryOperator {
-
- @Override
- public double applyAsDouble(double operand) {
- MapContext context = new MapContext();
- if (arguments.size() > 0)
- context.put(arguments.get(0), operand);
- return evaluate(context).asDouble();
- }
-
- @Override
- public String toString() {
- return LambdaFunctionNode.this.toString();
- }
-
- }
-
- private class DoubleBinaryLambda implements DoubleBinaryOperator {
-
- @Override
- public double applyAsDouble(double left, double right) {
- MapContext context = new MapContext();
- if (arguments.size() > 0)
- context.put(arguments.get(0), left);
- if (arguments.size() > 1)
- context.put(arguments.get(1), right);
- return evaluate(context).asDouble();
- }
-
- @Override
- public String toString() {
- return LambdaFunctionNode.this.toString();
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
deleted file mode 100644
index 26d3f1dcc0e..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ /dev/null
@@ -1,111 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.EvaluationContext;
-import com.yahoo.tensor.functions.PrimitiveTensorFunction;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.functions.ToStringContext;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * A node which performs a tensor function
- *
- * @author bratseth
- */
- @Beta
-public class TensorFunctionNode extends CompositeNode {
-
- private final TensorFunction function;
-
- public TensorFunctionNode(TensorFunction function) {
- this.function = function;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return function.functionArguments().stream()
- .map(f -> ((TensorFunctionExpressionNode)f).expression)
- .collect(Collectors.toList());
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- // Serialize as primitive
- return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
- }
-
- @Override
- public Value evaluate(Context context) {
- return new TensorValue(function.evaluate(context));
- }
-
- public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
- return new TensorFunctionExpressionNode(node);
- }
-
- /**
- * A tensor function implemented by an expression.
- * This allows us to pass expressions as tensor function arguments.
- */
- public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {
-
- /** An expression which produces a tensor */
- private final ExpressionNode expression;
-
- public TensorFunctionExpressionNode(ExpressionNode expression) {
- this.expression = expression;
- }
-
- @Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
-
- @Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
-
- @Override
- public Tensor evaluate(EvaluationContext context) {
- Value result = expression.evaluate((Context)context);
- if ( ! ( result instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
- "but this returns " + result + ", not a tensor");
- return ((TensorValue)result).asTensor();
- }
-
- @Override
- public String toString(ToStringContext c) {
- ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c;
- return expression.toString(context.context, context.path, context.parent);
- }
-
- }
-
- /** Allows passing serialization context arguments through TensorFunctions */
- private static class ExpressionNodeToStringContext implements ToStringContext {
-
- final SerializationContext context;
- final Deque<String> path;
- final CompositeNode parent;
-
- public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
- this.context = context;
- this.path = path;
- this.parent = parent;
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
new file mode 100644
index 00000000000..af309b3e8d8
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
@@ -0,0 +1,59 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+ @Beta
+public class TensorMatchNode extends CompositeNode {
+
+ private final ExpressionNode left, right;
+
+ public TensorMatchNode(ExpressionNode left, ExpressionNode right) {
+ this.left = left;
+ this.right = right;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ List<ExpressionNode> children = new ArrayList<>(2);
+ children.add(left);
+ children.add(right);
+ return children;
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if ( children.size() != 2)
+ throw new IllegalArgumentException("A match product must have two children");
+ return new TensorMatchNode(children.get(0), children.get(1));
+
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")";
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context)));
+ }
+
+ private TensorValue asTensor(Value value) {
+ if ( ! (value instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " +
+ "not a tensor: " + value);
+ return (TensorValue)value;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
new file mode 100644
index 00000000000..a1f83157e20
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
@@ -0,0 +1,65 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * A node which sums over all cells in the argument tensor
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorSumNode extends CompositeNode {
+
+ /** The tensor to sum */
+ private final ExpressionNode argument;
+
+ /** The dimension to sum over, or empty to sum all cells to a scalar */
+ private final Optional<String> dimension;
+
+ public TensorSumNode(ExpressionNode argument, Optional<String> dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(argument);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument");
+ return new TensorSumNode(children.get(0), dimension);
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "sum(" +
+ argument.toString(context, path, parent) +
+ ( dimension.isPresent() ? ", " + dimension.get() : "" ) +
+ ")";
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Value argumentValue = argument.evaluate(context);
+ if ( ! ( argumentValue instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " +
+ "but this returns " + argumentValue + ", not a tensor");
+ TensorValue tensorArgument = (TensorValue)argumentValue;
+ if (dimension.isPresent())
+ return tensorArgument.sum(dimension.get());
+ else
+ return tensorArgument.sum();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
index 932975f3b63..60fe19f909f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
@@ -15,8 +15,7 @@ public enum TruthOperator implements Serializable {
EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
- NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
+ LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } };
private final String operatorString;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 0fcfdb5d40c..78ad665c414 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -21,9 +21,10 @@ import com.yahoo.searchlib.rankingexpression.rule.*;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.*;
-import com.yahoo.tensor.functions.*;
+import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.TensorAddress;
import java.util.Collections;
+import java.util.Map;
import java.util.LinkedHashMap;
import java.util.Arrays;
import java.util.ArrayList;
@@ -59,83 +60,51 @@ TOKEN :
<RSQUARE: "]"> |
<LCURLY: "{"> |
<RCURLY: "}"> |
-
<ADD: "+"> |
<SUB: "-"> |
<DIV: "/"> |
<MUL: "*"> |
<DOT: "."> |
-
<DOLLAR: "$"> |
<COMMA: ","> |
<COLON: ":"> |
-
<LE: "<="> |
<LT: "<"> |
<EQ: "=="> |
- <NQ: "!="> |
<AQ: "~="> |
<GE: ">="> |
<GT: ">"> |
-
<STRING: ("\"" (~["\""] | "\\\"")* "\"") |
("'" (~["'"] | "\\'")* "'")> |
-
<IF: "if"> |
- <IN: "in"> |
- <F: "f"> |
-
- <ABS: "abs"> |
+ <COSH: "cosh"> |
+ <SINH: "sinh"> |
+ <TANH: "tanh"> |
+ <COS: "cos"> |
+ <SIN: "sin"> |
+ <TAN: "tan"> |
<ACOS: "acos"> |
<ASIN: "asin"> |
+ <ATAN2: "atan2"> |
<ATAN: "atan"> |
- <CEIL: "ceil"> |
- <COS: "cos"> |
- <COSH: "cosh"> |
- <ELU: "elu"> |
<EXP: "exp"> |
+ <LDEXP: "ldexp"> |
+ <LOG10: "log10"> |
+ <LOG: "log"> |
+ <POW: "pow"> |
+ <SQRT: "sqrt"> |
+ <CEIL: "ceil"> |
<FABS: "fabs"> |
<FLOOR: "floor"> |
+ <FMOD: "fmod"> |
+ <MIN: "min"> |
+ <MAX: "max"> |
<ISNAN: "isNan"> |
- <LOG: "log"> |
- <LOG10: "log10"> |
+ <IN: "in"> |
+ <SUM: "sum"> |
+ <MATCH: "match"> |
<RELU: "relu"> |
- <ROUND: "round"> |
<SIGMOID: "sigmoid"> |
- <SIGN: "sign"> |
- <SIN: "sin"> |
- <SINH: "sinh"> |
- <SQUARE: "square"> |
- <SQRT: "sqrt"> |
- <TAN: "tan"> |
- <TANH: "tanh"> |
-
- <ATAN2: "atan2"> |
- <FMOD: "fmod"> |
- <LDEXP: "ldexp"> |
- // MAX
- // MIN
- <MOD: "mod"> |
- <POW: "pow"> |
-
- <MAP: "map"> |
- <REDUCE: "reduce"> |
- <JOIN: "join"> |
- <RENAME: "rename"> |
- <TENSOR: "tensor"> |
- <L1_NORMALIZE: "l1_normalize"> |
- <L2_NORMALIZE: "l2_normalize"> |
- <MATMUL: "matmul"> |
- <SOFTMAX: "softmax"> |
- <XW_PLUS_B: "xw_plus_b"> |
-
- <AVG: "avg" > |
- <COUNT: "count"> |
- <PROD: "prod"> |
- <SUM: "sum"> |
- <MAX: "max"> |
- <MIN: "min"> |
-
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
}
@@ -206,7 +175,6 @@ TruthOperator comparator() : { }
( <LE> { return TruthOperator.SMALLEREQUAL; } |
<LT> { return TruthOperator.SMALLER; } |
<EQ> { return TruthOperator.EQUAL; } |
- <NQ> { return TruthOperator.NOTEQUAL; } |
<AQ> { return TruthOperator.APPROX_EQUAL; } |
<GE> { return TruthOperator.LARGEREQUAL; } |
<GT> { return TruthOperator.LARGER; } )
@@ -221,6 +189,7 @@ ExpressionNode value() :
{
( [ LOOKAHEAD(2) <SUB> { neg = true; } ]
( ret = constantPrimitive() |
+ ret = constantTensor() |
LOOKAHEAD(2) ret = ifExpression() |
LOOKAHEAD(2) ret = function() |
ret = feature() |
@@ -310,6 +279,7 @@ ExpressionNode arg() :
}
{
( ret = constantPrimitive() |
+ ret = constantTensor() |
LOOKAHEAD(2) ret = feature() |
name = identifier() { ret = new NameNode(name); } )
{ return ret; }
@@ -320,11 +290,11 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( function = scalarOrTensorFunction() | function = tensorFunction() )
+ ( function = scalarFunction() | function = tensorFunction() )
{ return function; }
}
-FunctionNode scalarOrTensorFunction() :
+FunctionNode scalarFunction() :
{
Function function;
ExpressionNode arg1, arg2;
@@ -342,223 +312,61 @@ FunctionNode scalarOrTensorFunction() :
ExpressionNode tensorFunction() :
{
- ExpressionNode tensorExpression;
-}
-{
- (
- tensorExpression = tensorMap() |
- tensorExpression = tensorReduce() |
- tensorExpression = tensorReduceComposites() |
- tensorExpression = tensorJoin() |
- tensorExpression = tensorRename() |
- tensorExpression = tensorGenerate() |
- tensorExpression = tensorL1Normalize() |
- tensorExpression = tensorL2Normalize() |
- tensorExpression = tensorMatmul() |
- tensorExpression = tensorSoftmax() |
- tensorExpression = tensorXwPlusB()
- )
- { return tensorExpression; }
-}
-
-ExpressionNode tensorMap() :
-{
- ExpressionNode tensor;
- LambdaFunctionNode doubleMapper;
-}
-{
- <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
- doubleMapper.asDoubleUnaryOperator())); }
-}
-
-ExpressionNode tensorReduce() :
-{
- ExpressionNode tensor;
- Reduce.Aggregator aggregator;
- List<String> dimensions = null;
-}
-{
- <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
-}
-
-ExpressionNode tensorReduceComposites() :
-{
- ExpressionNode tensor;
- Reduce.Aggregator aggregator;
- List<String> dimensions = null;
-}
-{
- aggregator = tensorReduceAggregator()
- <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
-}
-
-ExpressionNode tensorJoin() :
-{
ExpressionNode tensor1, tensor2;
- LambdaFunctionNode doubleJoiner;
+ String dimension = null;
+ TensorAddress address = null;
}
{
- <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- doubleJoiner.asDoubleBinaryOperator())); }
-}
-
-ExpressionNode tensorRename() :
-{
- ExpressionNode tensor;
- List<String> fromDimensions, toDimensions;
-}
-{
- <RENAME> <LBRACE> tensor = expression() <COMMA>
- fromDimensions = bracedIdentifierList() <COMMA>
- toDimensions = bracedIdentifierList()
- <RBRACE>
- { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
-}
-
-// TODO: Notice that null is parsed below
-ExpressionNode tensorGenerate() :
-{
- TensorType type;
- LambdaFunctionNode generator;
-}
-{
- <TENSOR> <LBRACE> <RBRACE> <LBRACE>
- { return new TensorFunctionNode(new Generate(null, null)); }
-}
-
-ExpressionNode tensorL1Normalize() :
-{
- ExpressionNode tensor;
- String dimension;
-}
-{
- <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
-}
-
-ExpressionNode tensorL2Normalize() :
-{
- ExpressionNode tensor;
- String dimension;
-}
-{
- <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
-}
-
-ExpressionNode tensorMatmul() :
-{
- ExpressionNode tensor1, tensor2;
- String dimension;
-}
-{
- <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- dimension)); }
-}
-
-ExpressionNode tensorSoftmax() :
-{
- ExpressionNode tensor;
- String dimension;
-}
-{
- <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
-}
-
-ExpressionNode tensorXwPlusB() :
-{
- ExpressionNode tensor1, tensor2, tensor3;
- String dimension;
-}
-{
- <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA>
- tensor2 = expression() <COMMA>
- tensor3 = expression() <COMMA>
- dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- TensorFunctionNode.wrapArgument(tensor3),
- dimension)); }
-}
-
-LambdaFunctionNode lambdaFunction() :
-{
- List<String> variables;
- ExpressionNode functionExpression;
-}
-{
- ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
- { return new LambdaFunctionNode(variables, functionExpression); }
-}
-
-Reduce.Aggregator tensorReduceAggregator() :
-{
-}
-{
- ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
- { return Reduce.Aggregator.valueOf(token.image); }
+ (
+ <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE>
+ { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); }
+ ) |
+ (
+ <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE>
+ { return new TensorMatchNode(tensor1, tensor2); }
+ )
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
- Reduce.Aggregator aggregator;
}
{
- ( <F> { return token.image; } ) |
- ( <MAP> { return token.image; } ) |
- ( <REDUCE> { return token.image; } ) |
- ( <JOIN> { return token.image; } ) |
- ( <RENAME> { return token.image; } ) |
- ( <TENSOR> { return token.image; } ) |
- ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
+ ( <SUM> | <MATCH> )
+ { return token.image; }
}
Function unaryFunctionName() : { }
{
- <ABS> { return Function.abs; } |
+ <COS> { return Function.cos; } |
+ <SIN> { return Function.sin; } |
+ <TAN> { return Function.tan; } |
+ <COSH> { return Function.cosh; } |
+ <SINH> { return Function.sinh; } |
+ <TANH> { return Function.tanh; } |
<ACOS> { return Function.acos; } |
<ASIN> { return Function.asin; } |
<ATAN> { return Function.atan; } |
- <CEIL> { return Function.ceil; } |
- <COS> { return Function.cos; } |
- <COSH> { return Function.cosh; } |
- <ELU> { return Function.elu; } |
<EXP> { return Function.exp; } |
+ <LOG10> { return Function.log10; } |
+ <LOG> { return Function.log; } |
+ <SQRT> { return Function.sqrt; } |
+ <CEIL> { return Function.ceil; } |
<FABS> { return Function.fabs; } |
<FLOOR> { return Function.floor; } |
<ISNAN> { return Function.isNan; } |
- <LOG> { return Function.log; } |
- <LOG10> { return Function.log10; } |
<RELU> { return Function.relu; } |
- <ROUND> { return Function.round; } |
- <SIGMOID> { return Function.sigmoid; } |
- <SIGN> { return Function.sign; } |
- <SIN> { return Function.sin; } |
- <SINH> { return Function.sinh; } |
- <SQUARE> { return Function.square; } |
- <SQRT> { return Function.sqrt; } |
- <TAN> { return Function.tan; } |
- <TANH> { return Function.tanh; }
+ <SIGMOID> { return Function.sigmoid; }
}
Function binaryFunctionName() : { }
{
<ATAN2> { return Function.atan2; } |
- <FMOD> { return Function.fmod; } |
<LDEXP> { return Function.ldexp; } |
- <MAX> { return Function.max; } |
+ <POW> { return Function.pow; } |
+ <FMOD> { return Function.fmod; } |
<MIN> { return Function.min; } |
- <MOD> { return Function.mod; } |
- <POW> { return Function.pow; }
+ <MAX> { return Function.max; }
}
List<ExpressionNode> expressionList() :
@@ -597,28 +405,6 @@ String identifier() :
<IDENTIFIER> { return token.image; }
}
-List<String> identifierList() :
-{
- List<String> list = new ArrayList<String>();
- String element;
-}
-{
- ( element = identifier() { list.add(element); } )?
- ( <COMMA> element = identifier() { list.add(element); } ) *
- { return list; }
-}
-
-List<String> bracedIdentifierList() :
-{
- List<String> list = new ArrayList<String>();
- String element;
-}
-{
- ( element = identifier() { return Collections.singletonList(element); } )
- |
- ( <LBRACE> list = identifierList() <RBRACE> { return list; } )
-}
-
// An identifier or integer
String tag() :
{
@@ -629,16 +415,6 @@ String tag() :
<INTEGER> { return token.image; }
}
-List<String> tagCommaLeadingList() :
-{
- List<String> list = new ArrayList<String>();
- String element;
-}
-{
- ( <COMMA> element = tag() { list.add(element); } ) *
- { return list; }
-}
-
ConstantNode constantPrimitive() :
{
String sign = "";
@@ -658,3 +434,50 @@ Value primitiveValue() :
( <INTEGER> | <FLOAT> | <STRING> )
{ return Value.parse(sign + token.image); }
}
+
+ConstantNode constantTensor() :
+{
+ Value constantValue;
+}
+{
+ <LCURLY> constantValue = tensorContent() <RCURLY>
+ { return new ConstantNode(constantValue); }
+}
+
+TensorValue tensorContent() :
+{
+ Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>();
+ TensorAddress address;
+ Double value;
+}
+{
+ ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ?
+ ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) *
+ { return new TensorValue(new MapTensor(cells)); }
+}
+
+TensorAddress tensorAddress() :
+{
+ List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>();
+ String dimension;
+ String label;
+}
+{
+ <LCURLY>
+ ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ?
+ ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) *
+ <RCURLY>
+ { return TensorAddress.fromUnsorted(elements); }
+}
+
+String label() :
+{
+ String label;
+
+}
+{
+ ( label = tag() |
+ ( "-" { label = "-"; } ) )
+ { return label; }
+}
+