summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/Model.java8
-rw-r--r--linguistics/src/main/java/com/yahoo/language/Linguistics.java7
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java141
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java146
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/Model.java22
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java9
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java3
-rw-r--r--model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java146
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ContextIndex.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Optimizer.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java12
15 files changed, 457 insertions, 80 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/Model.java b/container-search/src/main/java/com/yahoo/search/query/Model.java
index dc7a61344cb..167bb312f61 100644
--- a/container-search/src/main/java/com/yahoo/search/query/Model.java
+++ b/container-search/src/main/java/com/yahoo/search/query/Model.java
@@ -78,11 +78,11 @@ public class Model implements Cloneable {
private String defaultIndex = null;
private Query.Type type = Query.Type.ALL;
private Query parent;
- private Set<String> sources=new LinkedHashSet<>();
- private Set<String> restrict=new LinkedHashSet<>();
+ private Set<String> sources = new LinkedHashSet<>();
+ private Set<String> restrict = new LinkedHashSet<>();
private String searchPath;
private String documentDbName = null;
- private Execution execution=new Execution(new Execution.Context(null, null, null, null, null));
+ private Execution execution = new Execution(new Execution.Context(null, null, null, null, null));
public Model(Query query) {
setParent(query);
@@ -101,7 +101,7 @@ public class Model implements Cloneable {
*/
@Deprecated
public void traceLanguage() {
- if (getParent().getTraceLevel()<2) return;
+ if (getParent().getTraceLevel() < 2) return;
if (language != null) {
getParent().trace("Language " + getLanguage() + " specified directly as a parameter", false, 2);
}
diff --git a/linguistics/src/main/java/com/yahoo/language/Linguistics.java b/linguistics/src/main/java/com/yahoo/language/Linguistics.java
index e275f189b0c..c3c0c049e99 100644
--- a/linguistics/src/main/java/com/yahoo/language/Linguistics.java
+++ b/linguistics/src/main/java/com/yahoo/language/Linguistics.java
@@ -41,7 +41,12 @@ public interface Linguistics {
CHARACTER_CLASSES
}
- /** The same as new com.yahoo.language.simple.SimpleLinguistics(). Prefer using that directly. */
+ /**
+ * The same as new com.yahoo.language.simple.SimpleLinguistics(). Prefer using that directly.
+ *
+ * @deprecated use new com.yahoo.language.simple.SimpleLinguistics()
+ */
+ @Deprecated // TODO: Remove this field on Vespa 7
Linguistics SIMPLE = new SimpleLinguistics();
/**
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
new file mode 100644
index 00000000000..59dd1fd7b12
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -0,0 +1,141 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * An array context supporting functions invocations implemented as lazy values.
+ *
+ * @author bratseth
+ */
+class LazyArrayContext extends Context implements ContextIndex {
+
+ /** The current values set */
+ private Value[] values;
+
+ private static DoubleValue constantZero = DoubleValue.frozen(0);
+
+ /**
+ * Create a fast lookup, lazy context for an expression.
+ * This instance should be reused indefinitely by a single thread.
+ *
+ * @param expression the expression to create a context for
+ */
+ LazyArrayContext(RankingExpression expression, List<ExpressionFunction> functions) {
+ values = new Value[doubleValues().length];
+ Arrays.fill(values, DoubleValue.zero);
+ }
+
+ @Override
+ protected void extractBindTargets(ExpressionNode node, Set<String> bindTargets) {
+ if (isFunctionReference(node)) {
+ ReferenceNode reference = (ReferenceNode)node;
+ bindTargets.add(reference.getArguments().expressions().get(0).toString());
+
+ }
+ else {
+ super.extractBindTargets(node, bindTargets);
+ }
+ }
+
+ private boolean isFunctionReference(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode)) return false;
+
+ ReferenceNode reference = (ReferenceNode)node;
+ return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
+ }
+
+ /**
+ * Puts a value by name.
+ * The value will be frozen if it isn't already.
+ *
+ * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and
+ * ignoredUnknownValues is false
+ */
+ @Override
+ public final void put(String name, Value value) {
+ Integer index = nameToIndex().get(name);
+ if (index == null) {
+ if (ignoreUnknownValues())
+ return;
+ else
+ throw new IllegalArgumentException("Value '" + name + "' is not known to " + this);
+ }
+ put(index, value);
+ }
+
+ /** Same as put(index,DoubleValue.frozen(value)) */
+ public final void put(int index, double value) {
+ put(index, DoubleValue.frozen(value));
+ }
+
+ /**
+ * Puts a value by index.
+ * The value will be frozen if it isn't already.
+ */
+ public final void put(int index, Value value) {
+ values[index] = value.freeze();
+ try {
+ doubleValues()[index] = value.asDouble();
+ }
+ catch (UnsupportedOperationException e) {
+ doubleValues()[index] = Double.NaN; // see getDouble below
+ }
+ }
+
+ @Override
+ public TensorType getType(Reference reference) {
+ Integer index = nameToIndex().get(reference.toString());
+ if (index == null) return null;
+ return values[index].type();
+ }
+
+ /** Perform a slow lookup by name */
+ @Override
+ public Value get(String name) {
+ Integer index = nameToIndex().get(name);
+ if (index == null) return DoubleValue.zero;
+ return values[index];
+ }
+
+ /** Perform a fast lookup by index */
+ @Override
+ public final Value get(int index) {
+ return values[index];
+ }
+
+ /** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */
+ @Override
+ public final double getDouble(int index) {
+ double value = doubleValues()[index];
+ if (value == Double.NaN)
+ throw new UnsupportedOperationException("Value at " + index + " has no double representation");
+ return value;
+ }
+
+ /**
+ * Creates a clone of this context suitable for evaluating against the same ranking expression
+ * in a different thread (i.e, name name to index map, different value set.
+ */
+ public LazyArrayContext clone() {
+ LazyArrayContext clone = (LazyArrayContext)super.clone();
+ clone.values = new Value[nameToIndex().size()];
+ Arrays.fill(values, constantZero);
+ return clone;
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java
new file mode 100644
index 00000000000..b026007346d
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java
@@ -0,0 +1,146 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.Function;
+import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * A Value which is computed from an expression when first requested.
+ * This is not multithread safe.
+ *
+ * @author bratseth
+ */
+class LazyValue extends Value {
+
+ /** The function computing the value of this */
+ private final ExpressionFunction function;
+
+ /** The context used to compute the function of this */
+ private final Context context;
+
+ private Value computedValue = null;
+
+ public LazyValue(ExpressionFunction function, Context context) {
+ this.function = function;
+ this.context = context;
+ }
+
+ private Value computedValue() {
+ if (computedValue == null)
+ computedValue = function.getBody().evaluate(context);
+ return computedValue;
+ }
+
+ @Override
+ public TensorType type() {
+ return computedValue().type(); // TODO: Keep type information in this/ExpressionFunction to avoid computing here
+ }
+
+ @Override
+ public double asDouble() {
+ return computedValue().asDouble();
+ }
+
+ @Override
+ public Tensor asTensor() {
+ return computedValue().asTensor();
+ }
+
+ @Override
+ public boolean hasDouble() {
+ return type().rank() == 0;
+ }
+
+ @Override
+ public boolean asBoolean() {
+ return computedValue().asBoolean();
+ }
+
+ @Override
+ public Value negate() {
+ return computedValue().negate();
+ }
+
+ @Override
+ public Value add(Value value) {
+ return computedValue().add(value);
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ return computedValue().subtract(value);
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ return computedValue().multiply(value);
+ }
+
+ @Override
+ public Value divide(Value value) {
+ return computedValue().divide(value);
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ return computedValue().modulo(value);
+ }
+
+ @Override
+ public Value and(Value value) {
+ return computedValue().and(value);
+ }
+
+ @Override
+ public Value or(Value value) {
+ return computedValue().or(value);
+ }
+
+ @Override
+ public Value not() {
+ return computedValue().not();
+ }
+
+ @Override
+ public Value power(Value value) {
+ return computedValue().power(value);
+ }
+
+ @Override
+ public Value compare(TruthOperator operator, Value value) {
+ return computedValue().compare(operator, value);
+ }
+
+ @Override
+ public Value function(Function function, Value value) {
+ return computedValue().function(function, value);
+ }
+
+ @Override
+ public Value asMutable() {
+ return computedValue().asMutable();
+ }
+
+ @Override
+ public String toString() {
+ return "value of " + function;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ if (!(other instanceof Value)) return false;
+ return computedValue().equals(other);
+ }
+
+ @Override
+ public int hashCode() {
+ return computedValue().hashCode();
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
index cbeb1ca708c..5bccd526571 100644
--- a/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -3,6 +3,8 @@ package ai.vespa.models.evaluation;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import java.util.Collection;
import java.util.Collections;
@@ -50,9 +52,6 @@ public class Model {
}
- /** Returns an immutable list of the bound function instances of this */
- List<ExpressionFunction> boundFunctions() { return boundFunctions; }
-
/** Returns the function withe the given name, or null if none */ // TODO: Parameter overloading?
ExpressionFunction function(String name) {
for (ExpressionFunction function : functions)
@@ -61,6 +60,9 @@ public class Model {
return null;
}
+ /** Returns an immutable list of the bound function instances of this */
+ List<ExpressionFunction> boundFunctions() { return boundFunctions; }
+
/** Returns the function withe the given name, or null if none */ // TODO: Parameter overloading?
ExpressionFunction boundFunction(String name) {
for (ExpressionFunction function : boundFunctions)
@@ -69,6 +71,20 @@ public class Model {
return null;
}
+ /**
+ * Returns a function which can be used to evaluate the given function
+ *
+ * @throws IllegalArgumentException if the function is not present
+ */
+ public Context contextFor(String function) {
+ Context context = new LazyArrayContext(requireFunction(function).getBody(), boundFunctions);
+ for (ExpressionFunction boundFunction : boundFunctions) {
+ System.out.println("Binding " + boundFunction.getName());
+ context.put(boundFunction.getName(), new LazyValue(boundFunction, context));
+ }
+ return context;
+ }
+
@Override
public String toString() { return "model '" + name + "'"; }
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 6d1f0d885ae..35c6a269edd 100644
--- a/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -27,6 +27,15 @@ public class ModelsEvaluator {
public Map<String, Model> models() { return models; }
/**
+ * Returns a function which can be used to evaluate the given function in the given model
+ *
+ * @throws IllegalArgumentException if the function or model is not present
+ */
+ public Context contextFor(String modelName, String functionName) {
+ return requireModel(modelName).contextFor(functionName);
+ }
+
+ /**
* Evaluates the given function in the given model.
*
* @param modelName the name of the model to evaluate
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index a09ecffeda0..15d3ab4bf1f 100644
--- a/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -20,7 +20,8 @@ import java.util.regex.Pattern;
*/
class RankProfilesConfigImporter {
- private static final Pattern expressionPattern =
+ // TODO: Move to separate class ... or something
+ static final Pattern expressionPattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.rankingScript");
/**
diff --git a/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 13b6dbb6dd9..9965a3c86ba 100644
--- a/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -4,6 +4,7 @@ package ai.vespa.models.evaluation;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
@@ -46,7 +47,7 @@ public class ModelsEvaluatorTest {
}
@Test
- public void testScalarArrayCnotextEvaluation() {
+ public void testScalarArrayContextEvaluation() {
ModelsEvaluator evaluator = createEvaluator();
ArrayContext context = new ArrayContext(evaluator.requireModel("macros").requireFunction("fourtimessum").getBody());
context.put("var1", Value.of(Tensor.from("{{x:0}:3,{x:1}:5}")));
@@ -63,4 +64,13 @@ public class ModelsEvaluatorTest {
assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), evaluator.evaluate("macros", "fourtimessum", context));
}
+ @Test
+ public void testEvaluationDependingOnBoundMacro() {
+ ModelsEvaluator evaluator = createEvaluator();
+ Context context = evaluator.contextFor("macros", "secondphase");
+ context.put("match", 3);
+ context.put("rankboost", 5);
+ assertEquals(32.0, evaluator.evaluate("macros", "secondphase", context).asDouble(), delta);
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java
index ed9fa346c11..893e31c7087 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/AbstractArrayContext.java
@@ -7,8 +7,6 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import java.util.Collections;
-import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
@@ -20,19 +18,15 @@ import java.util.Set;
*
* @author bratseth
*/
-public abstract class AbstractArrayContext extends Context implements Cloneable {
+public abstract class AbstractArrayContext extends Context implements Cloneable, ContextIndex {
private final boolean ignoreUnknownValues;
- /** The mapping from variable name to index */
- private final ImmutableMap<String, Integer> nameToIndex;
-
- /** The current values set, pre-converted to doubles */
- private double[] doubleValues;
-
/** The name of the ranking expression this was created for */
private final String rankingExpressionName;
+ private IndexedBindings indexedBindings;
+
/**
* Create a fast lookup context for an expression.
* This instance should be reused indefinitely by a single thread.
@@ -53,44 +47,50 @@ public abstract class AbstractArrayContext extends Context implements Cloneable
protected AbstractArrayContext(RankingExpression expression, boolean ignoreUnknownValues) {
this.ignoreUnknownValues = ignoreUnknownValues;
this.rankingExpressionName = expression.getName();
- Set<String> variables = new LinkedHashSet<>();
- extractVariables(expression.getRoot(),variables);
+ this.indexedBindings = new IndexedBindings(expression);
+ }
- doubleValues = new double[variables.size()];
+ protected final Map<String, Integer> nameToIndex() { return indexedBindings.nameToIndex(); }
+ protected final double[] doubleValues() { return indexedBindings.doubleValues(); }
+ protected final boolean ignoreUnknownValues() { return ignoreUnknownValues; }
- int i = 0;
- ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
- for (String variable : variables)
- nameToIndexBuilder.put(variable,i++);
- nameToIndex = nameToIndexBuilder.build();
+ public Set<String> names() {
+ return indexedBindings.names();
}
- private void extractVariables(ExpressionNode node,Set<String> variables) {
- if (node instanceof ReferenceNode) {
- ReferenceNode fNode=(ReferenceNode)node;
- if (fNode.getArguments().expressions().size()>0)
- throw new UnsupportedOperationException("Array lookup is not supported with features having arguments)");
- variables.add(fNode.toString());
- }
- else if (node instanceof CompositeNode) {
- CompositeNode cNode=(CompositeNode)node;
- for (ExpressionNode child : cNode.children())
- extractVariables(child,variables);
- }
+ /**
+ * Returns the index from a name.
+ *
+ * @throws NullPointerException is this name is not known to this context
+ */
+ @Override
+ public final int getIndex(String name) { return indexedBindings.nameToIndex.get(name); }
+
+ /** Returns the max number of variables which may be set in this */
+ @Override
+ public int size() { return indexedBindings.size(); }
+
+ /** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */
+ @Override
+ public double getDouble(int index) {
+ return indexedBindings.getDouble(index);
}
- protected final Map<String, Integer> nameToIndex() { return nameToIndex; }
- protected final double[] doubleValues() { return doubleValues; }
- protected final boolean ignoreUnknownValues() { return ignoreUnknownValues; }
+ @Override
+ public String toString() {
+ return "fast lookup context for ranking expression '" + rankingExpressionName +
+ "' [" + size() + " variables]";
+ }
/**
* Creates a clone of this context suitable for evaluating against the same ranking expression
* in a different thread (i.e, name name to index map, different value set.
*/
+ @Override
public AbstractArrayContext clone() {
try {
- AbstractArrayContext clone=(AbstractArrayContext)super.clone();
- clone.doubleValues=new double[nameToIndex.size()];
+ AbstractArrayContext clone = (AbstractArrayContext)super.clone();
+ clone.indexedBindings = indexedBindings.clone();
return clone;
}
catch (CloneNotSupportedException e) {
@@ -98,34 +98,64 @@ public abstract class AbstractArrayContext extends Context implements Cloneable
}
}
- public Set<String> names() {
- return nameToIndex.keySet();
- }
+ private static class IndexedBindings implements Cloneable {
- /**
- * Returns the index from a name.
- *
- * @throws NullPointerException is this name is not known to this context
- */
- public final int getIndex(String name) {
- return nameToIndex.get(name);
- }
+ /** The mapping from variable name to index */
+ private final ImmutableMap<String, Integer> nameToIndex;
- /** Returns the max number of variables which may be set in this */
- public int size() {
- return doubleValues.length;
- }
+ /** The current values set, pre-converted to doubles */
+ private double[] doubleValues;
- /** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */
- @Override
- public double getDouble(int index) {
- return doubleValues[index];
- }
+ public IndexedBindings(RankingExpression expression) {
+ Set<String> bindTargets = new LinkedHashSet<>();
+ extractBindTargets(expression.getRoot(), bindTargets);
+
+ doubleValues = new double[bindTargets.size()];
+
+ int i = 0;
+ ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
+ for (String variable : bindTargets)
+ nameToIndexBuilder.put(variable,i++);
+ nameToIndex = nameToIndexBuilder.build();
+ }
+
+ private void extractBindTargets(ExpressionNode node, Set<String> bindTargets) {
+ if (node instanceof ReferenceNode) {
+ if (((ReferenceNode)node).getArguments().expressions().size() > 0)
+ throw new UnsupportedOperationException("Can not bind " + node +
+ ": Array lookup is not supported with features having arguments)");
+ bindTargets.add(node.toString());
+ }
+ else if (node instanceof CompositeNode) {
+ CompositeNode cNode = (CompositeNode)node;
+ for (ExpressionNode child : cNode.children())
+ extractBindTargets(child, bindTargets);
+ }
+ }
+
+ public Map<String, Integer> nameToIndex() { return nameToIndex; }
+ public double[] doubleValues() { return doubleValues; }
+ public Set<String> names() { return nameToIndex.keySet(); }
+ public int getIndex(String name) { return nameToIndex.get(name); }
+ public int size() { return doubleValues.length; }
+ public double getDouble(int index) { return doubleValues[index]; }
+
+ /**
+ * Creates a clone of this context suitable for evaluating against the same ranking expression
+ * in a different thread (i.e, name name to index map, different value set.
+ */
+ @Override
+ public IndexedBindings clone() {
+ try {
+ IndexedBindings clone = (IndexedBindings)super.clone();
+ clone.doubleValues = new double[nameToIndex.size()];
+ return clone;
+ }
+ catch (CloneNotSupportedException e) {
+ throw new RuntimeException("Programming error");
+ }
+ }
- @Override
- public String toString() {
- return "fast lookup context for ranking expression '" + rankingExpressionName +
- "' [" + doubleValues.length + " variables]";
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
index ee5952d9aea..237c3a1d0b1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
@@ -119,7 +119,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
public ArrayContext clone() {
ArrayContext clone = (ArrayContext)super.clone();
clone.values = new Value[nameToIndex().size()];
- Arrays.fill(values,constantZero);
+ Arrays.fill(values, constantZero);
return clone;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ContextIndex.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ContextIndex.java
new file mode 100644
index 00000000000..4f1465cd1f5
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ContextIndex.java
@@ -0,0 +1,20 @@
+package com.yahoo.searchlib.rankingexpression.evaluation;
+
+/**
+ * Indexed context lookup methods
+ *
+ * @author bratseth
+ */
+public interface ContextIndex {
+
+ /** Returns the number of bound variables in this */
+ int size();
+
+ /**
+ * Returns the index from a name.
+ *
+ * @throws NullPointerException is this name is not known to this context
+ */
+ int getIndex(String name);
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
index 836aadd9f70..b82173eabd5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
@@ -44,7 +44,7 @@ public class ExpressionOptimizer {
return null;
}
- public OptimizationReport optimize(RankingExpression expression, AbstractArrayContext arrayContext) {
+ public OptimizationReport optimize(RankingExpression expression, ContextIndex arrayContext) {
OptimizationReport report = new OptimizationReport();
// Note: Order of optimizations matter
gbdtOptimizer.optimize(expression, arrayContext, report);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Optimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Optimizer.java
index fd9c02100f7..044b5b589a5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Optimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Optimizer.java
@@ -18,6 +18,6 @@ public abstract class Optimizer {
/** Returns whether this is enabled */
public boolean isEnabled() { return enabled; }
- public abstract void optimize(RankingExpression expression, AbstractArrayContext context, OptimizationReport report);
+ public abstract void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
index 8999be4745a..bb8b91eecab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
@@ -2,8 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
@@ -34,7 +33,7 @@ public class GBDTForestOptimizer extends Optimizer {
* @param report the optimization report to which actions of this is logged
*/
@Override
- public void optimize(RankingExpression expression, AbstractArrayContext context, OptimizationReport report) {
+ public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) {
if ( ! isEnabled()) return;
this.report = report;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index 74af3e576c1..787818b0f42 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -33,7 +33,7 @@ public class GBDTOptimizer extends Optimizer {
* @param report the optimization report to which actions of this is logged
*/
@Override
- public void optimize(RankingExpression expression, AbstractArrayContext context, OptimizationReport report) {
+ public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) {
if (!isEnabled()) return;
this.report = report;
@@ -60,7 +60,7 @@ public class GBDTOptimizer extends Optimizer {
*
* @return the optimized expression
*/
- private ExpressionNode optimize(ExpressionNode node, AbstractArrayContext context) {
+ private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
if (node instanceof ArithmeticNode) {
Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator();
ExpressionNode ret = optimize(childIt.next(), context);
@@ -77,7 +77,7 @@ public class GBDTOptimizer extends Optimizer {
return node;
}
- private ExpressionNode createGBDTNode(IfNode cNode, AbstractArrayContext context) {
+ private ExpressionNode createGBDTNode(IfNode cNode,ContextIndex context) {
List<Double> values = new ArrayList<>();
try {
consumeNode(cNode, values, context);
@@ -93,7 +93,7 @@ public class GBDTOptimizer extends Optimizer {
/**
* Recursively consume nodes into the value list Returns the number of values produced by this.
*/
- private int consumeNode(ExpressionNode node, List<Double> values, AbstractArrayContext context) {
+ private int consumeNode(ExpressionNode node, List<Double> values, ContextIndex context) {
int beforeIndex = values.size();
if ( node instanceof IfNode) {
IfNode ifNode = (IfNode)node;
@@ -113,7 +113,7 @@ public class GBDTOptimizer extends Optimizer {
}
/** Consumes the if condition and return the size of the values resulting, for convenience */
- private int consumeIfCondition(ExpressionNode condition, List<Double> values, AbstractArrayContext context) {
+ private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
if (condition instanceof ComparisonNode) {
ComparisonNode comparison = (ComparisonNode)condition;
if (comparison.getOperator() == TruthOperator.SMALLER)
@@ -138,7 +138,7 @@ public class GBDTOptimizer extends Optimizer {
return values.size();
}
- private double getVariableIndex(ExpressionNode node, AbstractArrayContext context) {
+ private double getVariableIndex(ExpressionNode node, ContextIndex context) {
if (!(node instanceof ReferenceNode)) {
throw new IllegalArgumentException("Contained a left-hand comparison expression " +
"which was not a feature value but was: " + node);