summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java6
-rw-r--r--model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_readbin86 -> 86 bytes
-rw-r--r--model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_readbin62733 -> 62733 bytes
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java)2
5 files changed, 10 insertions, 7 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 716965784e3..f236bbd4467 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -14,6 +14,8 @@ import static org.junit.Assert.assertEquals;
*/
public class MlModelsImportingTest {
+ private static final double delta = 0.00000000001;
+
@Test
public void testImportingModels() {
ModelTester tester = new ModelTester("src/test/resources/config/models/");
@@ -28,6 +30,7 @@ public class MlModelsImportingTest {
xgboost);
FunctionEvaluator evaluator = xgboost.evaluatorOf();
assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta);
}
{
@@ -40,6 +43,7 @@ public class MlModelsImportingTest {
onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta);
}
{
@@ -49,6 +53,7 @@ public class MlModelsImportingTest {
tfMnistSoftmax);
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta);
}
{
@@ -62,6 +67,7 @@ public class MlModelsImportingTest {
tfMnist);
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta); // TODO: Verify in TF native
}
}
diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read
index 5cc9575b971..4fa0eadb0d3 100644
--- a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read
+++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read
Binary files differ
diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read
index 70a6fd42c91..e768328bff5 100644
--- a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read
+++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read
Binary files differ
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index a7926cd2e02..bcfc6ce0a04 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -7,9 +7,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import org.tensorflow.SavedModelBundle;
-
-import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -21,7 +18,7 @@ import static org.junit.Assert.assertTrue;
public class OnnxMnistSoftmaxImportTestCase {
@Test
- public void testMnistSoftmaxImport() throws IOException {
+ public void testMnistSoftmaxImport() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
@@ -43,14 +40,14 @@ public class OnnxMnistSoftmaxImportTestCase {
assertEquals(1, model.requiredMacros().size());
assertTrue(model.requiredMacros().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.requiredMacros().get("Placeholder"));
+ model.requiredMacros().get("Placeholder"));
// Check outputs
RankingExpression output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.getRoot().toString());
+ output.getRoot().toString());
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index bd7644be23b..dd6c8095e3c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -13,7 +13,7 @@ import static org.junit.Assert.assertTrue;
/**
* @author bratseth
*/
-public class MnistSoftmaxImportTestCase {
+public class TensorFlowMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {