summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-21 09:33:32 +0200
committerLester Solbakken <lesters@oath.com>2021-05-21 09:33:32 +0200
commit6448742f804482946a7bf2d17723dca6b4100b73 (patch)
tree135038f0298f3e519ed8e4327cf1bf1915df4b39
parent864eb3da782e9795826ec78add953a76eeb2ea17 (diff)
Wire in stateless ONNX runtime evaluation
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java50
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java16
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java2
-rw-r--r--config-model/src/test/cfg/application/ml_serving/models/add_mul.onnx24
-rw-r--r--config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnxbin31758 -> 0 bytes
-rw-r--r--config-model/src/test/cfg/application/ml_serving/models/sqrt.onnx11
-rw-r--r--config-model/src/test/cfg/application/ml_serving/models/sqrt.py23
-rw-r--r--config-model/src/test/cfg/application/onnx/files/add.onnx16
-rwxr-xr-xconfig-model/src/test/cfg/application/onnx/files/add.py26
-rw-r--r--config-model/src/test/cfg/application/onnx/models/mul.onnx16
-rwxr-xr-xconfig-model/src/test/cfg/application/onnx/models/mul.py26
-rw-r--r--config-model/src/test/cfg/application/onnx/searchdefinitions/test.sd27
-rw-r--r--config-model/src/test/cfg/application/onnx/services.xml22
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java24
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java82
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java108
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java22
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java24
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java30
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java3
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java10
33 files changed, 536 insertions, 95 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 41cb40da4d6..01d4042573c 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -266,7 +266,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
- reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
+ reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
if ( ! featureTypes.containsKey(reference)) {
throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index be246a143b2..95b291cf744 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -118,6 +118,9 @@ public class RankProfile implements Cloneable {
private List<ImmutableSDField> allFieldsList;
+ /** Global onnx models not tied to a search definition */
+ private OnnxModels onnxModels = new OnnxModels();
+
/**
* Creates a new rank profile for a particular search definition
*
@@ -139,11 +142,12 @@ public class RankProfile implements Cloneable {
* @param name the name of the new profile
* @param model the model owning this profile
*/
- public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry) {
+ public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.search = null;
this.model = Objects.requireNonNull(model, "model cannot be null");
this.rankProfileRegistry = rankProfileRegistry;
+ this.onnxModels = onnxModels;
}
public String getName() { return name; }
@@ -162,7 +166,7 @@ public class RankProfile implements Cloneable {
}
public Map<String, OnnxModel> onnxModels() {
- return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
+ return search != null ? search.onnxModels().asMap() : onnxModels.asMap();
}
private Stream<ImmutableSDField> allFields() {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 5337d58fb82..42fa1df802b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -75,6 +75,9 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
for (RankProfile rank : rankProfileRegistry.rankProfilesOf(search)) {
if (search != null && "default".equals(rank.getName())) continue;
+ if (search == null) {
+ this.onnxModels.add(rank.onnxModels());
+ }
RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields, deployProperties);
rankProfiles.put(rawRank.getName(), rawRank);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 6a497460c5f..3dab19699a3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -44,7 +44,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
ConvertedModel convertedModel =
convertedOnnxModels.computeIfAbsent(arguments.path(),
- path -> ConvertedModel.fromSourceOrStore(path, true, context));
+ path -> ConvertedModel.fromSourceOrStore(path, true, context, true));
return convertedModel.expression(arguments, context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index d78a69b4802..5098796e409 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -30,6 +30,9 @@ import com.yahoo.config.model.producer.UserConfigRepo;
import com.yahoo.config.provision.AllocatedHosts;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.container.QrConfig;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.OnnxModels;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstants;
@@ -57,6 +60,7 @@ import com.yahoo.vespa.model.filedistribution.FileDistributor;
import com.yahoo.vespa.model.generic.service.ServiceCluster;
import com.yahoo.vespa.model.ml.ConvertedModel;
import com.yahoo.vespa.model.ml.ModelName;
+import com.yahoo.vespa.model.ml.OnnxModelInfo;
import com.yahoo.vespa.model.routing.Routing;
import com.yahoo.vespa.model.search.AbstractSearchCluster;
import com.yahoo.vespa.model.utils.internal.ReflectionUtil;
@@ -277,6 +281,46 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
serviceClusters.addAll(builder.getClusters(deployState, this));
}
+ private OnnxModels onnxModelInfoFromSource(ImportedMlModel model) {
+ OnnxModels onnxModels = new OnnxModels();
+ if (model.modelType().equals(ImportedMlModel.ModelType.ONNX)) {
+ String path = model.source();
+ String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString();
+ if (path.startsWith(applicationPath)) {
+ path = path.substring(applicationPath.length() + 1);
+ }
+ loadModelInfo(onnxModels, model.name(), path);
+ }
+ return onnxModels;
+ }
+
+ private OnnxModels onnxModelInfoFromStore(String modelName) {
+ OnnxModels onnxModels = new OnnxModels();
+ String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString();
+ loadModelInfo(onnxModels, modelName, path);
+ return onnxModels;
+ }
+
+ private void loadModelInfo(OnnxModels onnModels, String name, String path) {
+ boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
+ if ( ! modelExists) {
+ path = ApplicationPackage.MODELS_DIR.append(path).toString();
+ modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
+ }
+ if (modelExists) {
+ OnnxModel onnxModel = new OnnxModel(name, path);
+ OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), this.applicationPackage);
+ for (String onnxName : onnxModelInfo.getInputs()) {
+ onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
+ for (String onnxName : onnxModelInfo.getOutputs()) {
+ onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
+ onnxModel.setModelInfo(onnxModelInfo);
+ onnModels.add(onnxModel);
+ }
+ }
+
/**
* Creates a rank profile not attached to any search definition, for each imported model in the application package,
* and adds it to the given rank profile registry.
@@ -286,7 +330,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
QueryProfiles queryProfiles) {
if ( ! importedModels.all().isEmpty()) { // models/ directory is available
for (ImportedMlModel model : importedModels.all()) {
- RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
+ OnnxModels onnxModels = onnxModelInfoFromSource(model);
+ RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry, onnxModels);
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
model.name(), profile, queryProfiles.getRegistry(), model);
@@ -298,7 +343,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) {
String modelName = generatedModelDir.getPath().last();
if (modelName.contains(".")) continue; // Name space: Not a global profile
- RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
+ OnnxModels onnxModels = onnxModelInfoFromStore(modelName);
+ RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry, onnxModels);
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile);
convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false));
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 9086ca9f40e..da26ea9daf2 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -73,14 +73,23 @@ public class ConvertedModel {
this.sourceModel = sourceModel;
}
+ public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) {
+ return fromSourceOrStore(modelPath, pathIsFile, context, false);
+ }
+
/**
* Create and store a converted model for a rank profile given from either an imported model,
* or (if unavailable) from stored application package data.
*
* @param modelPath the path to the model
* @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory
+ * @param context the transform context
+ * @param convertToNative force conversion to native Vespa expressions (if applicable)
*/
- public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) {
+ public static ConvertedModel fromSourceOrStore(Path modelPath,
+ boolean pathIsFile,
+ RankProfileTransformContext context,
+ boolean convertToNative) {
ImportedMlModel sourceModel = // TODO: Convert to name here, make sure its done just one way
context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath));
ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile);
@@ -90,6 +99,9 @@ public class ConvertedModel {
context.importedModels().all().stream().map(ImportedMlModel::source).collect(Collectors.joining(", ")));
if (sourceModel != null) {
+ if (convertToNative && ! sourceModel.isNative()) {
+ sourceModel = sourceModel.asNative();
+ }
return fromSource(modelName,
modelPath.toString(),
context.rankProfile(),
@@ -592,7 +604,7 @@ public class ConvertedModel {
// Write content explicitly as a file on the file system as this is distributed using file distribution
// - but only if this is a global model to avoid writing the same constants for each rank profile
// where they are used
- if (modelFiles.modelName.isGlobal()) {
+ if (modelFiles.modelName.isGlobal() || ! application.getFileReference(constantPath).exists()) {
createIfNeeded(constantsPath);
IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
index 7e33faadfc0..b6a9c855aeb 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
@@ -4,7 +4,7 @@ package com.yahoo.vespa.model.ml;
import com.yahoo.path.Path;
/**
- * Models used in a rank profile has the rank profile name as name space while gGlobal model names have no namespace
+ * Models used in a rank profile has the rank profile name as name space while global model names have no namespace
*
* @author bratseth
*/
diff --git a/config-model/src/test/cfg/application/ml_serving/models/add_mul.onnx b/config-model/src/test/cfg/application/ml_serving/models/add_mul.onnx
new file mode 100644
index 00000000000..ab054d112e9
--- /dev/null
+++ b/config-model/src/test/cfg/application/ml_serving/models/add_mul.onnx
@@ -0,0 +1,24 @@
+
+add_mul.py:£
+
+input1
+input2output1"Mul
+
+input1
+input2output2"Addadd_mulZ
+input1
+
+
+Z
+input2
+
+
+b
+output1
+
+
+b
+output2
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx
deleted file mode 100644
index a86019bf53a..00000000000
--- a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx
+++ /dev/null
Binary files differ
diff --git a/config-model/src/test/cfg/application/ml_serving/models/sqrt.onnx b/config-model/src/test/cfg/application/ml_serving/models/sqrt.onnx
new file mode 100644
index 00000000000..04a6420002c
--- /dev/null
+++ b/config-model/src/test/cfg/application/ml_serving/models/sqrt.onnx
@@ -0,0 +1,11 @@
+sqrt.py:V
+
+input out/layer/1:1"SqrtsqrtZ
+input
+
+
+b
+ out/layer/1:1
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/ml_serving/models/sqrt.py b/config-model/src/test/cfg/application/ml_serving/models/sqrt.py
new file mode 100644
index 00000000000..b7b99b3850c
--- /dev/null
+++ b/config-model/src/test/cfg/application/ml_serving/models/sqrt.py
@@ -0,0 +1,23 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1])
+OUTPUT = helper.make_tensor_value_info('out/layer/1:1', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Sqrt',
+ ['input'],
+ ['out/layer/1:1'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'sqrt',
+ [INPUT],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='sqrt.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'sqrt.onnx')
diff --git a/config-model/src/test/cfg/application/onnx/files/add.onnx b/config-model/src/test/cfg/application/onnx/files/add.onnx
new file mode 100644
index 00000000000..28318dbba4d
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx/files/add.onnx
@@ -0,0 +1,16 @@
+add.py:f
+
+input1
+input2output"AddaddZ
+input1
+
+
+Z
+input2
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/onnx/files/add.py b/config-model/src/test/cfg/application/onnx/files/add.py
new file mode 100755
index 00000000000..63b7dc87796
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx/files/add.py
@@ -0,0 +1,26 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Add',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'add',
+ [
+ INPUT_1,
+ INPUT_2
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='add.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'add.onnx')
diff --git a/config-model/src/test/cfg/application/onnx/models/mul.onnx b/config-model/src/test/cfg/application/onnx/models/mul.onnx
new file mode 100644
index 00000000000..087e2c3427f
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx/models/mul.onnx
@@ -0,0 +1,16 @@
+mul.py:f
+
+input1
+input2output"MulmulZ
+input1
+
+
+Z
+input2
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/onnx/models/mul.py b/config-model/src/test/cfg/application/onnx/models/mul.py
new file mode 100755
index 00000000000..db01561c355
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx/models/mul.py
@@ -0,0 +1,26 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Mul',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'mul',
+ [
+ INPUT_1,
+ INPUT_2
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'mul.onnx')
diff --git a/config-model/src/test/cfg/application/onnx/searchdefinitions/test.sd b/config-model/src/test/cfg/application/onnx/searchdefinitions/test.sd
new file mode 100644
index 00000000000..d49782ddf39
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx/searchdefinitions/test.sd
@@ -0,0 +1,27 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+search test {
+
+ document test {
+ field document_value type tensor<float>(d0[1]) {
+ indexing: attribute
+ }
+ }
+
+ onnx-model my_add {
+ file: files/add.onnx
+ input input1: attribute(document_value)
+ input input2: my_input_func
+ output output: out
+ }
+
+ rank-profile test {
+ function my_function() {
+ expression: tensor<float>(d0[1])(1)
+ }
+ first-phase {
+ expression: onnx(my_add).out{d0:1}
+ }
+ }
+
+}
diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml
new file mode 100644
index 00000000000..8731558c6f7
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx/services.xml
@@ -0,0 +1,22 @@
+<?xml version="1.0" encoding="utf-8" ?>
+<!-- Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services version="1.0">
+
+ <container version="1.0">
+ <model-evaluation/>
+ <nodes>
+ <node hostalias="node1" />
+ </nodes>
+ </container>
+
+ <content id="test" version="1.0">
+ <redundancy>1</redundancy>
+ <documents>
+ <document mode="index" type="test"/>
+ </documents>
+ <nodes>
+ <node distribution-key="0" hostalias="node1" />
+ </nodes>
+ </content>
+
+</services>
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 40bf970a313..a64b36b327d 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -44,30 +44,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testGlobalOnnxModel() throws IOException {
- ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
- VespaModel model = tester.createVespaModel();
- tester.assertLargeConstant(name + "_layer_Variable_1", model, Optional.of(10L));
- tester.assertLargeConstant(name + "_layer_Variable", model, Optional.of(7840L));
-
- // At this point the expression is stored - copy application to another location which do not have a models dir
- Path storedAppDir = applicationDir.append("copy");
- try {
- storedAppDir.toFile().mkdirs();
- IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
- IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
- storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
- VespaModel storedModel = storedTester.createVespaModel();
- tester.assertLargeConstant(name + "_layer_Variable_1", storedModel, Optional.of(10L));
- tester.assertLargeConstant(name + "_layer_Variable", storedModel, Optional.of(7840L));
- }
- finally {
- IOUtils.recursiveDeleteDir(storedAppDir.toFile());
- }
- }
-
- @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx_vespa('mnist_softmax.onnx')",
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index d0196ace766..bf35a002e3a 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -3,16 +3,13 @@ package com.yahoo.vespa.model.ml;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
-import ai.vespa.models.evaluation.RankProfilesConfigImporter;
import ai.vespa.models.handler.ModelsEvaluationHandler;
import com.yahoo.component.ComponentId;
-import com.yahoo.config.FileReference;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
-import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
@@ -21,7 +18,10 @@ import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.junit.Test;
+import java.io.File;
import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@@ -64,7 +64,7 @@ public class ModelEvaluationTest {
Path storedAppDir = appDir.append("copy");
try {
ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir);
- assertHasMlModels(tester.createVespaModel());
+ assertHasMlModels(tester.createVespaModel(), appDir);
// At this point the expression is stored - copy application to another location which do not have a models dir
storedAppDir.toFile().mkdirs();
@@ -72,7 +72,7 @@ public class ModelEvaluationTest {
IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
ImportedModelTester storedTester = new ImportedModelTester("ml_serving", storedAppDir);
- assertHasMlModels(storedTester.createVespaModel());
+ assertHasMlModels(storedTester.createVespaModel(), appDir);
}
finally {
IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
@@ -80,7 +80,7 @@ public class ModelEvaluationTest {
}
}
- private void assertHasMlModels(VespaModel model) {
+ private void assertHasMlModels(VespaModel model, Path appDir) {
ApplicationContainerCluster cluster = model.getContainerClusters().get("container");
assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
@@ -100,12 +100,13 @@ public class ModelEvaluationTest {
cluster.getConfig(ob);
OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob);
- assertEquals(4, config.rankprofile().size());
+ assertEquals(5, config.rankprofile().size());
Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
assertTrue(modelNames.contains("xgboost_2_2"));
assertTrue(modelNames.contains("lightgbm_regression"));
- assertTrue(modelNames.contains("mnist_softmax"));
+ assertTrue(modelNames.contains("add_mul"));
assertTrue(modelNames.contains("small_constants_and_functions"));
+ assertTrue(modelNames.contains("sqrt"));
// Compare profile content in a denser format than config:
StringBuilder sb = new StringBuilder();
@@ -113,10 +114,14 @@ public class ModelEvaluationTest {
sb.append(p.name()).append(": ").append(p.value()).append("\n");
assertEquals(profile, sb.toString());
- ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null))
- .importFrom(config, constantsConfig, onnxModelsConfig));
+ Map<String, File> fileMap = new HashMap<>();
+ for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) {
+ fileMap.put(onnxModel.fileref().value(), appDir.append(onnxModel.fileref().value()).toFile());
+ }
+ FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
+ ModelsEvaluator evaluator = new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
- assertEquals(4, evaluator.models().size());
+ assertEquals(5, evaluator.models().size());
Model xgboost = evaluator.models().get("xgboost_2_2");
assertNotNull(xgboost);
@@ -128,31 +133,29 @@ public class ModelEvaluationTest {
assertNotNull(lightgbm.evaluatorOf());
assertNotNull(lightgbm.evaluatorOf("lightgbm_regression"));
- Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax");
- assertNotNull(onnx_mnist_softmax);
- assertEquals(1, onnx_mnist_softmax.functions().size());
- assertNotNull(onnx_mnist_softmax.evaluatorOf());
- assertNotNull(onnx_mnist_softmax.evaluatorOf("default"));
- assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add"));
- assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add"));
- assertNotNull(onnx_mnist_softmax.evaluatorOf("add"));
- assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default"));
- assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default", "add"));
- assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default.add"));
- assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add"));
- assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add"));
- assertNotNull(evaluator.evaluatorOf("mnist_softmax", "add"));
- assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default.add"));
- assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default", "add"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), onnx_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder"));
+ Model add_mul = evaluator.models().get("add_mul");
+ assertNotNull(add_mul);
+ assertEquals(2, add_mul.functions().size());
+ assertNotNull(add_mul.evaluatorOf("output1"));
+ assertNotNull(add_mul.evaluatorOf("output2"));
+ assertNotNull(evaluator.evaluatorOf("add_mul", "output1"));
+ assertNotNull(evaluator.evaluatorOf("add_mul", "output2"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2"));
+
+ Model sqrt = evaluator.models().get("sqrt");
+ assertNotNull(sqrt);
+ assertEquals(1, sqrt.functions().size());
+ assertNotNull(sqrt.evaluatorOf());
+ assertNotNull(sqrt.evaluatorOf("out_layer_1_1")); // converted from "out/layer/1:1"
+ assertNotNull(evaluator.evaluatorOf("sqrt"));
+ assertNotNull(evaluator.evaluatorOf("sqrt", "out_layer_1_1"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).argumentTypes().get("input"));
}
private final String profile =
- "rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).rankingScript: map(input, f(a)(exp(a)))\n" +
- "rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).type: tensor<float>(d0[3])\n" +
- "rankingExpression(default.output).rankingScript: join(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), reduce(join(join(reduce(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), 9.999999974752427E-7, f(a,b)(a + b)), sum, d0), f(a,b)(a / b))\n" +
- "rankingExpression(default.output).input.type: tensor<float>(d0[3])\n" +
- "rankingExpression(default.output).type: tensor<float>(d0[3])\n";
+ "rankingExpression(output).rankingScript: onnxModel(small_constants_and_functions).output\n" +
+ "rankingExpression(output).type: tensor<float>(d0[3])\n";
private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) {
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
@@ -162,17 +165,4 @@ public class ModelEvaluationTest {
throw new IllegalArgumentException("No profile named " + name);
}
- // We don't have function file distribution so just return empty tensor constants
- private static class ToleratingMissingConstantFilesRankProfilesConfigImporter extends RankProfilesConfigImporter {
-
- public ToleratingMissingConstantFilesRankProfilesConfigImporter(FileAcquirer fileAcquirer) {
- super(fileAcquirer);
- }
-
- protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) {
- return Tensor.from(type, "{}");
- }
-
- }
-
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
new file mode 100644
index 00000000000..5dea4a04229
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
@@ -0,0 +1,108 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import ai.vespa.models.evaluation.FunctionEvaluator;
+import ai.vespa.models.evaluation.Model;
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import ai.vespa.models.evaluation.RankProfilesConfigImporter;
+import ai.vespa.models.handler.ModelsEvaluationHandler;
+import com.yahoo.component.ComponentId;
+import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests stateless model evaluation (turned on by the "model-evaluation" tag in "container")
+ * for ONNX models.
+ *
+ * @author lesters
+ */
+public class StatelessOnnxEvaluationTest {
+
+ @Test
+ public void testStatelessOnnxModelEvaluation() throws IOException {
+ Path appDir = Path.fromString("src/test/cfg/application/onnx");
+ Path storedAppDir = appDir.append("copy");
+ try {
+ ImportedModelTester tester = new ImportedModelTester("onnx_rt", appDir);
+ assertModelEvaluation(tester.createVespaModel(), appDir);
+
+ // At this point the expression is stored - copy application to another location which does not have a models dir
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(appDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ IOUtils.copyDirectory(appDir.append(ApplicationPackage.SEARCH_DEFINITIONS_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.SEARCH_DEFINITIONS_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester("onnx_rt", storedAppDir);
+ assertModelEvaluation(storedTester.createVespaModel(), appDir);
+
+ } finally {
+ IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ private void assertModelEvaluation(VespaModel model, Path appDir) {
+ ApplicationContainerCluster cluster = model.getContainerClusters().get("container");
+ assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
+
+ RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
+ cluster.getConfig(b);
+ RankProfilesConfig config = new RankProfilesConfig(b);
+
+ RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder();
+ cluster.getConfig(cb);
+ RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb);
+
+ OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder();
+ cluster.getConfig(ob);
+ OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob);
+
+ assertEquals(1, config.rankprofile().size());
+ Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
+ assertTrue(modelNames.contains("mul"));
+
+ // This is actually how ModelsEvaluator is injected
+ Map<String, File> fileMap = new HashMap<>();
+ for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) {
+ fileMap.put(onnxModel.fileref().value(), appDir.append(onnxModel.fileref().value()).toFile());
+ }
+ FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
+ ModelsEvaluator modelsEvaluator = new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
+ assertEquals(1, modelsEvaluator.models().size());
+
+ Model mul = modelsEvaluator.models().get("mul");
+ FunctionEvaluator evaluator = mul.evaluatorOf(); // or "default.output" - or actually use name of model output
+
+ Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]");
+ Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]");
+ Tensor output = evaluator.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(6.0, output.sum().asDouble(), 1e-9);
+
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index c1f973300d6..68bebfa6183 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -25,6 +26,7 @@ import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.HashMap;
import java.util.Map;
+import java.util.Set;
import java.util.stream.Collectors;
@@ -38,7 +40,8 @@ class TensorConverter {
{
Map<String, OnnxTensor> result = new HashMap<>();
for (String name : tensorMap.keySet()) {
- Tensor vespaTensor = tensorMap.get(name);
+ Tensor vespaTensor = tensorMap.get(name);
+ name = toOnnxName(name, session.getInputInfo().keySet());
TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo());
OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env);
result.put(name, onnxTensor);
@@ -143,7 +146,22 @@ class TensorConverter {
}
static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) {
- return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo())));
+ return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()),
+ e -> toVespaType(e.getValue().getInfo())));
+ }
+
+ static String asValidName(String name) {
+ return OnnxImporter.asValidIdentifier(name);
+ }
+
+ static String toOnnxName(String name, Set<String> onnxNames) {
+ if (onnxNames.contains(name))
+ return name;
+ for (String onnxName : onnxNames) {
+ if (asValidName(onnxName).equals(name))
+ return onnxName;
+ }
+ throw new IllegalArgumentException("ONNX model has no input with name " + name);
}
static TensorType toVespaType(ValueInfo valueInfo) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 47fe66dd424..cf92cbc1e89 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -35,6 +35,7 @@ public class ImportedModel implements ImportedMlModel {
private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
private final String name;
private final String source;
+ private final ModelType modelType;
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> inputs = new HashMap<>();
@@ -49,11 +50,12 @@ public class ImportedModel implements ImportedMlModel {
* @param name the name of this mode, containing only characters in [A-Za-z0-9_]
* @param source the source path (directory or file) of this model
*/
- public ImportedModel(String name, String source) {
+ public ImportedModel(String name, String source, ModelType modelType) {
if ( ! nameRegexp.matcher(name).matches())
throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'");
this.name = name;
this.source = source;
+ this.modelType = modelType;
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
@@ -64,6 +66,10 @@ public class ImportedModel implements ImportedMlModel {
@Override
public String source() { return source; }
+ /** Returns the original model type */
+ @Override
+ public ModelType modelType() { return modelType; }
+
@Override
public String toString() { return "imported model '" + name + "' from " + source; }
@@ -212,6 +218,16 @@ public class ImportedModel implements ImportedMlModel {
return values;
}
+ @Override
+ public boolean isNative() {
+ return true;
+ }
+
+ @Override
+ public ImportedModel asNative() {
+ return this;
+ }
+
/**
* A signature is a set of named inputs and outputs, where the inputs maps to input
* ("placeholder") names+types, and outputs maps to expressions nodes.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 8f73cd02184..3f87bfa0beb 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -52,8 +53,10 @@ public abstract class ModelImporter implements MlModelImporter {
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
*/
- protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
- ImportedModel model = new ImportedModel(graph.name(), modelSource);
+ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph,
+ String modelSource,
+ ImportedMlModel.ModelType modelType) {
+ ImportedModel model = new ImportedModel(graph.name(), modelSource, modelType);
log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" +
ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString()));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
index e40a06af042..b98dfa33320 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
@@ -12,12 +12,20 @@ import java.util.Optional;
*/
public interface ImportedMlModel {
+ enum ModelType {
+ VESPA, XGBOOST, LIGHTGBM, TENSORFLOW, ONNX
+ }
+
String name();
String source();
+ ModelType modelType();
+
Optional<String> inputTypeSpec(String input);
Map<String, String> smallConstants();
Map<String, String> largeConstants();
Map<String, String> functions();
List<ImportedMlFunction> outputExpressions();
+ boolean isNative();
+ ImportedMlModel asNative();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
index 76caa652ad2..ef731730c84 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.lightgbm;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -39,7 +40,7 @@ public class LightGBMImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
- ImportedModel model = new ImportedModel(modelName, modelPath);
+ ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.LIGHTGBM);
LightGBMParser parser = new LightGBMParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java
new file mode 100644
index 00000000000..714321ec116
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java
@@ -0,0 +1,24 @@
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import onnx.Onnx;
+
+public class ImportedOnnxModel extends ImportedModel {
+
+ private final Onnx.ModelProto modelProto;
+
+ public ImportedOnnxModel(String name, String source, Onnx.ModelProto modelProto) {
+ super(name, source, ModelType.ONNX);
+ this.modelProto = modelProto;
+ }
+
+ @Override
+ public boolean isNative() {
+ return false;
+ }
+
+ @Override
+ public ImportedModel asNative() {
+ return OnnxImporter.convertModel(name(), source(), modelProto, ModelType.ONNX);
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
index b1c5dc8a0d8..8e4dd07fc73 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
@@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import onnx.Onnx;
import java.io.File;
@@ -31,11 +32,36 @@ public class OnnxImporter extends ModelImporter {
try (FileInputStream inputStream = new FileInputStream(modelPath)) {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
// long version = model.getOpsetImport(0).getVersion(); // opset version
- IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph, modelPath);
+
+ ImportedModel importedModel = new ImportedOnnxModel(modelName, modelPath, model);
+ for (int i = 0; i < model.getGraph().getOutputCount(); ++i) {
+ Onnx.ValueInfoProto output = model.getGraph().getOutput(i);
+ String outputName = asValidIdentifier(output.getName());
+ importedModel.expression(outputName, "onnxModel(" + modelName + ")." + outputName);
+ }
+ return importedModel;
+
} catch (IOException e) {
throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
}
}
+ public ImportedModel importModelAsNative(String modelName, String modelPath, ImportedMlModel.ModelType modelType) {
+ try (FileInputStream inputStream = new FileInputStream(modelPath)) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ return convertModel(modelName, modelPath, model, modelType);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
+ }
+ }
+
+ public static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ }
+
+ static ImportedModel convertModel(String name, String source, Onnx.ModelProto modelProto, ImportedMlModel.ModelType modelType) {
+ IntermediateGraph graph = GraphImporter.importGraph(name, modelProto);
+ return convertIntermediateGraphToModel(graph, source, modelType);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index a879c24b373..402c6562cc4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.tensorflow;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.collections.Pair;
import com.yahoo.io.IOUtils;
@@ -59,7 +60,7 @@ public class TensorFlowImporter extends ModelImporter {
public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
try {
IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph, modelDir);
+ return convertIntermediateGraphToModel(graph, modelDir, ImportedMlModel.ModelType.TENSORFLOW);
}
catch (IOException e) {
throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
@@ -77,7 +78,13 @@ public class TensorFlowImporter extends ModelImporter {
Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, opset);
if (res.getFirst() == 0) {
log.info("Conversion to ONNX with opset " + opset + " successful.");
- return onnxImporter.importModel(modelName, convertedPath);
+
+ /*
+ * For now we have to import tensorflow models as native Vespa expressions.
+ * The temporary ONNX file that is created by conversion needs to be put
+ * in the application package so it can be file distributed.
+ */
+ return onnxImporter.importModelAsNative(modelName, convertedPath, ImportedMlModel.ModelType.TENSORFLOW);
}
log.fine("Conversion to ONNX with opset " + opset + " failed. Reason: " + res.getSecond());
outputOfLastConversionAttempt = res.getSecond();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
index 021fa1f7e51..95ecd678bd5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.vespa;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.vespa.parser.ModelParser;
import ai.vespa.rankingexpression.importer.vespa.parser.ParseException;
@@ -28,7 +29,7 @@ public class VespaImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
- ImportedModel model = new ImportedModel(modelName, modelPath);
+ ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.VESPA);
new ModelParser(new SimpleCharStream(IOUtils.readFile(new File(modelPath))), model).model();
return model;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index 686cf6cd2df..5829ea77815 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import ai.vespa.rankingexpression.importer.ImportedModel;
@@ -50,7 +51,7 @@ public class XGBoostImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
- ImportedModel model = new ImportedModel(modelName, modelPath);
+ ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.XGBOOST);
XGBoostParser parser = new XGBoostParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 09455abc380..96bf2c64485 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -24,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx").asNative();
// Check constants
assertEquals(2, model.largeConstants().size());
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
index f03c629df78..cd13afca77b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
@@ -14,7 +14,7 @@ public class PyTorchImportTestCase extends TestableModel {
@Test
public void testPyTorchExport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx").asNative();
Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult);
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index 04db902073b..4c7f6139032 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -14,6 +14,8 @@ import com.yahoo.tensor.Tensor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
/**
* @author lesters
@@ -24,6 +26,10 @@ public class SimpleImportTestCase {
public void testSimpleOnnxModelImport() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx");
+ assertFalse(model.isNative());
+ model = model.asNative();
+ assertTrue(model.isNative());
+
MapContext context = new MapContext();
context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]")));
context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]")));
@@ -35,7 +41,7 @@ public class SimpleImportTestCase {
@Test
public void testGather() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx").asNative();
MapContext context = new MapContext();
context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]")));
@@ -49,7 +55,7 @@ public class SimpleImportTestCase {
@Test
public void testConcat() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx").asNative();
MapContext context = new MapContext();
context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]")));