aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-04 04:52:16 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-04 04:52:16 -0800
commita71001d66ada9eaf4ae89d896fea60a39ea2056b (patch)
tree4cd063b18eb095407df8a4e244564d8bfc690de5 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
parent78b29e76f2bab63f7cec92f4c1fd9e7661602df7 (diff)
Propagate binding context to/from tensor functions
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java98
1 files changed, 83 insertions, 15 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index a0c261ae9d3..f510f38d7a7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -2,6 +2,8 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -16,8 +18,6 @@ import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
-import java.sql.Ref;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedHashMap;
@@ -71,7 +71,9 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext<Reference> context) { return function.type(context); }
+ public TensorType type(TypeContext<Reference> context) {
+ return function.type(context);
+ }
@Override
public Value evaluate(Context context) {
@@ -117,9 +119,16 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public String toString(ToStringContext c) {
- if (c instanceof ExpressionToStringContext) {
- ExpressionToStringContext context = (ExpressionToStringContext) c;
- return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString();
+ ToStringContext outermost = c;
+ while (outermost.wrapped() != null)
+ outermost = outermost.wrapped();
+
+ if (outermost instanceof ExpressionToStringContext) {
+ ExpressionToStringContext context = (ExpressionToStringContext)outermost;
+ return expression.toString(new StringBuilder(),
+ new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent),
+ context.path,
+ context.parent).toString();
}
else {
return expression.toString();
@@ -180,9 +189,17 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public String toString(ToStringContext c) {
- if (c instanceof ExpressionToStringContext) {
- ExpressionToStringContext context = (ExpressionToStringContext) c;
- return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString();
+ ToStringContext outermost = c;
+ while (outermost.wrapped() != null)
+ outermost = outermost.wrapped();
+
+ if (outermost instanceof ExpressionToStringContext) {
+ ExpressionToStringContext context = (ExpressionToStringContext)outermost;
+ return expression.toString(new StringBuilder(),
+ new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent),
+ context.path,
+ context.parent)
+ .toString();
}
else {
return expression.toString();
@@ -192,22 +209,73 @@ public class TensorFunctionNode extends CompositeNode {
}
/** Allows passing serialization context arguments through TensorFunctions */
- private static class ExpressionToStringContext implements ToStringContext {
+ private static class ExpressionToStringContext extends SerializationContext implements ToStringContext {
- final SerializationContext context;
- final Deque<String> path;
- final CompositeNode parent;
+ private final ToStringContext wrappedToStringContext;
+ private final SerializationContext wrappedSerializationContext;
+ private final Deque<String> path;
+ private final CompositeNode parent;
public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(),
null,
null);
- public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
- this.context = context;
+ ExpressionToStringContext(SerializationContext wrappedSerializationContext, Deque<String> path, CompositeNode parent) {
+ this(wrappedSerializationContext, null, path, parent);
+ }
+
+ ExpressionToStringContext(SerializationContext wrappedSerializationContext,
+ ToStringContext wrappedToStringContext,
+ Deque<String> path,
+ CompositeNode parent) {
+ this.wrappedSerializationContext = wrappedSerializationContext;
+ this.wrappedToStringContext = wrappedToStringContext;
this.path = path;
this.parent = parent;
}
+ /** Adds the serialization of a function */
+ public void addFunctionSerialization(String name, String expressionString) {
+ wrappedSerializationContext.addFunctionSerialization(name, expressionString);
+ }
+
+ /** Adds the serialization of the an argument type to a function */
+ public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) {
+ wrappedSerializationContext.addArgumentTypeSerialization(functionName, argumentName, type);
+ }
+
+ /** Adds the serialization of the return type of a function */
+ public void addFunctionTypeSerialization(String functionName, TensorType type) {
+ wrappedSerializationContext.addFunctionTypeSerialization(functionName, type);
+ }
+
+ public Map<String, String> serializedFunctions() {
+ return wrappedSerializationContext.serializedFunctions();
+ }
+
+ /** Returns a function or null if it isn't defined in this context */
+ public ExpressionFunction getFunction(String name) { return wrappedSerializationContext.getFunction(name); }
+
+ protected ImmutableMap<String, ExpressionFunction> functions() { return wrappedSerializationContext.functions(); }
+
+ public ToStringContext wrapped() { return wrappedToStringContext; }
+
+ /** Returns the resolution of an identifier, or null if it isn't defined in this context */
+ @Override
+ public String getBinding(String name) {
+ if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null)
+ return wrappedToStringContext.getBinding(name);
+ else
+ return wrappedSerializationContext.getBinding(name);
+ }
+
+ /** Returns a new context with the bindings replaced by the given bindings */
+ @Override
+ public ExpressionToStringContext withBindings(Map<String, String> bindings) {
+ return new ExpressionToStringContext(new SerializationContext(wrappedSerializationContext.functions().values(), bindings),
+ wrappedToStringContext, path, parent);
+ }
+
}
/** Turns an EvaluationContext into a Context */