diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-28 08:54:41 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-28 08:54:41 +0100 |
commit | b9161cb0f3eec983af285e01fae9b28756f038a0 (patch) | |
tree | 082316e679ccd2a81d8fff12fffe67c798811a2a /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | |
parent | 2f55986b4de9420e5728c5abbaafb69fb2f10a34 (diff) |
Propagate set/getChildren
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 26d3f1dcc0e..93d551ebfd7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.EvaluationContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -39,7 +40,10 @@ public class TensorFunctionNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { - throw new UnsupportedOperationException("Not implemented"); + List<TensorFunction> wrappedChildren = children.stream() + .map(TensorFunctionExpressionNode::new) + .collect(Collectors.toList()); + return new TensorFunctionNode(function.replaceArguments(wrappedChildren)); } @Override @@ -71,7 +75,23 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> functionArguments() { + if (expression instanceof CompositeNode) + return ((CompositeNode)expression).children().stream() + .map(TensorFunctionExpressionNode::new) + .collect(Collectors.toList()); + else + return Collections.emptyList(); + } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if (arguments.size() == 0) return this; + List<ExpressionNode> unwrappedChildren = arguments.stream() + .map(arg -> ((TensorFunctionExpressionNode)arg).expression) + .collect(Collectors.toList()); + return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren)); + } @Override public PrimitiveTensorFunction toPrimitive() { return this; } |