summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java154
1 files changed, 154 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
new file mode 100644
index 00000000000..bead2e7e7c9
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -0,0 +1,154 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.component.Version;
+import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.vespa.defaults.Defaults;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+import onnx.Onnx;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Paths;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Processes every "onnx-model" element in the schema. Parses the model file,
+ * adds missing input and output mappings (assigning default names), and
+ * adds tensor types to all model inputs and outputs.
+ *
+ * Must be processed before RankingExpressingTypeResolver.
+ *
+ * @author lesters
+ */
+public class OnnxModelTypeResolver extends Processor {
+
+ public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+
+ for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
+ OnnxModel modelConfig = entry.getValue();
+ try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+
+ // Model inputs - if not defined, assumes a function is provided with a valid name
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
+ String onnxInputName = valueInfo.getName();
+ String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
+ modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
+ modelConfig.addInputType(onnxInputName, valueInfo.getType());
+ }
+
+ // Model outputs
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
+ String onnxOutputName = valueInfo.getName();
+ String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
+ modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
+ modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
+ }
+
+ // Set the first output as default
+ if ( ! model.getGraph().getOutputList().isEmpty()) {
+ modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
+ }
+
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+ }
+
+ static boolean modelFileExists(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (getFile(pathInApplicationPackage, app).exists()) {
+ return true;
+ }
+ if (getFileReference(pathInApplicationPackage, app).isPresent()) {
+ return true;
+ }
+ return false;
+ }
+
+ private InputStream openModelFile(Path path) throws FileNotFoundException {
+ ApplicationFile file;
+ Optional<FileReference> reference;
+ Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
+
+ if ((file = getFile(path)).exists()) {
+ return file.createInputStream();
+ }
+ if ((file = getFile(modelsPath)).exists()) {
+ return file.createInputStream();
+ }
+ if ((reference = getFileReference(path)).isPresent()) {
+ return openFromFileRepository(path, reference.get());
+ }
+ if ((reference = getFileReference(modelsPath)).isPresent()) {
+ return openFromFileRepository(modelsPath, reference.get());
+ }
+
+ throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
+ "in application package or file repository.");
+ }
+
+ private ApplicationFile getFile(Path path) {
+ return getFile(path, search.applicationPackage());
+ }
+
+ private static ApplicationFile getFile(Path path, ApplicationPackage app) {
+ return app.getFile(path);
+ }
+
+ private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
+ return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
+ }
+
+ public static String getFileRepositoryPath(Path path, String fileReference) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
+ return Paths.get(fileRefDir, fileReference, path.getName()).toString();
+ }
+
+ private Optional<FileReference> getFileReference(Path path) {
+ return getFileReference(path, search.applicationPackage());
+ }
+
+ private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
+ Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
+ if (fileRegistry.isPresent()) {
+ for (FileRegistry.Entry file : fileRegistry.get().export()) {
+ if (file.relativePath.equals(path.toString())) {
+ return Optional.of(file.reference);
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
+ if (app == null) return Optional.empty();
+ Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
+ return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
+ }
+
+}