diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-21 10:25:49 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-21 10:25:49 -0700 |
commit | 569c2f0d781b33c17aadf6929fef2388643e1d64 (patch) | |
tree | dc030bcca551415b44218701f762fac9e21e6a3a /searchlib/src/test/java/com | |
parent | 772b67da6040957bd975b2418f98d2f18ee69fc4 (diff) |
Propagate input type information
Diffstat (limited to 'searchlib/src/test/java/com')
5 files changed, 35 insertions, 32 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index bf9684082f4..3a1c9ec9551 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -20,10 +20,11 @@ public class BatchNormImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - RankingExpression output = signature.outputExpression("y"); + ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - model.assertEqualResult("X", output.getName()); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.expression().getName()); + model.assertEqualResult("X", output.expression().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index a8f7542f3a4..4c35d843f5d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -19,22 +19,23 @@ public class DropoutImportTestCase { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); // Check required functions - assertEquals(1, model.get().requiredFunctions().size()); - assertTrue(model.get().requiredFunctions().containsKey("X")); + assertEquals(1, model.get().inputs().size()); + assertTrue(model.get().inputs().containsKey("X")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().requiredFunctions().get("X")); + model.get().inputs().get("X")); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - RankingExpression output = signature.outputExpression("y"); + ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("outputs/Maximum", output.getName()); + assertEquals("outputs/Maximum", output.expression().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.getRoot().toString()); - model.assertEqualResult("X", output.getName()); + output.expression().getRoot().toString()); + model.assertEqualResult("X", output.expression().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java index add66eece1a..b3e281ad25d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java @@ -20,11 +20,10 @@ public class MnistImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - RankingExpression output = signature.outputExpression("y"); + ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/outputs/add", output.getName()); - model.assertEqualResultSum("input", output.getName(), 0.00001); + assertEquals("dnn/outputs/add", output.expression().getName()); + model.assertEqualResultSum("input", output.expression().getName(), 0.00001); } - } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index e20ac16a691..b5655cfbfa5 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -27,27 +27,28 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor constant0 = model.largeConstants().get("test_Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), - constant0.type()); + constant0.type()); assertEquals(7840, constant0.size()); Tensor constant1 = model.largeConstants().get("test_Variable_1"); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d1", 10).build(), - constant1.type()); + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); - // Check required functions (inputs) - assertEquals(1, model.requiredFunctions().size()); - assertTrue(model.requiredFunctions().containsKey("Placeholder")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredFunctions().get("Placeholder")); + // Check inputs + assertEquals(1, model.inputs().size()); + assertTrue(model.inputs().containsKey("Placeholder")); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); - // Check outputs - RankingExpression output = model.defaultSignature().outputExpression("add"); + // Check signature + ImportedModel.ExpressionWithInputs output = model.defaultSignature().outputExpression("add"); assertNotNull(output); - assertEquals("add", output.getName()); + assertEquals("add", output.expression().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getRoot().toString()); + output.expression().getRoot().toString()); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), + model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.inputs().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index ef28eb4678f..4a0362c0229 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -38,10 +38,10 @@ public class TensorFlowMnistSoftmaxImportTestCase { assertEquals(0, model.get().functions().size()); // Check required functions - assertEquals(1, model.get().requiredFunctions().size()); - assertTrue(model.get().requiredFunctions().containsKey("Placeholder")); + assertEquals(1, model.get().inputs().size()); + assertTrue(model.get().inputs().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().requiredFunctions().get("Placeholder")); + model.get().inputs().get("Placeholder")); // Check signatures assertEquals(1, model.get().signatures().size()); @@ -56,11 +56,12 @@ public class TensorFlowMnistSoftmaxImportTestCase { // ... signature outputs assertEquals(1, signature.outputs().size()); - RankingExpression output = signature.outputExpression("y"); + ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("add", output.getName()); + assertEquals("add", output.expression().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", - output.getRoot().toString()); + output.expression().getRoot().toString()); + assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); |