diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java | 89 |
1 files changed, 38 insertions, 51 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java index 70a7372dbe9..971c2c4f218 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -3,7 +3,6 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.document.Attribute; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -17,12 +16,12 @@ import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import java.util.List; -import java.util.Map; import java.util.Optional; /** @@ -36,32 +35,22 @@ import java.util.Optional; */ public class TensorTransformer extends ExpressionTransformer { - private Search search; - private RankProfile rankprofile; - private Map<String, RankProfile.Macro> macros; - - public TensorTransformer(RankProfile rankprofile) { - this.rankprofile = rankprofile; - this.search = rankprofile.getSearch(); - this.macros = rankprofile.getMacros(); - } - @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof CompositeNode) { - node = transformChildren((CompositeNode) node); + node = transformChildren((CompositeNode) node, context); } if (node instanceof FunctionNode) { - node = transformFunctionNode((FunctionNode) node); + node = transformFunctionNode((FunctionNode) node, ((RankProfileTransformContext)context).rankProfile()); } return node; } - private ExpressionNode transformFunctionNode(FunctionNode node) { + private ExpressionNode transformFunctionNode(FunctionNode node, RankProfile profile) { switch (node.getFunction()) { case min: case max: - return transformMaxAndMinFunctionNode(node); + return transformMaxAndMinFunctionNode(node, profile); } return node; } @@ -80,7 +69,7 @@ public class TensorTransformer extends ExpressionTransformer { * There is currently no guarantee that all cases will be found. For * instance, if-statements are problematic. */ - private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) { + private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, RankProfile profile) { if (node.children().size() != 2) { return node; } @@ -88,7 +77,7 @@ public class TensorTransformer extends ExpressionTransformer { Optional<String> dimension = dimensionName(node.children().get(1)); if (dimension.isPresent()) { try { - Context context = buildContext(arg1); + Context context = buildContext(arg1, profile); Value value = arg1.evaluate(context); if (isTensorWithDimension(value, dimension.get())) { return replaceMaxAndMinFunction(node); @@ -110,12 +99,10 @@ public class TensorTransformer extends ExpressionTransformer { } private boolean isTensorWithDimension(Value value, String dimension) { - if (value instanceof TensorValue) { - Tensor tensor = ((TensorValue) value).asTensor(); - TensorType type = tensor.type(); - return type.dimensionNames().contains(dimension); - } - return false; + if (value instanceof TensorValue) + return value.asTensor().type().dimensionNames().contains(dimension); + else + return false; } private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { @@ -133,9 +120,9 @@ public class TensorTransformer extends ExpressionTransformer { * Creates an evaluation context by iterating through the expression tree, and * adding dummy values with correct types to the context. */ - private Context buildContext(ExpressionNode node) { + private Context buildContext(ExpressionNode node, RankProfile profile) { Context context = new MapContext(); - addRoot(node, context); + addRoot(node, context, profile); return context; } @@ -152,28 +139,28 @@ public class TensorTransformer extends ExpressionTransformer { return new TensorValue(empty); } - private void addRoot(ExpressionNode node, Context context) { - addChildren(node, context); + private void addRoot(ExpressionNode node, Context context, RankProfile profile) { + addChildren(node, context, profile); if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - addIfAttribute(referenceNode, context); - addIfConstant(referenceNode, context); - addIfQuery(referenceNode, context); + addIfAttribute(referenceNode, context, profile); + addIfConstant(referenceNode, context, profile); + addIfQuery(referenceNode, context, profile); addIfTensorFrom(referenceNode, context); - addIfMacro(referenceNode, context); + addIfMacro(referenceNode, context, profile); } } - private void addChildren(ExpressionNode node, Context context) { + private void addChildren(ExpressionNode node, Context context, RankProfile profile) { if (node instanceof CompositeNode) { List<ExpressionNode> children = ((CompositeNode) node).children(); for (ExpressionNode child : children) { - addRoot(child, context); + addRoot(child, context, profile); } } } - private void addIfAttribute(ReferenceNode node, Context context) { + private void addIfAttribute(ReferenceNode node, Context context, RankProfile profile) { if (!node.getName().equals("attribute")) { return; } @@ -181,7 +168,7 @@ public class TensorTransformer extends ExpressionTransformer { return; } String attribute = node.children().get(0).toString(); - Attribute a = search.getAttribute(attribute); + Attribute a = profile.getSearch().getAttribute(attribute); if (a == null) { return; } @@ -196,7 +183,7 @@ public class TensorTransformer extends ExpressionTransformer { context.put(node.toString(), v); } - private void addIfConstant(ReferenceNode node, Context context) { + private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) { if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) { return; } @@ -208,25 +195,25 @@ public class TensorTransformer extends ExpressionTransformer { child = ((CompositeNode) child).children().get(0); } String name = child.toString(); - addIfConstantInRankProfile(name, node, context); - addIfConstantInRankingConstants(name, node, context); + addIfConstantInRankProfile(name, node, context, profile); + addIfConstantInRankingConstants(name, node, context, profile); } - private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) { - if (rankprofile.getConstants().containsKey(name)) { - context.put(node.toString(), rankprofile.getConstants().get(name)); + private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context, RankProfile profile) { + if (profile.getConstants().containsKey(name)) { + context.put(node.toString(), profile.getConstants().get(name)); } } - private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) { - for (RankingConstant rankingConstant : search.getRankingConstants()) { + private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context, RankProfile profile) { + for (RankingConstant rankingConstant : profile.getSearch().getRankingConstants()) { if (rankingConstant.getName().equals(name)) { context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType())); } } } - private void addIfQuery(ReferenceNode node, Context context) { + private void addIfQuery(ReferenceNode node, Context context, RankProfile profile) { if (!node.getName().equals("query")) { return; } @@ -234,8 +221,8 @@ public class TensorTransformer extends ExpressionTransformer { return; } String name = node.children().get(0).toString(); - if (rankprofile.getQueryFeatureTypes().containsKey(name)) { - String type = rankprofile.getQueryFeatureTypes().get(name); + if (profile.getQueryFeatureTypes().containsKey(name)) { + String type = profile.getQueryFeatureTypes().get(name); Value v; if (type.contains("tensor")) { v = emptyTensorValue(TensorType.fromSpec(type)); @@ -267,13 +254,13 @@ public class TensorTransformer extends ExpressionTransformer { context.put(node.toString(), emptyTensorValue(type)); } - private void addIfMacro(ReferenceNode node, Context context) { - RankProfile.Macro macro = macros.get(node.getName()); + private void addIfMacro(ReferenceNode node, Context context, RankProfile profile) { + RankProfile.Macro macro = profile.getMacros().get(node.getName()); if (macro == null) { return; } ExpressionNode root = macro.getRankingExpression().getRoot(); - Context macroContext = buildContext(root); + Context macroContext = buildContext(root, profile); addMacroArguments(node, context, macro, macroContext); Value value = root.evaluate(macroContext); context.put(node.toString(), value); |