aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-21 09:33:32 +0200
committerLester Solbakken <lesters@oath.com>2021-05-21 09:33:32 +0200
commit6448742f804482946a7bf2d17723dca6b4100b73 (patch)
tree135038f0298f3e519ed8e4327cf1bf1915df4b39 /model-integration
parent864eb3da782e9795826ec78add953a76eeb2ea17 (diff)
Wire in stateless ONNX runtime evaluation
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java22
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java24
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java30
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java3
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java10
13 files changed, 127 insertions, 16 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index c1f973300d6..68bebfa6183 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -25,6 +26,7 @@ import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.HashMap;
import java.util.Map;
+import java.util.Set;
import java.util.stream.Collectors;
@@ -38,7 +40,8 @@ class TensorConverter {
{
Map<String, OnnxTensor> result = new HashMap<>();
for (String name : tensorMap.keySet()) {
- Tensor vespaTensor = tensorMap.get(name);
+ Tensor vespaTensor = tensorMap.get(name);
+ name = toOnnxName(name, session.getInputInfo().keySet());
TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo());
OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env);
result.put(name, onnxTensor);
@@ -143,7 +146,22 @@ class TensorConverter {
}
static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) {
- return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo())));
+ return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()),
+ e -> toVespaType(e.getValue().getInfo())));
+ }
+
+ static String asValidName(String name) {
+ return OnnxImporter.asValidIdentifier(name);
+ }
+
+ static String toOnnxName(String name, Set<String> onnxNames) {
+ if (onnxNames.contains(name))
+ return name;
+ for (String onnxName : onnxNames) {
+ if (asValidName(onnxName).equals(name))
+ return onnxName;
+ }
+ throw new IllegalArgumentException("ONNX model has no input with name " + name);
}
static TensorType toVespaType(ValueInfo valueInfo) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 47fe66dd424..cf92cbc1e89 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -35,6 +35,7 @@ public class ImportedModel implements ImportedMlModel {
private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
private final String name;
private final String source;
+ private final ModelType modelType;
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> inputs = new HashMap<>();
@@ -49,11 +50,12 @@ public class ImportedModel implements ImportedMlModel {
* @param name the name of this mode, containing only characters in [A-Za-z0-9_]
* @param source the source path (directory or file) of this model
*/
- public ImportedModel(String name, String source) {
+ public ImportedModel(String name, String source, ModelType modelType) {
if ( ! nameRegexp.matcher(name).matches())
throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'");
this.name = name;
this.source = source;
+ this.modelType = modelType;
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
@@ -64,6 +66,10 @@ public class ImportedModel implements ImportedMlModel {
@Override
public String source() { return source; }
+ /** Returns the original model type */
+ @Override
+ public ModelType modelType() { return modelType; }
+
@Override
public String toString() { return "imported model '" + name + "' from " + source; }
@@ -212,6 +218,16 @@ public class ImportedModel implements ImportedMlModel {
return values;
}
+ @Override
+ public boolean isNative() {
+ return true;
+ }
+
+ @Override
+ public ImportedModel asNative() {
+ return this;
+ }
+
/**
* A signature is a set of named inputs and outputs, where the inputs maps to input
* ("placeholder") names+types, and outputs maps to expressions nodes.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 8f73cd02184..3f87bfa0beb 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -52,8 +53,10 @@ public abstract class ModelImporter implements MlModelImporter {
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
*/
- protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
- ImportedModel model = new ImportedModel(graph.name(), modelSource);
+ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph,
+ String modelSource,
+ ImportedMlModel.ModelType modelType) {
+ ImportedModel model = new ImportedModel(graph.name(), modelSource, modelType);
log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" +
ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString()));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
index e40a06af042..b98dfa33320 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
@@ -12,12 +12,20 @@ import java.util.Optional;
*/
public interface ImportedMlModel {
+ enum ModelType {
+ VESPA, XGBOOST, LIGHTGBM, TENSORFLOW, ONNX
+ }
+
String name();
String source();
+ ModelType modelType();
+
Optional<String> inputTypeSpec(String input);
Map<String, String> smallConstants();
Map<String, String> largeConstants();
Map<String, String> functions();
List<ImportedMlFunction> outputExpressions();
+ boolean isNative();
+ ImportedMlModel asNative();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
index 76caa652ad2..ef731730c84 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.lightgbm;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -39,7 +40,7 @@ public class LightGBMImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
- ImportedModel model = new ImportedModel(modelName, modelPath);
+ ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.LIGHTGBM);
LightGBMParser parser = new LightGBMParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java
new file mode 100644
index 00000000000..714321ec116
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/ImportedOnnxModel.java
@@ -0,0 +1,24 @@
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import onnx.Onnx;
+
+public class ImportedOnnxModel extends ImportedModel {
+
+ private final Onnx.ModelProto modelProto;
+
+ public ImportedOnnxModel(String name, String source, Onnx.ModelProto modelProto) {
+ super(name, source, ModelType.ONNX);
+ this.modelProto = modelProto;
+ }
+
+ @Override
+ public boolean isNative() {
+ return false;
+ }
+
+ @Override
+ public ImportedModel asNative() {
+ return OnnxImporter.convertModel(name(), source(), modelProto, ModelType.ONNX);
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
index b1c5dc8a0d8..8e4dd07fc73 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
@@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import onnx.Onnx;
import java.io.File;
@@ -31,11 +32,36 @@ public class OnnxImporter extends ModelImporter {
try (FileInputStream inputStream = new FileInputStream(modelPath)) {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
// long version = model.getOpsetImport(0).getVersion(); // opset version
- IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph, modelPath);
+
+ ImportedModel importedModel = new ImportedOnnxModel(modelName, modelPath, model);
+ for (int i = 0; i < model.getGraph().getOutputCount(); ++i) {
+ Onnx.ValueInfoProto output = model.getGraph().getOutput(i);
+ String outputName = asValidIdentifier(output.getName());
+ importedModel.expression(outputName, "onnxModel(" + modelName + ")." + outputName);
+ }
+ return importedModel;
+
} catch (IOException e) {
throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
}
}
+ public ImportedModel importModelAsNative(String modelName, String modelPath, ImportedMlModel.ModelType modelType) {
+ try (FileInputStream inputStream = new FileInputStream(modelPath)) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ return convertModel(modelName, modelPath, model, modelType);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
+ }
+ }
+
+ public static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ }
+
+ static ImportedModel convertModel(String name, String source, Onnx.ModelProto modelProto, ImportedMlModel.ModelType modelType) {
+ IntermediateGraph graph = GraphImporter.importGraph(name, modelProto);
+ return convertIntermediateGraphToModel(graph, source, modelType);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index a879c24b373..402c6562cc4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.tensorflow;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.collections.Pair;
import com.yahoo.io.IOUtils;
@@ -59,7 +60,7 @@ public class TensorFlowImporter extends ModelImporter {
public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
try {
IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph, modelDir);
+ return convertIntermediateGraphToModel(graph, modelDir, ImportedMlModel.ModelType.TENSORFLOW);
}
catch (IOException e) {
throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
@@ -77,7 +78,13 @@ public class TensorFlowImporter extends ModelImporter {
Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, opset);
if (res.getFirst() == 0) {
log.info("Conversion to ONNX with opset " + opset + " successful.");
- return onnxImporter.importModel(modelName, convertedPath);
+
+ /*
+ * For now we have to import tensorflow models as native Vespa expressions.
+ * The temporary ONNX file that is created by conversion needs to be put
+ * in the application package so it can be file distributed.
+ */
+ return onnxImporter.importModelAsNative(modelName, convertedPath, ImportedMlModel.ModelType.TENSORFLOW);
}
log.fine("Conversion to ONNX with opset " + opset + " failed. Reason: " + res.getSecond());
outputOfLastConversionAttempt = res.getSecond();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
index 021fa1f7e51..95ecd678bd5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.vespa;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.vespa.parser.ModelParser;
import ai.vespa.rankingexpression.importer.vespa.parser.ParseException;
@@ -28,7 +29,7 @@ public class VespaImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
- ImportedModel model = new ImportedModel(modelName, modelPath);
+ ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.VESPA);
new ModelParser(new SimpleCharStream(IOUtils.readFile(new File(modelPath))), model).model();
return model;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index 686cf6cd2df..5829ea77815 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import ai.vespa.rankingexpression.importer.ImportedModel;
@@ -50,7 +51,7 @@ public class XGBoostImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
- ImportedModel model = new ImportedModel(modelName, modelPath);
+ ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.XGBOOST);
XGBoostParser parser = new XGBoostParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 09455abc380..96bf2c64485 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -24,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx").asNative();
// Check constants
assertEquals(2, model.largeConstants().size());
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
index f03c629df78..cd13afca77b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
@@ -14,7 +14,7 @@ public class PyTorchImportTestCase extends TestableModel {
@Test
public void testPyTorchExport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx").asNative();
Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult);
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index 04db902073b..4c7f6139032 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -14,6 +14,8 @@ import com.yahoo.tensor.Tensor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
/**
* @author lesters
@@ -24,6 +26,10 @@ public class SimpleImportTestCase {
public void testSimpleOnnxModelImport() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx");
+ assertFalse(model.isNative());
+ model = model.asNative();
+ assertTrue(model.isNative());
+
MapContext context = new MapContext();
context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]")));
context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]")));
@@ -35,7 +41,7 @@ public class SimpleImportTestCase {
@Test
public void testGather() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx").asNative();
MapContext context = new MapContext();
context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]")));
@@ -49,7 +55,7 @@ public class SimpleImportTestCase {
@Test
public void testConcat() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx").asNative();
MapContext context = new MapContext();
context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]")));