From 62de95451cf663f3f43532d2c4746eaa1b678d95 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 5 Feb 2018 19:10:56 +0100 Subject: Fix unit tests for large/small tensor constants --- .../integration/tensorflow/MnistSoftmaxImportTestCase.java | 6 +++--- .../integration/tensorflow/TestableTensorFlowModel.java | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 01dd15d5fa0..ad5abd4c03d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -20,15 +20,15 @@ public class MnistSoftmaxImportTestCase { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, model.get().constants().size()); + assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().constants().get("Variable"); + Tensor constant0 = model.get().largeConstants().get("Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().constants().get("Variable_1"); + Tensor constant1 = model.get().largeConstants().get("Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d0", 10).build(), constant1.type()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 2c621fd2e92..ae7714b271a 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -57,7 +57,8 @@ public class TestableTensorFlowModel { private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); return context; } -- cgit v1.2.3