From 31b00d9cdbba6081c18dce9e2dae76c33e580557 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 25 Jan 2018 13:25:26 +0100 Subject: Refactor: Rename --- .../tensorflow/BatchNormImportTestCase.java | 9 ++- .../tensorflow/MnistSoftmaxImportTestCase.java | 20 +++--- .../tensorflow/TensorFlowImportTester.java | 71 ---------------------- .../tensorflow/TestableTensorFlowModel.java | 71 ++++++++++++++++++++++ 4 files changed, 85 insertions(+), 86 deletions(-) delete mode 100644 searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java create mode 100644 searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java (limited to 'searchlib/src/test/java/com') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java index 3d028b0775e..c6ee586a78c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; -import org.tensorflow.SavedModelBundle; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -15,16 +14,16 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { - TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/batch_norm/saved"); - TensorFlowModel.Signature signature = tester.result().signature("serving_default"); + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", - 0, tester.result().signature("serving_default").skippedOutputs().size()); + 0, model.get().signature("serving_default").skippedOutputs().size()); RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - tester.assertEqualResult("X", output.getName()); + model.assertEqualResult("X", output.getName()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 044a6917e00..f12b9a2c628 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -16,26 +16,26 @@ public class MnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/mnist_softmax/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, tester.result().constants().size()); + assertEquals(2, model.get().constants().size()); - Tensor constant0 = tester.result().constants().get("Variable"); + Tensor constant0 = model.get().constants().get("Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = tester.result().constants().get("Variable_1"); + Tensor constant1 = model.get().constants().get("Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d0", 10).build(), constant1.type()); assertEquals(10, constant1.size()); // Check signatures - assertEquals(1, tester.result().signatures().size()); - TensorFlowModel.Signature signature = tester.result().signatures().get("serving_default"); + assertEquals(1, model.get().signatures().size()); + TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs @@ -53,10 +53,10 @@ public class MnistSoftmaxImportTestCase { output.getRoot().toString()); // Test execution - tester.assertEqualResult("Placeholder", "Variable/read"); - tester.assertEqualResult("Placeholder", "Variable_1/read"); - tester.assertEqualResult("Placeholder", "MatMul"); - tester.assertEqualResult("Placeholder", "add"); + model.assertEqualResult("Placeholder", "Variable/read"); + model.assertEqualResult("Placeholder", "Variable_1/read"); + model.assertEqualResult("Placeholder", "MatMul"); + model.assertEqualResult("Placeholder", "add"); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java deleted file mode 100644 index c6623296d04..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java +++ /dev/null @@ -1,71 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; - -import java.nio.FloatBuffer; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -/** - * Helper for TensorFlow import tests: Imports a model and provides asserts on it. - * This currently assumes the TensorFlow model takes a single input of type tensor(d0[1],d1[784]) - * - * @author bratseth - */ -public class TensorFlowImportTester { - - private SavedModelBundle model; - private TensorFlowModel result; - - // Sizes of the input vector - private final int d0Size = 1; - private final int d1Size = 784; - - public TensorFlowImportTester(String modelDir) { - model = SavedModelBundle.load(modelDir, "serve"); - result = new TensorFlowImporter().importModel(model); - } - - public TensorFlowModel result() { return result; } - - public void assertEqualResult(String inputName, String operationName) { - Tensor tfResult = tensorFlowExecute(model, inputName, operationName); - Context context = contextFrom(result); - Tensor placeholder = placeholderArgument(); - context.put(inputName, new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); - assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); - } - - private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { - Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, - FloatBuffer.allocate(d0Size * d1Size)); - runner.feed(inputName, placeholder); - List> results = runner.fetch(operationName).run(); - assertEquals(1, results.size()); - return new TensorConverter().toVespaTensor(results.get(0)); - } - - private Context contextFrom(TensorFlowModel result) { - MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); - return context; - } - - private Tensor placeholderArgument() { - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); - for (int d0 = 0; d0 < d0Size; d0++) - for (int d1 = 0; d1 < d1Size; d1++) - b.cell(0, d0, d1); - return b.build(); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java new file mode 100644 index 00000000000..186717d24cd --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -0,0 +1,71 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Helper for TensorFlow import tests: Imports a model and provides asserts on it. + * This currently assumes the TensorFlow model takes a single input of type tensor(d0[1],d1[784]) + * + * @author bratseth + */ +public class TestableTensorFlowModel { + + private SavedModelBundle tensorFlowModel; + private TensorFlowModel model; + + // Sizes of the input vector + private final int d0Size = 1; + private final int d1Size = 784; + + public TestableTensorFlowModel(String modelDir) { + tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); + model = new TensorFlowImporter().importModel(tensorFlowModel); + } + + public TensorFlowModel get() { return model; } + + public void assertEqualResult(String inputName, String operationName) { + Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); + Context context = contextFrom(model); + Tensor placeholder = placeholderArgument(); + context.put(inputName, new TensorValue(placeholder)); + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); + } + + private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { + Session.Runner runner = model.session().runner(); + org.tensorflow.Tensor placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, + FloatBuffer.allocate(d0Size * d1Size)); + runner.feed(inputName, placeholder); + List> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return new TensorConverter().toVespaTensor(results.get(0)); + } + + private Context contextFrom(TensorFlowModel result) { + MapContext context = new MapContext(); + result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + return context; + } + + private Tensor placeholderArgument() { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; d1++) + b.cell(0, d0, d1); + return b.build(); + } + +} -- cgit v1.2.3