summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java57
1 files changed, 31 insertions, 26 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index cf7d208ed25..db892dce593 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -10,6 +10,7 @@ import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
/**
* Tests instantiating models from rank-profiles configs.
@@ -32,11 +33,10 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, xgboost.functions().size());
- tester.assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
- ExpressionFunction function = xgboost.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get());
+ ExpressionFunction function = tester.assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ assertEquals("tensor()", function.returnType().get().toString());
assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
@@ -52,14 +52,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, onnxMnistSoftmax.functions().size());
- tester.assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
assertEquals("tensor(d1[10],d2[784])",
@@ -74,14 +74,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, tfMnistSoftmax.functions().size());
- tester.assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
- ExpressionFunction function = tfMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
@@ -92,20 +92,25 @@ public class MlModelsImportingTest {
{
Model tfMnist = tester.models().get("mnist_saved");
// Generated function
- tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
- "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
- tfMnist);
+ ExpressionFunction generatedFunction =
+ tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString());
+ assertEquals(1, generatedFunction.arguments().size());
+ assertEquals("input", generatedFunction.arguments().get(0));
+ assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types
// Function
assertEquals(1, tfMnist.functions().size());
- tester.assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
- ExpressionFunction function = tfMnist.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("input", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");