diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-31 11:13:51 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-31 11:13:51 +0100 |
commit | a44edeba9f38c38c431d7b9b6e1ac454e2a0e610 (patch) | |
tree | 21600936cfe396492965764911652b49b4c22731 /searchlib/src/test/java/com | |
parent | 9c4ba9bf5b96b8c62a9b8c5a6c20a9175c698b70 (diff) |
Verify macros
Diffstat (limited to 'searchlib/src/test/java/com')
2 files changed, 24 insertions, 0 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index b59b4750911..445ccf231a7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -2,10 +2,12 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author lesters @@ -15,6 +17,18 @@ public class DropoutImportTestCase { @Test public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); + + // Check (provided) macros + assertEquals(1, model.get().macros().size()); + assertTrue(model.get().macros().containsKey("training/input")); + assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("X")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("X")); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index f12b9a2c628..01dd15d5fa0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -8,6 +8,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author bratseth @@ -33,6 +34,15 @@ public class MnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); + // Check (provided) macros + assertEquals(0, model.get().macros().size()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("Placeholder")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("Placeholder")); + // Check signatures assertEquals(1, model.get().signatures().size()); TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); |