diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-10-02 08:44:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-10-02 08:44:28 +0200 |
commit | 55236fc050998712ad6dc136e2b5e45c9d41538f (patch) | |
tree | fbbe27c1c048846bbe9cdb26c0f80feb97e94074 /model-evaluation/src | |
parent | 2efcdc1fcd6258d1aa314c972dea61d28912e2db (diff) |
Don't expose generated functions
Diffstat (limited to 'model-evaluation/src')
3 files changed, 14 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 37c2d7961a8..5c8a53c9e83 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -29,6 +30,9 @@ public class Model { /** Free functions */ private final ImmutableList<ExpressionFunction> functions; + /** The subset of the free functions which are public (additional non-public methods are generated during import) */ + private final ImmutableList<ExpressionFunction> publicFunctions; + /** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */ private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions; @@ -70,6 +74,9 @@ public class Model { } this.contextPrototypes = contextBuilder.build(); this.functions = ImmutableList.copyOf(functions.values()); + this.publicFunctions = ImmutableList.copyOf(functions.values().stream() + .filter(f -> ! f.getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) + .collect(Collectors.toList())); // Optimize functions ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); @@ -91,10 +98,12 @@ public class Model { public String name() { return name; } /** - * Returns an immutable list of the free functions of this. + * Returns an immutable list of the free, public functions of this. * The functions returned always specifies types of all arguments and the return value */ - public List<ExpressionFunction> functions() { return functions; } + public List<ExpressionFunction> functions() { + return publicFunctions; + } /** Returns the given function, or throws a IllegalArgumentException if it does not exist */ ExpressionFunction requireFunction(String name) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 68c3b954675..cf7d208ed25 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -97,11 +97,11 @@ public class MlModelsImportingTest { tfMnist); // Function - assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function + assertEquals(1, tfMnist.functions().size()); tester.assertFunction("serving_default.y", "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", tfMnist); - ExpressionFunction function = tfMnist.functions().get(1); + ExpressionFunction function = tfMnist.functions().get(0); assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); assertEquals("input", function.arguments().get(0)); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index b915ee72a79..c52ea5f9047 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -128,7 +128,7 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; - String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_function_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}"; + String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}"; assertResponse(url, 200, expected); } |