summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-05 19:10:56 +0100
committerLester Solbakken <lesters@oath.com>2018-02-05 19:10:56 +0100
commit62de95451cf663f3f43532d2c4746eaa1b678d95 (patch)
tree7750facb074ba0b919f3904a74256d32d7bb772f
parent69ba083a5edb91609bfb4ac14c88160fb064add9 (diff)
Fix unit tests for large/small tensor constants
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java3
2 files changed, 5 insertions, 4 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 01dd15d5fa0..ad5abd4c03d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -20,15 +20,15 @@ public class MnistSoftmaxImportTestCase {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
- assertEquals(2, model.get().constants().size());
+ assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().constants().get("Variable");
+ Tensor constant0 = model.get().largeConstants().get("Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().constants().get("Variable_1");
+ Tensor constant1 = model.get().largeConstants().get("Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
constant1.type());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 2c621fd2e92..ae7714b271a 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -57,7 +57,8 @@ public class TestableTensorFlowModel {
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
return context;
}