diff options
author | Lester Solbakken <lesters@oath.com> | 2021-05-21 09:33:32 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-05-21 09:33:32 +0200 |
commit | 6448742f804482946a7bf2d17723dca6b4100b73 (patch) | |
tree | 135038f0298f3e519ed8e4327cf1bf1915df4b39 /model-integration | |
parent | 864eb3da782e9795826ec78add953a76eeb2ea17 (diff) |
Wire in stateless ONNX runtime evaluation
Diffstat (limited to 'model-integration')
13 files changed, 127 insertions, 16 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index c1f973300d6..68bebfa6183 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import ai.onnxruntime.TensorInfo; import ai.onnxruntime.ValueInfo; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -25,6 +26,7 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; @@ -38,7 +40,8 @@ class TensorConverter { { Map<String, OnnxTensor> result = new HashMap<>(); for (String name : tensorMap.keySet()) { - Tensor vespaTensor = tensorMap.get(name); + Tensor vespaTensor = tensorMap.get(name); + name = toOnnxName(name, session.getInputInfo().keySet()); TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo()); OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env); result.put(name, onnxTensor); @@ -143,7 +146,22 @@ class TensorConverter { } static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { - return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo()))); + return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), + e -> toVespaType(e.getValue().getInfo()))); + } + + static String asValidName(String name) { + return OnnxImporter.asValidIdentifier(name); + } + + static String toOnnxName(String name, Set<String> onnxNames) { + if (onnxNames.contains(name)) + return name; + for (String onnxName : onnxNames) { + if (asValidName(onnxName).equals(name)) + return onnxName; + } + throw new IllegalArgumentException("ONNX model has no input with name " + name); } static TensorType toVespaType(ValueInfo valueInfo) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 47fe66dd424..cf92cbc1e89 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -35,6 +35,7 @@ public class ImportedModel implements ImportedMlModel { private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); private final String name; private final String source; + private final ModelType modelType; private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> inputs = new HashMap<>(); @@ -49,11 +50,12 @@ public class ImportedModel implements ImportedMlModel { * @param name the name of this mode, containing only characters in [A-Za-z0-9_] * @param source the source path (directory or file) of this model */ - public ImportedModel(String name, String source) { + public ImportedModel(String name, String source, ModelType modelType) { if ( ! nameRegexp.matcher(name).matches()) throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'"); this.name = name; this.source = source; + this.modelType = modelType; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ @@ -64,6 +66,10 @@ public class ImportedModel implements ImportedMlModel { @Override public String source() { return source; } + /** Returns the original model type */ + @Override + public ModelType modelType() { return modelType; } + @Override public String toString() { return "imported model '" + name + "' from " + source; } @@ -212,6 +218,16 @@ public class ImportedModel implements ImportedMlModel { return values; } + @Override + public boolean isNative() { + return true; + } + + @Override + public ImportedModel asNative() { + return this; + } + /** * A signature is a set of named inputs and outputs, where the inputs maps to input * ("placeholder") names+types, and outputs maps to expressions nodes. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 8f73cd02184..3f87bfa0beb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -52,8 +53,10 @@ public abstract class ModelImporter implements MlModelImporter { * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. */ - protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { - ImportedModel model = new ImportedModel(graph.name(), modelSource); + protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, + String modelSource, + ImportedMlModel.ModelType modelType) { + ImportedModel model = new ImportedModel(graph.name(), modelSource, modelType); log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" + ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java index e40a06af042..b98dfa33320 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java @@ -12,12 +12,20 @@ import java.util.Optional; */ public interface ImportedMlModel { + enum ModelType { + VESPA, XGBOOST, LIGHTGBM, TENSORFLOW, ONNX + } + String name(); String source(); + ModelType modelType(); + Optional<String> inputTypeSpec(String input); Map<String, String> smallConstants(); Map<String, String> largeConstants(); Map<String, String> functions(); List<ImportedMlFunction> outputExpressions(); + boolean isNative(); + ImportedMlModel asNative(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java index 76caa652ad2..ef731730c84 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.lightgbm; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import com.fasterxml.jackson.databind.ObjectMapper; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -39,7 +40,7 @@ public class LightGBMImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName, modelPath); + ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.LIGHTGBM); LightGBMParser parser = new LightGBMParser(modelPath); RankingExpression expression = new RankingExpression(parser.toRankingExpression()); model.expression(modelName, expression); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java new file mode 100644 index 00000000000..714321ec116 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java @@ -0,0 +1,24 @@ +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import onnx.Onnx; + +public class ImportedOnnxModel extends ImportedModel { + + private final Onnx.ModelProto modelProto; + + public ImportedOnnxModel(String name, String source, Onnx.ModelProto modelProto) { + super(name, source, ModelType.ONNX); + this.modelProto = modelProto; + } + + @Override + public boolean isNative() { + return false; + } + + @Override + public ImportedModel asNative() { + return OnnxImporter.convertModel(name(), source(), modelProto, ModelType.ONNX); + } +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java index b1c5dc8a0d8..8e4dd07fc73 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java @@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import onnx.Onnx; import java.io.File; @@ -31,11 +32,36 @@ public class OnnxImporter extends ModelImporter { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); // long version = model.getOpsetImport(0).getVersion(); // opset version - IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph, modelPath); + + ImportedModel importedModel = new ImportedOnnxModel(modelName, modelPath, model); + for (int i = 0; i < model.getGraph().getOutputCount(); ++i) { + Onnx.ValueInfoProto output = model.getGraph().getOutput(i); + String outputName = asValidIdentifier(output.getName()); + importedModel.expression(outputName, "onnxModel(" + modelName + ")." + outputName); + } + return importedModel; + } catch (IOException e) { throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); } } + public ImportedModel importModelAsNative(String modelName, String modelPath, ImportedMlModel.ModelType modelType) { + try (FileInputStream inputStream = new FileInputStream(modelPath)) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + return convertModel(modelName, modelPath, model, modelType); + } catch (IOException e) { + throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); + } + } + + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); + } + + static ImportedModel convertModel(String name, String source, Onnx.ModelProto modelProto, ImportedMlModel.ModelType modelType) { + IntermediateGraph graph = GraphImporter.importGraph(name, modelProto); + return convertIntermediateGraphToModel(graph, source, modelType); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index a879c24b373..402c6562cc4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.collections.Pair; import com.yahoo.io.IOUtils; @@ -59,7 +60,7 @@ public class TensorFlowImporter extends ModelImporter { public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph, modelDir); + return convertIntermediateGraphToModel(graph, modelDir, ImportedMlModel.ModelType.TENSORFLOW); } catch (IOException e) { throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); @@ -77,7 +78,13 @@ public class TensorFlowImporter extends ModelImporter { Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, opset); if (res.getFirst() == 0) { log.info("Conversion to ONNX with opset " + opset + " successful."); - return onnxImporter.importModel(modelName, convertedPath); + + /* + * For now we have to import tensorflow models as native Vespa expressions. + * The temporary ONNX file that is created by conversion needs to be put + * in the application package so it can be file distributed. + */ + return onnxImporter.importModelAsNative(modelName, convertedPath, ImportedMlModel.ModelType.TENSORFLOW); } log.fine("Conversion to ONNX with opset " + opset + " failed. Reason: " + res.getSecond()); outputOfLastConversionAttempt = res.getSecond(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java index 021fa1f7e51..95ecd678bd5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.vespa; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.vespa.parser.ModelParser; import ai.vespa.rankingexpression.importer.vespa.parser.ParseException; @@ -28,7 +29,7 @@ public class VespaImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName, modelPath); + ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.VESPA); new ModelParser(new SimpleCharStream(IOUtils.readFile(new File(modelPath))), model).model(); return model; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index 686cf6cd2df..5829ea77815 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.xgboost; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import ai.vespa.rankingexpression.importer.ImportedModel; @@ -50,7 +51,7 @@ public class XGBoostImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName, modelPath); + ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.XGBOOST); XGBoostParser parser = new XGBoostParser(modelPath); RankingExpression expression = new RankingExpression(parser.toRankingExpression()); model.expression(modelName, expression); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 09455abc380..96bf2c64485 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -24,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx").asNative(); // Check constants assertEquals(2, model.largeConstants().size()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java index f03c629df78..cd13afca77b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java @@ -14,7 +14,7 @@ public class PyTorchImportTestCase extends TestableModel { @Test public void testPyTorchExport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx").asNative(); Tensor onnxResult = evaluateVespa(model, "output", model.inputs()); assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index 04db902073b..4c7f6139032 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -14,6 +14,8 @@ import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; /** * @author lesters @@ -24,6 +26,10 @@ public class SimpleImportTestCase { public void testSimpleOnnxModelImport() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx"); + assertFalse(model.isNative()); + model = model.asNative(); + assertTrue(model.isNative()); + MapContext context = new MapContext(); context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]"))); context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]"))); @@ -35,7 +41,7 @@ public class SimpleImportTestCase { @Test public void testGather() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx").asNative(); MapContext context = new MapContext(); context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"))); @@ -49,7 +55,7 @@ public class SimpleImportTestCase { @Test public void testConcat() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx").asNative(); MapContext context = new MapContext(); context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]"))); |