summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2016-11-26 22:45:20 +0100
committerGitHub <noreply@github.com>2016-11-26 22:45:20 +0100
commit2f55986b4de9420e5728c5abbaafb69fb2f10a34 (patch)
tree9a6a77f76d25620771dfe7ab5de49910c4321fc5 /searchlib/src/main
parent2bc82ba9d9698214e703f19039387609d82b12f8 (diff)
Revert "Revert "Bratseth/tensor functions 3""
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java70
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java122
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java111
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java59
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java65
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java3
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj375
15 files changed, 583 insertions, 298 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 0dff0414ac2..620c6fad0b4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.tensor.functions.EvaluationContext;
import java.util.Set;
@@ -10,7 +11,7 @@ import java.util.Set;
*
* @author bratseth
*/
-public abstract class Context {
+public abstract class Context implements EvaluationContext {
/**
* <p>Returns the value of a simple variable name.</p>
@@ -41,7 +42,7 @@ public abstract class Context {
* "main" (or only) value.
*/
public Value get(String name, Arguments arguments,String output) {
- if (arguments!=null && arguments.expressions().size()>0)
+ if (arguments!=null && arguments.expressions().size() > 0)
throw new UnsupportedOperationException(this + " does not support structured ranking expression variables, attempted to reference '" +
name + arguments + "'");
if (output==null)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index 2bae382d5bd..f8dcd8a6127 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -39,8 +39,8 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- return operator.evaluate(asDouble(), value.asDouble());
+ public Value compare(TruthOperator operator, Value value) {
+ return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 028dad16d21..0e0d793bfd1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -98,16 +98,6 @@ public final class DoubleValue extends DoubleCompatibleValue {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- try {
- return operator.evaluate(this.value, value.asDouble());
- }
- catch (UnsupportedOperationException e) {
- throw unsupported("comparison",value);
- }
- }
-
- @Override
public Value function(Function function, Value value) {
// use the tensor implementation of max and min if the argument is a tensor
if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 9ee9a1f7a71..2dffe2a1100 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -34,11 +34,9 @@ public class MapContext extends Context {
* Creates a map context from a map.
* The ownership of the map is transferred to this - it cannot be further modified by the caller.
* All the Values of the map will be frozen.
- *
- * @since 5.1.5
*/
public MapContext(Map<String,Value> bindings) {
- this.bindings=bindings;
+ this.bindings = bindings;
for (Value boundValue : bindings.values())
boundValue.freeze();
}
@@ -67,6 +65,9 @@ public class MapContext extends Context {
if (frozen) return bindings;
return Collections.unmodifiableMap(bindings);
}
+
+ /** Returns a new, modifiable context containing all the bindings of this */
+ public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); }
/** Returns an unmodifiable map of the names of this */
public @Override Set<String> names() {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 379b5755c7b..eb997ab818a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -68,10 +68,10 @@ public class StringValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
+ public Value compare(TruthOperator operator, Value value) {
if (operator.equals(TruthOperator.EQUAL))
- return this.equals(value);
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '='");
+ return new BooleanValue(this.equals(value));
+ throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 12bede95aae..b1f4a7b20ca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.TensorType;
+import java.util.Collections;
import java.util.Optional;
/**
@@ -17,7 +18,7 @@ import java.util.Optional;
*
* @author bratseth
*/
- @Beta
+@Beta
public class TensorValue extends Value {
/** The tensor value of this */
@@ -53,7 +54,7 @@ public class TensorValue extends Value {
@Override
public Value negate() {
- return new TensorValue(value.apply((Double value) -> -value));
+ return new TensorValue(value.map((value) -> -value));
}
@Override
@@ -61,7 +62,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
else
- return new TensorValue(value.apply((Double value) -> value + argument.asDouble()));
+ return new TensorValue(value.map((value) -> value + argument.asDouble()));
}
@Override
@@ -69,7 +70,7 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.subtract(((TensorValue) argument).value));
else
- return new TensorValue(value.apply((Double value) -> value - argument.asDouble()));
+ return new TensorValue(value.map((value) -> value - argument.asDouble()));
}
@Override
@@ -77,35 +78,15 @@ public class TensorValue extends Value {
if (argument instanceof TensorValue)
return new TensorValue(value.multiply(((TensorValue) argument).value));
else
- return new TensorValue(value.apply((Double value) -> value * argument.asDouble()));
+ return new TensorValue(value.map((value) -> value * argument.asDouble()));
}
@Override
public Value divide(Value argument) {
if (argument instanceof TensorValue)
- throw new UnsupportedOperationException("Two tensors cannot be divided");
+ return new TensorValue(value.divide(((TensorValue) argument).value));
else
- return new TensorValue(value.apply((Double value) -> value / argument.asDouble()));
- }
-
- public Value match(Value argument) {
- return new TensorValue(value.match(asTensor(argument, "match")));
- }
-
- public Value min(Value argument) {
- return new TensorValue(value.min(asTensor(argument, "min")));
- }
-
- public Value max(Value argument) {
- return new TensorValue(value.max(asTensor(argument, "max")));
- }
-
- public Value sum(String dimension) {
- return new TensorValue(value.sum(dimension));
- }
-
- public Value sum() {
- return new DoubleValue(value.sum());
+ return new TensorValue(value.map((value) -> value / argument.asDouble()));
}
private Tensor asTensor(Value value, String operationName) {
@@ -122,18 +103,37 @@ public class TensorValue extends Value {
}
@Override
- public boolean compare(TruthOperator operator, Value value) {
- throw new UnsupportedOperationException("A tensor cannot be compared with any value");
+ public Value compare(TruthOperator operator, Value argument) {
+ return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
+ }
+
+ private Tensor compareTensor(TruthOperator operator, Tensor argument) {
+ switch (operator) {
+ case LARGER: return value.larger(argument);
+ case LARGEREQUAL: return value.largerOrEqual(argument);
+ case SMALLER: return value.smaller(argument);
+ case SMALLEREQUAL: return value.smallerOrEqual(argument);
+ case EQUAL: return value.equal(argument);
+ case NOTEQUAL: return value.notEqual(argument);
+ default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
+ }
}
@Override
- public Value function(Function function, Value argument) {
- if (function.equals(Function.min) && argument instanceof TensorValue)
- return min(argument);
- else if (function.equals(Function.max) && argument instanceof TensorValue)
- return max(argument);
+ public Value function(Function function, Value arg) {
+ if (arg instanceof TensorValue)
+ return new TensorValue(functionOnTensor(function, asTensor(arg, function.toString())));
else
- return new TensorValue(value.apply((Double value) -> function.evaluate(value, argument.asDouble())));
+ return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
+ }
+
+ private Tensor functionOnTensor(Function function, Tensor argument) {
+ switch (function) {
+ case min: return value.min(argument);
+ case max: return value.max(argument);
+ case atan2: return value.atan2(argument);
+ default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ }
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index e5680edc68a..8ce18265231 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -42,7 +42,7 @@ public abstract class Value {
public abstract Value divide(Value value);
/** Perform the comparison specified by the operator between this value and the given value */
- public abstract boolean compare(TruthOperator operator,Value value);
+ public abstract Value compare(TruthOperator operator, Value value);
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function,Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 882d16ebc1c..af05acb365a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -8,10 +8,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import java.util.*;
/**
- * A node which returns true or false depending on the outcome of a comparison.
+ * A node which returns the outcome of a comparison.
*
* @author bratseth
- * @since 5.1.21
*/
public class ComparisonNode extends BooleanNode {
@@ -48,9 +47,9 @@ public class ComparisonNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
- Value leftValue=leftCondition.evaluate(context);
- Value rightValue=rightCondition.evaluate(context);
- return new BooleanValue(leftValue.compare(operator,rightValue));
+ Value leftValue = leftCondition.evaluate(context);
+ Value rightValue = rightCondition.evaluate(context);
+ return leftValue.compare(operator,rightValue);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index 675ce758faa..19b1a83ed99 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -12,31 +12,38 @@ import static java.lang.Math.*;
*/
public enum Function implements Serializable {
- cosh { public double evaluate(double x, double y) { return cosh(x); } },
- sinh { public double evaluate(double x, double y) { return sinh(x); } },
- tanh { public double evaluate(double x, double y) { return tanh(x); } },
- cos { public double evaluate(double x, double y) { return cos(x); } },
- sin { public double evaluate(double x, double y) { return sin(x); } },
- tan { public double evaluate(double x, double y) { return tan(x); } },
+ abs { public double evaluate(double x, double y) { return abs(x); } },
acos { public double evaluate(double x, double y) { return acos(x); } },
asin { public double evaluate(double x, double y) { return asin(x); } },
atan { public double evaluate(double x, double y) { return atan(x); } },
- exp { public double evaluate(double x, double y) { return exp(x); } },
- log10 { public double evaluate(double x, double y) { return log10(x); } },
- log { public double evaluate(double x, double y) { return log(x); } },
- sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
ceil { public double evaluate(double x, double y) { return ceil(x); } },
+ cos { public double evaluate(double x, double y) { return cos(x); } },
+ cosh { public double evaluate(double x, double y) { return cosh(x); } },
+ elu { public double evaluate(double x, double y) { return x<0 ? exp(x)-1 : x; } },
+ exp { public double evaluate(double x, double y) { return exp(x); } },
fabs { public double evaluate(double x, double y) { return abs(x); } },
floor { public double evaluate(double x, double y) { return floor(x); } },
isNan { public double evaluate(double x, double y) { return Double.isNaN(x) ? 1.0 : 0.0; } },
+ log { public double evaluate(double x, double y) { return log(x); } },
+ log10 { public double evaluate(double x, double y) { return log10(x); } },
relu { public double evaluate(double x, double y) { return max(x,0); } },
+ round { public double evaluate(double x, double y) { return round(x); } },
sigmoid { public double evaluate(double x, double y) { return 1.0 / (1.0 + exp(-1.0 * x)); } },
+ sign { public double evaluate(double x, double y) { return x >= 0 ? 1 : -1; } },
+ sin { public double evaluate(double x, double y) { return sin(x); } },
+ sinh { public double evaluate(double x, double y) { return sinh(x); } },
+ square { public double evaluate(double x, double y) { return x*x; } },
+ sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
+ tan { public double evaluate(double x, double y) { return tan(x); } },
+ tanh { public double evaluate(double x, double y) { return tanh(x); } },
+
atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } },
- pow(2) { public double evaluate(double x, double y) { return pow(x,y); } },
- ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } },
fmod(2) { public double evaluate(double x, double y) { return IEEEremainder(x,y); } },
+ ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } },
+ max(2) { public double evaluate(double x, double y) { return max(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
- max(2) { public double evaluate(double x, double y) { return max(x,y); } };
+ mod(2) { public double evaluate(double x, double y) { return x % y; } },
+ pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
private final int arity;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
new file mode 100644
index 00000000000..7b48288598d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -0,0 +1,122 @@
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * A free, parametrized function
+ *
+ * @author bratseth
+ */
+public class LambdaFunctionNode extends CompositeNode {
+
+ private final ImmutableList<String> arguments;
+ private final ExpressionNode functionExpression;
+
+ public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
+ // TODO: Verify that the function only accesses the arguments in mapperVariables
+ this.arguments = ImmutableList.copyOf(arguments);
+ this.functionExpression = functionExpression;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(functionExpression);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if ( children.size() != 1)
+ throw new IllegalArgumentException("A lambda function must have a single child expression");
+ return new LambdaFunctionNode(arguments, children.get(0));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")";
+ }
+
+ private String commaSeparated(List<String> list) {
+ StringBuilder b = new StringBuilder();
+ for (String element : list)
+ b.append(element).append(",");
+ if (b.length() > 0)
+ b.setLength(b.length()-1);
+ return b.toString();
+ }
+
+ /** Evaluate this in a context which must have the arguments bound */
+ @Override
+ public Value evaluate(Context context) {
+ return functionExpression.evaluate(context);
+ }
+
+ /**
+ * Returns this as a double unary operator
+ *
+ * @throws IllegalStateException if this has more than one argument
+ */
+ public DoubleUnaryOperator asDoubleUnaryOperator() {
+ if (arguments.size() > 1)
+ throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
+ "Must have at most one argument " + " but has " + arguments);
+ return new DoubleUnaryLambda();
+ }
+
+ /**
+ * Returns this as a double binary operator
+ *
+ * @throws IllegalStateException if this has more than two arguments
+ */
+ public DoubleBinaryOperator asDoubleBinaryOperator() {
+ if (arguments.size() > 2)
+ throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
+ "Must have at most two argument " + " but has " + arguments);
+ return new DoubleBinaryLambda();
+ }
+
+ private class DoubleUnaryLambda implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) {
+ MapContext context = new MapContext();
+ if (arguments.size() > 0)
+ context.put(arguments.get(0), operand);
+ return evaluate(context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+
+ }
+
+ private class DoubleBinaryLambda implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) {
+ MapContext context = new MapContext();
+ if (arguments.size() > 0)
+ context.put(arguments.get(0), left);
+ if (arguments.size() > 1)
+ context.put(arguments.get(1), right);
+ return evaluate(context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
new file mode 100644
index 00000000000..26d3f1dcc0e
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -0,0 +1,111 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.EvaluationContext;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * A node which performs a tensor function
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorFunctionNode extends CompositeNode {
+
+ private final TensorFunction function;
+
+ public TensorFunctionNode(TensorFunction function) {
+ this.function = function;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return function.functionArguments().stream()
+ .map(f -> ((TensorFunctionExpressionNode)f).expression)
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ // Serialize as primitive
+ return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ return new TensorValue(function.evaluate(context));
+ }
+
+ public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
+ return new TensorFunctionExpressionNode(node);
+ }
+
+ /**
+ * A tensor function implemented by an expression.
+ * This allows us to pass expressions as tensor function arguments.
+ */
+ public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {
+
+ /** An expression which produces a tensor */
+ private final ExpressionNode expression;
+
+ public TensorFunctionExpressionNode(ExpressionNode expression) {
+ this.expression = expression;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext context) {
+ Value result = expression.evaluate((Context)context);
+ if ( ! ( result instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
+ "but this returns " + result + ", not a tensor");
+ return ((TensorValue)result).asTensor();
+ }
+
+ @Override
+ public String toString(ToStringContext c) {
+ ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c;
+ return expression.toString(context.context, context.path, context.parent);
+ }
+
+ }
+
+ /** Allows passing serialization context arguments through TensorFunctions */
+ private static class ExpressionNodeToStringContext implements ToStringContext {
+
+ final SerializationContext context;
+ final Deque<String> path;
+ final CompositeNode parent;
+
+ public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ this.context = context;
+ this.path = path;
+ this.parent = parent;
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
deleted file mode 100644
index af309b3e8d8..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMatchNode.java
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.ArrayList;
-import java.util.Deque;
-import java.util.List;
-
-/**
- * @author bratseth
- */
- @Beta
-public class TensorMatchNode extends CompositeNode {
-
- private final ExpressionNode left, right;
-
- public TensorMatchNode(ExpressionNode left, ExpressionNode right) {
- this.left = left;
- this.right = right;
- }
-
- @Override
- public List<ExpressionNode> children() {
- List<ExpressionNode> children = new ArrayList<>(2);
- children.add(left);
- children.add(right);
- return children;
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if ( children.size() != 2)
- throw new IllegalArgumentException("A match product must have two children");
- return new TensorMatchNode(children.get(0), children.get(1));
-
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "match(" + left.toString(context, path, parent) + ", " + right.toString(context, path, parent) + ")";
- }
-
- @Override
- public Value evaluate(Context context) {
- return asTensor(left.evaluate(context)).match(asTensor(right.evaluate(context)));
- }
-
- private TensorValue asTensor(Value value) {
- if ( ! (value instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to take the tensor product with an argument which is " +
- "not a tensor: " + value);
- return (TensorValue)value;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
deleted file mode 100644
index a1f83157e20..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-import java.util.Optional;
-
-/**
- * A node which sums over all cells in the argument tensor
- *
- * @author bratseth
- */
- @Beta
-public class TensorSumNode extends CompositeNode {
-
- /** The tensor to sum */
- private final ExpressionNode argument;
-
- /** The dimension to sum over, or empty to sum all cells to a scalar */
- private final Optional<String> dimension;
-
- public TensorSumNode(ExpressionNode argument, Optional<String> dimension) {
- this.argument = argument;
- this.dimension = dimension;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return Collections.singletonList(argument);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument");
- return new TensorSumNode(children.get(0), dimension);
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "sum(" +
- argument.toString(context, path, parent) +
- ( dimension.isPresent() ? ", " + dimension.get() : "" ) +
- ")";
- }
-
- @Override
- public Value evaluate(Context context) {
- Value argumentValue = argument.evaluate(context);
- if ( ! ( argumentValue instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " +
- "but this returns " + argumentValue + ", not a tensor");
- TensorValue tensorArgument = (TensorValue)argumentValue;
- if (dimension.isPresent())
- return tensorArgument.sum(dimension.get());
- else
- return tensorArgument.sum();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
index 60fe19f909f..932975f3b63 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
@@ -15,7 +15,8 @@ public enum TruthOperator implements Serializable {
EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } };
+ LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
+ NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
private final String operatorString;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 78ad665c414..0fcfdb5d40c 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -21,10 +21,9 @@ import com.yahoo.searchlib.rankingexpression.rule.*;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.*;
+import com.yahoo.tensor.functions.*;
import java.util.Collections;
-import java.util.Map;
import java.util.LinkedHashMap;
import java.util.Arrays;
import java.util.ArrayList;
@@ -60,51 +59,83 @@ TOKEN :
<RSQUARE: "]"> |
<LCURLY: "{"> |
<RCURLY: "}"> |
+
<ADD: "+"> |
<SUB: "-"> |
<DIV: "/"> |
<MUL: "*"> |
<DOT: "."> |
+
<DOLLAR: "$"> |
<COMMA: ","> |
<COLON: ":"> |
+
<LE: "<="> |
<LT: "<"> |
<EQ: "=="> |
+ <NQ: "!="> |
<AQ: "~="> |
<GE: ">="> |
<GT: ">"> |
+
<STRING: ("\"" (~["\""] | "\\\"")* "\"") |
("'" (~["'"] | "\\'")* "'")> |
+
<IF: "if"> |
- <COSH: "cosh"> |
- <SINH: "sinh"> |
- <TANH: "tanh"> |
- <COS: "cos"> |
- <SIN: "sin"> |
- <TAN: "tan"> |
+ <IN: "in"> |
+ <F: "f"> |
+
+ <ABS: "abs"> |
<ACOS: "acos"> |
<ASIN: "asin"> |
- <ATAN2: "atan2"> |
<ATAN: "atan"> |
- <EXP: "exp"> |
- <LDEXP: "ldexp"> |
- <LOG10: "log10"> |
- <LOG: "log"> |
- <POW: "pow"> |
- <SQRT: "sqrt"> |
<CEIL: "ceil"> |
+ <COS: "cos"> |
+ <COSH: "cosh"> |
+ <ELU: "elu"> |
+ <EXP: "exp"> |
<FABS: "fabs"> |
<FLOOR: "floor"> |
- <FMOD: "fmod"> |
- <MIN: "min"> |
- <MAX: "max"> |
<ISNAN: "isNan"> |
- <IN: "in"> |
- <SUM: "sum"> |
- <MATCH: "match"> |
+ <LOG: "log"> |
+ <LOG10: "log10"> |
<RELU: "relu"> |
+ <ROUND: "round"> |
<SIGMOID: "sigmoid"> |
+ <SIGN: "sign"> |
+ <SIN: "sin"> |
+ <SINH: "sinh"> |
+ <SQUARE: "square"> |
+ <SQRT: "sqrt"> |
+ <TAN: "tan"> |
+ <TANH: "tanh"> |
+
+ <ATAN2: "atan2"> |
+ <FMOD: "fmod"> |
+ <LDEXP: "ldexp"> |
+ // MAX
+ // MIN
+ <MOD: "mod"> |
+ <POW: "pow"> |
+
+ <MAP: "map"> |
+ <REDUCE: "reduce"> |
+ <JOIN: "join"> |
+ <RENAME: "rename"> |
+ <TENSOR: "tensor"> |
+ <L1_NORMALIZE: "l1_normalize"> |
+ <L2_NORMALIZE: "l2_normalize"> |
+ <MATMUL: "matmul"> |
+ <SOFTMAX: "softmax"> |
+ <XW_PLUS_B: "xw_plus_b"> |
+
+ <AVG: "avg" > |
+ <COUNT: "count"> |
+ <PROD: "prod"> |
+ <SUM: "sum"> |
+ <MAX: "max"> |
+ <MIN: "min"> |
+
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
}
@@ -175,6 +206,7 @@ TruthOperator comparator() : { }
( <LE> { return TruthOperator.SMALLEREQUAL; } |
<LT> { return TruthOperator.SMALLER; } |
<EQ> { return TruthOperator.EQUAL; } |
+ <NQ> { return TruthOperator.NOTEQUAL; } |
<AQ> { return TruthOperator.APPROX_EQUAL; } |
<GE> { return TruthOperator.LARGEREQUAL; } |
<GT> { return TruthOperator.LARGER; } )
@@ -189,7 +221,6 @@ ExpressionNode value() :
{
( [ LOOKAHEAD(2) <SUB> { neg = true; } ]
( ret = constantPrimitive() |
- ret = constantTensor() |
LOOKAHEAD(2) ret = ifExpression() |
LOOKAHEAD(2) ret = function() |
ret = feature() |
@@ -279,7 +310,6 @@ ExpressionNode arg() :
}
{
( ret = constantPrimitive() |
- ret = constantTensor() |
LOOKAHEAD(2) ret = feature() |
name = identifier() { ret = new NameNode(name); } )
{ return ret; }
@@ -290,11 +320,11 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( function = scalarFunction() | function = tensorFunction() )
+ ( function = scalarOrTensorFunction() | function = tensorFunction() )
{ return function; }
}
-FunctionNode scalarFunction() :
+FunctionNode scalarOrTensorFunction() :
{
Function function;
ExpressionNode arg1, arg2;
@@ -312,61 +342,223 @@ FunctionNode scalarFunction() :
ExpressionNode tensorFunction() :
{
+ ExpressionNode tensorExpression;
+}
+{
+ (
+ tensorExpression = tensorMap() |
+ tensorExpression = tensorReduce() |
+ tensorExpression = tensorReduceComposites() |
+ tensorExpression = tensorJoin() |
+ tensorExpression = tensorRename() |
+ tensorExpression = tensorGenerate() |
+ tensorExpression = tensorL1Normalize() |
+ tensorExpression = tensorL2Normalize() |
+ tensorExpression = tensorMatmul() |
+ tensorExpression = tensorSoftmax() |
+ tensorExpression = tensorXwPlusB()
+ )
+ { return tensorExpression; }
+}
+
+ExpressionNode tensorMap() :
+{
+ ExpressionNode tensor;
+ LambdaFunctionNode doubleMapper;
+}
+{
+ <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
+ { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
+ doubleMapper.asDoubleUnaryOperator())); }
+}
+
+ExpressionNode tensorReduce() :
+{
+ ExpressionNode tensor;
+ Reduce.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+}
+
+ExpressionNode tensorReduceComposites() :
+{
+ ExpressionNode tensor;
+ Reduce.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ aggregator = tensorReduceAggregator()
+ <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+}
+
+ExpressionNode tensorJoin() :
+{
ExpressionNode tensor1, tensor2;
- String dimension = null;
- TensorAddress address = null;
+ LambdaFunctionNode doubleJoiner;
}
{
- (
- <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE>
- { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); }
- ) |
- (
- <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE>
- { return new TensorMatchNode(tensor1, tensor2); }
- )
+ <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
+ { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ doubleJoiner.asDoubleBinaryOperator())); }
+}
+
+ExpressionNode tensorRename() :
+{
+ ExpressionNode tensor;
+ List<String> fromDimensions, toDimensions;
+}
+{
+ <RENAME> <LBRACE> tensor = expression() <COMMA>
+ fromDimensions = bracedIdentifierList() <COMMA>
+ toDimensions = bracedIdentifierList()
+ <RBRACE>
+ { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
+}
+
+// TODO: Notice that null is parsed below
+ExpressionNode tensorGenerate() :
+{
+ TensorType type;
+ LambdaFunctionNode generator;
+}
+{
+ <TENSOR> <LBRACE> <RBRACE> <LBRACE>
+ { return new TensorFunctionNode(new Generate(null, null)); }
+}
+
+ExpressionNode tensorL1Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorL2Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorMatmul() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ dimension)); }
+}
+
+ExpressionNode tensorSoftmax() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorXwPlusB() :
+{
+ ExpressionNode tensor1, tensor2, tensor3;
+ String dimension;
+}
+{
+ <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA>
+ tensor2 = expression() <COMMA>
+ tensor3 = expression() <COMMA>
+ dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ TensorFunctionNode.wrapArgument(tensor3),
+ dimension)); }
+}
+
+LambdaFunctionNode lambdaFunction() :
+{
+ List<String> variables;
+ ExpressionNode functionExpression;
+}
+{
+ ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
+ { return new LambdaFunctionNode(variables, functionExpression); }
+}
+
+Reduce.Aggregator tensorReduceAggregator() :
+{
+}
+{
+ ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
+ { return Reduce.Aggregator.valueOf(token.image); }
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
+ Reduce.Aggregator aggregator;
}
{
- ( <SUM> | <MATCH> )
- { return token.image; }
+ ( <F> { return token.image; } ) |
+ ( <MAP> { return token.image; } ) |
+ ( <REDUCE> { return token.image; } ) |
+ ( <JOIN> { return token.image; } ) |
+ ( <RENAME> { return token.image; } ) |
+ ( <TENSOR> { return token.image; } ) |
+ ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
Function unaryFunctionName() : { }
{
- <COS> { return Function.cos; } |
- <SIN> { return Function.sin; } |
- <TAN> { return Function.tan; } |
- <COSH> { return Function.cosh; } |
- <SINH> { return Function.sinh; } |
- <TANH> { return Function.tanh; } |
+ <ABS> { return Function.abs; } |
<ACOS> { return Function.acos; } |
<ASIN> { return Function.asin; } |
<ATAN> { return Function.atan; } |
- <EXP> { return Function.exp; } |
- <LOG10> { return Function.log10; } |
- <LOG> { return Function.log; } |
- <SQRT> { return Function.sqrt; } |
<CEIL> { return Function.ceil; } |
+ <COS> { return Function.cos; } |
+ <COSH> { return Function.cosh; } |
+ <ELU> { return Function.elu; } |
+ <EXP> { return Function.exp; } |
<FABS> { return Function.fabs; } |
<FLOOR> { return Function.floor; } |
<ISNAN> { return Function.isNan; } |
+ <LOG> { return Function.log; } |
+ <LOG10> { return Function.log10; } |
<RELU> { return Function.relu; } |
- <SIGMOID> { return Function.sigmoid; }
+ <ROUND> { return Function.round; } |
+ <SIGMOID> { return Function.sigmoid; } |
+ <SIGN> { return Function.sign; } |
+ <SIN> { return Function.sin; } |
+ <SINH> { return Function.sinh; } |
+ <SQUARE> { return Function.square; } |
+ <SQRT> { return Function.sqrt; } |
+ <TAN> { return Function.tan; } |
+ <TANH> { return Function.tanh; }
}
Function binaryFunctionName() : { }
{
<ATAN2> { return Function.atan2; } |
- <LDEXP> { return Function.ldexp; } |
- <POW> { return Function.pow; } |
<FMOD> { return Function.fmod; } |
+ <LDEXP> { return Function.ldexp; } |
+ <MAX> { return Function.max; } |
<MIN> { return Function.min; } |
- <MAX> { return Function.max; }
+ <MOD> { return Function.mod; } |
+ <POW> { return Function.pow; }
}
List<ExpressionNode> expressionList() :
@@ -405,79 +597,64 @@ String identifier() :
<IDENTIFIER> { return token.image; }
}
-// An identifier or integer
-String tag() :
-{
- String name;
-}
-{
- name = identifier() { return name; } |
- <INTEGER> { return token.image; }
-}
-
-ConstantNode constantPrimitive() :
+List<String> identifierList() :
{
- String sign = "";
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
+ ( element = identifier() { list.add(element); } )?
+ ( <COMMA> element = identifier() { list.add(element); } ) *
+ { return list; }
}
-Value primitiveValue() :
+List<String> bracedIdentifierList() :
{
- String sign = "";
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return Value.parse(sign + token.image); }
+ ( element = identifier() { return Collections.singletonList(element); } )
+ |
+ ( <LBRACE> list = identifierList() <RBRACE> { return list; } )
}
-ConstantNode constantTensor() :
+// An identifier or integer
+String tag() :
{
- Value constantValue;
+ String name;
}
{
- <LCURLY> constantValue = tensorContent() <RCURLY>
- { return new ConstantNode(constantValue); }
+ name = identifier() { return name; } |
+ <INTEGER> { return token.image; }
}
-TensorValue tensorContent() :
+List<String> tagCommaLeadingList() :
{
- Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>();
- TensorAddress address;
- Double value;
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ?
- ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) *
- { return new TensorValue(new MapTensor(cells)); }
+ ( <COMMA> element = tag() { list.add(element); } ) *
+ { return list; }
}
-TensorAddress tensorAddress() :
+ConstantNode constantPrimitive() :
{
- List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>();
- String dimension;
- String label;
+ String sign = "";
}
{
- <LCURLY>
- ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ?
- ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) *
- <RCURLY>
- { return TensorAddress.fromUnsorted(elements); }
+ ( <SUB> { sign = "-";} ) ?
+ ( <INTEGER> | <FLOAT> | <STRING> )
+ { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
}
-String label() :
+Value primitiveValue() :
{
- String label;
-
+ String sign = "";
}
{
- ( label = tag() |
- ( "-" { label = "-"; } ) )
- { return label; }
+ ( <SUB> { sign = "-";} ) ?
+ ( <INTEGER> | <FLOAT> | <STRING> )
+ { return Value.parse(sign + token.image); }
}
-