summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java7
1 files changed, 3 insertions, 4 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 40ef2c65aaa..287a2387b34 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -33,7 +33,6 @@ public class MlModelsImportingTest {
"(optimized sum of condition trees of size 192 bytes)",
xgboost);
-
// Function
assertEquals(1, xgboost.functions().size());
ExpressionFunction function = xgboost.functions().get(0);
@@ -58,7 +57,7 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, onnxMnistSoftmax.functions().size());
ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
- // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
@@ -78,7 +77,7 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, tfMnistSoftmax.functions().size());
ExpressionFunction function = tfMnistSoftmax.functions().get(0);
- // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
assertEquals("x", function.arguments().get(0));
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));
@@ -103,7 +102,7 @@ public class MlModelsImportingTest {
// Function
assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function
ExpressionFunction function = tfMnist.functions().get(1);
- // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
assertEquals("x", function.arguments().get(0));
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));