summaryrefslogtreecommitdiffstats
path: root/model-inference/src/main/java/ai
diff options
context:
space:
mode:
authorArnstein Ressem <aressem@oath.com>2018-08-09 10:27:43 +0200
committerArnstein Ressem <aressem@oath.com>2018-08-09 10:27:43 +0200
commit6d61753ac389a884430be9e2eb9bbbd216ea4db5 (patch)
tree9a1b00c5ac40a0c0ea9d822658e225c90adab84b /model-inference/src/main/java/ai
parentb80a8292c21e0c0dd678024928077e3e268de789 (diff)
parente2887cb7299438c02bc49d888aaaf2e51631ace9 (diff)
Merge branch 'master' into aressem/kill-mbuild
Diffstat (limited to 'model-inference/src/main/java/ai')
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java57
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/FunctionReference.java76
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java213
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java154
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/Model.java132
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java46
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java86
7 files changed, 764 insertions, 0 deletions
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
new file mode 100644
index 00000000000..4acd6e483b4
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -0,0 +1,57 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * An evaluator which can be used to evaluate a single function once.
+ *
+ * @author bratseth
+ */
+// This wraps all access to the context and the ranking expression to avoid incorrect usage
+public class FunctionEvaluator {
+
+ private final ExpressionFunction function;
+ private final LazyArrayContext context;
+ private boolean evaluated = false;
+
+ FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) {
+ this.function = function;
+ this.context = context;
+ }
+
+ /**
+ * Binds the given variable referred in this expression to the given value.
+ *
+ * @param name the variable to bind
+ * @param value the value this becomes bound to
+ * @return this for chaining
+ */
+ public FunctionEvaluator bind(String name, Tensor value) {
+ if (evaluated)
+ throw new IllegalStateException("You cannot bind a value in a used evaluator");
+ context.put(name, new TensorValue(value));
+ return this;
+ }
+
+ /**
+ * Binds the given variable referred in this expression to the given value.
+ * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code>
+ *
+ * @param name the variable to bind
+ * @param value the value this becomes bound to
+ * @return this for chaining
+ */
+ public FunctionEvaluator bind(String name, double value) {
+ return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build());
+ }
+
+ public Tensor evaluate() {
+ evaluated = true;
+ return function.getBody().evaluate(context).asTensor();
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
new file mode 100644
index 00000000000..3b50cef6e2e
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -0,0 +1,76 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * A reference to a function.
+ * The function may be
+ * - free: Callable from users of models, or
+ * - bound: Representing a specific invocation from another ranking expression.
+ * In bound functions, any arguments are replaced by the values supplied in the function invocation.
+ * Function references has a serial form (textual representation) used in ranking expressions received in ranking
+ * expression configurations.
+ *
+ * This is immutable.
+ *
+ * @author bratseth
+ */
+class FunctionReference {
+
+ private static final Pattern referencePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
+
+ /** The name of the function referenced */
+ private final String name;
+
+ /** The id of the specific invocation of the function, or null if it is free */
+ private final String instance;
+
+ private FunctionReference(String name, String instance) {
+ this.name = name;
+ this.instance = instance;
+ }
+
+ /** Returns the name of the function referenced */
+ String functionName() { return name; }
+
+ boolean isFree() {
+ return instance == null;
+ }
+
+ String serialForm() {
+ return "rankingExpression(" + name + (instance != null ? instance : "") + ")";
+ }
+
+ @Override
+ public String toString() { return "reference to function '" + name + "'" +
+ ( instance != null ? " instance '" + instance + "'" : ""); }
+
+ @Override
+ public int hashCode() { return Objects.hash(name, instance); }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) return true;
+ if ( ! (o instanceof FunctionReference)) return false;
+ FunctionReference other = (FunctionReference)o;
+ if ( ! Objects.equals(this.name, other.name)) return false;
+ if ( ! Objects.equals(this.instance, other.instance)) return false;
+ return true;
+ }
+
+ /** Returns a function reference from the given serial form, or empty if the string is not a valid reference */
+ static Optional<FunctionReference> fromSerial(String serialForm) {
+ Matcher expressionMatcher = referencePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ return Optional.of(new FunctionReference(name, instance));
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
new file mode 100644
index 00000000000..2dcfd204077
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -0,0 +1,213 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An array context supporting functions invocations implemented as lazy values.
+ *
+ * @author bratseth
+ */
+final class LazyArrayContext extends Context implements ContextIndex {
+
+ private final IndexedBindings indexedBindings;
+
+ private LazyArrayContext(IndexedBindings indexedBindings) {
+ this.indexedBindings = indexedBindings.copy(this);
+ }
+
+ /**
+ * Create a fast lookup, lazy context for an expression.
+ *
+ * @param expression the expression to create a context for
+ */
+ LazyArrayContext(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, Model model) {
+ this.indexedBindings = new IndexedBindings(expression, functions, this, model);
+ }
+
+ /**
+ * Puts a value by name.
+ * The value will be frozen if it isn't already.
+ *
+ * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and
+ * ignoredUnknownValues is false
+ */
+ @Override
+ public void put(String name, Value value) {
+ put(requireIndexOf(name), value);
+ }
+
+ /** Same as put(index,DoubleValue.frozen(value)) */
+ public final void put(int index, double value) {
+ put(index, DoubleValue.frozen(value));
+ }
+
+ /**
+ * Puts a value by index.
+ * The value will be frozen if it isn't already.
+ */
+ public void put(int index, Value value) {
+ indexedBindings.set(index, value.freeze());
+ }
+
+ @Override
+ public TensorType getType(Reference reference) {
+ // TODO: Add type information so we do not need to evaluate to get this
+ return get(requireIndexOf(reference.toString())).type();
+ }
+
+ /** Perform a slow lookup by name */
+ @Override
+ public Value get(String name) {
+ return get(requireIndexOf(name));
+ }
+
+ /** Perform a fast lookup by index */
+ @Override
+ public Value get(int index) {
+ return indexedBindings.get(index);
+ }
+
+ @Override
+ public double getDouble(int index) {
+ double value = get(index).asDouble();
+ if (value == Double.NaN)
+ throw new UnsupportedOperationException("Value at " + index + " has no double representation");
+ return value;
+ }
+
+ @Override
+ public int getIndex(String name) {
+ return requireIndexOf(name);
+ }
+
+ @Override
+ public int size() {
+ return indexedBindings.names().size();
+ }
+
+ @Override
+ public Set<String> names() { return indexedBindings.names(); }
+
+ private Integer requireIndexOf(String name) {
+ Integer index = indexedBindings.indexOf(name);
+ if (index == null)
+ throw new IllegalArgumentException("Value '" + name + "' can not be bound in " + this);
+ return index;
+ }
+
+ /**
+ * Creates a copy of this context suitable for evaluating against the same ranking expression
+ * in a different thread or for re-binding free variables.
+ */
+ LazyArrayContext copy() {
+ return new LazyArrayContext(indexedBindings);
+ }
+
+ private static class IndexedBindings {
+
+ /** The mapping from variable name to index */
+ private final ImmutableMap<String, Integer> nameToIndex;
+
+ /** The current values set, pre-converted to doubles */
+ private final Value[] values;
+
+ private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values) {
+ this.nameToIndex = nameToIndex;
+ this.values = values;
+ }
+
+ /**
+ * Creates indexed bindings for the given expressions.
+ * The given expression and functions may be inspected but cannot be stored.
+ */
+ IndexedBindings(RankingExpression expression,
+ Map<FunctionReference, ExpressionFunction> functions,
+ LazyArrayContext owner,
+ Model model) {
+ Set<String> bindTargets = new LinkedHashSet<>();
+ extractBindTargets(expression.getRoot(), functions, bindTargets);
+
+ values = new Value[bindTargets.size()];
+ Arrays.fill(values, DoubleValue.zero);
+
+ int i = 0;
+ ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
+ for (String variable : bindTargets)
+ nameToIndexBuilder.put(variable,i++);
+ nameToIndex = nameToIndexBuilder.build();
+
+ for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
+ Integer index = nameToIndex.get(function.getKey().serialForm());
+ if (index != null) // Referenced in this, so bind it
+ values[index] = new LazyValue(function.getKey(), owner, model);
+ }
+ }
+
+ private void extractBindTargets(ExpressionNode node,
+ Map<FunctionReference, ExpressionFunction> functions,
+ Set<String> bindTargets) {
+ if (isFunctionReference(node)) {
+ FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
+ bindTargets.add(reference.serialForm());
+
+ extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets);
+ }
+ else if (isConstant(node)) {
+ // Ignore
+ }
+ else if (node instanceof ReferenceNode) {
+ bindTargets.add(node.toString());
+ }
+ else if (node instanceof CompositeNode) {
+ CompositeNode cNode = (CompositeNode)node;
+ for (ExpressionNode child : cNode.children())
+ extractBindTargets(child, functions, bindTargets);
+ }
+ }
+
+ private boolean isFunctionReference(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode)) return false;
+
+ ReferenceNode reference = (ReferenceNode)node;
+ return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
+ }
+
+ private boolean isConstant(ExpressionNode node) {
+ if ( ! (node instanceof ReferenceNode)) return false;
+
+ ReferenceNode reference = (ReferenceNode)node;
+ return reference.getName().equals("value") && reference.getArguments().size() == 1;
+ }
+
+ Value get(int index) { return values[index]; }
+ void set(int index, Value value) { values[index] = value; }
+ Set<String> names() { return nameToIndex.keySet(); }
+ Integer indexOf(String name) { return nameToIndex.get(name); }
+
+ IndexedBindings copy(Context context) {
+ Value[] valueCopy = new Value[values.length];
+ for (int i = 0; i < values.length; i++)
+ valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i];
+ return new IndexedBindings(nameToIndex, valueCopy);
+ }
+
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java
new file mode 100644
index 00000000000..4a1ee22d288
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyValue.java
@@ -0,0 +1,154 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.Function;
+import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * A Value which is computed from an expression when first requested.
+ * This is not multithread safe.
+ *
+ * @author bratseth
+ */
+class LazyValue extends Value {
+
+ /** The reference to the function computing the value of this */
+ private final FunctionReference function;
+
+ /** The context used to compute the function of this */
+ private final Context context;
+
+ /** The model this is part of */
+ private final Model model;
+
+ private Value computedValue = null;
+
+ public LazyValue(FunctionReference function, Context context, Model model) {
+ this.function = function;
+ this.context = context;
+ this.model = model;
+ }
+
+ private Value computedValue() {
+ if (computedValue == null)
+ computedValue = model.requireReferencedFunction(function).getBody().evaluate(context);
+ return computedValue;
+ }
+
+ @Override
+ public TensorType type() {
+ return computedValue().type(); // TODO: Keep type information in this/ExpressionFunction to avoid computing here
+ }
+
+ @Override
+ public double asDouble() {
+ return computedValue().asDouble();
+ }
+
+ @Override
+ public Tensor asTensor() {
+ return computedValue().asTensor();
+ }
+
+ @Override
+ public boolean hasDouble() {
+ return type().rank() == 0;
+ }
+
+ @Override
+ public boolean asBoolean() {
+ return computedValue().asBoolean();
+ }
+
+ @Override
+ public Value negate() {
+ return computedValue().negate();
+ }
+
+ @Override
+ public Value add(Value value) {
+ return computedValue().add(value);
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ return computedValue().subtract(value);
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ return computedValue().multiply(value);
+ }
+
+ @Override
+ public Value divide(Value value) {
+ return computedValue().divide(value);
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ return computedValue().modulo(value);
+ }
+
+ @Override
+ public Value and(Value value) {
+ return computedValue().and(value);
+ }
+
+ @Override
+ public Value or(Value value) {
+ return computedValue().or(value);
+ }
+
+ @Override
+ public Value not() {
+ return computedValue().not();
+ }
+
+ @Override
+ public Value power(Value value) {
+ return computedValue().power(value);
+ }
+
+ @Override
+ public Value compare(TruthOperator operator, Value value) {
+ return computedValue().compare(operator, value);
+ }
+
+ @Override
+ public Value function(Function function, Value value) {
+ return computedValue().function(function, value);
+ }
+
+ @Override
+ public Value asMutable() {
+ return computedValue().asMutable();
+ }
+
+ @Override
+ public String toString() {
+ return "value of " + function;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ if (!(other instanceof Value)) return false;
+ return computedValue().equals(other);
+ }
+
+ @Override
+ public int hashCode() {
+ return computedValue().hashCode();
+ }
+
+ LazyValue copyFor(Context context) {
+ return new LazyValue(this.function, context, model);
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
new file mode 100644
index 00000000000..ca739195867
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -0,0 +1,132 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * A named collection of functions
+ *
+ * @author bratseth
+ */
+public class Model {
+
+ private final String name;
+
+ /** Free functions */
+ private final ImmutableList<ExpressionFunction> functions;
+
+ /** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */
+ private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
+
+ /** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */
+ private final ImmutableMap<String, LazyArrayContext> contextPrototypes;
+
+ private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
+
+ public Model(String name, Collection<ExpressionFunction> functions) {
+ this(name, functions, Collections.emptyMap());
+ }
+
+ Model(String name, Collection<ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions) {
+ // TODO: Optimize functions
+ this.name = name;
+ this.functions = ImmutableList.copyOf(functions);
+
+ ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
+ for (ExpressionFunction function : functions) {
+ try {
+ contextBuilder.put(function.getName(), new LazyArrayContext(function.getBody(), referencedFunctions, this));
+ }
+ catch (RuntimeException e) {
+ throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
+ }
+ }
+ this.contextPrototypes = contextBuilder.build();
+
+ ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
+ for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
+ ExpressionFunction optimizedFunction = optimize(function.getValue(),
+ contextPrototypes.get(function.getKey().functionName()));
+ functionsBuilder.put(function.getKey(), optimizedFunction);
+ }
+ this.referencedFunctions = functionsBuilder.build();
+ }
+
+ /** Returns an optimized version of the given function */
+ private ExpressionFunction optimize(ExpressionFunction function, ContextIndex context) {
+ // Note: Optimization is in-place but we do not depend on that outside this method
+ expressionOptimizer.optimize(function.getBody(), context);
+ return function;
+ }
+
+ public String name() { return name; }
+
+ /** Returns an immutable list of the free functions of this */
+ public List<ExpressionFunction> functions() { return functions; }
+
+ /** Returns the given function, or throws a IllegalArgumentException if it does not exist */
+ ExpressionFunction requireFunction(String name) {
+ ExpressionFunction function = function(name);
+ if (function == null)
+ throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " +
+ functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
+ return function;
+ }
+
+ /** Returns the given function, or throws a IllegalArgumentException if it does not exist */
+ private LazyArrayContext requireContextProprotype(String name) {
+ LazyArrayContext context = contextPrototypes.get(name);
+ if (context == null) // Implies function is not present
+ throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " +
+ functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
+ return context;
+ }
+
+ /** Returns the function withe the given name, or null if none */ // TODO: Parameter overloading?
+ ExpressionFunction function(String name) {
+ for (ExpressionFunction function : functions)
+ if (function.getName().equals(name))
+ return function;
+ return null;
+ }
+
+ /** Returns an immutable map of the referenced function instances of this */
+ Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; }
+
+ /** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */
+ ExpressionFunction requireReferencedFunction(FunctionReference reference) {
+ ExpressionFunction function = referencedFunctions.get(reference);
+ if (function == null)
+ throw new IllegalArgumentException("No " + reference + " in " + this + ". References: " +
+ referencedFunctions.keySet().stream()
+ .map(FunctionReference::serialForm)
+ .collect(Collectors.joining(", ")));
+ return function;
+ }
+
+ /**
+ * Returns an evaluator which can be used to evaluate the given function in a single thread once.
+
+ * Usage:
+ * <code>Tensor result = model.evaluatorOf("myFunction").bind("foo", value).bind("bar", value).evaluate()</code>
+ *
+ * @throws IllegalArgumentException if the function is not present
+ */
+ public FunctionEvaluator evaluatorOf(String function) { // TODO: Parameter overloading?
+ return new FunctionEvaluator(requireFunction(function), requireContextProprotype(function).copy());
+ }
+
+ @Override
+ public String toString() { return "model '" + name + "'"; }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
new file mode 100644
index 00000000000..b36e06e5505
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -0,0 +1,46 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Evaluates machine-learned models added to Vespa applications and available as config form.
+ * Usage:
+ * <code>Tensor result = evaluator.bind("foo", value).bind("bar", value").evaluate()</code>
+ *
+ * @author bratseth
+ */
+public class ModelsEvaluator {
+
+ private final ImmutableMap<String, Model> models;
+
+ public ModelsEvaluator(RankProfilesConfig config) {
+ models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config));
+ }
+
+ /** Returns the models of this as an immutable map */
+ public Map<String, Model> models() { return models; }
+
+ /**
+ * Returns a function which can be used to evaluate the given function in the given model
+ *
+ * @throws IllegalArgumentException if the function or model is not present
+ */
+ public FunctionEvaluator evaluatorOf(String modelName, String functionName) {
+ return requireModel(modelName).evaluatorOf(functionName);
+ }
+
+ /** Returns the given model, or throws a IllegalArgumentException if it does not exist */
+ Model requireModel(String name) {
+ Model model = models.get(name);
+ if (model == null)
+ throw new IllegalArgumentException("No model named '" + name + ". Available models: " +
+ models.keySet().stream().collect(Collectors.joining(", ")));
+ return model;
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
new file mode 100644
index 00000000000..bfd6342218a
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -0,0 +1,86 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Converts RankProfilesConfig instances to RankingExpressions for evaluation
+ *
+ * @author bratseth
+ */
+class RankProfilesConfigImporter {
+
+ /**
+ * Returns a map of the models contained in this config, indexed on name.
+ * The map is modifiable and owned by the caller.
+ */
+ Map<String, Model> importFrom(RankProfilesConfig config) {
+ try {
+ Map<String, Model> models = new HashMap<>();
+ for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
+ Model model = importProfile(profile);
+ models.put(model.name(), model);
+ }
+ return models;
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException("Could not read rank profiles config - version mismatch?", e);
+ }
+ }
+
+ private Model importProfile(RankProfilesConfig.Rankprofile profile) throws ParseException {
+ List<ExpressionFunction> functions = new ArrayList<>();
+ Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
+ ExpressionFunction firstPhase = null;
+ ExpressionFunction secondPhase = null;
+ for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
+ Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
+ if ( reference.isPresent()) {
+ List<String> arguments = new ArrayList<>(); // TODO: Arguments?
+ RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
+
+ if (reference.get().isFree()) // make available in model under configured name
+ functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); //
+
+ // Make all functions, bound or not available under the name they are referenced by in expressions
+ referencedFunctions.put(reference.get(), new ExpressionFunction(reference.get().serialForm(), arguments, expression));
+ }
+ else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to macros
+ firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
+ new RankingExpression("first-phase", property.value()));
+ }
+ else if (property.name().equals("vespa.rank.secondphase")) { // Include in addition to macros
+ secondPhase = new ExpressionFunction("secondphase", new ArrayList<>(),
+ new RankingExpression("second-phase", property.value()));
+ }
+ }
+ if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body
+ functions.add(firstPhase);
+ if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body
+ functions.add(secondPhase);
+
+ try {
+ return new Model(profile.name(), functions, referencedFunctions);
+ }
+ catch (RuntimeException e) {
+ throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
+ }
+ }
+
+ private ExpressionFunction functionByName(String name, List<ExpressionFunction> functions) {
+ for (ExpressionFunction function : functions)
+ if (function.getName().equals(name))
+ return function;
+ return null;
+ }
+
+}