From b1f46fcd0495dbce905fb8b7318781f4cf5965a7 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 5 Feb 2018 16:04:42 +0100 Subject: Refactor TensorFlow import and add dimension renaming. --- .../tensorflow/DimensionRenamerTest.java | 49 ++++++++++++++++++++++ .../tensorflow/DropoutImportTestCase.java | 7 +--- .../tensorflow/MnistSoftmaxImportTestCase.java | 12 +++--- .../tensorflow/TestableTensorFlowModel.java | 10 +++-- 4 files changed, 62 insertions(+), 16 deletions(-) create mode 100644 searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java (limited to 'searchlib/src/test/java/com/yahoo') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java new file mode 100644 index 00000000000..ebcfde54c70 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java @@ -0,0 +1,49 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import org.junit.Test; + +import static junit.framework.TestCase.assertTrue; + +public class DimensionRenamerTest { + + @Test + public void testMnistRenaming() { + DimensionRenamer renamer = new DimensionRenamer(); + + renamer.addDimension("first_dimension_of_x"); + renamer.addDimension("second_dimension_of_x"); + renamer.addDimension("first_dimension_of_w"); + renamer.addDimension("second_dimension_of_w"); + renamer.addDimension("first_dimension_of_b"); + + // which dimension to join on matmul + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + + // other dimensions in matmul can't be equal + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + + // for efficiency, put dimension to join on innermost + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + + // bias + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + + renamer.solve(); + + String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get(); + String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get(); + String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get(); + String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get(); + String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get(); + + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0); + assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0); + assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0); + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0); + assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0); + + + } +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index 3b25bfe1b1e..f64d697d9b9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -18,11 +18,6 @@ public class DropoutImportTestCase { public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); - // Check (provided) macros - assertEquals(1, model.get().macros().size()); - assertTrue(model.get().macros().containsKey("training_input")); - assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString()); - // Check required macros assertEquals(1, model.get().requiredMacros().size()); assertTrue(model.get().requiredMacros().containsKey("X")); @@ -37,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index ad5abd4c03d..60dd3865aa1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -22,15 +22,15 @@ public class MnistSoftmaxImportTestCase { // Check constants assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("Variable"); + Tensor constant0 = model.get().largeConstants().get("Variable_read"); assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("Variable_1"); + Tensor constant1 = model.get().largeConstants().get("Variable_1_read"); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -59,12 +59,10 @@ public class MnistSoftmaxImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))", output.getRoot().toString()); // Test execution - model.assertEqualResult("Placeholder", "Variable/read"); - model.assertEqualResult("Placeholder", "Variable_1/read"); model.assertEqualResult("Placeholder", "MatMul"); model.assertEqualResult("Placeholder", "add"); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index ae7714b271a..1691756a64d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.tensorflow.SavedModelBundle; @@ -47,8 +48,11 @@ public class TestableTensorFlowModel { private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, - FloatBuffer.allocate(d0Size * d1Size)); + FloatBuffer fb = FloatBuffer.allocate(d0Size * d1Size); + for (int i = 0; i < d1Size; ++i) { + fb.put(i, (float)(i * 1.0 / d1Size)); + } + org.tensorflow.Tensor placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb); runner.feed(inputName, placeholder); List> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); @@ -66,7 +70,7 @@ public class TestableTensorFlowModel { Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); for (int d0 = 0; d0 < d0Size; d0++) for (int d1 = 0; d1 < d1Size; d1++) - b.cell(0, d0, d1); + b.cell(d1 * 1.0 / d1Size, d0, d1); return b.build(); } -- cgit v1.2.3