summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 13:18:55 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 13:18:55 +0100
commit742c9bf3accf86ba993243e3e42961ed0923edc6 (patch)
tree5e91f5055e4b4f1518ba23452314ed339e0561ef /model-integration
parent64f76d40dbd4c2cfe91465ab81edb219e6f6b374 (diff)
Add PyTorch ONNX export test case
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java22
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java122
-rw-r--r--model-integration/src/test/models/pytorch/pytorch.onnxbin0 -> 617 bytes
-rwxr-xr-xmodel-integration/src/test/models/pytorch/pytorch_test.py66
4 files changed, 210 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
new file mode 100644
index 00000000000..f03c629df78
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
@@ -0,0 +1,22 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class PyTorchImportTestCase extends TestableModel {
+
+ @Test
+ public void testPyTorchExport() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx");
+ Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
+ assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult);
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
new file mode 100644
index 00000000000..28a6ee902a0
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
@@ -0,0 +1,122 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+public class TestableModel {
+
+ Tensor evaluateVespa(ImportedModel model, String operationName, Map<String, TensorType> inputs) {
+ Context context = contextFrom(model);
+ for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
+ Tensor argument = vespaInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue());
+ context.put(entry.getKey(), new TensorValue(argument));
+ }
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ return expression.evaluate(context).asTensor();
+ }
+
+ Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) {
+ Session.Runner runner = tensorFlowModel.session().runner();
+ for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
+ try {
+ runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
+ } catch (Exception e) {
+ runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
+ }
+ }
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return TensorConverter.toVespaTensor(results.get(0));
+ }
+
+ private org.tensorflow.Tensor<?> tensorFlowFloatInputArgument(int d0Size, int d1Size) {
+ FloatBuffer fb1 = FloatBuffer.allocate(d0Size * d1Size);
+ int i = 0;
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; ++d1)
+ fb1.put(i++, (float)(d1 * 1.0 / d1Size));
+ return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1);
+ }
+
+ private org.tensorflow.Tensor<?> tensorFlowDoubleInputArgument(int d0Size, int d1Size) {
+ DoubleBuffer fb1 = DoubleBuffer.allocate(d0Size * d1Size);
+ int i = 0;
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; ++d1)
+ fb1.put(i++, (float)(d1 * 1.0 / d1Size));
+ return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1);
+ }
+
+ private Tensor vespaInputArgument(int d0Size, int d1Size) {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; d1++)
+ b.cell(d1 * 1.0 / d1Size, d0, d1);
+ return b.build();
+ }
+
+ private void evaluateFunction(Context context, ImportedModel model, String functionName) {
+ if (!context.names().contains(functionName)) {
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
+ evaluateFunctionDependencies(context, model, e.getRoot());
+ context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
+ }
+ }
+
+ private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ if (model.functions().containsKey(name)) {
+ evaluateFunction(context, model, name);
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateFunctionDependencies(context, model, child);
+ }
+ }
+ }
+
+ static Context contextFrom(ImportedModel result) {
+ TestableModelContext context = new TestableModelContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ return context;
+ }
+
+ private static class TestableModelContext extends MapContext implements ContextIndex {
+ @Override
+ public int size() {
+ return bindings().size();
+ }
+ @Override
+ public int getIndex(String name) {
+ throw new UnsupportedOperationException(this + " does not support index lookup by name");
+ }
+ }
+
+}
diff --git a/model-integration/src/test/models/pytorch/pytorch.onnx b/model-integration/src/test/models/pytorch/pytorch.onnx
new file mode 100644
index 00000000000..c940265b58b
--- /dev/null
+++ b/model-integration/src/test/models/pytorch/pytorch.onnx
Binary files differ
diff --git a/model-integration/src/test/models/pytorch/pytorch_test.py b/model-integration/src/test/models/pytorch/pytorch_test.py
new file mode 100755
index 00000000000..d2adb6c8974
--- /dev/null
+++ b/model-integration/src/test/models/pytorch/pytorch_test.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+
+import torch
+
+# ref: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
+
+# N is batch size; D_in is input dimension;
+# H is hidden dimension; D_out is output dimension.
+N, D_in, H, D_out = 1, 10, 5, 2
+
+# Create random Tensors to hold inputs and outputs
+x = torch.randn(N, D_in)
+y = torch.randn(N, D_out)
+
+# Use the nn package to define our model as a sequence of layers. nn.Sequential
+# is a Module which contains other Modules, and applies them in sequence to
+# produce its output. Each Linear Module computes output from input using a
+# linear function, and holds internal Tensors for its weight and bias.
+model = torch.nn.Sequential(
+ torch.nn.Linear(D_in, H),
+ torch.nn.ReLU(),
+ torch.nn.Linear(H, D_out),
+)
+
+# The nn package also contains definitions of popular loss functions; in this
+# case we will use Mean Squared Error (MSE) as our loss function.
+loss_fn = torch.nn.MSELoss(reduction='sum')
+
+learning_rate = 1e-4
+for t in range(500):
+ # Forward pass: compute predicted y by passing x to the model. Module objects
+ # override the __call__ operator so you can call them like functions. When
+ # doing so you pass a Tensor of input data to the Module and it produces
+ # a Tensor of output data.
+ y_pred = model(x)
+
+ # Compute and print loss. We pass Tensors containing the predicted and true
+ # values of y, and the loss function returns a Tensor containing the
+ # loss.
+ loss = loss_fn(y_pred, y)
+ if t % 100 == 99:
+ print(t, loss.item())
+
+ # Zero the gradients before running the backward pass.
+ model.zero_grad()
+
+ # Backward pass: compute gradient of the loss with respect to all the learnable
+ # parameters of the model. Internally, the parameters of each Module are stored
+ # in Tensors with requires_grad=True, so this call will compute gradients for
+ # all learnable parameters in the model.
+ loss.backward()
+
+ # Update the weights using gradient descent. Each parameter is a Tensor, so
+ # we can access its gradients like we did before.
+ with torch.no_grad():
+ for param in model.parameters():
+ param -= learning_rate * param.grad
+
+
+torch.onnx.export(model, x, "pytorch.onnx", verbose=True, opset_version=7, input_names=["input"], output_names=["output"])
+
+test_input = torch.tensor([ [i/D_in for i in range(D_in)] ])
+print(model(test_input)) # for Vespa validation
+
+
+