summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 10:04:27 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 10:04:27 +0100
commitba6a11e6e2674a2b5c1ef967319fb269f989a216 (patch)
treeb15c1c046989cafeed19d193fdb59634140d3db6
parent3e1477f5fda4a3dcd436a6d41843adc66e19f370 (diff)
Use a context for transform state
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java12
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java35
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java25
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java21
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java34
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java34
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java20
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java89
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java11
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java46
15 files changed, 213 insertions, 183 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 118fc8b6211..fa202770e26 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -63,11 +63,11 @@ public class DerivedConfiguration {
*/
public DerivedConfiguration(Search search, List<Search> abstractSearchList, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry) {
Validator.ensureNotNull("Search definition", search);
- if (!search.isProcessed()) {
+ if ( ! search.isProcessed()) {
throw new IllegalArgumentException("Search '" + search.getName() + "' not processed.");
}
this.search = search;
- if (!search.isDocumentsOnly()) {
+ if ( ! search.isDocumentsOnly()) {
streamingFields = new VsmFields(search);
streamingSummary = new VsmSummary(search);
}
@@ -160,15 +160,15 @@ public class DerivedConfiguration {
public Search getSearch() {
return search;
}
-
+
public RankProfileList getRankProfileList() {
return rankProfileList;
}
-
+
public VsmSummary getVsmSummary() {
return streamingSummary;
}
-
+
public VsmFields getVsmFields() {
return streamingFields;
}
@@ -180,7 +180,7 @@ public class DerivedConfiguration {
public Juniperrc getJuniperrc() {
return juniperrc;
}
-
+
public SummaryMap getSummaryMap() {
return summaryMap;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java
index e061ead465e..f835e0a6ed1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java
@@ -8,11 +8,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import java.util.Map;
/**
* Transforms named references to constant tensors with the rank feature 'constant'.
@@ -23,53 +23,44 @@ public class ConstantTensorTransformer extends ExpressionTransformer {
public static final String CONSTANT = "constant";
- private final Map<String, Value> constants;
- private final Map<String, String> rankPropertiesOutput;
-
- public ConstantTensorTransformer(Map<String, Value> constants,
- Map<String, String> rankPropertiesOutput) {
- this.constants = constants;
- this.rankPropertiesOutput = rankPropertiesOutput;
- }
-
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof ReferenceNode) {
- return transformFeature((ReferenceNode) node);
+ return transformFeature((ReferenceNode) node, (RankProfileTransformContext)context);
} else if (node instanceof CompositeNode) {
- return transformChildren((CompositeNode) node);
+ return transformChildren((CompositeNode) node, context);
} else {
return node;
}
}
- private ExpressionNode transformFeature(ReferenceNode node) {
+ private ExpressionNode transformFeature(ReferenceNode node, RankProfileTransformContext context) {
if (!node.getArguments().isEmpty()) {
- return transformArguments(node);
+ return transformArguments(node, context);
} else {
- return transformConstantReference(node);
+ return transformConstantReference(node, context);
}
}
- private ExpressionNode transformArguments(ReferenceNode node) {
+ private ExpressionNode transformArguments(ReferenceNode node, TransformContext context) {
List<ExpressionNode> arguments = node.getArguments().expressions();
List<ExpressionNode> transformedArguments = new ArrayList<>(arguments.size());
for (ExpressionNode argument : arguments) {
- transformedArguments.add(transform(argument));
+ transformedArguments.add(transform(argument, context));
}
return node.setArguments(transformedArguments);
}
- private ExpressionNode transformConstantReference(ReferenceNode node) {
- Value value = constants.get(node.getName());
+ private ExpressionNode transformConstantReference(ReferenceNode node, RankProfileTransformContext context) {
+ Value value = context.constants().get(node.getName());
if (value == null || !(value instanceof TensorValue)) {
return node;
}
TensorValue tensorValue = (TensorValue)value;
String featureName = CONSTANT + "(" + node.getName() + ")";
String tensorType = tensorValue.asTensor().type().toString();
- rankPropertiesOutput.put(featureName + ".value", tensorValue.toString());
- rankPropertiesOutput.put(featureName + ".type", tensorType);
+ context.rankPropertiesOutput().put(featureName + ".value", tensorValue.toString());
+ context.rankPropertiesOutput().put(featureName + ".type", tensorType);
return new ReferenceNode("constant", Arrays.asList(new NameNode(node.getName())), null);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
index ee5cccccb29..d7a38f47766 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
@@ -1,33 +1,44 @@
package com.yahoo.searchdefinition.expressiontransforms;
+import com.google.common.collect.ImmutableList;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.transform.ConstantDereferencer;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.Simplifier;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+import java.util.List;
import java.util.Map;
/**
* The transformations done on ranking expressions done at config time before passing them on to the Vespa
* engine for execution.
*
+ * An instance of this class has scope of one complete deployment.
+ *
* @author bratseth
*/
public class ExpressionTransforms {
+ private final List<ExpressionTransformer> transforms =
+ ImmutableList.of(new TensorFlowFeatureConverter(),
+ new ConstantDereferencer(),
+ new ConstantTensorTransformer(),
+ new MacroInliner(),
+ new MacroShadower(),
+ new TensorTransformer(),
+ new Simplifier());
+
public RankingExpression transform(RankingExpression expression,
RankProfile rankProfile,
Map<String, Value> constants,
Map<String, RankProfile.Macro> inlineMacros,
Map<String, String> rankPropertiesOutput) {
- expression = new TensorFlowFeatureConverter(rankProfile).transform(expression);
- expression = new ConstantDereferencer(constants).transform(expression);
- expression = new ConstantTensorTransformer(constants, rankPropertiesOutput).transform(expression);
- expression = new MacroInliner(inlineMacros).transform(expression);
- expression = new MacroShadower(rankProfile.getMacros()).transform(expression);
- expression = new TensorTransformer(rankProfile).transform(expression);
- expression = new Simplifier().transform(expression);
+ TransformContext context = new RankProfileTransformContext(rankProfile, constants, inlineMacros, rankPropertiesOutput);
+ for (ExpressionTransformer transformer : transforms)
+ expression = transformer.transform(expression, context);
return expression;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java
index a3933e6f8e2..6702955bae3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java
@@ -6,8 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-
-import java.util.Map;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
/**
* Inlines macros in ranking expressions
@@ -16,25 +15,19 @@ import java.util.Map;
*/
public class MacroInliner extends ExpressionTransformer {
- private final Map<String, RankProfile.Macro> macros;
-
- public MacroInliner(Map<String, RankProfile.Macro> macros) {
- this.macros = macros;
- }
-
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof ReferenceNode)
- return transformFeatureNode((ReferenceNode)node);
+ return transformFeatureNode((ReferenceNode)node, (RankProfileTransformContext)context);
if (node instanceof CompositeNode)
- return transformChildren((CompositeNode)node);
+ return transformChildren((CompositeNode)node, context);
return node;
}
- private ExpressionNode transformFeatureNode(ReferenceNode feature) {
- RankProfile.Macro macro = macros.get(feature.getName());
+ private ExpressionNode transformFeatureNode(ReferenceNode feature, RankProfileTransformContext context) {
+ RankProfile.Macro macro = context.inlineMacros().get(feature.getName());
if (macro == null) return feature;
- return transform(macro.getRankingExpression().getRoot()); // inline recursively and return
+ return transform(macro.getRankingExpression().getRoot(), context); // inline recursively and return
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java
index 1d9769d0d78..6eabb5ddcd4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java
@@ -3,10 +3,12 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.rule.*;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-
-import java.util.Map;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
/**
* Transforms function nodes to reference nodes if a macro shadows a built-in function.
@@ -23,44 +25,38 @@ import java.util.Map;
*/
public class MacroShadower extends ExpressionTransformer {
- private final Map<String, RankProfile.Macro> macros;
-
- public MacroShadower(Map<String, RankProfile.Macro> macros) {
- this.macros = macros;
- }
-
@Override
- public RankingExpression transform(RankingExpression expression) {
+ public RankingExpression transform(RankingExpression expression, TransformContext context) {
String name = expression.getName();
ExpressionNode node = expression.getRoot();
- ExpressionNode result = transform(node);
+ ExpressionNode result = transform(node, context);
return new RankingExpression(name, result);
}
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof FunctionNode)
- return transformFunctionNode((FunctionNode) node);
+ return transformFunctionNode((FunctionNode) node, context);
if (node instanceof CompositeNode)
- return transformChildren((CompositeNode)node);
+ return transformChildren((CompositeNode)node, context);
return node;
}
- protected ExpressionNode transformFunctionNode(FunctionNode function) {
+ protected ExpressionNode transformFunctionNode(FunctionNode function, TransformContext context) {
String name = function.getFunction().toString();
- RankProfile.Macro macro = macros.get(name);
+ RankProfile.Macro macro = ((RankProfileTransformContext)context).rankProfile().getMacros().get(name);
if (macro == null) {
- return transformChildren(function);
+ return transformChildren(function, context);
}
int functionArity = function.getFunction().arity();
int macroArity = macro.getFormalParams() != null ? macro.getFormalParams().size() : 0;
if (functionArity != macroArity) {
- return transformChildren(function);
+ return transformChildren(function, context);
}
ReferenceNode node = new ReferenceNode(name, function.children(), null);
- return transformChildren(node);
+ return transformChildren(node, context);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
new file mode 100644
index 00000000000..fb996d70607
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
@@ -0,0 +1,34 @@
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+
+import java.util.Map;
+
+/**
+ * Extends the transform context with rank profile information
+ *
+ * @author bratseth
+ */
+public class RankProfileTransformContext extends TransformContext {
+
+ private final RankProfile rankProfile;
+ private final Map<String, RankProfile.Macro> inlineMacros;
+ private final Map<String, String> rankPropertiesOutput;
+
+ RankProfileTransformContext(RankProfile rankProfile,
+ Map<String, Value> constants,
+ Map<String, RankProfile.Macro> inlineMacros,
+ Map<String, String> rankPropertiesOutput) {
+ super(constants);
+ this.rankProfile = rankProfile;
+ this.inlineMacros = inlineMacros;
+ this.rankPropertiesOutput = rankPropertiesOutput;
+ }
+
+ public RankProfile rankProfile() { return rankProfile; }
+ public Map<String, RankProfile.Macro> inlineMacros() { return inlineMacros; }
+ public Map<String, String> rankPropertiesOutput() { return rankPropertiesOutput; }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index e5886030d44..b7033d4ad9f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,6 +1,5 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
@@ -10,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import java.util.Map;
import java.util.Optional;
@@ -24,23 +24,18 @@ import java.util.Optional;
public class TensorFlowFeatureConverter extends ExpressionTransformer {
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
- private final RankProfile profile;
-
- public TensorFlowFeatureConverter(RankProfile profile) {
- this.profile = profile;
- }
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof ReferenceNode)
- return transformFeature((ReferenceNode) node);
+ return transformFeature((ReferenceNode) node, (RankProfileTransformContext)context);
else if (node instanceof CompositeNode)
- return super.transformChildren((CompositeNode) node);
+ return super.transformChildren((CompositeNode) node, context);
else
return node;
}
- private ExpressionNode transformFeature(ReferenceNode feature) {
+ private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
try {
if ( ! feature.getName().equals("tensorflow")) return feature;
@@ -48,15 +43,16 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer {
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
"the tensorflow model directory under [application]/models");
- // Find the specified expression
ImportResult result = tensorFlowImporter.importModel(asString(feature.getArguments().expressions().get(0)));
+
+ // Find the specified expression
ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(),
optionalArgument(1, feature.getArguments()));
String output = chooseOrDefault("outputs", signature.outputs(),
optionalArgument(2, feature.getArguments()));
// Add all constants
- result.constants().forEach((k, v) -> profile.addConstantTensor(k, new TensorValue(v)));
+ result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v)));
return result.expressions().get(output).getRoot();
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
index 70a7372dbe9..971c2c4f218 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
@@ -3,7 +3,6 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.document.Attribute;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -17,12 +16,12 @@ import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
/**
@@ -36,32 +35,22 @@ import java.util.Optional;
*/
public class TensorTransformer extends ExpressionTransformer {
- private Search search;
- private RankProfile rankprofile;
- private Map<String, RankProfile.Macro> macros;
-
- public TensorTransformer(RankProfile rankprofile) {
- this.rankprofile = rankprofile;
- this.search = rankprofile.getSearch();
- this.macros = rankprofile.getMacros();
- }
-
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof CompositeNode) {
- node = transformChildren((CompositeNode) node);
+ node = transformChildren((CompositeNode) node, context);
}
if (node instanceof FunctionNode) {
- node = transformFunctionNode((FunctionNode) node);
+ node = transformFunctionNode((FunctionNode) node, ((RankProfileTransformContext)context).rankProfile());
}
return node;
}
- private ExpressionNode transformFunctionNode(FunctionNode node) {
+ private ExpressionNode transformFunctionNode(FunctionNode node, RankProfile profile) {
switch (node.getFunction()) {
case min:
case max:
- return transformMaxAndMinFunctionNode(node);
+ return transformMaxAndMinFunctionNode(node, profile);
}
return node;
}
@@ -80,7 +69,7 @@ public class TensorTransformer extends ExpressionTransformer {
* There is currently no guarantee that all cases will be found. For
* instance, if-statements are problematic.
*/
- private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) {
+ private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, RankProfile profile) {
if (node.children().size() != 2) {
return node;
}
@@ -88,7 +77,7 @@ public class TensorTransformer extends ExpressionTransformer {
Optional<String> dimension = dimensionName(node.children().get(1));
if (dimension.isPresent()) {
try {
- Context context = buildContext(arg1);
+ Context context = buildContext(arg1, profile);
Value value = arg1.evaluate(context);
if (isTensorWithDimension(value, dimension.get())) {
return replaceMaxAndMinFunction(node);
@@ -110,12 +99,10 @@ public class TensorTransformer extends ExpressionTransformer {
}
private boolean isTensorWithDimension(Value value, String dimension) {
- if (value instanceof TensorValue) {
- Tensor tensor = ((TensorValue) value).asTensor();
- TensorType type = tensor.type();
- return type.dimensionNames().contains(dimension);
- }
- return false;
+ if (value instanceof TensorValue)
+ return value.asTensor().type().dimensionNames().contains(dimension);
+ else
+ return false;
}
private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
@@ -133,9 +120,9 @@ public class TensorTransformer extends ExpressionTransformer {
* Creates an evaluation context by iterating through the expression tree, and
* adding dummy values with correct types to the context.
*/
- private Context buildContext(ExpressionNode node) {
+ private Context buildContext(ExpressionNode node, RankProfile profile) {
Context context = new MapContext();
- addRoot(node, context);
+ addRoot(node, context, profile);
return context;
}
@@ -152,28 +139,28 @@ public class TensorTransformer extends ExpressionTransformer {
return new TensorValue(empty);
}
- private void addRoot(ExpressionNode node, Context context) {
- addChildren(node, context);
+ private void addRoot(ExpressionNode node, Context context, RankProfile profile) {
+ addChildren(node, context, profile);
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- addIfAttribute(referenceNode, context);
- addIfConstant(referenceNode, context);
- addIfQuery(referenceNode, context);
+ addIfAttribute(referenceNode, context, profile);
+ addIfConstant(referenceNode, context, profile);
+ addIfQuery(referenceNode, context, profile);
addIfTensorFrom(referenceNode, context);
- addIfMacro(referenceNode, context);
+ addIfMacro(referenceNode, context, profile);
}
}
- private void addChildren(ExpressionNode node, Context context) {
+ private void addChildren(ExpressionNode node, Context context, RankProfile profile) {
if (node instanceof CompositeNode) {
List<ExpressionNode> children = ((CompositeNode) node).children();
for (ExpressionNode child : children) {
- addRoot(child, context);
+ addRoot(child, context, profile);
}
}
}
- private void addIfAttribute(ReferenceNode node, Context context) {
+ private void addIfAttribute(ReferenceNode node, Context context, RankProfile profile) {
if (!node.getName().equals("attribute")) {
return;
}
@@ -181,7 +168,7 @@ public class TensorTransformer extends ExpressionTransformer {
return;
}
String attribute = node.children().get(0).toString();
- Attribute a = search.getAttribute(attribute);
+ Attribute a = profile.getSearch().getAttribute(attribute);
if (a == null) {
return;
}
@@ -196,7 +183,7 @@ public class TensorTransformer extends ExpressionTransformer {
context.put(node.toString(), v);
}
- private void addIfConstant(ReferenceNode node, Context context) {
+ private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) {
if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) {
return;
}
@@ -208,25 +195,25 @@ public class TensorTransformer extends ExpressionTransformer {
child = ((CompositeNode) child).children().get(0);
}
String name = child.toString();
- addIfConstantInRankProfile(name, node, context);
- addIfConstantInRankingConstants(name, node, context);
+ addIfConstantInRankProfile(name, node, context, profile);
+ addIfConstantInRankingConstants(name, node, context, profile);
}
- private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) {
- if (rankprofile.getConstants().containsKey(name)) {
- context.put(node.toString(), rankprofile.getConstants().get(name));
+ private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context, RankProfile profile) {
+ if (profile.getConstants().containsKey(name)) {
+ context.put(node.toString(), profile.getConstants().get(name));
}
}
- private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) {
- for (RankingConstant rankingConstant : search.getRankingConstants()) {
+ private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context, RankProfile profile) {
+ for (RankingConstant rankingConstant : profile.getSearch().getRankingConstants()) {
if (rankingConstant.getName().equals(name)) {
context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType()));
}
}
}
- private void addIfQuery(ReferenceNode node, Context context) {
+ private void addIfQuery(ReferenceNode node, Context context, RankProfile profile) {
if (!node.getName().equals("query")) {
return;
}
@@ -234,8 +221,8 @@ public class TensorTransformer extends ExpressionTransformer {
return;
}
String name = node.children().get(0).toString();
- if (rankprofile.getQueryFeatureTypes().containsKey(name)) {
- String type = rankprofile.getQueryFeatureTypes().get(name);
+ if (profile.getQueryFeatureTypes().containsKey(name)) {
+ String type = profile.getQueryFeatureTypes().get(name);
Value v;
if (type.contains("tensor")) {
v = emptyTensorValue(TensorType.fromSpec(type));
@@ -267,13 +254,13 @@ public class TensorTransformer extends ExpressionTransformer {
context.put(node.toString(), emptyTensorValue(type));
}
- private void addIfMacro(ReferenceNode node, Context context) {
- RankProfile.Macro macro = macros.get(node.getName());
+ private void addIfMacro(ReferenceNode node, Context context, RankProfile profile) {
+ RankProfile.Macro macro = profile.getMacros().get(node.getName());
if (macro == null) {
return;
}
ExpressionNode root = macro.getRankingExpression().getRoot();
- Context macroContext = buildContext(root);
+ Context macroContext = buildContext(root, profile);
addMacroArguments(node, context, macro, macroContext);
Value value = root.evaluate(macroContext);
context.put(node.toString(), value);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index e5693d24f0f..475fee62177 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -92,7 +92,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
private void assertContainsExpression(String expr, String transformedExpression) throws ParseException {
assertTrue("Expected expression '" + transformedExpression + "' not found",
- containsExpression(expr, transformedExpression));
+ containsExpression(expr, transformedExpression));
}
private boolean containsExpression(String expr, String transformedExpression) throws ParseException {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
index 46f79dcc6bd..1b8239ba643 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
@@ -10,7 +10,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
/**
* Replaces "features" which found in the given constants by their constant value
@@ -19,40 +18,33 @@ import java.util.Map;
*/
public class ConstantDereferencer extends ExpressionTransformer {
- /** The map of constants to dereference */
- private final Map<String, Value> constants;
-
- public ConstantDereferencer(Map<String, Value> constants) {
- this.constants = constants;
- }
-
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof ReferenceNode)
- return transformFeature((ReferenceNode) node);
+ return transformFeature((ReferenceNode) node, context);
else if (node instanceof CompositeNode)
- return transformChildren((CompositeNode)node);
+ return transformChildren((CompositeNode)node, context);
else
return node;
}
- private ExpressionNode transformFeature(ReferenceNode node) {
+ private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) {
if (!node.getArguments().isEmpty())
- return transformArguments(node);
+ return transformArguments(node, context);
else
- return transformConstantReference(node);
+ return transformConstantReference(node, context);
}
- private ExpressionNode transformArguments(ReferenceNode node) {
+ private ExpressionNode transformArguments(ReferenceNode node, TransformContext context) {
List<ExpressionNode> arguments = node.getArguments().expressions();
List<ExpressionNode> transformedArguments = new ArrayList<>(arguments.size());
for (ExpressionNode argument : arguments)
- transformedArguments.add(transform(argument));
+ transformedArguments.add(transform(argument, context));
return node.setArguments(transformedArguments);
}
- private ExpressionNode transformConstantReference(ReferenceNode node) {
- Value value = constants.get(node.getName());
+ private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) {
+ Value value = context.constants().get(node.getName());
if (value == null || (value instanceof TensorValue)) {
return node; // not a value constant reference
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
index bcc8b817641..c585c0dea1f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
@@ -15,22 +15,22 @@ import java.util.List;
*/
public abstract class ExpressionTransformer {
- public RankingExpression transform(RankingExpression expression) {
- return new RankingExpression(expression.getName(), transform(expression.getRoot()));
+ public RankingExpression transform(RankingExpression expression, TransformContext context) {
+ return new RankingExpression(expression.getName(), transform(expression.getRoot(), context));
}
/** Transforms an expression node and returns the transformed node */
- public abstract ExpressionNode transform(ExpressionNode node);
+ public abstract ExpressionNode transform(ExpressionNode node, TransformContext context);
/**
* Utility method which calls transform on each child of the given node and return the resulting transformed
* composite
*/
- protected CompositeNode transformChildren(CompositeNode node) {
+ protected CompositeNode transformChildren(CompositeNode node, TransformContext context) {
List<ExpressionNode> children = node.children();
List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
for (ExpressionNode child : children)
- transformedChildren.add(transform(child));
+ transformedChildren.add(transform(child, context));
return node.setChildren(transformedChildren);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index ebad0d5c21f..9e8491340b0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.transform;
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
@@ -10,8 +9,8 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import java.util.ArrayList;
import java.util.List;
@@ -24,9 +23,9 @@ import java.util.List;
public class Simplifier extends ExpressionTransformer {
@Override
- public ExpressionNode transform(ExpressionNode node) {
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
if (node instanceof CompositeNode)
- node = transformChildren((CompositeNode) node); // depth first
+ node = transformChildren((CompositeNode) node, context); // depth first
if (node instanceof IfNode)
node = transformIf((IfNode) node);
if (node instanceof EmbracedNode && hasSingleUndividableChild((EmbracedNode)node))
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java
new file mode 100644
index 00000000000..746ca3b3200
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java
@@ -0,0 +1,22 @@
+package com.yahoo.searchlib.rankingexpression.transform;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Map;
+
+/**
+ * Provides a context in which transforms on ranking expressions take place.
+ *
+ * @author bratseth
+ */
+public class TransformContext {
+
+ private final Map<String, Value> constants;
+
+ public TransformContext(Map<String, Value> constants) {
+ this.constants = constants;
+ }
+
+ public Map<String, Value> constants() { return constants; }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
index 4035e499a6a..84e51835458 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
@@ -5,11 +5,12 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import org.junit.Test;
-import static org.junit.Assert.*;
import java.util.HashMap;
import java.util.Map;
+import static org.junit.Assert.assertEquals;
+
/**
* @author bratseth
*/
@@ -17,14 +18,16 @@ public class ConstantDereferencerTestCase {
@Test
public void testConstantDereferencer() throws ParseException {
+ ConstantDereferencer c = new ConstantDereferencer();
+
Map<String, Value> constants = new HashMap<>();
constants.put("a", Value.parse("1.0"));
constants.put("b", Value.parse("2"));
constants.put("c", Value.parse("3.5"));
- ConstantDereferencer c = new ConstantDereferencer(constants);
+ TransformContext context = new TransformContext(constants);
- assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c")).toString());
- assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)")).toString());
+ assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString());
+ assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)"), context).toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
index f9d2472e306..8fac3395ac0 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
@@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import org.junit.Test;
+
+import java.util.Collections;
+
import static org.junit.Assert.*;
/**
@@ -17,31 +20,33 @@ public class SimplifierTestCase {
@Test
public void testSimplify() throws ParseException {
Simplifier s = new Simplifier();
- assertEquals("a + b", s.transform(new RankingExpression("a + b")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 ")).toString());
- assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1")).toString());
- assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)")).toString());
- assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )")).toString());
- assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)")).toString());
- assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))")).toString());
- assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)")).toString());
- assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4")).toString());
- assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar")).toString());
- assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)")).toString());
- assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)")).toString());
+ TransformContext c = new TransformContext(Collections.emptyMap());
+ assertEquals("a + b", s.transform(new RankingExpression("a + b"), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5"), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("( 1.0 + 2.0 ) + 3.5 "), c).toString());
+ assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + 1"), c).toString());
+ assertEquals("6.5 + a", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 ) + a"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * 0.0"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (0.0)"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("7.5 + ( 2.0 + 3.5 ) * (1.0 - 1.0)"), c).toString());
+ assertEquals("7.5", s.transform(new RankingExpression("if (2 > 0, 3.5 * 2 + 0.5, a *3 )"), c).toString());
+ assertEquals("0.0", s.transform(new RankingExpression("0.0 * (1.3 + 7.0)"), c).toString());
+ assertEquals("6.4", s.transform(new RankingExpression("max(0, 10.0-2.0)*(1-fabs(0.0-0.2))"), c).toString());
+ assertEquals("(query(d) + query(b) - query(a)) * query(c) / query(e)", s.transform(new RankingExpression("(query(d) + query(b) - query(a)) * query(c) / query(e)"), c).toString());
+ assertEquals("14.0", s.transform(new RankingExpression("5 + (2 + 3) + 4"), c).toString());
+ assertEquals("28.0 + bar", s.transform(new RankingExpression("7.0 + 12.0 + 9.0 + bar"), c).toString());
+ assertEquals("1.0 - 0.001 * attribute(number)", s.transform(new RankingExpression("1.0 - 0.001*attribute(number)"), c).toString());
+ assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)"), c).toString());
}
// A black box test verifying we are not screwing up real expressions
@Test
public void testSimplifyComplexExpression() throws ParseException {
RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)");
- RankingExpression simplified = new Simplifier().transform(initial);
+ TransformContext c = new TransformContext(Collections.emptyMap());
+ RankingExpression simplified = new Simplifier().transform(initial, c);
Context context = new MapContext();
context.put("INFERRED", 0.5);
@@ -65,7 +70,8 @@ public class SimplifierTestCase {
@Test
public void testParenthesisPreservation() throws ParseException {
Simplifier s = new Simplifier();
- CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0")).getRoot();
+ TransformContext c = new TransformContext(Collections.emptyMap());
+ CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0"), c).getRoot();
assertEquals("a + (b + c) / 100000000.0", transformed.toString());
}