summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 13:50:31 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 13:50:31 +0200
commitda1a20ab27fff180baf3f574774c3bbb57488fee (patch)
tree2e618b9588090cb32d2955aa0093f08eca09637d
parente803991178077a9d9833ce7f9c5aee539f6af787 (diff)
Expect the right arguments
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java2
3 files changed, 14 insertions, 13 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 593e7b54c10..e325c3d11b4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -15,17 +15,18 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test",
+ "src/test/files/integration/tensorflow/batch_norm/saved");
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName());
- model.assertEqualResult("X", output.getBody().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ ExpressionFunction function = signature.outputExpression("y");
+ assertNotNull(function);
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName());
+ model.assertEqualResult("X", function.getBody().getName());
+ assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 59712c0152f..8ca5a9a7888 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -30,13 +30,13 @@ public class DropoutImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("outputs/Maximum", output.getBody().getName());
+ ExpressionFunction function = signature.outputExpression("y");
+ assertNotNull(function);
+ assertEquals("outputs/Maximum", function.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.getBody().getRoot().toString());
- model.assertEqualResult("X", output.getBody().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ function.getBody().getRoot().toString());
+ model.assertEqualResult("X", function.getBody().getName());
+ assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index 0a48ecfce21..feba40601e3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -62,7 +62,7 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
output.getBody().getRoot().toString());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");