diff options
author | Lester Solbakken <lesters@oath.com> | 2018-05-28 11:33:17 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-05-28 11:33:17 +0200 |
commit | 88d06ec474f727d41963b6aa65c2382ccc01c3f5 (patch) | |
tree | d2a871f2e6870daadf674fca0f350692cbdc42a3 /searchlib/src/test | |
parent | 3c1334090cef6fb0891515040ad900702275ccea (diff) |
Add ONNX pseudo ranking feature
Diffstat (limited to 'searchlib/src/test')
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java index e118c2b885a..4b68cd40a08 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -24,18 +24,18 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() throws IOException { - OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add"); + OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("Variable_0"); + Tensor constant0 = model.largeConstants().get("test_Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("Variable_1_0"); + Tensor constant1 = model.largeConstants().get("test_Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); @@ -43,15 +43,15 @@ public class OnnxMnistSoftmaxImportTestCase { // Check required macros (inputs) assertEquals(1, model.requiredMacros().size()); - assertTrue(model.requiredMacros().containsKey("Placeholder_0")); + assertTrue(model.requiredMacros().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder_0")); + model.requiredMacros().get("Placeholder")); // Check outputs - RankingExpression output = model.expressions().get("add"); + RankingExpression output = model.outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", output.getRoot().toString()); } @@ -62,7 +62,7 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor argument = placeholderArgument(); Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); - Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add"); + Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add"); assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult); } @@ -74,7 +74,7 @@ public class OnnxMnistSoftmaxImportTestCase { } private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - OnnxModel model = new OnnxImporter().importModel(path, output); + OnnxModel model = new OnnxImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } |