summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-02 08:44:28 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-02 08:44:28 +0200
commit55236fc050998712ad6dc136e2b5e45c9d41538f (patch)
treefbbe27c1c048846bbe9cdb26c0f80feb97e94074
parent2efcdc1fcd6258d1aa314c972dea61d28912e2db (diff)
Don't expose generated functions
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java13
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java2
4 files changed, 15 insertions, 6 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 37c2d7961a8..5c8a53c9e83 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -29,6 +30,9 @@ public class Model {
/** Free functions */
private final ImmutableList<ExpressionFunction> functions;
+ /** The subset of the free functions which are public (additional non-public methods are generated during import) */
+ private final ImmutableList<ExpressionFunction> publicFunctions;
+
/** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */
private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
@@ -70,6 +74,9 @@ public class Model {
}
this.contextPrototypes = contextBuilder.build();
this.functions = ImmutableList.copyOf(functions.values());
+ this.publicFunctions = ImmutableList.copyOf(functions.values().stream()
+ .filter(f -> ! f.getName().startsWith(IntermediateOperation.FUNCTION_PREFIX))
+ .collect(Collectors.toList()));
// Optimize functions
ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
@@ -91,10 +98,12 @@ public class Model {
public String name() { return name; }
/**
- * Returns an immutable list of the free functions of this.
+ * Returns an immutable list of the free, public functions of this.
* The functions returned always specifies types of all arguments and the return value
*/
- public List<ExpressionFunction> functions() { return functions; }
+ public List<ExpressionFunction> functions() {
+ return publicFunctions;
+ }
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
ExpressionFunction requireFunction(String name) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 68c3b954675..cf7d208ed25 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -97,11 +97,11 @@ public class MlModelsImportingTest {
tfMnist);
// Function
- assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function
+ assertEquals(1, tfMnist.functions().size());
tester.assertFunction("serving_default.y",
"join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
tfMnist);
- ExpressionFunction function = tfMnist.functions().get(1);
+ ExpressionFunction function = tfMnist.functions().get(0);
assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
assertEquals("input", function.arguments().get(0));
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index b915ee72a79..c52ea5f9047 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -128,7 +128,7 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
- String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_function_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
+ String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
assertResponse(url, 200, expected);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 34f5f1365a1..0eff8e8bc08 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -29,7 +29,7 @@ import java.util.function.Function;
*/
public abstract class IntermediateOperation {
- private final static String FUNCTION_PREFIX = "imported_ml_function_";
+ public final static String FUNCTION_PREFIX = "imported_ml_function_";
protected final String name;
protected final String modelName;