diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-12 12:16:56 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-12 12:16:56 +0200 |
commit | 0a886d74d4c9ffde41eef1f7e3c186b60b9f3726 (patch) | |
tree | e142b94341563b28a2b4a0e26fe77458749d2ed9 /model-integration/src/test | |
parent | 8de8ff4f87295d812d4e660f0216953726200c92 (diff) |
Import Tensorflow models vis ONNX conversion
Diffstat (limited to 'model-integration/src/test')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java | 43 |
1 files changed, 0 insertions, 43 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 35c853bd746..09455abc380 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -55,47 +55,4 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals("{Placeholder=tensor<float>(d0[],d1[784])}", output.argumentTypes().toString()); } - @Test - public void testComparisonBetweenOnnxAndTensorflow() { - String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved"; - String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"; - - Tensor argument = placeholderArgument(); - Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); - Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add"); - - assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult); - } - - private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { - ImportedModel model = new TensorFlowImporter().importModel("test", path); - return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); - } - - private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - ImportedModel model = new OnnxImporter().importModel("test", path); - return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); - } - - private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) { - context.put(input, new TensorValue(argument)); - return expression.evaluate(context).asTensor(); - } - - private Context contextFrom(ImportedModel result) { - MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); - return context; - } - - private Tensor placeholderArgument() { - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build()); - for (int d0 = 0; d0 < 1; d0++) - for (int d1 = 0; d1 < 784; d1++) - b.cell(d1 * 1.0 / 784, d0, d1); - return b.build(); - } - - } |