aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
committerLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
commit88d06ec474f727d41963b6aa65c2382ccc01c3f5 (patch)
treed2a871f2e6870daadf674fca0f350692cbdc42a3 /searchlib/src/test
parent3c1334090cef6fb0891515040ad900702275ccea (diff)
Add ONNX pseudo ranking feature
Diffstat (limited to 'searchlib/src/test')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java18
1 files changed, 9 insertions, 9 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
index e118c2b885a..4b68cd40a08 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -24,18 +24,18 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add");
+ OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
- Tensor constant0 = model.largeConstants().get("Variable_0");
+ Tensor constant0 = model.largeConstants().get("test_Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.largeConstants().get("Variable_1_0");
+ Tensor constant1 = model.largeConstants().get("test_Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
@@ -43,15 +43,15 @@ public class OnnxMnistSoftmaxImportTestCase {
// Check required macros (inputs)
assertEquals(1, model.requiredMacros().size());
- assertTrue(model.requiredMacros().containsKey("Placeholder_0"));
+ assertTrue(model.requiredMacros().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.requiredMacros().get("Placeholder_0"));
+ model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.expressions().get("add");
+ RankingExpression output = model.outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
output.getRoot().toString());
}
@@ -62,7 +62,7 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor argument = placeholderArgument();
Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
- Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add");
+ Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add");
assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult);
}
@@ -74,7 +74,7 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- OnnxModel model = new OnnxImporter().importModel(path, output);
+ OnnxModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}