// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; /** * An evaluator which can be used to evaluate a single function once. * * @author bratseth */ // This wraps all access to the context and the ranking expression to avoid incorrect usage public class FunctionEvaluator { private final ExpressionFunction function; private final LazyArrayContext context; private boolean evaluated = false; FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) { this.function = function; this.context = context; } /** * Binds the given variable referred in this expression to the given value. * * @param name the variable to bind * @param value the value this becomes bound to * @return this for chaining */ public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); TensorType requiredType = function.getArgumentType(name); if (requiredType == null) throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + ". Expected arguments: " + function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) .map(e -> e.getKey() + ": " + e.getValue()) .collect(Collectors.joining(", "))); if ( ! value.type().isAssignableTo(requiredType)) throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); context.put(name, new TensorValue(value)); return this; } /** * Binds the given variable referred in this expression to the given value. * This is equivalent to bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build()) * * @param name the variable to bind * @param value the value this becomes bound to * @return this for chaining */ public FunctionEvaluator bind(String name, double value) { return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build()); } /** * Binds the given variable referred in this expression to the given value. * String values are not yet supported in tensors. * * @param name the variable to bind * @param value the value this becomes bound to * @return this for chaining */ public FunctionEvaluator bind(String name, String value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); context.put(name, new StringValue(value)); return this; } /** * Sets the default value to use for variables which are not bound * * @param value the default value * @return this for chaining */ public FunctionEvaluator setMissingValue(Tensor value) { if (evaluated) throw new IllegalStateException("Cannot change the missing value in a used evaluator"); context.setMissingValue(value); return this; } /** * Sets the default value to use for variables which are not bound * * @param value the default value * @return this for chaining */ public FunctionEvaluator setMissingValue(double value) { return setMissingValue(Tensor.Builder.of(TensorType.empty).cell(value).build()); } public Tensor evaluate() { function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) .forEach(argument -> checkArgument(argument.getKey(), argument.getValue())); evaluated = true; evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); } private void checkArgument(String name, TensorType type) { if (context.isMissing(name)) throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type); if (! context.get(name).type().isAssignableTo(type)) throw new IllegalStateException("Argument '" + name + "' must be bound to a value of type " + type); } /** * Evaluate ONNX models (if not already evaluated) and add the result back to the context. */ private void evaluateOnnxModels() { for (Map.Entry entry : context().onnxModels().entrySet()) { String onnxFeature = entry.getKey(); String outputName = function.getName(); // Function name is output of model (sometimes) int idx = onnxFeature.indexOf(")."); if (idx > 0 && idx + 2 < onnxFeature.length()) { // explicitly specified as onnx(modelname).outputname ; pick the last part outputName = onnxFeature.substring(idx+2); } OnnxModel onnxModel = entry.getValue(); if (context.get(onnxFeature).equals(context.defaultValue())) { Map inputs = new HashMap<>(); for (Map.Entry input: onnxModel.inputs().entrySet()) { inputs.put(input.getKey(), context.get(input.getKey()).asTensor()); } Tensor result = onnxModel.evaluate(inputs, outputName); context.put(onnxFeature, new TensorValue(result)); } } } /** Returns the function evaluated by this */ public ExpressionFunction function() { return function; } public LazyArrayContext context() { return context; } }