From bd6d16ba66a7b6745fc15a8b25dc7120fb5580ab Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 24 Nov 2016 14:42:06 +0100 Subject: Generalize tensor function handling --- .../rankingexpression/rule/LambdaFunctionNode.java | 10 +++ .../rankingexpression/rule/TensorFunctionNode.java | 1 - .../rankingexpression/rule/TensorJoinNode.java | 66 ------------------- .../rankingexpression/rule/TensorMapNode.java | 58 ----------------- .../rankingexpression/rule/TensorReduceNode.java | 71 -------------------- .../rankingexpression/rule/TensorRenameNode.java | 75 ---------------------- .../src/main/javacc/RankingExpressionParser.jj | 17 +++-- 7 files changed, 20 insertions(+), 278 deletions(-) delete mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java delete mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java delete mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java delete mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java (limited to 'searchlib/src/main') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 593fa4bc45e..ef31bf6ba0d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -92,6 +92,11 @@ public class LambdaFunctionNode extends CompositeNode { context.put(arguments.get(0), operand); return evaluate(context).asDouble(); } + + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } } @@ -107,6 +112,11 @@ public class LambdaFunctionNode extends CompositeNode { return evaluate(context).asDouble(); } + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } + } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ebd79f65578..26d3f1dcc0e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -11,7 +11,6 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.List; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java deleted file mode 100644 index 21455113578..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.Tensor; - -import java.util.Deque; -import java.util.List; - -/** - * A node which joins two tensors - * - * @author bratseth - */ - @Beta -public class TensorJoinNode extends CompositeNode { - - /** The tensor to aggregate over */ - private final ExpressionNode argument1, argument2; - - private final LambdaFunctionNode doubleJoiner; - - public TensorJoinNode(ExpressionNode argument1, ExpressionNode argument2, LambdaFunctionNode doubleJoiner) { - this.argument1 = argument1; - this.argument2 = argument2; - this.doubleJoiner = doubleJoiner; - } - - @Override - public List children() { - return ImmutableList.of(argument1, argument2, doubleJoiner); - } - - @Override - public CompositeNode setChildren(List children) { - if (children.size() != 3) - throw new IllegalArgumentException("A tensor join node must have two tensors and one joiner"); - return new TensorJoinNode(children.get(0), children.get(1), (LambdaFunctionNode)children.get(2)); - } - - @Override - public String toString(SerializationContext context, Deque path, CompositeNode parent) { - return "join(" + argument1.toString(context, path, parent) + ", " + - argument2.toString(context, path, parent) + ", " + - doubleJoiner.toString() + ")"; - } - - @Override - public Value evaluate(Context context) { - Tensor argument1Value = asTensor(argument1.evaluate(context), argument1); - Tensor argument2Value = asTensor(argument2.evaluate(context), argument2); - return new TensorValue(argument1Value.join(argument2Value, doubleJoiner.asDoubleBinaryOperator())); - } - - private Tensor asTensor(Value value, ExpressionNode producingNode) { - if ( ! ( value instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to join '" + producingNode + "', " + - "but this returns " + value + ", not a tensor"); - return ((TensorValue)value).asTensor(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java deleted file mode 100644 index 0cb0da150b4..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.Deque; -import java.util.List; - -/** - * A node which maps the values of a tensor - * - * @author bratseth - */ - @Beta -public class TensorMapNode extends CompositeNode { - - /** The tensor to aggregate over */ - private final ExpressionNode argument; - - private final LambdaFunctionNode doubleMapper; - - public TensorMapNode(ExpressionNode argument, LambdaFunctionNode doubleMapper) { - this.argument = argument; - this.doubleMapper = doubleMapper; - } - - @Override - public List children() { - return ImmutableList.of(argument, doubleMapper); - } - - @Override - public CompositeNode setChildren(List children) { - if (children.size() != 2) - throw new IllegalArgumentException("A tensor map node must have one tensor and one mapper"); - return new TensorMapNode(children.get(0), (LambdaFunctionNode)children.get(1)); - } - - @Override - public String toString(SerializationContext context, Deque path, CompositeNode parent) { - return "map(" + argument.toString(context, path, parent) + ", " + doubleMapper.toString() + ")"; - } - - @Override - public Value evaluate(Context context) { - Value argumentValue = argument.evaluate(context); - if ( ! ( argumentValue instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to map '" + argument + "', " + - "but this returns " + argumentValue + ", not a tensor"); - TensorValue tensorArgument = (TensorValue)argumentValue; - return new TensorValue(tensorArgument.asTensor().map(doubleMapper.asDoubleUnaryOperator())); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java deleted file mode 100644 index d4b95d12fdd..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.functions.Reduce; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; - -/** - * A node which performs a dimension reduction over a tensor - * - * @author bratseth - */ - @Beta -public class TensorReduceNode extends CompositeNode { - - /** The tensor to aggregate over */ - private final ExpressionNode argument; - - private final Reduce.Aggregator aggregator; - - /** The dimensions to sum over, or empty to sum all cells */ - private final ImmutableList dimensions; - - public TensorReduceNode(ExpressionNode argument, Reduce.Aggregator aggregator, List dimensions) { - this.argument = argument; - this.aggregator = aggregator; - this.dimensions = ImmutableList.copyOf(dimensions); - } - - @Override - public List children() { - return Collections.singletonList(argument); - } - - @Override - public CompositeNode setChildren(List children) { - if (children.size() != 1) throw new IllegalArgumentException("A tensor reduce node must have one tensor argument"); - return new TensorReduceNode(children.get(0), aggregator, dimensions); - } - - @Override - public String toString(SerializationContext context, Deque path, CompositeNode parent) { - return "reduce(" + argument.toString(context, path, parent) + ", " + - aggregator + leadingCommaSeparated(dimensions) + ")"; - } - - private String leadingCommaSeparated(List list) { - StringBuilder b = new StringBuilder(); - for (String element : list) - b.append(", ").append(element); - return b.toString(); - } - - @Override - public Value evaluate(Context context) { - Value argumentValue = argument.evaluate(context); - if ( ! ( argumentValue instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to reduce '" + argument + "', " + - "but this returns " + argumentValue + ", not a tensor"); - TensorValue tensorArgument = (TensorValue)argumentValue; - return new TensorValue(tensorArgument.asTensor().reduce(aggregator, dimensions)); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java deleted file mode 100644 index b7f21c215dc..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.Collections; -import java.util.Deque; -import java.util.List; - -/** - * A node which performs a dimension rename in a tensor - * - * @author bratseth - */ - @Beta -public class TensorRenameNode extends CompositeNode { - - private final ExpressionNode argument; - - private final ImmutableList fromDimensions, toDimensions; - - public TensorRenameNode(ExpressionNode argument, List fromDimensions, List toDimensions) { - if (fromDimensions.size() < 1) - throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension"); - if (fromDimensions.size() != toDimensions.size()) - throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " + - fromDimensions.size() + " and " + toDimensions.size()); - this.argument = argument; - this.fromDimensions = ImmutableList.copyOf(fromDimensions); - this.toDimensions = ImmutableList.copyOf(toDimensions); - } - - @Override - public List children() { - return Collections.singletonList(argument); - } - - @Override - public CompositeNode setChildren(List children) { - if (children.size() != 1) throw new IllegalArgumentException("A tensor rename node must have one tensor argument"); - return new TensorRenameNode(children.get(0), fromDimensions, toDimensions); - } - - @Override - public String toString(SerializationContext context, Deque path, CompositeNode parent) { - return "rename(" + argument.toString(context, path, parent) + ", " + - vector(fromDimensions) + ", " + vector(toDimensions) + ")"; - } - - private String vector(List list) { - if (list.size() == 1) return list.get(0); - - StringBuilder b = new StringBuilder("["); - for (String element : list) - b.append(element).append(","); - b.setLength(b.length() - 1); - b.append("]"); - return b.toString(); - } - - @Override - public Value evaluate(Context context) { - Value argumentValue = argument.evaluate(context); - if ( ! ( argumentValue instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to rename dimensions in '" + argument + "', " + - "but this returns " + argumentValue + ", not a tensor"); - TensorValue tensorArgument = (TensorValue)argumentValue; - return new TensorValue(tensorArgument.asTensor().rename(fromDimensions, toDimensions)); - } - -} diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index a2f6366f101..d5f4c67e62a 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -367,7 +367,8 @@ ExpressionNode tensorMap() : } { tensor = expression() doubleMapper = lambdaFunction() - { return new TensorMapNode(tensor, doubleMapper); } + { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), + doubleMapper.asDoubleUnaryOperator())); } } ExpressionNode tensorReduce() : @@ -378,7 +379,7 @@ ExpressionNode tensorReduce() : } { tensor = expression() aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() - { return new TensorReduceNode(tensor, aggregator, dimensions); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } } ExpressionNode tensorReduceComposites() : @@ -390,7 +391,7 @@ ExpressionNode tensorReduceComposites() : { aggregator = tensorReduceAggregator() tensor = expression() dimensions = tagCommaLeadingList() - { return new TensorReduceNode(tensor, aggregator, dimensions); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } } ExpressionNode tensorJoin() : @@ -400,7 +401,9 @@ ExpressionNode tensorJoin() : } { tensor1 = expression() tensor2 = expression() doubleJoiner = lambdaFunction() - { return new TensorJoinNode(tensor1, tensor2, doubleJoiner); } + { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + doubleJoiner.asDoubleBinaryOperator())); } } ExpressionNode tensorRename() : @@ -413,10 +416,10 @@ ExpressionNode tensorRename() : fromDimensions = bracedIdentifierList() toDimensions = bracedIdentifierList() - { return new TensorRenameNode(tensor, fromDimensions, toDimensions); } + { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } } -// TODO +// TODO: Notice that null is parsed below ExpressionNode tensorGenerate() : { TensorType type; @@ -424,7 +427,7 @@ ExpressionNode tensorGenerate() : } { - { return null; } + { return new TensorFunctionNode(new Generate(null, null)); } } ExpressionNode tensorL1Normalize() : -- cgit v1.2.3