diff options
author | Lester Solbakken <lesters@oath.com> | 2019-11-22 13:18:55 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-11-22 13:18:55 +0100 |
commit | 742c9bf3accf86ba993243e3e42961ed0923edc6 (patch) | |
tree | 5e91f5055e4b4f1518ba23452314ed339e0561ef /model-integration | |
parent | 64f76d40dbd4c2cfe91465ab81edb219e6f6b374 (diff) |
Add PyTorch ONNX export test case
Diffstat (limited to 'model-integration')
4 files changed, 210 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java new file mode 100644 index 00000000000..f03c629df78 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java @@ -0,0 +1,22 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class PyTorchImportTestCase extends TestableModel { + + @Test + public void testPyTorchExport() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx"); + Tensor onnxResult = evaluateVespa(model, "output", model.inputs()); + assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java new file mode 100644 index 00000000000..28a6ee902a0 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java @@ -0,0 +1,122 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.tensorflow.TensorConverter; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class TestableModel { + + Tensor evaluateVespa(ImportedModel model, String operationName, Map<String, TensorType> inputs) { + Context context = contextFrom(model); + for (Map.Entry<String, TensorType> entry : inputs.entrySet()) { + Tensor argument = vespaInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()); + context.put(entry.getKey(), new TensorValue(argument)); + } + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); + RankingExpression expression = model.expressions().get(operationName); + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + return expression.evaluate(context).asTensor(); + } + + Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) { + Session.Runner runner = tensorFlowModel.session().runner(); + for (Map.Entry<String, TensorType> entry : inputs.entrySet()) { + try { + runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); + } catch (Exception e) { + runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); + } + } + List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return TensorConverter.toVespaTensor(results.get(0)); + } + + private org.tensorflow.Tensor<?> tensorFlowFloatInputArgument(int d0Size, int d1Size) { + FloatBuffer fb1 = FloatBuffer.allocate(d0Size * d1Size); + int i = 0; + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; ++d1) + fb1.put(i++, (float)(d1 * 1.0 / d1Size)); + return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1); + } + + private org.tensorflow.Tensor<?> tensorFlowDoubleInputArgument(int d0Size, int d1Size) { + DoubleBuffer fb1 = DoubleBuffer.allocate(d0Size * d1Size); + int i = 0; + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; ++d1) + fb1.put(i++, (float)(d1 * 1.0 / d1Size)); + return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1); + } + + private Tensor vespaInputArgument(int d0Size, int d1Size) { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; d1++) + b.cell(d1 * 1.0 / d1Size, d0, d1); + return b.build(); + } + + private void evaluateFunction(Context context, ImportedModel model, String functionName) { + if (!context.names().contains(functionName)) { + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); + evaluateFunctionDependencies(context, model, e.getRoot()); + context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); + } + } + + private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) { + if (node instanceof ReferenceNode) { + String name = node.toString(); + if (model.functions().containsKey(name)) { + evaluateFunction(context, model, name); + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) { + evaluateFunctionDependencies(context, model, child); + } + } + } + + static Context contextFrom(ImportedModel result) { + TestableModelContext context = new TestableModelContext(); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + return context; + } + + private static class TestableModelContext extends MapContext implements ContextIndex { + @Override + public int size() { + return bindings().size(); + } + @Override + public int getIndex(String name) { + throw new UnsupportedOperationException(this + " does not support index lookup by name"); + } + } + +} diff --git a/model-integration/src/test/models/pytorch/pytorch.onnx b/model-integration/src/test/models/pytorch/pytorch.onnx Binary files differnew file mode 100644 index 00000000000..c940265b58b --- /dev/null +++ b/model-integration/src/test/models/pytorch/pytorch.onnx diff --git a/model-integration/src/test/models/pytorch/pytorch_test.py b/model-integration/src/test/models/pytorch/pytorch_test.py new file mode 100755 index 00000000000..d2adb6c8974 --- /dev/null +++ b/model-integration/src/test/models/pytorch/pytorch_test.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +import torch + +# ref: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html + +# N is batch size; D_in is input dimension; +# H is hidden dimension; D_out is output dimension. +N, D_in, H, D_out = 1, 10, 5, 2 + +# Create random Tensors to hold inputs and outputs +x = torch.randn(N, D_in) +y = torch.randn(N, D_out) + +# Use the nn package to define our model as a sequence of layers. nn.Sequential +# is a Module which contains other Modules, and applies them in sequence to +# produce its output. Each Linear Module computes output from input using a +# linear function, and holds internal Tensors for its weight and bias. +model = torch.nn.Sequential( + torch.nn.Linear(D_in, H), + torch.nn.ReLU(), + torch.nn.Linear(H, D_out), +) + +# The nn package also contains definitions of popular loss functions; in this +# case we will use Mean Squared Error (MSE) as our loss function. +loss_fn = torch.nn.MSELoss(reduction='sum') + +learning_rate = 1e-4 +for t in range(500): + # Forward pass: compute predicted y by passing x to the model. Module objects + # override the __call__ operator so you can call them like functions. When + # doing so you pass a Tensor of input data to the Module and it produces + # a Tensor of output data. + y_pred = model(x) + + # Compute and print loss. We pass Tensors containing the predicted and true + # values of y, and the loss function returns a Tensor containing the + # loss. + loss = loss_fn(y_pred, y) + if t % 100 == 99: + print(t, loss.item()) + + # Zero the gradients before running the backward pass. + model.zero_grad() + + # Backward pass: compute gradient of the loss with respect to all the learnable + # parameters of the model. Internally, the parameters of each Module are stored + # in Tensors with requires_grad=True, so this call will compute gradients for + # all learnable parameters in the model. + loss.backward() + + # Update the weights using gradient descent. Each parameter is a Tensor, so + # we can access its gradients like we did before. + with torch.no_grad(): + for param in model.parameters(): + param -= learning_rate * param.grad + + +torch.onnx.export(model, x, "pytorch.onnx", verbose=True, opset_version=7, input_names=["input"], output_names=["output"]) + +test_input = torch.tensor([ [i/D_in for i in range(D_in)] ]) +print(model(test_input)) # for Vespa validation + + + |