summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-10-02 12:06:50 +0200
committerGitHub <noreply@github.com>2018-10-02 12:06:50 +0200
commit573a8c03e44c273f9649944cffbe6d6091b8aeb7 (patch)
tree6f5f498c768eae4726207932b6080bbe24dcaed7
parent553250535e399607d3363fc38753f10d9f47a78b (diff)
parent7cca94818a1d2dbde134dd728a2a4ce7e089bc04 (diff)
Merge pull request #7173 from vespa-engine/bratseth/model-evaluation-improvements
Bratseth/model evaluation improvements
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java34
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java33
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java15
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java59
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java3
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java11
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java16
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java2
12 files changed, 113 insertions, 82 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 10de10bcdfe..adf1770284b 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -90,7 +90,7 @@ public class ModelEvaluationTest {
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
- System.out.println(config);
+ // System.out.println(config);
RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder();
cluster.getConfig(cb);
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 78b30f0c873..4d1b5a97583 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -4,7 +4,6 @@ package ai.vespa.models.evaluation;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
@@ -31,22 +30,22 @@ public final class LazyArrayContext extends Context implements ContextIndex {
public final static Value defaultContextValue = DoubleValue.zero;
+ private final ExpressionFunction function;
+
private final IndexedBindings indexedBindings;
- private LazyArrayContext(IndexedBindings indexedBindings) {
+ private LazyArrayContext(ExpressionFunction function, IndexedBindings indexedBindings) {
+ this.function = function;
this.indexedBindings = indexedBindings.copy(this);
}
- /**
- * Create a fast lookup, lazy context for an expression.
- *
- * @param expression the expression to create a context for
- */
- LazyArrayContext(RankingExpression expression,
- Map<FunctionReference, ExpressionFunction> functions,
+ /** Create a fast lookup, lazy context for a function */
+ LazyArrayContext(ExpressionFunction function,
+ Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
Model model) {
- this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model);
+ this.function = function;
+ this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model);
}
/**
@@ -76,7 +75,6 @@ public final class LazyArrayContext extends Context implements ContextIndex {
@Override
public TensorType getType(Reference reference) {
- // TODO: Add type information so we do not need to evaluate to get this
return get(requireIndexOf(reference.toString())).type();
}
@@ -128,7 +126,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
* in a different thread or for re-binding free variables.
*/
LazyArrayContext copy() {
- return new LazyArrayContext(indexedBindings);
+ return new LazyArrayContext(function, indexedBindings);
}
private static class IndexedBindings {
@@ -154,15 +152,15 @@ public final class LazyArrayContext extends Context implements ContextIndex {
* Creates indexed bindings for the given expressions.
* The given expression and functions may be inspected but cannot be stored.
*/
- IndexedBindings(RankingExpression expression,
- Map<FunctionReference, ExpressionFunction> functions,
+ IndexedBindings(ExpressionFunction function,
+ Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
LazyArrayContext owner,
Model model) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
- extractBindTargets(expression.getRoot(), functions, bindTargets, arguments);
+ extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments);
this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
@@ -183,10 +181,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
values[index] = new TensorValue(constant.value());
}
- for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
- Integer index = nameToIndex.get(function.getKey().serialForm());
+ for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) {
+ Integer index = nameToIndex.get(referencedFunction.getKey().serialForm());
if (index != null) // Referenced in this, so bind it
- values[index] = new LazyValue(function.getKey(), owner, model);
+ values[index] = new LazyValue(referencedFunction.getKey(), owner, model);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
index 4a1ee22d288..a7b536ba911 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
@@ -42,7 +42,7 @@ class LazyValue extends Value {
@Override
public TensorType type() {
- return computedValue().type(); // TODO: Keep type information in this/ExpressionFunction to avoid computing here
+ return model.requireReferencedFunction(function).returnType().get();
}
@Override
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index fda1ae935ca..e001204f650 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -29,6 +30,9 @@ public class Model {
/** Free functions */
private final ImmutableList<ExpressionFunction> functions;
+ /** The subset of the free functions which are public (additional non-public methods are generated during import) */
+ private final ImmutableList<ExpressionFunction> publicFunctions;
+
/** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */
private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
@@ -55,14 +59,24 @@ public class Model {
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this);
+ LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this);
contextBuilder.put(function.getValue().getName(), context);
+ if ( ! function.getValue().returnType().isPresent()) {
+ functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
+ }
+
for (String argument : context.arguments()) {
- if (function.getValue().argumentTypes().get(argument) == null)
- functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ if (function.getValue().getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) {
+ // Internal (generated) functions do not have type info - add arguments
+ if (!function.getValue().arguments().contains(argument))
+ functions.put(function.getKey(), function.getValue().withArgument(argument));
+ }
+ else {
+ // External functions have type info (when not scalar) - add argument types
+ if (function.getValue().argumentTypes().get(argument) == null)
+ functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ }
}
- if ( ! function.getValue().returnType().isPresent())
- functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
@@ -70,6 +84,9 @@ public class Model {
}
this.contextPrototypes = contextBuilder.build();
this.functions = ImmutableList.copyOf(functions.values());
+ this.publicFunctions = ImmutableList.copyOf(functions.values().stream()
+ .filter(f -> ! f.getName().startsWith(IntermediateOperation.FUNCTION_PREFIX))
+ .collect(Collectors.toList()));
// Optimize functions
ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
@@ -91,10 +108,12 @@ public class Model {
public String name() { return name; }
/**
- * Returns an immutable list of the free functions of this.
+ * Returns an immutable list of the free, public functions of this.
* The functions returned always specifies types of all arguments and the return value
*/
- public List<ExpressionFunction> functions() { return functions; }
+ public List<ExpressionFunction> functions() {
+ return publicFunctions;
+ }
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
ExpressionFunction requireFunction(String name) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index fc01eb84ebd..ea2ce087bd8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -115,18 +115,11 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
cursor.setString("function", compactedFunction);
cursor.setString("info", baseUrl(request) + model.name() + "/" + compactedFunction);
cursor.setString("eval", baseUrl(request) + model.name() + "/" + compactedFunction + "/" + EVALUATE);
- Cursor bindings = cursor.setArray("bindings");
- for (String bindingName : evaluator.context().names()) {
- // TODO: Use an API which exposes only the external binding names instead of this
- if (bindingName.startsWith("constant(")) {
- continue;
- }
- if (bindingName.startsWith("rankingExpression(")) {
- continue;
- }
+ Cursor bindings = cursor.setArray("arguments");
+ for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) {
Cursor binding = bindings.addObject();
- binding.setString("binding", bindingName);
- binding.setString("type", ""); // TODO: implement type information when available
+ binding.setString("name", argument.getKey());
+ binding.setString("type", argument.getValue().toString());
}
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 68c3b954675..db892dce593 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -10,6 +10,7 @@ import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
/**
* Tests instantiating models from rank-profiles configs.
@@ -32,11 +33,10 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, xgboost.functions().size());
- tester.assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
- ExpressionFunction function = xgboost.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get());
+ ExpressionFunction function = tester.assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ assertEquals("tensor()", function.returnType().get().toString());
assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
@@ -52,14 +52,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, onnxMnistSoftmax.functions().size());
- tester.assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
assertEquals("tensor(d1[10],d2[784])",
@@ -74,14 +74,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, tfMnistSoftmax.functions().size());
- tester.assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
- ExpressionFunction function = tfMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
@@ -92,20 +92,25 @@ public class MlModelsImportingTest {
{
Model tfMnist = tester.models().get("mnist_saved");
// Generated function
- tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
- "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
- tfMnist);
+ ExpressionFunction generatedFunction =
+ tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString());
+ assertEquals(1, generatedFunction.arguments().size());
+ assertEquals("input", generatedFunction.arguments().get(0));
+ assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types
// Function
- assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function
- tester.assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
- ExpressionFunction function = tfMnist.functions().get(1);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ assertEquals(1, tfMnist.functions().size());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("input", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index 50dd1d1d05f..bacdb52a201 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -49,12 +49,13 @@ public class ModelTester {
.importFrom(config, constantsConfig);
}
- public void assertFunction(String name, String expression, Model model) {
+ public ExpressionFunction assertFunction(String name, String expression, Model model) {
assertNotNull("Model is present in config", model);
ExpressionFunction function = model.function(name);
assertNotNull("Function '" + name + "' is in " + model, function);
assertEquals(name, function.getName());
assertEquals(expression, function.getBody().getRoot().toString());
+ return function;
}
public void assertBoundFunction(String name, String expression, Model model) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index b915ee72a79..23f0fa7a571 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -55,7 +55,8 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testListModels() {
String url = "http://localhost/model-evaluation/v1";
- String expected = "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}";
+ String expected =
+ "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}";
assertResponse(url, 200, expected);
}
@@ -82,14 +83,14 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSoftmaxDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax";
- String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}]}";
+ String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
assertResponse(url, 200, expected);
}
@Test
public void testMnistSoftmaxTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/";
- String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}";
+ String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}";
assertResponse(url, 200, expected);
}
@@ -128,14 +129,14 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
- String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_function_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
+ String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
assertResponse(url, 200, expected);
}
@Test
public void testMnistSavedTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/";
- String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}";
+ String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}";
assertResponse(url, 200, expected);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 787b857839d..674571ff73e 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -98,12 +98,24 @@ public class ExpressionFunction {
return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
}
- /** Returns a copy of this with the given argument and argument type added */
- public ExpressionFunction withArgument(String argument, TensorType type) {
+ /** Returns a copy of this with the given argument added (if not already present) */
+ public ExpressionFunction withArgument(String argument) {
+ if (arguments.contains(argument)) return this;
+
List<String> arguments = new ArrayList<>(this.arguments);
arguments.add(argument);
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
+ }
+
+ /** Returns a copy of this with the given argument (if not present) and argument type added */
+ public ExpressionFunction withArgument(String argument, TensorType type) {
+ List<String> arguments = new ArrayList<>(this.arguments);
+ if ( ! arguments.contains(argument))
+ arguments.add(argument);
+
Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes);
argumentTypes.put(argument, type);
+
return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
index c42d9ecc37f..cd5f42ac05c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
@@ -67,6 +67,13 @@ public class Reference extends TypeContext.Name {
}
/**
+ * Returns whether this is a simple identifier - no arguments or output
+ */
+ public boolean isIdentifier() {
+ return this.arguments.expressions().size() == 0 && output == null;
+ }
+
+ /**
* A <i>simple feature reference</i> is a reference with a single identifier argument
* (and an optional output).
*/
@@ -97,13 +104,6 @@ public class Reference extends TypeContext.Name {
}
}
- /**
- * Returns whether this is a simple identifier - no arguments or output
- */
- public boolean isIdentifier() {
- return this.arguments.expressions().size() == 0 && output == null;
- }
-
public Reference withArguments(Arguments arguments) {
return new Reference(name(), arguments, output);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index 6235756d4e1..481b7f9397a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -196,7 +196,9 @@ public abstract class ModelImporter {
if (operation.rankingExpressionFunction().isPresent()) {
TensorFunction function = operation.rankingExpressionFunction().get();
try {
- model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString()));
+ model.function(operation.rankingExpressionFunctionName(),
+ new RankingExpression(operation.rankingExpressionFunctionName(),
+ function.toString()));
}
catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function +
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 34f5f1365a1..0eff8e8bc08 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -29,7 +29,7 @@ import java.util.function.Function;
*/
public abstract class IntermediateOperation {
- private final static String FUNCTION_PREFIX = "imported_ml_function_";
+ public final static String FUNCTION_PREFIX = "imported_ml_function_";
protected final String name;
protected final String modelName;