diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java | 178 |
1 files changed, 15 insertions, 163 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java index fd2a5fcf2e4..fe232299363 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.document.Attribute; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; @@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; @@ -75,11 +77,20 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor return node; } - private Optional<String> dimensionName(ExpressionNode arg) { - if (arg instanceof ReferenceNode && ((ReferenceNode)arg).children().size() == 0) { - return Optional.of(((ReferenceNode) arg).getName()); + private Optional<String> dimensionName(ExpressionNode node) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (reference.isIdentifier()) + return Optional.of(reference.name()); + else + return Optional.empty(); + } + else if (node instanceof NameNode) { + return Optional.of(((NameNode)node).getValue()); + } + else { + return Optional.empty(); } - return Optional.empty(); } private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { @@ -93,163 +104,4 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); } - /** - * Creates an evaluation context by iterating through the expression tree, and - * adding dummy values with correct types to the context. - */ - private Context buildContext(ExpressionNode node, RankProfile profile) { - Context context = new MapContext(); - addRoot(node, context, profile); - return context; - } - - private Value emptyStringValue() { - return new StringValue(""); - } - - private Value emptyDoubleValue() { - return new DoubleValue(0.0); - } - - private Value emptyTensorValue(TensorType type) { - Tensor empty = Tensor.Builder.of(type).build(); - return new TensorValue(empty); - } - - private void addRoot(ExpressionNode node, Context context, RankProfile profile) { - addChildren(node, context, profile); - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - addIfAttribute(referenceNode, context, profile); - addIfConstant(referenceNode, context, profile); - addIfQuery(referenceNode, context, profile); - addIfTensorFrom(referenceNode, context); - addIfMacro(referenceNode, context, profile); - } - } - - private void addChildren(ExpressionNode node, Context context, RankProfile profile) { - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode) node).children(); - for (ExpressionNode child : children) { - addRoot(child, context, profile); - } - } - } - - private void addIfAttribute(ReferenceNode node, Context context, RankProfile profile) { - if (!node.getName().equals("attribute")) { - return; - } - if (node.children().size() == 0) { - return; - } - String attribute = node.children().get(0).toString(); - Attribute a = profile.getSearch().getAttribute(attribute); - if (a == null) { - return; - } - Value v; - if (a.getType() == Attribute.Type.STRING) { - v = emptyStringValue(); - } else if (a.getType() == Attribute.Type.TENSOR) { - v = emptyTensorValue(a.tensorType().orElseThrow(RuntimeException::new)); - } else { - v = emptyDoubleValue(); - } - context.put(node.toString(), v); - } - - private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) { - if ( ! node.getName().equals(ConstantTensorTransformer.CONSTANT)) { - return; - } - if (node.children().size() != 1) { - return; - } - ExpressionNode child = node.children().get(0); - while (child instanceof CompositeNode && ((CompositeNode) child).children().size() > 0) { - child = ((CompositeNode) child).children().get(0); - } - String name = child.toString(); - addIfConstantInRankProfile(name, node, context, profile); - addIfConstantInRankingConstants(name, node, context, profile); - } - - private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context, RankProfile profile) { - if (profile.getConstants().containsKey(name)) { - context.put(node.toString(), profile.getConstants().get(name)); - } - } - - private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context, RankProfile profile) { - RankingConstant constant = profile.getSearch().getRankingConstants().get(name); - if (constant != null) - context.put(node.toString(), emptyTensorValue(constant.getTensorType())); - } - - private void addIfQuery(ReferenceNode node, Context context, RankProfile profile) { - if (!node.getName().equals("query")) { - return; - } - if (node.children().size() != 1) { - return; - } - String name = node.children().get(0).toString(); - if (profile.getQueryFeatureTypes().containsKey(name)) { - String type = profile.getQueryFeatureTypes().get(name); - Value v; - if (type.contains("tensor")) { - v = emptyTensorValue(TensorType.fromSpec(type)); - } else if (type.equalsIgnoreCase("string")) { - v = emptyStringValue(); - } else { - v = emptyDoubleValue(); - } - context.put(node.toString(), v); - } - } - - private void addIfTensorFrom(ReferenceNode node, Context context) { - if (!node.getName().startsWith("tensorFrom")) { - return; - } - if (node.children().size() < 1 || node.children().size() > 2) { - return; - } - ExpressionNode source = node.children().get(0); - if (source instanceof CompositeNode && ((CompositeNode) source).children().size() > 0) { - source = ((CompositeNode) source).children().get(0); - } - String dimension = source.toString(); - if (node.children().size() == 2) { - dimension = node.children().get(1).toString(); - } - TensorType type = (new TensorType.Builder()).mapped(dimension).build(); - context.put(node.toString(), emptyTensorValue(type)); - } - - private void addIfMacro(ReferenceNode node, Context context, RankProfile profile) { - RankProfile.Macro macro = profile.getMacros().get(node.getName()); - if (macro == null) { - return; - } - ExpressionNode root = macro.getRankingExpression().getRoot(); - Context macroContext = buildContext(root, profile); - addMacroArguments(node, context, macro, macroContext); - Value value = root.evaluate(macroContext); - context.put(node.toString(), value); - } - - private void addMacroArguments(ReferenceNode node, Context context, RankProfile.Macro macro, Context macroContext) { - if (macro.getFormalParams().size() > 0 && node.children().size() > 0) { - for (int i = 0; i < macro.getFormalParams().size() && i < node.children().size(); ++i) { - String param = macro.getFormalParams().get(i); - ExpressionNode argumentExpression = node.children().get(i); - Value arg = argumentExpression.evaluate(context); - macroContext.put(param, arg); - } - } - } - } |