diff options
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java')
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 40ef2c65aaa..287a2387b34 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -33,7 +33,6 @@ public class MlModelsImportingTest { "(optimized sum of condition trees of size 192 bytes)", xgboost); - // Function assertEquals(1, xgboost.functions().size()); ExpressionFunction function = xgboost.functions().get(0); @@ -58,7 +57,7 @@ public class MlModelsImportingTest { // Function assertEquals(1, onnxMnistSoftmax.functions().size()); ExpressionFunction function = onnxMnistSoftmax.functions().get(0); - // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); @@ -78,7 +77,7 @@ public class MlModelsImportingTest { // Function assertEquals(1, tfMnistSoftmax.functions().size()); ExpressionFunction function = tfMnistSoftmax.functions().get(0); - // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); assertEquals("x", function.arguments().get(0)); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x")); @@ -103,7 +102,7 @@ public class MlModelsImportingTest { // Function assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function ExpressionFunction function = tfMnist.functions().get(1); - // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); assertEquals("x", function.arguments().get(0)); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x")); |