summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 05:46:22 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 05:46:22 +0200
commit9c80048457caab3881f3319aadd0990f65c04937 (patch)
treed180b1a6a866b53e0c23657a31ebe836d641911f
parent8d80010a385f40d4bb852e6b11810692a67e90ed (diff)
Include argument type information in functions
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java63
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java23
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java6
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java44
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java58
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java1
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java4
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg24
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java8
10 files changed, 176 insertions, 57 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
index 9ed82b9eef5..d2d5ecbd5aa 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
@@ -54,7 +54,7 @@ public class MlModelsTest {
assertEquals("test", config.rankprofile(2).name());
RankProfilesConfig.Rankprofile.Fef test = config.rankprofile(2).fef();
- // Compare string content in a denser for that config:
+ // Compare profile content in a denser format than config:
StringBuilder b = new StringBuilder();
for (RankProfilesConfig.Rankprofile.Fef.Property p : test.property())
b.append(p.name()).append(": ").append(p.value()).append("\n");
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 22bba9dd079..7d4db9daeff 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -36,6 +36,27 @@ import static org.junit.Assert.assertTrue;
*/
public class ModelEvaluationTest {
+ /** Tests that we do not load models (which would waste memory) when not requested */
+ @Test
+ public void testMl_serving_not_activated() {
+ Path appDir = Path.fromString("src/test/cfg/application/ml_serving_not_activated");
+ try {
+ ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir);
+ VespaModel model = tester.createVespaModel();
+ ContainerCluster cluster = model.getContainerClusters().get("container");
+ assertNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
+
+ RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
+ cluster.getConfig(b);
+ RankProfilesConfig config = new RankProfilesConfig(b);
+
+ assertEquals(0, config.rankprofile().size());
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+ }
+
@Test
public void testMl_serving() throws IOException {
Path appDir = Path.fromString("src/test/cfg/application/ml_serving");
@@ -58,27 +79,6 @@ public class ModelEvaluationTest {
}
}
- /** Tests that we do not load models (which will waste memory) when not requested */
- @Test
- public void testMl_serving_not_activated() {
- Path appDir = Path.fromString("src/test/cfg/application/ml_serving_not_activated");
- try {
- ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir);
- VespaModel model = tester.createVespaModel();
- ContainerCluster cluster = model.getContainerClusters().get("container");
- assertNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
-
- RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
- cluster.getConfig(b);
- RankProfilesConfig config = new RankProfilesConfig(b);
-
- assertEquals(0, config.rankprofile().size());
- }
- finally {
- IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- }
- }
-
private void assertHasMlModels(VespaModel model) {
ContainerCluster cluster = model.getContainerClusters().get("container");
assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
@@ -90,6 +90,7 @@ public class ModelEvaluationTest {
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
+ System.out.println(config);
RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder();
cluster.getConfig(cb);
@@ -102,6 +103,12 @@ public class ModelEvaluationTest {
assertTrue(modelNames.contains("mnist_softmax"));
assertTrue(modelNames.contains("mnist_softmax_saved"));
+ // Compare profile content in a denser format than config:
+ StringBuilder sb = new StringBuilder();
+ for (RankProfilesConfig.Rankprofile.Fef.Property p : findProfile("mnist_saved", config).property())
+ sb.append(p.name()).append(": ").append(p.value()).append("\n");
+ assertEquals(mnistProfile, sb.toString());
+
ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null))
.importFrom(config, constantsConfig));
@@ -136,6 +143,20 @@ public class ModelEvaluationTest {
assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y"));
}
+ private final String mnistProfile =
+ "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
+ "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type: tensor(d3[300])\n" +
+ "rankingExpression(serving_default.y).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
+ "rankingExpression(serving_default.y).x.type: tensor(d0[],d1[784])\n";
+
+ private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) {
+ for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
+ if (profile.name().equals(name))
+ return profile.fef();
+ }
+ throw new IllegalArgumentException("No profile named " + name);
+ }
+
// We don't have function file distribution so just return empty tensor constants
private static class ToleratingMissingConstantFilesRankProfilesConfigImporter extends RankProfilesConfigImporter {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 00fcad94ce8..5bb22b23345 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -1,6 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.collections.Pair;
+import com.yahoo.tensor.TensorType;
+
import java.util.Objects;
import java.util.Optional;
import java.util.regex.Matcher;
@@ -23,6 +26,8 @@ class FunctionReference {
private static final Pattern referencePattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
+ private static final Pattern argumentTypePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?");
/** The name of the function referenced */
private final String name;
@@ -73,4 +78,22 @@ class FunctionReference {
return Optional.of(new FunctionReference(name, instance));
}
+ /**
+ * Returns a function reference and argument name string from the given serial form,
+ * or empty if the string is not a valid function argument serial form
+ */
+ static Optional<Pair<FunctionReference, String>> fromTypeArgumentSerial(String serialForm) {
+ Matcher expressionMatcher = argumentTypePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ String argument = expressionMatcher.group(3);
+ return Optional.of(new Pair<>(new FunctionReference(name, instance), argument));
+ }
+
+ public static FunctionReference fromName(String name) {
+ return new FunctionReference(name, null);
+ }
+
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 3fb43d73187..ac8f28677a4 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -45,7 +45,6 @@ public class Model {
Collection<ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants) {
- // TODO: Optimize functions
this.name = name;
this.functions = ImmutableList.copyOf(functions);
@@ -79,7 +78,10 @@ public class Model {
public String name() { return name; }
- /** Returns an immutable list of the free functions of this */
+ /**
+ * Returns an immutable list of the free functions of this.
+ * The functions returned always specifies types of all arguments and the return value
+ */
public List<ExpressionFunction> functions() { return functions; }
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 7bea2d0825a..f48d76e86f3 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
@@ -18,7 +19,9 @@ import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -60,26 +63,34 @@ public class RankProfilesConfigImporter {
private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
throws ParseException {
- List<ExpressionFunction> functions = new ArrayList<>();
- Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
- SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
- ExpressionFunction firstPhase = null;
- ExpressionFunction secondPhase = null;
List<Constant> constants = readLargeConstants(constantsConfig);
+ Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>();
+ Map<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<>();
+ SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
+ ExpressionFunction firstPhase = null;
+ ExpressionFunction secondPhase = null;
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
+ Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
if ( reference.isPresent()) {
List<String> arguments = new ArrayList<>(); // TODO: Arguments?
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
+ ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), arguments, expression);
if (reference.get().isFree()) // make available in model under configured name
- functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); //
-
- // Make all functions, bound or not available under the name they are referenced by in expressions
- referencedFunctions.put(reference.get(),
- new ExpressionFunction(reference.get().serialForm(), arguments, expression));
+ functions.put(reference.get(), function);
+ // Make all functions, bound or not, available under the name they are referenced by in expressions
+ referencedFunctions.put(reference.get(), function);
+ }
+ else if (argumentType.isPresent()) { // Arguments always follows the function in properties
+ FunctionReference argReference = argumentType.get().getFirst();
+ ExpressionFunction function = referencedFunctions.get(argReference);
+ function = function.withArgument(argumentType.get().getSecond(), TensorType.fromSpec(property.value()));
+ if (argReference.isFree())
+ functions.put(argReference, function);
+ referencedFunctions.put(argReference, function);
}
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
@@ -93,22 +104,23 @@ public class RankProfilesConfigImporter {
smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
}
}
- if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body
- functions.add(firstPhase);
- if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body
- functions.add(secondPhase);
+ if (functionByName("firstphase", functions.values()) == null && firstPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("firstphase"), firstPhase);
+ if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("secondphase"), secondPhase);
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions, referencedFunctions, constants);
+ return new Model(profile.name(), functions.values(), referencedFunctions, constants);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
}
}
- private ExpressionFunction functionByName(String name, List<ExpressionFunction> functions) {
+ // TODO: Replace by lookup in map
+ private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) {
for (ExpressionFunction function : functions)
if (function.getName().equals(name))
return function;
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 6e55c0c9a53..40ef2c65aaa 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -1,8 +1,12 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
@@ -28,6 +32,15 @@ public class MlModelsImportingTest {
tester.assertFunction("xgboost_2_2",
"(optimized sum of condition trees of size 192 bytes)",
xgboost);
+
+
+ // Function
+ assertEquals(1, xgboost.functions().size());
+ ExpressionFunction function = xgboost.functions().get(0);
+ assertEquals("xgboost_2_2", function.getName());
+ // assertEquals("f109, f29, f56, f60", commaSeparated(xgboost.functions().get(0).arguments())); TODO
+
+ // Evaluator
FunctionEvaluator evaluator = xgboost.evaluatorOf();
assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta);
@@ -41,6 +54,16 @@ public class MlModelsImportingTest {
onnxMnistSoftmax);
assertEquals("tensor(d1[10],d2[784])",
onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
+
+ // Function
+ assertEquals(1, onnxMnistSoftmax.functions().size());
+ ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
+ // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(1, function.arguments().size());
+ assertEquals("Placeholder", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+
+ // Evaluator
FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta);
@@ -51,6 +74,16 @@ public class MlModelsImportingTest {
tester.assertFunction("serving_default.y",
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
tfMnistSoftmax);
+
+ // Function
+ assertEquals(1, tfMnistSoftmax.functions().size());
+ ExpressionFunction function = tfMnistSoftmax.functions().get(0);
+ // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(1, function.arguments().size());
+ assertEquals("x", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));
+
+ // Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta);
@@ -59,16 +92,31 @@ public class MlModelsImportingTest {
{
Model tfMnist = tester.models().get("mnist_saved");
tester.assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
tfMnist);
- // Macro:
- tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
+
+ // Generated function
+ tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
"join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
tfMnist);
- FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
- assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+
+ // Function
+ assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function
+ ExpressionFunction function = tfMnist.functions().get(1);
+ // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(1, function.arguments().size());
+ assertEquals("x", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));
+
+ // Evaluator
+ FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
+ assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta);
}
}
+ private String commaSeparated(List<?> items) {
+ return items.stream().map(item -> item.toString()).sorted().collect(Collectors.joining(", "));
+ }
+
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index 9a3e59aed80..50dd1d1d05f 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -60,7 +60,6 @@ public class ModelTester {
public void assertBoundFunction(String name, String expression, Model model) {
ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get());
assertNotNull("Function '" + name + "' is present", function);
- assertEquals(name, function.getName());
assertEquals(expression, function.getBody().getRoot().toString());
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 6726f117c05..b92e13b640f 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -126,7 +126,7 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
- String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
+ String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_function_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
assertResponse(url, 200, expected);
}
@@ -140,7 +140,7 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedEvaluateDefaultFunctionShouldFail() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/eval";
- String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_macro_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
+ String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
assertResponse(url, 404, expected);
}
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
index 1cc36f75158..7980d157193 100644
--- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -1,14 +1,22 @@
-rankprofile[0].name "mnist_saved"
-rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript"
-rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"
-rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript"
-rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))"
+rankprofile[0].name "mnist_softmax"
+rankprofile[0].fef.property[0].name "rankingExpression(default.add).rankingScript"
+rankprofile[0].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))"
+rankprofile[0].fef.property[1].name "rankingExpression(default.add).Placeholder.type"
+rankprofile[0].fef.property[1].value "tensor(d0[],d1[784])"
rankprofile[1].name "xgboost_2_2"
rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript"
rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)"
rankprofile[2].name "mnist_softmax_saved"
rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript"
rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"
-rankprofile[3].name "mnist_softmax"
-rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript"
-rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))"
+rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).x.type"
+rankprofile[2].fef.property[1].value "tensor(d0[],d1[784])"
+rankprofile[3].name "mnist_saved"
+rankprofile[3].fef.property[0].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript"
+rankprofile[3].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"
+rankprofile[3].fef.property[1].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type"
+rankprofile[3].fef.property[1].value "tensor(d3[300])"
+rankprofile[3].fef.property[2].name "rankingExpression(serving_default.y).rankingScript"
+rankprofile[3].fef.property[2].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))"
+rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).x.type"
+rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])"
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index f6502a9801d..787b857839d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -11,6 +11,7 @@ import com.yahoo.text.Utf8;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
@@ -97,7 +98,12 @@ public class ExpressionFunction {
return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
}
- public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) {
+ /** Returns a copy of this with the given argument and argument type added */
+ public ExpressionFunction withArgument(String argument, TensorType type) {
+ List<String> arguments = new ArrayList<>(this.arguments);
+ arguments.add(argument);
+ Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes);
+ argumentTypes.put(argument, type);
return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}