diff options
Diffstat (limited to 'searchlib/src/main')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 26d3f1dcc0e..93d551ebfd7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.EvaluationContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -39,7 +40,10 @@ public class TensorFunctionNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { - throw new UnsupportedOperationException("Not implemented"); + List<TensorFunction> wrappedChildren = children.stream() + .map(TensorFunctionExpressionNode::new) + .collect(Collectors.toList()); + return new TensorFunctionNode(function.replaceArguments(wrappedChildren)); } @Override @@ -71,7 +75,23 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> functionArguments() { + if (expression instanceof CompositeNode) + return ((CompositeNode)expression).children().stream() + .map(TensorFunctionExpressionNode::new) + .collect(Collectors.toList()); + else + return Collections.emptyList(); + } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if (arguments.size() == 0) return this; + List<ExpressionNode> unwrappedChildren = arguments.stream() + .map(arg -> ((TensorFunctionExpressionNode)arg).expression) + .collect(Collectors.toList()); + return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren)); + } @Override public PrimitiveTensorFunction toPrimitive() { return this; } |