diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-25 12:51:58 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-25 12:51:58 +0100 |
commit | 01f2897bce20939c5716fc19876c2541a3d9bbc5 (patch) | |
tree | e9ee4d98a671607ffe4caa8168b78305a94240c5 /searchlib | |
parent | 819533f55f0f137d30c6828b7851a7e0d3010ed7 (diff) |
Minor improvements
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java index 13d042ee5dd..c01b92fb1c7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java @@ -5,7 +5,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -61,7 +60,7 @@ public class TensorflowImportTestCase { assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", - toNonPrimitiveString(output)); + output.getRoot().toString()); // Test execution assertEqualResult(model, result, "Placeholder", "Variable/read"); @@ -95,9 +94,14 @@ public class TensorflowImportTestCase { assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } + // Sizes of the "Placeholder" vector + private final int d0Size = 1; + private final int d1Size = 784; + private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); + org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, + FloatBuffer.allocate(d0Size * d1Size)); runner.feed(inputName, placeholder); List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); @@ -110,16 +114,11 @@ public class TensorflowImportTestCase { return context; } - private String toNonPrimitiveString(RankingExpression expression) { - // toString on the wrapping expression will map to primitives, which is harder to read - return ((TensorFunctionNode)expression.getRoot()).function().toString(); - } - private Tensor placeholderArgument() { - int size = 784; - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build()); - for (int i = 0; i < size; i++) - b.cell(0, 0, i); + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; d1++) + b.cell(0, d0, d1); return b.build(); } |