summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-25 12:51:58 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-25 12:51:58 +0100
commit01f2897bce20939c5716fc19876c2541a3d9bbc5 (patch)
treee9ee4d98a671607ffe4caa8168b78305a94240c5 /searchlib
parent819533f55f0f137d30c6828b7851a7e0d3010ed7 (diff)
Minor improvements
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java23
1 files changed, 11 insertions, 12 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
index 13d042ee5dd..c01b92fb1c7 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
@@ -5,7 +5,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -61,7 +60,7 @@ public class TensorflowImportTestCase {
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
- toNonPrimitiveString(output));
+ output.getRoot().toString());
// Test execution
assertEqualResult(model, result, "Placeholder", "Variable/read");
@@ -95,9 +94,14 @@ public class TensorflowImportTestCase {
assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
+ // Sizes of the "Placeholder" vector
+ private final int d0Size = 1;
+ private final int d1Size = 784;
+
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
Session.Runner runner = model.session().runner();
- org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size },
+ FloatBuffer.allocate(d0Size * d1Size));
runner.feed(inputName, placeholder);
List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
assertEquals(1, results.size());
@@ -110,16 +114,11 @@ public class TensorflowImportTestCase {
return context;
}
- private String toNonPrimitiveString(RankingExpression expression) {
- // toString on the wrapping expression will map to primitives, which is harder to read
- return ((TensorFunctionNode)expression.getRoot()).function().toString();
- }
-
private Tensor placeholderArgument() {
- int size = 784;
- Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
- for (int i = 0; i < size; i++)
- b.cell(0, 0, i);
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; d1++)
+ b.cell(0, d0, d1);
return b.build();
}