summaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-21 10:25:49 -0700
committerJon Bratseth <bratseth@oath.com>2018-09-21 10:25:49 -0700
commit569c2f0d781b33c17aadf6929fef2388643e1d64 (patch)
treedc030bcca551415b44218701f762fac9e21e6a3a /searchlib/src/test/java/com
parent772b67da6040957bd975b2418f98d2f18ee69fc4 (diff)
Propagate input type information
Diffstat (limited to 'searchlib/src/test/java/com')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java15
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java25
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java13
5 files changed, 35 insertions, 32 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index bf9684082f4..3a1c9ec9551 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -20,10 +20,11 @@ public class BatchNormImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
- model.assertEqualResult("X", output.getName());
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.expression().getName());
+ model.assertEqualResult("X", output.expression().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index a8f7542f3a4..4c35d843f5d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -19,22 +19,23 @@ public class DropoutImportTestCase {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
// Check required functions
- assertEquals(1, model.get().requiredFunctions().size());
- assertTrue(model.get().requiredFunctions().containsKey("X"));
+ assertEquals(1, model.get().inputs().size());
+ assertTrue(model.get().inputs().containsKey("X"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().requiredFunctions().get("X"));
+ model.get().inputs().get("X"));
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("outputs/Maximum", output.getName());
+ assertEquals("outputs/Maximum", output.expression().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.getRoot().toString());
- model.assertEqualResult("X", output.getName());
+ output.expression().getRoot().toString());
+ model.assertEqualResult("X", output.expression().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
index add66eece1a..b3e281ad25d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
@@ -20,11 +20,10 @@ public class MnistImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/outputs/add", output.getName());
- model.assertEqualResultSum("input", output.getName(), 0.00001);
+ assertEquals("dnn/outputs/add", output.expression().getName());
+ model.assertEqualResultSum("input", output.expression().getName(), 0.00001);
}
-
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index e20ac16a691..b5655cfbfa5 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -27,27 +27,28 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor constant0 = model.largeConstants().get("test_Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
- constant0.type());
+ constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = model.largeConstants().get("test_Variable_1");
assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
- constant1.type());
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
- // Check required functions (inputs)
- assertEquals(1, model.requiredFunctions().size());
- assertTrue(model.requiredFunctions().containsKey("Placeholder"));
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.requiredFunctions().get("Placeholder"));
+ // Check inputs
+ assertEquals(1, model.inputs().size());
+ assertTrue(model.inputs().containsKey("Placeholder"));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
- // Check outputs
- RankingExpression output = model.defaultSignature().outputExpression("add");
+ // Check signature
+ ImportedModel.ExpressionWithInputs output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
- assertEquals("add", output.getName());
+ assertEquals("add", output.expression().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.getRoot().toString());
+ output.expression().getRoot().toString());
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
+ model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.inputs().toString());
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index ef28eb4678f..4a0362c0229 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -38,10 +38,10 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals(0, model.get().functions().size());
// Check required functions
- assertEquals(1, model.get().requiredFunctions().size());
- assertTrue(model.get().requiredFunctions().containsKey("Placeholder"));
+ assertEquals(1, model.get().inputs().size());
+ assertTrue(model.get().inputs().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().requiredFunctions().get("Placeholder"));
+ model.get().inputs().get("Placeholder"));
// Check signatures
assertEquals(1, model.get().signatures().size());
@@ -56,11 +56,12 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// ... signature outputs
assertEquals(1, signature.outputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("add", output.getName());
+ assertEquals("add", output.expression().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
- output.getRoot().toString());
+ output.expression().getRoot().toString());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");