aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-01-22 13:25:32 +0100
committerLester Solbakken <lesters@oath.com>2018-01-22 13:35:58 +0100
commitce34b6dd37afdce666e3b0b058c524ef9ebb5ef6 (patch)
tree3076e6377c4a938cffdd24c2b460e11f5833f47c /searchlib/src/test/java/com/yahoo
parent4148debe89932119346b102a81164921af007d00 (diff)
Add batch normalization test case
Diffstat (limited to 'searchlib/src/test/java/com/yahoo')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java91
1 files changed, 6 insertions, 85 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
index e22e4a36bab..cf4e64c74a1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
@@ -71,99 +71,20 @@ public class TensorflowImportTestCase {
}
@Test
- public void test3LayerMnistImport() {
- String modelDir = "src/test/files/integration/tensorflow/3_layer_mnist/saved";
+ public void testBatchNormImport() {
+ String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved";
SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
TensorFlowModel result = new TensorFlowImporter().importModel(model);
+ TensorFlowModel.Signature signature = result.signature("serving_default");
- // Check constants
- assertEquals(8, result.constants().size());
-
- Tensor outputBias = result.constants().get("dnn/outputs/bias");
- assertNotNull(outputBias);
- assertEquals(new TensorType.Builder().indexed("d0", 10).build(), outputBias.type());
- assertEquals(10, outputBias.size());
-
- Tensor outputWeights = result.constants().get("dnn/outputs/weights");
- assertNotNull(outputWeights);
- assertEquals(new TensorType.Builder().indexed("d0", 40).indexed("d1", 10).build(), outputWeights.type());
- assertEquals(400, outputWeights.size());
-
- Tensor hidden3Bias = result.constants().get("dnn/hidden3/bias");
- assertNotNull(hidden3Bias);
- assertEquals(new TensorType.Builder().indexed("d0", 40).build(), hidden3Bias.type());
- assertEquals(40, hidden3Bias.size());
-
- Tensor hidden3Weights = result.constants().get("dnn/hidden3/weights");
- assertNotNull(hidden3Weights);
- assertEquals(new TensorType.Builder().indexed("d0", 100).indexed("d1", 40).build(), hidden3Weights.type());
- assertEquals(4000, hidden3Weights.size());
-
- Tensor hidden2Bias = result.constants().get("dnn/hidden2/bias");
- assertNotNull(hidden2Bias);
- assertEquals(new TensorType.Builder().indexed("d0", 100).build(), hidden2Bias.type());
- assertEquals(100, hidden2Bias.size());
-
- Tensor hidden2Weights = result.constants().get("dnn/hidden2/weights");
- assertNotNull(hidden2Weights);
- assertEquals(new TensorType.Builder().indexed("d0", 300).indexed("d1", 100).build(), hidden2Weights.type());
- assertEquals(30000, hidden2Weights.size());
-
- Tensor hidden1Bias = result.constants().get("dnn/hidden1/bias");
- assertNotNull(hidden1Bias);
- assertEquals(new TensorType.Builder().indexed("d0", 300).build(), hidden1Bias.type());
- assertEquals(300, hidden1Bias.size());
-
- Tensor hidden1Weights = result.constants().get("dnn/hidden1/weights");
- assertNotNull(hidden1Weights);
- assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 300).build(), hidden1Weights.type());
- assertEquals(235200, hidden1Weights.size());
-
- // Check signatures
- assertEquals(1, result.signatures().size());
- TensorFlowModel.Signature signature = result.signatures().get("serving_default");
- assertNotNull(signature);
+ assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size());
- // ... signature inputs
- assertEquals(1, signature.inputs().size());
- TensorType argument0 = signature.inputArgument("x");
- assertNotNull(argument0);
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
-
- // ... signature outputs
- assertEquals(1, signature.outputs().size());
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/outputs/add", output.getName());
- assertEquals("" +
- "join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(X, rename(constant('dnn/hidden1/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden1/bias'), d0, d1), f(a,b)(a + b)), f(a)(if (a < 0, exp(a) - 1, a))), rename(constant('dnn/hidden2/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden2/bias'), d0, d1), f(a,b)(a + b)), f(a)(max(0,a))), rename(constant('dnn/hidden3/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden3/bias'), d0, d1), f(a,b)(a + b)), f(a)(1 / (1 + exp(-a)))), rename(constant('dnn/outputs/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/outputs/bias'), d0, d1), f(a,b)(a + b))",
- toNonPrimitiveString(output));
-
- // Test constants
- assertEqualResult(model, result, "X", "dnn/hidden1/weights/read");
- assertEqualResult(model, result, "X", "dnn/hidden1/bias/read");
- assertEqualResult(model, result, "X", "dnn/hidden2/weights/read");
- assertEqualResult(model, result, "X", "dnn/hidden2/bias/read");
- assertEqualResult(model, result, "X", "dnn/hidden3/weights/read");
- assertEqualResult(model, result, "X", "dnn/hidden3/bias/read");
- assertEqualResult(model, result, "X", "dnn/outputs/weights/read");
- assertEqualResult(model, result, "X", "dnn/outputs/bias/read");
-
- // Test execution
- assertEqualResult(model, result, "X", "dnn/hidden1/MatMul");
- assertEqualResult(model, result, "X", "dnn/hidden1/add");
- assertEqualResult(model, result, "X", "dnn/hidden1/Elu");
- assertEqualResult(model, result, "X", "dnn/hidden2/MatMul");
- assertEqualResult(model, result, "X", "dnn/hidden2/add");
- assertEqualResult(model, result, "X", "dnn/hidden2/Relu");
- assertEqualResult(model, result, "X", "dnn/hidden3/MatMul");
- assertEqualResult(model, result, "X", "dnn/hidden3/add");
- assertEqualResult(model, result, "X", "dnn/hidden3/Sigmoid");
- assertEqualResult(model, result, "X", "dnn/outputs/MatMul");
- assertEqualResult(model, result, "X", "dnn/outputs/add");
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
+ assertEqualResult(model, result, "X", output.getName());
}
-
private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
Context context = contextFrom(result);