diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-25 13:01:18 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-25 13:01:18 +0100 |
commit | a0f6d44333202731d07139ba6f0256dd4443da78 (patch) | |
tree | 7e8aace9cae769ba9b1a4e0990c02d23117380ca | |
parent | 01f2897bce20939c5716fc19876c2541a3d9bbc5 (diff) |
Refactor: Extract test helper logic
4 files changed, 166 insertions, 125 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java new file mode 100644 index 00000000000..770ba168f19 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -0,0 +1,32 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import org.junit.Test; +import org.tensorflow.SavedModelBundle; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author lesters + */ +public class BatchNormImportTestCase { + + @Test + public void testBatchNormImport() { + TensorFlowImportTester tester = new TensorFlowImportTester(); + String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved"; + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + TensorFlowModel result = new TensorFlowImporter().importModel(model); + TensorFlowModel.Signature signature = result.signature("serving_default"); + + assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size()); + + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); + tester.assertEqualResult(model, result, "X", output.getName()); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java new file mode 100644 index 00000000000..11063690e2a --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -0,0 +1,73 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class MnistSoftmaxImportTestCase { + + @Test + public void testMnistSoftmaxImport() { + TensorFlowImportTester tester = new TensorFlowImportTester(); + String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + TensorFlowModel result = new TensorFlowImporter().importModel(model); + + // Check constants + assertEquals(2, result.constants().size()); + + Tensor constant0 = result.constants().get("Variable"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = result.constants().get("Variable_1"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check signatures + assertEquals(1, result.signatures().size()); + TensorFlowModel.Signature signature = result.signatures().get("serving_default"); + assertNotNull(signature); + + // ... signature inputs + assertEquals(1, signature.inputs().size()); + TensorType argument0 = signature.inputArgument("x"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // ... signature outputs + assertEquals(1, signature.outputs().size()); + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("add", output.getName()); + assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", + output.getRoot().toString()); + + // Test execution + tester.assertEqualResult(model, result, "Placeholder", "Variable/read"); + tester.assertEqualResult(model, result, "Placeholder", "Variable_1/read"); + tester.assertEqualResult(model, result, "Placeholder", "MatMul"); + tester.assertEqualResult(model, result, "Placeholder", "add"); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java new file mode 100644 index 00000000000..5e5b474e445 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java @@ -0,0 +1,61 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Helper for TensorFlow import tests. + * This currently assumes the TensorFlow model takes a single input named Placeholder, of type tensor(d0[1],d1[784]) + * + * @author bratseth + */ +public class TensorFlowImportTester { + + // Sizes of the "Placeholder" vector + private final int d0Size = 1; + private final int d1Size = 784; + + public void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) { + Tensor tfResult = tensorFlowExecute(model, inputName, operationName); + Context context = contextFrom(result); + Tensor placeholder = placeholderArgument(); + context.put(inputName, new TensorValue(placeholder)); + Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); + } + + private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { + Session.Runner runner = model.session().runner(); + org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, + FloatBuffer.allocate(d0Size * d1Size)); + runner.feed(inputName, placeholder); + List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return new TensorConverter().toVespaTensor(results.get(0)); + } + + private Context contextFrom(TensorFlowModel result) { + MapContext context = new MapContext(); + result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + return context; + } + + private Tensor placeholderArgument() { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; d1++) + b.cell(0, d0, d1); + return b.build(); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java deleted file mode 100644 index c01b92fb1c7..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; - -import java.nio.FloatBuffer; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - -/** - * @author bratseth - */ -public class TensorflowImportTestCase { - - @Test - public void testMnistSoftmaxImport() { - String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); - - // Check constants - assertEquals(2, result.constants().size()); - - Tensor constant0 = result.constants().get("Variable"); - assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), - constant0.type()); - assertEquals(7840, constant0.size()); - - Tensor constant1 = result.constants().get("Variable_1"); - assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), - constant1.type()); - assertEquals(10, constant1.size()); - - // Check signatures - assertEquals(1, result.signatures().size()); - TensorFlowModel.Signature signature = result.signatures().get("serving_default"); - assertNotNull(signature); - - // ... signature inputs - assertEquals(1, signature.inputs().size()); - TensorType argument0 = signature.inputArgument("x"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - - // ... signature outputs - assertEquals(1, signature.outputs().size()); - RankingExpression output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("add", output.getName()); - assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", - output.getRoot().toString()); - - // Test execution - assertEqualResult(model, result, "Placeholder", "Variable/read"); - assertEqualResult(model, result, "Placeholder", "Variable_1/read"); - assertEqualResult(model, result, "Placeholder", "MatMul"); - assertEqualResult(model, result, "Placeholder", "add"); - } - - @Test - public void testBatchNormImport() { - String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); - TensorFlowModel.Signature signature = result.signature("serving_default"); - - assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size()); - - RankingExpression output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - assertEqualResult(model, result, "X", output.getName()); - - } - - private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) { - Tensor tfResult = tensorFlowExecute(model, inputName, operationName); - Context context = contextFrom(result); - Tensor placeholder = placeholderArgument(); - context.put(inputName, new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); - assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); - } - - // Sizes of the "Placeholder" vector - private final int d0Size = 1; - private final int d1Size = 784; - - private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { - Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, - FloatBuffer.allocate(d0Size * d1Size)); - runner.feed(inputName, placeholder); - List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); - assertEquals(1, results.size()); - return new TensorConverter().toVespaTensor(results.get(0)); - } - - private Context contextFrom(TensorFlowModel result) { - MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); - return context; - } - - private Tensor placeholderArgument() { - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); - for (int d0 = 0; d0 < d0Size; d0++) - for (int d1 = 0; d1 < d1Size; d1++) - b.cell(0, d0, d1); - return b.build(); - } - -} |