diff options
Diffstat (limited to 'model-evaluation')
3 files changed, 47 insertions, 31 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 5c8a53c9e83..e001204f650 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -61,12 +61,22 @@ public class Model { try { LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this); contextBuilder.put(function.getValue().getName(), context); + if ( ! function.getValue().returnType().isPresent()) { + functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); + } + for (String argument : context.arguments()) { - if (function.getValue().argumentTypes().get(argument) == null) - functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + if (function.getValue().getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) { + // Internal (generated) functions do not have type info - add arguments + if (!function.getValue().arguments().contains(argument)) + functions.put(function.getKey(), function.getValue().withArgument(argument)); + } + else { + // External functions have type info (when not scalar) - add argument types + if (function.getValue().argumentTypes().get(argument) == null) + functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + } } - if ( ! function.getValue().returnType().isPresent()) - functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index cf7d208ed25..db892dce593 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; /** * Tests instantiating models from rank-profiles configs. @@ -32,11 +33,10 @@ public class MlModelsImportingTest { // Function assertEquals(1, xgboost.functions().size()); - tester.assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - ExpressionFunction function = xgboost.functions().get(0); - assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); + ExpressionFunction function = tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + assertEquals("tensor()", function.returnType().get().toString()); assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments())); function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); @@ -52,14 +52,14 @@ public class MlModelsImportingTest { // Function assertEquals(1, onnxMnistSoftmax.functions().size()); - tester.assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - ExpressionFunction function = onnxMnistSoftmax.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); // Evaluator assertEquals("tensor(d1[10],d2[784])", @@ -74,14 +74,14 @@ public class MlModelsImportingTest { // Function assertEquals(1, tfMnistSoftmax.functions().size()); - tester.assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - ExpressionFunction function = tfMnistSoftmax.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available @@ -92,20 +92,25 @@ public class MlModelsImportingTest { { Model tfMnist = tester.models().get("mnist_saved"); // Generated function - tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", - "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", - tfMnist); + ExpressionFunction generatedFunction = + tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString()); + assertEquals(1, generatedFunction.arguments().size()); + assertEquals("input", generatedFunction.arguments().get(0)); + assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types // Function assertEquals(1, tfMnist.functions().size()); - tester.assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - ExpressionFunction function = tfMnist.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("input", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString()); // Evaluator FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java index 50dd1d1d05f..bacdb52a201 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -49,12 +49,13 @@ public class ModelTester { .importFrom(config, constantsConfig); } - public void assertFunction(String name, String expression, Model model) { + public ExpressionFunction assertFunction(String name, String expression, Model model) { assertNotNull("Model is present in config", model); ExpressionFunction function = model.function(name); assertNotNull("Function '" + name + "' is in " + model, function); assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); + return function; } public void assertBoundFunction(String name, String expression, Model model) { |