summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-02 10:01:36 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-02 10:01:36 +0200
commit8909fd9728591d8e00e7babc601c600b26d5acf4 (patch)
tree53231c4abb7857b8345c5125bb8539519f0d776e /model-evaluation/src/test
parent55236fc050998712ad6dc136e2b5e45c9d41538f (diff)
Be truthful about generated functions
Diffstat (limited to 'model-evaluation/src/test')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java57
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java3
2 files changed, 33 insertions, 27 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index cf7d208ed25..db892dce593 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -10,6 +10,7 @@ import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
/**
* Tests instantiating models from rank-profiles configs.
@@ -32,11 +33,10 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, xgboost.functions().size());
- tester.assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
- ExpressionFunction function = xgboost.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get());
+ ExpressionFunction function = tester.assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ assertEquals("tensor()", function.returnType().get().toString());
assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
@@ -52,14 +52,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, onnxMnistSoftmax.functions().size());
- tester.assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
assertEquals("tensor(d1[10],d2[784])",
@@ -74,14 +74,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, tfMnistSoftmax.functions().size());
- tester.assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
- ExpressionFunction function = tfMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
@@ -92,20 +92,25 @@ public class MlModelsImportingTest {
{
Model tfMnist = tester.models().get("mnist_saved");
// Generated function
- tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
- "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
- tfMnist);
+ ExpressionFunction generatedFunction =
+ tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString());
+ assertEquals(1, generatedFunction.arguments().size());
+ assertEquals("input", generatedFunction.arguments().get(0));
+ assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types
// Function
assertEquals(1, tfMnist.functions().size());
- tester.assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
- ExpressionFunction function = tfMnist.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("input", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index 50dd1d1d05f..bacdb52a201 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -49,12 +49,13 @@ public class ModelTester {
.importFrom(config, constantsConfig);
}
- public void assertFunction(String name, String expression, Model model) {
+ public ExpressionFunction assertFunction(String name, String expression, Model model) {
assertNotNull("Model is present in config", model);
ExpressionFunction function = model.function(name);
assertNotNull("Function '" + name + "' is in " + model, function);
assertEquals(name, function.getName());
assertEquals(expression, function.getBody().getRoot().toString());
+ return function;
}
public void assertBoundFunction(String name, String expression, Model model) {