diff options
author | Lester Solbakken <lesters@oath.com> | 2021-05-21 09:33:32 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-05-21 09:33:32 +0200 |
commit | 6448742f804482946a7bf2d17723dca6b4100b73 (patch) | |
tree | 135038f0298f3e519ed8e4327cf1bf1915df4b39 | |
parent | 864eb3da782e9795826ec78add953a76eeb2ea17 (diff) |
Wire in stateless ONNX runtime evaluation
33 files changed, 536 insertions, 95 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 41cb40da4d6..01d4042573c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -266,7 +266,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); - reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); if ( ! featureTypes.containsKey(reference)) { throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index be246a143b2..95b291cf744 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -118,6 +118,9 @@ public class RankProfile implements Cloneable { private List<ImmutableSDField> allFieldsList; + /** Global onnx models not tied to a search definition */ + private OnnxModels onnxModels = new OnnxModels(); + /** * Creates a new rank profile for a particular search definition * @@ -139,11 +142,12 @@ public class RankProfile implements Cloneable { * @param name the name of the new profile * @param model the model owning this profile */ - public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry) { + public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.search = null; this.model = Objects.requireNonNull(model, "model cannot be null"); this.rankProfileRegistry = rankProfileRegistry; + this.onnxModels = onnxModels; } public String getName() { return name; } @@ -162,7 +166,7 @@ public class RankProfile implements Cloneable { } public Map<String, OnnxModel> onnxModels() { - return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); + return search != null ? search.onnxModels().asMap() : onnxModels.asMap(); } private Stream<ImmutableSDField> allFields() { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 5337d58fb82..42fa1df802b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -75,6 +75,9 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ for (RankProfile rank : rankProfileRegistry.rankProfilesOf(search)) { if (search != null && "default".equals(rank.getName())) continue; + if (search == null) { + this.onnxModels.add(rank.onnxModels()); + } RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields, deployProperties); rankProfiles.put(rawRank.getName(), rawRank); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 6a497460c5f..3dab19699a3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -44,7 +44,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans FeatureArguments arguments = asFeatureArguments(feature.getArguments()); ConvertedModel convertedModel = convertedOnnxModels.computeIfAbsent(arguments.path(), - path -> ConvertedModel.fromSourceOrStore(path, true, context)); + path -> ConvertedModel.fromSourceOrStore(path, true, context, true)); return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index d78a69b4802..5098796e409 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -30,6 +30,9 @@ import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.container.QrConfig; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.OnnxModels; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstants; @@ -57,6 +60,7 @@ import com.yahoo.vespa.model.filedistribution.FileDistributor; import com.yahoo.vespa.model.generic.service.ServiceCluster; import com.yahoo.vespa.model.ml.ConvertedModel; import com.yahoo.vespa.model.ml.ModelName; +import com.yahoo.vespa.model.ml.OnnxModelInfo; import com.yahoo.vespa.model.routing.Routing; import com.yahoo.vespa.model.search.AbstractSearchCluster; import com.yahoo.vespa.model.utils.internal.ReflectionUtil; @@ -277,6 +281,46 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri serviceClusters.addAll(builder.getClusters(deployState, this)); } + private OnnxModels onnxModelInfoFromSource(ImportedMlModel model) { + OnnxModels onnxModels = new OnnxModels(); + if (model.modelType().equals(ImportedMlModel.ModelType.ONNX)) { + String path = model.source(); + String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString(); + if (path.startsWith(applicationPath)) { + path = path.substring(applicationPath.length() + 1); + } + loadModelInfo(onnxModels, model.name(), path); + } + return onnxModels; + } + + private OnnxModels onnxModelInfoFromStore(String modelName) { + OnnxModels onnxModels = new OnnxModels(); + String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString(); + loadModelInfo(onnxModels, modelName, path); + return onnxModels; + } + + private void loadModelInfo(OnnxModels onnModels, String name, String path) { + boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); + if ( ! modelExists) { + path = ApplicationPackage.MODELS_DIR.append(path).toString(); + modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); + } + if (modelExists) { + OnnxModel onnxModel = new OnnxModel(name, path); + OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), this.applicationPackage); + for (String onnxName : onnxModelInfo.getInputs()) { + onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } + for (String onnxName : onnxModelInfo.getOutputs()) { + onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } + onnxModel.setModelInfo(onnxModelInfo); + onnModels.add(onnxModel); + } + } + /** * Creates a rank profile not attached to any search definition, for each imported model in the application package, * and adds it to the given rank profile registry. @@ -286,7 +330,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri QueryProfiles queryProfiles) { if ( ! importedModels.all().isEmpty()) { // models/ directory is available for (ImportedMlModel model : importedModels.all()) { - RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); + OnnxModels onnxModels = onnxModelInfoFromSource(model); + RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), model.name(), profile, queryProfiles.getRegistry(), model); @@ -298,7 +343,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) { String modelName = generatedModelDir.getPath().last(); if (modelName.contains(".")) continue; // Name space: Not a global profile - RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); + OnnxModels onnxModels = onnxModelInfoFromStore(modelName); + RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 9086ca9f40e..da26ea9daf2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -73,14 +73,23 @@ public class ConvertedModel { this.sourceModel = sourceModel; } + public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { + return fromSourceOrStore(modelPath, pathIsFile, context, false); + } + /** * Create and store a converted model for a rank profile given from either an imported model, * or (if unavailable) from stored application package data. * * @param modelPath the path to the model * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory + * @param context the transform context + * @param convertToNative force conversion to native Vespa expressions (if applicable) */ - public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { + public static ConvertedModel fromSourceOrStore(Path modelPath, + boolean pathIsFile, + RankProfileTransformContext context, + boolean convertToNative) { ImportedMlModel sourceModel = // TODO: Convert to name here, make sure its done just one way context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath)); ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile); @@ -90,6 +99,9 @@ public class ConvertedModel { context.importedModels().all().stream().map(ImportedMlModel::source).collect(Collectors.joining(", "))); if (sourceModel != null) { + if (convertToNative && ! sourceModel.isNative()) { + sourceModel = sourceModel.asNative(); + } return fromSource(modelName, modelPath.toString(), context.rankProfile(), @@ -592,7 +604,7 @@ public class ConvertedModel { // Write content explicitly as a file on the file system as this is distributed using file distribution // - but only if this is a global model to avoid writing the same constants for each rank profile // where they are used - if (modelFiles.modelName.isGlobal()) { + if (modelFiles.modelName.isGlobal() || ! application.getFileReference(constantPath).exists()) { createIfNeeded(constantsPath); IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java index 7e33faadfc0..b6a9c855aeb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.model.ml; import com.yahoo.path.Path; /** - * Models used in a rank profile has the rank profile name as name space while gGlobal model names have no namespace + * Models used in a rank profile has the rank profile name as name space while global model names have no namespace * * @author bratseth */ diff --git a/config-model/src/test/cfg/application/ml_serving/models/add_mul.onnx b/config-model/src/test/cfg/application/ml_serving/models/add_mul.onnx new file mode 100644 index 00000000000..ab054d112e9 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/add_mul.onnx @@ -0,0 +1,24 @@ + +add_mul.py:£ + +input1 +input2output1"Mul + +input1 +input2output2"Addadd_mulZ +input1 + + +Z +input2 + + +b +output1 + + +b +output2 + + +B
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx Binary files differdeleted file mode 100644 index a86019bf53a..00000000000 --- a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx +++ /dev/null diff --git a/config-model/src/test/cfg/application/ml_serving/models/sqrt.onnx b/config-model/src/test/cfg/application/ml_serving/models/sqrt.onnx new file mode 100644 index 00000000000..04a6420002c --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/sqrt.onnx @@ -0,0 +1,11 @@ +sqrt.py:V + +input
out/layer/1:1"SqrtsqrtZ +input + + +b +
out/layer/1:1 + + +B
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_serving/models/sqrt.py b/config-model/src/test/cfg/application/ml_serving/models/sqrt.py new file mode 100644 index 00000000000..b7b99b3850c --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/sqrt.py @@ -0,0 +1,23 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1]) +OUTPUT = helper.make_tensor_value_info('out/layer/1:1', TensorProto.FLOAT, [1]) + +nodes = [ + helper.make_node( + 'Sqrt', + ['input'], + ['out/layer/1:1'], + ), +] +graph_def = helper.make_graph( + nodes, + 'sqrt', + [INPUT], + [OUTPUT], +) +model_def = helper.make_model(graph_def, producer_name='sqrt.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'sqrt.onnx') diff --git a/config-model/src/test/cfg/application/onnx/files/add.onnx b/config-model/src/test/cfg/application/onnx/files/add.onnx new file mode 100644 index 00000000000..28318dbba4d --- /dev/null +++ b/config-model/src/test/cfg/application/onnx/files/add.onnx @@ -0,0 +1,16 @@ +add.py:f + +input1 +input2output"AddaddZ +input1 + + +Z +input2 + + +b +output + + +B
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/onnx/files/add.py b/config-model/src/test/cfg/application/onnx/files/add.py new file mode 100755 index 00000000000..63b7dc87796 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx/files/add.py @@ -0,0 +1,26 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) +INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) + +nodes = [ + helper.make_node( + 'Add', + ['input1', 'input2'], + ['output'], + ), +] +graph_def = helper.make_graph( + nodes, + 'add', + [ + INPUT_1, + INPUT_2 + ], + [OUTPUT], +) +model_def = helper.make_model(graph_def, producer_name='add.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'add.onnx') diff --git a/config-model/src/test/cfg/application/onnx/models/mul.onnx b/config-model/src/test/cfg/application/onnx/models/mul.onnx new file mode 100644 index 00000000000..087e2c3427f --- /dev/null +++ b/config-model/src/test/cfg/application/onnx/models/mul.onnx @@ -0,0 +1,16 @@ +mul.py:f + +input1 +input2output"MulmulZ +input1 + + +Z +input2 + + +b +output + + +B
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/onnx/models/mul.py b/config-model/src/test/cfg/application/onnx/models/mul.py new file mode 100755 index 00000000000..db01561c355 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx/models/mul.py @@ -0,0 +1,26 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) +INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) + +nodes = [ + helper.make_node( + 'Mul', + ['input1', 'input2'], + ['output'], + ), +] +graph_def = helper.make_graph( + nodes, + 'mul', + [ + INPUT_1, + INPUT_2 + ], + [OUTPUT], +) +model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'mul.onnx') diff --git a/config-model/src/test/cfg/application/onnx/searchdefinitions/test.sd b/config-model/src/test/cfg/application/onnx/searchdefinitions/test.sd new file mode 100644 index 00000000000..d49782ddf39 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx/searchdefinitions/test.sd @@ -0,0 +1,27 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +search test { + + document test { + field document_value type tensor<float>(d0[1]) { + indexing: attribute + } + } + + onnx-model my_add { + file: files/add.onnx + input input1: attribute(document_value) + input input2: my_input_func + output output: out + } + + rank-profile test { + function my_function() { + expression: tensor<float>(d0[1])(1) + } + first-phase { + expression: onnx(my_add).out{d0:1} + } + } + +} diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml new file mode 100644 index 00000000000..8731558c6f7 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx/services.xml @@ -0,0 +1,22 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + <model-evaluation/> + <nodes> + <node hostalias="node1" /> + </nodes> + </container> + + <content id="test" version="1.0"> + <redundancy>1</redundancy> + <documents> + <document mode="index" type="test"/> + </documents> + <nodes> + <node distribution-key="0" hostalias="node1" /> + </nodes> + </content> + +</services> diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 40bf970a313..a64b36b327d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -44,30 +44,6 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testGlobalOnnxModel() throws IOException { - ImportedModelTester tester = new ImportedModelTester(name, applicationDir); - VespaModel model = tester.createVespaModel(); - tester.assertLargeConstant(name + "_layer_Variable_1", model, Optional.of(10L)); - tester.assertLargeConstant(name + "_layer_Variable", model, Optional.of(7840L)); - - // At this point the expression is stored - copy application to another location which do not have a models dir - Path storedAppDir = applicationDir.append("copy"); - try { - storedAppDir.toFile().mkdirs(); - IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); - IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), - storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); - VespaModel storedModel = storedTester.createVespaModel(); - tester.assertLargeConstant(name + "_layer_Variable_1", storedModel, Optional.of(10L)); - tester.assertLargeConstant(name + "_layer_Variable", storedModel, Optional.of(7840L)); - } - finally { - IOUtils.recursiveDeleteDir(storedAppDir.toFile()); - } - } - - @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx_vespa('mnist_softmax.onnx')", diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index d0196ace766..bf35a002e3a 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -3,16 +3,13 @@ package com.yahoo.vespa.model.ml; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; -import ai.vespa.models.evaluation.RankProfilesConfigImporter; import ai.vespa.models.handler.ModelsEvaluationHandler; import com.yahoo.component.ComponentId; -import com.yahoo.config.FileReference; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; -import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; @@ -21,7 +18,10 @@ import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.junit.Test; +import java.io.File; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -64,7 +64,7 @@ public class ModelEvaluationTest { Path storedAppDir = appDir.append("copy"); try { ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); - assertHasMlModels(tester.createVespaModel()); + assertHasMlModels(tester.createVespaModel(), appDir); // At this point the expression is stored - copy application to another location which do not have a models dir storedAppDir.toFile().mkdirs(); @@ -72,7 +72,7 @@ public class ModelEvaluationTest { IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); ImportedModelTester storedTester = new ImportedModelTester("ml_serving", storedAppDir); - assertHasMlModels(storedTester.createVespaModel()); + assertHasMlModels(storedTester.createVespaModel(), appDir); } finally { IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); @@ -80,7 +80,7 @@ public class ModelEvaluationTest { } } - private void assertHasMlModels(VespaModel model) { + private void assertHasMlModels(VespaModel model, Path appDir) { ApplicationContainerCluster cluster = model.getContainerClusters().get("container"); assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); @@ -100,12 +100,13 @@ public class ModelEvaluationTest { cluster.getConfig(ob); OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob); - assertEquals(4, config.rankprofile().size()); + assertEquals(5, config.rankprofile().size()); Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); assertTrue(modelNames.contains("lightgbm_regression")); - assertTrue(modelNames.contains("mnist_softmax")); + assertTrue(modelNames.contains("add_mul")); assertTrue(modelNames.contains("small_constants_and_functions")); + assertTrue(modelNames.contains("sqrt")); // Compare profile content in a denser format than config: StringBuilder sb = new StringBuilder(); @@ -113,10 +114,14 @@ public class ModelEvaluationTest { sb.append(p.name()).append(": ").append(p.value()).append("\n"); assertEquals(profile, sb.toString()); - ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null)) - .importFrom(config, constantsConfig, onnxModelsConfig)); + Map<String, File> fileMap = new HashMap<>(); + for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) { + fileMap.put(onnxModel.fileref().value(), appDir.append(onnxModel.fileref().value()).toFile()); + } + FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap); + ModelsEvaluator evaluator = new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer); - assertEquals(4, evaluator.models().size()); + assertEquals(5, evaluator.models().size()); Model xgboost = evaluator.models().get("xgboost_2_2"); assertNotNull(xgboost); @@ -128,31 +133,29 @@ public class ModelEvaluationTest { assertNotNull(lightgbm.evaluatorOf()); assertNotNull(lightgbm.evaluatorOf("lightgbm_regression")); - Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax"); - assertNotNull(onnx_mnist_softmax); - assertEquals(1, onnx_mnist_softmax.functions().size()); - assertNotNull(onnx_mnist_softmax.evaluatorOf()); - assertNotNull(onnx_mnist_softmax.evaluatorOf("default")); - assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add")); - assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add")); - assertNotNull(onnx_mnist_softmax.evaluatorOf("add")); - assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default")); - assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default", "add")); - assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default.add")); - assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add")); - assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add")); - assertNotNull(evaluator.evaluatorOf("mnist_softmax", "add")); - assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default.add")); - assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default", "add")); - assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), onnx_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder")); + Model add_mul = evaluator.models().get("add_mul"); + assertNotNull(add_mul); + assertEquals(2, add_mul.functions().size()); + assertNotNull(add_mul.evaluatorOf("output1")); + assertNotNull(add_mul.evaluatorOf("output2")); + assertNotNull(evaluator.evaluatorOf("add_mul", "output1")); + assertNotNull(evaluator.evaluatorOf("add_mul", "output2")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2")); + + Model sqrt = evaluator.models().get("sqrt"); + assertNotNull(sqrt); + assertEquals(1, sqrt.functions().size()); + assertNotNull(sqrt.evaluatorOf()); + assertNotNull(sqrt.evaluatorOf("out_layer_1_1")); // converted from "out/layer/1:1" + assertNotNull(evaluator.evaluatorOf("sqrt")); + assertNotNull(evaluator.evaluatorOf("sqrt", "out_layer_1_1")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).argumentTypes().get("input")); } private final String profile = - "rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).rankingScript: map(input, f(a)(exp(a)))\n" + - "rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).type: tensor<float>(d0[3])\n" + - "rankingExpression(default.output).rankingScript: join(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), reduce(join(join(reduce(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), 9.999999974752427E-7, f(a,b)(a + b)), sum, d0), f(a,b)(a / b))\n" + - "rankingExpression(default.output).input.type: tensor<float>(d0[3])\n" + - "rankingExpression(default.output).type: tensor<float>(d0[3])\n"; + "rankingExpression(output).rankingScript: onnxModel(small_constants_and_functions).output\n" + + "rankingExpression(output).type: tensor<float>(d0[3])\n"; private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) { for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { @@ -162,17 +165,4 @@ public class ModelEvaluationTest { throw new IllegalArgumentException("No profile named " + name); } - // We don't have function file distribution so just return empty tensor constants - private static class ToleratingMissingConstantFilesRankProfilesConfigImporter extends RankProfilesConfigImporter { - - public ToleratingMissingConstantFilesRankProfilesConfigImporter(FileAcquirer fileAcquirer) { - super(fileAcquirer); - } - - protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { - return Tensor.from(type, "{}"); - } - - } - } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java new file mode 100644 index 00000000000..5dea4a04229 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java @@ -0,0 +1,108 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import ai.vespa.models.evaluation.Model; +import ai.vespa.models.evaluation.ModelsEvaluator; +import ai.vespa.models.evaluation.RankProfilesConfigImporter; +import ai.vespa.models.handler.ModelsEvaluationHandler; +import com.yahoo.component.ComponentId; +import com.yahoo.config.FileReference; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests stateless model evaluation (turned on by the "model-evaluation" tag in "container") + * for ONNX models. + * + * @author lesters + */ +public class StatelessOnnxEvaluationTest { + + @Test + public void testStatelessOnnxModelEvaluation() throws IOException { + Path appDir = Path.fromString("src/test/cfg/application/onnx"); + Path storedAppDir = appDir.append("copy"); + try { + ImportedModelTester tester = new ImportedModelTester("onnx_rt", appDir); + assertModelEvaluation(tester.createVespaModel(), appDir); + + // At this point the expression is stored - copy application to another location which does not have a models dir + storedAppDir.toFile().mkdirs(); + IOUtils.copy(appDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + IOUtils.copyDirectory(appDir.append(ApplicationPackage.SEARCH_DEFINITIONS_DIR).toFile(), + storedAppDir.append(ApplicationPackage.SEARCH_DEFINITIONS_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester("onnx_rt", storedAppDir); + assertModelEvaluation(storedTester.createVespaModel(), appDir); + + } finally { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + private void assertModelEvaluation(VespaModel model, Path appDir) { + ApplicationContainerCluster cluster = model.getContainerClusters().get("container"); + assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); + + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); + cluster.getConfig(b); + RankProfilesConfig config = new RankProfilesConfig(b); + + RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); + cluster.getConfig(cb); + RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); + + OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder(); + cluster.getConfig(ob); + OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob); + + assertEquals(1, config.rankprofile().size()); + Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); + assertTrue(modelNames.contains("mul")); + + // This is actually how ModelsEvaluator is injected + Map<String, File> fileMap = new HashMap<>(); + for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) { + fileMap.put(onnxModel.fileref().value(), appDir.append(onnxModel.fileref().value()).toFile()); + } + FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap); + ModelsEvaluator modelsEvaluator = new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer); + assertEquals(1, modelsEvaluator.models().size()); + + Model mul = modelsEvaluator.models().get("mul"); + FunctionEvaluator evaluator = mul.evaluatorOf(); // or "default.output" - or actually use name of model output + + Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]"); + Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]"); + Tensor output = evaluator.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, output.sum().asDouble(), 1e-9); + + } + +} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index c1f973300d6..68bebfa6183 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import ai.onnxruntime.TensorInfo; import ai.onnxruntime.ValueInfo; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -25,6 +26,7 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; @@ -38,7 +40,8 @@ class TensorConverter { { Map<String, OnnxTensor> result = new HashMap<>(); for (String name : tensorMap.keySet()) { - Tensor vespaTensor = tensorMap.get(name); + Tensor vespaTensor = tensorMap.get(name); + name = toOnnxName(name, session.getInputInfo().keySet()); TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo()); OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env); result.put(name, onnxTensor); @@ -143,7 +146,22 @@ class TensorConverter { } static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { - return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo()))); + return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), + e -> toVespaType(e.getValue().getInfo()))); + } + + static String asValidName(String name) { + return OnnxImporter.asValidIdentifier(name); + } + + static String toOnnxName(String name, Set<String> onnxNames) { + if (onnxNames.contains(name)) + return name; + for (String onnxName : onnxNames) { + if (asValidName(onnxName).equals(name)) + return onnxName; + } + throw new IllegalArgumentException("ONNX model has no input with name " + name); } static TensorType toVespaType(ValueInfo valueInfo) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 47fe66dd424..cf92cbc1e89 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -35,6 +35,7 @@ public class ImportedModel implements ImportedMlModel { private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); private final String name; private final String source; + private final ModelType modelType; private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> inputs = new HashMap<>(); @@ -49,11 +50,12 @@ public class ImportedModel implements ImportedMlModel { * @param name the name of this mode, containing only characters in [A-Za-z0-9_] * @param source the source path (directory or file) of this model */ - public ImportedModel(String name, String source) { + public ImportedModel(String name, String source, ModelType modelType) { if ( ! nameRegexp.matcher(name).matches()) throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'"); this.name = name; this.source = source; + this.modelType = modelType; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ @@ -64,6 +66,10 @@ public class ImportedModel implements ImportedMlModel { @Override public String source() { return source; } + /** Returns the original model type */ + @Override + public ModelType modelType() { return modelType; } + @Override public String toString() { return "imported model '" + name + "' from " + source; } @@ -212,6 +218,16 @@ public class ImportedModel implements ImportedMlModel { return values; } + @Override + public boolean isNative() { + return true; + } + + @Override + public ImportedModel asNative() { + return this; + } + /** * A signature is a set of named inputs and outputs, where the inputs maps to input * ("placeholder") names+types, and outputs maps to expressions nodes. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 8f73cd02184..3f87bfa0beb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -52,8 +53,10 @@ public abstract class ModelImporter implements MlModelImporter { * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. */ - protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { - ImportedModel model = new ImportedModel(graph.name(), modelSource); + protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, + String modelSource, + ImportedMlModel.ModelType modelType) { + ImportedModel model = new ImportedModel(graph.name(), modelSource, modelType); log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" + ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java index e40a06af042..b98dfa33320 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java @@ -12,12 +12,20 @@ import java.util.Optional; */ public interface ImportedMlModel { + enum ModelType { + VESPA, XGBOOST, LIGHTGBM, TENSORFLOW, ONNX + } + String name(); String source(); + ModelType modelType(); + Optional<String> inputTypeSpec(String input); Map<String, String> smallConstants(); Map<String, String> largeConstants(); Map<String, String> functions(); List<ImportedMlFunction> outputExpressions(); + boolean isNative(); + ImportedMlModel asNative(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java index 76caa652ad2..ef731730c84 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.lightgbm; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import com.fasterxml.jackson.databind.ObjectMapper; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -39,7 +40,7 @@ public class LightGBMImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName, modelPath); + ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.LIGHTGBM); LightGBMParser parser = new LightGBMParser(modelPath); RankingExpression expression = new RankingExpression(parser.toRankingExpression()); model.expression(modelName, expression); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java new file mode 100644 index 00000000000..714321ec116 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java @@ -0,0 +1,24 @@ +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import onnx.Onnx; + +public class ImportedOnnxModel extends ImportedModel { + + private final Onnx.ModelProto modelProto; + + public ImportedOnnxModel(String name, String source, Onnx.ModelProto modelProto) { + super(name, source, ModelType.ONNX); + this.modelProto = modelProto; + } + + @Override + public boolean isNative() { + return false; + } + + @Override + public ImportedModel asNative() { + return OnnxImporter.convertModel(name(), source(), modelProto, ModelType.ONNX); + } +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java index b1c5dc8a0d8..8e4dd07fc73 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java @@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import onnx.Onnx; import java.io.File; @@ -31,11 +32,36 @@ public class OnnxImporter extends ModelImporter { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); // long version = model.getOpsetImport(0).getVersion(); // opset version - IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph, modelPath); + + ImportedModel importedModel = new ImportedOnnxModel(modelName, modelPath, model); + for (int i = 0; i < model.getGraph().getOutputCount(); ++i) { + Onnx.ValueInfoProto output = model.getGraph().getOutput(i); + String outputName = asValidIdentifier(output.getName()); + importedModel.expression(outputName, "onnxModel(" + modelName + ")." + outputName); + } + return importedModel; + } catch (IOException e) { throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); } } + public ImportedModel importModelAsNative(String modelName, String modelPath, ImportedMlModel.ModelType modelType) { + try (FileInputStream inputStream = new FileInputStream(modelPath)) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + return convertModel(modelName, modelPath, model, modelType); + } catch (IOException e) { + throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); + } + } + + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); + } + + static ImportedModel convertModel(String name, String source, Onnx.ModelProto modelProto, ImportedMlModel.ModelType modelType) { + IntermediateGraph graph = GraphImporter.importGraph(name, modelProto); + return convertIntermediateGraphToModel(graph, source, modelType); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index a879c24b373..402c6562cc4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.collections.Pair; import com.yahoo.io.IOUtils; @@ -59,7 +60,7 @@ public class TensorFlowImporter extends ModelImporter { public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph, modelDir); + return convertIntermediateGraphToModel(graph, modelDir, ImportedMlModel.ModelType.TENSORFLOW); } catch (IOException e) { throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); @@ -77,7 +78,13 @@ public class TensorFlowImporter extends ModelImporter { Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, opset); if (res.getFirst() == 0) { log.info("Conversion to ONNX with opset " + opset + " successful."); - return onnxImporter.importModel(modelName, convertedPath); + + /* + * For now we have to import tensorflow models as native Vespa expressions. + * The temporary ONNX file that is created by conversion needs to be put + * in the application package so it can be file distributed. + */ + return onnxImporter.importModelAsNative(modelName, convertedPath, ImportedMlModel.ModelType.TENSORFLOW); } log.fine("Conversion to ONNX with opset " + opset + " failed. Reason: " + res.getSecond()); outputOfLastConversionAttempt = res.getSecond(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java index 021fa1f7e51..95ecd678bd5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.vespa; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.vespa.parser.ModelParser; import ai.vespa.rankingexpression.importer.vespa.parser.ParseException; @@ -28,7 +29,7 @@ public class VespaImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName, modelPath); + ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.VESPA); new ModelParser(new SimpleCharStream(IOUtils.readFile(new File(modelPath))), model).model(); return model; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index 686cf6cd2df..5829ea77815 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.xgboost; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import ai.vespa.rankingexpression.importer.ImportedModel; @@ -50,7 +51,7 @@ public class XGBoostImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName, modelPath); + ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.XGBOOST); XGBoostParser parser = new XGBoostParser(modelPath); RankingExpression expression = new RankingExpression(parser.toRankingExpression()); model.expression(modelName, expression); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 09455abc380..96bf2c64485 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -24,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx").asNative(); // Check constants assertEquals(2, model.largeConstants().size()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java index f03c629df78..cd13afca77b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java @@ -14,7 +14,7 @@ public class PyTorchImportTestCase extends TestableModel { @Test public void testPyTorchExport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx").asNative(); Tensor onnxResult = evaluateVespa(model, "output", model.inputs()); assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index 04db902073b..4c7f6139032 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -14,6 +14,8 @@ import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; /** * @author lesters @@ -24,6 +26,10 @@ public class SimpleImportTestCase { public void testSimpleOnnxModelImport() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx"); + assertFalse(model.isNative()); + model = model.asNative(); + assertTrue(model.isNative()); + MapContext context = new MapContext(); context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]"))); context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]"))); @@ -35,7 +41,7 @@ public class SimpleImportTestCase { @Test public void testGather() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx").asNative(); MapContext context = new MapContext(); context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"))); @@ -49,7 +55,7 @@ public class SimpleImportTestCase { @Test public void testConcat() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx").asNative(); MapContext context = new MapContext(); context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]"))); |