diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-25 22:59:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-25 22:59:28 +0200 |
commit | 727ae90506e72ed0a6695e2d7cb5c719f0152842 (patch) | |
tree | 1011ce314160c766e119a42c67daf6bc35980fe4 | |
parent | ccda281b6c60de0e6c7108a8532d7f7438ebd9ae (diff) | |
parent | b525b8d8efcf71b421db1e549e4f078514e26135 (diff) |
Merge pull request #6674 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-8
Improve evaluation API
10 files changed, 208 insertions, 106 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 629fa9624c5..b645af582e1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -112,7 +112,7 @@ public class ConvertedModel { if ( ! arguments.output().isPresent()) { List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix = - expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(modelName + "." + arguments.signature().get() + ".")).collect(Collectors.toList()); + expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList()); if (entriesWithTheRightPrefix.size() < 1) throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + missingExpressionMessageSuffix()); @@ -720,8 +720,7 @@ public class ConvertedModel { public Optional<String> output() { return output; } public String toName() { - return modelName + - (signature.isPresent() ? "." + signature.get() : "") + + return (signature.isPresent() ? signature.get() : "") + (output.isPresent() ? "." + output.get() : ""); } diff --git a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java index ded8d88aa99..8331ada2271 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java @@ -2,7 +2,6 @@ package com.yahoo.config.model; import com.google.common.io.Files; -import com.yahoo.component.Version; import com.yahoo.config.ConfigInstance; import com.yahoo.config.application.api.ApplicationMetaData; import com.yahoo.config.application.api.UnparsedConfigDefinition; @@ -18,9 +17,7 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.UnproperSearch; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionKey; -import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.model.VespaModel; -import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.search.SearchDefinition; import org.junit.After; import org.junit.Rule; @@ -37,11 +34,9 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.jar.JarEntry; import java.util.jar.JarFile; import java.util.regex.Pattern; -import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.is; @@ -63,10 +58,10 @@ public class ApplicationDeployTest { @Test public void testVespaModel() throws SAXException, IOException { - FilesApplicationPackage app = createAppPkg(TESTDIR + "app1"); - assertThat(app.getApplicationName(), is("app1")); - VespaModel model = new VespaModel(app); - List<SearchDefinition> searchDefinitions = getSearchDefinitions(app); + ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "app1"); + assertThat(tester.app().getApplicationName(), is("app1")); + VespaModel model = new VespaModel(tester.app()); + List<SearchDefinition> searchDefinitions = tester.getSearchDefinitions(); assertEquals(searchDefinitions.size(), 5); for (SearchDefinition searchDefinition : searchDefinitions) { Search s = searchDefinition.getSearch(); @@ -90,11 +85,11 @@ public class ApplicationDeployTest { new File(TESTSDDIR + "product.sd"), new File(TESTSDDIR + "sock.sd")}; Arrays.sort(truth); - List<File> appSdFiles = app.getSearchDefinitionFiles(); + List<File> appSdFiles = tester.app().getSearchDefinitionFiles(); Collections.sort(appSdFiles); assertEquals(appSdFiles, Arrays.asList(truth)); - List<FilesApplicationPackage.Component> components = app.getComponents(); + List<FilesApplicationPackage.Component> components = tester.app().getComponents(); assertEquals(1, components.size()); Map<String, Bundle.DefEntry> defEntriesByName = defEntries2map(components.get(0).getDefEntries()); @@ -122,42 +117,24 @@ public class ApplicationDeployTest { } @Test - public void testMl_ServingApplication() throws SAXException, IOException { - FilesApplicationPackage app = createAppPkg(TESTDIR + "ml_serving"); - VespaModel model = new VespaModel(app); - ContainerCluster cluster = model.getContainerClusters().get("container"); - RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); - cluster.getConfig(b); - RankProfilesConfig config = new RankProfilesConfig(b); - assertEquals(3, config.rankprofile().size()); - Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); - assertTrue(modelNames.contains("xgboost_2_2_json")); - assertTrue(modelNames.contains("mnist_softmax_onnx")); - assertTrue(modelNames.contains("mnist_softmax_saved")); - } - - @Test public void testGetFile() throws IOException { - FilesApplicationPackage app = createAppPkg(TESTDIR + "app1"); - try (Reader foo = app.getFile(Path.fromString("files/foo.json")).createReader()) { + ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "app1"); + try (Reader foo = tester.app().getFile(Path.fromString("files/foo.json")).createReader()) { assertEquals(IOUtils.readAll(foo), "foo : foo\n"); } - try (Reader bar = app.getFile(Path.fromString("files/sub/bar.json")).createReader()) { + try (Reader bar = tester.app().getFile(Path.fromString("files/sub/bar.json")).createReader()) { assertEquals(IOUtils.readAll(bar), "bar : bar\n"); } - assertTrue(app.getFile(Path.createRoot()).exists()); - assertTrue(app.getFile(Path.createRoot()).isDirectory()); + assertTrue(tester.app().getFile(Path.createRoot()).exists()); + assertTrue(tester.app().getFile(Path.createRoot()).isDirectory()); } /* * Put a list of def entries to a map, with the name as key. This is done because the order * of the def entries in the list cannot be guaranteed. */ - private Map<String, Bundle.DefEntry> defEntries2map - (List<Bundle.DefEntry> defEntries) { - Map<String, Bundle.DefEntry> ret = - new HashMap<>(); - + private Map<String, Bundle.DefEntry> defEntries2map(List<Bundle.DefEntry> defEntries) { + Map<String, Bundle.DefEntry> ret = new HashMap<>(); for (Bundle.DefEntry def : defEntries) ret.put(def.defName, def); return ret; @@ -166,8 +143,8 @@ public class ApplicationDeployTest { @Test public void testSdFromDocprocBundle() throws IOException, SAXException { String appDir = "src/test/cfg/application/app_sdbundles"; - FilesApplicationPackage app = createAppPkg(appDir); - VespaModel model = new VespaModel(app); + ApplicationPackageTester tester = ApplicationPackageTester.create(appDir); + VespaModel model = new VespaModel(tester.app()); // Check that the resulting documentmanager config contains those types DocumentmanagerConfig.Builder b = new DocumentmanagerConfig.Builder(); model.getConfig(b, VespaModel.ROOT_CONFIGID); @@ -188,10 +165,10 @@ public class ApplicationDeployTest { } @Test - public void include_dirs_are_included() throws Exception { - FilesApplicationPackage app = createAppPkg(TESTDIR + "include_dirs"); + public void include_dirs_are_included() { + ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "include_dirs"); - List<String> includeDirs = app.getUserIncludeDirs(); + List<String> includeDirs = tester.app().getUserIncludeDirs(); assertThat(includeDirs, contains("jdisc_dir", "dir1", "dir2", "empty_dir")); } @@ -216,33 +193,33 @@ public class ApplicationDeployTest { public void testThatModelIsRebuiltWhenSearchDefinitionIsAdded() throws IOException { File tmpDir = tmpFolder.getRoot(); IOUtils.copyDirectory(new File(TESTDIR, "app1"), tmpDir); - FilesApplicationPackage app = createAppPkg(tmpDir.getAbsolutePath()); - assertEquals(5, getSearchDefinitions(app).size()); + ApplicationPackageTester tester = ApplicationPackageTester.create(tmpDir.getAbsolutePath()); + assertEquals(5, tester.getSearchDefinitions().size()); File sdDir = new File(tmpDir, "searchdefinitions"); File sd = new File(sdDir, "testfoo.sd"); IOUtils.writeFile(sd, "search testfoo { document testfoo { field bar type string { } } }", false); - assertEquals(6, getSearchDefinitions(app).size()); + assertEquals(6, tester.getSearchDefinitions().size()); } @Test public void testThatAppWithDeploymentXmlIsValid() throws IOException { File tmpDir = tmpFolder.getRoot(); IOUtils.copyDirectory(new File(TESTDIR, "app1"), tmpDir); - createAppPkg(tmpDir.getAbsolutePath()); + ApplicationPackageTester.create(tmpDir.getAbsolutePath()); } @Test(expected = IllegalArgumentException.class) public void testThatAppWithIllegalDeploymentXmlIsNotValid() throws IOException { File tmpDir = tmpFolder.getRoot(); IOUtils.copyDirectory(new File(TESTDIR, "app_invalid_deployment_xml"), tmpDir); - createAppPkg(tmpDir.getAbsolutePath()); + ApplicationPackageTester.create(tmpDir.getAbsolutePath()); } @Test public void testThatAppWithIllegalEmptyProdRegion() throws IOException { File tmpDir = tmpFolder.getRoot(); IOUtils.copyDirectory(new File(TESTDIR, "empty_prod_region_in_deployment_xml"), tmpDir); - createAppPkg(tmpDir.getAbsolutePath()); + ApplicationPackageTester.create(tmpDir.getAbsolutePath()); } @Test @@ -250,48 +227,13 @@ public class ApplicationDeployTest { File tmpDir = tmpFolder.getRoot(); IOUtils.copyDirectory(new File(TESTDIR, "invalid_parallel_deployment_xml"), tmpDir); try { - createAppPkg(tmpDir.getAbsolutePath()); + ApplicationPackageTester.create(tmpDir.getAbsolutePath()); fail("Expected exception"); } catch (IllegalArgumentException e) { assertThat(e.getMessage(), containsString("element \"delay\" not allowed here")); } } - private List<SearchDefinition> getSearchDefinitions(FilesApplicationPackage app) { - return new DeployState.Builder().applicationPackage(app).build().getSearchDefinitions(); - } - - public FilesApplicationPackage createAppPkg(String appPkg) throws IOException { - return createAppPkg(appPkg, true); - } - - public FilesApplicationPackage createAppPkgDoNotValidateXml(String appPkg) throws IOException { - return createAppPkg(appPkg, false); - } - - public FilesApplicationPackage createAppPkg(String appPkg, boolean validateXml) throws IOException { - final FilesApplicationPackage filesApplicationPackage = FilesApplicationPackage.fromFile(new File(appPkg)); - if (validateXml) { - ApplicationPackageXmlFilesValidator validator = - ApplicationPackageXmlFilesValidator.create(new File(appPkg), new Version(6)); - validator.checkApplication(); - validator.checkIncludedDirs(filesApplicationPackage); - } - return filesApplicationPackage; - } - - @Test - public void testThatNewServicesFileNameWorks() throws IOException { - String appPkg = TESTDIR + "newfilenames"; - assertEquals(appPkg + "/services.xml", createAppPkgDoNotValidateXml(appPkg).getServicesSource()); - } - - @Test - public void testThatNewHostsFileNameWorks() throws IOException { - String appPkg = TESTDIR + "newfilenames"; - assertEquals(appPkg + "/hosts.xml", createAppPkgDoNotValidateXml(appPkg).getHostSource()); - } - @Test public void testGetJars() throws IOException { String jarName = "src/test/cfg/application/app_sdbundles/components/testbundle.jar"; @@ -412,9 +354,9 @@ public class ApplicationDeployTest { } @Test(expected=IllegalArgumentException.class) - public void testDifferentNameOfSdFileAndSearchName() throws IOException { - FilesApplicationPackage app = createAppPkg(TESTDIR + "sdfilenametest"); - new DeployState.Builder().applicationPackage(app).build(); + public void testDifferentNameOfSdFileAndSearchName() { + ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "sdfilenametest"); + new DeployState.Builder().applicationPackage(tester.app()).build(); } } diff --git a/config-model/src/test/java/com/yahoo/config/model/ApplicationPackageTester.java b/config-model/src/test/java/com/yahoo/config/model/ApplicationPackageTester.java new file mode 100644 index 00000000000..3e052421684 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/config/model/ApplicationPackageTester.java @@ -0,0 +1,54 @@ +package com.yahoo.config.model; + +import com.yahoo.component.Version; +import com.yahoo.config.model.application.provider.ApplicationPackageXmlFilesValidator; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.vespa.model.search.SearchDefinition; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +/** + * Helper for tests using a file application package + * + * @author bratseth + */ +public class ApplicationPackageTester { + + private final FilesApplicationPackage applicationPackage; + + private ApplicationPackageTester(String applicationPackageDir, boolean validate) { + try { + FilesApplicationPackage applicationPackage = + FilesApplicationPackage.fromFile(new File(applicationPackageDir)); + if (validate) { + ApplicationPackageXmlFilesValidator validator = + ApplicationPackageXmlFilesValidator.create(new File(applicationPackageDir), new Version(6)); + validator.checkApplication(); + validator.checkIncludedDirs(applicationPackage); + } + this.applicationPackage = applicationPackage; + } + catch (IOException e) { + throw new IllegalArgumentException("Could not create an application package from '" + + applicationPackageDir + "'", e); + } + } + + public FilesApplicationPackage app() { return applicationPackage; } + + public List<SearchDefinition> getSearchDefinitions() { + return new DeployState.Builder().applicationPackage(app()).build().getSearchDefinitions(); + } + + public static ApplicationPackageTester create(String applicationPackageDir) { + return new ApplicationPackageTester(applicationPackageDir, true); + } + + public static ApplicationPackageTester createWithoutValidation(String applicationPackageDir) { + return new ApplicationPackageTester(applicationPackageDir, false); + } + +} diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java new file mode 100644 index 00000000000..8eccc4e7d06 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java @@ -0,0 +1,66 @@ +package com.yahoo.config.model; + +import ai.vespa.models.evaluation.Model; +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ContainerCluster; +import org.junit.Test; +import org.xml.sax.SAXException; + +import java.io.IOException; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * @author bratseth + */ +public class ModelEvaluationTest { + + private static final String TESTDIR = "src/test/cfg/application/"; + + @Test + public void testMl_ServingApplication() throws SAXException, IOException { + ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "ml_serving"); + VespaModel model = new VespaModel(tester.app()); + ContainerCluster cluster = model.getContainerClusters().get("container"); + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); + cluster.getConfig(b); + RankProfilesConfig config = new RankProfilesConfig(b); + System.out.println(config.rankprofile(2).toString()); + assertEquals(3, config.rankprofile().size()); + Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); + assertTrue(modelNames.contains("xgboost_2_2_json")); + assertTrue(modelNames.contains("mnist_softmax_onnx")); + assertTrue(modelNames.contains("mnist_softmax_saved")); + + ModelsEvaluator evaluator = new ModelsEvaluator(config); + + assertEquals(3, evaluator.models().size()); + Model xgboost = evaluator.models().get("xgboost_2_2_json"); + assertNotNull(xgboost); + assertNotNull(xgboost.evaluatorOf()); + assertNotNull(xgboost.evaluatorOf("xgboost_2_2_json")); + System.out.println("xgboost functions: " + xgboost.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", "))); + + Model onnx = evaluator.models().get("mnist_softmax_onnx"); + assertNotNull(onnx); + assertNotNull(onnx.evaluatorOf()); + assertNotNull(onnx.evaluatorOf("default")); + assertNotNull(onnx.evaluatorOf("default", "add")); + System.out.println("onnx functions: " + onnx.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", "))); + + Model tensorflow = evaluator.models().get("mnist_softmax_saved"); + assertNotNull(tensorflow); + assertNotNull(tensorflow.evaluatorOf()); + assertNotNull(tensorflow.evaluatorOf("serving_default")); + assertNotNull(tensorflow.evaluatorOf("serving_default", "y")); + System.out.println("tensorflow functions: " + tensorflow.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", "))); + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 9bbc1347aeb..f67c85e2881 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -82,7 +82,8 @@ class RankProfileSearchFixture { } public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { - RankProfile compiled = rankProfileRegistry.get(search, rankProfile).compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); + RankProfile compiled = rankProfileRegistry.get(search, rankProfile) + .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 4db5f312cae..a96a3ce798b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -169,7 +169,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx','y'): " + - "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax_onnx.default.add", + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.add", Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index a212726efda..c317f07b87a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -205,7 +205,7 @@ public class RankingExpressionWithTensorFlowTestCase { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+ - "Available expressions: mnist_softmax_saved.serving_default.y", + "Available expressions: serving_default.y", Exceptions.toMessageString(expected)); } } @@ -221,8 +221,8 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_default','x'): " + - "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " + - "Available expressions: mnist_softmax_saved.serving_default.y", + "No expression 'serving_default.x' in model 'mnist_softmax/saved'. " + + "Available expressions: serving_default.y", Exceptions.toMessageString(expected)); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 3b50cef6e2e..00fcad94ce8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -22,7 +22,7 @@ import java.util.regex.Pattern; class FunctionReference { private static final Pattern referencePattern = - Pattern.compile("rankingExpression\\(([a-zA-Z0-9_]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); /** The name of the function referenced */ private final String name; diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ca739195867..d8b7e82677c 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -116,14 +117,54 @@ public class Model { /** * Returns an evaluator which can be used to evaluate the given function in a single thread once. - + * * Usage: * <code>Tensor result = model.evaluatorOf("myFunction").bind("foo", value).bind("bar", value).evaluate()</code> * - * @throws IllegalArgumentException if the function is not present + * @param names the names identifying the function - this can be from 0 to 2, specifying function or "signature" + * name, and "output", respectively. Names which are unnecessary to determine the desired function + * uniquely (e.g if there is just one function or output) can be omitted. + * @throws IllegalArgumentException if the function is not present, or not uniquely identified by the names given */ - public FunctionEvaluator evaluatorOf(String function) { // TODO: Parameter overloading? - return new FunctionEvaluator(requireFunction(function), requireContextProprotype(function).copy()); + public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading? + if (names.length == 0) { + if (functions.size() > 1) + throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given"); + return evaluatorOf(functions.get(0)); + } + else if (names.length == 1) { + String name = names[0]; + ExpressionFunction function = function(name); + if (function != null) return evaluatorOf(function); + + List<ExpressionFunction> functionsStartingByName = + functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList()); + if (functionsStartingByName.size() == 0) + throwUndeterminedFunction("No function '" + name + "' in " + this); + else if (functionsStartingByName.size() == 1) + return evaluatorOf(functionsStartingByName.get(0)); + else + throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this); + + } + else if (names.length == 2) { + String name = names[0] + "." + names[1]; + ExpressionFunction function = function(name); + if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this); + return evaluatorOf(function); + } + throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " + + Arrays.toString(names)); + } + + /** Returns a single-use evaluator of a function */ + private FunctionEvaluator evaluatorOf(ExpressionFunction function) { + return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy()); + } + + private void throwUndeterminedFunction(String message) { + throw new IllegalArgumentException(message + ". Available functions: " + + functions.stream().map(f -> f.getName()).collect(Collectors.joining(", "))); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 045844ee219..6716993e1dd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -107,21 +107,20 @@ public class ImportedModel { List<Pair<String, RankingExpression>> names = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - names.add(new Pair<>(name + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), + names.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), expressions().get(outputEntry.getValue()))); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - names.add(new Pair<>(name + "." + signatureEntry.getKey(), + names.add(new Pair<>(signatureEntry.getKey(), expressions().get(signatureEntry.getKey()))); } if (signatures().isEmpty()) { // fallback for models without signatures - if (expressions().size() == 1) {// Use just model name - names.add(new Pair<>(name, - expressions().values().iterator().next())); + if (expressions().size() == 1) { + Map.Entry<String, RankingExpression> singleEntry = expressions.entrySet().iterator().next(); + names.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - names.add(new Pair<>(name + "." + expressionEntry.getKey(), - expressionEntry.getValue())); + names.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue())); } } } |