blob: 41a09d9af8e39755e73da6b6763e0a9e6f6592ca (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.schema.derived.FileDistributedOnnxModels;
import com.yahoo.schema.derived.RankProfileList;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import com.yahoo.vespa.model.container.component.Handler;
import com.yahoo.vespa.model.container.component.SystemBindingPattern;
import java.nio.file.Path;
import java.util.Objects;
/**
* Configuration of components for stateless model evaluation
*
* @author bratseth
*/
public class ContainerModelEvaluation implements
RankProfilesConfig.Producer,
RankingConstantsConfig.Producer,
OnnxModelsConfig.Producer,
RankingExpressionsConfig.Producer {
public final static String LINGUISTICS_BUNDLE_NAME = "linguistics-components";
public final static String EVALUATION_BUNDLE_NAME = "model-evaluation";
public final static String INTEGRATION_BUNDLE_NAME = "model-integration";
public final static String ONNXRUNTIME_BUNDLE_NAME = "container-onnxruntime.jar";
public final static String ONNX_RUNTIME_CLASS = "ai.vespa.modelintegration.evaluator.OnnxRuntime";
private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName();
private final static String REST_HANDLER_NAME = "ai.vespa.models.handler.ModelsEvaluationHandler";
private final static String REST_BINDING_PATH = "/model-evaluation/v1";
public static final Path MODEL_EVALUATION_BUNDLE_FILE = PlatformBundles.absoluteBundlePath(EVALUATION_BUNDLE_NAME);
public static final Path MODEL_INTEGRATION_BUNDLE_FILE = PlatformBundles.absoluteBundlePath(INTEGRATION_BUNDLE_NAME);
public static final Path ONNXRUNTIME_BUNDLE_FILE = PlatformBundles.absoluteBundlePath(ONNXRUNTIME_BUNDLE_NAME);
/** Global rank profiles, aka models */
private final RankProfileList rankProfileList;
private final FileDistributedOnnxModels onnxModels; // For cluster specific ONNX model settings
public ContainerModelEvaluation(ApplicationContainerCluster cluster,
RankProfileList rankProfileList, FileDistributedOnnxModels onnxModels) {
this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null");
this.onnxModels = onnxModels;
cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME);
cluster.addComponent(ContainerModelEvaluation.getHandler());
}
@Override
public void getConfig(RankProfilesConfig.Builder builder) {
rankProfileList.getConfig(builder);
}
@Override
public void getConfig(RankingConstantsConfig.Builder builder) {
builder.constant(rankProfileList.getConstantsConfig());
}
@Override
public void getConfig(OnnxModelsConfig.Builder builder) {
if (onnxModels != null) {
builder.model(onnxModels.getConfig());
} else {
builder.model(rankProfileList.getOnnxConfig());
}
}
public void getConfig(RankingExpressionsConfig.Builder builder) {
builder.expression(rankProfileList.getExpressionsConfig());
}
public static Handler getHandler() {
Handler handler = new Handler(new ComponentModel(REST_HANDLER_NAME, null, EVALUATION_BUNDLE_NAME));
handler.addServerBindings(
SystemBindingPattern.fromHttpPath(REST_BINDING_PATH),
SystemBindingPattern.fromHttpPath(REST_BINDING_PATH + "/*"));
return handler;
}
}
|