From 79ddd9f94394e03e6893839de7310be0563f8577 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 10 Sep 2018 10:43:32 +0200 Subject: Use correct model --- .../models/evaluation/MlModelsImportingTest.java | 6 ++ .../mnist_softmax_saved_layer_Variable_1_read | Bin 86 -> 86 bytes .../mnist_softmax_saved_layer_Variable_read | Bin 62733 -> 62733 bytes .../integration/ml/MnistSoftmaxImportTestCase.java | 70 --------------------- .../ml/OnnxMnistSoftmaxImportTestCase.java | 9 +-- .../ml/TensorFlowMnistSoftmaxImportTestCase.java | 70 +++++++++++++++++++++ 6 files changed, 79 insertions(+), 76 deletions(-) delete mode 100644 searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java create mode 100644 searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 716965784e3..f236bbd4467 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -14,6 +14,8 @@ import static org.junit.Assert.assertEquals; */ public class MlModelsImportingTest { + private static final double delta = 0.00000000001; + @Test public void testImportingModels() { ModelTester tester = new ModelTester("src/test/resources/config/models/"); @@ -28,6 +30,7 @@ public class MlModelsImportingTest { xgboost); FunctionEvaluator evaluator = xgboost.evaluatorOf(); assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta); } { @@ -40,6 +43,7 @@ public class MlModelsImportingTest { onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } { @@ -49,6 +53,7 @@ public class MlModelsImportingTest { tfMnistSoftmax); FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } { @@ -62,6 +67,7 @@ public class MlModelsImportingTest { tfMnist); FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta); // TODO: Verify in TF native } } diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read index 5cc9575b971..4fa0eadb0d3 100644 Binary files a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read and b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read differ diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read index 70a6fd42c91..e768328bff5 100644 Binary files a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read and b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read differ diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java deleted file mode 100644 index bd7644be23b..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -/** - * @author bratseth - */ -public class MnistSoftmaxImportTestCase { - - @Test - public void testMnistSoftmaxImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved"); - - // Check constants - assertEquals(2, model.get().largeConstants().size()); - - Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); - assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), - constant0.type()); - assertEquals(7840, constant0.size()); - - Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); - assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d1", 10).build(), - constant1.type()); - assertEquals(10, constant1.size()); - - // Check (provided) macros - assertEquals(0, model.get().macros().size()); - - // Check required macros - assertEquals(1, model.get().requiredMacros().size()); - assertTrue(model.get().requiredMacros().containsKey("Placeholder")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().requiredMacros().get("Placeholder")); - - // Check signatures - assertEquals(1, model.get().signatures().size()); - ImportedModel.Signature signature = model.get().signatures().get("serving_default"); - assertNotNull(signature); - - // ... signature inputs - assertEquals(1, signature.inputs().size()); - TensorType argument0 = signature.inputArgument("x"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - - // ... signature outputs - assertEquals(1, signature.outputs().size()); - RankingExpression output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("add", output.getName()); - assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", - output.getRoot().toString()); - - // Test execution - model.assertEqualResult("Placeholder", "MatMul"); - model.assertEqualResult("Placeholder", "add"); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index a7926cd2e02..bcfc6ce0a04 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -7,9 +7,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import org.tensorflow.SavedModelBundle; - -import java.io.IOException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -21,7 +18,7 @@ import static org.junit.Assert.assertTrue; public class OnnxMnistSoftmaxImportTestCase { @Test - public void testMnistSoftmaxImport() throws IOException { + public void testMnistSoftmaxImport() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants @@ -43,14 +40,14 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals(1, model.requiredMacros().size()); assertTrue(model.requiredMacros().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder")); + model.requiredMacros().get("Placeholder")); // Check outputs RankingExpression output = model.defaultSignature().outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getRoot().toString()); + output.getRoot().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java new file mode 100644 index 00000000000..dd6c8095e3c --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -0,0 +1,70 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * @author bratseth + */ +public class TensorFlowMnistSoftmaxImportTestCase { + + @Test + public void testMnistSoftmaxImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved"); + + // Check constants + assertEquals(2, model.get().largeConstants().size()); + + Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check (provided) macros + assertEquals(0, model.get().macros().size()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("Placeholder")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("Placeholder")); + + // Check signatures + assertEquals(1, model.get().signatures().size()); + ImportedModel.Signature signature = model.get().signatures().get("serving_default"); + assertNotNull(signature); + + // ... signature inputs + assertEquals(1, signature.inputs().size()); + TensorType argument0 = signature.inputArgument("x"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // ... signature outputs + assertEquals(1, signature.outputs().size()); + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("add", output.getName()); + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", + output.getRoot().toString()); + + // Test execution + model.assertEqualResult("Placeholder", "MatMul"); + model.assertEqualResult("Placeholder", "add"); + } + +} -- cgit v1.2.3