blob: 4e7988a20061d7d0dfbc238f88492a4415b078cc (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.RankProfile;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.util.Set;
/**
* Analyzes expression to figure out what inputs it needs
*
* @author arnej
*/
public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> {
private final Set<String> neededInputs;
public InputRecorder(Set<String> target) {
this.neededInputs = target;
}
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
if (node instanceof ReferenceNode r) {
handle(r, context);
return node;
}
if (node instanceof CompositeNode c)
return transformChildren(c, context);
if (node instanceof ConstantNode) {
return node;
}
throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]");
}
private void handle(ReferenceNode feature, RankProfileTransformContext context) {
Reference ref = feature.reference();
String name = ref.name();
var args = ref.arguments();
if (args.size() == 0) {
var f = context.rankProfile().getFunctions().get(name);
if (f != null && f.function().arguments().size() == 0) {
transform(f.function().getBody().getRoot(), context);
return;
}
neededInputs.add(feature.toString());
return;
}
if (args.size() == 1) {
if (FeatureNames.isAttributeFeature(ref)) {
neededInputs.add(feature.toString());
return;
}
if (FeatureNames.isQueryFeature(ref)) {
// get rid of this later, we should be able
// to get it from the query
neededInputs.add(feature.toString());
return;
}
if (FeatureNames.isConstantFeature(ref)) {
var allConstants = context.rankProfile().constants();
if (allConstants.containsKey(ref)) {
// assumes we have the constant available during evaluation without any more wiring
return;
}
throw new IllegalArgumentException("unknown constant: " + feature);
}
}
if ("onnx".equals(name)) {
if (args.size() != 1) {
throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature);
}
var arg = args.expressions().get(0);
var models = context.rankProfile().onnxModels();
var model = models.get(arg.toString());
if (model == null) {
throw new IllegalArgumentException("missing onnx model: " + arg);
}
for (String onnxInput : model.getInputMap().values()) {
neededInputs.add(onnxInput);
}
return;
}
throw new IllegalArgumentException("cannot handle feature: " + feature);
}
}
|