summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java282
1 files changed, 282 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
new file mode 100644
index 00000000000..e9723042d77
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
@@ -0,0 +1,282 @@
+package com.yahoo.searchdefinition;
+
+import com.yahoo.searchdefinition.document.Attribute;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Transforms and simplifies tensor expressions.
+ *
+ * Currently transforms min(tensor,dim) and max(tensor,dim) to
+ * reduce(tensor,min/max,dim). This is necessary as the backend does
+ * not recognize these forms of min and max.
+ *
+ * @author lesters
+ */
+public class TensorTransformer extends ExpressionTransformer {
+
+ private Search search;
+ private RankProfile rankprofile;
+ private Map<String, RankProfile.Macro> macros;
+
+ public TensorTransformer(RankProfile rankprofile) {
+ this.rankprofile = rankprofile;
+ this.search = rankprofile.getSearch();
+ this.macros = rankprofile.getMacros();
+ }
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node) {
+ if (node instanceof CompositeNode) {
+ node = transformChildren((CompositeNode) node);
+ }
+ if (node instanceof FunctionNode) {
+ node = transformFunctionNode((FunctionNode) node);
+ }
+ return node;
+ }
+
+ private ExpressionNode transformFunctionNode(FunctionNode node) {
+ switch (node.getFunction()) {
+ case min:
+ case max:
+ return transformMaxAndMinFunctionNode(node);
+ }
+ return node;
+ }
+
+ /**
+ * Transforms max and min functions if it can be proven that the first
+ * argument resolves to a tensor and the second argument is a valid
+ * dimension in the tensor. If these do not hold, the node will not
+ * be transformed.
+ *
+ * The test for whether or not the first argument resolves to a tensor
+ * is to evaluate that expression. All values used in the expression
+ * is bound to a context with dummy values with enough information to
+ * deduce tensor types.
+ *
+ * There is currently no guarantee that all cases will be found. For
+ * instance, if-statements are problematic.
+ */
+ private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) {
+ if (node.children().size() != 2) {
+ return node;
+ }
+ ExpressionNode arg1 = node.children().get(0);
+ ExpressionNode arg2 = node.children().get(1);
+ if (!potentialDimensionName(arg2)) {
+ return node;
+ }
+ try {
+ String dimension = ((ReferenceNode) arg2).getName();
+ Context context = buildContext(arg1);
+ Value value = arg1.evaluate(context);
+ if (verifyTensorAndDimension(value, dimension)) {
+ return replaceMaxAndMinFunction(node);
+ }
+ } catch (Exception e) {
+ // Don't replace the expression in case of any errors, e.g. unknown values or rank features
+ }
+ return node;
+ }
+
+ private boolean potentialDimensionName(ExpressionNode arg) {
+ return arg instanceof ReferenceNode && ((ReferenceNode) arg).children().size() == 0;
+ }
+
+ private boolean verifyTensorAndDimension(Value value, String dimension) {
+ if (value instanceof TensorValue) {
+ Tensor tensor = ((TensorValue) value).asTensor();
+ TensorType type = tensor.type();
+ return type.dimensionNames().contains(dimension);
+ }
+ return false;
+ }
+
+ private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
+ ExpressionNode arg1 = node.children().get(0);
+ ExpressionNode arg2 = node.children().get(1);
+
+ TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
+ Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
+ String dimension = ((ReferenceNode) arg2).getName();
+
+ return new TensorFunctionNode(new Reduce(expression, aggregator, dimension));
+ }
+
+ /**
+ * Creates an evaluation context by iterating through the expression tree, and
+ * adding dummy values with correct types to the context.
+ */
+ private Context buildContext(ExpressionNode node) {
+ Context context = new MapContext();
+ addRoot(node, context);
+ return context;
+ }
+
+ private Value emptyStringValue() {
+ return new StringValue("");
+ }
+
+ private Value emptyDoubleValue() {
+ return new DoubleValue(0.0);
+ }
+
+ private Value emptyTensorValue(TensorType type) {
+ Tensor empty = Tensor.Builder.of(type).build();
+ return new TensorValue(empty);
+ }
+
+ private void addRoot(ExpressionNode node, Context context) {
+ addChildren(node, context);
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ addIfAttribute(referenceNode, context);
+ addIfConstant(referenceNode, context);
+ addIfQuery(referenceNode, context);
+ addIfTensorFrom(referenceNode, context);
+ addIfMacro(referenceNode, context);
+ }
+ }
+
+ private void addChildren(ExpressionNode node, Context context) {
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode) node).children();
+ for (ExpressionNode child : children) {
+ addRoot(child, context);
+ }
+ }
+ }
+
+ private void addIfAttribute(ReferenceNode node, Context context) {
+ if (!node.getName().equals("attribute")) {
+ return;
+ }
+ if (node.children().size() == 0) {
+ return;
+ }
+ String attribute = node.children().get(0).toString();
+ Attribute a = search.getAttribute(attribute);
+ Value v;
+ if (a.getType() == Attribute.Type.STRING) {
+ v = emptyStringValue();
+ } else if (a.getType() == Attribute.Type.TENSOR) {
+ v = emptyTensorValue(a.tensorType().orElseThrow(RuntimeException::new));
+ } else {
+ v = emptyDoubleValue();
+ }
+ context.put(node.toString(), v);
+ }
+
+ private void addIfConstant(ReferenceNode node, Context context) {
+ if (!node.getName().equals("constant")) {
+ return;
+ }
+ if (node.children().size() != 1) {
+ return;
+ }
+ ExpressionNode child = node.children().get(0);
+ while (child instanceof CompositeNode && ((CompositeNode) child).children().size() > 0) {
+ child = ((CompositeNode) child).children().get(0);
+ }
+ String name = child.toString();
+ addIfConstantInRankProfile(name, node, context);
+ addIfConstantInRankingConstants(name, node, context);
+ }
+
+ private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) {
+ if (rankprofile.getConstants().containsKey(name)) {
+ context.put(node.toString(), rankprofile.getConstants().get(name));
+ }
+ }
+
+ private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) {
+ for (RankingConstant rankingConstant : search.getRankingConstants()) {
+ if (rankingConstant.getName().equals(name)) {
+ context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType()));
+ }
+ }
+ }
+
+ private void addIfQuery(ReferenceNode node, Context context) {
+ if (!node.getName().equals("query")) {
+ return;
+ }
+ if (node.children().size() != 1) {
+ return;
+ }
+ String name = node.children().get(0).toString();
+ if (rankprofile.getQueryFeatureTypes().containsKey(name)) {
+ String type = rankprofile.getQueryFeatureTypes().get(name);
+ Value v;
+ if (type.contains("tensor")) {
+ v = emptyTensorValue(TensorType.fromSpec(type));
+ } else if (type.equalsIgnoreCase("string")) {
+ v = emptyStringValue();
+ } else {
+ v = emptyDoubleValue();
+ }
+ context.put(node.toString(), v);
+ }
+ }
+
+ private void addIfTensorFrom(ReferenceNode node, Context context) {
+ if (!node.getName().startsWith("tensorFrom")) {
+ return;
+ }
+ if (node.children().size() < 1 || node.children().size() > 2) {
+ return;
+ }
+ ExpressionNode source = node.children().get(0);
+ if (source instanceof CompositeNode && ((CompositeNode) source).children().size() > 0) {
+ source = ((CompositeNode) source).children().get(0);
+ }
+ String dimension = source.toString();
+ if (node.children().size() == 2) {
+ dimension = node.children().get(1).toString();
+ }
+ TensorType type = (new TensorType.Builder()).mapped(dimension).build();
+ context.put(node.toString(), emptyTensorValue(type));
+ }
+
+ private void addIfMacro(ReferenceNode node, Context context) {
+ RankProfile.Macro macro = macros.get(node.getName());
+ if (macro == null) {
+ return;
+ }
+ ExpressionNode root = macro.getRankingExpression().getRoot();
+ Context macroContext = buildContext(root);
+ addMacroArguments(node, context, macro, macroContext);
+ Value value = root.evaluate(macroContext);
+ context.put(node.toString(), value);
+ }
+
+ private void addMacroArguments(ReferenceNode node, Context context, RankProfile.Macro macro, Context macroContext) {
+ if (macro.getFormalParams().size() > 0 && node.children().size() > 0) {
+ for (int i = 0; i < macro.getFormalParams().size() && i < node.children().size(); ++i) {
+ String param = macro.getFormalParams().get(i);
+ ExpressionNode argumentExpression = node.children().get(i);
+ Value arg = argumentExpression.evaluate(context);
+ macroContext.put(param, arg);
+ }
+ }
+ }
+
+}