aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-28 08:54:41 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-28 08:54:41 +0100
commitb9161cb0f3eec983af285e01fae9b28756f038a0 (patch)
tree082316e679ccd2a81d8fff12fffe67c798811a2a /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
parent2f55986b4de9420e5728c5abbaafb69fb2f10a34 (diff)
Propagate set/getChildren
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java24
1 files changed, 22 insertions, 2 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 26d3f1dcc0e..93d551ebfd7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.EvaluationContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
@@ -39,7 +40,10 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
- throw new UnsupportedOperationException("Not implemented");
+ List<TensorFunction> wrappedChildren = children.stream()
+ .map(TensorFunctionExpressionNode::new)
+ .collect(Collectors.toList());
+ return new TensorFunctionNode(function.replaceArguments(wrappedChildren));
}
@Override
@@ -71,7 +75,23 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> functionArguments() {
+ if (expression instanceof CompositeNode)
+ return ((CompositeNode)expression).children().stream()
+ .map(TensorFunctionExpressionNode::new)
+ .collect(Collectors.toList());
+ else
+ return Collections.emptyList();
+ }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if (arguments.size() == 0) return this;
+ List<ExpressionNode> unwrappedChildren = arguments.stream()
+ .map(arg -> ((TensorFunctionExpressionNode)arg).expression)
+ .collect(Collectors.toList());
+ return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren));
+ }
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }