aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-05 16:04:42 +0100
committerLester Solbakken <lesters@oath.com>2018-02-22 12:54:34 +0100
commitb1f46fcd0495dbce905fb8b7318781f4cf5965a7 (patch)
treed0a0506fe66e5af4af2a927101a0eb9ed9420d38 /searchlib/src/test/java/com/yahoo
parente307df56eaaf5b0ebca5aefb7f7e0c5c3a970bdb (diff)
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'searchlib/src/test/java/com/yahoo')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java49
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java10
4 files changed, 62 insertions, 16 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
new file mode 100644
index 00000000000..ebcfde54c70
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
@@ -0,0 +1,49 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertTrue;
+
+public class DimensionRenamerTest {
+
+ @Test
+ public void testMnistRenaming() {
+ DimensionRenamer renamer = new DimensionRenamer();
+
+ renamer.addDimension("first_dimension_of_x");
+ renamer.addDimension("second_dimension_of_x");
+ renamer.addDimension("first_dimension_of_w");
+ renamer.addDimension("second_dimension_of_w");
+ renamer.addDimension("first_dimension_of_b");
+
+ // which dimension to join on matmul
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
+
+ // other dimensions in matmul can't be equal
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
+
+ // for efficiency, put dimension to join on innermost
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
+
+ // bias
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
+
+ renamer.solve();
+
+ String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
+ String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
+ String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
+ String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
+ String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
+
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
+ assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
+ assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
+ assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
+
+
+ }
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index 3b25bfe1b1e..f64d697d9b9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -18,11 +18,6 @@ public class DropoutImportTestCase {
public void testDropoutImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved");
- // Check (provided) macros
- assertEquals(1, model.get().macros().size());
- assertTrue(model.get().macros().containsKey("training_input"));
- assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString());
-
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
assertTrue(model.get().requiredMacros().containsKey("X"));
@@ -37,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/BiasAdd", output.getName());
- assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index ad5abd4c03d..60dd3865aa1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -22,15 +22,15 @@ public class MnistSoftmaxImportTestCase {
// Check constants
assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().largeConstants().get("Variable");
+ Tensor constant0 = model.get().largeConstants().get("Variable_read");
assertNotNull(constant0);
- assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().largeConstants().get("Variable_1");
+ Tensor constant1 = model.get().largeConstants().get("Variable_1_read");
assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
assertEquals(10, constant1.size());
@@ -59,12 +59,10 @@ public class MnistSoftmaxImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))",
output.getRoot().toString());
// Test execution
- model.assertEqualResult("Placeholder", "Variable/read");
- model.assertEqualResult("Placeholder", "Variable_1/read");
model.assertEqualResult("Placeholder", "MatMul");
model.assertEqualResult("Placeholder", "add");
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index ae7714b271a..1691756a64d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.tensorflow.SavedModelBundle;
@@ -47,8 +48,11 @@ public class TestableTensorFlowModel {
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
Session.Runner runner = model.session().runner();
- org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size },
- FloatBuffer.allocate(d0Size * d1Size));
+ FloatBuffer fb = FloatBuffer.allocate(d0Size * d1Size);
+ for (int i = 0; i < d1Size; ++i) {
+ fb.put(i, (float)(i * 1.0 / d1Size));
+ }
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb);
runner.feed(inputName, placeholder);
List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
assertEquals(1, results.size());
@@ -66,7 +70,7 @@ public class TestableTensorFlowModel {
Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
for (int d0 = 0; d0 < d0Size; d0++)
for (int d1 = 0; d1 < d1Size; d1++)
- b.cell(0, d0, d1);
+ b.cell(d1 * 1.0 / d1Size, d0, d1);
return b.build();
}