diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-10-02 12:06:50 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-02 12:06:50 +0200 |
commit | 573a8c03e44c273f9649944cffbe6d6091b8aeb7 (patch) | |
tree | 6f5f498c768eae4726207932b6080bbe24dcaed7 | |
parent | 553250535e399607d3363fc38753f10d9f47a78b (diff) | |
parent | 7cca94818a1d2dbde134dd728a2a4ce7e089bc04 (diff) |
Merge pull request #7173 from vespa-engine/bratseth/model-evaluation-improvements
Bratseth/model evaluation improvements
12 files changed, 113 insertions, 82 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 10de10bcdfe..adf1770284b 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -90,7 +90,7 @@ public class ModelEvaluationTest { RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); cluster.getConfig(b); RankProfilesConfig config = new RankProfilesConfig(b); - System.out.println(config); + // System.out.println(config); RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); cluster.getConfig(cb); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 78b30f0c873..4d1b5a97583 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -4,7 +4,6 @@ package ai.vespa.models.evaluation; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; @@ -31,22 +30,22 @@ public final class LazyArrayContext extends Context implements ContextIndex { public final static Value defaultContextValue = DoubleValue.zero; + private final ExpressionFunction function; + private final IndexedBindings indexedBindings; - private LazyArrayContext(IndexedBindings indexedBindings) { + private LazyArrayContext(ExpressionFunction function, IndexedBindings indexedBindings) { + this.function = function; this.indexedBindings = indexedBindings.copy(this); } - /** - * Create a fast lookup, lazy context for an expression. - * - * @param expression the expression to create a context for - */ - LazyArrayContext(RankingExpression expression, - Map<FunctionReference, ExpressionFunction> functions, + /** Create a fast lookup, lazy context for a function */ + LazyArrayContext(ExpressionFunction function, + Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, Model model) { - this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model); + this.function = function; + this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model); } /** @@ -76,7 +75,6 @@ public final class LazyArrayContext extends Context implements ContextIndex { @Override public TensorType getType(Reference reference) { - // TODO: Add type information so we do not need to evaluate to get this return get(requireIndexOf(reference.toString())).type(); } @@ -128,7 +126,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { * in a different thread or for re-binding free variables. */ LazyArrayContext copy() { - return new LazyArrayContext(indexedBindings); + return new LazyArrayContext(function, indexedBindings); } private static class IndexedBindings { @@ -154,15 +152,15 @@ public final class LazyArrayContext extends Context implements ContextIndex { * Creates indexed bindings for the given expressions. * The given expression and functions may be inspected but cannot be stored. */ - IndexedBindings(RankingExpression expression, - Map<FunctionReference, ExpressionFunction> functions, + IndexedBindings(ExpressionFunction function, + Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, LazyArrayContext owner, Model model) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation - extractBindTargets(expression.getRoot(), functions, bindTargets, arguments); + extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments); this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; @@ -183,10 +181,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { values[index] = new TensorValue(constant.value()); } - for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { - Integer index = nameToIndex.get(function.getKey().serialForm()); + for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) { + Integer index = nameToIndex.get(referencedFunction.getKey().serialForm()); if (index != null) // Referenced in this, so bind it - values[index] = new LazyValue(function.getKey(), owner, model); + values[index] = new LazyValue(referencedFunction.getKey(), owner, model); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java index 4a1ee22d288..a7b536ba911 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java @@ -42,7 +42,7 @@ class LazyValue extends Value { @Override public TensorType type() { - return computedValue().type(); // TODO: Keep type information in this/ExpressionFunction to avoid computing here + return model.requireReferencedFunction(function).returnType().get(); } @Override diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index fda1ae935ca..e001204f650 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -29,6 +30,9 @@ public class Model { /** Free functions */ private final ImmutableList<ExpressionFunction> functions; + /** The subset of the free functions which are public (additional non-public methods are generated during import) */ + private final ImmutableList<ExpressionFunction> publicFunctions; + /** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */ private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions; @@ -55,14 +59,24 @@ public class Model { ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this); + LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this); contextBuilder.put(function.getValue().getName(), context); + if ( ! function.getValue().returnType().isPresent()) { + functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); + } + for (String argument : context.arguments()) { - if (function.getValue().argumentTypes().get(argument) == null) - functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + if (function.getValue().getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) { + // Internal (generated) functions do not have type info - add arguments + if (!function.getValue().arguments().contains(argument)) + functions.put(function.getKey(), function.getValue().withArgument(argument)); + } + else { + // External functions have type info (when not scalar) - add argument types + if (function.getValue().argumentTypes().get(argument) == null) + functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + } } - if ( ! function.getValue().returnType().isPresent()) - functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); @@ -70,6 +84,9 @@ public class Model { } this.contextPrototypes = contextBuilder.build(); this.functions = ImmutableList.copyOf(functions.values()); + this.publicFunctions = ImmutableList.copyOf(functions.values().stream() + .filter(f -> ! f.getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) + .collect(Collectors.toList())); // Optimize functions ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); @@ -91,10 +108,12 @@ public class Model { public String name() { return name; } /** - * Returns an immutable list of the free functions of this. + * Returns an immutable list of the free, public functions of this. * The functions returned always specifies types of all arguments and the return value */ - public List<ExpressionFunction> functions() { return functions; } + public List<ExpressionFunction> functions() { + return publicFunctions; + } /** Returns the given function, or throws a IllegalArgumentException if it does not exist */ ExpressionFunction requireFunction(String name) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index fc01eb84ebd..ea2ce087bd8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -115,18 +115,11 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { cursor.setString("function", compactedFunction); cursor.setString("info", baseUrl(request) + model.name() + "/" + compactedFunction); cursor.setString("eval", baseUrl(request) + model.name() + "/" + compactedFunction + "/" + EVALUATE); - Cursor bindings = cursor.setArray("bindings"); - for (String bindingName : evaluator.context().names()) { - // TODO: Use an API which exposes only the external binding names instead of this - if (bindingName.startsWith("constant(")) { - continue; - } - if (bindingName.startsWith("rankingExpression(")) { - continue; - } + Cursor bindings = cursor.setArray("arguments"); + for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) { Cursor binding = bindings.addObject(); - binding.setString("binding", bindingName); - binding.setString("type", ""); // TODO: implement type information when available + binding.setString("name", argument.getKey()); + binding.setString("type", argument.getValue().toString()); } } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 68c3b954675..db892dce593 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; /** * Tests instantiating models from rank-profiles configs. @@ -32,11 +33,10 @@ public class MlModelsImportingTest { // Function assertEquals(1, xgboost.functions().size()); - tester.assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - ExpressionFunction function = xgboost.functions().get(0); - assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); + ExpressionFunction function = tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + assertEquals("tensor()", function.returnType().get().toString()); assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments())); function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); @@ -52,14 +52,14 @@ public class MlModelsImportingTest { // Function assertEquals(1, onnxMnistSoftmax.functions().size()); - tester.assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - ExpressionFunction function = onnxMnistSoftmax.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); // Evaluator assertEquals("tensor(d1[10],d2[784])", @@ -74,14 +74,14 @@ public class MlModelsImportingTest { // Function assertEquals(1, tfMnistSoftmax.functions().size()); - tester.assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - ExpressionFunction function = tfMnistSoftmax.functions().get(0); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + ExpressionFunction function = + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available @@ -92,20 +92,25 @@ public class MlModelsImportingTest { { Model tfMnist = tester.models().get("mnist_saved"); // Generated function - tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", - "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", - tfMnist); + ExpressionFunction generatedFunction = + tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString()); + assertEquals(1, generatedFunction.arguments().size()); + assertEquals("input", generatedFunction.arguments().get(0)); + assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types // Function - assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function - tester.assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - ExpressionFunction function = tfMnist.functions().get(1); - assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + assertEquals(1, tfMnist.functions().size()); + ExpressionFunction function = + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("input", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input")); + assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString()); // Evaluator FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java index 50dd1d1d05f..bacdb52a201 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -49,12 +49,13 @@ public class ModelTester { .importFrom(config, constantsConfig); } - public void assertFunction(String name, String expression, Model model) { + public ExpressionFunction assertFunction(String name, String expression, Model model) { assertNotNull("Model is present in config", model); ExpressionFunction function = model.function(name); assertNotNull("Function '" + name + "' is in " + model, function); assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); + return function; } public void assertBoundFunction(String name, String expression, Model model) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index b915ee72a79..23f0fa7a571 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -55,7 +55,8 @@ public class ModelsEvaluationHandlerTest { @Test public void testListModels() { String url = "http://localhost/model-evaluation/v1"; - String expected = "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}"; + String expected = + "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}"; assertResponse(url, 200, expected); } @@ -82,14 +83,14 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSoftmaxDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax"; - String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}]}"; + String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; assertResponse(url, 200, expected); } @Test public void testMnistSoftmaxTypeDetails() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/"; - String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}"; + String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}"; assertResponse(url, 200, expected); } @@ -128,14 +129,14 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; - String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_function_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}"; + String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; assertResponse(url, 200, expected); } @Test public void testMnistSavedTypeDetails() { String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/"; - String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}"; + String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}"; assertResponse(url, 200, expected); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 787b857839d..674571ff73e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -98,12 +98,24 @@ public class ExpressionFunction { return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType)); } - /** Returns a copy of this with the given argument and argument type added */ - public ExpressionFunction withArgument(String argument, TensorType type) { + /** Returns a copy of this with the given argument added (if not already present) */ + public ExpressionFunction withArgument(String argument) { + if (arguments.contains(argument)) return this; + List<String> arguments = new ArrayList<>(this.arguments); arguments.add(argument); + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); + } + + /** Returns a copy of this with the given argument (if not present) and argument type added */ + public ExpressionFunction withArgument(String argument, TensorType type) { + List<String> arguments = new ArrayList<>(this.arguments); + if ( ! arguments.contains(argument)) + arguments.add(argument); + Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes); argumentTypes.put(argument, type); + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java index c42d9ecc37f..cd5f42ac05c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java @@ -67,6 +67,13 @@ public class Reference extends TypeContext.Name { } /** + * Returns whether this is a simple identifier - no arguments or output + */ + public boolean isIdentifier() { + return this.arguments.expressions().size() == 0 && output == null; + } + + /** * A <i>simple feature reference</i> is a reference with a single identifier argument * (and an optional output). */ @@ -97,13 +104,6 @@ public class Reference extends TypeContext.Name { } } - /** - * Returns whether this is a simple identifier - no arguments or output - */ - public boolean isIdentifier() { - return this.arguments.expressions().size() == 0 && output == null; - } - public Reference withArguments(Arguments arguments) { return new Reference(name(), arguments, output); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 6235756d4e1..481b7f9397a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -196,7 +196,9 @@ public abstract class ModelImporter { if (operation.rankingExpressionFunction().isPresent()) { TensorFunction function = operation.rankingExpressionFunction().get(); try { - model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); + model.function(operation.rankingExpressionFunctionName(), + new RankingExpression(operation.rankingExpressionFunctionName(), + function.toString())); } catch (ParseException e) { throw new RuntimeException("Tensorflow function " + function + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index 34f5f1365a1..0eff8e8bc08 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -29,7 +29,7 @@ import java.util.function.Function; */ public abstract class IntermediateOperation { - private final static String FUNCTION_PREFIX = "imported_ml_function_"; + public final static String FUNCTION_PREFIX = "imported_ml_function_"; protected final String name; protected final String modelName; |