summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java89
1 files changed, 38 insertions, 51 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
index 70a7372dbe9..971c2c4f218 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
@@ -3,7 +3,6 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.document.Attribute;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -17,12 +16,12 @@ import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
/**
@@ -36,32 +35,22 @@ import java.util.Optional;
*/
public class TensorTransformer extends ExpressionTransformer {
- private Search search;
- private RankProfile rankprofile;
- private Map<String, RankProfile.Macro> macros;
-
- public TensorTransformer(RankProfile rankprofile) {
- this.rankprofile = rankprofile;
- this.search = rankprofile.getSearch();
- this.macros = rankprofile.getMacros();
- }
-
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof CompositeNode) {
- node = transformChildren((CompositeNode) node);
+ node = transformChildren((CompositeNode) node, context);
}
if (node instanceof FunctionNode) {
- node = transformFunctionNode((FunctionNode) node);
+ node = transformFunctionNode((FunctionNode) node, ((RankProfileTransformContext)context).rankProfile());
}
return node;
}
- private ExpressionNode transformFunctionNode(FunctionNode node) {
+ private ExpressionNode transformFunctionNode(FunctionNode node, RankProfile profile) {
switch (node.getFunction()) {
case min:
case max:
- return transformMaxAndMinFunctionNode(node);
+ return transformMaxAndMinFunctionNode(node, profile);
}
return node;
}
@@ -80,7 +69,7 @@ public class TensorTransformer extends ExpressionTransformer {
* There is currently no guarantee that all cases will be found. For
* instance, if-statements are problematic.
*/
- private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) {
+ private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, RankProfile profile) {
if (node.children().size() != 2) {
return node;
}
@@ -88,7 +77,7 @@ public class TensorTransformer extends ExpressionTransformer {
Optional<String> dimension = dimensionName(node.children().get(1));
if (dimension.isPresent()) {
try {
- Context context = buildContext(arg1);
+ Context context = buildContext(arg1, profile);
Value value = arg1.evaluate(context);
if (isTensorWithDimension(value, dimension.get())) {
return replaceMaxAndMinFunction(node);
@@ -110,12 +99,10 @@ public class TensorTransformer extends ExpressionTransformer {
}
private boolean isTensorWithDimension(Value value, String dimension) {
- if (value instanceof TensorValue) {
- Tensor tensor = ((TensorValue) value).asTensor();
- TensorType type = tensor.type();
- return type.dimensionNames().contains(dimension);
- }
- return false;
+ if (value instanceof TensorValue)
+ return value.asTensor().type().dimensionNames().contains(dimension);
+ else
+ return false;
}
private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
@@ -133,9 +120,9 @@ public class TensorTransformer extends ExpressionTransformer {
* Creates an evaluation context by iterating through the expression tree, and
* adding dummy values with correct types to the context.
*/
- private Context buildContext(ExpressionNode node) {
+ private Context buildContext(ExpressionNode node, RankProfile profile) {
Context context = new MapContext();
- addRoot(node, context);
+ addRoot(node, context, profile);
return context;
}
@@ -152,28 +139,28 @@ public class TensorTransformer extends ExpressionTransformer {
return new TensorValue(empty);
}
- private void addRoot(ExpressionNode node, Context context) {
- addChildren(node, context);
+ private void addRoot(ExpressionNode node, Context context, RankProfile profile) {
+ addChildren(node, context, profile);
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- addIfAttribute(referenceNode, context);
- addIfConstant(referenceNode, context);
- addIfQuery(referenceNode, context);
+ addIfAttribute(referenceNode, context, profile);
+ addIfConstant(referenceNode, context, profile);
+ addIfQuery(referenceNode, context, profile);
addIfTensorFrom(referenceNode, context);
- addIfMacro(referenceNode, context);
+ addIfMacro(referenceNode, context, profile);
}
}
- private void addChildren(ExpressionNode node, Context context) {
+ private void addChildren(ExpressionNode node, Context context, RankProfile profile) {
if (node instanceof CompositeNode) {
List<ExpressionNode> children = ((CompositeNode) node).children();
for (ExpressionNode child : children) {
- addRoot(child, context);
+ addRoot(child, context, profile);
}
}
}
- private void addIfAttribute(ReferenceNode node, Context context) {
+ private void addIfAttribute(ReferenceNode node, Context context, RankProfile profile) {
if (!node.getName().equals("attribute")) {
return;
}
@@ -181,7 +168,7 @@ public class TensorTransformer extends ExpressionTransformer {
return;
}
String attribute = node.children().get(0).toString();
- Attribute a = search.getAttribute(attribute);
+ Attribute a = profile.getSearch().getAttribute(attribute);
if (a == null) {
return;
}
@@ -196,7 +183,7 @@ public class TensorTransformer extends ExpressionTransformer {
context.put(node.toString(), v);
}
- private void addIfConstant(ReferenceNode node, Context context) {
+ private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) {
if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) {
return;
}
@@ -208,25 +195,25 @@ public class TensorTransformer extends ExpressionTransformer {
child = ((CompositeNode) child).children().get(0);
}
String name = child.toString();
- addIfConstantInRankProfile(name, node, context);
- addIfConstantInRankingConstants(name, node, context);
+ addIfConstantInRankProfile(name, node, context, profile);
+ addIfConstantInRankingConstants(name, node, context, profile);
}
- private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) {
- if (rankprofile.getConstants().containsKey(name)) {
- context.put(node.toString(), rankprofile.getConstants().get(name));
+ private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context, RankProfile profile) {
+ if (profile.getConstants().containsKey(name)) {
+ context.put(node.toString(), profile.getConstants().get(name));
}
}
- private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) {
- for (RankingConstant rankingConstant : search.getRankingConstants()) {
+ private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context, RankProfile profile) {
+ for (RankingConstant rankingConstant : profile.getSearch().getRankingConstants()) {
if (rankingConstant.getName().equals(name)) {
context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType()));
}
}
}
- private void addIfQuery(ReferenceNode node, Context context) {
+ private void addIfQuery(ReferenceNode node, Context context, RankProfile profile) {
if (!node.getName().equals("query")) {
return;
}
@@ -234,8 +221,8 @@ public class TensorTransformer extends ExpressionTransformer {
return;
}
String name = node.children().get(0).toString();
- if (rankprofile.getQueryFeatureTypes().containsKey(name)) {
- String type = rankprofile.getQueryFeatureTypes().get(name);
+ if (profile.getQueryFeatureTypes().containsKey(name)) {
+ String type = profile.getQueryFeatureTypes().get(name);
Value v;
if (type.contains("tensor")) {
v = emptyTensorValue(TensorType.fromSpec(type));
@@ -267,13 +254,13 @@ public class TensorTransformer extends ExpressionTransformer {
context.put(node.toString(), emptyTensorValue(type));
}
- private void addIfMacro(ReferenceNode node, Context context) {
- RankProfile.Macro macro = macros.get(node.getName());
+ private void addIfMacro(ReferenceNode node, Context context, RankProfile profile) {
+ RankProfile.Macro macro = profile.getMacros().get(node.getName());
if (macro == null) {
return;
}
ExpressionNode root = macro.getRankingExpression().getRoot();
- Context macroContext = buildContext(root);
+ Context macroContext = buildContext(root, profile);
addMacroArguments(node, context, macro, macroContext);
Value value = root.evaluate(macroContext);
context.put(node.toString(), value);