From da1a20ab27fff180baf3f574774c3bbb57488fee Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 1 Oct 2018 13:50:31 +0200 Subject: Expect the right arguments --- .../integration/ml/BatchNormImportTestCase.java | 13 +++++++------ .../integration/ml/DropoutImportTestCase.java | 12 ++++++------ .../ml/TensorFlowMnistSoftmaxImportTestCase.java | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 593e7b54c10..e325c3d11b4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -15,17 +15,18 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", + "src/test/files/integration/tensorflow/batch_norm/saved"); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + ExpressionFunction function = signature.outputExpression("y"); + assertNotNull(function); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName()); + model.assertEqualResult("X", function.getBody().getName()); + assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 59712c0152f..8ca5a9a7888 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -30,13 +30,13 @@ public class DropoutImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("outputs/Maximum", output.getBody().getName()); + ExpressionFunction function = signature.outputExpression("y"); + assertNotNull(function); + assertEquals("outputs/Maximum", function.getBody().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.getBody().getRoot().toString()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + function.getBody().getRoot().toString()); + model.assertEqualResult("X", function.getBody().getName()); + assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index 0a48ecfce21..feba40601e3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -62,7 +62,7 @@ public class TensorFlowMnistSoftmaxImportTestCase { assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getBody().getRoot().toString()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); -- cgit v1.2.3