diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 10:04:27 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 10:04:27 +0100 |
commit | ba6a11e6e2674a2b5c1ef967319fb269f989a216 (patch) | |
tree | b15c1c046989cafeed19d193fdb59634140d3db6 | |
parent | 3e1477f5fda4a3dcd436a6d41843adc66e19f370 (diff) |
Use a context for transform state
15 files changed, 213 insertions, 183 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 118fc8b6211..fa202770e26 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -63,11 +63,11 @@ public class DerivedConfiguration { */ public DerivedConfiguration(Search search, List<Search> abstractSearchList, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry) { Validator.ensureNotNull("Search definition", search); - if (!search.isProcessed()) { + if ( ! search.isProcessed()) { throw new IllegalArgumentException("Search '" + search.getName() + "' not processed."); } this.search = search; - if (!search.isDocumentsOnly()) { + if ( ! search.isDocumentsOnly()) { streamingFields = new VsmFields(search); streamingSummary = new VsmSummary(search); } @@ -160,15 +160,15 @@ public class DerivedConfiguration { public Search getSearch() { return search; } - + public RankProfileList getRankProfileList() { return rankProfileList; } - + public VsmSummary getVsmSummary() { return streamingSummary; } - + public VsmFields getVsmFields() { return streamingFields; } @@ -180,7 +180,7 @@ public class DerivedConfiguration { public Juniperrc getJuniperrc() { return juniperrc; } - + public SummaryMap getSummaryMap() { return summaryMap; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java index e061ead465e..f835e0a6ed1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java @@ -8,11 +8,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.NameNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Map; /** * Transforms named references to constant tensors with the rank feature 'constant'. @@ -23,53 +23,44 @@ public class ConstantTensorTransformer extends ExpressionTransformer { public static final String CONSTANT = "constant"; - private final Map<String, Value> constants; - private final Map<String, String> rankPropertiesOutput; - - public ConstantTensorTransformer(Map<String, Value> constants, - Map<String, String> rankPropertiesOutput) { - this.constants = constants; - this.rankPropertiesOutput = rankPropertiesOutput; - } - @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof ReferenceNode) { - return transformFeature((ReferenceNode) node); + return transformFeature((ReferenceNode) node, (RankProfileTransformContext)context); } else if (node instanceof CompositeNode) { - return transformChildren((CompositeNode) node); + return transformChildren((CompositeNode) node, context); } else { return node; } } - private ExpressionNode transformFeature(ReferenceNode node) { + private ExpressionNode transformFeature(ReferenceNode node, RankProfileTransformContext context) { if (!node.getArguments().isEmpty()) { - return transformArguments(node); + return transformArguments(node, context); } else { - return transformConstantReference(node); + return transformConstantReference(node, context); } } - private ExpressionNode transformArguments(ReferenceNode node) { + private ExpressionNode transformArguments(ReferenceNode node, TransformContext context) { List<ExpressionNode> arguments = node.getArguments().expressions(); List<ExpressionNode> transformedArguments = new ArrayList<>(arguments.size()); for (ExpressionNode argument : arguments) { - transformedArguments.add(transform(argument)); + transformedArguments.add(transform(argument, context)); } return node.setArguments(transformedArguments); } - private ExpressionNode transformConstantReference(ReferenceNode node) { - Value value = constants.get(node.getName()); + private ExpressionNode transformConstantReference(ReferenceNode node, RankProfileTransformContext context) { + Value value = context.constants().get(node.getName()); if (value == null || !(value instanceof TensorValue)) { return node; } TensorValue tensorValue = (TensorValue)value; String featureName = CONSTANT + "(" + node.getName() + ")"; String tensorType = tensorValue.asTensor().type().toString(); - rankPropertiesOutput.put(featureName + ".value", tensorValue.toString()); - rankPropertiesOutput.put(featureName + ".type", tensorType); + context.rankPropertiesOutput().put(featureName + ".value", tensorValue.toString()); + context.rankPropertiesOutput().put(featureName + ".type", tensorType); return new ReferenceNode("constant", Arrays.asList(new NameNode(node.getName())), null); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index ee5cccccb29..d7a38f47766 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -1,33 +1,44 @@ package com.yahoo.searchdefinition.expressiontransforms; +import com.google.common.collect.ImmutableList; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.transform.ConstantDereferencer; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.searchlib.rankingexpression.transform.Simplifier; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; +import java.util.List; import java.util.Map; /** * The transformations done on ranking expressions done at config time before passing them on to the Vespa * engine for execution. * + * An instance of this class has scope of one complete deployment. + * * @author bratseth */ public class ExpressionTransforms { + private final List<ExpressionTransformer> transforms = + ImmutableList.of(new TensorFlowFeatureConverter(), + new ConstantDereferencer(), + new ConstantTensorTransformer(), + new MacroInliner(), + new MacroShadower(), + new TensorTransformer(), + new Simplifier()); + public RankingExpression transform(RankingExpression expression, RankProfile rankProfile, Map<String, Value> constants, Map<String, RankProfile.Macro> inlineMacros, Map<String, String> rankPropertiesOutput) { - expression = new TensorFlowFeatureConverter(rankProfile).transform(expression); - expression = new ConstantDereferencer(constants).transform(expression); - expression = new ConstantTensorTransformer(constants, rankPropertiesOutput).transform(expression); - expression = new MacroInliner(inlineMacros).transform(expression); - expression = new MacroShadower(rankProfile.getMacros()).transform(expression); - expression = new TensorTransformer(rankProfile).transform(expression); - expression = new Simplifier().transform(expression); + TransformContext context = new RankProfileTransformContext(rankProfile, constants, inlineMacros, rankPropertiesOutput); + for (ExpressionTransformer transformer : transforms) + expression = transformer.transform(expression, context); return expression; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java index a3933e6f8e2..6702955bae3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java @@ -6,8 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; - -import java.util.Map; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; /** * Inlines macros in ranking expressions @@ -16,25 +15,19 @@ import java.util.Map; */ public class MacroInliner extends ExpressionTransformer { - private final Map<String, RankProfile.Macro> macros; - - public MacroInliner(Map<String, RankProfile.Macro> macros) { - this.macros = macros; - } - @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof ReferenceNode) - return transformFeatureNode((ReferenceNode)node); + return transformFeatureNode((ReferenceNode)node, (RankProfileTransformContext)context); if (node instanceof CompositeNode) - return transformChildren((CompositeNode)node); + return transformChildren((CompositeNode)node, context); return node; } - private ExpressionNode transformFeatureNode(ReferenceNode feature) { - RankProfile.Macro macro = macros.get(feature.getName()); + private ExpressionNode transformFeatureNode(ReferenceNode feature, RankProfileTransformContext context) { + RankProfile.Macro macro = context.inlineMacros().get(feature.getName()); if (macro == null) return feature; - return transform(macro.getRankingExpression().getRoot()); // inline recursively and return + return transform(macro.getRankingExpression().getRoot(), context); // inline recursively and return } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java index 1d9769d0d78..6eabb5ddcd4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java @@ -3,10 +3,12 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.*; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; - -import java.util.Map; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; /** * Transforms function nodes to reference nodes if a macro shadows a built-in function. @@ -23,44 +25,38 @@ import java.util.Map; */ public class MacroShadower extends ExpressionTransformer { - private final Map<String, RankProfile.Macro> macros; - - public MacroShadower(Map<String, RankProfile.Macro> macros) { - this.macros = macros; - } - @Override - public RankingExpression transform(RankingExpression expression) { + public RankingExpression transform(RankingExpression expression, TransformContext context) { String name = expression.getName(); ExpressionNode node = expression.getRoot(); - ExpressionNode result = transform(node); + ExpressionNode result = transform(node, context); return new RankingExpression(name, result); } @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof FunctionNode) - return transformFunctionNode((FunctionNode) node); + return transformFunctionNode((FunctionNode) node, context); if (node instanceof CompositeNode) - return transformChildren((CompositeNode)node); + return transformChildren((CompositeNode)node, context); return node; } - protected ExpressionNode transformFunctionNode(FunctionNode function) { + protected ExpressionNode transformFunctionNode(FunctionNode function, TransformContext context) { String name = function.getFunction().toString(); - RankProfile.Macro macro = macros.get(name); + RankProfile.Macro macro = ((RankProfileTransformContext)context).rankProfile().getMacros().get(name); if (macro == null) { - return transformChildren(function); + return transformChildren(function, context); } int functionArity = function.getFunction().arity(); int macroArity = macro.getFormalParams() != null ? macro.getFormalParams().size() : 0; if (functionArity != macroArity) { - return transformChildren(function); + return transformChildren(function, context); } ReferenceNode node = new ReferenceNode(name, function.children(), null); - return transformChildren(node); + return transformChildren(node, context); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java new file mode 100644 index 00000000000..fb996d70607 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -0,0 +1,34 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; + +import java.util.Map; + +/** + * Extends the transform context with rank profile information + * + * @author bratseth + */ +public class RankProfileTransformContext extends TransformContext { + + private final RankProfile rankProfile; + private final Map<String, RankProfile.Macro> inlineMacros; + private final Map<String, String> rankPropertiesOutput; + + RankProfileTransformContext(RankProfile rankProfile, + Map<String, Value> constants, + Map<String, RankProfile.Macro> inlineMacros, + Map<String, String> rankPropertiesOutput) { + super(constants); + this.rankProfile = rankProfile; + this.inlineMacros = inlineMacros; + this.rankPropertiesOutput = rankPropertiesOutput; + } + + public RankProfile rankProfile() { return rankProfile; } + public Map<String, RankProfile.Macro> inlineMacros() { return inlineMacros; } + public Map<String, String> rankPropertiesOutput() { return rankPropertiesOutput; } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index e5886030d44..b7033d4ad9f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,6 +1,5 @@ package com.yahoo.searchdefinition.expressiontransforms; -import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; @@ -10,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import java.util.Map; import java.util.Optional; @@ -24,23 +24,18 @@ import java.util.Optional; public class TensorFlowFeatureConverter extends ExpressionTransformer { private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); - private final RankProfile profile; - - public TensorFlowFeatureConverter(RankProfile profile) { - this.profile = profile; - } @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof ReferenceNode) - return transformFeature((ReferenceNode) node); + return transformFeature((ReferenceNode) node, (RankProfileTransformContext)context); else if (node instanceof CompositeNode) - return super.transformChildren((CompositeNode) node); + return super.transformChildren((CompositeNode) node, context); else return node; } - private ExpressionNode transformFeature(ReferenceNode feature) { + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { try { if ( ! feature.getName().equals("tensorflow")) return feature; @@ -48,15 +43,16 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer { throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); - // Find the specified expression ImportResult result = tensorFlowImporter.importModel(asString(feature.getArguments().expressions().get(0))); + + // Find the specified expression ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(), optionalArgument(1, feature.getArguments())); String output = chooseOrDefault("outputs", signature.outputs(), optionalArgument(2, feature.getArguments())); // Add all constants - result.constants().forEach((k, v) -> profile.addConstantTensor(k, new TensorValue(v))); + result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); return result.expressions().get(output).getRoot(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java index 70a7372dbe9..971c2c4f218 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -3,7 +3,6 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.document.Attribute; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -17,12 +16,12 @@ import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import java.util.List; -import java.util.Map; import java.util.Optional; /** @@ -36,32 +35,22 @@ import java.util.Optional; */ public class TensorTransformer extends ExpressionTransformer { - private Search search; - private RankProfile rankprofile; - private Map<String, RankProfile.Macro> macros; - - public TensorTransformer(RankProfile rankprofile) { - this.rankprofile = rankprofile; - this.search = rankprofile.getSearch(); - this.macros = rankprofile.getMacros(); - } - @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof CompositeNode) { - node = transformChildren((CompositeNode) node); + node = transformChildren((CompositeNode) node, context); } if (node instanceof FunctionNode) { - node = transformFunctionNode((FunctionNode) node); + node = transformFunctionNode((FunctionNode) node, ((RankProfileTransformContext)context).rankProfile()); } return node; } - private ExpressionNode transformFunctionNode(FunctionNode node) { + private ExpressionNode transformFunctionNode(FunctionNode node, RankProfile profile) { switch (node.getFunction()) { case min: case max: - return transformMaxAndMinFunctionNode(node); + return transformMaxAndMinFunctionNode(node, profile); } return node; } @@ -80,7 +69,7 @@ public class TensorTransformer extends ExpressionTransformer { * There is currently no guarantee that all cases will be found. For * instance, if-statements are problematic. */ - private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) { + private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, RankProfile profile) { if (node.children().size() != 2) { return node; } @@ -88,7 +77,7 @@ public class TensorTransformer extends ExpressionTransformer { Optional<String> dimension = dimensionName(node.children().get(1)); if (dimension.isPresent()) { try { - Context context = buildContext(arg1); + Context context = buildContext(arg1, profile); Value value = arg1.evaluate(context); if (isTensorWithDimension(value, dimension.get())) { return replaceMaxAndMinFunction(node); @@ -110,12 +99,10 @@ public class TensorTransformer extends ExpressionTransformer { } private boolean isTensorWithDimension(Value value, String dimension) { - if (value instanceof TensorValue) { - Tensor tensor = ((TensorValue) value).asTensor(); - TensorType type = tensor.type(); - return type.dimensionNames().contains(dimension); - } - return false; + if (value instanceof TensorValue) + return value.asTensor().type().dimensionNames().contains(dimension); + else + return false; } private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { @@ -133,9 +120,9 @@ public class TensorTransformer extends ExpressionTransformer { * Creates an evaluation context by iterating through the expression tree, and * adding dummy values with correct types to the context. */ - private Context buildContext(ExpressionNode node) { + private Context buildContext(ExpressionNode node, RankProfile profile) { Context context = new MapContext(); - addRoot(node, context); + addRoot(node, context, profile); return context; } @@ -152,28 +139,28 @@ public class TensorTransformer extends ExpressionTransformer { return new TensorValue(empty); } - private void addRoot(ExpressionNode node, Context context) { - addChildren(node, context); + private void addRoot(ExpressionNode node, Context context, RankProfile profile) { + addChildren(node, context, profile); if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - addIfAttribute(referenceNode, context); - addIfConstant(referenceNode, context); - addIfQuery(referenceNode, context); + addIfAttribute(referenceNode, context, profile); + addIfConstant(referenceNode, context, profile); + addIfQuery(referenceNode, context, profile); addIfTensorFrom(referenceNode, context); - addIfMacro(referenceNode, context); + addIfMacro(referenceNode, context, profile); } } - private void addChildren(ExpressionNode node, Context context) { + private void addChildren(ExpressionNode node, Context context, RankProfile profile) { if (node instanceof CompositeNode) { List<ExpressionNode> children = ((CompositeNode) node).children(); for (ExpressionNode child : children) { - addRoot(child, context); + addRoot(child, context, profile); } } } - private void addIfAttribute(ReferenceNode node, Context context) { + private void addIfAttribute(ReferenceNode node, Context context, RankProfile profile) { if (!node.getName().equals("attribute")) { return; } @@ -181,7 +168,7 @@ public class TensorTransformer extends ExpressionTransformer { return; } String attribute = node.children().get(0).toString(); - Attribute a = search.getAttribute(attribute); + Attribute a = profile.getSearch().getAttribute(attribute); if (a == null) { return; } @@ -196,7 +183,7 @@ public class TensorTransformer extends ExpressionTransformer { context.put(node.toString(), v); } - private void addIfConstant(ReferenceNode node, Context context) { + private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) { if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) { return; } @@ -208,25 +195,25 @@ public class TensorTransformer extends ExpressionTransformer { child = ((CompositeNode) child).children().get(0); } String name = child.toString(); - addIfConstantInRankProfile(name, node, context); - addIfConstantInRankingConstants(name, node, context); + addIfConstantInRankProfile(name, node, context, profile); + addIfConstantInRankingConstants(name, node, context, profile); } - private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) { - if (rankprofile.getConstants().containsKey(name)) { - context.put(node.toString(), rankprofile.getConstants().get(name)); + private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context, RankProfile profile) { + if (profile.getConstants().containsKey(name)) { + context.put(node.toString(), profile.getConstants().get(name)); } } - private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) { - for (RankingConstant rankingConstant : search.getRankingConstants()) { + private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context, RankProfile profile) { + for (RankingConstant rankingConstant : profile.getSearch().getRankingConstants()) { if (rankingConstant.getName().equals(name)) { context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType())); } } } - private void addIfQuery(ReferenceNode node, Context context) { + private void addIfQuery(ReferenceNode node, Context context, RankProfile profile) { if (!node.getName().equals("query")) { return; } @@ -234,8 +221,8 @@ public class TensorTransformer extends ExpressionTransformer { return; } String name = node.children().get(0).toString(); - if (rankprofile.getQueryFeatureTypes().containsKey(name)) { - String type = rankprofile.getQueryFeatureTypes().get(name); + if (profile.getQueryFeatureTypes().containsKey(name)) { + String type = profile.getQueryFeatureTypes().get(name); Value v; if (type.contains("tensor")) { v = emptyTensorValue(TensorType.fromSpec(type)); @@ -267,13 +254,13 @@ public class TensorTransformer extends ExpressionTransformer { context.put(node.toString(), emptyTensorValue(type)); } - private void addIfMacro(ReferenceNode node, Context context) { - RankProfile.Macro macro = macros.get(node.getName()); + private void addIfMacro(ReferenceNode node, Context context, RankProfile profile) { + RankProfile.Macro macro = profile.getMacros().get(node.getName()); if (macro == null) { return; } ExpressionNode root = macro.getRankingExpression().getRoot(); - Context macroContext = buildContext(root); + Context macroContext = buildContext(root, profile); addMacroArguments(node, context, macro, macroContext); Value value = root.evaluate(macroContext); context.put(node.toString(), value); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index e5693d24f0f..475fee62177 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -92,7 +92,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { private void assertContainsExpression(String expr, String transformedExpression) throws ParseException { assertTrue("Expected expression '" + transformedExpression + "' not found", - containsExpression(expr, transformedExpression)); + containsExpression(expr, transformedExpression)); } private boolean containsExpression(String expr, String transformedExpression) throws ParseException { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index 46f79dcc6bd..1b8239ba643 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -10,7 +10,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import java.util.ArrayList; import java.util.List; -import java.util.Map; /** * Replaces "features" which found in the given constants by their constant value @@ -19,40 +18,33 @@ import java.util.Map; */ public class ConstantDereferencer extends ExpressionTransformer { - /** The map of constants to dereference */ - private final Map<String, Value> constants; - - public ConstantDereferencer(Map<String, Value> constants) { - this.constants = constants; - } - @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof ReferenceNode) - return transformFeature((ReferenceNode) node); + return transformFeature((ReferenceNode) node, context); else if (node instanceof CompositeNode) - return transformChildren((CompositeNode)node); + return transformChildren((CompositeNode)node, context); else return node; } - private ExpressionNode transformFeature(ReferenceNode node) { + private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) { if (!node.getArguments().isEmpty()) - return transformArguments(node); + return transformArguments(node, context); else - return transformConstantReference(node); + return transformConstantReference(node, context); } - private ExpressionNode transformArguments(ReferenceNode node) { + private ExpressionNode transformArguments(ReferenceNode node, TransformContext context) { List<ExpressionNode> arguments = node.getArguments().expressions(); List<ExpressionNode> transformedArguments = new ArrayList<>(arguments.size()); for (ExpressionNode argument : arguments) - transformedArguments.add(transform(argument)); + transformedArguments.add(transform(argument, context)); return node.setArguments(transformedArguments); } - private ExpressionNode transformConstantReference(ReferenceNode node) { - Value value = constants.get(node.getName()); + private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) { + Value value = context.constants().get(node.getName()); if (value == null || (value instanceof TensorValue)) { return node; // not a value constant reference } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java index bcc8b817641..c585c0dea1f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java @@ -15,22 +15,22 @@ import java.util.List; */ public abstract class ExpressionTransformer { - public RankingExpression transform(RankingExpression expression) { - return new RankingExpression(expression.getName(), transform(expression.getRoot())); + public RankingExpression transform(RankingExpression expression, TransformContext context) { + return new RankingExpression(expression.getName(), transform(expression.getRoot(), context)); } /** Transforms an expression node and returns the transformed node */ - public abstract ExpressionNode transform(ExpressionNode node); + public abstract ExpressionNode transform(ExpressionNode node, TransformContext context); /** * Utility method which calls transform on each child of the given node and return the resulting transformed * composite */ - protected CompositeNode transformChildren(CompositeNode node) { + protected CompositeNode transformChildren(CompositeNode node, TransformContext context) { List<ExpressionNode> children = node.children(); List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); for (ExpressionNode child : children) - transformedChildren.add(transform(child)); + transformedChildren.add(transform(child, context)); return node.setChildren(transformedChildren); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index ebad0d5c21f..9e8491340b0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; @@ -10,8 +9,8 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import java.util.ArrayList; import java.util.List; @@ -24,9 +23,9 @@ import java.util.List; public class Simplifier extends ExpressionTransformer { @Override - public ExpressionNode transform(ExpressionNode node) { + public ExpressionNode transform(ExpressionNode node, TransformContext context) { if (node instanceof CompositeNode) - node = transformChildren((CompositeNode) node); // depth first + node = transformChildren((CompositeNode) node, context); // depth first if (node instanceof IfNode) node = transformIf((IfNode) node); if (node instanceof EmbracedNode && hasSingleUndividableChild((EmbracedNode)node)) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java new file mode 100644 index 00000000000..746ca3b3200 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java @@ -0,0 +1,22 @@ +package com.yahoo.searchlib.rankingexpression.transform; + +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Map; + +/** + * Provides a context in which transforms on ranking expressions take place. + * + * @author bratseth + */ +public class TransformContext { + + private final Map<String, Value> constants; + + public TransformContext(Map<String, Value> constants) { + this.constants = constants; + } + + public Map<String, Value> constants() { return constants; } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java index 4035e499a6a..84e51835458 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java @@ -5,11 +5,12 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import org.junit.Test; -import static org.junit.Assert.*; import java.util.HashMap; import java.util.Map; +import static org.junit.Assert.assertEquals; + /** * @author bratseth */ @@ -17,14 +18,16 @@ public class ConstantDereferencerTestCase { @Test public void testConstantDereferencer() throws ParseException { + ConstantDereferencer c = new ConstantDereferencer(); + Map<String, Value> constants = new HashMap<>(); constants.put("a", Value.parse("1.0")); constants.put("b", Value.parse("2")); constants.put("c", Value.parse("3.5")); - ConstantDereferencer c = new ConstantDereferencer(constants); + TransformContext context = new TransformContext(constants); - assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c")).toString()); - assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)")).toString()); + assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString()); + assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)"), context).toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java index f9d2472e306..8fac3395ac0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java @@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import org.junit.Test; + +import java.util.Collections; + import static org.junit.Assert.*; /** @@ -17,31 +20,33 @@ public class SimplifierTestCase { @Test public void testSimplify() throws ParseException { Simplifier s = new Simplifier(); - assertEquals("a + b", s.transform(new RankingExpression("a + b")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 ")).toString()); - assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1")).toString()); - assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)")).toString()); - assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )")).toString()); - assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)")).toString()); - assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))")).toString()); - assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)")).toString()); - assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4")).toString()); - assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar")).toString()); - assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)")).toString()); - assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)")).toString()); + TransformContext c = new TransformContext(Collections.emptyMap()); + assertEquals("a + b", s.transform(new RankingExpression("a + b"), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5"), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 "), c).toString()); + assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1"), c).toString()); + assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)"), c).toString()); + assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )"), c).toString()); + assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)"), c).toString()); + assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))"), c).toString()); + assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)"), c).toString()); + assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4"), c).toString()); + assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar"), c).toString()); + assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)"), c).toString()); + assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)"), c).toString()); } // A black box test verifying we are not screwing up real expressions @Test public void testSimplifyComplexExpression() throws ParseException { RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)"); - RankingExpression simplified = new Simplifier().transform(initial); + TransformContext c = new TransformContext(Collections.emptyMap()); + RankingExpression simplified = new Simplifier().transform(initial, c); Context context = new MapContext(); context.put("INFERRED", 0.5); @@ -65,7 +70,8 @@ public class SimplifierTestCase { @Test public void testParenthesisPreservation() throws ParseException { Simplifier s = new Simplifier(); - CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0")).getRoot(); + TransformContext c = new TransformContext(Collections.emptyMap()); + CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0"), c).getRoot(); assertEquals("a + (b + c) / 100000000.0", transformed.toString()); } |