summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-23 16:28:07 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-23 16:28:07 +0100
commiteb7a97abcdf72222d37459a986217ec5ad5cdacd (patch)
tree60ed52388d2896170dd6ae89bf39809303ec90d1 /searchlib/src/main
parentf65c80a1fb5fdc285ce0db63b3b1f039f5201505 (diff)
Implement tensor rename
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java76
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj45
3 files changed, 126 insertions, 9 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
index 65a1802c72d..4f73c632422 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -25,12 +26,12 @@ public class TensorReduceNode extends CompositeNode {
private final ReduceFunction.Aggregator aggregator;
/** The dimensions to sum over, or empty to sum all cells */
- private final List<String> dimensions;
+ private final ImmutableList<String> dimensions;
public TensorReduceNode(ExpressionNode argument, ReduceFunction.Aggregator aggregator, List<String> dimensions) {
this.argument = argument;
this.aggregator = aggregator;
- this.dimensions = dimensions;
+ this.dimensions = ImmutableList.copyOf(dimensions);
}
@Override
@@ -40,16 +41,17 @@ public class TensorReduceNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument");
+ if (children.size() != 1) throw new IllegalArgumentException("A tensor reduce node must have one tensor argument");
return new TensorReduceNode(children.get(0), aggregator, dimensions);
}
@Override
public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "reduce(" + argument.toString(context, path, parent) + ", \"" + aggregator + "\"" + commaSeparated(dimensions) + ")";
+ return "reduce(" + argument.toString(context, path, parent) + ", " +
+ aggregator + leadingCommaSeparated(dimensions) + ")";
}
- private String commaSeparated(List<String> list) {
+ private String leadingCommaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
b.append(", ").append(element);
@@ -60,7 +62,7 @@ public class TensorReduceNode extends CompositeNode {
public Value evaluate(Context context) {
Value argumentValue = argument.evaluate(context);
if ( ! ( argumentValue instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " +
+ throw new IllegalArgumentException("Attempted to reduce '" + argument + "', " +
"but this returns " + argumentValue + ", not a tensor");
TensorValue tensorArgument = (TensorValue)argumentValue;
return new TensorValue(tensorArgument.asTensor().reduce(aggregator, dimensions));
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java
new file mode 100644
index 00000000000..17a08beba8b
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java
@@ -0,0 +1,76 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.functions.ReduceFunction;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * A node which performs a dimension rename in a tensor
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorRenameNode extends CompositeNode {
+
+ private final ExpressionNode argument;
+
+ private final ImmutableList<String> fromDimensions, toDimensions;
+
+ public TensorRenameNode(ExpressionNode argument, List<String> fromDimensions, List<String> toDimensions) {
+ if (fromDimensions.size() < 1)
+ throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
+ if (fromDimensions.size() != toDimensions.size())
+ throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
+ fromDimensions.size() + " and " + toDimensions.size());
+ this.argument = argument;
+ this.fromDimensions = ImmutableList.copyOf(fromDimensions);
+ this.toDimensions = ImmutableList.copyOf(toDimensions);
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(argument);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 1) throw new IllegalArgumentException("A tensor rename node must have one tensor argument");
+ return new TensorRenameNode(children.get(0), fromDimensions, toDimensions);
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "rename(" + argument.toString(context, path, parent) + ", " +
+ vector(fromDimensions) + ", " + vector(toDimensions) + ")";
+ }
+
+ private String vector(List<String> list) {
+ if (list.size() == 1) return list.get(0);
+
+ StringBuilder b = new StringBuilder("[");
+ for (String element : list)
+ b.append(element).append(",");
+ b.setLength(b.length() - 1);
+ b.append("]");
+ return b.toString();
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Value argumentValue = argument.evaluate(context);
+ if ( ! ( argumentValue instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to rename dimensions in '" + argument + "', " +
+ "but this returns " + argumentValue + ", not a tensor");
+ TensorValue tensorArgument = (TensorValue)argumentValue;
+ return new TensorValue(tensorArgument.asTensor().rename(fromDimensions, toDimensions));
+ }
+
+}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 5a5d916f7e7..b3060e1ff45 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -21,8 +21,7 @@ import com.yahoo.searchlib.rankingexpression.rule.*;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.*;
import com.yahoo.tensor.functions.*;
import java.util.Collections;
import java.util.Map;
@@ -104,6 +103,8 @@ TOKEN :
<MAP: "map"> |
<REDUCE: "reduce"> |
<JOIN: "join"> |
+ <RENAME: "rename"> |
+ <TENSOR: "tensor"> |
<AVG: "avg" > |
<COUNT: "count"> |
<PROD: "prod"> |
@@ -324,7 +325,9 @@ ExpressionNode tensorFunction() :
tensorExpression = tensorMap() |
tensorExpression = tensorReduce() |
tensorExpression = tensorReduceComposites() |
- tensorExpression = tensorJoin()
+ tensorExpression = tensorJoin() |
+ tensorExpression = tensorRename() |
+ tensorExpression = tensorGenerate()
)
{ return tensorExpression; }
}
@@ -372,6 +375,30 @@ ExpressionNode tensorJoin() :
{ return new TensorJoinNode(tensor1, tensor2, doubleJoiner); }
}
+ExpressionNode tensorRename() :
+{
+ ExpressionNode tensor;
+ List<String> fromDimensions, toDimensions;
+}
+{
+ <RENAME> <LBRACE> tensor = expression() <COMMA>
+ fromDimensions = identifierVector() <COMMA>
+ toDimensions = identifierVector()
+ <RBRACE>
+ { return new TensorRenameNode(tensor, fromDimensions, toDimensions); }
+}
+
+// TODO
+ExpressionNode tensorGenerate() :
+{
+ TensorType type;
+ LambdaFunctionNode generator;
+}
+{
+ <TENSOR> <LBRACE> <RBRACE> <LBRACE>
+ { return null; }
+}
+
LambdaFunctionNode lambdaFunction() :
{
List<String> variables;
@@ -400,6 +427,8 @@ String tensorFunctionName() :
( <MAP> { return token.image; } ) |
( <REDUCE> { return token.image; } ) |
( <JOIN> { return token.image; } ) |
+ ( <RENAME> { return token.image; } ) |
+ ( <TENSOR> { return token.image; } ) |
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
@@ -483,6 +512,16 @@ List<String> identifierList() :
{ return list; }
}
+List<String> identifierVector() :
+{
+ List<String> list = new ArrayList<String>();
+ String element;
+}
+{
+ ( element = identifier() { return Collections.singletonList(element); } )
+ |
+ ( <LSQUARE> list = identifierList() <RSQUARE> { return list; } )
+}
// An identifier or integer
String tag() :