summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 13:18:55 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 13:18:55 +0100
commit742c9bf3accf86ba993243e3e42961ed0923edc6 (patch)
tree5e91f5055e4b4f1518ba23452314ed339e0561ef /model-integration/src/test/java/ai
parent64f76d40dbd4c2cfe91465ab81edb219e6f6b374 (diff)
Add PyTorch ONNX export test case
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java22
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java122
2 files changed, 144 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
new file mode 100644
index 00000000000..f03c629df78
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
@@ -0,0 +1,22 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class PyTorchImportTestCase extends TestableModel {
+
+ @Test
+ public void testPyTorchExport() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx");
+ Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
+ assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult);
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
new file mode 100644
index 00000000000..28a6ee902a0
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
@@ -0,0 +1,122 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+public class TestableModel {
+
+ Tensor evaluateVespa(ImportedModel model, String operationName, Map<String, TensorType> inputs) {
+ Context context = contextFrom(model);
+ for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
+ Tensor argument = vespaInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue());
+ context.put(entry.getKey(), new TensorValue(argument));
+ }
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ return expression.evaluate(context).asTensor();
+ }
+
+ Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) {
+ Session.Runner runner = tensorFlowModel.session().runner();
+ for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
+ try {
+ runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
+ } catch (Exception e) {
+ runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
+ }
+ }
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return TensorConverter.toVespaTensor(results.get(0));
+ }
+
+ private org.tensorflow.Tensor<?> tensorFlowFloatInputArgument(int d0Size, int d1Size) {
+ FloatBuffer fb1 = FloatBuffer.allocate(d0Size * d1Size);
+ int i = 0;
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; ++d1)
+ fb1.put(i++, (float)(d1 * 1.0 / d1Size));
+ return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1);
+ }
+
+ private org.tensorflow.Tensor<?> tensorFlowDoubleInputArgument(int d0Size, int d1Size) {
+ DoubleBuffer fb1 = DoubleBuffer.allocate(d0Size * d1Size);
+ int i = 0;
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; ++d1)
+ fb1.put(i++, (float)(d1 * 1.0 / d1Size));
+ return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1);
+ }
+
+ private Tensor vespaInputArgument(int d0Size, int d1Size) {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; d1++)
+ b.cell(d1 * 1.0 / d1Size, d0, d1);
+ return b.build();
+ }
+
+ private void evaluateFunction(Context context, ImportedModel model, String functionName) {
+ if (!context.names().contains(functionName)) {
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
+ evaluateFunctionDependencies(context, model, e.getRoot());
+ context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
+ }
+ }
+
+ private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ if (model.functions().containsKey(name)) {
+ evaluateFunction(context, model, name);
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateFunctionDependencies(context, model, child);
+ }
+ }
+ }
+
+ static Context contextFrom(ImportedModel result) {
+ TestableModelContext context = new TestableModelContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ return context;
+ }
+
+ private static class TestableModelContext extends MapContext implements ContextIndex {
+ @Override
+ public int size() {
+ return bindings().size();
+ }
+ @Override
+ public int getIndex(String name) {
+ throw new UnsupportedOperationException(this + " does not support index lookup by name");
+ }
+ }
+
+}