diff options
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java')
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java | 57 |
1 files changed, 31 insertions, 26 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index cf7d208ed25..db892dce593 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; /** * Tests instantiating models from rank-profiles configs. @@ -32,11 +33,10 @@ public class MlModelsImportingTest { // Function assertEquals(1, xgboost.functions().size()); - tester.assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - ExpressionFunction function = xgboost.functions().get(0); - assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); + ExpressionFunction function = tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + assertEquals("tensor()", function.returnType().get().toString()); assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments())); function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); @@ -52,14 +52,14 @@ public class MlModelsImportingTest { // Function assertEquals(1, onnxMnistSoftmax.functions().size()); - tester.assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - ExpressionFunction function = onnxMnistSoftmax.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); // Evaluator assertEquals("tensor(d1[10],d2[784])", @@ -74,14 +74,14 @@ public class MlModelsImportingTest { // Function assertEquals(1, tfMnistSoftmax.functions().size()); - tester.assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - ExpressionFunction function = tfMnistSoftmax.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available @@ -92,20 +92,25 @@ public class MlModelsImportingTest { { Model tfMnist = tester.models().get("mnist_saved"); // Generated function - tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", - "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", - tfMnist); + ExpressionFunction generatedFunction = + tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString()); + assertEquals(1, generatedFunction.arguments().size()); + assertEquals("input", generatedFunction.arguments().get(0)); + assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types // Function assertEquals(1, tfMnist.functions().size()); - tester.assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - ExpressionFunction function = tfMnist.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("input", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString()); // Evaluator FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); |