blob: 71100ca2b76d868c0591ba61d793819a060b21a6 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.tensor.Tensor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
/**
* @author lesters
*/
public class PyTorchImportTestCase extends TestableModel {
@Test
public void testPyTorchExport() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx").asNative();
Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult);
}
}
|