aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-05-19 12:03:06 +0200
committerJon Bratseth <bratseth@gmail.com>2022-05-19 12:03:06 +0200
commit5c24dc5c9642a8d9ed70aee4c950fd0678a1ebec (patch)
treebd9b74bf00c832456f0b83c1b2cd7010be387d68 /config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java
parentf17c4fe7de4c55f5c4ee61897eab8c2f588d8405 (diff)
Rename the 'searchdefinition' package to 'schema'
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java96
1 files changed, 96 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java
new file mode 100644
index 00000000000..ce56a4320d3
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java
@@ -0,0 +1,96 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.schema.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.schema.OnnxModel;
+import com.yahoo.schema.RankProfile;
+import com.yahoo.schema.RankProfileRegistry;
+import com.yahoo.schema.Schema;
+import com.yahoo.schema.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+import com.yahoo.vespa.model.ml.OnnxModelInfo;
+
+import java.util.Map;
+
+/**
+ * Processes ONNX ranking features of the form:
+ *
+ * onnx("files/model.onnx", "path/to/output:1")
+ *
+ * And generates an "onnx-model" configuration as if it was defined in the profile:
+ *
+ * onnx-model files_model_onnx {
+ * file: "files/model.onnx"
+ * }
+ *
+ * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be
+ * processed after this.
+ *
+ * @author lesters
+ */
+public class OnnxModelConfigGenerator extends Processor {
+
+ public OnnxModelConfigGenerator(Schema schema, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(schema, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+ for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) {
+ if (profile.getFirstPhaseRanking() != null) {
+ process(profile.getFirstPhaseRanking().getRoot(), profile);
+ }
+ if (profile.getSecondPhaseRanking() != null) {
+ process(profile.getSecondPhaseRanking().getRoot(), profile);
+ }
+ for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
+ process(function.getValue().function().getBody().getRoot(), profile);
+ }
+ for (ReferenceNode feature : profile.getSummaryFeatures()) {
+ process(feature, profile);
+ }
+ }
+ }
+
+ private void process(ExpressionNode node, RankProfile profile) {
+ if (node instanceof ReferenceNode) {
+ process((ReferenceNode)node, profile);
+ } else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode) node).children()) {
+ process(child, profile);
+ }
+ }
+ }
+
+ private void process(ReferenceNode feature, RankProfile profile) {
+ if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
+ if (feature.getArguments().size() > 0) {
+ if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
+ ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
+ String path = OnnxModelTransformer.stripQuotes(node.toString());
+ String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
+
+ // Only add the configuration if the model can actually be found.
+ if ( ! OnnxModelInfo.modelExists(path, schema.applicationPackage())) {
+ path = ApplicationPackage.MODELS_DIR.append(path).toString();
+ if ( ! OnnxModelInfo.modelExists(path, schema.applicationPackage())) {
+ return;
+ }
+ }
+
+ OnnxModel onnxModel = profile.onnxModels().get(modelConfigName);
+ if (onnxModel == null)
+ profile.add(new OnnxModel(modelConfigName, path));
+ }
+ }
+ }
+ }
+
+}