blob: 67c016074de3126c42ee784ae33dbc6ed3789db8 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
/**
* make it possible to evaluate an ONNX model anywhere in the ranking expression tree
*/
class OnnxExpressionNode extends CompositeNode {
private final OnnxModel model;
private final String onnxOutputName;
private final TensorType expectedType;
private final String outputAs;
private final List<String> modelInputs = new ArrayList<>();
private final List<ExpressionNode> inputRefs = new ArrayList<>();
OnnxExpressionNode(OnnxModel model, String onnxOutputName, TensorType expectedType, String outputAs) {
this.model = model;
this.onnxOutputName = onnxOutputName;
this.expectedType = expectedType;
this.outputAs = outputAs;
for (var input : model.inputSpecs) {
modelInputs.add(input.onnxName);
var optRef = parseOnnxInput(input.source);
if (optRef.isEmpty()) {
throw new IllegalArgumentException("Bad input source for ONNX model " + model.name() + ": '" + input + "'");
}
var ref = optRef.get();
inputRefs.add(new ReferenceNode(ref));
}
}
static Optional<Reference> parseOnnxInput(String input) {
var optRef = Reference.simple(input);
if (optRef.isPresent()) {
return optRef;
}
try {
var ref = Reference.fromIdentifier(input);
return Optional.of(ref);
} catch (Exception e) {
// fallthrough
}
return Optional.empty();
}
@Override
public List<ExpressionNode> children() { return List.copyOf(inputRefs); }
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
if (inputRefs.size() != children.size()) {
throw new IllegalArgumentException("bad setChildren");
}
inputRefs.clear();
inputRefs.addAll(children);
return this;
}
@Override
public Value evaluate(Context context) {
Map<String, Tensor> inputs = new HashMap<>();
for (int i = 0; i < modelInputs.size(); i++) {
Value inputValue = inputRefs.get(i).evaluate(context);
inputs.put(modelInputs.get(i), inputValue.asTensor());
}
return new TensorValue(model.unmappedEvaluate(inputs, onnxOutputName));
}
@Override
public TensorType type(TypeContext<Reference> context) { return expectedType; }
@Override
public int hashCode() { return Objects.hash("OnnxExpressionNode", model.name(), onnxOutputName); }
@Override
public StringBuilder toString(StringBuilder b, SerializationContext context, Deque<String> path, CompositeNode parent) {
b.append("onnx_expression_node(").append(model.name()).append(")");
if (outputAs != null && ! outputAs.equals("")) {
b.append(".").append(outputAs);
}
return b;
}
}
|