diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-08 15:23:49 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-08 15:23:49 +0100 |
commit | cdcafc6fc8b4417abab8c72bbce5c503533558ea (patch) | |
tree | ea4b1c1a64da79bf42270d1b59f10ae32013b4d8 /searchlib/src/main | |
parent | df287b9364b8088192146df70f5f4814ff6c94c1 (diff) |
Serialize scalar functions with context
Diffstat (limited to 'searchlib/src/main')
2 files changed, 43 insertions, 32 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index e6e49e07c34..4ffd40f00f7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -21,7 +22,6 @@ import java.util.Deque; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; import java.util.stream.Collectors; /** @@ -49,8 +49,8 @@ public class TensorFunctionNode extends CompositeNode { } private ExpressionNode toExpressionNode(TensorFunction f) { - if (f instanceof TensorFunctionExpressionNode) - return ((TensorFunctionExpressionNode)f).expression; + if (f instanceof ExpressionTensorFunction) + return ((ExpressionTensorFunction)f).expression; else return new TensorFunctionNode(f); } @@ -58,7 +58,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { List<TensorFunction> wrappedChildren = children.stream() - .map(TensorFunctionExpressionNode::new) + .map(ExpressionTensorFunction::new) .collect(Collectors.toList()); return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @@ -66,7 +66,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { // Serialize as primitive - return string.append(function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this))); + return string.append(function.toPrimitive().toString(new ExpressionToStringContext(context, path, this))); } @Override @@ -77,29 +77,29 @@ public class TensorFunctionNode extends CompositeNode { return new TensorValue(function.evaluate(context)); } - public static TensorFunctionExpressionNode wrap(ExpressionNode node) { - return new TensorFunctionExpressionNode(node); + public static ExpressionTensorFunction wrap(ExpressionNode node) { + return new ExpressionTensorFunction(node); } - public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) { - Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>(); + public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>(); for (var entry : nodes.entrySet()) - closures.put(entry.getKey(), new ExpressionClosure(entry.getValue())); - return closures; + functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue())); + return functions; } - public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) { - List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>(); + public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) { + List<ScalarFunction> functions = new ArrayList<>(); for (var entry : nodes) - closures.add(new ExpressionClosure(entry)); - return closures; + functions.add(new ExpressionScalarFunction(entry)); + return functions; } - private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> { + private static class ExpressionScalarFunction implements ScalarFunction { private final ExpressionNode expression; - public ExpressionClosure(ExpressionNode expression) { + public ExpressionScalarFunction(ExpressionNode expression) { this.expression = expression; } @@ -110,7 +110,18 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString() { - return expression.toString(); + return toString(ExpressionToStringContext.empty); + } + + @Override + public String toString(ToStringContext c) { + if (c instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext) c; + return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString(); + } + else { + return expression.toString(); + } } } @@ -119,12 +130,12 @@ public class TensorFunctionNode extends CompositeNode { * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ - public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { + public static class ExpressionTensorFunction extends PrimitiveTensorFunction { /** An expression which produces a tensor */ private final ExpressionNode expression; - public TensorFunctionExpressionNode(ExpressionNode expression) { + public ExpressionTensorFunction(ExpressionNode expression) { this.expression = expression; } @@ -132,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode { public List<TensorFunction> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() - .map(TensorFunctionExpressionNode::new) + .map(ExpressionTensorFunction::new) .collect(Collectors.toList()); else return Collections.emptyList(); @@ -142,9 +153,9 @@ public class TensorFunctionNode extends CompositeNode { public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() - .map(arg -> ((TensorFunctionExpressionNode)arg).expression) + .map(arg -> ((ExpressionTensorFunction)arg).expression) .collect(Collectors.toList()); - return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren)); + return new ExpressionTensorFunction(((CompositeNode)expression).setChildren(unwrappedChildren)); } @Override @@ -163,13 +174,13 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString() { - return toString(ExpressionNodeToStringContext.empty); + return toString(ExpressionToStringContext.empty); } @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionNodeToStringContext) { - ExpressionNodeToStringContext context = (ExpressionNodeToStringContext) c; + if (c instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext) c; return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString(); } else { @@ -180,17 +191,17 @@ public class TensorFunctionNode extends CompositeNode { } /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionNodeToStringContext implements ToStringContext { + private static class ExpressionToStringContext implements ToStringContext { final SerializationContext context; final Deque<String> path; final CompositeNode parent; - public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(new SerializationContext(), - null, - null); + public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), + null, + null); - public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { + public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { this.context = context; this.path = path; this.parent = parent; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 6d687b015f1..9a38b5efc1f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); - TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1); + TensorFunctionNode.ExpressionTensorFunction expression = TensorFunctionNode.wrap(arg1); Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); |