From 9dd9026441ca39df01dbdb456e3e38e402d28a5e Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Wed, 8 Mar 2023 13:08:24 +0000 Subject: add ExpressionNode computing an output from an ONNX model --- .../models/evaluation/OnnxExpressionNode.java | 102 +++++++++++++++++++++ .../java/ai/vespa/models/evaluation/OnnxModel.java | 19 ++++ 2 files changed, 121 insertions(+) create mode 100644 model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java (limited to 'model-evaluation') diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java new file mode 100644 index 00000000000..a50d9e36d74 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java @@ -0,0 +1,102 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * make it possible to evaluate an ONNX model anywhere in the ranking expression tree + */ +class OnnxExpressionNode extends CompositeNode { + private final OnnxModel model; + private final String onnxOutputName; + private final TensorType expectedType; + private final String outputAs; + private final List modelInputs = new ArrayList<>(); + private final List inputRefs = new ArrayList<>(); + + OnnxExpressionNode(OnnxModel model, String onnxOutputName, TensorType expectedType, String outputAs) { + this.model = model; + this.onnxOutputName = onnxOutputName; + this.expectedType = expectedType; + this.outputAs = outputAs; + for (var input : model.inputSpecs) { + modelInputs.add(input.onnxName); + var optRef = parseOnnxInput(input.source); + if (optRef.isEmpty()) { + throw new IllegalArgumentException("Bad input source for ONNX model " + model.name() + ": '" + input + "'"); + } + var ref = optRef.get(); + inputRefs.add(new ReferenceNode(ref)); + } + } + + static Optional parseOnnxInput(String input) { + var optRef = Reference.simple(input); + if (optRef.isPresent()) { + return optRef; + } + try { + var ref = Reference.fromIdentifier(input); + return Optional.of(ref); + } catch (Exception e) { + // fallthrough + } + return Optional.empty(); + } + + @Override + public List children() { return List.copyOf(inputRefs); } + + @Override + public CompositeNode setChildren(List children) { + if (inputRefs.size() != children.size()) { + throw new IllegalArgumentException("bad setChildren"); + } + inputRefs.clear(); + inputRefs.addAll(children); + return this; + } + + @Override + public Value evaluate(Context context) { + Map inputs = new HashMap<>(); + for (int i = 0; i < modelInputs.size(); i++) { + Value inputValue = inputRefs.get(i).evaluate(context); + inputs.put(modelInputs.get(i), inputValue.asTensor()); + } + return new TensorValue(model.evaluate(inputs, onnxOutputName)); + } + + @Override + public TensorType type(TypeContext context) { return expectedType; } + + @Override + public int hashCode() { return Objects.hash("OnnxExpressionNode", model.name(), onnxOutputName); } + + @Override + public StringBuilder toString(StringBuilder b, SerializationContext context, Deque path, CompositeNode parent) { + b.append("onnx_expression_node(").append(model.name()).append(")"); + if (outputAs != null && ! outputAs.equals("")) { + b.append(".").append(outputAs); + } + return b; + } +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index cf97c20e881..59febf7cdbf 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -4,6 +4,9 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -64,6 +67,7 @@ class OnnxModel implements AutoCloseable { private final OnnxRuntime onnx; private OnnxEvaluator evaluator; + private final Map exprPerOutput = new HashMap<>(); OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) { this.name = name; @@ -81,6 +85,7 @@ class OnnxModel implements AutoCloseable { evaluator = onnx.evaluatorOf(modelFile.getPath(), options); fillInputTypes(evaluator().getInputs()); fillOutputTypes(evaluator().getOutputs()); + fillOutputExpressions(); } } @@ -156,6 +161,20 @@ class OnnxModel implements AutoCloseable { return map; } + void fillOutputExpressions() { + for (var spec : outputSpecs) { + var node = new OnnxExpressionNode(this, spec.onnxName, spec.expectedType, spec.outputAs); + exprPerOutput.put(spec.outputAs, node); + } + } + + ExpressionNode getExpressionForOutput(String outputName) { + if (outputName == null && exprPerOutput.size() == 1) { + return exprPerOutput.values().iterator().next(); + } + return exprPerOutput.get(outputName); + } + public Tensor evaluate(Map inputs, String output) { var mapped = new HashMap(); for (var spec : inputSpecs) { -- cgit v1.2.3