diff options
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java | 6 | ||||
-rw-r--r-- | model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read | bin | 86 -> 86 bytes | |||
-rw-r--r-- | model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read | bin | 62733 -> 62733 bytes | |||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java | 9 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java) | 2 |
5 files changed, 10 insertions, 7 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 716965784e3..f236bbd4467 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -14,6 +14,8 @@ import static org.junit.Assert.assertEquals; */ public class MlModelsImportingTest { + private static final double delta = 0.00000000001; + @Test public void testImportingModels() { ModelTester tester = new ModelTester("src/test/resources/config/models/"); @@ -28,6 +30,7 @@ public class MlModelsImportingTest { xgboost); FunctionEvaluator evaluator = xgboost.evaluatorOf(); assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta); } { @@ -40,6 +43,7 @@ public class MlModelsImportingTest { onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } { @@ -49,6 +53,7 @@ public class MlModelsImportingTest { tfMnistSoftmax); FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } { @@ -62,6 +67,7 @@ public class MlModelsImportingTest { tfMnist); FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta); // TODO: Verify in TF native } } diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read Binary files differindex 5cc9575b971..4fa0eadb0d3 100644 --- a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read Binary files differindex 70a6fd42c91..e768328bff5 100644 --- a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index a7926cd2e02..bcfc6ce0a04 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -7,9 +7,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import org.tensorflow.SavedModelBundle; - -import java.io.IOException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -21,7 +18,7 @@ import static org.junit.Assert.assertTrue; public class OnnxMnistSoftmaxImportTestCase { @Test - public void testMnistSoftmaxImport() throws IOException { + public void testMnistSoftmaxImport() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants @@ -43,14 +40,14 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals(1, model.requiredMacros().size()); assertTrue(model.requiredMacros().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder")); + model.requiredMacros().get("Placeholder")); // Check outputs RankingExpression output = model.defaultSignature().outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getRoot().toString()); + output.getRoot().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index bd7644be23b..dd6c8095e3c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -13,7 +13,7 @@ import static org.junit.Assert.assertTrue; /** * @author bratseth */ -public class MnistSoftmaxImportTestCase { +public class TensorFlowMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { |