diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 16:28:07 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 16:28:07 +0100 |
commit | eb7a97abcdf72222d37459a986217ec5ad5cdacd (patch) | |
tree | 60ed52388d2896170dd6ae89bf39809303ec90d1 /searchlib/src/main | |
parent | f65c80a1fb5fdc285ce0db63b3b1f039f5201505 (diff) |
Implement tensor rename
Diffstat (limited to 'searchlib/src/main')
3 files changed, 126 insertions, 9 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java index 65a1802c72d..4f73c632422 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -25,12 +26,12 @@ public class TensorReduceNode extends CompositeNode { private final ReduceFunction.Aggregator aggregator; /** The dimensions to sum over, or empty to sum all cells */ - private final List<String> dimensions; + private final ImmutableList<String> dimensions; public TensorReduceNode(ExpressionNode argument, ReduceFunction.Aggregator aggregator, List<String> dimensions) { this.argument = argument; this.aggregator = aggregator; - this.dimensions = dimensions; + this.dimensions = ImmutableList.copyOf(dimensions); } @Override @@ -40,16 +41,17 @@ public class TensorReduceNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { - if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument"); + if (children.size() != 1) throw new IllegalArgumentException("A tensor reduce node must have one tensor argument"); return new TensorReduceNode(children.get(0), aggregator, dimensions); } @Override public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return "reduce(" + argument.toString(context, path, parent) + ", \"" + aggregator + "\"" + commaSeparated(dimensions) + ")"; + return "reduce(" + argument.toString(context, path, parent) + ", " + + aggregator + leadingCommaSeparated(dimensions) + ")"; } - private String commaSeparated(List<String> list) { + private String leadingCommaSeparated(List<String> list) { StringBuilder b = new StringBuilder(); for (String element : list) b.append(", ").append(element); @@ -60,7 +62,7 @@ public class TensorReduceNode extends CompositeNode { public Value evaluate(Context context) { Value argumentValue = argument.evaluate(context); if ( ! ( argumentValue instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " + + throw new IllegalArgumentException("Attempted to reduce '" + argument + "', " + "but this returns " + argumentValue + ", not a tensor"); TensorValue tensorArgument = (TensorValue)argumentValue; return new TensorValue(tensorArgument.asTensor().reduce(aggregator, dimensions)); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java new file mode 100644 index 00000000000..17a08beba8b --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java @@ -0,0 +1,76 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.functions.ReduceFunction; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; + +/** + * A node which performs a dimension rename in a tensor + * + * @author bratseth + */ + @Beta +public class TensorRenameNode extends CompositeNode { + + private final ExpressionNode argument; + + private final ImmutableList<String> fromDimensions, toDimensions; + + public TensorRenameNode(ExpressionNode argument, List<String> fromDimensions, List<String> toDimensions) { + if (fromDimensions.size() < 1) + throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension"); + if (fromDimensions.size() != toDimensions.size()) + throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " + + fromDimensions.size() + " and " + toDimensions.size()); + this.argument = argument; + this.fromDimensions = ImmutableList.copyOf(fromDimensions); + this.toDimensions = ImmutableList.copyOf(toDimensions); + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(argument); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if (children.size() != 1) throw new IllegalArgumentException("A tensor rename node must have one tensor argument"); + return new TensorRenameNode(children.get(0), fromDimensions, toDimensions); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "rename(" + argument.toString(context, path, parent) + ", " + + vector(fromDimensions) + ", " + vector(toDimensions) + ")"; + } + + private String vector(List<String> list) { + if (list.size() == 1) return list.get(0); + + StringBuilder b = new StringBuilder("["); + for (String element : list) + b.append(element).append(","); + b.setLength(b.length() - 1); + b.append("]"); + return b.toString(); + } + + @Override + public Value evaluate(Context context) { + Value argumentValue = argument.evaluate(context); + if ( ! ( argumentValue instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to rename dimensions in '" + argument + "', " + + "but this returns " + argumentValue + ", not a tensor"); + TensorValue tensorArgument = (TensorValue)argumentValue; + return new TensorValue(tensorArgument.asTensor().rename(fromDimensions, toDimensions)); + } + +} diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 5a5d916f7e7..b3060e1ff45 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -21,8 +21,7 @@ import com.yahoo.searchlib.rankingexpression.rule.*; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.*; import com.yahoo.tensor.functions.*; import java.util.Collections; import java.util.Map; @@ -104,6 +103,8 @@ TOKEN : <MAP: "map"> | <REDUCE: "reduce"> | <JOIN: "join"> | + <RENAME: "rename"> | + <TENSOR: "tensor"> | <AVG: "avg" > | <COUNT: "count"> | <PROD: "prod"> | @@ -324,7 +325,9 @@ ExpressionNode tensorFunction() : tensorExpression = tensorMap() | tensorExpression = tensorReduce() | tensorExpression = tensorReduceComposites() | - tensorExpression = tensorJoin() + tensorExpression = tensorJoin() | + tensorExpression = tensorRename() | + tensorExpression = tensorGenerate() ) { return tensorExpression; } } @@ -372,6 +375,30 @@ ExpressionNode tensorJoin() : { return new TensorJoinNode(tensor1, tensor2, doubleJoiner); } } +ExpressionNode tensorRename() : +{ + ExpressionNode tensor; + List<String> fromDimensions, toDimensions; +} +{ + <RENAME> <LBRACE> tensor = expression() <COMMA> + fromDimensions = identifierVector() <COMMA> + toDimensions = identifierVector() + <RBRACE> + { return new TensorRenameNode(tensor, fromDimensions, toDimensions); } +} + +// TODO +ExpressionNode tensorGenerate() : +{ + TensorType type; + LambdaFunctionNode generator; +} +{ + <TENSOR> <LBRACE> <RBRACE> <LBRACE> + { return null; } +} + LambdaFunctionNode lambdaFunction() : { List<String> variables; @@ -400,6 +427,8 @@ String tensorFunctionName() : ( <MAP> { return token.image; } ) | ( <REDUCE> { return token.image; } ) | ( <JOIN> { return token.image; } ) | + ( <RENAME> { return token.image; } ) | + ( <TENSOR> { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } @@ -483,6 +512,16 @@ List<String> identifierList() : { return list; } } +List<String> identifierVector() : +{ + List<String> list = new ArrayList<String>(); + String element; +} +{ + ( element = identifier() { return Collections.singletonList(element); } ) + | + ( <LSQUARE> list = identifierList() <RSQUARE> { return list; } ) +} // An identifier or integer String tag() : |