aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-25 22:59:28 +0200
committerGitHub <noreply@github.com>2018-08-25 22:59:28 +0200
commit727ae90506e72ed0a6695e2d7cb5c719f0152842 (patch)
tree1011ce314160c766e119a42c67daf6bc35980fe4
parentccda281b6c60de0e6c7108a8532d7f7438ebd9ae (diff)
parentb525b8d8efcf71b421db1e549e4f078514e26135 (diff)
Merge pull request #6674 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-8
Improve evaluation API
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java5
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java114
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/ApplicationPackageTester.java54
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java66
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java6
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java49
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java13
10 files changed, 208 insertions, 106 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 629fa9624c5..b645af582e1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -112,7 +112,7 @@ public class ConvertedModel {
if ( ! arguments.output().isPresent()) {
List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix =
- expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(modelName + "." + arguments.signature().get() + ".")).collect(Collectors.toList());
+ expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList());
if (entriesWithTheRightPrefix.size() < 1)
throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() +
missingExpressionMessageSuffix());
@@ -720,8 +720,7 @@ public class ConvertedModel {
public Optional<String> output() { return output; }
public String toName() {
- return modelName +
- (signature.isPresent() ? "." + signature.get() : "") +
+ return (signature.isPresent() ? signature.get() : "") +
(output.isPresent() ? "." + output.get() : "");
}
diff --git a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java
index ded8d88aa99..8331ada2271 100644
--- a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java
+++ b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java
@@ -2,7 +2,6 @@
package com.yahoo.config.model;
import com.google.common.io.Files;
-import com.yahoo.component.Version;
import com.yahoo.config.ConfigInstance;
import com.yahoo.config.application.api.ApplicationMetaData;
import com.yahoo.config.application.api.UnparsedConfigDefinition;
@@ -18,9 +17,7 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.UnproperSearch;
import com.yahoo.vespa.config.ConfigDefinition;
import com.yahoo.vespa.config.ConfigDefinitionKey;
-import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.model.VespaModel;
-import com.yahoo.vespa.model.container.ContainerCluster;
import com.yahoo.vespa.model.search.SearchDefinition;
import org.junit.After;
import org.junit.Rule;
@@ -37,11 +34,9 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.regex.Pattern;
-import java.util.stream.Collectors;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.is;
@@ -63,10 +58,10 @@ public class ApplicationDeployTest {
@Test
public void testVespaModel() throws SAXException, IOException {
- FilesApplicationPackage app = createAppPkg(TESTDIR + "app1");
- assertThat(app.getApplicationName(), is("app1"));
- VespaModel model = new VespaModel(app);
- List<SearchDefinition> searchDefinitions = getSearchDefinitions(app);
+ ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "app1");
+ assertThat(tester.app().getApplicationName(), is("app1"));
+ VespaModel model = new VespaModel(tester.app());
+ List<SearchDefinition> searchDefinitions = tester.getSearchDefinitions();
assertEquals(searchDefinitions.size(), 5);
for (SearchDefinition searchDefinition : searchDefinitions) {
Search s = searchDefinition.getSearch();
@@ -90,11 +85,11 @@ public class ApplicationDeployTest {
new File(TESTSDDIR + "product.sd"),
new File(TESTSDDIR + "sock.sd")};
Arrays.sort(truth);
- List<File> appSdFiles = app.getSearchDefinitionFiles();
+ List<File> appSdFiles = tester.app().getSearchDefinitionFiles();
Collections.sort(appSdFiles);
assertEquals(appSdFiles, Arrays.asList(truth));
- List<FilesApplicationPackage.Component> components = app.getComponents();
+ List<FilesApplicationPackage.Component> components = tester.app().getComponents();
assertEquals(1, components.size());
Map<String, Bundle.DefEntry> defEntriesByName =
defEntries2map(components.get(0).getDefEntries());
@@ -122,42 +117,24 @@ public class ApplicationDeployTest {
}
@Test
- public void testMl_ServingApplication() throws SAXException, IOException {
- FilesApplicationPackage app = createAppPkg(TESTDIR + "ml_serving");
- VespaModel model = new VespaModel(app);
- ContainerCluster cluster = model.getContainerClusters().get("container");
- RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
- cluster.getConfig(b);
- RankProfilesConfig config = new RankProfilesConfig(b);
- assertEquals(3, config.rankprofile().size());
- Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
- assertTrue(modelNames.contains("xgboost_2_2_json"));
- assertTrue(modelNames.contains("mnist_softmax_onnx"));
- assertTrue(modelNames.contains("mnist_softmax_saved"));
- }
-
- @Test
public void testGetFile() throws IOException {
- FilesApplicationPackage app = createAppPkg(TESTDIR + "app1");
- try (Reader foo = app.getFile(Path.fromString("files/foo.json")).createReader()) {
+ ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "app1");
+ try (Reader foo = tester.app().getFile(Path.fromString("files/foo.json")).createReader()) {
assertEquals(IOUtils.readAll(foo), "foo : foo\n");
}
- try (Reader bar = app.getFile(Path.fromString("files/sub/bar.json")).createReader()) {
+ try (Reader bar = tester.app().getFile(Path.fromString("files/sub/bar.json")).createReader()) {
assertEquals(IOUtils.readAll(bar), "bar : bar\n");
}
- assertTrue(app.getFile(Path.createRoot()).exists());
- assertTrue(app.getFile(Path.createRoot()).isDirectory());
+ assertTrue(tester.app().getFile(Path.createRoot()).exists());
+ assertTrue(tester.app().getFile(Path.createRoot()).isDirectory());
}
/*
* Put a list of def entries to a map, with the name as key. This is done because the order
* of the def entries in the list cannot be guaranteed.
*/
- private Map<String, Bundle.DefEntry> defEntries2map
- (List<Bundle.DefEntry> defEntries) {
- Map<String, Bundle.DefEntry> ret =
- new HashMap<>();
-
+ private Map<String, Bundle.DefEntry> defEntries2map(List<Bundle.DefEntry> defEntries) {
+ Map<String, Bundle.DefEntry> ret = new HashMap<>();
for (Bundle.DefEntry def : defEntries)
ret.put(def.defName, def);
return ret;
@@ -166,8 +143,8 @@ public class ApplicationDeployTest {
@Test
public void testSdFromDocprocBundle() throws IOException, SAXException {
String appDir = "src/test/cfg/application/app_sdbundles";
- FilesApplicationPackage app = createAppPkg(appDir);
- VespaModel model = new VespaModel(app);
+ ApplicationPackageTester tester = ApplicationPackageTester.create(appDir);
+ VespaModel model = new VespaModel(tester.app());
// Check that the resulting documentmanager config contains those types
DocumentmanagerConfig.Builder b = new DocumentmanagerConfig.Builder();
model.getConfig(b, VespaModel.ROOT_CONFIGID);
@@ -188,10 +165,10 @@ public class ApplicationDeployTest {
}
@Test
- public void include_dirs_are_included() throws Exception {
- FilesApplicationPackage app = createAppPkg(TESTDIR + "include_dirs");
+ public void include_dirs_are_included() {
+ ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "include_dirs");
- List<String> includeDirs = app.getUserIncludeDirs();
+ List<String> includeDirs = tester.app().getUserIncludeDirs();
assertThat(includeDirs, contains("jdisc_dir", "dir1", "dir2", "empty_dir"));
}
@@ -216,33 +193,33 @@ public class ApplicationDeployTest {
public void testThatModelIsRebuiltWhenSearchDefinitionIsAdded() throws IOException {
File tmpDir = tmpFolder.getRoot();
IOUtils.copyDirectory(new File(TESTDIR, "app1"), tmpDir);
- FilesApplicationPackage app = createAppPkg(tmpDir.getAbsolutePath());
- assertEquals(5, getSearchDefinitions(app).size());
+ ApplicationPackageTester tester = ApplicationPackageTester.create(tmpDir.getAbsolutePath());
+ assertEquals(5, tester.getSearchDefinitions().size());
File sdDir = new File(tmpDir, "searchdefinitions");
File sd = new File(sdDir, "testfoo.sd");
IOUtils.writeFile(sd, "search testfoo { document testfoo { field bar type string { } } }", false);
- assertEquals(6, getSearchDefinitions(app).size());
+ assertEquals(6, tester.getSearchDefinitions().size());
}
@Test
public void testThatAppWithDeploymentXmlIsValid() throws IOException {
File tmpDir = tmpFolder.getRoot();
IOUtils.copyDirectory(new File(TESTDIR, "app1"), tmpDir);
- createAppPkg(tmpDir.getAbsolutePath());
+ ApplicationPackageTester.create(tmpDir.getAbsolutePath());
}
@Test(expected = IllegalArgumentException.class)
public void testThatAppWithIllegalDeploymentXmlIsNotValid() throws IOException {
File tmpDir = tmpFolder.getRoot();
IOUtils.copyDirectory(new File(TESTDIR, "app_invalid_deployment_xml"), tmpDir);
- createAppPkg(tmpDir.getAbsolutePath());
+ ApplicationPackageTester.create(tmpDir.getAbsolutePath());
}
@Test
public void testThatAppWithIllegalEmptyProdRegion() throws IOException {
File tmpDir = tmpFolder.getRoot();
IOUtils.copyDirectory(new File(TESTDIR, "empty_prod_region_in_deployment_xml"), tmpDir);
- createAppPkg(tmpDir.getAbsolutePath());
+ ApplicationPackageTester.create(tmpDir.getAbsolutePath());
}
@Test
@@ -250,48 +227,13 @@ public class ApplicationDeployTest {
File tmpDir = tmpFolder.getRoot();
IOUtils.copyDirectory(new File(TESTDIR, "invalid_parallel_deployment_xml"), tmpDir);
try {
- createAppPkg(tmpDir.getAbsolutePath());
+ ApplicationPackageTester.create(tmpDir.getAbsolutePath());
fail("Expected exception");
} catch (IllegalArgumentException e) {
assertThat(e.getMessage(), containsString("element \"delay\" not allowed here"));
}
}
- private List<SearchDefinition> getSearchDefinitions(FilesApplicationPackage app) {
- return new DeployState.Builder().applicationPackage(app).build().getSearchDefinitions();
- }
-
- public FilesApplicationPackage createAppPkg(String appPkg) throws IOException {
- return createAppPkg(appPkg, true);
- }
-
- public FilesApplicationPackage createAppPkgDoNotValidateXml(String appPkg) throws IOException {
- return createAppPkg(appPkg, false);
- }
-
- public FilesApplicationPackage createAppPkg(String appPkg, boolean validateXml) throws IOException {
- final FilesApplicationPackage filesApplicationPackage = FilesApplicationPackage.fromFile(new File(appPkg));
- if (validateXml) {
- ApplicationPackageXmlFilesValidator validator =
- ApplicationPackageXmlFilesValidator.create(new File(appPkg), new Version(6));
- validator.checkApplication();
- validator.checkIncludedDirs(filesApplicationPackage);
- }
- return filesApplicationPackage;
- }
-
- @Test
- public void testThatNewServicesFileNameWorks() throws IOException {
- String appPkg = TESTDIR + "newfilenames";
- assertEquals(appPkg + "/services.xml", createAppPkgDoNotValidateXml(appPkg).getServicesSource());
- }
-
- @Test
- public void testThatNewHostsFileNameWorks() throws IOException {
- String appPkg = TESTDIR + "newfilenames";
- assertEquals(appPkg + "/hosts.xml", createAppPkgDoNotValidateXml(appPkg).getHostSource());
- }
-
@Test
public void testGetJars() throws IOException {
String jarName = "src/test/cfg/application/app_sdbundles/components/testbundle.jar";
@@ -412,9 +354,9 @@ public class ApplicationDeployTest {
}
@Test(expected=IllegalArgumentException.class)
- public void testDifferentNameOfSdFileAndSearchName() throws IOException {
- FilesApplicationPackage app = createAppPkg(TESTDIR + "sdfilenametest");
- new DeployState.Builder().applicationPackage(app).build();
+ public void testDifferentNameOfSdFileAndSearchName() {
+ ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "sdfilenametest");
+ new DeployState.Builder().applicationPackage(tester.app()).build();
}
}
diff --git a/config-model/src/test/java/com/yahoo/config/model/ApplicationPackageTester.java b/config-model/src/test/java/com/yahoo/config/model/ApplicationPackageTester.java
new file mode 100644
index 00000000000..3e052421684
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/config/model/ApplicationPackageTester.java
@@ -0,0 +1,54 @@
+package com.yahoo.config.model;
+
+import com.yahoo.component.Version;
+import com.yahoo.config.model.application.provider.ApplicationPackageXmlFilesValidator;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.vespa.model.search.SearchDefinition;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Helper for tests using a file application package
+ *
+ * @author bratseth
+ */
+public class ApplicationPackageTester {
+
+ private final FilesApplicationPackage applicationPackage;
+
+ private ApplicationPackageTester(String applicationPackageDir, boolean validate) {
+ try {
+ FilesApplicationPackage applicationPackage =
+ FilesApplicationPackage.fromFile(new File(applicationPackageDir));
+ if (validate) {
+ ApplicationPackageXmlFilesValidator validator =
+ ApplicationPackageXmlFilesValidator.create(new File(applicationPackageDir), new Version(6));
+ validator.checkApplication();
+ validator.checkIncludedDirs(applicationPackage);
+ }
+ this.applicationPackage = applicationPackage;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not create an application package from '" +
+ applicationPackageDir + "'", e);
+ }
+ }
+
+ public FilesApplicationPackage app() { return applicationPackage; }
+
+ public List<SearchDefinition> getSearchDefinitions() {
+ return new DeployState.Builder().applicationPackage(app()).build().getSearchDefinitions();
+ }
+
+ public static ApplicationPackageTester create(String applicationPackageDir) {
+ return new ApplicationPackageTester(applicationPackageDir, true);
+ }
+
+ public static ApplicationPackageTester createWithoutValidation(String applicationPackageDir) {
+ return new ApplicationPackageTester(applicationPackageDir, false);
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
new file mode 100644
index 00000000000..8eccc4e7d06
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
@@ -0,0 +1,66 @@
+package com.yahoo.config.model;
+
+import ai.vespa.models.evaluation.Model;
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.container.ContainerCluster;
+import org.junit.Test;
+import org.xml.sax.SAXException;
+
+import java.io.IOException;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author bratseth
+ */
+public class ModelEvaluationTest {
+
+ private static final String TESTDIR = "src/test/cfg/application/";
+
+ @Test
+ public void testMl_ServingApplication() throws SAXException, IOException {
+ ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "ml_serving");
+ VespaModel model = new VespaModel(tester.app());
+ ContainerCluster cluster = model.getContainerClusters().get("container");
+ RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
+ cluster.getConfig(b);
+ RankProfilesConfig config = new RankProfilesConfig(b);
+ System.out.println(config.rankprofile(2).toString());
+ assertEquals(3, config.rankprofile().size());
+ Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
+ assertTrue(modelNames.contains("xgboost_2_2_json"));
+ assertTrue(modelNames.contains("mnist_softmax_onnx"));
+ assertTrue(modelNames.contains("mnist_softmax_saved"));
+
+ ModelsEvaluator evaluator = new ModelsEvaluator(config);
+
+ assertEquals(3, evaluator.models().size());
+ Model xgboost = evaluator.models().get("xgboost_2_2_json");
+ assertNotNull(xgboost);
+ assertNotNull(xgboost.evaluatorOf());
+ assertNotNull(xgboost.evaluatorOf("xgboost_2_2_json"));
+ System.out.println("xgboost functions: " + xgboost.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
+
+ Model onnx = evaluator.models().get("mnist_softmax_onnx");
+ assertNotNull(onnx);
+ assertNotNull(onnx.evaluatorOf());
+ assertNotNull(onnx.evaluatorOf("default"));
+ assertNotNull(onnx.evaluatorOf("default", "add"));
+ System.out.println("onnx functions: " + onnx.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
+
+ Model tensorflow = evaluator.models().get("mnist_softmax_saved");
+ assertNotNull(tensorflow);
+ assertNotNull(tensorflow.evaluatorOf());
+ assertNotNull(tensorflow.evaluatorOf("serving_default"));
+ assertNotNull(tensorflow.evaluatorOf("serving_default", "y"));
+ System.out.println("tensorflow functions: " + tensorflow.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 9bbc1347aeb..f67c85e2881 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -82,7 +82,8 @@ class RankProfileSearchFixture {
}
public RankProfile compileRankProfile(String rankProfile, Path applicationDir) {
- RankProfile compiled = rankProfileRegistry.get(search, rankProfile).compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile()));
+ RankProfile compiled = rankProfileRegistry.get(search, rankProfile)
+ .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile()));
compiledRankProfiles.put(rankProfile, compiled);
return compiled;
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 4db5f312cae..a96a3ce798b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -169,7 +169,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx','y'): " +
- "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax_onnx.default.add",
+ "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.add",
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index a212726efda..c317f07b87a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -205,7 +205,7 @@ public class RankingExpressionWithTensorFlowTestCase {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_defaultz'): " +
"No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+
- "Available expressions: mnist_softmax_saved.serving_default.y",
+ "Available expressions: serving_default.y",
Exceptions.toMessageString(expected));
}
}
@@ -221,8 +221,8 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_default','x'): " +
- "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " +
- "Available expressions: mnist_softmax_saved.serving_default.y",
+ "No expression 'serving_default.x' in model 'mnist_softmax/saved'. " +
+ "Available expressions: serving_default.y",
Exceptions.toMessageString(expected));
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 3b50cef6e2e..00fcad94ce8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -22,7 +22,7 @@ import java.util.regex.Pattern;
class FunctionReference {
private static final Pattern referencePattern =
- Pattern.compile("rankingExpression\\(([a-zA-Z0-9_]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
/** The name of the function referenced */
private final String name;
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ca739195867..d8b7e82677c 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@@ -116,14 +117,54 @@ public class Model {
/**
* Returns an evaluator which can be used to evaluate the given function in a single thread once.
-
+ *
* Usage:
* <code>Tensor result = model.evaluatorOf("myFunction").bind("foo", value).bind("bar", value).evaluate()</code>
*
- * @throws IllegalArgumentException if the function is not present
+ * @param names the names identifying the function - this can be from 0 to 2, specifying function or "signature"
+ * name, and "output", respectively. Names which are unnecessary to determine the desired function
+ * uniquely (e.g if there is just one function or output) can be omitted.
+ * @throws IllegalArgumentException if the function is not present, or not uniquely identified by the names given
*/
- public FunctionEvaluator evaluatorOf(String function) { // TODO: Parameter overloading?
- return new FunctionEvaluator(requireFunction(function), requireContextProprotype(function).copy());
+ public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading?
+ if (names.length == 0) {
+ if (functions.size() > 1)
+ throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given");
+ return evaluatorOf(functions.get(0));
+ }
+ else if (names.length == 1) {
+ String name = names[0];
+ ExpressionFunction function = function(name);
+ if (function != null) return evaluatorOf(function);
+
+ List<ExpressionFunction> functionsStartingByName =
+ functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList());
+ if (functionsStartingByName.size() == 0)
+ throwUndeterminedFunction("No function '" + name + "' in " + this);
+ else if (functionsStartingByName.size() == 1)
+ return evaluatorOf(functionsStartingByName.get(0));
+ else
+ throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this);
+
+ }
+ else if (names.length == 2) {
+ String name = names[0] + "." + names[1];
+ ExpressionFunction function = function(name);
+ if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this);
+ return evaluatorOf(function);
+ }
+ throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " +
+ Arrays.toString(names));
+ }
+
+ /** Returns a single-use evaluator of a function */
+ private FunctionEvaluator evaluatorOf(ExpressionFunction function) {
+ return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy());
+ }
+
+ private void throwUndeterminedFunction(String message) {
+ throw new IllegalArgumentException(message + ". Available functions: " +
+ functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 045844ee219..6716993e1dd 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -107,21 +107,20 @@ public class ImportedModel {
List<Pair<String, RankingExpression>> names = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- names.add(new Pair<>(name + "." + signatureEntry.getKey() + "." + outputEntry.getKey(),
+ names.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
expressions().get(outputEntry.getValue())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
- names.add(new Pair<>(name + "." + signatureEntry.getKey(),
+ names.add(new Pair<>(signatureEntry.getKey(),
expressions().get(signatureEntry.getKey())));
}
if (signatures().isEmpty()) { // fallback for models without signatures
- if (expressions().size() == 1) {// Use just model name
- names.add(new Pair<>(name,
- expressions().values().iterator().next()));
+ if (expressions().size() == 1) {
+ Map.Entry<String, RankingExpression> singleEntry = expressions.entrySet().iterator().next();
+ names.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue()));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- names.add(new Pair<>(name + "." + expressionEntry.getKey(),
- expressionEntry.getValue()));
+ names.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue()));
}
}
}