summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:42:06 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:42:06 +0100
commitbd6d16ba66a7b6745fc15a8b25dc7120fb5580ab (patch)
tree839ab08d3b894d86185aeb61595f261ae4fc5922 /searchlib/src/main
parentb55f0fa91a7301539f3b7a7a1fcd59f73a541fdc (diff)
Generalize tensor function handling
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java66
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java58
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java71
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java75
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj17
7 files changed, 20 insertions, 278 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 593fa4bc45e..ef31bf6ba0d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -92,6 +92,11 @@ public class LambdaFunctionNode extends CompositeNode {
context.put(arguments.get(0), operand);
return evaluate(context).asDouble();
}
+
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
}
@@ -107,6 +112,11 @@ public class LambdaFunctionNode extends CompositeNode {
return evaluate(context).asDouble();
}
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ebd79f65578..26d3f1dcc0e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -11,7 +11,6 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java
deleted file mode 100644
index 21455113578..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.Tensor;
-
-import java.util.Deque;
-import java.util.List;
-
-/**
- * A node which joins two tensors
- *
- * @author bratseth
- */
- @Beta
-public class TensorJoinNode extends CompositeNode {
-
- /** The tensor to aggregate over */
- private final ExpressionNode argument1, argument2;
-
- private final LambdaFunctionNode doubleJoiner;
-
- public TensorJoinNode(ExpressionNode argument1, ExpressionNode argument2, LambdaFunctionNode doubleJoiner) {
- this.argument1 = argument1;
- this.argument2 = argument2;
- this.doubleJoiner = doubleJoiner;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return ImmutableList.of(argument1, argument2, doubleJoiner);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 3)
- throw new IllegalArgumentException("A tensor join node must have two tensors and one joiner");
- return new TensorJoinNode(children.get(0), children.get(1), (LambdaFunctionNode)children.get(2));
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "join(" + argument1.toString(context, path, parent) + ", " +
- argument2.toString(context, path, parent) + ", " +
- doubleJoiner.toString() + ")";
- }
-
- @Override
- public Value evaluate(Context context) {
- Tensor argument1Value = asTensor(argument1.evaluate(context), argument1);
- Tensor argument2Value = asTensor(argument2.evaluate(context), argument2);
- return new TensorValue(argument1Value.join(argument2Value, doubleJoiner.asDoubleBinaryOperator()));
- }
-
- private Tensor asTensor(Value value, ExpressionNode producingNode) {
- if ( ! ( value instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to join '" + producingNode + "', " +
- "but this returns " + value + ", not a tensor");
- return ((TensorValue)value).asTensor();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java
deleted file mode 100644
index 0cb0da150b4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java
+++ /dev/null
@@ -1,58 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.Deque;
-import java.util.List;
-
-/**
- * A node which maps the values of a tensor
- *
- * @author bratseth
- */
- @Beta
-public class TensorMapNode extends CompositeNode {
-
- /** The tensor to aggregate over */
- private final ExpressionNode argument;
-
- private final LambdaFunctionNode doubleMapper;
-
- public TensorMapNode(ExpressionNode argument, LambdaFunctionNode doubleMapper) {
- this.argument = argument;
- this.doubleMapper = doubleMapper;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return ImmutableList.of(argument, doubleMapper);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 2)
- throw new IllegalArgumentException("A tensor map node must have one tensor and one mapper");
- return new TensorMapNode(children.get(0), (LambdaFunctionNode)children.get(1));
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "map(" + argument.toString(context, path, parent) + ", " + doubleMapper.toString() + ")";
- }
-
- @Override
- public Value evaluate(Context context) {
- Value argumentValue = argument.evaluate(context);
- if ( ! ( argumentValue instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to map '" + argument + "', " +
- "but this returns " + argumentValue + ", not a tensor");
- TensorValue tensorArgument = (TensorValue)argumentValue;
- return new TensorValue(tensorArgument.asTensor().map(doubleMapper.asDoubleUnaryOperator()));
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
deleted file mode 100644
index d4b95d12fdd..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.functions.Reduce;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-
-/**
- * A node which performs a dimension reduction over a tensor
- *
- * @author bratseth
- */
- @Beta
-public class TensorReduceNode extends CompositeNode {
-
- /** The tensor to aggregate over */
- private final ExpressionNode argument;
-
- private final Reduce.Aggregator aggregator;
-
- /** The dimensions to sum over, or empty to sum all cells */
- private final ImmutableList<String> dimensions;
-
- public TensorReduceNode(ExpressionNode argument, Reduce.Aggregator aggregator, List<String> dimensions) {
- this.argument = argument;
- this.aggregator = aggregator;
- this.dimensions = ImmutableList.copyOf(dimensions);
- }
-
- @Override
- public List<ExpressionNode> children() {
- return Collections.singletonList(argument);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 1) throw new IllegalArgumentException("A tensor reduce node must have one tensor argument");
- return new TensorReduceNode(children.get(0), aggregator, dimensions);
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "reduce(" + argument.toString(context, path, parent) + ", " +
- aggregator + leadingCommaSeparated(dimensions) + ")";
- }
-
- private String leadingCommaSeparated(List<String> list) {
- StringBuilder b = new StringBuilder();
- for (String element : list)
- b.append(", ").append(element);
- return b.toString();
- }
-
- @Override
- public Value evaluate(Context context) {
- Value argumentValue = argument.evaluate(context);
- if ( ! ( argumentValue instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to reduce '" + argument + "', " +
- "but this returns " + argumentValue + ", not a tensor");
- TensorValue tensorArgument = (TensorValue)argumentValue;
- return new TensorValue(tensorArgument.asTensor().reduce(aggregator, dimensions));
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java
deleted file mode 100644
index b7f21c215dc..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.Collections;
-import java.util.Deque;
-import java.util.List;
-
-/**
- * A node which performs a dimension rename in a tensor
- *
- * @author bratseth
- */
- @Beta
-public class TensorRenameNode extends CompositeNode {
-
- private final ExpressionNode argument;
-
- private final ImmutableList<String> fromDimensions, toDimensions;
-
- public TensorRenameNode(ExpressionNode argument, List<String> fromDimensions, List<String> toDimensions) {
- if (fromDimensions.size() < 1)
- throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
- if (fromDimensions.size() != toDimensions.size())
- throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
- fromDimensions.size() + " and " + toDimensions.size());
- this.argument = argument;
- this.fromDimensions = ImmutableList.copyOf(fromDimensions);
- this.toDimensions = ImmutableList.copyOf(toDimensions);
- }
-
- @Override
- public List<ExpressionNode> children() {
- return Collections.singletonList(argument);
- }
-
- @Override
- public CompositeNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 1) throw new IllegalArgumentException("A tensor rename node must have one tensor argument");
- return new TensorRenameNode(children.get(0), fromDimensions, toDimensions);
- }
-
- @Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "rename(" + argument.toString(context, path, parent) + ", " +
- vector(fromDimensions) + ", " + vector(toDimensions) + ")";
- }
-
- private String vector(List<String> list) {
- if (list.size() == 1) return list.get(0);
-
- StringBuilder b = new StringBuilder("[");
- for (String element : list)
- b.append(element).append(",");
- b.setLength(b.length() - 1);
- b.append("]");
- return b.toString();
- }
-
- @Override
- public Value evaluate(Context context) {
- Value argumentValue = argument.evaluate(context);
- if ( ! ( argumentValue instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to rename dimensions in '" + argument + "', " +
- "but this returns " + argumentValue + ", not a tensor");
- TensorValue tensorArgument = (TensorValue)argumentValue;
- return new TensorValue(tensorArgument.asTensor().rename(fromDimensions, toDimensions));
- }
-
-}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index a2f6366f101..d5f4c67e62a 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -367,7 +367,8 @@ ExpressionNode tensorMap() :
}
{
<MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
- { return new TensorMapNode(tensor, doubleMapper); }
+ { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
+ doubleMapper.asDoubleUnaryOperator())); }
}
ExpressionNode tensorReduce() :
@@ -378,7 +379,7 @@ ExpressionNode tensorReduce() :
}
{
<REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorReduceNode(tensor, aggregator, dimensions); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorReduceComposites() :
@@ -390,7 +391,7 @@ ExpressionNode tensorReduceComposites() :
{
aggregator = tensorReduceAggregator()
<LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorReduceNode(tensor, aggregator, dimensions); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorJoin() :
@@ -400,7 +401,9 @@ ExpressionNode tensorJoin() :
}
{
<JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
- { return new TensorJoinNode(tensor1, tensor2, doubleJoiner); }
+ { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ doubleJoiner.asDoubleBinaryOperator())); }
}
ExpressionNode tensorRename() :
@@ -413,10 +416,10 @@ ExpressionNode tensorRename() :
fromDimensions = bracedIdentifierList() <COMMA>
toDimensions = bracedIdentifierList()
<RBRACE>
- { return new TensorRenameNode(tensor, fromDimensions, toDimensions); }
+ { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
}
-// TODO
+// TODO: Notice that null is parsed below
ExpressionNode tensorGenerate() :
{
TensorType type;
@@ -424,7 +427,7 @@ ExpressionNode tensorGenerate() :
}
{
<TENSOR> <LBRACE> <RBRACE> <LBRACE>
- { return null; }
+ { return new TensorFunctionNode(new Generate(null, null)); }
}
ExpressionNode tensorL1Normalize() :