diff options
274 files changed, 4135 insertions, 1122 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index e97f601fb74..26e9fa8a52a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,6 +81,7 @@ add_subdirectory(messagebus) add_subdirectory(messagebus_test) add_subdirectory(metrics) add_subdirectory(model-evaluation) +add_subdirectory(model-integration) add_subdirectory(node-repository) add_subdirectory(orchestrator) add_subdirectory(persistence) diff --git a/application/pom.xml b/application/pom.xml index 10d2a14721e..9d15ec86a21 100644 --- a/application/pom.xml +++ b/application/pom.xml @@ -36,6 +36,11 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>model-integration</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>jrt</artifactId> <version>${project.version}</version> </dependency> diff --git a/clustercontroller-apps/src/main/java/com/yahoo/vespa/clustercontroller/apps/clustercontroller/ClusterControllerClusterConfigurer.java b/clustercontroller-apps/src/main/java/com/yahoo/vespa/clustercontroller/apps/clustercontroller/ClusterControllerClusterConfigurer.java index 6c8af75efac..b15cb2ad399 100644 --- a/clustercontroller-apps/src/main/java/com/yahoo/vespa/clustercontroller/apps/clustercontroller/ClusterControllerClusterConfigurer.java +++ b/clustercontroller-apps/src/main/java/com/yahoo/vespa/clustercontroller/apps/clustercontroller/ClusterControllerClusterConfigurer.java @@ -73,7 +73,6 @@ public class ClusterControllerClusterConfigurer { options.distributionBits = config.ideal_distribution_bits(); options.minNodeRatioPerGroup = config.min_node_ratio_per_group(); options.setMaxDeferredTaskVersionWaitTime(Duration.ofMillis((int)(config.max_deferred_task_version_wait_time_sec() * 1000))); - options.enableMultipleBucketSpaces = config.enable_multiple_bucket_spaces(); options.clusterHasGlobalDocumentTypes = config.cluster_has_global_document_types(); options.minMergeCompletionRatio = config.min_merge_completion_ratio(); } diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java index 56fe679fc6a..005bf7971a5 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java @@ -463,14 +463,9 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd cluster.setSlobrokGenerationCount(0); } - // TODO don't hardcode bucket spaces - if (options.enableMultipleBucketSpaces) { - configuredBucketSpaces = Collections.unmodifiableSet( - Stream.of(FixedBucketSpaces.defaultSpace(), FixedBucketSpaces.globalSpace()) - .collect(Collectors.toSet())); - } else { - configuredBucketSpaces = Collections.emptySet(); - } + configuredBucketSpaces = Collections.unmodifiableSet( + Stream.of(FixedBucketSpaces.defaultSpace(), FixedBucketSpaces.globalSpace()) + .collect(Collectors.toSet())); stateVersionTracker.setMinMergeCompletionRatio(options.minMergeCompletionRatio); communicator.propagateOptions(options); diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetControllerOptions.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetControllerOptions.java index 31268e78338..e069dde1901 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetControllerOptions.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetControllerOptions.java @@ -119,9 +119,6 @@ public class FleetControllerOptions implements Cloneable { private Duration maxDeferredTaskVersionWaitTime = Duration.ofSeconds(30); - // TODO replace this flag with a set of bucket spaces instead - public boolean enableMultipleBucketSpaces = false; - public boolean clusterHasGlobalDocumentTypes = false; // TODO: Choose a default value @@ -233,7 +230,6 @@ public class FleetControllerOptions implements Cloneable { sb.append("<tr><td><nobr>Maximum node event log size</nobr></td><td align=\"right\">").append(eventNodeLogMaxSize).append("</td></tr>"); sb.append("<tr><td><nobr>Wanted distribution bits</nobr></td><td align=\"right\">").append(distributionBits).append("</td></tr>"); sb.append("<tr><td><nobr>Max deferred task version wait time</nobr></td><td align=\"right\">").append(maxDeferredTaskVersionWaitTime.toMillis()).append("ms</td></tr>"); - sb.append("<tr><td><nobr>Multiple bucket spaces enabled</nobr></td><td align=\"right\">").append(enableMultipleBucketSpaces).append("</td></tr>"); sb.append("<tr><td><nobr>Cluster has global document types configured</nobr></td><td align=\"right\">").append(clusterHasGlobalDocumentTypes).append("</td></tr>"); sb.append("</table>"); diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/MasterElectionTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/MasterElectionTest.java index 07d176745bc..9d6e39f244a 100644 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/MasterElectionTest.java +++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/MasterElectionTest.java @@ -501,7 +501,6 @@ public class MasterElectionTest extends FleetControllerTest { public void previously_published_state_is_taken_into_account_for_default_space_when_controller_bootstraps() throws Exception { startingTest("MasterElectionTest::previously_published_state_is_taken_into_account_for_default_space_when_controller_bootstraps"); FleetControllerOptions options = new FleetControllerOptions("mycluster"); - options.enableMultipleBucketSpaces = true; options.clusterHasGlobalDocumentTypes = true; options.masterZooKeeperCooldownPeriod = 1; options.minTimeBeforeFirstSystemStateBroadcast = 100000; @@ -545,7 +544,6 @@ public class MasterElectionTest extends FleetControllerTest { public void default_space_nodes_not_marked_as_maintenance_when_cluster_has_no_global_document_types() throws Exception { startingTest("MasterElectionTest::default_space_nodes_not_marked_as_maintenance_when_cluster_has_no_global_document_types"); FleetControllerOptions options = new FleetControllerOptions("mycluster"); - options.enableMultipleBucketSpaces = true; options.clusterHasGlobalDocumentTypes = false; options.masterZooKeeperCooldownPeriod = 1; options.minTimeBeforeFirstSystemStateBroadcast = 100000; diff --git a/component/src/main/java/com/yahoo/component/provider/ComponentRegistry.java b/component/src/main/java/com/yahoo/component/provider/ComponentRegistry.java index 03b0285639f..67de9c094e8 100644 --- a/component/src/main/java/com/yahoo/component/provider/ComponentRegistry.java +++ b/component/src/main/java/com/yahoo/component/provider/ComponentRegistry.java @@ -3,14 +3,11 @@ package com.yahoo.component.provider; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.yahoo.component.Component; import com.yahoo.component.ComponentId; import com.yahoo.component.ComponentSpecification; import com.yahoo.component.Version; import com.yahoo.component.VersionSpecification; -import java.util.ArrayList; -import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -148,7 +145,7 @@ public class ComponentRegistry<COMPONENT> { * * @return the matching version, or null if there are no matches */ - protected static Version findBestMatch(VersionSpecification versionSpec,Set<Version> versions) { + protected static Version findBestMatch(VersionSpecification versionSpec, Set<Version> versions) { Version bestMatch=null; for (Version version : versions) { //No version is set if getSpecifiedMajor() == null diff --git a/config-application-package/src/main/java/com/yahoo/config/application/XmlPreProcessor.java b/config-application-package/src/main/java/com/yahoo/config/application/XmlPreProcessor.java index 0bb160319c0..261a684353b 100644 --- a/config-application-package/src/main/java/com/yahoo/config/application/XmlPreProcessor.java +++ b/config-application-package/src/main/java/com/yahoo/config/application/XmlPreProcessor.java @@ -61,7 +61,7 @@ public class XmlPreProcessor { return input; } - private List<PreProcessor> setupChain() throws IOException { + private List<PreProcessor> setupChain() { List<PreProcessor> chain = new ArrayList<>(); chain.add(new IncludeProcessor(applicationDir)); chain.add(new OverrideProcessor(environment, region)); diff --git a/config-application-package/src/test/java/com/yahoo/config/application/MultiOverrideProcessorTest.java b/config-application-package/src/test/java/com/yahoo/config/application/MultiOverrideProcessorTest.java new file mode 100644 index 00000000000..c450e478c85 --- /dev/null +++ b/config-application-package/src/test/java/com/yahoo/config/application/MultiOverrideProcessorTest.java @@ -0,0 +1,136 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.application; + +import com.yahoo.config.provision.Environment; +import com.yahoo.config.provision.RegionName; +import org.custommonkey.xmlunit.XMLUnit; +import org.junit.Test; +import org.w3c.dom.Document; + +import javax.xml.transform.TransformerException; +import java.io.StringReader; + +/** + * Demonstrates that only the most specific match is retained and that this can be overridden by using ids. + * + * @author bratseth + */ +public class MultiOverrideProcessorTest { + + static { + XMLUnit.setIgnoreWhitespace(true); + } + + private static final String input = + "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + + "<services version=\"1.0\" xmlns:deploy=\"vespa\">\n" + + " <container id='qrserver' version='1.0'>\n" + + " <component id=\"comp-B\" class=\"com.yahoo.ls.MyComponent\" bundle=\"lsbe-hv\">\n" + + " <config name=\"ls.config.resource-pool\">\n" + + " <resource>\n" + + " <item>\n" + + " <id>comp-B-item-0</id>\n" + + " <type></type>\n" + + " </item>\n" + + " <item deploy:environment=\"dev perf test staging prod\" deploy:region=\"us-west-1 us-east-3\">\n" + + " <id>comp-B-item-1</id>\n" + + " <type></type>\n" + + " </item>\n" + + " <item>\n" + + " <id>comp-B-item-2</id>\n" + + " <type></type>\n" + + " </item>\n" + + " </resource>\n" + + " </config>\n" + + " </component>\n" + + " </container>\n" + + "</services>\n"; + + private static final String inputWithIds = + "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + + "<services version=\"1.0\" xmlns:deploy=\"vespa\">\n" + + " <container id='qrserver' version='1.0'>\n" + + " <component id=\"comp-B\" class=\"com.yahoo.ls.MyComponent\" bundle=\"lsbe-hv\">\n" + + " <config name=\"ls.config.resource-pool\">\n" + + " <resource>\n" + + " <item id='1'>\n" + + " <id>comp-B-item-0</id>\n" + + " <type></type>\n" + + " </item>\n" + + " <item id='2' deploy:environment=\"dev perf test staging prod\" deploy:region=\"us-west-1 us-east-3\">\n" + + " <id>comp-B-item-1</id>\n" + + " <type></type>\n" + + " </item>\n" + + " <item id='3'>\n" + + " <id>comp-B-item-2</id>\n" + + " <type></type>\n" + + " </item>\n" + + " </resource>\n" + + " </config>\n" + + " </component>\n" + + " </container>\n" + + "</services>\n"; + + @Test + public void testParsingDev() throws TransformerException { + String expected = + "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + + "<services version=\"1.0\" xmlns:deploy=\"vespa\">\n" + + " <container id='qrserver' version='1.0'>\n" + + " <component id=\"comp-B\" class=\"com.yahoo.ls.MyComponent\" bundle=\"lsbe-hv\">\n" + + " <config name=\"ls.config.resource-pool\">\n" + + " <resource>\n" + + " <item>\n" + + " <id>comp-B-item-1</id>\n" + + " <type></type>\n" + + " </item>\n" + + " </resource>\n" + + " </config>\n" + + " </component>\n" + + " </container>\n" + + "</services>"; + assertOverride(Environment.dev, RegionName.defaultName(), expected); + } + + @Test + public void testParsingDevWithIds() throws TransformerException { + String expected = + "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + + "<services version=\"1.0\" xmlns:deploy=\"vespa\">\n" + + " <container id='qrserver' version='1.0'>\n" + + " <component id=\"comp-B\" class=\"com.yahoo.ls.MyComponent\" bundle=\"lsbe-hv\">\n" + + " <config name=\"ls.config.resource-pool\">\n" + + " <resource>\n" + + " <item id='1'>\n" + + " <id>comp-B-item-0</id>\n" + + " <type></type>\n" + + " </item>\n" + + " <item id='2'>\n" + + " <id>comp-B-item-1</id>\n" + + " <type></type>\n" + + " </item>\n" + + " <item id='3'>\n" + + " <id>comp-B-item-2</id>\n" + + " <type></type>\n" + + " </item>\n" + + " </resource>\n" + + " </config>\n" + + " </component>\n" + + " </container>\n" + + "</services>"; + assertOverrideWithIds(Environment.dev, RegionName.defaultName(), expected); + } + + private void assertOverride(Environment environment, RegionName region, String expected) throws TransformerException { + Document inputDoc = Xml.getDocument(new StringReader(input)); + Document newDoc = new OverrideProcessor(environment, region).process(inputDoc); + TestBase.assertDocument(expected, newDoc); + } + + private void assertOverrideWithIds(Environment environment, RegionName region, String expected) throws TransformerException { + Document inputDoc = Xml.getDocument(new StringReader(inputWithIds)); + Document newDoc = new OverrideProcessor(environment, region).process(inputDoc); + TestBase.assertDocument(expected, newDoc); + } + +} diff --git a/config-application-package/src/test/java/com/yahoo/config/application/OverrideProcessorTest.java b/config-application-package/src/test/java/com/yahoo/config/application/OverrideProcessorTest.java index 62e6671120b..e4690418847 100644 --- a/config-application-package/src/test/java/com/yahoo/config/application/OverrideProcessorTest.java +++ b/config-application-package/src/test/java/com/yahoo/config/application/OverrideProcessorTest.java @@ -38,7 +38,7 @@ public class OverrideProcessorTest { " <document mode='index' type='music'/>\n" + " <document type='music2' mode='index' />\n" + " <document deploy:environment='prod' deploy:region='us-east-3' mode='index' type='music'/>\n" + - " <document deploy:environment='prod' deploy:region='us-east-3' mode='index' type='music2'/>\n" + + " <document deploy:environment='staging prod' deploy:region='us-east-3' mode='index' type='music2'/>\n" + " <document deploy:environment='prod' mode='index' type='music3'/>\n" + " <document deploy:environment='prod' deploy:region='us-west' mode='index' type='music4'/>\n" + " </documents>" + @@ -272,6 +272,7 @@ public class OverrideProcessorTest { " <redundancy>1</redundancy>" + " <documents>" + " <document mode='index' type='music'/>\n" + + " <document mode='index' type='music2'/>\n" + " <document type='music2' mode='index' />\n" + " </documents>" + " <nodes>" + diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java new file mode 100644 index 00000000000..54cdf807878 --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java @@ -0,0 +1,37 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model.api; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * An imported function of an imported machine-learned model + * + * @author bratseth + */ +public class ImportedMlFunction { + + private final String name; + private final List<String> arguments; + private final Map<String, String> argumentTypes; + private final String expression; + private final Optional<String> returnType; + + public ImportedMlFunction(String name, List<String> arguments, String expression, + Map<String, String> argumentTypes, Optional<String> returnType) { + this.name = name; + this.arguments = Collections.unmodifiableList(arguments); + this.expression = expression; + this.argumentTypes = Collections.unmodifiableMap(argumentTypes); + this.returnType = returnType; + } + + public String name() { return name; } + public List<String> arguments() { return arguments; } + public Map<String, String> argumentTypes() { return argumentTypes; } + public String expression() { return expression; } + public Optional<String> returnType() { return returnType; } + +} diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java new file mode 100644 index 00000000000..078e4c239d6 --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java @@ -0,0 +1,23 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model.api; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Config model view of an imported machine-learned model. + * + * @author bratseth + */ +public interface ImportedMlModel { + + String name(); + String source(); + Optional<String> inputTypeSpec(String input); + Map<String, String> smallConstants(); + Map<String, String> largeConstants(); + Map<String, String> functions(); + List<ImportedMlFunction> outputExpressions(); + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModels.java index 40d1ca8030a..aeef81788b8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModels.java @@ -1,13 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.config.model.api; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -19,71 +18,66 @@ import java.util.Optional; * * @author bratseth */ -public class ImportedModels { +public class ImportedMlModels { /** All imported models, indexed by their names */ - private final ImmutableMap<String, ImportedModel> importedModels; - - private static final ImmutableList<ModelImporter> importers = - ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), new XGBoostImporter()); + private final Map<String, ImportedMlModel> importedModels; /** Create a null imported models */ - public ImportedModels() { - importedModels = ImmutableMap.of(); + public ImportedMlModels() { + importedModels = Collections.emptyMap(); } - public ImportedModels(File modelsDirectory) { - Map<String, ImportedModel> models = new HashMap<>(); + public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) { + Map<String, ImportedMlModel> models = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models); - importedModels = ImmutableMap.copyOf(models); + importRecursively(modelsDirectory, models, importers); + importedModels = Collections.unmodifiableMap(models); } - private static void importRecursively(File dir, Map<String, ImportedModel> models) { + /** + * Returns the model at the given location in the application package. + * + * @param modelPath the path to this model (file or directory, depending on model type) + * under the application package, both from the root or relative to the + * models directory works + * @return the model at this path or null if none + */ + public ImportedMlModel get(File modelPath) { + return importedModels.get(toName(modelPath)); + } + + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedMlModel> all() { + return importedModels.values(); + } + + private static void importRecursively(File dir, + Map<String, ImportedMlModel> models, + Collection<MlModelImporter> importers) { if ( ! dir.isDirectory()) return; Arrays.stream(dir.listFiles()).sorted().forEach(child -> { - Optional<ModelImporter> importer = findImporterOf(child); + Optional<MlModelImporter> importer = findImporterOf(child, importers); if (importer.isPresent()) { String name = toName(child); - ImportedModel existing = models.get(name); + ImportedMlModel existing = models.get(name); if (existing != null) throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + " both resolve to the model name '" + name + "'"); models.put(name, importer.get().importModel(name, child)); } else { - importRecursively(child, models); + importRecursively(child, models, importers); } }); } - private static Optional<ModelImporter> findImporterOf(File path) { + private static Optional<MlModelImporter> findImporterOf(File path, Collection<MlModelImporter> importers) { return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); } - /** - * Returns the model at the given location in the application package. - * - * @param modelPath the path to this model (file or directory, depending on model type) - * under the application package, both from the root or relative to the - * models directory works - * @return the model at this path or null if none - */ - public ImportedModel get(File modelPath) { - return importedModels.get(toName(modelPath)); - } - - public ImportedModel get(String modelName) { - return importedModels.get(modelName); - } - - /** Returns an immutable collection of all the imported models */ - public Collection<ImportedModel> all() { - return importedModels.values(); - } - private static String toName(File modelFile) { Path modelPath = Path.fromString(modelFile.toString()); if (modelFile.isFile()) diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java b/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java new file mode 100644 index 00000000000..d24eeb2d55a --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java @@ -0,0 +1,16 @@ +package com.yahoo.config.model.api; + +import java.io.File; + +/** + * Config model view of a machine-learned model importer + * + * @author bratseth + */ +public interface MlModelImporter { + + boolean canImport(String modelPath); + + ImportedMlModel importModel(String modelName, File modelPath); + +} diff --git a/config-model/pom.xml b/config-model/pom.xml index a65a6a836ed..cfe11d1aca0 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -132,6 +132,12 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>model-integration</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>vdslib</artifactId> <version>${project.version}</version> <scope>provided</scope> @@ -292,18 +298,6 @@ <artifactId>mockito-core</artifactId> <scope>test</scope> </dependency> - <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </dependency> </dependencies> <build> diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index 574c25a2f84..06b2452dbd4 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.config.model.deploy; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.component.Version; import com.yahoo.component.Vtag; import com.yahoo.config.application.api.ApplicationPackage; @@ -9,6 +10,7 @@ import com.yahoo.config.application.api.FileRegistry; import com.yahoo.config.application.api.UnparsedConfigDefinition; import com.yahoo.config.model.api.ConfigDefinitionRepo; import com.yahoo.config.model.api.HostProvisioner; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ValidationParameters; import com.yahoo.config.model.application.provider.BaseDeployLogger; @@ -21,7 +23,6 @@ import com.yahoo.config.provision.Zone; import com.yahoo.io.reader.NamedReader; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.SearchBuilder; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionBuilder; @@ -39,6 +40,7 @@ import java.io.IOException; import java.io.Reader; import java.time.Instant; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; @@ -66,7 +68,7 @@ public class DeployState implements ConfigDefinitionStore { private final Zone zone; private final QueryProfiles queryProfiles; private final SemanticRules semanticRules; - private final ImportedModels importedModels; + private final ImportedMlModels importedModels; private final ValidationOverrides validationOverrides; private final Version wantedNodeVespaVersion; private final Instant now; @@ -80,11 +82,23 @@ public class DeployState implements ConfigDefinitionStore { return new Builder().applicationPackage(applicationPackage).build(); } - private DeployState(ApplicationPackage applicationPackage, SearchDocumentModel searchDocumentModel, RankProfileRegistry rankProfileRegistry, - FileRegistry fileRegistry, DeployLogger deployLogger, Optional<HostProvisioner> hostProvisioner, DeployProperties properties, - Optional<ApplicationPackage> permanentApplicationPackage, Optional<ConfigDefinitionRepo> configDefinitionRepo, - java.util.Optional<Model> previousModel, Set<Rotation> rotations, Zone zone, QueryProfiles queryProfiles, - SemanticRules semanticRules, Instant now, Version wantedNodeVespaVersion) { + private DeployState(ApplicationPackage applicationPackage, + SearchDocumentModel searchDocumentModel, + RankProfileRegistry rankProfileRegistry, + FileRegistry fileRegistry, + DeployLogger deployLogger, + Optional<HostProvisioner> hostProvisioner, + DeployProperties properties, + Optional<ApplicationPackage> permanentApplicationPackage, + Optional<ConfigDefinitionRepo> configDefinitionRepo, + java.util.Optional<Model> previousModel, + Set<Rotation> rotations, + Collection<MlModelImporter> modelImporters, + Zone zone, + QueryProfiles queryProfiles, + SemanticRules semanticRules, + Instant now, + Version wantedNodeVespaVersion) { this.logger = deployLogger; this.fileRegistry = fileRegistry; this.rankProfileRegistry = rankProfileRegistry; @@ -100,7 +114,8 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated - this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR)); + this.importedModels = new ImportedMlModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR), + modelImporters); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty); this.wantedNodeVespaVersion = wantedNodeVespaVersion; @@ -215,7 +230,7 @@ public class DeployState implements ConfigDefinitionStore { public SemanticRules getSemanticRules() { return semanticRules; } /** The (machine learned) models imported from the models/ directory, as an unmodifiable map indexed by model name */ - public ImportedModels getImportedModels() { return importedModels; } + public ImportedMlModels getImportedModels() { return importedModels; } public Version getWantedNodeVespaVersion() { return wantedNodeVespaVersion; } @@ -232,6 +247,7 @@ public class DeployState implements ConfigDefinitionStore { private Optional<ConfigDefinitionRepo> configDefinitionRepo = Optional.empty(); private Optional<Model> previousModel = Optional.empty(); private Set<Rotation> rotations = new HashSet<>(); + private Collection<MlModelImporter> modelImporters = Collections.emptyList(); private Zone zone = Zone.defaultZone(); private Instant now = Instant.now(); private Version wantedNodeVespaVersion = Vtag.currentVersion; @@ -281,6 +297,11 @@ public class DeployState implements ConfigDefinitionStore { return this; } + public Builder modelImporters(Collection<MlModelImporter> modelImporters) { + this.modelImporters = modelImporters; + return this; + } + public Builder zone(Zone zone) { this.zone = zone; return this; @@ -305,9 +326,23 @@ public class DeployState implements ConfigDefinitionStore { QueryProfiles queryProfiles = new QueryProfilesBuilder().build(applicationPackage); SemanticRules semanticRules = new SemanticRuleBuilder().build(applicationPackage); SearchDocumentModel searchDocumentModel = createSearchDocumentModel(rankProfileRegistry, logger, queryProfiles, validationParameters); - return new DeployState(applicationPackage, searchDocumentModel, rankProfileRegistry, fileRegistry, logger, hostProvisioner, - properties, permanentApplicationPackage, configDefinitionRepo, previousModel, rotations, - zone, queryProfiles, semanticRules, now, wantedNodeVespaVersion); + return new DeployState(applicationPackage, + searchDocumentModel, + rankProfileRegistry, + fileRegistry, + logger, + hostProvisioner, + properties, + permanentApplicationPackage, + configDefinitionRepo, + previousModel, + rotations, + modelImporters, + zone, + queryProfiles, + semanticRules, + now, + wantedNodeVespaVersion); } private SearchDocumentModel createSearchDocumentModel(RankProfileRegistry rankProfileRegistry, diff --git a/config-model/src/main/java/com/yahoo/config/model/test/TestDriver.java b/config-model/src/main/java/com/yahoo/config/model/test/TestDriver.java index 0e442ad7993..e1dc29bb321 100644 --- a/config-model/src/main/java/com/yahoo/config/model/test/TestDriver.java +++ b/config-model/src/main/java/com/yahoo/config/model/test/TestDriver.java @@ -5,7 +5,6 @@ import com.google.common.annotations.Beta; import com.yahoo.component.Version; import com.yahoo.config.model.MapConfigModelRegistry; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.application.provider.SchemaValidators; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.builder.xml.ConfigModelBuilder; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 78f61d7192d..7c0b90c35fa 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.query.profile.types.FieldDescription; @@ -16,7 +17,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; @@ -647,7 +647,7 @@ public class RankProfile implements Serializable, Cloneable { * Returns a copy of this where the content is optimized for execution. * Compiled profiles should never be modified. */ - public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { + public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedMlModels importedModels) { try { RankProfile compiled = this.clone(); compiled.compileThis(queryProfiles, importedModels); @@ -658,7 +658,7 @@ public class RankProfile implements Serializable, Cloneable { } } - private void compileThis(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { + private void compileThis(QueryProfileRegistry queryProfiles, ImportedMlModels importedModels) { checkNameCollisions(getFunctions(), getConstants()); ExpressionTransforms expressionTransforms = new ExpressionTransforms(); @@ -688,7 +688,7 @@ public class RankProfile implements Serializable, Cloneable { private Map<String, RankingExpressionFunction> compileFunctions(Supplier<Map<String, RankingExpressionFunction>> functions, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Map<String, RankingExpressionFunction> inlineFunctions, ExpressionTransforms expressionTransforms) { Map<String, RankingExpressionFunction> compiledFunctions = new LinkedHashMap<>(); @@ -716,7 +716,7 @@ public class RankProfile implements Serializable, Cloneable { private RankingExpression compile(RankingExpression expression, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Map<String, Value> constants, Map<String, RankingExpressionFunction> inlineFunctions, ExpressionTransforms expressionTransforms) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java index e354c52092f..adefa5566ab 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java @@ -1,9 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import com.yahoo.config.FileReference; import com.yahoo.vespa.model.AbstractService; -import com.yahoo.vespa.model.utils.FileSender; import java.util.Collection; import java.util.Collections; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java index d8ec0b053ad..cf88886029f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java @@ -25,6 +25,8 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -335,10 +337,21 @@ public class SearchBuilder { return createFromFile(fileName, new BaseDeployLogger()); } + /** + * Convenience factory methdd to create a SearchBuilder from multiple SD files. Only for testing. + */ + public static SearchBuilder createFromFiles(Collection<String> fileNames) throws IOException, ParseException { + return createFromFiles(fileNames, new BaseDeployLogger()); + } + public static SearchBuilder createFromFile(String fileName, DeployLogger logger) throws IOException, ParseException { return createFromFile(fileName, logger, new RankProfileRegistry(), new QueryProfileRegistry()); } + public static SearchBuilder createFromFiles(Collection<String> fileNames, DeployLogger logger) throws IOException, ParseException { + return createFromFiles(fileNames, logger, new RankProfileRegistry(), new QueryProfileRegistry()); + } + /** * Convenience factory method to import and build a {@link Search} object from a file. * @@ -354,10 +367,24 @@ public class SearchBuilder { RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryprofileRegistry) throws IOException, ParseException { + return createFromFiles(Collections.singletonList(fileName), deployLogger, + rankProfileRegistry, queryprofileRegistry); + } + + /** + * Convenience factory methdd to create a SearchBuilder from multiple SD files.. + */ + public static SearchBuilder createFromFiles(Collection<String> fileNames, + DeployLogger deployLogger, + RankProfileRegistry rankProfileRegistry, + QueryProfileRegistry queryprofileRegistry) + throws IOException, ParseException { SearchBuilder builder = new SearchBuilder(MockApplicationPackage.createEmpty(), rankProfileRegistry, queryprofileRegistry); - builder.importFile(fileName); + for (String fileName : fileNames) { + builder.importFile(fileName); + } builder.build(true, deployLogger); return builder; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 0c5c7733dda..7dc4b815da6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.derived; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.config.ConfigInstance; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.application.api.DeployLogger; @@ -12,7 +13,6 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.derived.validation.Validation; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import java.io.IOException; import java.io.Writer; @@ -49,7 +49,7 @@ public class DerivedConfiguration { public DerivedConfiguration(Search search, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels) { + ImportedMlModels importedModels) { this(search, new BaseDeployLogger(), rankProfileRegistry, queryProfiles, importedModels); } @@ -68,7 +68,7 @@ public class DerivedConfiguration { DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels) { + ImportedMlModels importedModels) { Validator.ensureNotNull("Search definition", search); this.search = search; if ( ! search.isDocumentsOnly()) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index fcbfb47c597..4c117e44857 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -1,11 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.derived; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.RankingConstants; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; @@ -45,7 +45,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels) { + ImportedMlModels importedModels) { setName(search == null ? "default" : search.getName()); this.rankingConstants = rankingConstants; deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); @@ -53,7 +53,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Search search, AttributeFields attributeFields) { if (search != null) { // profiles belonging to a search have a default profile diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index a1b0e72051b..b7f515cedd4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.derived; +import com.yahoo.config.model.api.ImportedMlModels; import com.google.common.collect.ImmutableList; import com.yahoo.collections.Pair; import com.yahoo.compress.Compressor; @@ -9,7 +10,6 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; @@ -50,7 +50,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { + public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) { this.name = rankProfile.getName(); compressedProperties = compress(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields).derive()); } @@ -148,7 +148,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { + public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) { RankProfile compiled = rankProfile.compile(queryProfiles, importedModels); attributeTypes = compiled.getAttributeTypes(); queryFeatureTypes = compiled.getQueryFeatureTypes(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 2fe2dacf2ce..f20298cfe1a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -1,10 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import java.util.HashMap; @@ -19,13 +19,13 @@ public class RankProfileTransformContext extends TransformContext { private final RankProfile rankProfile; private final QueryProfileRegistry queryProfiles; - private final ImportedModels importedModels; + private final ImportedMlModels importedModels; private final Map<String, RankProfile.RankingExpressionFunction> inlineFunctions; private final Map<String, String> rankProperties = new HashMap<>(); public RankProfileTransformContext(RankProfile rankProfile, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Map<String, Value> constants, Map<String, RankProfile.RankingExpressionFunction> inlineFunctions) { super(constants); @@ -37,7 +37,7 @@ public class RankProfileTransformContext extends TransformContext { public RankProfile rankProfile() { return rankProfile; } public QueryProfileRegistry queryProfiles() { return queryProfiles; } - public ImportedModels importedModels() { return importedModels; } + public ImportedMlModels importedModels() { return importedModels; } public Map<String, RankProfile.RankingExpressionFunction> inlineFunctions() { return inlineFunctions; } public Map<String, String> rankProperties() { return rankProperties; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/BuiltInFieldSets.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/BuiltInFieldSets.java index a0c4c8adb2d..df189389348 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/BuiltInFieldSets.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/BuiltInFieldSets.java @@ -40,6 +40,9 @@ public class BuiltInFieldSets extends Processor { private void addDocumentFieldSet() { for (Field docField : search.getDocument().fieldSet()) { + if (docField instanceof SDField && ((SDField) docField).isExtraField()) { + continue; // skip + } search.fieldSets().addBuiltInFieldSetItem(DOC_FIELDSET_NAME, docField.getName()); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/MakeDefaultSummaryTheSuperSet.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/MakeDefaultSummaryTheSuperSet.java index 6f67c22d9d2..13eebc289a6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/MakeDefaultSummaryTheSuperSet.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/MakeDefaultSummaryTheSuperSet.java @@ -39,6 +39,7 @@ public class MakeDefaultSummaryTheSuperSet extends Processor { for (SummaryField summaryField : search.getUniqueNamedSummaryFields().values() ) { if (defaultSummary.getSummaryField(summaryField.getName()) != null) continue; if (summaryField.getTransform() == SummaryTransform.ATTRIBUTE) continue; + if (summaryField.getTransform() == SummaryTransform.ATTRIBUTECOMBINER) continue; defaultSummary.add(summaryField.clone()); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index fef198f7939..58fc08d15e7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model; +import com.yahoo.config.model.api.ImportedMlModel; import com.yahoo.config.ConfigBuilder; import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; @@ -18,6 +19,7 @@ import com.yahoo.config.model.ConfigModelRepo; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.FileDistribution; import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.producer.AbstractConfigProducer; @@ -33,8 +35,6 @@ import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.searchdefinition.processing.Processing; import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.ml.ConvertedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigKey; import com.yahoo.vespa.config.ConfigPayload; @@ -150,7 +150,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri this(configModelRegistry, deployState, true, null); } - private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) throws IOException, SAXException { + private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) + throws IOException, SAXException { super("vespamodel"); this.validationOverrides = deployState.validationOverrides(); configModelRegistry = new VespaConfigModelRegistry(configModelRegistry); @@ -216,11 +217,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri * Creates a rank profile not attached to any search definition, for each imported model in the application package, * and adds it to the given rank profile registry. */ - private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedModels importedModels, + private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedMlModels importedModels, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { if ( ! importedModels.all().isEmpty()) { // models/ directory is available - for (ImportedModel model : importedModels.all()) { + for (ImportedMlModel model : importedModels.all()) { RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java index 2af9b297e9e..954f20f36c0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java @@ -10,6 +10,7 @@ import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.model.api.ConfigModelPlugin; import com.yahoo.config.model.api.HostProvisioner; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.ModelCreateResult; @@ -29,6 +30,8 @@ import org.xml.sax.SAXException; import java.io.IOException; import java.time.Clock; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.logging.Logger; @@ -41,19 +44,17 @@ public class VespaModelFactory implements ModelFactory { private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName()); private final ConfigModelRegistry configModelRegistry; + private final Collection<MlModelImporter> modelImporters; private final Zone zone; private final Clock clock; private final Version version; /** Creates a factory for vespa models for this version of the source */ @Inject - public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) { - this(Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), pluginRegistry, zone); - } - - /** Creates a factory for vespa models of a particular version */ - public VespaModelFactory(Version version, ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) { - this.version = version; + public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, + ComponentRegistry<MlModelImporter> modelImporters, + Zone zone) { + this.version = Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro); List<ConfigModelBuilder> modelBuilders = new ArrayList<>(); for (ConfigModelPlugin plugin : pluginRegistry.allComponents()) { if (plugin instanceof ConfigModelBuilder) { @@ -61,6 +62,7 @@ public class VespaModelFactory implements ModelFactory { } } this.configModelRegistry = new MapConfigModelRegistry(modelBuilders); + this.modelImporters = modelImporters.allComponents(); this.zone = zone; this.clock = Clock.systemUTC(); } @@ -79,6 +81,7 @@ public class VespaModelFactory implements ModelFactory { } else { this.configModelRegistry = configModelRegistry; } + this.modelImporters = Collections.emptyList(); this.zone = Zone.defaultZone(); this.clock = clock; } @@ -137,6 +140,7 @@ public class VespaModelFactory implements ModelFactory { .properties(createDeployProperties(modelContext.properties())) .modelHostProvisioner(createHostProvisioner(modelContext)) .rotations(modelContext.properties().rotations()) + .modelImporters(modelImporters) .zone(zone) .now(clock.instant()) .wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 29a758fee74..12613018ca7 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -641,6 +641,8 @@ public final class ContainerCluster } if (jvmGCOptions != null) { jvmBuilder.gcopts(jvmGCOptions); + } else { + jvmBuilder.gcopts(G1GC); } builder.jvm(jvmBuilder); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 91df3fee6e8..0be25808541 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -462,16 +462,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { private static String buildJvmGCOptions(Zone zone, String jvmGCOPtions, boolean isHostedVespa) { if (jvmGCOPtions != null) { return jvmGCOPtions; - } else if (zone.system() == SystemName.dev) { - return ContainerCluster.G1GC; - } else if (isHostedVespa) { - return ((zone.environment() != Environment.prod) || RegionName.from("us-east-3").equals(zone.region())) - ? ContainerCluster.G1GC : ContainerCluster.CMS; + } else if ((zone.system() == SystemName.dev) || isHostedVespa) { + return null; } else { return ContainerCluster.CMS; } } - private String getJvmOptions(ContainerCluster cluster, Element nodesElement, DeployLogger deployLogger) { + private static String getJvmOptions(ContainerCluster cluster, Element nodesElement, DeployLogger deployLogger) { String jvmOptions = ""; if (nodesElement.hasAttribute(VespaDomBuilder.JVM_OPTIONS)) { jvmOptions = nodesElement.getAttribute(VespaDomBuilder.JVM_OPTIONS); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/TuningDispatch.java b/config-model/src/main/java/com/yahoo/vespa/model/content/TuningDispatch.java index 022611fa4f7..b3815994742 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/TuningDispatch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/TuningDispatch.java @@ -7,7 +7,7 @@ package com.yahoo.vespa.model.content; public class TuningDispatch { private final Integer maxHitsPerPartition; - public enum DispatchPolicy { ROUNDROBIN, RANDOM}; + public enum DispatchPolicy { ROUNDROBIN, ADAPTIVE}; private final DispatchPolicy dispatchPolicy; private final Boolean useLocalNode; private final Double minGroupCoverage; @@ -48,6 +48,8 @@ public class TuningDispatch { } public Builder setDispatchPolicy(String policy) { if (policy == null) { + } else if ("random".equals(policy.toLowerCase())) { + dispatchPolicy = DispatchPolicy.ADAPTIVE; } else if ("round-robin".equals(policy.toLowerCase())) { dispatchPolicy = DispatchPolicy.ROUNDROBIN; } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java index 7bb6e3dabd4..27da41c2bfa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java @@ -86,7 +86,6 @@ public class ContentCluster extends AbstractConfigProducer implements private final boolean isHostedVespa; private final Map<String, NewDocumentType> documentDefinitions; private final Set<NewDocumentType> globallyDistributedDocuments; - private boolean forceEnableMultipleBucketSpaces = false; private com.yahoo.vespa.model.content.StorageGroup rootGroup; private StorageCluster storageNodes; private DistributorCluster distributorNodes; @@ -266,10 +265,7 @@ public class ContentCluster extends AbstractConfigProducer implements } private void setupExperimental(ContentCluster cluster, ModelElement experimental) { - Boolean enableMultipleBucketSpaces = experimental.childAsBoolean("enable-multiple-bucket-spaces"); - if (enableMultipleBucketSpaces != null) { - cluster.forceEnableMultipleBucketSpaces = enableMultipleBucketSpaces; - } + // Put handling of experimental flags here } private void validateGroupSiblings(String cluster, StorageGroup group) { @@ -615,7 +611,6 @@ public class ContentCluster extends AbstractConfigProducer implements builder.min_distributor_up_ratio(0); builder.min_storage_up_ratio(0); } - builder.enable_multiple_bucket_spaces(true); // Telling the controller whether we actually _have_ global document types lets // it selectively enable or disable constraints that aren't needed when these // are not are present, even if full protocol and backend support is enabled @@ -750,9 +745,5 @@ public class ContentCluster extends AbstractConfigProducer implements docTypeBuilder.bucketspace(bucketSpace); builder.documenttype(docTypeBuilder); } - // NOTE: this config is kept around to allow the use of multiple bucket spaces - // on older versions of Vespa. It is for all intents and purposes a no-op in - // newer versions where multiple bucket spaces are enabled by default. - builder.enable_multiple_bucket_spaces(forceEnableMultipleBucketSpaces); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 30586b1e677..c834bea7be2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -1,10 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import com.yahoo.config.model.api.ImportedMlFunction; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.ImportedMlModel; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; @@ -19,7 +21,6 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -41,7 +42,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; import java.io.IOException; -import java.io.Reader; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; @@ -71,12 +71,12 @@ public class ConvertedModel { private final ImmutableMap<String, ExpressionFunction> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ - private final Optional<ImportedModel> sourceModel; + private final Optional<ImportedMlModel> sourceModel; private ConvertedModel(ModelName modelName, String modelDescription, Map<String, ExpressionFunction> expressions, - Optional<ImportedModel> sourceModel) { + Optional<ImportedMlModel> sourceModel) { this.modelName = modelName; this.modelDescription = modelDescription; this.expressions = ImmutableMap.copyOf(expressions); @@ -91,13 +91,13 @@ public class ConvertedModel { * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory */ public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { - ImportedModel sourceModel = // TODO: Convert to name here, make sure its done just one way + ImportedMlModel sourceModel = // TODO: Convert to name here, make sure its done just one way context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath)); ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile); if (sourceModel == null && ! new ModelStore(context.rankProfile().applicationPackage(), modelName).exists()) throw new IllegalArgumentException("No model '" + modelPath + "' is available. Available models: " + - context.importedModels().all().stream().map(ImportedModel::source).collect(Collectors.joining(", "))); + context.importedModels().all().stream().map(ImportedMlModel::source).collect(Collectors.joining(", "))); if (sourceModel != null) { return fromSource(modelName, @@ -117,7 +117,7 @@ public class ConvertedModel { String modelDescription, RankProfile rankProfile, QueryProfileRegistry queryProfileRegistry, - ImportedModel importedModel) { + ImportedMlModel importedModel) { ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); return new ConvertedModel(modelName, modelDescription, @@ -188,7 +188,7 @@ public class ConvertedModel { // ----------------------- Static model conversion/storage below here - private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model, + private static Map<String, ExpressionFunction> convertAndStore(ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles, ModelStore store) { @@ -203,8 +203,9 @@ public class ConvertedModel { // Add expressions Map<String, ExpressionFunction> expressions = new HashMap<>(); - for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { - addExpression(output.getSecond(), output.getFirst(), + for (ImportedMlFunction outputFunction : model.outputExpressions()) { + ExpressionFunction expression = asExpressionFunction(outputFunction); + addExpression(expression, expression.getName(), constantsReplacedByFunctions, model, store, profile, queryProfiles, expressions); @@ -219,10 +220,27 @@ public class ConvertedModel { return expressions; } + private static ExpressionFunction asExpressionFunction(ImportedMlFunction function) { + try { + Map<String, TensorType> argumentTypes = new HashMap<>(); + for (Map.Entry<String, String> entry : function.argumentTypes().entrySet()) + argumentTypes.put(entry.getKey(), TensorType.fromSpec(entry.getValue())); + + return new ExpressionFunction(function.name(), + function.arguments(), + new RankingExpression(function.expression()), + argumentTypes, + function.returnType().map(TensorType::fromSpec)); + } + catch (ParseException e) { + throw new IllegalArgumentException("Gor an illegal argument from importing " + function.name(), e); + } + } + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, - ImportedModel model, + ImportedMlModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, @@ -249,7 +267,9 @@ public class ConvertedModel { return store.readExpressions(); } - private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); store.writeSmallConstant(constantName, constantValue); profile.addConstant(constantName, asValue(constantValue)); } @@ -259,7 +279,8 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, Set<String> constantsReplacedByFunctions, String constantName, - Tensor constantValue) { + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName); if (rankingExpressionFunctionOverridingConstant != null) { TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles)); @@ -302,19 +323,19 @@ public class ConvertedModel { * Verify that the inputs declared in the given expression exists in the given rank profile as functions, * and return tensors of the correct types. */ - private static void verifyInputs(RankingExpression expression, ImportedModel model, + private static void verifyInputs(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { - TensorType requiredType = model.inputs().get(functionName); - if (requiredType == null) continue; // Not a required function + Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); + if ( ! requiredType.isPresent()) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); if (rankingExpressionFunction == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + " but this function is not present in " + - profile); + "' of type " + requiredType.get() + + " but this function is not present in " + profile); // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second // phase and summary features), as it may only resolve correctly given those bindings // Or, probably better, annotate the functions with type constraints here and verify during general @@ -322,12 +343,12 @@ public class ConvertedModel { TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + + "' of type " + requiredType.get() + " which must be produced by a function in the rank profile, but " + "this function references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) + if ( ! actualType.isAssignableTo(requiredType.get())) throw new IllegalArgumentException("Model refers input '" + functionName + "'. " + - typeMismatchExplanation(requiredType, actualType)); + typeMismatchExplanation(requiredType.get(), actualType)); } } @@ -339,8 +360,8 @@ public class ConvertedModel { } /** Add the generated functions to the rank profile */ - private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) { - model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy())); + private static void addGeneratedFunctions(ImportedMlModel model, RankProfile profile) { + model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v))); } /** @@ -348,7 +369,7 @@ public class ConvertedModel { * function specifies that a single exemplar should be evaluated, we can * reduce the batch dimension out. */ - private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model, + private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); TensorType typeBeforeReducing = expression.getRoot().type(typeContext); @@ -376,7 +397,7 @@ public class ConvertedModel { expression.setRoot(root); } - private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, + private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, TypeContext<Reference> typeContext) { if (node instanceof TensorFunctionNode) { TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); @@ -384,7 +405,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -392,7 +413,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -482,13 +503,13 @@ public class ConvertedModel { return node; } - private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { + private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedMlModel model) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; if (referenceNode.getOutput() == null) { // function references cannot specify outputs names.add(referenceNode.getName()); if (model.functions().containsKey(referenceNode.getName())) { - addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model); + addFunctionNamesIn(RankingExpression.from(model.functions().get(referenceNode.getName())).getRoot(), names, model); } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java b/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java index 078fdf87c31..8158de692a9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java @@ -7,6 +7,7 @@ import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.application.validation.RestartConfigs; import com.yahoo.vespa.model.content.SearchCoverage; +import com.yahoo.vespa.model.content.TuningDispatch; import java.util.ArrayList; import java.util.List; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java index fe83148d7af..b3b530448fc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java @@ -412,8 +412,8 @@ public class IndexedSearchCluster extends SearchCluster builder.minGroupCoverage(tuning.dispatch.minGroupCoverage); if (tuning.dispatch.policy != null) { switch (tuning.dispatch.policy) { - case RANDOM: - builder.distributionPolicy(DistributionPolicy.RANDOM); + case ADAPTIVE: + builder.distributionPolicy(DistributionPolicy.ADAPTIVE); break; case ROUNDROBIN: builder.distributionPolicy(DistributionPolicy.ROUNDROBIN); @@ -422,6 +422,15 @@ public class IndexedSearchCluster extends SearchCluster } builder.maxNodesDownPerGroup(rootDispatch.getMaxNodesDownPerFixedRow()); builder.useMultilevelDispatch(useMultilevelDispatchSetup()); + builder.searchableCopies(rootDispatch.getSearchableCopies()); + if (searchCoverage != null) { + if (searchCoverage.getMinimum() != null) + builder.minSearchCoverage(searchCoverage.getMinimum()); + if (searchCoverage.getMinWaitAfterCoverageFactor() != null) + builder.minWaitAfterCoverageFactor(searchCoverage.getMinWaitAfterCoverageFactor()); + if (searchCoverage.getMaxWaitAfterCoverageFactor() != null) + builder.maxWaitAfterCoverageFactor(searchCoverage.getMaxWaitAfterCoverageFactor()); + } } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java b/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java index 5581b9b87ba..2e3c0681f75 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java @@ -43,7 +43,7 @@ public class Tuning extends AbstractConfigProducer implements PartitionsConfig.P } for (PartitionsConfig.Dataset.Builder dataset : builder.dataset) { switch (policy) { - case RANDOM: + case ADAPTIVE: dataset.useroundrobinforfixedrow(false); break; case ROUNDROBIN: diff --git a/config-model/src/main/resources/schema/content.rnc b/config-model/src/main/resources/schema/content.rnc index 58d22ea9b6f..c0313cd50ef 100644 --- a/config-model/src/main/resources/schema/content.rnc +++ b/config-model/src/main/resources/schema/content.rnc @@ -114,9 +114,9 @@ Content = element content { Documents? & ContentNodes? & TopGroup? & - Controllers? & + Controllers? # Contains experimental feature switches - Experimental? + #Experimental? } Controllers = @@ -368,9 +368,9 @@ TuningCompression = element compression { element level { xsd:nonNegativeInteger }? } -Experimental = element experimental { - element enable-multiple-bucket-spaces { xsd:boolean }? -} +#Experimental = element experimental { +# Put experimental flags here +#} Thread = element thread { ## The lowest priority this thread should handle. diff --git a/config-model/src/test/derived/inheritance/documentmanager.cfg b/config-model/src/test/derived/inheritance/documentmanager.cfg index 754144c0af9..47f697c80f7 100644 --- a/config-model/src/test/derived/inheritance/documentmanager.cfg +++ b/config-model/src/test/derived/inheritance/documentmanager.cfg @@ -79,8 +79,6 @@ datatype[].documenttype[].bodystruct -1989003153 datatype[].documenttype[].fieldsets{[document]}.fields[] "onlygrandparent" datatype[].documenttype[].fieldsets{[document]}.fields[] "onlymother" datatype[].documenttype[].fieldsets{[document]}.fields[] "overridden" -datatype[].documenttype[].fieldsets{[document]}.fields[] "rankfeatures" -datatype[].documenttype[].fieldsets{[document]}.fields[] "summaryfeatures" datatype[].id 2126589281 datatype[].structtype[].name "father.header" datatype[].structtype[].version 0 @@ -113,8 +111,6 @@ datatype[].documenttype[].bodystruct -1742340170 datatype[].documenttype[].fieldsets{[document]}.fields[] "onlyfather" datatype[].documenttype[].fieldsets{[document]}.fields[] "onlygrandparent" datatype[].documenttype[].fieldsets{[document]}.fields[] "overridden" -datatype[].documenttype[].fieldsets{[document]}.fields[] "rankfeatures" -datatype[].documenttype[].fieldsets{[document]}.fields[] "summaryfeatures" datatype[].id 81425825 datatype[].structtype[].name "child.header" datatype[].structtype[].version 0 @@ -151,5 +147,3 @@ datatype[].documenttype[].fieldsets{[document]}.fields[] "onlyfather" datatype[].documenttype[].fieldsets{[document]}.fields[] "onlygrandparent" datatype[].documenttype[].fieldsets{[document]}.fields[] "onlymother" datatype[].documenttype[].fieldsets{[document]}.fields[] "overridden" -datatype[].documenttype[].fieldsets{[document]}.fields[] "rankfeatures" -datatype[].documenttype[].fieldsets{[document]}.fields[] "summaryfeatures" diff --git a/config-model/src/test/derived/inheritfromgrandparent/documentmanager.cfg b/config-model/src/test/derived/inheritfromgrandparent/documentmanager.cfg index 97babb77bd1..8d5bc57ef31 100644 --- a/config-model/src/test/derived/inheritfromgrandparent/documentmanager.cfg +++ b/config-model/src/test/derived/inheritfromgrandparent/documentmanager.cfg @@ -72,8 +72,6 @@ datatype[].documenttype[].inherits[].name "document" datatype[].documenttype[].inherits[].version 0 datatype[].documenttype[].headerstruct 836075987 datatype[].documenttype[].bodystruct -389494616 -datatype[].documenttype[].fieldsets{[document]}.fields[] "rankfeatures" -datatype[].documenttype[].fieldsets{[document]}.fields[] "summaryfeatures" datatype[].id 81425825 datatype[].structtype[].name "child.header" datatype[].structtype[].version 0 @@ -101,5 +99,3 @@ datatype[].documenttype[].inherits[].version 0 datatype[].documenttype[].headerstruct 81425825 datatype[].documenttype[].bodystruct -126593034 datatype[].documenttype[].fieldsets{[document]}.fields[] "child_field" -datatype[].documenttype[].fieldsets{[document]}.fields[] "rankfeatures" -datatype[].documenttype[].fieldsets{[document]}.fields[] "summaryfeatures" diff --git a/config-model/src/test/derived/inheritfromparent/documentmanager.cfg b/config-model/src/test/derived/inheritfromparent/documentmanager.cfg index 6f99d9de53b..154b6524c33 100644 --- a/config-model/src/test/derived/inheritfromparent/documentmanager.cfg +++ b/config-model/src/test/derived/inheritfromparent/documentmanager.cfg @@ -55,7 +55,7 @@ datatype[].documenttype[].inherits[].name "document" datatype[].documenttype[].inherits[].version 0 datatype[].documenttype[].headerstruct 836075987 datatype[].documenttype[].bodystruct -389494616 -datatype[].documenttype[].fieldsets{[document]}.fields[] "weight_src" +datatype[].documenttype[].fieldsets{[]}.fields[] "weight_src" datatype[].id 81425825 datatype[].structtype[].name "child.header" datatype[].structtype[].version 0 @@ -82,8 +82,5 @@ datatype[].documenttype[].inherits[].name "parent" datatype[].documenttype[].inherits[].version 0 datatype[].documenttype[].headerstruct 81425825 datatype[].documenttype[].bodystruct -126593034 -datatype[].documenttype[].fieldsets{[document]}.fields[] "child_field" -datatype[].documenttype[].fieldsets{[document]}.fields[] "rankfeatures" -datatype[].documenttype[].fieldsets{[document]}.fields[] "summaryfeatures" -datatype[].documenttype[].fieldsets{[document]}.fields[] "weight" -datatype[].documenttype[].fieldsets{[document]}.fields[] "weight_src" +datatype[].documenttype[].fieldsets{[]}.fields[] "child_field" +datatype[].documenttype[].fieldsets{[]}.fields[] "weight_src" diff --git a/config-model/src/test/derived/inheritfromparent/documenttypes.cfg b/config-model/src/test/derived/inheritfromparent/documenttypes.cfg index d8493eefe95..70c4bc4297c 100644 --- a/config-model/src/test/derived/inheritfromparent/documenttypes.cfg +++ b/config-model/src/test/derived/inheritfromparent/documenttypes.cfg @@ -75,7 +75,7 @@ documenttype[].datatype[].sstruct.compression.type NONE documenttype[].datatype[].sstruct.compression.level 0 documenttype[].datatype[].sstruct.compression.threshold 95 documenttype[].datatype[].sstruct.compression.minsize 200 -documenttype[].fieldsets{[document]}.fields[] "weight_src" +documenttype[].fieldsets{[]}.fields[] "weight_src" documenttype[].id 746267614 documenttype[].name "child" documenttype[].version 0 @@ -118,8 +118,5 @@ documenttype[].datatype[].sstruct.compression.type NONE documenttype[].datatype[].sstruct.compression.level 0 documenttype[].datatype[].sstruct.compression.threshold 95 documenttype[].datatype[].sstruct.compression.minsize 200 -documenttype[].fieldsets{[document]}.fields[] "child_field" -documenttype[].fieldsets{[document]}.fields[] "rankfeatures" -documenttype[].fieldsets{[document]}.fields[] "summaryfeatures" -documenttype[].fieldsets{[document]}.fields[] "weight" -documenttype[].fieldsets{[document]}.fields[] "weight_src" +documenttype[].fieldsets{[]}.fields[] "child_field" +documenttype[].fieldsets{[]}.fields[] "weight_src" diff --git a/config-model/src/test/examples/position_base.sd b/config-model/src/test/examples/position_base.sd new file mode 100644 index 00000000000..8f8ff85cff2 --- /dev/null +++ b/config-model/src/test/examples/position_base.sd @@ -0,0 +1,8 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +search position_base { + document position_base { + field pos type position { + indexing: attribute + } + } +} diff --git a/config-model/src/test/examples/position_inherited.sd b/config-model/src/test/examples/position_inherited.sd new file mode 100644 index 00000000000..b3341e01f80 --- /dev/null +++ b/config-model/src/test/examples/position_inherited.sd @@ -0,0 +1,4 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +search position_inherited { + document position_inherited inherits position_base {} +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java index 07a36832094..131972ffc73 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.yolean.Exceptions; import org.junit.Test; @@ -26,7 +26,7 @@ public class IncorrectRankingExpressionFileRefTestCase extends SearchDefinitionT Search search = SearchBuilder.buildFromFile("src/test/examples/incorrectrankingexpressionfileref.sd", registry, new QueryProfileRegistry()); - new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // cause rank profile parsing fail("parsing should have failed"); } catch (IllegalArgumentException e) { String message = Exceptions.toMessageString(e); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 4df3add13c5..06761ad45bc 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.Iterator; @@ -91,7 +91,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertEquals(8, rankProfile.getNumThreadsPerSearch()); assertEquals(70, rankProfile.getMinHitsPerThread()); assertEquals(1200, rankProfile.getNumSearchPartitions()); - RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), attributeFields); + RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedMlModels(), attributeFields); assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").isPresent()); assertEquals("0.78", findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").get()); assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.numthreadspersearch").isPresent()); @@ -126,7 +126,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { } private static void assertAttributeTypeSettings(RankProfile profile, Search search) { - RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)); + RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search)); assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.a").get()); assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.attribute.b").get()); assertEquals("tensor(x[])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.c").get()); @@ -168,7 +168,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { } private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) { - RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)); + RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search)); assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor1").get()); assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor2").get()); assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java index 8df3985fd24..1d6a75f039d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -55,7 +55,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { assertEquals("query(a) = 1500", parent.getRankProperties().get(0).toString()); // Check derived model - RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedModels(), attributeFields); + RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedMlModels(), attributeFields); assertEquals("(query(a),1500)", rawParent.configProperties().get(0).toString()); } @@ -67,7 +67,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { // Check derived model RawRankProfile rawChild = new RawRankProfile(rankProfileRegistry.get(search, "child"), new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), attributeFields); assertEquals("(query(a),2000)", rawChild.configProperties().get(0).toString()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index 150469cc928..af6507f352d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -3,7 +3,7 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; @@ -67,19 +67,19 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.get(s, "parent").compile(queryProfileRegistry, new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(queryProfileRegistry, new ImportedMlModels()); assertEquals("0.0", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child1 = rankProfileRegistry.get(s, "child1").compile(queryProfileRegistry, new ImportedModels()); + RankProfile child1 = rankProfileRegistry.get(s, "child1").compile(queryProfileRegistry, new ImportedMlModels()); assertEquals("6.5", child1.getFirstPhaseRanking().getRoot().toString()); assertEquals("11.5", child1.getSecondPhaseRanking().getRoot().toString()); - RankProfile child2 = rankProfileRegistry.get(s, "child2").compile(queryProfileRegistry, new ImportedModels()); + RankProfile child2 = rankProfileRegistry.get(s, "child2").compile(queryProfileRegistry, new ImportedMlModels()); assertEquals("16.6", child2.getFirstPhaseRanking().getRoot().toString()); assertEquals("foo: 14.0", child2.getFunctions().get("foo").function().getBody().toString()); List<Pair<String, String>> rankProperties = new RawRankProfile(child2, queryProfileRegistry, - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString()); assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString()); @@ -110,7 +110,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); try { - rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); fail("Should have caused an exception"); } catch (IllegalArgumentException e) { @@ -171,7 +171,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("safeLog(popShareSlowDecaySignal,myValue)", profile.getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString()); assertEquals("safeLog(popShareSlowDecaySignal,-9.21034037)", - profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString()); } @Test @@ -194,7 +194,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase Search s = builder.getSearch(); RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("k1 + (k2 + k3) / 100000000.0", - profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); } @Test @@ -220,7 +220,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase Search s = builder.getSearch(); RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("0.5 + 50 * (attribute(rating_yelp) - 3)", - profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index e507a6c48e4..368f6fec80e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -6,8 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import com.yahoo.yolean.Exceptions; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.Optional; @@ -64,10 +63,10 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("7.0 * (9 + attribute(a))", child.getFirstPhaseRanking().getRoot().toString()); } @@ -124,14 +123,14 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("17.0", parent.getFirstPhaseRanking().getRoot().toString()); assertEquals("0.0", parent.getSecondPhaseRanking().getRoot().toString()); assertEquals("10.0", getRankingExpression("foo", parent, s)); assertEquals("17.0", getRankingExpression("firstphase", parent, s)); assertEquals("0.0", getRankingExpression("secondphase", parent, s)); - RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("31.0 + bar + arg(4.0)", child.getFirstPhaseRanking().getRoot().toString()); assertEquals("24.0", child.getSecondPhaseRanking().getRoot().toString()); assertEquals("12.0", getRankingExpression("foo", child, s)); @@ -180,7 +179,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("attribute(a) + C + (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString()); assertEquals("attribute(a) + attribute(b)", getRankingExpression("C", test, s)); assertEquals("attribute(b) + 1", getRankingExpression("D", test, s)); @@ -211,7 +210,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase private String getRankingExpression(String name, RankProfile rankProfile, Search search) { Optional<String> rankExpression = - new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)) + new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search)) .configProperties() .stream() .filter(r -> r.getFirst().equals("rankingExpression(" + name + ").rankingScript")) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 6a1e5b207c6..a0deedb404a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.List; @@ -45,10 +45,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(sin).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -89,10 +89,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(tan).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -139,10 +139,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(sin).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -203,10 +203,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java index 5e649c2e551..830b7d531c3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java @@ -4,9 +4,8 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.yolean.Exceptions; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -27,7 +26,7 @@ public class RankingExpressionValidationTestCase extends SearchDefinitionTestCas try { RankProfileRegistry registry = new RankProfileRegistry(); Search search = importWithExpression(expression, registry); - new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // cause rank profile parsing fail("No exception on incorrect ranking expression " + expression); } catch (IllegalArgumentException e) { // Success diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/SearchDefinitionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/SearchDefinitionTestCase.java index fa4280d2236..21c7362f793 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/SearchDefinitionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/SearchDefinitionTestCase.java @@ -32,6 +32,8 @@ public abstract class SearchDefinitionTestCase { writer.newLine(); writer.flush(); writer.close(); + System.err.println(e.getMessage() + " [not equal files: >>>"+expectedFile+"<<< and >>>"+cfgFile+"<<< in assertConfigFiles]"); + return; } throw new AssertionError(e.getMessage() + " [not equal files: >>>"+expectedFile+"<<< and >>>"+cfgFile+"<<< in assertConfigFiles]", e); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java index f39896f5779..5e8a4597a2d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java @@ -3,12 +3,11 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.document.DocumenttypesConfig; import com.yahoo.document.config.DocumentmanagerConfig; -import com.yahoo.io.IOUtils; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.vespa.configmodel.producers.DocumentManager; import com.yahoo.vespa.configmodel.producers.DocumentTypes; @@ -38,7 +37,7 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase DerivedConfiguration config = new DerivedConfiguration(builder.getSearch(searchDefinitionName), builder.getRankProfileRegistry(), builder.getQueryProfileRegistry(), - new ImportedModels()); + new ImportedMlModels()); return export(dirName, builder, config); } @@ -46,7 +45,7 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase DerivedConfiguration config = new DerivedConfiguration(search, builder.getRankProfileRegistry(), builder.getQueryProfileRegistry(), - new ImportedModels()); + new ImportedMlModels()); return export(dirName, builder, config); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java index f4344c9b03c..2160dda45aa 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java @@ -9,12 +9,9 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; -import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; -import java.io.IOException; - /** * Tests deriving rank for files from search definitions * @@ -35,7 +32,7 @@ public class EmptyRankProfileTestCase extends SearchDefinitionTestCase { doc.addField(new SDField("c", DataType.STRING)); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java index 2bae285301c..a7821615f48 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java @@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; @@ -42,7 +42,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); new Processing().process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true, false); - DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); // Check attribute fields derived.getAttributeFields(); // TODO: assert content @@ -73,7 +73,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedModels()); + DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedMlModels()); // Check il script addition assertIndexing(Arrays.asList("clear_state | guard { input a | tokenize normalize stem:\"SHORTEST\" | index a; }", @@ -100,7 +100,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { field2.setLiteralBoost(20); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); assertIndexing(Arrays.asList("clear_state | guard { input title | tokenize normalize stem:\"SHORTEST\" | summary title | index title; }", "clear_state | guard { input body | tokenize normalize stem:\"SHORTEST\" | summary body | index body; }", "clear_state | guard { input title | tokenize | index title_literal; }", diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java index 723cd58a34a..61d1cd36f56 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.io.File; @@ -36,7 +36,7 @@ public class SimpleInheritTestCase extends AbstractExportingTestCase { DerivedConfiguration config = new DerivedConfiguration(search, builder.getRankProfileRegistry(), new QueryProfileRegistry(), - new ImportedModels()); + new ImportedMlModels()); config.export(toDirName); checkDir(toDirName, expectedResultsDirName); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java index 8941b07101d..a34d4de4f51 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java @@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; @@ -34,7 +34,7 @@ public class TypeConversionTestCase extends SearchDefinitionTestCase { document.addField(a); new Processing().process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true, false); - DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); IndexInfo indexInfo = derived.getIndexInfo(); assertFalse(indexInfo.hasCommand("default", "compact-to-term")); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java index d38bce04617..ae70061696b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java @@ -6,13 +6,11 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.DerivedConfiguration; -import com.yahoo.searchdefinition.derived.Deriver; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import org.junit.Ignore; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; -import java.io.File; + import java.io.IOException; import static org.junit.Assert.assertEquals; @@ -100,7 +98,7 @@ public class ImplicitSearchFieldsTestCase extends SearchDefinitionTestCase { sb.importFile("src/test/examples/nextgen/simple.sd"); sb.build(); assertNotNull(sb.getSearch()); - new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedModels()); + new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedMlModels()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/PositionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/PositionTestCase.java index d8749d8eb32..9cf555e2c9a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/PositionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/PositionTestCase.java @@ -6,6 +6,7 @@ import com.yahoo.document.PositionDataType; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.document.Attribute; +import com.yahoo.searchdefinition.document.FieldSet; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.vespa.documentmodel.SummaryField; import com.yahoo.vespa.documentmodel.SummaryTransform; @@ -14,6 +15,7 @@ import org.junit.Ignore; import org.junit.Test; import java.io.IOException; +import java.util.Arrays; import java.util.Iterator; import static org.junit.Assert.*; @@ -26,6 +28,17 @@ import static org.junit.Assert.*; public class PositionTestCase { @Test + public void inherited_position_zcurve_field_is_not_added_to_document_fieldset() throws Exception { + SearchBuilder sb = SearchBuilder.createFromFiles(Arrays.asList( + "src/test/examples/position_base.sd", + "src/test/examples/position_inherited.sd")); + + Search search = sb.getSearch("position_inherited"); + FieldSet fieldSet = search.getDocument().getFieldSets().builtInFieldSets().get("[document]"); // TODO why is this not public in BuiltInFieldSets? + assertFalse(fieldSet.getFieldNames().contains(PositionDataType.getZCurveFieldName("pos"))); + } + + @Test public void requireThatPositionCanBeAttribute() throws Exception { Search search = SearchBuilder.buildFromFile("src/test/examples/position_attribute.sd"); assertNull(search.getAttribute("pos")); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index cff9abb08ed..9df03f25cb3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,7 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.google.common.collect.ImmutableList; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; @@ -10,7 +12,10 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; @@ -26,6 +31,9 @@ import static org.junit.Assert.assertEquals; */ class RankProfileSearchFixture { + private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; private Search search; @@ -83,7 +91,8 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) - .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); + .compile(queryProfileRegistry, + new ImportedMlModels(applicationDir.toFile(), importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index fd048737b43..d8f1a2ba545 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.io.IOException; @@ -39,7 +39,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { List<Pair<String, String>> rankProperties = new RawRankProfile(functionsRankProfile, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(search)).configProperties(); assertEquals(6, rankProperties.size()); @@ -65,7 +65,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressioninfile", registry, new QueryProfileRegistry()).getSearch(); - new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // rank profile parsing happens during deriving + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // rank profile parsing happens during deriving } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 6e3a227e2a9..fe1d722a49c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.List; @@ -200,10 +200,10 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { "}\n"); builder.build(true, new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); return testRankProperties; } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java index 4bdb809f546..dff21904c75 100755 --- a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java @@ -12,9 +12,7 @@ import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg; import org.junit.Test; import org.w3c.dom.Element; -import org.xml.sax.SAXException; -import java.io.IOException; import java.io.StringReader; import static org.hamcrest.CoreMatchers.is; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java index 396fe3e0af5..aa1ac401014 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java @@ -163,11 +163,8 @@ public class ContainerModelBuilderTest extends ContainerModelBuilderTestBase { @Test public void requireThatJvmGCOptionsIsHonoured() throws IOException, SAXException { - final Zone US_EAST_3 = new Zone(Environment.prod, RegionName.from("us-east-3")); verifyJvmGCOptions(false, Zone.defaultZone(),ContainerCluster.CMS); - verifyJvmGCOptions(false, US_EAST_3, ContainerCluster.CMS); - verifyJvmGCOptions(true, Zone.defaultZone(), ContainerCluster.CMS); - verifyJvmGCOptions(true, US_EAST_3, ContainerCluster.G1GC); + verifyJvmGCOptions(true, Zone.defaultZone(), ContainerCluster.G1GC); } @Test diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java index 0156128f7ca..83b4cfebca5 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java @@ -40,8 +40,11 @@ public class ContentSearchClusterTest { } private static ContentCluster createClusterWithGlobalType() throws Exception { - return createCluster(createClusterBuilderWithGlobalType().getXml(), - createSearchDefinitions("global", "regular")); + return createClusterFromBuilderAndDocTypes(createClusterBuilderWithGlobalType(), "global", "regular"); + } + + private static ContentCluster createClusterWithoutGlobalType() throws Exception { + return createClusterFromBuilderAndDocTypes(createClusterBuilderWithOnlyDefaultTypes(), "marve", "fleksnes"); } private static ContentCluster createClusterFromBuilderAndDocTypes(ContentClusterBuilder builder, String... docTypes) throws Exception { @@ -49,26 +52,10 @@ public class ContentSearchClusterTest { "<node distribution-key='0' hostalias='mockhost'/>", "<node distribution-key='1' hostalias='mockhost'/>", "</group>")); - builder.enableMultipleBucketSpaces(true); String clusterXml = builder.getXml(); return createCluster(clusterXml, createSearchDefinitions(docTypes)); } - private static ContentCluster createClusterWithMultipleBucketSpacesEnabled() throws Exception { - return createClusterFromBuilderAndDocTypes(createClusterBuilderWithGlobalType(), "global", "regular"); - } - - private static ContentCluster createClusterWithMultipleBucketSpacesEnabledButNoGlobalDocs() throws Exception { - return createClusterFromBuilderAndDocTypes(createClusterBuilderWithOnlyDefaultTypes(), "marve", "fleksnes"); - } - - private static ContentCluster createClusterWithGlobalDocsButNotMultipleSpacesEnabled() throws Exception { - return createCluster(createClusterBuilderWithGlobalType() - .enableMultipleBucketSpaces(false) - .getXml(), - createSearchDefinitions("global", "regular")); - } - private static ContentClusterBuilder createClusterBuilderWithGlobalType() { return new ContentClusterBuilder() .docTypes(Arrays.asList(DocType.indexGlobal("global"), DocType.index("regular"))); @@ -178,41 +165,18 @@ public class ContentSearchClusterTest { assertEquals(2, config.documenttype().size()); assertDocumentType("global", "global", config.documenttype(0)); assertDocumentType("regular", "default", config.documenttype(1)); - // Safeguard against flipping the switch - assertFalse(config.enable_multiple_bucket_spaces()); - } - - @Test - public void require_that_multiple_bucket_spaces_can_be_force_enabled() throws Exception { - ContentCluster cluster = createClusterWithMultipleBucketSpacesEnabled(); - { - BucketspacesConfig config = getBucketspacesConfig(cluster); - assertEquals(2, config.documenttype().size()); - assertDocumentType("global", "global", config.documenttype(0)); - assertDocumentType("regular", "default", config.documenttype(1)); - assertTrue(config.enable_multiple_bucket_spaces()); - } - { - assertTrue(getFleetcontrollerConfig(cluster).enable_multiple_bucket_spaces()); - } } @Test public void cluster_with_global_document_types_sets_cluster_controller_global_docs_config_option() throws Exception { - ContentCluster cluster = createClusterWithMultipleBucketSpacesEnabled(); + ContentCluster cluster = createClusterWithGlobalType(); assertTrue(getFleetcontrollerConfig(cluster).cluster_has_global_document_types()); } @Test public void cluster_without_global_document_types_unsets_cluster_controller_global_docs_config_option() throws Exception { - ContentCluster cluster = createClusterWithMultipleBucketSpacesEnabledButNoGlobalDocs(); + ContentCluster cluster = createClusterWithoutGlobalType(); assertFalse(getFleetcontrollerConfig(cluster).cluster_has_global_document_types()); } - @Test - public void controller_global_documents_config_always_enabled_even_without_experimental_flag_set() throws Exception { - ContentCluster cluster = createClusterWithGlobalDocsButNotMultipleSpacesEnabled(); - assertTrue(getFleetcontrollerConfig(cluster).cluster_has_global_document_types()); - } - } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/IndexedHierarchicDistributionTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/IndexedHierarchicDistributionTest.java index 4749ef314b4..f402bae8fd9 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/IndexedHierarchicDistributionTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/IndexedHierarchicDistributionTest.java @@ -184,7 +184,15 @@ public class IndexedHierarchicDistributionTest { " </group>", ""); } private ContentCluster getIllegalGroupsCluster() throws Exception { - return createCluster(createClusterXml(getOddGroupsClusterXml(), 4, 4)); + return createCluster(createClusterXml(getOddGroupsClusterXml(), Optional.of(getRoundRobinDispatchXml()), 4, 4)); + } + + private String getRoundRobinDispatchXml() { + return joinLines("<tuning>", + " <dispatch>", + " <dispatch-policy>round-robin</dispatch-policy>", + " </dispatch>", + "</tuning>"); } private String getRandomDispatchXml() { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/IndexingAndDocprocRoutingTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/IndexingAndDocprocRoutingTest.java index 071d51aae52..9cdb2606241 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/IndexingAndDocprocRoutingTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/IndexingAndDocprocRoutingTest.java @@ -1,7 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.content; -import com.yahoo.messagebus.routing.*; +import com.yahoo.messagebus.routing.Hop; +import com.yahoo.messagebus.routing.HopBlueprint; +import com.yahoo.messagebus.routing.PolicyDirective; +import com.yahoo.messagebus.routing.Route; +import com.yahoo.messagebus.routing.RoutingTable; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ContainerCluster; @@ -16,7 +20,12 @@ import org.junit.Test; import org.xml.sax.SAXException; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.hamcrest.Matchers.*; import static org.junit.Assert.assertNotNull; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/TuningDispatchTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/TuningDispatchTest.java index 14b7f045ca8..e84f256d6dc 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/TuningDispatchTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/TuningDispatchTest.java @@ -30,7 +30,17 @@ public class TuningDispatchTest { TuningDispatch dispatch = new TuningDispatch.Builder() .setDispatchPolicy("random") .build(); - assertTrue(TuningDispatch.DispatchPolicy.RANDOM == dispatch.getDispatchPolicy()); + assertTrue(TuningDispatch.DispatchPolicy.ADAPTIVE == dispatch.getDispatchPolicy()); + assertNull(dispatch.getMinGroupCoverage()); + assertNull(dispatch.getMinActiveDocsCoverage()); + } + + @Test + public void requireThatWeightedDispatchWork() { + TuningDispatch dispatch = new TuningDispatch.Builder() + .setDispatchPolicy("adaptive") + .build(); + assertTrue(TuningDispatch.DispatchPolicy.ADAPTIVE == dispatch.getDispatchPolicy()); assertNull(dispatch.getMinGroupCoverage()); assertNull(dispatch.getMinActiveDocsCoverage()); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/ClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/ClusterTest.java index b4994e5d009..1a2eb93face 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/ClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/ClusterTest.java @@ -20,6 +20,7 @@ import java.util.List; import static com.yahoo.config.model.test.TestUtil.joinLines; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author Simon Thoresen Hult @@ -49,6 +50,7 @@ public class ClusterTest { " <max-wait-after-coverage-factor>0.58</max-wait-after-coverage-factor>", " </coverage>", "</search>")); + assertEquals(1, cluster.getSearch().getIndexed().getTLDs().size()); for (Dispatch tld : cluster.getSearch().getIndexed().getTLDs()) { PartitionsConfig.Builder builder = new PartitionsConfig.Builder(); tld.getConfig(builder); @@ -57,6 +59,22 @@ public class ClusterTest { assertEquals(0.23, config.dataset(0).higher_coverage_minsearchwait(), 1E-6); assertEquals(0.58, config.dataset(0).higher_coverage_maxsearchwait(), 1E-6); assertEquals(2, config.dataset(0).searchablecopies()); + assertTrue(config.dataset(0).useroundrobinforfixedrow()); + } + } + + @Test + public void requireThatDispatchTuningIsApplied() throws ParseException { + ContentCluster cluster = newContentCluster(joinLines("<search>", "</search>"), + joinLines("<tuning>", + "</tuning>")); + assertEquals(1, cluster.getSearch().getIndexed().getTLDs().size()); + for (Dispatch tld : cluster.getSearch().getIndexed().getTLDs()) { + PartitionsConfig.Builder builder = new PartitionsConfig.Builder(); + tld.getConfig(builder); + PartitionsConfig config = new PartitionsConfig(builder); + assertEquals(2, config.dataset(0).searchablecopies()); + assertTrue(config.dataset(0).useroundrobinforfixedrow()); } } @@ -70,34 +88,49 @@ public class ClusterTest { } private static ContentCluster newContentCluster(String contentSearchXml) throws ParseException { - return newContentCluster(contentSearchXml, false); + return newContentCluster(contentSearchXml, "", false); + } + + private static ContentCluster newContentCluster(String contentSearchXml, String searchNodeTuningXml) throws ParseException { + return newContentCluster(contentSearchXml, searchNodeTuningXml, false); } private static ContentCluster newContentCluster(String contentSearchXml, boolean globalDocType) throws ParseException { + return newContentCluster(contentSearchXml, "", globalDocType); + } + + private static ContentCluster newContentCluster(String contentSearchXml, String searchNodeTuningXml, boolean globalDocType) throws ParseException { ApplicationPackage app = new MockApplicationPackage.Builder() - .withHosts(joinLines("<hosts>", - " <host name='localhost'><alias>my_host</alias></host>", - "</hosts>")) - .withServices(joinLines("<services version='1.0'>", - " <admin version='2.0'>", - " <adminserver hostalias='my_host' />", - " </admin>", - " <content version='1.0'>", - " <redundancy>3</redundancy>", - " <documents>", - " " + getDocumentXml(globalDocType), - " </documents>", - " <engine>", - " <proton>", - " <searchable-copies>2</searchable-copies>", - " </proton>", - " </engine>", - " <group>", - " <node hostalias='my_host' distribution-key='0' />", - " </group>", - contentSearchXml, - " </content>", - "</services>")) + .withHosts(joinLines( + "<hosts>", + " <host name='localhost'><alias>my_host</alias></host>", + "</hosts>")) + .withServices(joinLines( + "<services version='1.0'>", + " <admin version='2.0'>", + " <adminserver hostalias='my_host' />", + " </admin>", + "<jdisc id='foo' version='1.0'>", + " <search />", + " <nodes><node hostalias='my_host' /></nodes>", + "</jdisc>", + " <content version='1.0'>", + " <redundancy>3</redundancy>", + " <documents>", + " " + getDocumentXml(globalDocType), + " </documents>", + " <engine>", + " <proton>", + " <searchable-copies>2</searchable-copies>", + searchNodeTuningXml, + " </proton>", + " </engine>", + " <group>", + " <node hostalias='my_host' distribution-key='0' />", + " </group>", + contentSearchXml, + " </content>", + "</services>")) .withSearchDefinitions(ApplicationPackageUtils.generateSearchDefinition("my_document")) .build(); List<Content> contents = new TestDriver().buildModel(app).getConfigModels(Content.class); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/DomTuningDispatchBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/DomTuningDispatchBuilderTest.java index 8c6fe110b0a..d9646c163e4 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/DomTuningDispatchBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/cluster/DomTuningDispatchBuilderTest.java @@ -86,7 +86,7 @@ public class DomTuningDispatchBuilderTest { " </dispatch>" + " </tuning>" + "</content>"); - assertTrue(TuningDispatch.DispatchPolicy.RANDOM == dispatch.getDispatchPolicy()); + assertTrue(TuningDispatch.DispatchPolicy.ADAPTIVE == dispatch.getDispatchPolicy()); } private static TuningDispatch newTuningDispatch(String xml) throws Exception { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterBuilder.java b/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterBuilder.java index 79bc9504659..95c57bb544c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterBuilder.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterBuilder.java @@ -26,7 +26,6 @@ public class ContentClusterBuilder { private Optional<String> dispatchXml = Optional.empty(); private Optional<Double> protonDiskLimit = Optional.empty(); private Optional<Double> protonMemoryLimit = Optional.empty(); - private Optional<Boolean> enableMultipleBucketSpaces = Optional.empty(); public ContentClusterBuilder() { } @@ -78,11 +77,6 @@ public class ContentClusterBuilder { return this; } - public ContentClusterBuilder enableMultipleBucketSpaces(boolean value) { - this.enableMultipleBucketSpaces = Optional.of(value); - return this; - } - public ContentCluster build(MockRoot root) throws Exception { return ContentClusterUtils.createCluster(getXml(), root); } @@ -100,11 +94,6 @@ public class ContentClusterBuilder { if (dispatchXml.isPresent()) { xml += dispatchXml.get(); } - if (enableMultipleBucketSpaces.isPresent()) { - xml += joinLines("<experimental>", - "<enable-multiple-bucket-spaces>" + (enableMultipleBucketSpaces.get() ? "true" : "false") + "</enable-multiple-bucket-spaces>", - "</experimental>"); - } return xml + groupXml + "</content>"; } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 2ae629562d0..f5edc83da5c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -1,11 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import com.google.common.collect.ImmutableList; import com.yahoo.config.model.ApplicationPackageTester; +import com.yahoo.config.model.api.MlModelImporter; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.model.VespaModel; @@ -26,6 +32,10 @@ import static org.junit.Assert.assertEquals; */ public class ImportedModelTester { + private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); + private final String modelName; private final Path applicationDir; @@ -36,7 +46,10 @@ public class ImportedModelTester { public VespaModel createVespaModel() { try { - return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); + DeployState.Builder state = new DeployState.Builder(); + state.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app()); + state.modelImporters(importers); + return new VespaModel(state.build()); } catch (SAXException | IOException e) { throw new RuntimeException(e); diff --git a/config-provisioning/pom.xml b/config-provisioning/pom.xml index 05f94973017..bccb3013360 100644 --- a/config-provisioning/pom.xml +++ b/config-provisioning/pom.xml @@ -30,17 +30,34 @@ Provisioning APIs. </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>config-bundle</artifactId> + <artifactId>component</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>container-dev</artifactId> + <artifactId>configdefinitions</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>config-bundle</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.google.inject</groupId> + <artifactId>guice</artifactId> + <scope>provided</scope> + <classifier>no_aop</classifier> + </dependency> + <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>testutil</artifactId> <version>${project.version}</version> diff --git a/configdefinitions/src/vespa/bucketspaces.def b/configdefinitions/src/vespa/bucketspaces.def index 4db107ec1ee..c9468850018 100644 --- a/configdefinitions/src/vespa/bucketspaces.def +++ b/configdefinitions/src/vespa/bucketspaces.def @@ -10,5 +10,3 @@ documenttype[].name string ## The bucket space this document type belongs to. documenttype[].bucketspace string -## Switch to enable multiple bucket spaces in content layer and content nodes. -enable_multiple_bucket_spaces bool default=false diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def index 487f8ac24c3..50989c3ef74 100644 --- a/configdefinitions/src/vespa/dispatch.def +++ b/configdefinitions/src/vespa/dispatch.def @@ -14,11 +14,23 @@ minGroupCoverage double default=100 maxNodesDownPerGroup int default=0 # Distribution policy for group selection -distributionPolicy enum { ROUNDROBIN, RANDOM } default=ROUNDROBIN +distributionPolicy enum { ROUNDROBIN, ADAPTIVE } default=ROUNDROBIN # Is multi-level dispatch configured for this cluster useMultilevelDispatch bool default=false +# Number of document copies +searchableCopies long default=1 + +# Minimum search coverage required before returning the results of a query +minSearchCoverage double default=100 + +# Minimum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage +minWaitAfterCoverageFactor double default=0 + +# Maximum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage +maxWaitAfterCoverageFactor double default=1 + # The unique key of a search node node[].key int diff --git a/configdefinitions/src/vespa/fleetcontroller.def b/configdefinitions/src/vespa/fleetcontroller.def index ca7ede28cb2..04c9e3b7c73 100644 --- a/configdefinitions/src/vespa/fleetcontroller.def +++ b/configdefinitions/src/vespa/fleetcontroller.def @@ -155,9 +155,6 @@ min_node_ratio_per_group double default=0.0 ## within this duration. max_deferred_task_version_wait_time_sec double default=30.0 -## Switch to enable multiple bucket spaces in cluster controller. -enable_multiple_bucket_spaces bool default=false - ## Whether or not the content cluster the controller has responsibility for ## contains any document types that are tagged as global. If this is true, ## global document-specific behavior is enabled that marks nodes down in the diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 5f60be8c202..3dd6e0090c5 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -55,6 +55,7 @@ <preprocess:include file='config-models.xml' required='false' /> <preprocess:include file='node-repository.xml' required='false' /> <preprocess:include file='hosted-vespa/routing-status.xml' required='false' /> + <preprocess:include file='model-integration.xml' required='true' /> <!-- TODO Vespa 7: Remove scoreboard.xml, replaced by metrics-packets.xml --> <preprocess:include file='hosted-vespa/scoreboard.xml' required='false' /> diff --git a/container-core/src/main/java/com/yahoo/container/handler/Coverage.java b/container-core/src/main/java/com/yahoo/container/handler/Coverage.java index 4a937068d81..84cc0734e7c 100644 --- a/container-core/src/main/java/com/yahoo/container/handler/Coverage.java +++ b/container-core/src/main/java/com/yahoo/container/handler/Coverage.java @@ -28,9 +28,9 @@ public class Coverage { EXPLICITLY_FULL, EXPLICITLY_INCOMPLETE, DOCUMENT_COUNT; } - private final static int DEGRADED_BY_MATCH_PHASE = 1; - private final static int DEGRADED_BY_TIMEOUT = 2; - private final static int DEGRADED_BY_ADAPTIVE_TIMEOUT = 4; + public final static int DEGRADED_BY_MATCH_PHASE = 1; + public final static int DEGRADED_BY_TIMEOUT = 2; + public final static int DEGRADED_BY_ADAPTIVE_TIMEOUT = 4; /** * Build an invalid instance to initiate manually. diff --git a/container-dev/pom.xml b/container-dev/pom.xml index 1b0f36e3adb..7d61882c085 100644 --- a/container-dev/pom.xml +++ b/container-dev/pom.xml @@ -104,18 +104,6 @@ <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> <dependency> @@ -178,18 +166,6 @@ <groupId>xerces</groupId> <artifactId>xercesImpl</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> <dependency> diff --git a/container-disc/pom.xml b/container-disc/pom.xml index 62985192ad3..be4ac23a938 100644 --- a/container-disc/pom.xml +++ b/container-disc/pom.xml @@ -174,6 +174,7 @@ jdisc-security-filters-jar-with-dependencies.jar, jdisc_http_service-jar-with-dependencies.jar, model-evaluation-jar-with-dependencies.jar, + model-integration-jar-with-dependencies.jar, vespaclient-container-plugin-jar-with-dependencies.jar, vespa-athenz-jar-with-dependencies.jar, security-utils-jar-with-dependencies.jar, diff --git a/container-search/src/main/java/com/yahoo/fs4/PongPacket.java b/container-search/src/main/java/com/yahoo/fs4/PongPacket.java index 5a09c4d3d3d..13fb7d84408 100644 --- a/container-search/src/main/java/com/yahoo/fs4/PongPacket.java +++ b/container-search/src/main/java/com/yahoo/fs4/PongPacket.java @@ -12,9 +12,6 @@ import java.util.Optional; */ public class PongPacket extends BasicPacket { - @SuppressWarnings("unused") - private int lowPartitionId; // ignored (historical field) - private int dispatchTimestamp; @SuppressWarnings("unused") @@ -40,7 +37,7 @@ public class PongPacket extends BasicPacket { public void decodeBody(ByteBuffer buffer) { int features = buffer.getInt(); - lowPartitionId = buffer.getInt(); + buffer.getInt(); // Unused lowPartitionId dispatchTimestamp = buffer.getInt(); if ((features & MRF_MLD) != 0) { totalNodes = buffer.getInt(); diff --git a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java index de4d9c9fe8b..f40550f1f70 100644 --- a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java +++ b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java @@ -1,6 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.fs4.mplex; +import com.yahoo.concurrent.SystemTimer; +import com.yahoo.fs4.BasicPacket; +import com.yahoo.fs4.ChannelTimeoutException; +import com.yahoo.fs4.Packet; +import com.yahoo.search.Query; +import com.yahoo.search.dispatch.ResponseMonitor; + import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -9,12 +16,6 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; -import com.yahoo.concurrent.SystemTimer; -import com.yahoo.fs4.BasicPacket; -import com.yahoo.fs4.ChannelTimeoutException; -import com.yahoo.fs4.Packet; -import com.yahoo.search.Query; - /** * This class is used to represent a "channel" in the FS4 protocol. * A channel represents a session between a client and the fdispatch. @@ -34,6 +35,7 @@ public class FS4Channel { volatile private BlockingQueue<BasicPacket> responseQueue; private Query query; private boolean isPingChannel = false; + private ResponseMonitor<FS4Channel> monitor; /** for unit testing. do not use */ protected FS4Channel () { @@ -197,6 +199,9 @@ public class FS4Channel { throws InterruptedException, InvalidChannelException { ensureValidQ().put(packet); + if(monitor != null) { + monitor.responseAvailable(this); + } } /** @@ -241,4 +246,7 @@ public class FS4Channel { return "fs4 channel " + channelId + (isValid() ? " [valid]" : " [invalid]"); } + public void setResponseMonitor(ResponseMonitor<FS4Channel> monitor) { + this.monitor = monitor; + } } diff --git a/container-search/src/main/java/com/yahoo/prelude/Pong.java b/container-search/src/main/java/com/yahoo/prelude/Pong.java index de2a1a4d8fb..ba3ff2eda00 100644 --- a/container-search/src/main/java/com/yahoo/prelude/Pong.java +++ b/container-search/src/main/java/com/yahoo/prelude/Pong.java @@ -9,7 +9,6 @@ import java.util.Optional; import com.yahoo.fs4.PongPacket; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.statistics.ElapsedTime; -import com.yahoo.container.protect.Error; /** * An answer from Ping. @@ -47,15 +46,6 @@ public class Pong { public int getErrorSize() { return errors.size(); } - - /** - * Returns the package causing this to exist, or empty if none - * - * @deprecated do not use - */ - // TODO: Remove on Vespa 7 - @Deprecated // OK - public Optional<PongPacket> getPongPacket() { return pongPacket; } /** Returns the number of active documents in the backend responding in this Pong, if available */ public Optional<Long> activeDocuments() { diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java index f5d082635ab..8fa8bdb66bf 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java @@ -50,7 +50,7 @@ public class FS4InvokerFactory { public SearchInvoker getSearchInvoker(Query query, Node node) { Backend backend = fs4ResourcePool.getBackend(node.hostname(), node.fs4port()); - return new FS4SearchInvoker(searcher, query, backend.openChannel(), node); + return new FS4SearchInvoker(searcher, query, backend.openChannel(), Optional.of(node)); } /** @@ -70,14 +70,14 @@ public class FS4InvokerFactory { * list is invalid and the remaining coverage is not sufficient */ public Optional<SearchInvoker> getSearchInvoker(Query query, int groupId, List<Node> nodes, boolean acceptIncompleteCoverage) { - Map<Integer, SearchInvoker> invokers = new HashMap<>(); + List<SearchInvoker> invokers = new ArrayList<>(nodes.size()); Set<Integer> failed = null; for (Node node : nodes) { boolean nodeAdded = false; if (node.isWorking()) { Backend backend = fs4ResourcePool.getBackend(node.hostname(), node.fs4port()); if (backend.probeConnection()) { - invokers.put(node.key(), new FS4SearchInvoker(searcher, query, backend.openChannel(), node)); + invokers.add(node.key(), new FS4SearchInvoker(searcher, query, backend.openChannel(), Optional.of(node))); nodeAdded = true; } } @@ -99,7 +99,7 @@ public class FS4InvokerFactory { } if (!searchCluster.isPartialGroupCoverageSufficient(groupId, success)) { if (acceptIncompleteCoverage) { - createCoverageErrorInvoker(invokers, nodes, failed); + invokers.add(createCoverageErrorInvoker(nodes, failed)); } else { return Optional.empty(); } @@ -107,13 +107,13 @@ public class FS4InvokerFactory { } if (invokers.size() == 1) { - return Optional.of(invokers.values().iterator().next()); + return Optional.of(invokers.get(0)); } else { - return Optional.of(new InterleavedSearchInvoker(invokers)); + return Optional.of(new InterleavedSearchInvoker(invokers, searchCluster)); } } - private void createCoverageErrorInvoker(Map<Integer, SearchInvoker> invokers, List<Node> nodes, Set<Integer> failed) { + private SearchInvoker createCoverageErrorInvoker(List<Node> nodes, Set<Integer> failed) { long activeDocuments = 0; StringBuilder down = new StringBuilder("Connection failure on nodes with distribution-keys: "); Integer key = null; @@ -129,7 +129,8 @@ public class FS4InvokerFactory { } } Coverage coverage = new Coverage(0, activeDocuments, 0); - invokers.put(key, new SearchErrorInvoker(ErrorMessage.createBackendCommunicationError(down.toString()), coverage)); + coverage.setNodesTried(1); + return new SearchErrorInvoker(ErrorMessage.createBackendCommunicationError(down.toString()), coverage); } public FillInvoker getFillInvoker(Query query, Node node) { diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java index 98676890bdf..da32cfc4fda 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java @@ -6,14 +6,13 @@ import com.yahoo.fs4.ChannelTimeoutException; import com.yahoo.fs4.Packet; import com.yahoo.fs4.QueryPacket; import com.yahoo.fs4.QueryResultPacket; -import com.yahoo.fs4.mplex.Backend; import com.yahoo.fs4.mplex.FS4Channel; import com.yahoo.fs4.mplex.InvalidChannelException; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.ResponseMonitor; import com.yahoo.search.dispatch.SearchInvoker; import com.yahoo.search.dispatch.searchcluster.Node; -import com.yahoo.search.result.Coverage; import com.yahoo.search.result.ErrorMessage; import java.io.IOException; @@ -30,29 +29,21 @@ import static java.util.Arrays.asList; * * @author ollivir */ -public class FS4SearchInvoker extends SearchInvoker { +public class FS4SearchInvoker extends SearchInvoker implements ResponseMonitor<FS4Channel> { private final VespaBackEndSearcher searcher; private FS4Channel channel; - private final Optional<Node> node; private ErrorMessage pendingSearchError = null; private Query query = null; private QueryPacket queryPacket = null; - public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, FS4Channel channel, Node node) { + public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, FS4Channel channel, Optional<Node> node) { + super(node); this.searcher = searcher; - this.node = Optional.of(node); this.channel = channel; channel.setQuery(query); - } - - // fdispatch code path - public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, Backend backend) { - this.searcher = searcher; - this.node = Optional.empty(); - this.channel = backend.openChannel(); - channel.setQuery(query); + channel.setResponseMonitor(this); } @Override @@ -68,6 +59,8 @@ public class FS4SearchInvoker extends SearchInvoker { this.query = query; this.queryPacket = queryPacket; + channel.setResponseMonitor(this); + try { boolean couldSend = channel.sendPacket(queryPacket); if (!couldSend) { @@ -115,7 +108,7 @@ public class FS4SearchInvoker extends SearchInvoker { searcher.addMetaInfo(query, queryPacket.getQueryPacketData(), resultPacket, result); - searcher.addUnfilledHits(result, resultPacket.getDocuments(), false, queryPacket.getQueryPacketData(), cacheKey, node.map(Node::key)); + searcher.addUnfilledHits(result, resultPacket.getDocuments(), false, queryPacket.getQueryPacketData(), cacheKey, distributionKey()); Packet[] packets; CacheControl cacheControl = searcher.getCacheControl(); PacketWrapper packetWrapper = cacheControl.lookup(cacheKey, query); @@ -130,7 +123,7 @@ public class FS4SearchInvoker extends SearchInvoker { } else { packets = new Packet[1]; packets[0] = resultPacket; - cacheControl.cache(cacheKey, query, new DocsumPacketKey[0], packets, node.map(Node::key)); + cacheControl.cache(cacheKey, query, new DocsumPacketKey[0], packets, distributionKey()); } } return asList(result); @@ -138,10 +131,7 @@ public class FS4SearchInvoker extends SearchInvoker { private List<Result> errorResult(ErrorMessage errorMessage) { Result error = new Result(query, errorMessage); - node.ifPresent(n -> { - Coverage coverage = new Coverage(0, n.getActiveDocuments(), 0); - error.setCoverage(coverage); - }); + getErrorCoverage().ifPresent(error::setCoverage); return Arrays.asList(error); } @@ -164,4 +154,9 @@ public class FS4SearchInvoker extends SearchInvoker { private boolean isLoggingFine() { return getLogger().isLoggable(Level.FINE); } + + @Override + public void responseAvailable(FS4Channel from) { + responseAvailable(); + } } diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java index a98c34295ee..209f6faefa0 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java @@ -222,7 +222,7 @@ public class FastSearcher extends VespaBackEndSearcher { if(direct.isPresent()) { return fs4InvokerFactory.getSearchInvoker(query, direct.get()); } - return new FS4SearchInvoker(this, query, dispatchBackend); + return new FS4SearchInvoker(this, query, dispatchBackend.openChannel(), Optional.empty()); } /** @@ -284,6 +284,7 @@ public class FastSearcher extends VespaBackEndSearcher { result.hits().addAll(partialResult.hits().asUnorderedHits()); } if (finalCoverage != null) { + adjustCoverageDegradedReason(finalCoverage); result.setCoverage(finalCoverage); } @@ -301,6 +302,18 @@ public class FastSearcher extends VespaBackEndSearcher { return result; } + private void adjustCoverageDegradedReason(Coverage coverage) { + int asked = coverage.getNodesTried(); + int answered = coverage.getNodes(); + if (asked > answered) { + int searchableCopies = (int) dispatcher.searchCluster().dispatchConfig().searchableCopies(); + int missingNodes = (asked - answered) - (searchableCopies - 1); + if (missingNodes > 0) { + coverage.setDegradedReason(com.yahoo.container.handler.Coverage.DEGRADED_BY_TIMEOUT); + } + } + } + private static @NonNull Optional<String> quotedSummaryClass(String summaryClass) { return Optional.of(summaryClass == null ? "[null]" : quote(summaryClass)); } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index 0382f47457e..1ca64be7924 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -20,7 +20,6 @@ import com.yahoo.vespa.config.search.DispatchConfig; import java.util.Arrays; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; @@ -134,7 +133,7 @@ public class Dispatcher extends AbstractComponent { int max = Integer.min(searchCluster.orderedGroups().size(), MAX_GROUP_SELECTION_ATTEMPTS); Set<Integer> rejected = null; for (int i = 0; i < max; i++) { - Optional<Group> groupInCluster = loadBalancer.takeGroupForQuery(rejected); + Optional<Group> groupInCluster = loadBalancer.takeGroup(rejected); if (!groupInCluster.isPresent()) { // No groups available break; diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java index 9ff43df87cf..83647b332e6 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java @@ -5,12 +5,25 @@ import com.yahoo.fs4.QueryPacket; import com.yahoo.prelude.fastsearch.CacheKey; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.searchcluster.SearchCluster; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.vespa.config.search.DispatchConfig; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.IdentityHashMap; import java.util.List; -import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static com.yahoo.container.handler.Coverage.DEGRADED_BY_ADAPTIVE_TIMEOUT; +import static com.yahoo.container.handler.Coverage.DEGRADED_BY_TIMEOUT; /** * InterleavedSearchInvoker uses multiple {@link SearchInvoker} objects to interface with content @@ -19,11 +32,25 @@ import java.util.Map; * * @author ollivir */ -public class InterleavedSearchInvoker extends SearchInvoker { - private final Collection<SearchInvoker> invokers; +public class InterleavedSearchInvoker extends SearchInvoker implements ResponseMonitor<SearchInvoker> { + private static final Logger log = Logger.getLogger(InterleavedSearchInvoker.class.getName()); + + private final Set<SearchInvoker> invokers; + private final SearchCluster searchCluster; + private final LinkedBlockingQueue<SearchInvoker> availableForProcessing; + private Query query; + + private boolean adaptiveTimeoutCalculated = false; + private long adaptiveTimeoutMin = 0; + private long adaptiveTimeoutMax = 0; + private long deadline = 0; - public InterleavedSearchInvoker(Map<Integer, SearchInvoker> invokers) { - this.invokers = new ArrayList<>(invokers.values()); + public InterleavedSearchInvoker(Collection<SearchInvoker> invokers, SearchCluster searchCluster) { + super(Optional.empty()); + this.invokers = Collections.newSetFromMap(new IdentityHashMap<>()); + this.invokers.addAll(invokers); + this.searchCluster = searchCluster; + this.availableForProcessing = newQueue(); } /** @@ -33,27 +60,109 @@ public class InterleavedSearchInvoker extends SearchInvoker { */ @Override protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException { + this.query = query; + invokers.forEach(invoker -> invoker.setMonitor(this)); + deadline = currentTime() + query.getTimeLeft(); + int originalHits = query.getHits(); int originalOffset = query.getOffset(); query.setHits(query.getHits() + query.getOffset()); query.setOffset(0); + for (SearchInvoker invoker : invokers) { invoker.sendSearchRequest(query, null); } + query.setHits(originalHits); query.setOffset(originalOffset); } @Override protected List<Result> getSearchResults(CacheKey cacheKey) throws IOException { + int requests = invokers.size(); + int responses = 0; List<Result> results = new ArrayList<>(); - for (SearchInvoker invoker : invokers) { - results.addAll(invoker.getSearchResults(cacheKey)); + long nextTimeout = query.getTimeLeft(); + try { + while (!invokers.isEmpty() && nextTimeout >= 0) { + SearchInvoker invoker = availableForProcessing.poll(nextTimeout, TimeUnit.MILLISECONDS); + if (invoker == null) { + if (log.isLoggable(Level.FINE)) { + log.fine("Search timed out with " + requests + " requests made, " + responses + " responses received"); + } + break; + } else { + invokers.remove(invoker); + results.addAll(invoker.getSearchResults(cacheKey)); + responses++; + } + nextTimeout = nextTimeout(requests, responses); + } + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while waiting for search results", e); } + + insertTimeoutErrors(results); return results; } + private void insertTimeoutErrors(List<Result> results) { + int degradedReason = adaptiveTimeoutCalculated ? DEGRADED_BY_ADAPTIVE_TIMEOUT : DEGRADED_BY_TIMEOUT; + + for (SearchInvoker invoker : invokers) { + Optional<Integer> dk = invoker.distributionKey(); + String message; + if (dk.isPresent()) { + message = "Backend communication timeout on node with distribution-key " + dk.get(); + } else { + message = "Backend communication timeout"; + } + Result error = new Result(query, ErrorMessage.createBackendCommunicationError(message)); + invoker.getErrorCoverage().ifPresent(coverage -> { + coverage.setDegradedReason(degradedReason); + error.setCoverage(coverage); + }); + results.add(error); + } + } + + private long nextTimeout(int requests, int responses) { + DispatchConfig config = searchCluster.dispatchConfig(); + double minimumCoverage = config.minSearchCoverage(); + + if (requests == responses || minimumCoverage >= 100.0) { + return query.getTimeLeft(); + } + int minimumResponses = (int) (requests * minimumCoverage / 100.0); + + if (responses < minimumResponses) { + return query.getTimeLeft(); + } + + long timeLeft = query.getTimeLeft(); + if (!adaptiveTimeoutCalculated) { + adaptiveTimeoutMin = (long) (timeLeft * config.minWaitAfterCoverageFactor()); + adaptiveTimeoutMax = (long) (timeLeft * config.maxWaitAfterCoverageFactor()); + adaptiveTimeoutCalculated = true; + } + + long now = currentTime(); + int pendingQueries = requests - responses; + double missWidth = ((100.0 - config.minSearchCoverage()) * requests) / 100.0 - 1.0; + double slopedWait = adaptiveTimeoutMin; + if (pendingQueries > 1 && missWidth > 0.0) { + slopedWait += ((adaptiveTimeoutMax - adaptiveTimeoutMin) * (pendingQueries - 1)) / missWidth; + } + long nextAdaptive = (long) slopedWait; + if (now + nextAdaptive >= deadline) { + return deadline - now; + } + deadline = now + nextAdaptive; + + return nextAdaptive; + } + @Override protected void release() { if (!invokers.isEmpty()) { @@ -61,4 +170,26 @@ public class InterleavedSearchInvoker extends SearchInvoker { invokers.clear(); } } + + @Override + public void responseAvailable(SearchInvoker from) { + if (availableForProcessing != null) { + availableForProcessing.add(from); + } + } + + @Override + protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) { + // never to be called + } + + // For overriding in tests + protected long currentTime() { + return System.currentTimeMillis(); + } + + // For overriding in tests + protected LinkedBlockingQueue<SearchInvoker> newQueue() { + return new LinkedBlockingQueue<>(); + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java index 222ae6a4ea0..df6384cf81c 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; -import com.yahoo.search.Query; import com.yahoo.search.dispatch.searchcluster.Group; import com.yahoo.search.dispatch.searchcluster.SearchCluster; @@ -45,13 +44,13 @@ public class LoadBalancer { } /** - * Select and allocate the search cluster group which is to be used for the provided query. Callers <b>must</b> call + * Select and allocate the search cluster group which is to be used for the next search query. Callers <b>must</b> call * {@link #releaseGroup} symmetrically for each taken allocation. * * @param rejectedGroups if not null, the load balancer will only return groups with IDs not in the set * @return The node group to target, or <i>empty</i> if the internal dispatch logic cannot be used */ - public Optional<Group> takeGroupForQuery(Set<Integer> rejectedGroups) { + public Optional<Group> takeGroup(Set<Integer> rejectedGroups) { if (scoreboard == null) { return Optional.empty(); } @@ -60,7 +59,7 @@ public class LoadBalancer { } /** - * Release an allocation given by {@link #takeGroupForQuery}. The release must be done exactly once for each allocation. + * Release an allocation given by {@link #takeGroup}. The release must be done exactly once for each allocation. * * @param group * previously allocated group diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java b/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java new file mode 100644 index 00000000000..c2e81d43677 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java @@ -0,0 +1,13 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +/** + * Classes implementing ResponseMonitor can be informed by monitored objects + * that a response is available for processing. The responseAvailable method + * must be thread-safe. + * + * @author ollivir + */ +public interface ResponseMonitor<T> { + void responseAvailable(T from); +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java index d5c505aa31b..01da3c20745 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java @@ -11,6 +11,7 @@ import com.yahoo.search.result.ErrorMessage; import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Optional; /** * A search invoker that will immediately produce an error that occurred during @@ -23,8 +24,10 @@ public class SearchErrorInvoker extends SearchInvoker { private final ErrorMessage message; private Query query; private final Coverage coverage; + private ResponseMonitor<SearchInvoker> monitor; public SearchErrorInvoker(ErrorMessage message, Coverage coverage) { + super(Optional.empty()); this.message = message; this.coverage = coverage; } @@ -36,6 +39,9 @@ public class SearchErrorInvoker extends SearchInvoker { @Override protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException { this.query = query; + if(monitor != null) { + monitor.responseAvailable(this); + } } @Override @@ -52,4 +58,8 @@ public class SearchErrorInvoker extends SearchInvoker { // nothing to do } + @Override + protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) { + this.monitor = monitor; + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java index 53e09823f32..2691b32d631 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java @@ -5,9 +5,12 @@ import com.yahoo.fs4.QueryPacket; import com.yahoo.prelude.fastsearch.CacheKey; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.searchcluster.Node; +import com.yahoo.search.result.Coverage; import java.io.IOException; import java.util.List; +import java.util.Optional; /** * SearchInvoker encapsulates an allocated connection for running a single search query. @@ -16,6 +19,13 @@ import java.util.List; * @author ollivir */ public abstract class SearchInvoker extends CloseableInvoker { + private final Optional<Node> node; + private ResponseMonitor<SearchInvoker> monitor; + + protected SearchInvoker(Optional<Node> node) { + this.node = node; + } + /** * Retrieve the hits for the given {@link Query}. The invoker may return more than one result, in which case the caller is responsible * for merging the results. If multiple results are returned and the search query had a hit offset other than zero, that offset is @@ -29,4 +39,26 @@ public abstract class SearchInvoker extends CloseableInvoker { protected abstract void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException; protected abstract List<Result> getSearchResults(CacheKey cacheKey) throws IOException; + + protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) { + this.monitor = monitor; + } + + protected void responseAvailable() { + if(monitor != null) { + monitor.responseAvailable(this); + } + } + + protected Optional<Integer> distributionKey() { + return node.map(Node::key); + } + + protected Optional<Coverage> getErrorCoverage() { + if(node.isPresent()) { + return Optional.of(new Coverage(0, node.get().getActiveDocuments(), 0)); + } else { + return Optional.empty(); + } + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java index b8d76906f70..8e278f78d7a 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java @@ -36,10 +36,7 @@ public class SearchCluster implements NodeManager<Node> { private static final Logger log = Logger.getLogger(SearchCluster.class.getName()); - /** The min active docs a group must have to be considered up, as a % of the average active docs of the other groups */ - private final double minActivedocsCoveragePercentage; - private final double minGroupCoverage; - private final int maxNodesDownPerGroup; + private final DispatchConfig dispatchConfig; private final int size; private final String clusterId; private final ImmutableMap<Integer, Group> groups; @@ -62,20 +59,14 @@ public class SearchCluster implements NodeManager<Node> { private final FS4ResourcePool fs4ResourcePool; public SearchCluster(String clusterId, DispatchConfig dispatchConfig, FS4ResourcePool fs4ResourcePool, int containerClusterSize, VipStatus vipStatus) { - this(clusterId, dispatchConfig.minActivedocsPercentage(), dispatchConfig.minGroupCoverage(), dispatchConfig.maxNodesDownPerGroup(), - toNodes(dispatchConfig), fs4ResourcePool, containerClusterSize, vipStatus); - } - - public SearchCluster(String clusterId, double minActivedocsCoverage, double minGroupCoverage, int maxNodesDownPerGroup, List<Node> nodes, FS4ResourcePool fs4ResourcePool, - int containerClusterSize, VipStatus vipStatus) { this.clusterId = clusterId; - this.minActivedocsCoveragePercentage = minActivedocsCoverage; - this.minGroupCoverage = minGroupCoverage; - this.maxNodesDownPerGroup = maxNodesDownPerGroup; - this.size = nodes.size(); + this.dispatchConfig = dispatchConfig; + this.size = dispatchConfig.node().size(); this.fs4ResourcePool = fs4ResourcePool; this.vipStatus = vipStatus; + List<Node> nodes = toNodes(dispatchConfig); + // Create groups ImmutableMap.Builder<Integer, Group> groupsBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<Integer, List<Node>> group : nodes.stream().collect(Collectors.groupingBy(Node::group)).entrySet()) { @@ -143,6 +134,10 @@ public class SearchCluster implements NodeManager<Node> { return nodesBuilder.build(); } + public DispatchConfig dispatchConfig() { + return dispatchConfig; + } + /** Returns the number of nodes in this cluster (across all groups) */ public int size() { return size; } @@ -286,7 +281,7 @@ public class SearchCluster implements NodeManager<Node> { if (averageDocumentsInOtherGroups > 0) { double coverage = 100.0 * (double) activeDocuments / averageDocumentsInOtherGroups; - sufficientCoverage = coverage >= minActivedocsCoveragePercentage; + sufficientCoverage = coverage >= dispatchConfig.minActivedocsPercentage(); } if (sufficientCoverage) { sufficientCoverage = isGroupNodeCoverageSufficient(nodes); @@ -302,7 +297,8 @@ public class SearchCluster implements NodeManager<Node> { } } int numNodes = nodes.size(); - int nodesAllowedDown = maxNodesDownPerGroup + (int) (((double) numNodes * (100.0 - minGroupCoverage)) / 100.0); + int nodesAllowedDown = dispatchConfig.maxNodesDownPerGroup() + + (int) (((double) numNodes * (100.0 - dispatchConfig.minGroupCoverage())) / 100.0); return nodesUp + nodesAllowedDown >= numNodes; } @@ -325,7 +321,7 @@ public class SearchCluster implements NodeManager<Node> { */ public boolean isPartialGroupCoverageSufficient(int groupId, List<Node> nodes) { if (orderedGroups.size() == 1) { - return nodes.size() >= groupSize() - maxNodesDownPerGroup; + return nodes.size() >= groupSize() - dispatchConfig.maxNodesDownPerGroup(); } long sumOfActiveDocuments = 0; int otherGroups = 0; diff --git a/container-search/src/main/java/com/yahoo/search/pagetemplates/result/PageTemplatesXmlRenderer.java b/container-search/src/main/java/com/yahoo/search/pagetemplates/result/PageTemplatesXmlRenderer.java index b7d7188e77e..92e4bb7e5b8 100644 --- a/container-search/src/main/java/com/yahoo/search/pagetemplates/result/PageTemplatesXmlRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/pagetemplates/result/PageTemplatesXmlRenderer.java @@ -179,7 +179,6 @@ public class PageTemplatesXmlRenderer extends AsynchronousSectionedRenderer<Resu private void renderHitGroup(XMLWriter writer, HitGroup hit) { if (hit.types().contains("section")) { - renderSection(writer, hit); // Renders /result/section } else if (hit.types().contains("meta")) { diff --git a/container-search/src/main/java/com/yahoo/search/rendering/XmlRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/XmlRenderer.java new file mode 100644 index 00000000000..2a822f89352 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/rendering/XmlRenderer.java @@ -0,0 +1,421 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.rendering; + +import com.yahoo.concurrent.CopyOnWriteHashMap; +import com.yahoo.io.ByteWriter; +import com.yahoo.net.URI; +import com.yahoo.prelude.fastsearch.GroupingListHit; +import com.yahoo.prelude.hitfield.HitField; +import com.yahoo.prelude.hitfield.JSONString; +import com.yahoo.prelude.hitfield.XMLString; +import com.yahoo.processing.rendering.AsynchronousSectionedRenderer; +import com.yahoo.processing.response.Data; +import com.yahoo.processing.response.DataList; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.grouping.result.HitRenderer; +import com.yahoo.search.query.context.QueryContext; +import com.yahoo.search.result.*; +import com.yahoo.text.Utf8String; +import com.yahoo.text.XML; +import com.yahoo.text.XMLWriter; +import com.yahoo.yolean.trace.TraceNode; +import com.yahoo.yolean.trace.TraceVisitor; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.io.Writer; +import java.nio.charset.Charset; +import java.nio.charset.CharsetEncoder; +import java.util.Iterator; +import java.util.concurrent.Executor; +import java.util.stream.Collectors; + +/** + * XML rendering of search results. This is NOT the default (but it once was). + * + * @author Tony Vaagenes + */ +@SuppressWarnings({ "rawtypes", "deprecation" }) +public final class XmlRenderer extends AsynchronousSectionedRenderer<Result> { + + public static final String DEFAULT_MIMETYPE = "text/xml"; + public static final String DEFAULT_ENCODING = "utf-8"; + + private static final Utf8String RESULT = new Utf8String("result"); + private static final Utf8String GROUP = new Utf8String("group"); + private static final Utf8String ID = new Utf8String("id"); + private static final Utf8String FIELD = new Utf8String("field"); + private static final Utf8String HIT = new Utf8String("hit"); + private static final Utf8String ERROR = new Utf8String("error"); + private static final Utf8String TOTAL_HIT_COUNT = new Utf8String("total-hit-count"); + private static final Utf8String QUERY_TIME = new Utf8String("querytime"); + private static final Utf8String SUMMARY_FETCH_TIME = new Utf8String("summaryfetchtime"); + private static final Utf8String SEARCH_TIME = new Utf8String("searchtime"); + private static final Utf8String NAME = new Utf8String("name"); + private static final Utf8String CODE = new Utf8String("code"); + private static final Utf8String COVERAGE_DOCS = new Utf8String("coverage-docs"); + private static final Utf8String COVERAGE_NODES = new Utf8String("coverage-nodes"); + private static final Utf8String COVERAGE_FULL = new Utf8String("coverage-full"); + private static final Utf8String COVERAGE = new Utf8String("coverage"); + private static final Utf8String RESULTS_FULL = new Utf8String("results-full"); + private static final Utf8String RESULTS = new Utf8String("results"); + private static final Utf8String TYPE = new Utf8String("type"); + private static final Utf8String RELEVANCY = new Utf8String("relevancy"); + private static final Utf8String SOURCE = new Utf8String("source"); + + + // this is shared between umpteen threads by design + private final CopyOnWriteHashMap<String, Utf8String> fieldNameMap = new CopyOnWriteHashMap<>(); + + private XMLWriter writer; + + public XmlRenderer() { + this(null); + } + + /** + * Creates an XML renderer using a custom executor. + * Using a custom executor is useful for tests to avoid creating new threads for each renderer registry. + */ + public XmlRenderer(Executor executor) { + super(executor); + } + + @Override + public void init() { + super.init(); + writer = null; + } + + @Override + public String getEncoding() { + if (getResult() == null + || getResult().getQuery() == null + || getResult().getQuery().getModel().getEncoding() == null) { + return DEFAULT_ENCODING; + } else { + return getResult().getQuery().getModel().getEncoding(); + } + } + + @Override + public String getMimeType() { + return DEFAULT_MIMETYPE; + } + + private XMLWriter wrapWriter(Writer writer) { + return XMLWriter.from(writer, 10, -1); + } + + private void header(XMLWriter writer, Result result) throws IOException { + // TODO: move setting this to Result + writer.xmlHeader(getRequestedEncoding(result.getQuery())); + writer.openTag(RESULT).attribute(TOTAL_HIT_COUNT, String.valueOf(result.getTotalHitCount())); + renderCoverageAttributes(result.getCoverage(false), writer); + renderTime(writer, result); + writer.closeStartTag(); + } + + private void renderTime(XMLWriter writer, Result result) { + if ( ! result.getQuery().getPresentation().getTiming()) return; + + final String threeDecimals = "%.3f"; + final double milli = .001d; + final long now = System.currentTimeMillis(); + final long searchTime = now - result.getElapsedTime().first(); + final double searchSeconds = ((double) searchTime) * milli; + + if (result.getElapsedTime().firstFill() != 0L) { + final long queryTime = result.getElapsedTime().weightedSearchTime(); + final long summaryFetchTime = result.getElapsedTime().weightedFillTime(); + final double querySeconds = ((double) queryTime) * milli; + final double summarySeconds = ((double) summaryFetchTime) * milli; + writer.attribute(QUERY_TIME, String.format(threeDecimals, querySeconds)); + writer.attribute(SUMMARY_FETCH_TIME, String.format(threeDecimals, summarySeconds)); + } + writer.attribute(SEARCH_TIME, String.format(threeDecimals, searchSeconds)); + } + + protected static void renderCoverageAttributes(Coverage coverage, XMLWriter writer) throws IOException { + if (coverage == null) return; + writer.attribute(COVERAGE_DOCS,coverage.getDocs()); + writer.attribute(COVERAGE_NODES,coverage.getNodes()); + writer.attribute(COVERAGE_FULL,coverage.getFull()); + writer.attribute(COVERAGE,coverage.getResultPercentage()); + writer.attribute(RESULTS_FULL,coverage.getFullResultSets()); + writer.attribute(RESULTS,coverage.getResultSets()); + } + + public void error(XMLWriter writer, Result result) throws IOException { + ErrorMessage error = result.hits().getError(); + writer.openTag(ERROR).attribute(CODE,error.getCode()).content(error.getMessage(),false).closeTag(); + } + + @SuppressWarnings("UnusedParameters") + protected void emptyResult(XMLWriter writer, Result result) throws IOException {} + + @SuppressWarnings("UnusedParameters") + public void queryContext(XMLWriter writer, QueryContext queryContext, Query owner) throws IOException { + if (owner.getTraceLevel()!=0) { + XMLWriter xmlWriter=XMLWriter.from(writer); + xmlWriter.openTag("meta").attribute("type", QueryContext.ID); + TraceNode traceRoot = owner.getModel().getExecution().trace().traceNode().root(); + traceRoot.accept(new RenderingVisitor(xmlWriter, owner.getStartTime())); + xmlWriter.closeTag(); + } + } + + private void renderSingularHit(XMLWriter writer, Hit hit) { + writer.openTag(HIT); + renderHitAttributes(writer, hit); + writer.closeStartTag(); + renderHitFields(writer, hit); + } + + private void renderHitFields(XMLWriter writer, Hit hit) { + renderSyntheticRelevanceField(writer, hit); + hit.forEachField((name, value) -> renderField(writer, name, value)); + } + + private void renderField(XMLWriter writer, String name, Object value) { + if (name.startsWith("$")) return; + + writeOpenFieldElement(writer, name); + renderFieldContent(writer, value); + writeCloseFieldElement(writer); + } + + private void renderFieldContent(XMLWriter writer, Object value) { + writer.escapedContent(asXML(value), false); + } + + private String asXML(Object value) { + if (value == null) + return "(null)"; + else if (value instanceof HitField) + return ((HitField)value).quotedContent(false); + else if (value instanceof StructuredData || value instanceof XMLString || value instanceof JSONString) + return value.toString(); + else + return XML.xmlEscape(value.toString(), false, '\u001f'); + } + + private void renderSyntheticRelevanceField(XMLWriter writer, Hit hit) { + String relevancyFieldName = "relevancy"; + Relevance relevance = hit.getRelevance(); + + // skip depending on hit type + if (relevance != null) { + renderSimpleField(writer, relevancyFieldName, relevance); + } + } + + private void renderSimpleField(XMLWriter writer, String relevancyFieldName, Relevance relevance) { + writeOpenFieldElement(writer, relevancyFieldName); + writer.content(relevance.toString(), false); + writeCloseFieldElement(writer); + } + + private void writeCloseFieldElement(XMLWriter writer) { + writer.closeTag(); + } + + private void writeOpenFieldElement(XMLWriter writer, String relevancyFieldName) { + Utf8String utf8 = fieldNameMap.get(relevancyFieldName); + if (utf8 == null) { + utf8 = new Utf8String(relevancyFieldName); + fieldNameMap.put(relevancyFieldName, utf8); + } + writer.openTag(FIELD).attribute(NAME, utf8); + writer.closeStartTag(); + } + + private void renderHitAttributes(XMLWriter writer, Hit hit) { + writer.attribute(TYPE, hit.types().stream().collect(Collectors.joining(" "))); + if (hit.getRelevance() != null) + writer.attribute(RELEVANCY, hit.getRelevance().toString()); + writer.attribute(SOURCE, hit.getSource()); + } + + private void renderHitGroup(XMLWriter writer, HitGroup hit) throws IOException { + if (HitRenderer.renderHeader(hit, writer)) { + // empty + } else if (hit.types().contains("grouphit")) { + // TODO Keep this? + renderHitGroupOfTypeGroupHit(writer, hit); + } else { + renderGroup(writer, hit); + } + } + + private void renderGroup(XMLWriter writer, HitGroup hit) { + writer.openTag(GROUP); + renderHitAttributes(writer, hit); + writer.closeStartTag(); + } + + private void renderHitGroupOfTypeGroupHit(XMLWriter writer, HitGroup hit) { + writer.openTag(HIT); + renderHitAttributes(writer, hit); + renderId(writer, hit); + writer.closeStartTag(); + } + + private void renderId(XMLWriter writer, HitGroup hit) { + URI uri = hit.getId(); + if (uri != null) { + writer.openTag(ID).content(uri.stringValue(),false).closeTag(); + } + } + + private boolean simpleRenderHit(XMLWriter writer, Hit hit) throws IOException { + if (hit instanceof DefaultErrorHit) { + return simpleRenderDefaultErrorHit(writer, (DefaultErrorHit) hit); + } else if (hit instanceof GroupingListHit) { + return true; + } else { + return false; + } + } + + public static boolean simpleRenderDefaultErrorHit(XMLWriter writer, ErrorHit defaultErrorHit) throws IOException { + writer.openTag("errordetails"); + for (Iterator i = defaultErrorHit.errorIterator(); i.hasNext();) { + ErrorMessage error = (ErrorMessage) i.next(); + renderMessageDefaultErrorHit(writer, error); + } + writer.closeTag(); + return true; + } + + public static void renderMessageDefaultErrorHit(XMLWriter writer, ErrorMessage error) throws IOException { + writer.openTag("error"); + writer.attribute("source", error.getSource()); + writer.attribute("error", error.getMessage()); + writer.attribute("code", Integer.toString(error.getCode())); + writer.content(error.getDetailedMessage(), false); + if (error.getCause()!=null) { + writer.openTag("cause"); + writer.content("\n", true); + StringWriter stackTrace=new StringWriter(); + error.getCause().printStackTrace(new PrintWriter(stackTrace)); + writer.content(stackTrace.toString(), true); + writer.closeTag(); + } + writer.closeTag(); + } + + public static final class RenderingVisitor extends TraceVisitor { + + private static final String tag = "p"; + private final XMLWriter writer; + private long baseTime; + + public RenderingVisitor(XMLWriter writer,long baseTime) { + this.writer=writer; + this.baseTime=baseTime; + } + + @Override + public void entering(TraceNode node) { + if (node.isRoot()) return; + writer.openTag(tag); + } + + @Override + public void leaving(TraceNode node) { + if (node.isRoot()) return; + writer.closeTag(); + } + + @Override + public void visit(TraceNode node) { + if (node.isRoot()) return; + if (node.payload()==null) return; + + writer.openTag(tag); + if (node.timestamp()!=0) + writer.content(node.timestamp()-baseTime,false).content(" ms: ", false); + writer.content(node.payload().toString(),false); + writer.closeTag(); + } + + } + + private Result getResult() { + Result r; + try { + r = (Result) getResponse(); + } catch (ClassCastException e) { + throw new IllegalArgumentException( + "XmlRenderer attempted used outside a search context, got a " + + getResponse().getClass().getName()); + } + return r; + } + + @Override + public void beginResponse(OutputStream stream) throws IOException { + Charset cs = Charset.forName(getRequestedEncoding(getResult().getQuery())); + CharsetEncoder encoder = cs.newEncoder(); + writer = wrapWriter(new ByteWriter(stream, encoder)); + + header(writer, getResult()); + if (getResult().hits().getError() != null || getResult().hits().getQuery().errors().size() > 0) { + error(writer, getResult()); + } + + if (getResult().getConcreteHitCount() == 0) { + emptyResult(writer, getResult()); + } + + if (getResult().getContext(false) != null) { + queryContext(writer, getResult().getContext(false), getResult().getQuery()); + } + + } + + /** Returns the encoding of the query, or the encoding given by the template if none is set */ + public final String getRequestedEncoding(Query query) { + String encoding = query.getModel().getEncoding(); + if (encoding != null) return encoding; + return getEncoding(); + } + + @Override + public void beginList(DataList<?> list) throws IOException { + if (getRecursionLevel() == 1) return; + + HitGroup hit = (HitGroup) list; + boolean renderedSimple = simpleRenderHit(writer, hit); + if (renderedSimple) return; + + renderHitGroup(writer, hit); + } + + @Override + public void data(Data data) throws IOException { + Hit hit = (Hit) data; + boolean renderedSimple = simpleRenderHit(writer, hit); + if (renderedSimple) return; + + renderSingularHit(writer, hit); + writer.closeTag(); + } + + @Override + public void endList(DataList<?> list) { + if (getRecursionLevel() > 1) + writer.closeTag(); + } + + @Override + public void endResponse() { + writer.closeTag(); + writer.close(); + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java new file mode 100644 index 00000000000..69458f25f93 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java @@ -0,0 +1,180 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.fs4.QueryPacket; +import com.yahoo.prelude.fastsearch.CacheKey; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.dispatch.searchcluster.Node; +import com.yahoo.search.dispatch.searchcluster.SearchCluster; +import com.yahoo.test.ManualClock; +import org.junit.Test; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static com.yahoo.search.dispatch.MockSearchCluster.createDispatchConfig; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * @author ollivir + */ +public class InterleavedSearchInvokerTest { + private ManualClock clock = new ManualClock(Instant.now()); + private Query query = new TestQuery(); + private LinkedList<Event> expectedEvents = new LinkedList<>(); + private List<SearchInvoker> invokers = new ArrayList<>(); + + @Test + public void requireThatAdaptiveTimeoutsAreNotUsedWithFullCoverageRequirement() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(100.0), 1, 3); + SearchInvoker invoker = createInterleavedInvoker(cluster, 3); + + expectedEvents.add(new Event(5000, 100, 0)); + expectedEvents.add(new Event(4900, 100, 1)); + expectedEvents.add(new Event(4800, 100, 2)); + + invoker.search(query, null, null); + + assertTrue("All test scenario events processed", expectedEvents.isEmpty()); + } + + @Test + public void requireThatTimeoutsAreNotMarkedAsAdaptive() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(100.0), 1, 3); + SearchInvoker invoker = createInterleavedInvoker(cluster, 3); + + expectedEvents.add(new Event(5000, 300, 0)); + expectedEvents.add(new Event(4700, 300, 1)); + expectedEvents.add(null); + + List<Result> results = invoker.search(query, null, null); + + assertTrue("All test scenario events processed", expectedEvents.isEmpty()); + assertNotNull("Last invoker is marked as an error", results.get(2).hits().getErrorHit()); + assertTrue("Timed out invoker is a normal timeout", results.get(2).getCoverage(false).isDegradedByTimeout()); + } + + @Test + public void requireThatAdaptiveTimeoutDecreasesTimeoutWhenCoverageIsReached() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(50.0), 1, 4); + SearchInvoker invoker = createInterleavedInvoker(cluster, 4); + + expectedEvents.add(new Event(5000, 100, 0)); + expectedEvents.add(new Event(4900, 100, 1)); + expectedEvents.add(new Event(2400, 100, 2)); + expectedEvents.add(new Event(0, 0, null)); + + List<Result> results = invoker.search(query, null, null); + + assertTrue("All test scenario events processed", expectedEvents.isEmpty()); + assertNotNull("Last invoker is marked as an error", results.get(3).hits().getErrorHit()); + assertTrue("Timed out invoker is an adaptive timeout", results.get(3).getCoverage(false).isDegradedByAdapativeTimeout()); + } + + private InterleavedSearchInvoker createInterleavedInvoker(SearchCluster searchCluster, int numInvokers) { + for (int i = 0; i < numInvokers; i++) { + invokers.add(new TestInvoker()); + } + + return new InterleavedSearchInvoker(invokers, searchCluster) { + @Override + protected long currentTime() { + return clock.millis(); + } + + @Override + protected LinkedBlockingQueue<SearchInvoker> newQueue() { + return new LinkedBlockingQueue<SearchInvoker>() { + @Override + public SearchInvoker poll(long timeout, TimeUnit timeUnit) throws InterruptedException { + assertFalse(expectedEvents.isEmpty()); + Event ev = expectedEvents.removeFirst(); + if (ev == null) { + return null; + } else { + return ev.process(query, timeout); + } + } + }; + } + }; + } + + private class Event { + Long expectedTimeout; + long delay; + Integer invokerIndex; + + public Event(Integer expectedTimeout, int delay, Integer invokerIndex) { + this.expectedTimeout = (long) expectedTimeout; + this.delay = delay; + this.invokerIndex = invokerIndex; + } + + public SearchInvoker process(Query query, long currentTimeout) { + if (expectedTimeout != null) { + assertEquals("Expecting timeout to be " + expectedTimeout, (long) expectedTimeout, currentTimeout); + } + clock.advance(Duration.ofMillis(delay)); + if (query.getTimeLeft() < 0) { + fail("Test sequence ran out of time window"); + } + if (invokerIndex == null) { + return null; + } else { + return invokers.get(invokerIndex); + } + } + } + + private class TestInvoker extends SearchInvoker { + protected TestInvoker() { + super(Optional.of(new Node(42, "?", 0, 0))); + } + + @Override + protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException { + } + + @Override + protected List<Result> getSearchResults(CacheKey cacheKey) throws IOException { + return Collections.singletonList(new Result(query)); + } + + @Override + protected void release() { + } + } + + public class TestQuery extends Query { + private long start = clock.millis(); + + public TestQuery() { + super(); + setTimeout(5000); + } + + @Override + public long getStartTime() { + return start; + } + + @Override + public long getDurationTime() { + return clock.millis() - start; + } + } +} diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java index 38a753360d8..c056423a9c4 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java @@ -7,9 +7,9 @@ import com.yahoo.search.dispatch.searchcluster.SearchCluster; import junit.framework.AssertionFailedError; import org.junit.Test; -import java.util.Arrays; import java.util.Optional; +import static com.yahoo.search.dispatch.MockSearchCluster.createDispatchConfig; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -22,10 +22,10 @@ public class LoadBalancerTest { @Test public void requreThatLoadBalancerServesSingleNodeSetups() { Node n1 = new Node(0, "test-node1", 0, 0); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.orElseGet(() -> { throw new AssertionFailedError("Expected a SearchCluster.Group"); }); @@ -36,10 +36,10 @@ public class LoadBalancerTest { public void requreThatLoadBalancerServesMultiGroupSetups() { Node n1 = new Node(0, "test-node1", 0, 0); Node n2 = new Node(1, "test-node2", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.orElseGet(() -> { throw new AssertionFailedError("Expected a SearchCluster.Group"); }); @@ -52,10 +52,10 @@ public class LoadBalancerTest { Node n2 = new Node(1, "test-node2", 1, 0); Node n3 = new Node(0, "test-node3", 0, 1); Node n4 = new Node(1, "test-node4", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2, n3, n4), null, 2, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2, n3, n4), null, 2, null); LoadBalancer lb = new LoadBalancer(cluster, true); - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); assertThat(grp.isPresent(), is(true)); } @@ -63,18 +63,18 @@ public class LoadBalancerTest { public void requreThatLoadBalancerReturnsDifferentGroups() { Node n1 = new Node(0, "test-node1", 0, 0); Node n2 = new Node(1, "test-node2", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); // get first group - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.get(); int id1 = group.id(); // release allocation lb.releaseGroup(group); // get second group - grp = lb.takeGroupForQuery(null); + grp = lb.takeGroup(null); group = grp.get(); assertThat(group.id(), not(equalTo(id1))); } @@ -83,16 +83,16 @@ public class LoadBalancerTest { public void requreThatLoadBalancerReturnsGroupWithShortestQueue() { Node n1 = new Node(0, "test-node1", 0, 0); Node n2 = new Node(1, "test-node2", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); // get first group - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.get(); int id1 = group.id(); // get second group - grp = lb.takeGroupForQuery(null); + grp = lb.takeGroup(null); group = grp.get(); int id2 = group.id(); assertThat(id2, not(equalTo(id1))); @@ -100,7 +100,7 @@ public class LoadBalancerTest { lb.releaseGroup(group); // get third group - grp = lb.takeGroupForQuery(null); + grp = lb.takeGroup(null); group = grp.get(); assertThat(group.id(), equalTo(id2)); } diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java index fc505097472..f7b92419b52 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java @@ -6,9 +6,9 @@ import com.google.common.collect.ImmutableMultimap; import com.yahoo.search.dispatch.searchcluster.Group; import com.yahoo.search.dispatch.searchcluster.Node; import com.yahoo.search.dispatch.searchcluster.SearchCluster; +import com.yahoo.vespa.config.search.DispatchConfig; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Optional; @@ -22,7 +22,11 @@ public class MockSearchCluster extends SearchCluster { private final ImmutableMultimap<String, Node> nodesByHost; public MockSearchCluster(String clusterId, int groups, int nodesPerGroup) { - super(clusterId, 100, 100, 0, Collections.emptyList(), null, 1, null); + this(clusterId, createDispatchConfig(), groups, nodesPerGroup); + } + + public MockSearchCluster(String clusterId, DispatchConfig dispatchConfig, int groups, int nodesPerGroup) { + super(clusterId, dispatchConfig, null, 1, null); ImmutableMap.Builder<Integer, Group> groupBuilder = ImmutableMap.builder(); ImmutableMultimap.Builder<String, Node> hostBuilder = ImmutableMultimap.builder(); @@ -58,7 +62,7 @@ public class MockSearchCluster extends SearchCluster { } public Optional<Group> group(int n) { - if(n < numGroups) { + if (n < numGroups) { return Optional.of(groups.get(n)); } else { return Optional.empty(); @@ -80,4 +84,24 @@ public class MockSearchCluster extends SearchCluster { public void failed(Node node) { node.setWorking(false); } + + public static DispatchConfig createDispatchConfig(Node... nodes) { + return createDispatchConfig(100.0, nodes); + } + + public static DispatchConfig createDispatchConfig(double minSearchCoverage, Node... nodes) { + DispatchConfig.Builder builder = new DispatchConfig.Builder(); + builder.minActivedocsPercentage(88.0); + builder.minGroupCoverage(99.0); + builder.maxNodesDownPerGroup(0); + builder.minSearchCoverage(minSearchCoverage); + if(minSearchCoverage < 100.0) { + builder.minWaitAfterCoverageFactor(0); + builder.maxWaitAfterCoverageFactor(0.5); + } + for (Node n : nodes) { + builder.node(new DispatchConfig.Node.Builder().key(n.key()).host(n.hostname()).port(n.fs4port()).group(n.group())); + } + return new DispatchConfig(builder); + } } diff --git a/container/pom.xml b/container/pom.xml index 32a7947d6d5..529f01e0a40 100644 --- a/container/pom.xml +++ b/container/pom.xml @@ -53,18 +53,6 @@ <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> </dependencies> diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java index 6a69eb54d2c..b2aaa99ab7a 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java @@ -35,4 +35,10 @@ public interface OwnershipIssues { */ void ensureResponse(IssueId issueId, Optional<Contact> contact); + /** + * Get the owner of an application, given its ownership issue ID. + * @param issueId ID of the ownership issue. + * @return The owner of the application, if it has been confirmed. + */ + Optional<User> getConfirmedOwner(IssueId issueId); } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java index 14f252732fb..80365e71fb9 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java @@ -21,4 +21,8 @@ public class DummyOwnershipIssues implements OwnershipIssues { } + @Override + public Optional<User> getConfirmedOwner(IssueId issueId) { + return Optional.empty(); + } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java index 619f8abc180..1baff026385 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java @@ -11,6 +11,7 @@ import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.SystemName; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationActivity; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationVersion; @@ -50,6 +51,7 @@ public class Application { private final Change change; private final Change outstandingChange; private final Optional<IssueId> ownershipIssueId; + private final Optional<User> owner; private final ApplicationMetrics metrics; private final Optional<RotationId> rotation; private final Map<HostName, RotationStatus> rotationStatus; @@ -58,23 +60,23 @@ public class Application { public Application(ApplicationId id, Instant now) { this(id, now, DeploymentSpec.empty, ValidationOverrides.empty, Collections.emptyMap(), new DeploymentJobs(OptionalLong.empty(), Collections.emptyList(), Optional.empty(), false), - Change.empty(), Change.empty(), Optional.empty(), new ApplicationMetrics(0, 0), + Change.empty(), Change.empty(), Optional.empty(), Optional.empty(), new ApplicationMetrics(0, 0), Optional.empty(), Collections.emptyMap()); } /** Used from persistence layer: Do not use */ public Application(ApplicationId id, Instant createdAt, DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides, List<Deployment> deployments, DeploymentJobs deploymentJobs, Change change, - Change outstandingChange, Optional<IssueId> ownershipIssueId, ApplicationMetrics metrics, + Change outstandingChange, Optional<IssueId> ownershipIssueId, Optional<User> owner, ApplicationMetrics metrics, Optional<RotationId> rotation, Map<HostName, RotationStatus> rotationStatus) { this(id, createdAt, deploymentSpec, validationOverrides, deployments.stream().collect(Collectors.toMap(Deployment::zone, d -> d)), - deploymentJobs, change, outstandingChange, ownershipIssueId, metrics, rotation, rotationStatus); + deploymentJobs, change, outstandingChange, ownershipIssueId, owner, metrics, rotation, rotationStatus); } Application(ApplicationId id, Instant createdAt, DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides, Map<ZoneId, Deployment> deployments, DeploymentJobs deploymentJobs, Change change, - Change outstandingChange, Optional<IssueId> ownershipIssueId, ApplicationMetrics metrics, + Change outstandingChange, Optional<IssueId> ownershipIssueId, Optional<User> owner, ApplicationMetrics metrics, Optional<RotationId> rotation, Map<HostName, RotationStatus> rotationStatus) { this.id = Objects.requireNonNull(id, "id cannot be null"); this.createdAt = Objects.requireNonNull(createdAt, "instant of creation cannot be null"); @@ -85,6 +87,7 @@ public class Application { this.change = Objects.requireNonNull(change, "change cannot be null"); this.outstandingChange = Objects.requireNonNull(outstandingChange, "outstandingChange cannot be null"); this.ownershipIssueId = Objects.requireNonNull(ownershipIssueId, "ownershipIssueId cannot be null"); + this.owner = Objects.requireNonNull(owner, "owner cannot be null"); this.metrics = Objects.requireNonNull(metrics, "metrics cannot be null"); this.rotation = Objects.requireNonNull(rotation, "rotation cannot be null"); this.rotationStatus = ImmutableMap.copyOf(Objects.requireNonNull(rotationStatus, "rotationStatus cannot be null")); @@ -139,6 +142,10 @@ public class Application { return ownershipIssueId; } + public Optional<User> owner() { + return owner; + } + /** Returns metrics for this */ public ApplicationMetrics metrics() { return metrics; diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java index 5951ceb4792..1e138dd5a4d 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java @@ -12,6 +12,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; @@ -49,6 +50,7 @@ public class LockedApplication { private final Change change; private final Change outstandingChange; private final Optional<IssueId> ownershipIssueId; + private Optional<User> owner; private final ApplicationMetrics metrics; private final Optional<RotationId> rotation; private final Map<HostName, RotationStatus> rotationStatus; @@ -64,7 +66,7 @@ public class LockedApplication { application.deploymentSpec(), application.validationOverrides(), application.deployments(), application.deploymentJobs(), application.change(), application.outstandingChange(), - application.ownershipIssueId(), application.metrics(), + application.ownershipIssueId(), application.owner(), application.metrics(), application.rotation(), application.rotationStatus()); } @@ -72,7 +74,7 @@ public class LockedApplication { private LockedApplication(Lock lock, ApplicationId id, Instant createdAt, DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides, Map<ZoneId, Deployment> deployments, DeploymentJobs deploymentJobs, Change change, - Change outstandingChange, Optional<IssueId> ownershipIssueId, ApplicationMetrics metrics, + Change outstandingChange, Optional<IssueId> ownershipIssueId, Optional<User> owner, ApplicationMetrics metrics, Optional<RotationId> rotation, Map<HostName, RotationStatus> rotationStatus) { this.lock = lock; this.id = id; @@ -84,6 +86,7 @@ public class LockedApplication { this.change = change; this.outstandingChange = outstandingChange; this.ownershipIssueId = ownershipIssueId; + this.owner = owner; this.metrics = metrics; this.rotation = rotation; this.rotationStatus = rotationStatus; @@ -92,44 +95,44 @@ public class LockedApplication { /** Returns a read-only copy of this */ public Application get() { return new Application(id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, - outstandingChange, ownershipIssueId, metrics, rotation, rotationStatus); + outstandingChange, ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withBuiltInternally(boolean builtInternally) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs.withBuiltInternally(builtInternally), change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withProjectId(OptionalLong projectId) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs.withProjectId(projectId), change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withDeploymentIssueId(IssueId issueId) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs.with(issueId), change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withJobPause(JobType jobType, OptionalLong pausedUntil) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs.withPause(jobType, pausedUntil), change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withJobCompletion(long projectId, JobType jobType, JobStatus.JobRun completion, Optional<DeploymentJobs.JobError> jobError) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs.withCompletion(projectId, jobType, completion, jobError), - change, outstandingChange, ownershipIssueId, metrics, rotation, rotationStatus); + change, outstandingChange, ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withJobTriggering(JobType jobType, JobStatus.JobRun job) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs.withTriggering(jobType, job), change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withNewDeployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, @@ -179,54 +182,60 @@ public class LockedApplication { public LockedApplication withoutDeploymentJob(JobType jobType) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs.without(jobType), change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication with(DeploymentSpec deploymentSpec) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication with(ValidationOverrides validationOverrides) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withChange(Change change) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withOutstandingChange(Change outstandingChange) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication withOwnershipIssueId(IssueId issueId) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - Optional.ofNullable(issueId), metrics, rotation, rotationStatus); + Optional.ofNullable(issueId), owner, metrics, rotation, rotationStatus); + } + + public LockedApplication withOwner(User owner) { + return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, Optional.ofNullable(owner), metrics, rotation, rotationStatus); } public LockedApplication with(MetricsService.ApplicationMetrics metrics) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } public LockedApplication with(RotationId rotation) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - ownershipIssueId, metrics, Optional.of(rotation), rotationStatus); + ownershipIssueId, owner, metrics, Optional.of(rotation), rotationStatus); } public LockedApplication withRotationStatus(Map<HostName, RotationStatus> rotationStatus) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, - outstandingChange, ownershipIssueId, metrics, rotation, rotationStatus); + outstandingChange, ownershipIssueId, owner, metrics, rotation, rotationStatus); } /** Don't expose non-leaf sub-objects. */ @@ -239,7 +248,7 @@ public class LockedApplication { private LockedApplication with(Map<ZoneId, Deployment> deployments) { return new LockedApplication(lock, id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, outstandingChange, - ownershipIssueId, metrics, rotation, rotationStatus); + ownershipIssueId, owner, metrics, rotation, rotationStatus); } @Override diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java index 633490b9299..eecf7a7b1a7 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java @@ -40,6 +40,7 @@ import java.io.UncheckedIOException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -172,21 +173,23 @@ public class InternalStepRunner implements StepRunner { try { PrepareResponse prepareResponse = deployment.get().prepareResponse(); if ( ! prepareResponse.configChangeActions.refeedActions.stream().allMatch(action -> action.allowed)) { - logger.log("Deploy failed due to non-compatible changes that require re-feed. " + - "Your options are: \n" + - "1. Revert the incompatible changes.\n" + - "2. If you think it is safe in your case, you can override this validation, see\n" + - " http://docs.vespa.ai/documentation/reference/validation-overrides.html\n" + - "3. Deploy as a new application under a different name.\n" + - "Illegal actions:\n" + - prepareResponse.configChangeActions.refeedActions.stream() - .filter(action -> ! action.allowed) - .flatMap(action -> action.messages.stream()) - .collect(Collectors.joining("\n")) + "\n" + - "Details:\n" + - prepareResponse.log.stream() - .map(entry -> entry.message) - .collect(Collectors.joining("\n"))); + List<String> messages = new ArrayList<>(); + messages.add("Deploy failed due to non-compatible changes that require re-feed."); + messages.add("Your options are:"); + messages.add("1. Revert the incompatible changes."); + messages.add("2. If you think it is safe in your case, you can override this validation, see"); + messages.add(" http://docs.vespa.ai/documentation/reference/validation-overrides.html"); + messages.add("3. Deploy as a new application under a different name."); + messages.add("Illegal actions:"); + prepareResponse.configChangeActions.refeedActions.stream() + .filter(action -> ! action.allowed) + .flatMap(action -> action.messages.stream()) + .forEach(messages::add); + messages.add("Details:"); + prepareResponse.log.stream() + .map(entry -> entry.message) + .forEach(messages::add); + logger.log(messages); return Optional.of(deploymentFailed); } @@ -327,13 +330,13 @@ public class InternalStepRunner implements StepRunner { logger.log("Attempting to find endpoints ..."); Map<ZoneId, List<URI>> endpoints = deploymentEndpoints(id.application(), zones); - logger.log("Found endpoints:\n" + - endpoints.entrySet().stream() - .map(zoneEndpoints -> "- " + zoneEndpoints.getKey() + ":\n" + - zoneEndpoints.getValue().stream() - .map(uri -> " |-- " + uri) - .collect(Collectors.joining("\n"))) - .collect(Collectors.joining("\n"))); + List<String> messages = new ArrayList<>(); + messages.add("Found endpoints"); + endpoints.forEach((zone, uris) -> { + messages.add("- " + zone); + uris.forEach(uri -> messages.add(" |-- " + uri)); + }); + logger.log(messages); if ( ! endpoints.containsKey(id.type().zone(controller.system()))) { if (timedOut(deployment.get(), endpointTimeout)) { logger.log(WARNING, "Endpoints failed to show up within " + endpointTimeout.toMinutes() + " minutes!"); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java index 71b2988d840..ba947a0e79a 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java @@ -318,12 +318,7 @@ public class JobController { Optional<URI> testerEndpoint(RunId id) { ApplicationId tester = id.tester().id(); return controller.applications().getDeploymentEndpoints(new DeploymentId(tester, id.type().zone(controller.system()))) - .flatMap(uris -> uris.stream() - .filter(uri -> uri.getHost().contains(String.format("%s--%s--%s.", - tester.instance().value(), - tester.application().value(), - tester.tenant().value()))) - .findAny()); + .flatMap(uris -> uris.stream().findAny()); } // TODO jvenstad: Find a more appropriate way of doing this, at least when this is the only build service. diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java index 2e6b3d1360d..b710b78682a 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java @@ -38,6 +38,7 @@ public class ApplicationOwnershipConfirmer extends Maintainer { protected void maintain() { confirmApplicationOwnerships(); ensureConfirmationResponses(); + updateConfirmedApplicationOwners(); } /** File an ownership issue with the owners of all applications we know about. */ @@ -81,6 +82,22 @@ public class ApplicationOwnershipConfirmer extends Maintainer { }); } + private void updateConfirmedApplicationOwners() { + ApplicationList.from(controller().applications().asList()) + .withProjectId() + .hasProductionDeployment() + .asList() + .stream() + .filter(application -> application.ownershipIssueId().isPresent()) + .forEach(application -> { + IssueId ownershipIssueId = application.ownershipIssueId().get(); + ownershipIssues.getConfirmedOwner(ownershipIssueId).ifPresent(owner -> { + controller().applications().lockIfPresent(application.id(), lockedApplication -> + controller().applications().store(lockedApplication.withOwner(owner))); + }); + }); + } + private Tenant ownerOf(ApplicationId applicationId) { return controller().tenants().tenant(applicationId.tenant()) .orElseThrow(() -> new IllegalStateException("No tenant found for application " + applicationId)); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java index c69e77a43e1..9c087a101b9 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java @@ -16,6 +16,7 @@ import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; @@ -61,6 +62,7 @@ public class ApplicationSerializer { private final String deployingField = "deployingField"; private final String outstandingChangeField = "outstandingChangeField"; private final String ownershipIssueIdField = "ownershipIssueId"; + private final String ownerField = "confirmedOwner"; private final String writeQualityField = "writeQuality"; private final String queryQualityField = "queryQuality"; private final String rotationField = "rotation"; @@ -146,6 +148,7 @@ public class ApplicationSerializer { toSlime(application.change(), root, deployingField); toSlime(application.outstandingChange(), root, outstandingChangeField); application.ownershipIssueId().ifPresent(issueId -> root.setString(ownershipIssueIdField, issueId.value())); + application.owner().ifPresent(owner -> root.setString(ownerField, owner.username())); root.setDouble(queryQualityField, application.metrics().queryServiceQuality()); root.setDouble(writeQualityField, application.metrics().writeServiceQuality()); application.rotation().ifPresent(rotation -> root.setString(rotationField, rotation.asString())); @@ -301,13 +304,14 @@ public class ApplicationSerializer { Change deploying = changeFromSlime(root.field(deployingField)); Change outstandingChange = changeFromSlime(root.field(outstandingChangeField)); Optional<IssueId> ownershipIssueId = optionalString(root.field(ownershipIssueIdField)).map(IssueId::from); + Optional<User> owner = optionalString(root.field(ownerField)).map(User::from); ApplicationMetrics metrics = new ApplicationMetrics(root.field(queryQualityField).asDouble(), root.field(writeQualityField).asDouble()); Optional<RotationId> rotation = rotationFromSlime(root.field(rotationField)); Map<HostName, RotationStatus> rotationStatus = rotationStatusFromSlime(root.field(rotationStatusField)); return new Application(id, createdAt, deploymentSpec, validationOverrides, deployments, deploymentJobs, deploying, - outstandingChange, ownershipIssueId, metrics, rotation, rotationStatus); + outstandingChange, ownershipIssueId, owner, metrics, rotation, rotationStatus); } private List<Deployment> deploymentsFromSlime(Inspector array) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index c5db553219e..49da3867f76 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -492,6 +492,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { application.activity().lastWritesPerSecond().ifPresent(value -> activity.setDouble("lastWritesPerSecond", value)); application.ownershipIssueId().ifPresent(issueId -> object.setString("ownershipIssueId", issueId.value())); + application.owner().ifPresent(owner -> object.setString("owner", owner.username())); application.deploymentJobs().issueId().ifPresent(issueId -> object.setString("deploymentIssueId", issueId.value())); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java index 62653f29518..75c287e700f 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java @@ -58,14 +58,12 @@ public class ApplicationOwnershipConfirmerTest { Optional<IssueId> issueId = Optional.of(IssueId.from("1")); issues.response = issueId; confirmer.maintain(); - confirmer.maintain(); assertFalse("No issue is stored for an application newer than 3 months.", propertyApp.get().ownershipIssueId().isPresent()); assertFalse("No issue is stored for an application newer than 3 months.", userApp.get().ownershipIssueId().isPresent()); tester.clock().advance(Duration.ofDays(91)); confirmer.maintain(); - confirmer.maintain(); assertEquals("Confirmation issue has been filed for property owned application.", issueId, propertyApp.get().ownershipIssueId()); assertEquals("Confirmation issue has been filed for user owned application.", issueId, userApp.get().ownershipIssueId()); @@ -75,7 +73,6 @@ public class ApplicationOwnershipConfirmerTest { // No new issue is created, so return empty now. issues.response = Optional.empty(); confirmer.maintain(); - confirmer.maintain(); assertEquals("Confirmation issue reference is not updated when no issue id is returned.", issueId, propertyApp.get().ownershipIssueId()); assertEquals("Confirmation issue reference is not updated when no issue id is returned.", issueId, userApp.get().ownershipIssueId()); @@ -86,16 +83,20 @@ public class ApplicationOwnershipConfirmerTest { tester.controller().applications().deactivate(userApp.get().id(), userApp.get().productionDeployments().keySet().stream().findAny().get()); assertTrue("No production deployments are listed for user.", userApp.get().productionDeployments().isEmpty()); confirmer.maintain(); - confirmer.maintain(); // Time has passed, and a new confirmation issue is in order for the property which is still in production. Optional<IssueId> issueId2 = Optional.of(IssueId.from("2")); issues.response = issueId2; confirmer.maintain(); - confirmer.maintain(); assertEquals("A new confirmation issue id is stored when something is returned to the maintainer.", issueId2, propertyApp.get().ownershipIssueId()); assertEquals("Confirmation issue for application without production deployments has not been filed.", issueId, userApp.get().ownershipIssueId()); + + assertFalse("No owner is stored for application", propertyApp.get().owner().isPresent()); + issues.owner = Optional.of(User.from("username")); + confirmer.maintain(); + assertEquals("Owner has been added to application", propertyApp.get().owner().get().username(), "username"); + } private class MockOwnershipIssues implements OwnershipIssues { @@ -103,6 +104,7 @@ public class ApplicationOwnershipConfirmerTest { private Optional<IssueId> response; private boolean escalatedToContact = false; private boolean escalatedToTerminator = false; + private Optional<User> owner = Optional.empty(); @Override public Optional<IssueId> confirmOwnership(Optional<IssueId> issueId, ApplicationId applicationId, User asignee, Contact contact) { @@ -115,6 +117,10 @@ public class ApplicationOwnershipConfirmerTest { else escalatedToTerminator = true; } + @Override + public Optional<User> getConfirmedOwner(IssueId issueId) { + return owner; + } } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java index a6a6635032d..42ce696af89 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java @@ -12,6 +12,7 @@ import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; @@ -106,6 +107,7 @@ public class ApplicationSerializerTest { Change.of(Version.fromString("6.7")), Change.of(ApplicationVersion.from(new SourceRevision("repo", "master", "deadcafe"), 42)), Optional.of(IssueId.from("1234")), + Optional.of(User.from("by-username")), new MetricsService.ApplicationMetrics(0.5, 0.9), Optional.of(new RotationId("my-rotation")), rotationStatus); @@ -138,6 +140,7 @@ public class ApplicationSerializerTest { assertEquals(original.outstandingChange(), serialized.outstandingChange()); assertEquals(original.ownershipIssueId(), serialized.ownershipIssueId()); + assertEquals(original.owner(), serialized.owner()); assertEquals(original.change(), serialized.change()); assertEquals(original.rotation().get(), serialized.rotation().get()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java index 8b260765423..3d0489ab0a1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java @@ -27,6 +27,7 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId; import com.yahoo.vespa.hosted.controller.api.identifiers.UserId; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction; import com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerException; @@ -508,7 +509,8 @@ public class ApplicationApiTest extends ControllerContainerTest { private void addIssues(ContainerControllerTester tester, ApplicationId id) { tester.controller().applications().lockOrThrow(id, application -> tester.controller().applications().store(application.withDeploymentIssueId(IssueId.from("123")) - .withOwnershipIssueId(IssueId.from("321")))); + .withOwnershipIssueId(IssueId.from("321")) + .withOwner(User.from("owner-username")))); } @Test diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json index 70da148ef86..da6bf455857 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json @@ -231,5 +231,6 @@ "lastWritesPerSecond": 2.0 }, "ownershipIssueId": "321", + "owner": "owner-username", "deploymentIssueId": "123" } diff --git a/document/src/main/java/com/yahoo/document/Document.java b/document/src/main/java/com/yahoo/document/Document.java index 85049baf8b0..222ebe29c6d 100644 --- a/document/src/main/java/com/yahoo/document/Document.java +++ b/document/src/main/java/com/yahoo/document/Document.java @@ -286,6 +286,7 @@ public class Document extends StructuredFieldValue { /** * Get JSON representation of the document root and its children contained in a JSON object + * * @return JSON representation of document */ public String toJson() { diff --git a/documentgen-test/pom.xml b/documentgen-test/pom.xml index fb6515c3f35..fdc1b89a438 100644 --- a/documentgen-test/pom.xml +++ b/documentgen-test/pom.xml @@ -55,7 +55,6 @@ <arg>-Xlint:all</arg> <arg>-Xlint:-unchecked</arg> <arg>-Xlint:-serial</arg> - <arg>-Werror</arg> </compilerArgs> </configuration> </plugin> diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java index b3daf5c296d..4c483072f5f 100644 --- a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java +++ b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java @@ -2,17 +2,11 @@ package com.yahoo.jrt; import com.yahoo.security.SslContextBuilder; -import com.yahoo.security.X509CertificateUtils; import com.yahoo.security.tls.TransportSecurityOptions; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; -import java.io.IOException; -import java.io.UncheckedIOException; import java.nio.channels.SocketChannel; -import java.nio.file.Files; -import java.security.cert.X509Certificate; -import java.util.List; /** * A {@link CryptoSocket} that creates {@link TlsCryptoSocket} instances. @@ -40,9 +34,10 @@ public class TlsCryptoEngine implements CryptoEngine { } private static SSLContext createSslContext(TransportSecurityOptions options) { - return new SslContextBuilder() - .withTrustStore(options.getCaCertificatesFile()) - .withKeyStore(options.getPrivateKeyFile(), options.getCertificatesFile()) - .build(); + SslContextBuilder builder = new SslContextBuilder(); + options.getCertificatesFile() + .ifPresent(certificates -> builder.withKeyStore(options.getPrivateKeyFile().get(), certificates)); + options.getCaCertificatesFile().ifPresent(builder::withTrustStore); + return builder.build(); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index e001204f650..55da2e78894 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,7 +7,6 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -25,6 +24,9 @@ import java.util.stream.Collectors; @Beta public class Model { + /** The prefix generated by mode-integration/../IntermediateOperation */ + private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; + private final String name; /** Free functions */ @@ -66,7 +68,7 @@ public class Model { } for (String argument : context.arguments()) { - if (function.getValue().getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) { + if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) { // Internal (generated) functions do not have type info - add arguments if (!function.getValue().arguments().contains(argument)) functions.put(function.getKey(), function.getValue().withArgument(argument)); @@ -85,7 +87,7 @@ public class Model { this.contextPrototypes = contextBuilder.build(); this.functions = ImmutableList.copyOf(functions.values()); this.publicFunctions = ImmutableList.copyOf(functions.values().stream() - .filter(f -> ! f.getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) + .filter(f -> ! f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) .collect(Collectors.toList())); // Optimize functions diff --git a/model-integration/CMakeLists.txt b/model-integration/CMakeLists.txt new file mode 100644 index 00000000000..26d5b3d1bbc --- /dev/null +++ b/model-integration/CMakeLists.txt @@ -0,0 +1,4 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +install_fat_java_artifact(model-integration) + +install(FILES src/main/config/model-integration.xml DESTINATION conf/configserver-app)
\ No newline at end of file diff --git a/model-integration/README b/model-integration/README new file mode 100644 index 00000000000..7b29ac16e34 --- /dev/null +++ b/model-integration/README @@ -0,0 +1,8 @@ +3rd party ML models and converters from these to ranking expresssions, provided as a separate bundle. + +This has two purposes +- Make converters (importers) available to config models while loading them in just a single instance even when + there are multiple config models. +- Make third party models directly available to the container. + +TensorFlow depends on JNI code which necessitates using a separate bundle to achieve the above.
\ No newline at end of file diff --git a/model-integration/pom.xml b/model-integration/pom.xml new file mode 100644 index 00000000000..5a2e7f0dbcd --- /dev/null +++ b/model-integration/pom.xml @@ -0,0 +1,134 @@ +<?xml version="1.0"?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 + http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>com.yahoo.vespa</groupId> + <artifactId>parent</artifactId> + <version>6-SNAPSHOT</version> + <relativePath>../parent/pom.xml</relativePath> + </parent> + <artifactId>model-integration</artifactId> + <version>6-SNAPSHOT</version> + <packaging>container-plugin</packaging> + <dependencies> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>annotations</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>config-model-api</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>searchlib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>vespajlib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <scope>provided</scope> + </dependency> + + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </dependency> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + </dependency> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>com.yahoo.vespa</groupId> + <artifactId>bundle-plugin</artifactId> + <extensions>true</extensions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <configuration> + <compilerArgs> + <arg>-Xlint:rawtypes</arg> + <arg>-Xlint:unchecked</arg> + <arg>-Werror</arg> + </compilerArgs> + </configuration> + </plugin> + <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + <version>3.5.1.1</version> + <executions> + <execution> + <phase>generate-sources</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <addSources>main</addSources> + <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory> + <inputDirectories> + <include>src/main/protobuf</include> + </inputDirectories> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + + <profiles> + <!-- Exclude TF JNI when building for rhel6, which needs a special, natively installed variant --> + <profile> + <id>rhel6</id> + <activation> + <property> + <name>target.env</name> + <value>rhel6</value> + </property> + </activation> + <dependencies> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + <exclusions> + <exclusion> + <groupId>org.tensorflow</groupId> + <artifactId>libtensorflow_jni</artifactId> + </exclusion> + </exclusions> + </dependency> + </dependencies> + </profile> + </profiles> + +</project> diff --git a/model-integration/src/main/config/model-integration.xml b/model-integration/src/main/config/model-integration.xml new file mode 100644 index 00000000000..da45ce23575 --- /dev/null +++ b/model-integration/src/main/config/model-integration.xml @@ -0,0 +1,10 @@ +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<!-- Component which can import some ml model. + This is included into the config server services.xml to enable it to translate + model pseudofeatures in ranking expressions during config mddel building. + It is provided as separate bundles instead of being included in the config model + because some of these (TensorFlow) includes + JNI code, and so can only exist in one instance in the server. --> +<component id="ai.vespa.rankingexpression.importer.onnx.OnnxImporter" bundle="model-integration" /> +<component id="ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter" bundle="model-integration" /> +<component id="ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter" bundle="model-integration" /> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 38f1d2329e2..9e9f66be700 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer; +package ai.vespa.rankingexpression.importer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import java.util.ArrayDeque; import java.util.ArrayList; @@ -77,7 +77,7 @@ public class DimensionRenamer { * algorithm below needs to be adapted with a backtracking (tree) search * to find solutions. */ - public void solve(int maxIterations) { + private void solve(int maxIterations) { initialize(); // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts @@ -104,7 +104,7 @@ public class DimensionRenamer { // Then run this algorithm again. } - public void solve() { + void solve() { solve(100000); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 59ec66b7209..ec4e729f9c7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -1,22 +1,20 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.yahoo.collections.Pair; +import com.yahoo.config.model.api.ImportedMlFunction; +import com.yahoo.config.model.api.ImportedMlModel; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.regex.Pattern; @@ -25,7 +23,7 @@ import java.util.regex.Pattern; * * @author bratseth */ -public class ImportedModel { +public class ImportedModel implements ImportedMlModel { private static final String defaultSignatureName = "default"; @@ -54,27 +52,39 @@ public class ImportedModel { } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + @Override public String name() { return name; } /** Returns the source path (directory or file) of this model */ + @Override public String source() { return source; } /** Returns an immutable map of the inputs of this */ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } + @Override + public Optional<String> inputTypeSpec(String input) { + return Optional.ofNullable(inputs.get(input)).map(TensorType::toString); + } + /** - * Returns an immutable map of the small constants of this. - * These should have sizes up to a few kb at most, and correspond to constant - * values given in the TensorFlow or ONNX source. + * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form. + * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ - public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + @Override + public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } + + boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } /** * Returns an immutable map of the large constants of this. * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. * For TensorFlow this corresponds to Variable files stored separately. */ - public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } + @Override + public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } + + boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } /** * Returns an immutable map of the expressions of this - corresponding to graph nodes @@ -83,28 +93,31 @@ public class ImportedModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + // TODO: Most of the usage of the above can be replaced by a faster expressionNames method + /** * Returns an immutable map of the functions that are part of this model. * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ - public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } + @Override + public Map<String, String> functions() { return asExpressionStrings(functions); } /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } /** Returns the given signature. If it does not already exist it is added to this. */ - Signature signature(String name) { + public Signature signature(String name) { return signatures.computeIfAbsent(name, Signature::new); } /** Convenience method for returning a default signature */ - Signature defaultSignature() { return signature(defaultSignatureName); } + public Signature defaultSignature() { return signature(defaultSignatureName); } - void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } - void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void function(String name, RankingExpression expression) { functions.put(name, expression); } + public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } + public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + public void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + public void function(String name, RankingExpression expression) { functions.put(name, expression); } /** * Returns all the output expressions of this indexed by name. The names consist of one or two parts @@ -112,43 +125,67 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, ExpressionFunction>> outputExpressions() { - List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); + @Override + public List<ImportedMlFunction> outputExpressions() { + List<ImportedMlFunction> functions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), - signatureEntry.getValue().outputExpression(outputEntry.getKey()) - .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); + functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(), + signatureEntry.getKey() + "." + outputEntry.getKey())); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().values()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty()))); + functions.add(new ImportedMlFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().values()), + expressions().get(signatureEntry.getKey()).getRoot().toString(), + asTensorTypeStrings(signatureEntry.getValue().inputMap()), + Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedMlFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue().getRoot().toString(), + asTensorTypeStrings(inputs), + Optional.empty())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedMlFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue().getRoot().toString(), + asTensorTypeStrings(inputs), + Optional.empty())); } } } - return expressions; + return functions; + } + + private Map<String, String> asTensorStrings(Map<String, Tensor> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, Tensor> entry : map.entrySet()) { + Tensor tensor = entry.getValue(); + // TODO: See Tensor.toStandardString + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) + values.put(entry.getKey(), tensor.toString()); + else + values.put(entry.getKey(), tensor.type() + ":" + tensor); + } + return values; + } + + private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) { + Map<String, String> stringMap = new HashMap<>(); + for (Map.Entry<String, TensorType> entry : map.entrySet()) + stringMap.put(entry.getKey(), entry.getValue().toString()); + return stringMap; + } + + private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, RankingExpression> entry : map.entrySet()) + values.put(entry.getKey(), entry.getValue().getRoot().toString()); + return values; } /** @@ -165,14 +202,14 @@ public class ImportedModel { private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); - public Signature(String name) { + Signature(String name) { this.name = name; } public String name() { return name; } /** Returns the result this is part of */ - public ImportedModel owner() { return ImportedModel.this; } + ImportedModel owner() { return ImportedModel.this; } /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name @@ -181,7 +218,7 @@ public class ImportedModel { public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } /** Returns the name and type of all inputs in this signature as an immutable map */ - public Map<String, TensorType> inputMap() { + Map<String, TensorType> inputMap() { ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* // in the model, as these are the names which must actually be bound, if we are to avoid creating an @@ -217,6 +254,15 @@ public class ImportedModel { Optional.empty()); } + /** Returns the expression this output references as an imported function */ + public ImportedMlFunction outputFunction(String outputName, String functionName) { + return new ImportedMlFunction(functionName, + new ArrayList<>(inputs.values()), + owner().expressions().get(outputs.get(outputName)).getRoot().toString(), + asTensorTypeStrings(inputMap()), + Optional.empty()); + } + @Override public String toString() { return "signature '" + name + "'"; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index eee92862e7f..aec98d06874 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer; +package ai.vespa.rankingexpression.importer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import java.util.Collection; import java.util.HashMap; @@ -68,7 +68,7 @@ public class IntermediateGraph { return index.values(); } - public void optimize() { + void optimize() { renameDimensions(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 481b7f9397a..0200a9032a5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -1,13 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; @@ -21,7 +20,7 @@ import java.util.Optional; import java.util.logging.Logger; /** - * Base class for importing ML models (ONNX/TensorFlow) as native Vespa + * Base class for importing ML models (ONNX/TensorFlow etc.) as native Vespa * ranking expressions. The general mechanism for import is for the * specific ML platform import implementations to create an * IntermediateGraph. This class offers common code to convert the @@ -29,25 +28,27 @@ import java.util.logging.Logger; * * @author lesters */ -public abstract class ModelImporter { +public abstract class ModelImporter implements MlModelImporter { private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); /** Returns whether the file or directory at the given path is of the type which can be imported by this */ + @Override public abstract boolean canImport(String modelPath); - /** Imports the given model */ - public abstract ImportedModel importModel(String modelName, String modelPath); - + @Override public final ImportedModel importModel(String modelName, File modelPath) { return importModel(modelName, modelPath.toString()); } + /** Imports the given model */ + public abstract ImportedModel importModel(String modelName, String modelPath); + /** * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. */ - static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { + protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); graph.optimize(); @@ -72,17 +73,6 @@ public abstract class ModelImporter { } } - private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String inputName : signature.inputs().values()) { - if (inputName.equals(operation.name())) { - return true; - } - } - } - return false; - } - private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { @@ -97,7 +87,7 @@ public abstract class ModelImporter { /** * Convert intermediate representation to Vespa ranking expressions. */ - static void importExpressions(IntermediateGraph graph, ImportedModel model) { + private static void importExpressions(IntermediateGraph graph, ImportedModel model) { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { try { @@ -134,7 +124,7 @@ public abstract class ModelImporter { private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) { return operation.function(); } @@ -178,7 +168,7 @@ public abstract class ModelImporter { } catch (ParseException e) { throw new RuntimeException("Imported function " + function + - " cannot be parsed as a ranking expression", e); + " cannot be parsed as a ranking expression", e); } } } @@ -201,7 +191,7 @@ public abstract class ModelImporter { function.toString())); } catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + + throw new RuntimeException("Model function " + function + " cannot be parsed as a ranking expression", e); } } @@ -228,7 +218,7 @@ public abstract class ModelImporter { } /** - * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. + * Log all model Variables (i.e file constants) imported as part of this with their ordered type. * This allows users to learn the exact types (including dimension order after renaming) of the Variables * such that these can be converted and fed to a parent document independently of the rest of the model * for fast model weight updates. @@ -237,7 +227,7 @@ public abstract class ModelImporter { for (IntermediateOperation operation : graph.operations()) { if ( ! (operation instanceof Constant)) continue; if ( ! operation.type().isPresent()) continue; // will not happen - log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + + log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() + " of type " + operation.type().get()); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index 209d73a9f38..c4acfeb3235 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer; +package ai.vespa.rankingexpression.importer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorTypeParser; @@ -200,7 +200,7 @@ public class OrderedTensorType { return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... } - public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { + private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); for (int i = 0; i < dims.size(); ++ i) { String dimensionName = dimensionPrefix + i; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 3fe92440cae..dd2add973e4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -1,21 +1,21 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; +package ai.vespa.rankingexpression.importer.onnx; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.Argument; +import ai.vespa.rankingexpression.importer.operations.ConcatV2; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.Identity; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Join; +import ai.vespa.rankingexpression.importer.operations.Map; +import ai.vespa.rankingexpression.importer.operations.MatMul; +import ai.vespa.rankingexpression.importer.operations.NoOp; +import ai.vespa.rankingexpression.importer.operations.Reshape; +import ai.vespa.rankingexpression.importer.operations.Shape; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; @@ -28,9 +28,9 @@ import java.util.stream.Collectors; * * @author lesters */ -public class GraphImporter { +class GraphImporter { - public static IntermediateOperation mapOperation(Onnx.NodeProto node, + private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, IntermediateGraph graph) { String nodeName = node.getName(); @@ -79,7 +79,7 @@ public class GraphImporter { return op; } - public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { + static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { Onnx.GraphProto onnxGraph = model.getGraph(); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java index e6bb5f40b3f..0a8a797a847 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.onnx; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.ModelImporter; import onnx.Onnx; import java.io.File; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index 18856d4a25f..f3d87d89c27 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; +package ai.vespa.rankingexpression.importer.onnx; import com.google.protobuf.ByteString; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import onnx.Onnx; @@ -17,9 +17,9 @@ import java.nio.FloatBuffer; * * @author lesters */ -public class TensorConverter { +class TensorConverter { - public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) { + static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) { Values values = readValuesOf(tensorProto); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); for (int i = 0; i < values.size(); i++) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 715c55d8323..f251a14213b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; +package ai.vespa.rankingexpression.importer.onnx; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import onnx.Onnx; @@ -11,9 +11,9 @@ import onnx.Onnx; * * @author lesters */ -public class TypeConverter { +class TypeConverter { - public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { + static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); if (shape != null) { if (shape.getDimCount() != type.rank()) { @@ -30,11 +30,11 @@ public class TypeConverter { } } - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { + static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... } - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { Onnx.TensorShapeProto shape = type.getTensorType().getShape(); OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); for (int i = 0; i < shape.getDimCount(); ++ i) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java new file mode 100644 index 00000000000..9599cf8627c --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java @@ -0,0 +1,5 @@ +@ExportPackage + +package ai.vespa.rankingexpression.importer.onnx; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index 7fc2aae87d1..d6ea00ca453 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.Rename; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 1b8c62fe0e9..a21fc5ff2f7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 3c0f8569c47..41d421b1f5a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -1,10 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index 5e4abeaa234..a1cc83296b0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 742ed8b89ab..8ae6d81b8d4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index d29bd4b7a9e..c2787aa14d4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 0eff8e8bc08..60fba264635 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -1,11 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; @@ -180,7 +180,7 @@ public abstract class IntermediateOperation { /** * An interface mapping operation attributes to Vespa Values. - * Adapter for differences in ONNX/TensorFlow. + * Adapter for differences in different model types. */ public interface AttributeMap { Optional<Value> get(String key); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 8413ed74118..fed95e13bb7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index f54ae83052f..e0842d820f9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 52e223f9518..1dbfd6e40dc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index 95a77c07590..4be220db9d5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -1,10 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java index 9d9eca47b1c..ce0c58971d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java index 19ba146492c..4c5ce33b1b5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.Collections; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java index b335fd7e1c5..e5e5c29f8f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index e91c2305f7d..18f3cc1cc39 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -24,8 +24,6 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; - public class Reshape extends IntermediateOperation { public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { @@ -52,7 +50,7 @@ public class Reshape extends IntermediateOperation { int size = cell.getValue().intValue(); if (size < 0) { size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - tensorSize(inputType.type()).intValue(); + OrderedTensorType.tensorSize(inputType.type()).intValue(); } outputTypeBuilder.add(TensorType.Dimension.indexed( String.format("%s_%d", vespaName(), dimensionIndex), size)); @@ -82,7 +80,7 @@ public class Reshape extends IntermediateOperation { } public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!tensorSize(inputType).equals(tensorSize(outputType))) { + if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index 927a4a368f9..dc690329a8d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; @@ -11,8 +11,8 @@ import com.yahoo.tensor.functions.TensorFunction; import java.util.List; import java.util.function.DoubleBinaryOperator; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; +import static ai.vespa.rankingexpression.importer.OrderedTensorType.dimensionSize; +import static ai.vespa.rankingexpression.importer.OrderedTensorType.tensorSize; public class Select extends IntermediateOperation { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index da566909adc..361729a8c14 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index c750c47e27e..2eeefcbe8a2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java index 0171d1ea171..131af8de065 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package ai.vespa.rankingexpression.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java index 89b75e8e3e2..ecb67f93d69 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java @@ -1,12 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.NodeDef; @@ -20,15 +20,15 @@ import java.util.stream.Collectors; * * @author lesters */ -public class AttributeConverter implements IntermediateOperation.AttributeMap { +class AttributeConverter implements IntermediateOperation.AttributeMap { private final Map<String, AttrValue> attributeMap; - public AttributeConverter(NodeDef node) { + private AttributeConverter(NodeDef node) { attributeMap = node.getAttrMap(); } - public static AttributeConverter convert(NodeDef node) { + static AttributeConverter convert(NodeDef node) { return new AttributeConverter(node); } @@ -83,4 +83,5 @@ public class AttributeConverter implements IntermediateOperation.AttributeMap { } return Optional.empty(); } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index e1b292f9e61..cb838cd67b1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -1,29 +1,29 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.Argument; +import ai.vespa.rankingexpression.importer.operations.ConcatV2; +import ai.vespa.rankingexpression.importer.operations.Const; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.ExpandDims; +import ai.vespa.rankingexpression.importer.operations.Identity; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Join; +import ai.vespa.rankingexpression.importer.operations.Map; +import ai.vespa.rankingexpression.importer.operations.MatMul; +import ai.vespa.rankingexpression.importer.operations.Mean; +import ai.vespa.rankingexpression.importer.operations.Merge; +import ai.vespa.rankingexpression.importer.operations.NoOp; +import ai.vespa.rankingexpression.importer.operations.PlaceholderWithDefault; +import ai.vespa.rankingexpression.importer.operations.Reshape; +import ai.vespa.rankingexpression.importer.operations.Select; +import ai.vespa.rankingexpression.importer.operations.Shape; +import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Switch; import com.yahoo.tensor.functions.ScalarFunctions; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; @@ -43,9 +43,9 @@ import java.util.stream.Collectors; * * @author lesters */ -public class GraphImporter { +class GraphImporter { - public static IntermediateOperation mapOperation(NodeDef node, + private static IntermediateOperation mapOperation(NodeDef node, List<IntermediateOperation> inputs, IntermediateGraph graph) { String nodeName = node.getName(); @@ -112,7 +112,7 @@ public class GraphImporter { return op; } - public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { + static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef()); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); @@ -209,7 +209,7 @@ public class GraphImporter { throw new IllegalArgumentException("Could not find node '" + name + "'"); } - public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { + static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { Session.Runner fetched = bundle.session().runner().fetch(name); List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); if (importedTensors.size() != 1) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index d2d0acfc964..6c92ffa6055 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -26,7 +26,7 @@ public class TensorConverter { return toVespaTensor(tfTensor, "d"); } - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); @@ -35,7 +35,7 @@ public class TensorConverter { return builder.build(); } - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { + static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); for (int i = 0; i < values.size(); i++) { @@ -44,7 +44,7 @@ public class TensorConverter { return builder.build(); } - public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { + static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); Values values = readValuesOf(tensorProto); for (int i = 0; i < values.size(); ++i) { @@ -71,7 +71,7 @@ public class TensorConverter { return size; } - public static Long dimensionSize(TensorType.Dimension dim) { + private static Long dimensionSize(TensorType.Dimension dim) { return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); } @@ -182,8 +182,8 @@ public class TensorConverter { } private static abstract class ProtoValues extends Values { - protected final TensorProto tensorProto; - protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } + final TensorProto tensorProto; + ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } } private static class ProtoBoolValues extends ProtoValues { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index 7c18e04bae7..2a406f92756 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -1,8 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.ModelImporter; import org.tensorflow.SavedModelBundle; import java.io.File; @@ -47,7 +48,7 @@ public class TensorFlowImporter extends ModelImporter { } /** Imports a TensorFlow model */ - ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { + public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); return convertIntermediateGraphToModel(graph, modelDir); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index 67ad1edc312..63a605ce97a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.NodeDef; @@ -15,9 +15,9 @@ import java.util.List; * * @author lesters */ -public class TypeConverter { +class TypeConverter { - public static void verifyType(NodeDef node, OrderedTensorType type) { + static void verifyType(NodeDef node, OrderedTensorType type) { TensorShapeProto shape = tensorFlowShape(node); if (shape != null) { if (shape.getDimCount() != type.rank()) { @@ -50,11 +50,11 @@ public class TypeConverter { return shapeList.get(0); // support multiple outputs? } - public static OrderedTensorType fromTensorFlowType(NodeDef node) { + static OrderedTensorType fromTensorFlowType(NodeDef node) { return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... } - public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { + private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); TensorShapeProto shape = tensorFlowShape(node); for (int i = 0; i < shape.getDimCount(); ++ i) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java index 25bac27f315..31cb60b5509 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java @@ -1,9 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; @@ -16,7 +14,7 @@ import java.nio.charset.StandardCharsets; * * @author bratseth */ -public class VariableConverter { +class VariableConverter { /** * Reads the tensor with the given TensorFlow name at the given model location, @@ -24,7 +22,7 @@ public class VariableConverter { * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor * tensor dimensions are implicitly ordered. */ - public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { + static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, bundle), diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java new file mode 100644 index 00000000000..0840e584d25 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java @@ -0,0 +1,4 @@ +@ExportPackage +package ai.vespa.rankingexpression.importer.tensorflow; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index 725f319a839..ac462cc39eb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -1,8 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.xgboost; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost.XGBoostParser; +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.ModelImporter; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import java.io.File; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java index fef8bfec81d..2b215f816f5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost; +package ai.vespa.rankingexpression.importer.xgboost; import java.io.File; import java.io.IOException; @@ -13,7 +13,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; /** * @author grace-lam */ -public class XGBoostParser { +class XGBoostParser { private List<XGBoostTree> xgboostTrees; @@ -24,7 +24,7 @@ public class XGBoostParser { * @throws JsonProcessingException Fails JSON parsing. * @throws IOException Fails file reading. */ - public XGBoostParser(String filePath) throws JsonProcessingException, IOException { + XGBoostParser(String filePath) throws JsonProcessingException, IOException { this.xgboostTrees = new ArrayList<>(); ObjectMapper mapper = new ObjectMapper(); JsonNode forestNode = mapper.readTree(new File(filePath)); @@ -38,7 +38,7 @@ public class XGBoostParser { * * @return Vespa ranking expressions. */ - public String toRankingExpression() { + String toRankingExpression() { StringBuilder ret = new StringBuilder(); for (int i = 0; i < xgboostTrees.size(); i++) { ret.append(treeToRankExp(xgboostTrees.get(i))); @@ -55,7 +55,7 @@ public class XGBoostParser { * @param node XGBoost tree node to convert. * @return Vespa ranking expression for input node. */ - public String treeToRankExp(XGBoostTree node) { + private String treeToRankExp(XGBoostTree node) { if (node.isLeaf()) { return Double.toString(node.getLeaf()); } else { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java index 6bbc9abe8ae..e32e0f1eab5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost; +package ai.vespa.rankingexpression.importer.xgboost; import java.util.List; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java new file mode 100644 index 00000000000..d310de9041b --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java @@ -0,0 +1,4 @@ +@ExportPackage +package ai.vespa.rankingexpression.importer.xgboost; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/model-integration/src/main/java/org/tensorflow/package-info.java b/model-integration/src/main/java/org/tensorflow/package-info.java new file mode 100644 index 00000000000..cc6335f0d38 --- /dev/null +++ b/model-integration/src/main/java/org/tensorflow/package-info.java @@ -0,0 +1,6 @@ +@ExportPackage +@PublicApi +package org.tensorflow; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/searchlib/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto index dc6542867e0..dc6542867e0 100644 --- a/searchlib/src/main/protobuf/onnx.proto +++ b/model-integration/src/main/protobuf/onnx.proto diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java index b3dafff621c..cf8dd6e8e71 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import org.junit.Test; import static org.junit.Assert.assertTrue; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java index 55e1d234782..afe699d6e05 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 6d0ee0a906d..315456c2613 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -1,11 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.onnx; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -21,18 +23,18 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("test_Variable"); + Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("test_Variable_1"); + Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -55,8 +57,8 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testComparisonBetweenOnnxAndTensorflow() { - String tfModelPath = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - String onnxModelPath = "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"; + String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved"; + String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"; Tensor argument = placeholderArgument(); Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); @@ -82,8 +84,8 @@ public class OnnxMnistSoftmaxImportTestCase { private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java index e325c3d11b4..1a072f54c89 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -16,7 +16,7 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", - "src/test/files/integration/tensorflow/batch_norm/saved"); + "src/test/models/tensorflow/batch_norm/saved"); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java index 5b0cca6a940..37104ab43db 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -17,8 +17,6 @@ import org.tensorflow.Session; import java.nio.FloatBuffer; import java.util.List; -import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTensorFlowModel.contextFrom; - /** * Microbenchmark of imported ML model evaluation. * @@ -26,13 +24,13 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTenso */ public class BlogEvaluationBenchmark { - static final String modelDir = "src/test/files/integration/tensorflow/blog/saved"; + static final String modelDir = "src/test/models/tensorflow/blog/saved"; public static void main(String[] args) throws ParseException { SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel); - Context context = contextFrom(model); + Context context = TestableTensorFlowModel.contextFrom(model); Tensor u = generateInputTensor(); Tensor d = generateInputTensor(); context.put("input_u", new TensorValue(u)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index 8ca5a9a7888..5e20be051ea 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.TensorType; +import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -17,18 +18,18 @@ public class DropoutImportTestCase { @Test public void testDropoutImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/dropout/saved"); // Check required functions - assertEquals(1, model.get().inputs().size()); + Assert.assertEquals(1, model.get().inputs().size()); assertTrue(model.get().inputs().containsKey("X")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().inputs().get("X")); + Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().inputs().get("X")); ImportedModel.Signature signature = model.get().signature("serving_default"); - assertEquals("Has skipped outputs", - 0, model.get().signature("serving_default").skippedOutputs().size()); + Assert.assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); ExpressionFunction function = signature.outputExpression("y"); assertNotNull(function); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java index 3d8d5d5a570..28b91b3797a 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java @@ -1,8 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import ai.vespa.rankingexpression.importer.ImportedModel; +import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -15,11 +16,11 @@ public class MnistImportTestCase { @Test public void testMnistImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist/saved"); ImportedModel.Signature signature = model.get().signature("serving_default"); - assertEquals("Has skipped outputs", - 0, model.get().signature("serving_default").skippedOutputs().size()); + Assert.assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index feba40601e3..be676186017 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -1,10 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -18,34 +19,34 @@ public class TensorFlowMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, model.get().largeConstants().size()); + Assert.assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); + Tensor constant0 = Tensor.from(model.get().largeConstants().get("test_Variable_read")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); + Tensor constant1 = Tensor.from(model.get().largeConstants().get("test_Variable_1_read")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); // Check (provided) functions - assertEquals(0, model.get().functions().size()); + Assert.assertEquals(0, model.get().functions().size()); // Check required functions - assertEquals(1, model.get().inputs().size()); + Assert.assertEquals(1, model.get().inputs().size()); assertTrue(model.get().inputs().containsKey("Placeholder")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().inputs().get("Placeholder")); + Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().inputs().get("Placeholder")); // Check signatures - assertEquals(1, model.get().signatures().size()); + Assert.assertEquals(1, model.get().signatures().size()); ImportedModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index fbe7c5fac63..4ff0c96d369 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -93,8 +93,8 @@ public class TestableTensorFlowModel { static Context contextFrom(ImportedModel result) { TestableModelContext context = new TestableModelContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } @@ -108,7 +108,7 @@ public class TestableTensorFlowModel { private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { - RankingExpression e = model.functions().get(functionName); + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); evaluateFunctionDependencies(context, model, e.getRoot()); context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java index aabbdf33c9d..c9fffe143b4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import org.junit.Test; @@ -11,7 +11,7 @@ public class VariableConverterTestCase { @Test public void testConversion() { - byte[] converted = VariableConverter.importVariable("src/test/files/integration/tensorflow/mnist_softmax/saved", + byte[] converted = VariableConverter.importVariable("src/test/models/tensorflow/mnist_softmax/saved", "Variable_1", "tensor(d0[10],d1[1])"); assertEquals("{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"1\",\"d1\":\"0\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"2\",\"d1\":\"0\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"3\",\"d1\":\"0\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"4\",\"d1\":\"0\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"5\",\"d1\":\"0\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"6\",\"d1\":\"0\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"7\",\"d1\":\"0\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"8\",\"d1\":\"0\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"9\",\"d1\":\"0\"},\"value\":-0.2573896050453186}]}", diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java new file mode 100644 index 00000000000..965d5eb8577 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -0,0 +1,28 @@ +package ai.vespa.rankingexpression.importer.xgboost; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import ai.vespa.rankingexpression.importer.ImportedModel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * @author bratseth + */ +public class XGBoostImportTestCase { + + @Test + public void testXGBoost() { + ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json"); + assertTrue("All inputs are scalar", model.inputs().isEmpty()); + assertEquals(1, model.expressions().size()); + System.out.println(model.expressions().keySet()); + RankingExpression expression = model.expressions().get("test"); + assertNotNull(expression); + assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)", + expression.getRoot().toString()); + } + +} diff --git a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx Binary files differindex a86019bf53a..a86019bf53a 100644 --- a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx +++ b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py index bc6ea13ebc1..bc6ea13ebc1 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py +++ b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt index f3ce68a1cbd..f3ce68a1cbd 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 Binary files differindex 875e8361e10..875e8361e10 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index Binary files differindex 46c7b258cf5..46c7b258cf5 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt index a669e69b709..a669e69b709 100644 --- a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001 Binary files differindex 1efd102aef9..1efd102aef9 100644 --- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index Binary files differindex 56c60dbe529..56c60dbe529 100644 --- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py b/model-integration/src/test/models/tensorflow/dropout/dropout.py index 42c15cd2812..42c15cd2812 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py +++ b/model-integration/src/test/models/tensorflow/dropout/dropout.py diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt index ad431f0460d..ad431f0460d 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 Binary files differindex 000c9b3a7b5..000c9b3a7b5 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index Binary files differindex 9492ef4bde2..9492ef4bde2 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt index eb926836576..eb926836576 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 Binary files differindex a7ca01888c7..a7ca01888c7 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index Binary files differindex 7989c109a3a..7989c109a3a 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py index 86a17e81f8f..86a17e81f8f 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py +++ b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py index 07a9fa4a213..07a9fa4a213 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt index 8100dfd594d..8100dfd594d 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differindex 8474aa0a04c..8474aa0a04c 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index Binary files differindex cfcdac20409..cfcdac20409 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index diff --git a/model-integration/src/test/models/xgboost/xgboost.2.2.json b/model-integration/src/test/models/xgboost/xgboost.2.2.json new file mode 100644 index 00000000000..f8949b47e52 --- /dev/null +++ b/model-integration/src/test/models/xgboost/xgboost.2.2.json @@ -0,0 +1,19 @@ +[ + { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 1.71218 }, + { "nodeid": 4, "leaf": -1.70044 } + ]}, + { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [ + { "nodeid": 5, "leaf": -1.94071 }, + { "nodeid": 6, "leaf": 1.85965 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 0.784718 }, + { "nodeid": 4, "leaf": -0.96853 } + ]}, + { "nodeid": 2, "leaf": -6.23624 } + ]} +]
\ No newline at end of file diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/ConfigServerInfo.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/ConfigServerInfo.java index ec911cc5600..2f7eb25f824 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/ConfigServerInfo.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/ConfigServerInfo.java @@ -4,12 +4,10 @@ package com.yahoo.vespa.hosted.node.admin.component; import com.yahoo.vespa.athenz.api.AthenzService; import java.net.URI; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.function.Function; - -import static java.util.stream.Collectors.toMap; +import java.util.stream.Collectors; /** * Information necessary to e.g. establish communication with the config servers @@ -19,15 +17,19 @@ import static java.util.stream.Collectors.toMap; public class ConfigServerInfo { private final List<String> configServerHostNames; private final URI loadBalancerEndpoint; - private final Map<String, URI> configServerURIs; private final AthenzService configServerIdentity; + private final Function<String, URI> configServerHostnameToUriMapper; + private final List<URI> configServerURIs; public ConfigServerInfo(String loadBalancerHostName, List<String> configServerHostNames, String scheme, int port, AthenzService configServerAthenzIdentity) { this.configServerHostNames = configServerHostNames; - this.configServerURIs = createConfigServerUris(scheme, configServerHostNames, port); this.loadBalancerEndpoint = createLoadBalancerEndpoint(loadBalancerHostName, scheme, port); this.configServerIdentity = configServerAthenzIdentity; + this.configServerHostnameToUriMapper = hostname -> URI.create(scheme + "://" + hostname + ":" + port); + this.configServerURIs = configServerHostNames.stream() + .map(configServerHostnameToUriMapper) + .collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList)); } private static URI createLoadBalancerEndpoint(String loadBalancerHost, String scheme, int port) { @@ -39,16 +41,11 @@ public class ConfigServerInfo { } public List<URI> getConfigServerUris() { - return new ArrayList<>(configServerURIs.values()); + return configServerURIs; } public URI getConfigServerUri(String hostname) { - URI uri = configServerURIs.get(hostname); - if (uri == null) { - throw new IllegalArgumentException("There is no config server '" + hostname + "'"); - } - - return uri; + return configServerHostnameToUriMapper.apply(hostname); } public URI getLoadBalancerEndpoint() { @@ -58,14 +55,4 @@ public class ConfigServerInfo { public AthenzService getConfigServerIdentity() { return configServerIdentity; } - - private static Map<String, URI> createConfigServerUris( - String scheme, - List<String> configServerHosts, - int port) { - return configServerHosts.stream().collect(toMap( - Function.identity(), - hostname -> URI.create(scheme + "://" + hostname + ":" + port))); - } - } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java index 9d9ec9bc9b3..2fd40a1b486 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java @@ -36,7 +36,7 @@ import java.util.regex.Pattern; import static com.yahoo.vespa.hosted.node.admin.task.util.file.FileFinder.nameMatches; import static com.yahoo.vespa.hosted.node.admin.task.util.file.FileFinder.olderThan; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; import static com.yahoo.vespa.hosted.node.admin.util.SecretAgentCheckConfig.nodeTypeToRole; /** @@ -113,15 +113,10 @@ public class StorageMaintainer { } if (context.nodeType() == NodeType.config || context.nodeType() == NodeType.controller) { - // configserver - Path configServerCheckPath = context.pathInNodeUnderVespaHome("libexec/yms/yms_check_ymonsb2"); - configs.add(new SecretAgentCheckConfig(nodeTypeToRole(context.nodeType()), 60, configServerCheckPath, - "-zero", "configserver") - .withTags(tags)); - // configserver-new + // configserver/controller Path configServerNewCheckPath = Paths.get("/usr/bin/curl"); - configs.add(new SecretAgentCheckConfig(nodeTypeToRole(context.nodeType())+"-new", 60, configServerNewCheckPath, + configs.add(new SecretAgentCheckConfig(nodeTypeToRole(context.nodeType()), 60, configServerNewCheckPath, "-s", "localhost:19071/yamas-metrics") .withTags(tags)); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandler.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandler.java index a9d61d20f5b..13a1311b86f 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandler.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandler.java @@ -23,7 +23,7 @@ import java.util.regex.Pattern; import static com.yahoo.vespa.hosted.node.admin.task.util.file.FileFinder.nameEndsWith; import static com.yahoo.vespa.hosted.node.admin.task.util.file.FileFinder.nameMatches; import static com.yahoo.vespa.hosted.node.admin.task.util.file.FileFinder.nameStartsWith; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; /** * Finds coredumps, collects metadata and reports them diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index 25ed3ec4b59..98975dddb56 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -138,11 +138,13 @@ public class NodeAgentImpl implements NodeAgent { this.healthChecker = healthChecker; this.loopThread = new Thread(() -> { - try { - while (!terminated.get()) tick(); - } catch (Throwable t) { - numberOfUnhandledException++; - context.log(logger, LogLevel.ERROR, "Unhandled throwable, ignoring", t); + while (!terminated.get()) { + try { + tick(); + } catch (Throwable t) { + numberOfUnhandledException++; + context.log(logger, LogLevel.ERROR, "Unhandled throwable, ignoring", t); + } } }); this.loopThread.setName("tick-" + context.hostname()); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Editor.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Editor.java index 7dcae199858..bbc3427433b 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Editor.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Editor.java @@ -13,7 +13,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import java.util.logging.Logger; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; /** * An editor meant to edit small line-based files like /etc/fstab. diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java index 2a88387f8fc..5f72ed7e9b8 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java @@ -49,7 +49,7 @@ public class MakeDirectory { } } else { if (createParents) { - // We'll skip logginer system modification here, as we'll log about the creation + // We'll skip logging system modification here, as we'll log about the creation // of the directory next. path.createParents(); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java index 14fea240baa..cef35803e98 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java @@ -9,6 +9,8 @@ import java.io.StringWriter; import java.nio.file.Files; import java.nio.file.Path; +import static com.yahoo.yolean.Exceptions.uncheck; + /** * Uses the Velocity engine to render a template, to and from both String and Path objects. * @@ -30,7 +32,7 @@ public class Template { } public static Template at(Path templatePath) { - return of(IOExceptionUtil.uncheck(() -> new String(Files.readAllBytes(templatePath)))); + return of(uncheck(() -> new String(Files.readAllBytes(templatePath)))); } public static Template of(String template) { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java index 1b927cfc682..f573051eca6 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java @@ -26,7 +26,8 @@ import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; +import static com.yahoo.yolean.Exceptions.uncheckAndIgnore; /** * Thin wrapper around java.nio.file.Path, especially nice for UNIX-specific features. @@ -117,7 +118,7 @@ public class UnixPath { } public Optional<FileAttributes> getAttributesIfExists() { - return IOExceptionUtil.ifExists(this::getAttributes); + return Optional.ofNullable(uncheckAndIgnore(this::getAttributes, NoSuchFileException.class)); } public UnixPath createNewFile() { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcess2Impl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcess2Impl.java index 67020270a99..6c8b15b10a6 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcess2Impl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcess2Impl.java @@ -12,7 +12,7 @@ import java.time.Instant; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; /** * @author hakonhall diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessFactoryImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessFactoryImpl.java index 1c7a60a13fc..78c25897e1d 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessFactoryImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessFactoryImpl.java @@ -16,7 +16,7 @@ import java.util.List; import java.util.Set; import java.util.logging.Logger; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; /** * @author hakonhall diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessStarterImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessStarterImpl.java index 43105f43dc0..3d2caaa2ce9 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessStarterImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ProcessStarterImpl.java @@ -6,7 +6,7 @@ import com.yahoo.log.LogLevel; import java.util.logging.Logger; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; /** * @author hakonhall diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java index e22606104f1..db8520c63a8 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java @@ -34,7 +34,7 @@ import java.util.Optional; import java.util.function.Function; import java.util.logging.Logger; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java index cf5d29d70f1..9ea5c87511b 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java @@ -34,7 +34,7 @@ import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -117,8 +117,8 @@ public class StorageMaintainerTest { public void configserver() { Path path = executeAs(NodeType.config); - assertChecks(path, "athenz-certificate-expiry", "configserver", "configserver-new", - "host-life", "ntp", "system-coredumps-processing", "zkbackupage"); + assertChecks(path, "athenz-certificate-expiry", "configserver", "host-life", + "ntp", "system-coredumps-processing", "zkbackupage"); assertCheckEnds(path.resolve("configserver.yaml"), " tags:\n" + @@ -132,8 +132,8 @@ public class StorageMaintainerTest { public void controller() { Path path = executeAs(NodeType.controller); - assertChecks(path, "athenz-certificate-expiry", "controller", "controller-new", "host-life", - "ntp", "system-coredumps-processing", "vespa", "vespa-health", "zkbackupage"); + assertChecks(path, "athenz-certificate-expiry", "controller", "host-life", "ntp", + "system-coredumps-processing", "vespa", "vespa-health", "zkbackupage"); // Do not set namespace for vespa metrics. WHY? diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandlerTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandlerTest.java index 8d599660ace..7779a74ae03 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandlerTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/coredump/CoredumpHandlerTest.java @@ -30,7 +30,7 @@ import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; -import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; +import static com.yahoo.yolean.Exceptions.uncheck; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.mockito.ArgumentMatchers.any; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java index acf62ae91b9..398f3fceca8 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision; import com.google.common.collect.ImmutableList; @@ -52,18 +52,25 @@ public class NodeList { return new NodeList(nodes.stream().filter(node -> node.allocation().get().membership().cluster().type().equals(type)).collect(Collectors.toList())); } + /** Returns the subset of nodes that are in the given state */ + public NodeList in(Node.State state) { + return nodes.stream() + .filter(node -> node.state() == state) + .collect(collectingAndThen(Collectors.toList(), NodeList::new)); + } + /** Returns the subset of nodes owned by the given application */ public NodeList owner(ApplicationId application) { return nodes.stream() - .filter(node -> node.allocation().map(a -> a.owner().equals(application)).orElse(false)) - .collect(collectingAndThen(Collectors.toList(), NodeList::new)); + .filter(node -> node.allocation().map(a -> a.owner().equals(application)).orElse(false)) + .collect(collectingAndThen(Collectors.toList(), NodeList::new)); } /** Returns the subset of nodes matching the given node type */ public NodeList nodeType(NodeType nodeType) { return nodes.stream() - .filter(node -> node.type() == nodeType) - .collect(collectingAndThen(Collectors.toList(), NodeList::new)); + .filter(node -> node.type() == nodeType) + .collect(collectingAndThen(Collectors.toList(), NodeList::new)); } /** Returns the parent nodes of the given child nodes */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancer.java new file mode 100644 index 00000000000..effc5b1a41d --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancer.java @@ -0,0 +1,76 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.lb; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Ordering; +import com.yahoo.config.provision.HostName; + +import java.util.List; +import java.util.Objects; + +/** + * Represents a load balancer for an application. + * + * @author mpolden + */ +public class LoadBalancer { + + private final LoadBalancerId id; + private final HostName hostname; + private final List<Integer> ports; + private final List<Real> reals; + private final boolean inactive; + + public LoadBalancer(LoadBalancerId id, HostName hostname, List<Integer> ports, List<Real> reals, boolean inactive) { + this.id = Objects.requireNonNull(id, "id must be non-null"); + this.hostname = Objects.requireNonNull(hostname, "hostname must be non-null"); + this.ports = Ordering.natural().immutableSortedCopy(requirePorts(ports)); + this.reals = ImmutableList.copyOf(Objects.requireNonNull(reals, "targets must be non-null")); + this.inactive = inactive; + } + + /** An identifier for this load balancer. The ID is unique inside the zone */ + public LoadBalancerId id() { + return id; + } + + /** Fully-qualified domain name of this load balancer. This hostname can be used for query and feed */ + public HostName hostname() { + return hostname; + } + + /** Listening port(s) of this load balancer */ + public List<Integer> ports() { + return ports; + } + + /** Real servers behind this load balancer */ + public List<Real> reals() { + return reals; + } + + /** + * Returns whether this load balancer is inactive. Inactive load balancers cannot be reactivated, and are + * eventually deleted + */ + public boolean inactive() { + return inactive; + } + + /** Return a copy of this that is set inactive */ + public LoadBalancer deactivate() { + return new LoadBalancer(id, hostname, ports, reals, true); + } + + private static List<Integer> requirePorts(List<Integer> ports) { + Objects.requireNonNull(ports, "ports must be non-null"); + if (ports.isEmpty()) { + throw new IllegalArgumentException("ports must be non-empty"); + } + if (!ports.stream().allMatch(port -> port >= 1 && port <= 65535)) { + throw new IllegalArgumentException("all ports must be >= 1 and <= 65535"); + } + return ports; + } + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerId.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerId.java new file mode 100644 index 00000000000..1431f21de47 --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerId.java @@ -0,0 +1,65 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.lb; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; + +import java.util.Objects; + +/** + * Uniquely identifies a load balancer for an application's container cluster. + * + * @author mpolden + */ +public class LoadBalancerId { + + private final ApplicationId application; + private final ClusterSpec.Id cluster; + private final String serializedForm; + + public LoadBalancerId(ApplicationId application, ClusterSpec.Id cluster) { + this.application = Objects.requireNonNull(application, "application must be non-null"); + this.cluster = Objects.requireNonNull(cluster, "cluster must be non-null"); + this.serializedForm = serializedForm(application, cluster); + } + + public ApplicationId application() { + return application; + } + + public ClusterSpec.Id cluster() { + return cluster; + } + + /** Serialized form of this */ + public String serializedForm() { + return serializedForm; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LoadBalancerId that = (LoadBalancerId) o; + return Objects.equals(application, that.application) && + Objects.equals(cluster, that.cluster); + } + + @Override + public int hashCode() { + return Objects.hash(application, cluster); + } + + /** Create an instance from a serialized value on the form tenant:application:instance:cluster-id */ + public static LoadBalancerId fromSerializedForm(String value) { + int lastSeparator = value.lastIndexOf(":"); + ApplicationId application = ApplicationId.fromSerializedForm(value.substring(0, lastSeparator)); + ClusterSpec.Id cluster = ClusterSpec.Id.from(value.substring(lastSeparator + 1)); + return new LoadBalancerId(application, cluster); + } + + private static String serializedForm(ApplicationId application, ClusterSpec.Id cluster) { + return application.serializedForm() + ":" + cluster.value(); + } + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerService.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerService.java new file mode 100644 index 00000000000..b5f59414c65 --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerService.java @@ -0,0 +1,32 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.lb; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; + +import java.util.List; + +/** + * A managed load balance service. + * + * @author mpolden + */ +public interface LoadBalancerService { + + /** Create a load balancer for given application cluster. Implementations are expected to be idempotent */ + LoadBalancer create(ApplicationId application, ClusterSpec.Id cluster, List<Real> reals); + + /** Permanently remove load balancer with given ID */ + void remove(LoadBalancerId loadBalancer); + + /** Returns the protocol supported by this load balancer service */ + Protocol protocol(); + + /** Load balancer protocols */ + enum Protocol { + ipv4, + ipv6, + dualstack + } + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceProvider.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceProvider.java new file mode 100644 index 00000000000..7f37798b977 --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceProvider.java @@ -0,0 +1,45 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.lb; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.container.di.componentgraph.Provider; + +import java.util.List; + +/** + * A provider for a {@link LoadBalancerService}. This provides a default instance for cases where a component has not + * been explicitly configured. + * + * @author mpolden + */ +public class LoadBalancerServiceProvider implements Provider<LoadBalancerService> { + + private static final LoadBalancerService instance = new LoadBalancerService() { + + @Override + public LoadBalancer create(ApplicationId application, ClusterSpec.Id cluster, List<Real> reals) { + throw new UnsupportedOperationException(); + } + + @Override + public void remove(LoadBalancerId loadBalancer) { + throw new UnsupportedOperationException(); + } + + @Override + public Protocol protocol() { + throw new UnsupportedOperationException(); + } + + }; + + @Override + public LoadBalancerService get() { + return instance; + } + + @Override + public void deconstruct() {} + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/Real.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/Real.java new file mode 100644 index 00000000000..784d58f103e --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/Real.java @@ -0,0 +1,80 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.lb; + +import com.google.common.net.InetAddresses; +import com.yahoo.config.provision.HostName; + +import java.util.Objects; + +/** + * Represents a server behind a load balancer. + * + * @author mpolden + */ +public class Real { + + private static int defaultPort = 4443; + + private final HostName hostname; + private final String ipAddress; + private final int port; + + public Real(HostName hostname, String ipAddress) { + this(hostname, ipAddress, defaultPort); + } + + public Real(HostName hostname, String ipAddress, int port) { + this.hostname = hostname; + this.ipAddress = requireIpAddress(ipAddress); + if (port < 1 || port > 65535) { + throw new IllegalArgumentException("port number must be >= 1 and <= 65535"); + } + this.port = port; + } + + /** The hostname of this real */ + public HostName hostname() { + return hostname; + } + + /** Target IP address for this real */ + public String ipAddress() { + return ipAddress; + } + + /** Target port for this real */ + public int port() { + return port; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Real real = (Real) o; + return port == real.port && + Objects.equals(hostname, real.hostname) && + Objects.equals(ipAddress, real.ipAddress); + } + + @Override + public int hashCode() { + return Objects.hash(hostname, ipAddress, port); + } + + @Override + public String toString() { + return "real server " + hostname + " (" + ipAddress + ":" + port + ")"; + } + + private static String requireIpAddress(String ipAddress) { + Objects.requireNonNull(ipAddress, "ipAddress must be non-null"); + try { + InetAddresses.forString(ipAddress); + return ipAddress; + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("ipAddress must be a valid IP address", e); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/package-info.java index 1530754cc43..cdcf58b81e3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/package-info.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. /** - * ONNX integration + * @author mpolden */ @ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.vespa.hosted.provision.lb; import com.yahoo.osgi.annotation.ExportPackage; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDatabaseClient.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDatabaseClient.java index 632244cec69..d824f9fa53b 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDatabaseClient.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDatabaseClient.java @@ -16,6 +16,8 @@ import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.curator.transaction.CuratorOperations; import com.yahoo.vespa.curator.transaction.CuratorTransaction; import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.Status; @@ -34,6 +36,9 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; +import static java.util.stream.Collectors.collectingAndThen; +import static java.util.stream.Collectors.toMap; + /** * Client which reads and writes nodes to a curator database. * Nodes are stored in files named <code>/provision/v1/[nodestate]/[hostname]</code>. @@ -49,6 +54,7 @@ public class CuratorDatabaseClient { private static final Path root = Path.fromString("/provision/v1"); private static final Path lockRoot = root.append("locks"); + private static final Path loadBalancersRoot = root.append("loadBalancers"); private static final Duration defaultLockTimeout = Duration.ofMinutes(2); private final NodeSerializer nodeSerializer; @@ -76,6 +82,7 @@ public class CuratorDatabaseClient { curatorDatabase.create(inactiveJobsPath()); curatorDatabase.create(infrastructureVersionsPath()); curatorDatabase.create(osVersionsPath()); + curatorDatabase.create(loadBalancersRoot); } /** @@ -400,4 +407,50 @@ public class CuratorDatabaseClient { private Path osVersionsPath() { return root.append("osVersions"); } + + public Map<LoadBalancerId, LoadBalancer> readLoadBalancers() { + return curatorDatabase.getChildren(loadBalancersRoot).stream() + .map(LoadBalancerId::fromSerializedForm) + .map(this::readLoadBalancer) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(collectingAndThen(toMap(LoadBalancer::id, Function.identity()), + Collections::unmodifiableMap)); + } + + public List<LoadBalancer> readLoadBalancers(ApplicationId application) { + return readLoadBalancers().values().stream() + .filter(lb -> lb.id().application().equals(application)) + .collect(collectingAndThen(Collectors.toList(), Collections::unmodifiableList)); + } + + public Optional<LoadBalancer> readLoadBalancer(LoadBalancerId id) { + return read(loadBalancerPath(id), LoadBalancerSerializer::fromJson); + } + + public void writeLoadBalancer(LoadBalancer loadBalancer) { + Path path = loadBalancerPath(loadBalancer.id()); + curatorDatabase.create(path); + NestedTransaction transaction = new NestedTransaction(); + CuratorTransaction curatorTransaction = curatorDatabase.newCuratorTransactionIn(transaction); + curatorTransaction.add(CuratorOperations.setData(path.getAbsolute(), + LoadBalancerSerializer.toJson(loadBalancer))); + transaction.commit(); + } + + public void removeLoadBalancer(LoadBalancer loadBalancer) { + NestedTransaction transaction = new NestedTransaction(); + CuratorTransaction curatorTransaction = curatorDatabase.newCuratorTransactionIn(transaction); + curatorTransaction.add(CuratorOperations.delete(loadBalancerPath(loadBalancer.id()).getAbsolute())); + transaction.commit(); + } + + public Lock lockLoadBalancers() { + return lock(lockRoot.append("loadBalancersLock"), defaultLockTimeout); + } + + private Path loadBalancerPath(LoadBalancerId id) { + return loadBalancersRoot.append(id.serializedForm()); + } + } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java new file mode 100644 index 00000000000..ba29fcf2920 --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java @@ -0,0 +1,78 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.persistence; + +import com.yahoo.config.provision.HostName; +import com.yahoo.slime.ArrayTraverser; +import com.yahoo.slime.Cursor; +import com.yahoo.slime.Slime; +import com.yahoo.vespa.config.SlimeUtils; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; +import com.yahoo.vespa.hosted.provision.lb.Real; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Serializer for load balancers. + * + * @author mpolden + */ +public class LoadBalancerSerializer { + + private static final String idField = "id"; + private static final String hostnameField = "hostname"; + private static final String inactiveField = "inactive"; + private static final String portsField = "ports"; + private static final String realsField = "reals"; + private static final String ipAddressField = "ipAddress"; + private static final String portField = "port"; + + public static byte[] toJson(LoadBalancer loadBalancer) { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + + root.setString(idField, loadBalancer.id().serializedForm()); + root.setString(hostnameField, loadBalancer.hostname().toString()); + Cursor portArray = root.setArray(portsField); + loadBalancer.ports().forEach(portArray::addLong); + Cursor realArray = root.setArray(realsField); + loadBalancer.reals().forEach(real -> { + Cursor realObject = realArray.addObject(); + realObject.setString(hostnameField, real.hostname().value()); + realObject.setString(ipAddressField, real.ipAddress()); + realObject.setLong(portField, real.port()); + }); + root.setBool(inactiveField, loadBalancer.inactive()); + + try { + return SlimeUtils.toJsonBytes(slime); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static LoadBalancer fromJson(byte[] data) { + Cursor object = SlimeUtils.jsonToSlime(data).get(); + + List<Real> reals = new ArrayList<>(); + object.field(realsField).traverse((ArrayTraverser) (i, realObject) -> { + reals.add(new Real(HostName.from(realObject.field(hostnameField).asString()), + realObject.field(ipAddressField).asString(), + (int) realObject.field(portField).asLong())); + + }); + + List<Integer> ports = new ArrayList<>(); + object.field(portsField).traverse((ArrayTraverser) (i, port) -> ports.add((int) port.asLong())); + + return new LoadBalancer(LoadBalancerId.fromSerializedForm(object.field(idField).asString()), + HostName.from(object.field(hostnameField).asString()), + ports, + reals, + object.field(inactiveField).asBool()); + } + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java new file mode 100644 index 00000000000..cc2bdad70c2 --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java @@ -0,0 +1,125 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.provisioning; + +import com.google.common.net.InetAddresses; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.HostName; +import com.yahoo.transaction.Mutex; +import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.NodeList; +import com.yahoo.vespa.hosted.provision.NodeRepository; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerService; +import com.yahoo.vespa.hosted.provision.lb.Real; +import com.yahoo.vespa.hosted.provision.persistence.CuratorDatabaseClient; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Provides provisioning of load balancers for applications. + * + * @author mpolden + */ +public class LoadBalancerProvisioner { + + private final NodeRepository nodeRepository; + private final CuratorDatabaseClient db; + private final LoadBalancerService service; + + public LoadBalancerProvisioner(NodeRepository nodeRepository, LoadBalancerService service) { + this.nodeRepository = nodeRepository; + this.db = nodeRepository.database(); + this.service = service; + } + + /** + * Provision load balancer(s) for given application. + * + * If the application has multiple container clusters, one load balancer will be provisioned for each cluster. + */ + public Map<LoadBalancerId, LoadBalancer> provision(ApplicationId application) { + try (Mutex applicationLock = nodeRepository.lock(application)) { + try (Mutex loadBalancersLock = db.lockLoadBalancers()) { + Map<LoadBalancerId, LoadBalancer> loadBalancers = new LinkedHashMap<>(); + for (Map.Entry<ClusterSpec.Id, List<Node>> kv : activeContainers(application).entrySet()) { + LoadBalancer loadBalancer = create(application, kv.getKey(), kv.getValue()); + loadBalancers.put(loadBalancer.id(), loadBalancer); + db.writeLoadBalancer(loadBalancer); + } + return Collections.unmodifiableMap(loadBalancers); + } + } + } + + /** Deactivate all load balancers assigned to given application */ + public void deactivate(ApplicationId application) { + try (Mutex applicationLock = nodeRepository.lock(application)) { + try (Mutex loadBalancersLock = db.lockLoadBalancers()) { + if (!activeContainers(application).isEmpty()) { + throw new IllegalArgumentException(application + " has active containers, refusing to deactivate load balancers"); + } + db.readLoadBalancers(application) + .stream() + .map(LoadBalancer::deactivate) + .forEach(db::writeLoadBalancer); + } + } + } + + private LoadBalancer create(ApplicationId application, ClusterSpec.Id cluster, List<Node> nodes) { + Map<HostName, Set<String>> hostnameToIpAdresses = nodes.stream() + .collect(Collectors.toMap(node -> HostName.from(node.hostname()), + this::reachableIpAddresses)); + List<Real> reals = new ArrayList<>(); + hostnameToIpAdresses.forEach((hostname, ipAddresses) -> { + ipAddresses.forEach(ipAddress -> reals.add(new Real(hostname, ipAddress))); + }); + return service.create(application, cluster, reals); + } + + /** Returns a list of active containers for given application, grouped by cluster ID */ + private Map<ClusterSpec.Id, List<Node>> activeContainers(ApplicationId application) { + return new NodeList(nodeRepository.getNodes()) + .owner(application) + .in(Node.State.active) + .type(ClusterSpec.Type.container) + .asList() + .stream() + .collect(Collectors.groupingBy(n -> n.allocation().get().membership().cluster().id())); + } + + /** Find IP addresses reachable by the load balancer service */ + private Set<String> reachableIpAddresses(Node node) { + Set<String> reachable = new LinkedHashSet<>(node.ipAddresses()); + // Remove addresses unreachable by the load balancer service + switch (service.protocol()) { + case ipv4: + reachable.removeIf(this::isIpv6); + break; + case ipv6: + reachable.removeIf(this::isIpv4); + break; + } + return reachable; + } + + private boolean isIpv4(String ipAddress) { + return InetAddresses.forString(ipAddress) instanceof Inet4Address; + } + + private boolean isIpv6(String ipAddress) { + return InetAddresses.forString(ipAddress) instanceof Inet6Address; + } + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/LoadBalancerServiceMock.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/LoadBalancerServiceMock.java new file mode 100644 index 00000000000..c4e595a3fcf --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/LoadBalancerServiceMock.java @@ -0,0 +1,46 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.testutils; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.HostName; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerService; +import com.yahoo.vespa.hosted.provision.lb.Real; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * @author mpolden + */ +public class LoadBalancerServiceMock implements LoadBalancerService { + + private final Map<LoadBalancerId, LoadBalancer> loadBalancers = new HashMap<>(); + + @Override + public Protocol protocol() { + return Protocol.ipv4; + } + + @Override + public LoadBalancer create(ApplicationId application, ClusterSpec.Id cluster, List<Real> reals) { + LoadBalancer loadBalancer = new LoadBalancer( + new LoadBalancerId(application, cluster), + HostName.from("lb-" + application.toShortString() + "-" + cluster.value()), + Collections.singletonList(4443), + reals, + false); + loadBalancers.put(loadBalancer.id(), loadBalancer); + return loadBalancer; + } + + @Override + public void remove(LoadBalancerId loadBalancer) { + loadBalancers.remove(loadBalancer); + } + +} diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java new file mode 100644 index 00000000000..5fcbab64429 --- /dev/null +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java @@ -0,0 +1,45 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.persistence; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.HostName; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; +import com.yahoo.vespa.hosted.provision.lb.Real; +import org.junit.Test; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; + +/** + * @author mpolden + */ +public class LoadBalancerSerializerTest { + + @Test + public void test_serialization() { + LoadBalancer loadBalancer = new LoadBalancer(new LoadBalancerId(ApplicationId.from("tenant1", + "application1", + "default"), + ClusterSpec.Id.from("qrs")), + HostName.from("lb-host"), + Arrays.asList(4080, 4443), + Arrays.asList(new Real(HostName.from("real-1"), + "127.0.0.1", + 4080), + new Real(HostName.from("real-2"), + "127.0.0.2", + 4080)), + false); + + LoadBalancer serialized = LoadBalancerSerializer.fromJson(LoadBalancerSerializer.toJson(loadBalancer)); + assertEquals(loadBalancer.id(), serialized.id()); + assertEquals(loadBalancer.hostname(), serialized.hostname()); + assertEquals(loadBalancer.ports(), serialized.ports()); + assertEquals(loadBalancer.inactive(), serialized.inactive()); + assertEquals(loadBalancer.reals(), serialized.reals()); + } + +} diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisionerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisionerTest.java new file mode 100644 index 00000000000..ab7fda20889 --- /dev/null +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisionerTest.java @@ -0,0 +1,150 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.provisioning; + +import com.yahoo.component.Version; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.HostName; +import com.yahoo.config.provision.HostSpec; +import com.yahoo.config.provision.Zone; +import com.yahoo.transaction.NestedTransaction; +import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; +import com.yahoo.vespa.hosted.provision.lb.LoadBalancerService; +import com.yahoo.vespa.hosted.provision.lb.Real; +import com.yahoo.vespa.hosted.provision.node.Agent; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * @author mpolden + */ +public class LoadBalancerProvisionerTest { + + private final ApplicationId app1 = ApplicationId.from("tenant1", "application1", "default"); + private final ApplicationId app2 = ApplicationId.from("tenant2", "application2", "default"); + + private ProvisioningTester tester; + private LoadBalancerService service; + private LoadBalancerProvisioner loadBalancerProvisioner; + + @Before + public void before() { + tester = new ProvisioningTester(Zone.defaultZone()); + service = tester.loadBalancerService(); + loadBalancerProvisioner = new LoadBalancerProvisioner(tester.nodeRepository(), service); + } + + @Test + public void provision_load_balancer() { + ClusterSpec.Id containerCluster1 = ClusterSpec.Id.from("qrs1"); + ClusterSpec.Id contentCluster = ClusterSpec.Id.from("content"); + tester.activate(app1, prepare(app1, + clusterRequest(ClusterSpec.Type.container, containerCluster1), + clusterRequest(ClusterSpec.Type.content, contentCluster))); + tester.activate(app2, prepare(app2, + clusterRequest(ClusterSpec.Type.container, ClusterSpec.Id.from("qrs")))); + + // Provision a load balancer for each application + Map<LoadBalancerId, LoadBalancer> loadBalancers = loadBalancerProvisioner.provision(app1); + loadBalancerProvisioner.provision(app2); + assertEquals(1, loadBalancers.size()); + + LoadBalancer loadBalancer = loadBalancers.values().iterator().next(); + assertEquals(loadBalancer.id().application(), app1); + assertEquals(loadBalancer.id().cluster(), containerCluster1); + assertEquals(loadBalancer.ports(), Collections.singletonList(4443)); + assertEquals(loadBalancer.reals().get(0).ipAddress(), "127.0.0.1"); + assertEquals(loadBalancer.reals().get(0).port(), 4443); + assertEquals(loadBalancer.reals().get(1).ipAddress(), "127.0.0.2"); + assertEquals(loadBalancer.reals().get(1).port(), 4443); + + // A container is failed + List<Node> containers = tester.getNodes(app1).type(ClusterSpec.Type.container).asList(); + Node container1 = containers.get(0); + Node container2 = containers.get(1); + tester.nodeRepository().fail(container1.hostname(), Agent.system, "Failed by unit test"); + + // Reprovisioning load balancer removes failed container + loadBalancer = loadBalancerProvisioner.provision(app1).values().iterator().next(); + assertEquals(1, loadBalancer.reals().size()); + assertEquals(container2.hostname(), loadBalancer.reals().get(0).hostname().value()); + + // Redeploying replaces failed node + tester.activate(app1, prepare(app1, + clusterRequest(ClusterSpec.Type.container, containerCluster1), + clusterRequest(ClusterSpec.Type.content, contentCluster))); + + // Reprovisioning load balancer adds the new node + Node container3 = tester.getNodes(app1).type(ClusterSpec.Type.container).asList().get(1); + loadBalancer = loadBalancerProvisioner.provision(app1).values().iterator().next(); + assertEquals(2, loadBalancer.reals().size()); + assertEquals(container3.hostname(), loadBalancer.reals().get(1).hostname().value()); + + // Add another container cluster + ClusterSpec.Id containerCluster2 = ClusterSpec.Id.from("qrs2"); + tester.activate(app1, prepare(app1, + clusterRequest(ClusterSpec.Type.container, containerCluster1), + clusterRequest(ClusterSpec.Type.container, containerCluster2), + clusterRequest(ClusterSpec.Type.content, contentCluster))); + + // Load balancer is provisioned for second container cluster + loadBalancers = loadBalancerProvisioner.provision(app1); + assertEquals(2, loadBalancers.size()); + List<HostName> activeContainers = tester.getNodes(app1, Node.State.active) + .type(ClusterSpec.Type.container).asList() + .stream() + .map(Node::hostname) + .map(HostName::from) + .sorted() + .collect(Collectors.toList()); + List<HostName> reals = loadBalancers.values().stream() + .flatMap(lb -> lb.reals().stream()) + .map(Real::hostname) + .sorted() + .collect(Collectors.toList()); + assertEquals(activeContainers, reals); + + // Removing load balancer with active containers fails + try { + loadBalancerProvisioner.deactivate(app1); + fail("Expected exception"); + } catch (IllegalArgumentException ignored) {} + + // Application and load balancer is removed + NestedTransaction removeTransaction = new NestedTransaction(); + tester.provisioner().remove(removeTransaction, app1); + removeTransaction.commit(); + + loadBalancerProvisioner.deactivate(app1); + List<LoadBalancer> assignedLoadBalancer = tester.nodeRepository().database().readLoadBalancers(app1); + assertEquals(2, loadBalancers.size()); + assertTrue("Load balancers marked for deletion", assignedLoadBalancer.stream().allMatch(LoadBalancer::inactive)); + } + + private ClusterSpec clusterRequest(ClusterSpec.Type type, ClusterSpec.Id id) { + return ClusterSpec.request(type, id, Version.fromString("6.42"), false); + } + + private Set<HostSpec> prepare(ApplicationId application, ClusterSpec... specs) { + tester.makeReadyNodes(specs.length * 2, "default"); + Set<HostSpec> allNodes = new LinkedHashSet<>(); + for (ClusterSpec spec : specs) { + allNodes.addAll(tester.prepare(application, spec, 2, 1, "default")); + } + return allNodes; + } + +} diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java index 81414c0ac2d..8c0937d34b8 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java @@ -28,6 +28,7 @@ import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.filter.NodeHostFilter; import com.yahoo.vespa.hosted.provision.persistence.NameResolver; +import com.yahoo.vespa.hosted.provision.testutils.LoadBalancerServiceMock; import com.yahoo.vespa.hosted.provision.testutils.MockNameResolver; import com.yahoo.vespa.orchestrator.Orchestrator; import com.yahoo.vespa.service.monitor.application.ConfigServerApplication; @@ -68,6 +69,7 @@ public class ProvisioningTester { private final NodeRepositoryProvisioner provisioner; private final CapacityPolicies capacityPolicies; private final ProvisionLogger provisionLogger; + private final LoadBalancerServiceMock loadBalancerService; private int nextHost = 0; private int nextIP = 0; @@ -90,6 +92,7 @@ public class ProvisioningTester { this.orchestrator = mock(Orchestrator.class); doThrow(new RuntimeException()).when(orchestrator).acquirePermissionToRemove(any()); this.provisioner = new NodeRepositoryProvisioner(nodeRepository, nodeFlavors, zone); + this.loadBalancerService = new LoadBalancerServiceMock(); this.capacityPolicies = new CapacityPolicies(zone, nodeFlavors); this.provisionLogger = new NullProvisionLogger(); } @@ -126,6 +129,7 @@ public class ProvisioningTester { public Orchestrator orchestrator() { return orchestrator; } public ManualClock clock() { return clock; } public NodeRepositoryProvisioner provisioner() { return provisioner; } + public LoadBalancerServiceMock loadBalancerService() { return loadBalancerService; } public CapacityPolicies capacityPolicies() { return capacityPolicies; } public NodeList getNodes(ApplicationId id, Node.State ... inState) { return new NodeList(nodeRepository.getNodes(id, inState)); } @@ -95,6 +95,7 @@ <module>messagebus</module> <module>metrics</module> <module>model-evaluation</module> + <module>model-integration</module> <module>node-repository</module> <module>node-admin</module> <module>node-maintainer</module> diff --git a/searchcore/src/tests/fdispatch/randomrow/randomrow_test.cpp b/searchcore/src/tests/fdispatch/randomrow/randomrow_test.cpp index 217e994980b..41db91035ed 100644 --- a/searchcore/src/tests/fdispatch/randomrow/randomrow_test.cpp +++ b/searchcore/src/tests/fdispatch/randomrow/randomrow_test.cpp @@ -16,18 +16,25 @@ TEST("requireThatEmpyStateReturnsRowZero") TEST("requireThatDecayWorks") { + constexpr double SMALL = 0.000001; StateOfRows s(1, 1.0, 1000); s.updateSearchTime(1.0, 0); EXPECT_EQUAL(1.0, s.getRowState(0).getAverageSearchTime()); s.updateSearchTime(2.0, 0); - EXPECT_EQUAL(1.001, s.getRowState(0).getAverageSearchTime()); + EXPECT_APPROX(1.5, s.getRowState(0).getAverageSearchTime(), SMALL); s.updateSearchTime(2.0, 0); - EXPECT_APPROX(1.002, s.getRowState(0).getAverageSearchTime(), 0.0001); + EXPECT_APPROX(1.666667, s.getRowState(0).getAverageSearchTime(), SMALL); s.updateSearchTime(0.1, 0); s.updateSearchTime(0.1, 0); s.updateSearchTime(0.1, 0); s.updateSearchTime(0.1, 0); - EXPECT_APPROX(0.998396, s.getRowState(0).getAverageSearchTime(), 0.000001); + EXPECT_APPROX(0.771429, s.getRowState(0).getAverageSearchTime(), SMALL); + for (size_t i(0); i < 10000; i++) { + s.updateSearchTime(1.0, 0); + } + EXPECT_APPROX(1.0, s.getRowState(0).getAverageSearchTime(), SMALL); + s.updateSearchTime(0.1, 0); + EXPECT_APPROX(0.9991, s.getRowState(0).getAverageSearchTime(), SMALL); } TEST("requireWeightedSelectionWorks") diff --git a/searchcore/src/vespa/searchcore/config/partitions.def b/searchcore/src/vespa/searchcore/config/partitions.def index 5213d4ef72e..0b0a434b219 100644 --- a/searchcore/src/vespa/searchcore/config/partitions.def +++ b/searchcore/src/vespa/searchcore/config/partitions.def @@ -53,8 +53,8 @@ dataset[].mpp int default=1 ## for queries when using the FIXEDROW query distribution. dataset[].maxnodesdownperfixedrow int default=0 -## Use simple roundrobin or random. -dataset[].useroundrobinforfixedrow bool default=true +## Use simple roundrobin or adaptive based on latency. +dataset[].useroundrobinforfixedrow bool default=false ## specifies where a fdispatch or fsearch process can be contacted. ## must be in the format hostname:port/id where /id is optional. @@ -206,4 +206,4 @@ dataset[].min_activedocs_coverage double default=97.0 ## Decay rate used when loadbalancing between groups. ## Lower number will react faster to changes in cluster. -dataset[].latency_decay_rate double default=10000 +dataset[].latency_decay_rate double default=1000 diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/engineadapter.h b/searchcore/src/vespa/searchcore/fdispatch/program/engineadapter.h index 30fe83009f6..add5f045d51 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/program/engineadapter.h +++ b/searchcore/src/vespa/searchcore/fdispatch/program/engineadapter.h @@ -35,12 +35,11 @@ public: typedef search::engine::DocsumClient DocsumClient; typedef search::engine::MonitorClient MonitorClient; - EngineAdapter(FastS_AppContext *appCtx, - FastOS_ThreadPool *threadPool); + EngineAdapter(FastS_AppContext *appCtx, FastOS_ThreadPool *threadPool); - virtual SearchReply::UP search(SearchRequest::Source request, SearchClient &client) override; - virtual DocsumReply::UP getDocsums(DocsumRequest::Source request, DocsumClient &client) override; - virtual MonitorReply::UP ping(MonitorRequest::UP request, MonitorClient &client) override; + SearchReply::UP search(SearchRequest::Source request, SearchClient &client) override; + DocsumReply::UP getDocsums(DocsumRequest::Source request, DocsumClient &client) override; + MonitorReply::UP ping(MonitorRequest::UP request, MonitorClient &client) override; }; } // namespace fdispatch diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/plain_dataset.cpp b/searchcore/src/vespa/searchcore/fdispatch/search/plain_dataset.cpp index b48f1abfb51..4e1bc62b790 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/plain_dataset.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/search/plain_dataset.cpp @@ -126,8 +126,7 @@ FastS_PartitionMap::LinkIn(FastS_EngineBase *engine) _childmaxnodesSinceReload = std::max(_childmaxnodesSinceReload, _childmaxnodesNow); _childnodes += engine->_reported._actNodes; if (part._maxpartsNow <= engine->_reported._maxParts) { - _childmaxpartsNow += engine->_reported._maxParts - - part._maxpartsNow; + _childmaxpartsNow += engine->_reported._maxParts - part._maxpartsNow; _childmaxpartsSinceReload += std::max(_childmaxpartsSinceReload, _childmaxpartsNow); part._maxpartsNow = engine->_reported._maxParts; } @@ -196,7 +195,7 @@ FastS_PlainDataSet::FastS_PlainDataSet(FastS_AppContext *appCtx, FastS_DataSetDesc *desc) : FastS_DataSetBase(appCtx, desc), _partMap(desc), - _stateOfRows(_partMap.getNumRows(), 1.0, desc->GetQueryDistributionMode().getLatencyDecayRate()), + _stateOfRows(_partMap.getNumRows(), 0.010, desc->GetQueryDistributionMode().getLatencyDecayRate()), _MHPN_log(), _slowQueryLimitFactor(desc->GetSlowQueryLimitFactor()), _slowQueryLimitBias(desc->GetSlowQueryLimitBias()), diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.cpp b/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.cpp index c4b0319e6cb..b0ca9f3463b 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.cpp @@ -6,10 +6,12 @@ namespace fdispatch { void RowState::updateSearchTime(double searchTime) { - _avgSearchTime = (searchTime + (_decayRate-1)*_avgSearchTime)/_decayRate; + _numQueries++; + double decayRate = std::min(_numQueries, _decayRate); + _avgSearchTime = (searchTime + (decayRate-1)*_avgSearchTime)/decayRate; } -StateOfRows::StateOfRows(size_t numRows, double initialValue, double decayRate) : +StateOfRows::StateOfRows(size_t numRows, double initialValue, uint64_t decayRate) : _rows(numRows, RowState(initialValue, decayRate)), _sumActiveDocs(0), _invalidActiveDocsCounter(0) { diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.h b/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.h index 07bc769cfdd..00bf7f1bd91 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.h +++ b/searchcore/src/vespa/searchcore/fdispatch/search/rowstate.h @@ -15,10 +15,11 @@ namespace fdispatch { **/ class RowState { public: - RowState(double initialValue, double decayRate) : + RowState(double initialValue, uint64_t decayRate) : + _decayRate(std::max(1ul, decayRate)), _avgSearchTime(initialValue), - _decayRate(decayRate), - _sumActiveDocs(0) + _sumActiveDocs(0), + _numQueries(0) { } double getAverageSearchTime() const { return _avgSearchTime; } double getAverageSearchTimeInverse() const { return 1.0/_avgSearchTime; } @@ -30,9 +31,10 @@ public: _sumActiveDocs = tmp; } private: - double _avgSearchTime; - double _decayRate; + const uint64_t _decayRate; + double _avgSearchTime; uint64_t _sumActiveDocs; + uint64_t _numQueries; }; /** @@ -43,7 +45,7 @@ private: **/ class StateOfRows { public: - StateOfRows(size_t numRows, double initial, double decayRate); + StateOfRows(size_t numRows, double initial, uint64_t decayRate); void updateSearchTime(double searchTime, uint32_t rowId); const RowState & getRowState(uint32_t rowId) const { return _rows[rowId]; } RowState & getRowState(uint32_t rowId) { return _rows[rowId]; } @@ -56,8 +58,8 @@ public: bool activeDocsValid() const { return _invalidActiveDocsCounter == 0; } private: std::vector<RowState> _rows; - uint64_t _sumActiveDocs; - size_t _invalidActiveDocsCounter; + uint64_t _sumActiveDocs; + size_t _invalidActiveDocsCounter; }; } diff --git a/searchlib/pom.xml b/searchlib/pom.xml index f1be4e96269..2dec6c92619 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -33,18 +33,13 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>vespajlib</artifactId> <version>${project.version}</version> + <scope>provided</scope> </dependency> <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> + <groupId>com.yahoo.vespa</groupId> + <artifactId>config-model-api</artifactId> + <version>${project.version}</version> + <scope>provided</scope> </dependency> <dependency> <groupId>com.fasterxml.jackson.core</groupId> @@ -117,26 +112,6 @@ </execution> </executions> </plugin> - <plugin> - <groupId>com.github.os72</groupId> - <artifactId>protoc-jar-maven-plugin</artifactId> - <version>3.5.1.1</version> - <executions> - <execution> - <phase>generate-sources</phase> - <goals> - <goal>run</goal> - </goals> - <configuration> - <addSources>main</addSources> - <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory> - <inputDirectories> - <include>src/main/protobuf</include> - </inputDirectories> - </configuration> - </execution> - </executions> - </plugin> </plugins> </build> </project> diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp index e72caf9405b..015eb70c74a 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp @@ -44,18 +44,18 @@ StructFields::StructFields(const vespalib::string &fieldName, const IAttributeMa _hasMapKey(false), _error(false) { - // Note: Doesn't handle imported attributes - std::vector<AttributeGuard> attrs; - attrMgr.getAttributeList(attrs); + std::vector<const search::attribute::IAttributeVector *> attrs; + auto attrCtx = attrMgr.createContext(); + attrCtx->getAttributeList(attrs); vespalib::string prefix = fieldName + "."; vespalib::string keyName = prefix + "key"; vespalib::string valuePrefix = prefix + "value."; - for (const auto &guard : attrs) { - vespalib::string name = guard->getName(); + for (const auto attr : attrs) { + vespalib::string name = attr->getName(); if (name.substr(0, prefix.size()) != prefix) { continue; } - auto collType = guard->getCollectionType(); + auto collType = attr->getCollectionType(); if (collType != CollectionType::Type::ARRAY) { LOG(warning, "Attribute %s is not an array attribute", name.c_str()); _error = true; diff --git a/security-utils/pom.xml b/security-utils/pom.xml index 7006a0f5f86..6f094f28362 100644 --- a/security-utils/pom.xml +++ b/security-utils/pom.xml @@ -43,6 +43,12 @@ <artifactId>hamcrest-library</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>testutil</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> </dependencies> <build> <plugins> diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java b/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java index 67466179634..82caf02223f 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java @@ -1,13 +1,18 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.security.tls; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.security.tls.json.TransportSecurityOptionsJsonSerializer; +import com.yahoo.security.tls.policy.AuthorizedPeers; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.Paths; import java.util.Objects; import java.util.Optional; @@ -18,61 +23,88 @@ import java.util.Optional; */ public class TransportSecurityOptions { - private static final ObjectMapper mapper = new ObjectMapper(); - private final Path privateKeyFile; private final Path certificatesFile; private final Path caCertificatesFile; + private final AuthorizedPeers authorizedPeers; - public TransportSecurityOptions(String privateKeyFile, String certificatesFile, String caCertificatesFile) { - this(Paths.get(privateKeyFile), Paths.get(certificatesFile), Paths.get(caCertificatesFile)); + private TransportSecurityOptions(Builder builder) { + this.privateKeyFile = builder.privateKeyFile; + this.certificatesFile = builder.certificatesFile; + this.caCertificatesFile = builder.caCertificatesFile; + this.authorizedPeers = builder.authorizedPeers; } - public TransportSecurityOptions(Path privateKeyFile, Path certificatesFile, Path caCertificatesFile) { - this.privateKeyFile = privateKeyFile; - this.certificatesFile = certificatesFile; - this.caCertificatesFile = caCertificatesFile; + public Optional<Path> getPrivateKeyFile() { + return Optional.ofNullable(privateKeyFile); } - public Path getPrivateKeyFile() { - return privateKeyFile; + public Optional<Path> getCertificatesFile() { + return Optional.ofNullable(certificatesFile); } - public Path getCertificatesFile() { - return certificatesFile; + public Optional<Path> getCaCertificatesFile() { + return Optional.ofNullable(caCertificatesFile); } - public Path getCaCertificatesFile() { - return caCertificatesFile; + public Optional<AuthorizedPeers> getAuthorizedPeers() { + return Optional.ofNullable(authorizedPeers); } public static TransportSecurityOptions fromJsonFile(Path file) { - try { - return fromJsonNode(mapper.readTree(file.toFile())); + try (InputStream in = Files.newInputStream(file)) { + return new TransportSecurityOptionsJsonSerializer().deserialize(in); } catch (IOException e) { throw new UncheckedIOException(e); } } public static TransportSecurityOptions fromJson(String json) { - try { - return fromJsonNode(mapper.readTree(json)); + return new TransportSecurityOptionsJsonSerializer() + .deserialize(new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))); + } + + public String toJson() { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + new TransportSecurityOptionsJsonSerializer().serialize(out, this); + return new String(out.toByteArray(), StandardCharsets.UTF_8); + } + + public void toJsonFile(Path file) { + try (OutputStream out = Files.newOutputStream(file)) { + new TransportSecurityOptionsJsonSerializer().serialize(out, this); } catch (IOException e) { throw new UncheckedIOException(e); } } - private static TransportSecurityOptions fromJsonNode(JsonNode root) { - JsonNode filesNode = getField(root, "files"); - String privateKeyFile = getField(filesNode, "private-key").asText(); - String certificatesFile = getField(filesNode, "certificates").asText(); - String caCertificatesFile = getField(filesNode, "ca-certificates").asText(); - return new TransportSecurityOptions(privateKeyFile, certificatesFile, caCertificatesFile); - } + public static class Builder { + private Path privateKeyFile; + private Path certificatesFile; + private Path caCertificatesFile; + private AuthorizedPeers authorizedPeers; + + public Builder() {} - private static JsonNode getField(JsonNode root, String fieldName) { - return Optional.ofNullable(root.get(fieldName)) - .orElseThrow(() -> new IllegalArgumentException(String.format("'%s' field missing", fieldName))); + public Builder withCertificates(Path certificatesFile, Path privateKeyFile) { + this.certificatesFile = certificatesFile; + this.privateKeyFile = privateKeyFile; + return this; + } + + public Builder withCaCertificates(Path caCertificatesFile) { + this.caCertificatesFile = caCertificatesFile; + return this; + } + + public Builder withAuthorizedPeers(AuthorizedPeers authorizedPeers) { + this.authorizedPeers = authorizedPeers; + return this; + } + + public TransportSecurityOptions build() { + return new TransportSecurityOptions(this); + } } @Override @@ -81,6 +113,7 @@ public class TransportSecurityOptions { "privateKeyFile=" + privateKeyFile + ", certificatesFile=" + certificatesFile + ", caCertificatesFile=" + caCertificatesFile + + ", authorizedPeers=" + authorizedPeers + '}'; } @@ -91,11 +124,12 @@ public class TransportSecurityOptions { TransportSecurityOptions that = (TransportSecurityOptions) o; return Objects.equals(privateKeyFile, that.privateKeyFile) && Objects.equals(certificatesFile, that.certificatesFile) && - Objects.equals(caCertificatesFile, that.caCertificatesFile); + Objects.equals(caCertificatesFile, that.caCertificatesFile) && + Objects.equals(authorizedPeers, that.authorizedPeers); } @Override public int hashCode() { - return Objects.hash(privateKeyFile, certificatesFile, caCertificatesFile); + return Objects.hash(privateKeyFile, certificatesFile, caCertificatesFile, authorizedPeers); } }
\ No newline at end of file diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityUtils.java b/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityUtils.java index 5595d33a9b5..adae2e82873 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityUtils.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/TransportSecurityUtils.java @@ -26,6 +26,10 @@ public class TransportSecurityUtils { this.configValue = configValue; } + public String configValue() { + return configValue; + } + static MixedMode fromConfigValue(String configValue) { return Arrays.stream(values()) .filter(v -> v.configValue.equals(configValue)) diff --git a/security-utils/src/main/java/com/yahoo/security/tls/json/TransportSecurityOptionsEntity.java b/security-utils/src/main/java/com/yahoo/security/tls/json/TransportSecurityOptionsEntity.java new file mode 100644 index 00000000000..fbb98d7c382 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/json/TransportSecurityOptionsEntity.java @@ -0,0 +1,41 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.json; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_EMPTY; + +/** + * Jackson bindings for transport security options + * + * @author bjorncs + */ +@JsonIgnoreProperties(ignoreUnknown = true) +class TransportSecurityOptionsEntity { + + @JsonProperty("files") Files files; + @JsonProperty("authorized-peers") @JsonInclude(NON_EMPTY) List<AuthorizedPeer> authorizedPeers; + + static class Files { + @JsonProperty("private-key") String privateKeyFile; + @JsonProperty("certificates") String certificatesFile; + @JsonProperty("ca-certificates") String caCertificatesFile; + } + + static class AuthorizedPeer { + @JsonProperty("required-credentials") List<RequiredCredential> requiredCredentials; + @JsonProperty("name") String name; + @JsonProperty("roles") @JsonInclude(NON_EMPTY) List<String> roles; + } + + static class RequiredCredential { + @JsonProperty("field") CredentialField field; + @JsonProperty("must-match") String matchExpression; + } + + enum CredentialField { CN, SAN_DNS } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializer.java b/security-utils/src/main/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializer.java new file mode 100644 index 00000000000..f75cb4bcfff --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializer.java @@ -0,0 +1,162 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.json; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.security.tls.TransportSecurityOptions; +import com.yahoo.security.tls.json.TransportSecurityOptionsEntity.AuthorizedPeer; +import com.yahoo.security.tls.json.TransportSecurityOptionsEntity.CredentialField; +import com.yahoo.security.tls.json.TransportSecurityOptionsEntity.Files; +import com.yahoo.security.tls.json.TransportSecurityOptionsEntity.RequiredCredential; +import com.yahoo.security.tls.policy.AuthorizedPeers; +import com.yahoo.security.tls.policy.HostGlobPattern; +import com.yahoo.security.tls.policy.PeerPolicy; +import com.yahoo.security.tls.policy.RequiredPeerCredential; +import com.yahoo.security.tls.policy.Role; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; + +/** + * @author bjorncs + */ +public class TransportSecurityOptionsJsonSerializer { + + private static final ObjectMapper mapper = new ObjectMapper(); + + public TransportSecurityOptions deserialize(InputStream in) { + try { + TransportSecurityOptionsEntity entity = mapper.readValue(in, TransportSecurityOptionsEntity.class); + return toTransportSecurityOptions(entity); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void serialize(OutputStream out, TransportSecurityOptions options) { + try { + mapper.writerWithDefaultPrettyPrinter().writeValue(out, toTransportSecurityOptionsEntity(options)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static TransportSecurityOptions toTransportSecurityOptions(TransportSecurityOptionsEntity entity) { + TransportSecurityOptions.Builder builder = new TransportSecurityOptions.Builder(); + Files files = entity.files; + if (files != null) { + if (files.certificatesFile != null && files.privateKeyFile != null) { + builder.withCertificates(Paths.get(files.certificatesFile), Paths.get(files.privateKeyFile)); + } else if (files.certificatesFile != null || files.privateKeyFile != null) { + throw new IllegalArgumentException("Both 'private-key' and 'certificates' must be configured together"); + } + if (files.caCertificatesFile != null) { + builder.withCaCertificates(Paths.get(files.caCertificatesFile)); + } + } + List<AuthorizedPeer> authorizedPeersEntity = entity.authorizedPeers; + if (authorizedPeersEntity != null) { + if (authorizedPeersEntity.size() == 0) { + throw new IllegalArgumentException("'authorized-peers' cannot be empty"); + } + builder.withAuthorizedPeers(new AuthorizedPeers(toPeerPolicies(authorizedPeersEntity))); + } + return builder.build(); + } + + private static Set<PeerPolicy> toPeerPolicies(List<AuthorizedPeer> authorizedPeersEntity) { + return authorizedPeersEntity.stream() + .map(TransportSecurityOptionsJsonSerializer::toPeerPolicy) + .collect(toSet()); + } + + private static PeerPolicy toPeerPolicy(AuthorizedPeer authorizedPeer) { + if (authorizedPeer.name == null) { + throw missingFieldException("name"); + } + if (authorizedPeer.requiredCredentials == null) { + throw missingFieldException("required-credentials"); + } + return new PeerPolicy(authorizedPeer.name, toRoles(authorizedPeer.roles), toRequestPeerCredentials(authorizedPeer.requiredCredentials)); + } + + private static Set<Role> toRoles(List<String> roles) { + if (roles == null) return Collections.emptySet(); + return roles.stream() + .map(Role::new) + .collect(toSet()); + } + + private static List<RequiredPeerCredential> toRequestPeerCredentials(List<RequiredCredential> requiredCredentials) { + return requiredCredentials.stream() + .map(TransportSecurityOptionsJsonSerializer::toRequiredPeerCredential) + .collect(toList()); + } + + private static RequiredPeerCredential toRequiredPeerCredential(RequiredCredential requiredCredential) { + if (requiredCredential.field == null) { + throw missingFieldException("field"); + } + if (requiredCredential.matchExpression == null) { + throw missingFieldException("must-match"); + } + return new RequiredPeerCredential(toField(requiredCredential.field), new HostGlobPattern(requiredCredential.matchExpression)); + } + + private static RequiredPeerCredential.Field toField(CredentialField field) { + switch (field) { + case CN: return RequiredPeerCredential.Field.CN; + case SAN_DNS: return RequiredPeerCredential.Field.SAN_DNS; + default: throw new IllegalArgumentException("Invalid field type: " + field); + } + } + + private static TransportSecurityOptionsEntity toTransportSecurityOptionsEntity(TransportSecurityOptions options) { + TransportSecurityOptionsEntity entity = new TransportSecurityOptionsEntity(); + entity.files = new Files(); + options.getCaCertificatesFile().ifPresent(value -> entity.files.caCertificatesFile = value.toString()); + options.getCertificatesFile().ifPresent(value -> entity.files.certificatesFile = value.toString()); + options.getPrivateKeyFile().ifPresent(value -> entity.files.privateKeyFile = value.toString()); + options.getAuthorizedPeers().ifPresent( authorizedPeers -> { + entity.authorizedPeers = new ArrayList<>(); + for (PeerPolicy peerPolicy : authorizedPeers.peerPolicies()) { + AuthorizedPeer authorizedPeer = new AuthorizedPeer(); + authorizedPeer.name = peerPolicy.policyName(); + authorizedPeer.requiredCredentials = new ArrayList<>(); + for (RequiredPeerCredential requiredPeerCredential : peerPolicy.requiredCredentials()) { + RequiredCredential requiredCredential = new RequiredCredential(); + requiredCredential.field = toField(requiredPeerCredential.field()); + requiredCredential.matchExpression = requiredPeerCredential.pattern().asString(); + authorizedPeer.requiredCredentials.add(requiredCredential); + } + if (!peerPolicy.assumedRoles().isEmpty()) { + authorizedPeer.roles = new ArrayList<>(); + peerPolicy.assumedRoles().forEach(role -> authorizedPeer.roles.add(role.name())); + } + entity.authorizedPeers.add(authorizedPeer); + } + }); + return entity; + } + + private static CredentialField toField(RequiredPeerCredential.Field field) { + switch (field) { + case CN: return CredentialField.CN; + case SAN_DNS: return CredentialField.SAN_DNS; + default: throw new IllegalArgumentException("Invalid field type: " + field); + } + } + + private static IllegalArgumentException missingFieldException(String fieldName) { + return new IllegalArgumentException(String.format("'%s' missing", fieldName)); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/json/package-info.java b/security-utils/src/main/java/com/yahoo/security/tls/json/package-info.java new file mode 100644 index 00000000000..2aaf276b439 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/json/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.security.tls.json; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/security-utils/src/main/java/com/yahoo/security/tls/policy/AuthorizedPeers.java b/security-utils/src/main/java/com/yahoo/security/tls/policy/AuthorizedPeers.java new file mode 100644 index 00000000000..d62219b2ebe --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/policy/AuthorizedPeers.java @@ -0,0 +1,53 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.policy; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * @author bjorncs + */ +public class AuthorizedPeers { + + private final Set<PeerPolicy> peerPolicies; + + public AuthorizedPeers(Set<PeerPolicy> peerPolicies) { + this.peerPolicies = verifyPeerPolicies(peerPolicies); + } + + private Set<PeerPolicy> verifyPeerPolicies(Set<PeerPolicy> peerPolicies) { + long distinctNames = peerPolicies.stream() + .map(PeerPolicy::policyName) + .distinct() + .count(); + if (distinctNames != peerPolicies.size()) { + throw new IllegalArgumentException("'authorized-peers' contains entries with duplicate names"); + } + return Collections.unmodifiableSet(peerPolicies); + } + + public Set<PeerPolicy> peerPolicies() { + return peerPolicies; + } + + @Override + public String toString() { + return "AuthorizedPeers{" + + "peerPolicies=" + peerPolicies + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AuthorizedPeers that = (AuthorizedPeers) o; + return Objects.equals(peerPolicies, that.peerPolicies); + } + + @Override + public int hashCode() { + return Objects.hash(peerPolicies); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/policy/HostGlobPattern.java b/security-utils/src/main/java/com/yahoo/security/tls/policy/HostGlobPattern.java new file mode 100644 index 00000000000..c7acf5dfbeb --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/policy/HostGlobPattern.java @@ -0,0 +1,72 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.policy; + +import java.util.Objects; +import java.util.regex.Pattern; + +/** + * @author bjorncs + */ +public class HostGlobPattern { + + private final String pattern; + private final Pattern regexPattern; + + public HostGlobPattern(String pattern) { + this.pattern = pattern; + this.regexPattern = toRegexPattern(pattern); + } + + public String asString() { + return pattern; + } + + public boolean matches(String hostString) { + return regexPattern.matcher(hostString).matches(); + } + + private static Pattern toRegexPattern(String pattern) { + StringBuilder builder = new StringBuilder("^"); + for (char c : pattern.toCharArray()) { + if (c == '*') { + // Note: we explicitly stop matching at a dot separator boundary. + // This is to make host name matching less vulnerable to dirty tricks. + builder.append("[^.]*"); + } else if (c == '?') { + // Same applies for single chars; they should only match _within_ a dot boundary. + builder.append("[^.]"); + } else if (isRegexMetaCharacter(c)){ + builder.append("\\"); + builder.append(c); + } else { + builder.append(c); + } + } + builder.append('$'); + return Pattern.compile(builder.toString()); + } + + private static boolean isRegexMetaCharacter(char c) { + return "<([{\\^-=$!|]})?*+.>".indexOf(c) != -1; // note: includes '?' and '*' + } + + @Override + public String toString() { + return "HostGlobPattern{" + + "pattern='" + pattern + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + HostGlobPattern that = (HostGlobPattern) o; + return Objects.equals(pattern, that.pattern); + } + + @Override + public int hashCode() { + return Objects.hash(pattern); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/policy/PeerPolicy.java b/security-utils/src/main/java/com/yahoo/security/tls/policy/PeerPolicy.java new file mode 100644 index 00000000000..294f8543f43 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/policy/PeerPolicy.java @@ -0,0 +1,59 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.policy; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * @author bjorncs + */ +public class PeerPolicy { + + private final String policyName; + private final Set<Role> assumedRoles; + private final List<RequiredPeerCredential> requiredCredentials; + + public PeerPolicy(String policyName, Set<Role> assumedRoles, List<RequiredPeerCredential> requiredCredentials) { + this.policyName = policyName; + this.assumedRoles = assumedRoles; + this.requiredCredentials = Collections.unmodifiableList(requiredCredentials); + } + + public String policyName() { + return policyName; + } + + public Set<Role> assumedRoles() { + return assumedRoles; + } + + public List<RequiredPeerCredential> requiredCredentials() { + return requiredCredentials; + } + + @Override + public String toString() { + return "PeerPolicy{" + + "policyName='" + policyName + '\'' + + ", assumedRoles=" + assumedRoles + + ", requiredCredentials=" + requiredCredentials + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PeerPolicy that = (PeerPolicy) o; + return Objects.equals(policyName, that.policyName) && + Objects.equals(assumedRoles, that.assumedRoles) && + Objects.equals(requiredCredentials, that.requiredCredentials); + } + + @Override + public int hashCode() { + return Objects.hash(policyName, assumedRoles, requiredCredentials); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/policy/RequiredPeerCredential.java b/security-utils/src/main/java/com/yahoo/security/tls/policy/RequiredPeerCredential.java new file mode 100644 index 00000000000..4f028d8b1ab --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/policy/RequiredPeerCredential.java @@ -0,0 +1,50 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.policy; + +import java.util.Objects; + +/** + * @author bjorncs + */ +public class RequiredPeerCredential { + + public enum Field { CN, SAN_DNS } + + private final Field field; + private final HostGlobPattern pattern; + + public RequiredPeerCredential(Field field, HostGlobPattern pattern) { + this.field = field; + this.pattern = pattern; + } + + public Field field() { + return field; + } + + public HostGlobPattern pattern() { + return pattern; + } + + @Override + public String toString() { + return "RequiredPeerCredential{" + + "field=" + field + + ", pattern=" + pattern + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RequiredPeerCredential that = (RequiredPeerCredential) o; + return field == that.field && + Objects.equals(pattern, that.pattern); + } + + @Override + public int hashCode() { + return Objects.hash(field, pattern); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/policy/Role.java b/security-utils/src/main/java/com/yahoo/security/tls/policy/Role.java new file mode 100644 index 00000000000..6d64ccff2c5 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/policy/Role.java @@ -0,0 +1,40 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.policy; + +import java.util.Objects; + +/** + * @author bjorncs + */ +public class Role { + + private final String name; + + public Role(String name) { + this.name = name; + } + + public String name() { + return name; + } + + @Override + public String toString() { + return "Role{" + + "name='" + name + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Role role = (Role) o; + return Objects.equals(name, role.name); + } + + @Override + public int hashCode() { + return Objects.hash(name); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/policy/package-info.java b/security-utils/src/main/java/com/yahoo/security/tls/policy/package-info.java new file mode 100644 index 00000000000..4215bd25d3e --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/policy/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.security.tls.policy; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java b/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java index 84f71cf8fc2..aa5509a23b2 100644 --- a/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java +++ b/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java @@ -17,20 +17,22 @@ import static org.junit.Assert.*; public class TransportSecurityOptionsTest { private static final Path TEST_CONFIG_FILE = Paths.get("src/test/resources/transport-security-options.json"); + private static final TransportSecurityOptions OPTIONS = new TransportSecurityOptions.Builder() + .withCertificates(Paths.get("certs.pem"), Paths.get("myhost.key")) + .withCaCertificates(Paths.get("my_cas.pem")) + .build(); @Test public void can_read_options_from_json_file() { - TransportSecurityOptions expectedOptions = new TransportSecurityOptions("myhost.key", "certs.pem", "my_cas.pem"); TransportSecurityOptions actualOptions = TransportSecurityOptions.fromJsonFile(TEST_CONFIG_FILE); - assertEquals(expectedOptions, actualOptions); + assertEquals(OPTIONS, actualOptions); } @Test public void can_read_options_from_json() throws IOException { String tlsJson = new String(Files.readAllBytes(TEST_CONFIG_FILE), StandardCharsets.UTF_8); - TransportSecurityOptions expectedOptions = new TransportSecurityOptions("myhost.key", "certs.pem", "my_cas.pem"); TransportSecurityOptions actualOptions = TransportSecurityOptions.fromJson(tlsJson); - assertEquals(expectedOptions, actualOptions); + assertEquals(OPTIONS, actualOptions); } } diff --git a/security-utils/src/test/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializerTest.java new file mode 100644 index 00000000000..5e611b1eba5 --- /dev/null +++ b/security-utils/src/test/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializerTest.java @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.json; + +import com.yahoo.security.tls.TransportSecurityOptions; +import com.yahoo.security.tls.policy.AuthorizedPeers; +import com.yahoo.security.tls.policy.HostGlobPattern; +import com.yahoo.security.tls.policy.PeerPolicy; +import com.yahoo.security.tls.policy.RequiredPeerCredential; +import com.yahoo.security.tls.policy.Role; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import static com.yahoo.security.tls.policy.RequiredPeerCredential.Field.CN; +import static com.yahoo.security.tls.policy.RequiredPeerCredential.Field.SAN_DNS; +import static com.yahoo.test.json.JsonTestHelper.assertJsonEquals; +import static java.util.Collections.singleton; +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class TransportSecurityOptionsJsonSerializerTest { + + @Rule public TemporaryFolder tempDirectory = new TemporaryFolder(); + + private static final Path TEST_CONFIG_FILE = Paths.get("src/test/resources/transport-security-options.json"); + + @Test + public void can_serialize_and_deserialize_transport_security_options() { + TransportSecurityOptions options = new TransportSecurityOptions.Builder() + .withCaCertificates(Paths.get("/path/to/ca-certs.pem")) + .withCertificates(Paths.get("/path/to/cert.pem"), Paths.get("/path/to/key.pem")) + .withAuthorizedPeers( + new AuthorizedPeers( + new HashSet<>(Arrays.asList( + new PeerPolicy("cfgserver", singleton(new Role("myrole")), Arrays.asList( + new RequiredPeerCredential(CN, new HostGlobPattern("mycfgserver")), + new RequiredPeerCredential(SAN_DNS, new HostGlobPattern("*.suffix.com")))), + new PeerPolicy("node", singleton(new Role("anotherrole")), Collections.singletonList(new RequiredPeerCredential(CN, new HostGlobPattern("hostname")))))))) + .build(); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + TransportSecurityOptionsJsonSerializer serializer = new TransportSecurityOptionsJsonSerializer(); + serializer.serialize(out, options); + TransportSecurityOptions deserializedOptions = serializer.deserialize(new ByteArrayInputStream(out.toByteArray())); + assertEquals(options, deserializedOptions); + } + + @Test + public void can_serialize_options_without_authorized_peers() throws IOException { + TransportSecurityOptions options = new TransportSecurityOptions.Builder() + .withCertificates(Paths.get("certs.pem"), Paths.get("myhost.key")) + .withCaCertificates(Paths.get("my_cas.pem")) + .build(); + File outputFile = tempDirectory.newFile(); + try (OutputStream out = Files.newOutputStream(outputFile.toPath())) { + new TransportSecurityOptionsJsonSerializer().serialize(out, options); + } + String expectedOutput = new String(Files.readAllBytes(TEST_CONFIG_FILE)); + String actualOutput = new String(Files.readAllBytes(outputFile.toPath())); + assertJsonEquals(expectedOutput, actualOutput); + } + +} diff --git a/security-utils/src/test/java/com/yahoo/security/tls/policy/AuthorizedPeersTest.java b/security-utils/src/test/java/com/yahoo/security/tls/policy/AuthorizedPeersTest.java new file mode 100644 index 00000000000..ce8249b9c6c --- /dev/null +++ b/security-utils/src/test/java/com/yahoo/security/tls/policy/AuthorizedPeersTest.java @@ -0,0 +1,27 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.policy; + +import org.junit.Test; + +import java.util.HashSet; +import java.util.List; + +import static com.yahoo.security.tls.policy.RequiredPeerCredential.Field.CN; +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; + +/** + * @author bjorncs + */ +public class AuthorizedPeersTest { + + @Test(expected = IllegalArgumentException.class) + public void throws_exception_on_peer_policies_with_duplicate_names() { + List<RequiredPeerCredential> requiredPeerCredential = singletonList(new RequiredPeerCredential(CN, new HostGlobPattern("mycfgserver"))); + PeerPolicy peerPolicy1 = new PeerPolicy("duplicate-name", singleton(new Role("role")), requiredPeerCredential); + PeerPolicy peerPolicy2 = new PeerPolicy("duplicate-name", singleton(new Role("anotherrole")), requiredPeerCredential); + new AuthorizedPeers(new HashSet<>(asList(peerPolicy1, peerPolicy2))); + } + +} diff --git a/security-utils/src/test/java/com/yahoo/security/tls/policy/HostGlobPatternTest.java b/security-utils/src/test/java/com/yahoo/security/tls/policy/HostGlobPatternTest.java new file mode 100644 index 00000000000..ebec5605621 --- /dev/null +++ b/security-utils/src/test/java/com/yahoo/security/tls/policy/HostGlobPatternTest.java @@ -0,0 +1,69 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.policy; + +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + + +/** + * @author bjorncs + */ +public class HostGlobPatternTest { + + @Test + public void glob_without_wildcards_matches_entire_string() { + assertTrue(globMatches("foo", "foo")); + assertFalse(globMatches("foo", "fooo")); + assertFalse(globMatches("foo", "ffoo")); + } + + @Test + public void wildcard_glob_can_match_prefix() { + assertTrue(globMatches("foo*", "foo")); + assertTrue(globMatches("foo*", "foobar")); + assertFalse(globMatches("foo*", "ffoo")); + } + + @Test + public void wildcard_glob_can_match_suffix() { + assertTrue(globMatches("*foo", "foo")); + assertTrue(globMatches("*foo", "ffoo")); + assertFalse(globMatches("*foo", "fooo")); + } + + @Test + public void wildcard_glob_can_match_substring() { + assertTrue(globMatches("f*o", "fo")); + assertTrue(globMatches("f*o", "foo")); + assertTrue(globMatches("f*o", "ffoo")); + assertFalse(globMatches("f*o", "boo")); + } + + @Test + public void wildcard_glob_does_not_cross_multiple_dot_delimiter_boundaries() { + assertTrue(globMatches("*.bar.baz", "foo.bar.baz")); + assertTrue(globMatches("*.bar.baz", ".bar.baz")); + assertFalse(globMatches("*.bar.baz", "zoid.foo.bar.baz")); + assertTrue(globMatches("foo.*.baz", "foo.bar.baz")); + assertFalse(globMatches("foo.*.baz", "foo.bar.zoid.baz")); + } + + @Test + public void single_char_glob_matches_non_dot_characters() { + assertTrue(globMatches("f?o", "foo")); + assertFalse(globMatches("f?o", "fooo")); + assertFalse(globMatches("f?o", "ffoo")); + assertFalse(globMatches("f?o", "f.o")); + } + + @Test + public void other_regex_meta_characters_are_matched_as_literal_characters() { + assertTrue(globMatches("<([{\\^-=$!|]})+.>", "<([{\\^-=$!|]})+.>")); + } + + private static boolean globMatches(String pattern, String value) { + return new HostGlobPattern(pattern).matches(value); + } +} diff --git a/simplemetrics/src/test/java/com/yahoo/metrics/simple/DimensionsCacheTest.java b/simplemetrics/src/test/java/com/yahoo/metrics/simple/DimensionsCacheTest.java index 8745843782c..0fde3bcf588 100644 --- a/simplemetrics/src/test/java/com/yahoo/metrics/simple/DimensionsCacheTest.java +++ b/simplemetrics/src/test/java/com/yahoo/metrics/simple/DimensionsCacheTest.java @@ -17,7 +17,7 @@ import com.yahoo.metrics.simple.UntypedMetric.AssumedType; /** * Functional test for point persistence layer. * - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> + * @author Steinar Knutsen */ public class DimensionsCacheTest { diff --git a/standalone-container/pom.xml b/standalone-container/pom.xml index 608d72c72f2..73d05d35df7 100644 --- a/standalone-container/pom.xml +++ b/standalone-container/pom.xml @@ -45,6 +45,12 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>model-integration</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>defaults</artifactId> <version>${project.version}</version> <scope>provided</scope> diff --git a/standalone-container/vespa-standalone-container.spec b/standalone-container/vespa-standalone-container.spec index f051142c516..6143df6a446 100644 --- a/standalone-container/vespa-standalone-container.spec +++ b/standalone-container/vespa-standalone-container.spec @@ -54,6 +54,7 @@ declare -a modules=( jdisc_core jdisc_http_service model-evaluation + model-integration security-utils simplemetrics standalone-container diff --git a/storage/src/tests/distributor/externaloperationhandlertest.cpp b/storage/src/tests/distributor/externaloperationhandlertest.cpp index 54aca78d13d..a9956712679 100644 --- a/storage/src/tests/distributor/externaloperationhandlertest.cpp +++ b/storage/src/tests/distributor/externaloperationhandlertest.cpp @@ -62,12 +62,20 @@ class ExternalOperationHandlerTest : public CppUnit::TestFixture, .safe_time_not_reached.getLongValue("count"); } + int64_t safe_time_not_reached_metric_count(const metrics::LoadMetric<UpdateMetricSet>& metrics) const { + return metrics[documentapi::LoadType::DEFAULT].failures.safe_time_not_reached.getLongValue("count"); + } + int64_t concurrent_mutatations_metric_count( const metrics::LoadMetric<PersistenceOperationMetricSet>& metrics) const { return metrics[documentapi::LoadType::DEFAULT].failures .concurrent_mutations.getLongValue("count"); } + int64_t concurrent_mutatations_metric_count(const metrics::LoadMetric<UpdateMetricSet>& metrics) const { + return metrics[documentapi::LoadType::DEFAULT].failures.concurrent_mutations.getLongValue("count"); + } + void set_up_distributor_for_sequencing_test(); const vespalib::string _dummy_id{"id:foo:testdoctype1::bar"}; diff --git a/storage/src/tests/distributor/twophaseupdateoperationtest.cpp b/storage/src/tests/distributor/twophaseupdateoperationtest.cpp index f4b80f7961c..3c82931467e 100644 --- a/storage/src/tests/distributor/twophaseupdateoperationtest.cpp +++ b/storage/src/tests/distributor/twophaseupdateoperationtest.cpp @@ -204,7 +204,7 @@ TwoPhaseUpdateOperationTest::replyToMessage( api::ReturnCode::Result result) { std::shared_ptr<api::StorageMessage> msg2 = sender.commands.at(index); - UpdateCommand& updatec = dynamic_cast<UpdateCommand&>(*msg2); + auto& updatec = dynamic_cast<UpdateCommand&>(*msg2); std::unique_ptr<api::StorageReply> reply(updatec.makeReply()); static_cast<api::UpdateReply*>(reply.get())->setOldTimestamp(oldTimestamp); reply->setResult(api::ReturnCode(result, "")); @@ -222,7 +222,7 @@ TwoPhaseUpdateOperationTest::replyToPut( const std::string& traceMsg) { std::shared_ptr<api::StorageMessage> msg2 = sender.commands.at(index); - PutCommand& putc = dynamic_cast<PutCommand&>(*msg2); + auto& putc = dynamic_cast<PutCommand&>(*msg2); std::unique_ptr<api::StorageReply> reply(putc.makeReply()); reply->setResult(api::ReturnCode(result, "")); if (!traceMsg.empty()) { @@ -240,7 +240,7 @@ TwoPhaseUpdateOperationTest::replyToCreateBucket( api::ReturnCode::Result result) { std::shared_ptr<api::StorageMessage> msg2 = sender.commands.at(index); - CreateBucketCommand& putc = dynamic_cast<CreateBucketCommand&>(*msg2); + auto& putc = dynamic_cast<CreateBucketCommand&>(*msg2); std::unique_ptr<api::StorageReply> reply(putc.makeReply()); reply->setResult(api::ReturnCode(result, "")); callback.receive(sender, @@ -257,8 +257,7 @@ TwoPhaseUpdateOperationTest::replyToGet( api::ReturnCode::Result result, const std::string& traceMsg) { - const api::GetCommand& get( - static_cast<const api::GetCommand&>(*sender.commands.at(index))); + auto& get = static_cast<const api::GetCommand&>(*sender.commands.at(index)); std::shared_ptr<api::StorageReply> reply; if (haveDocument) { diff --git a/storage/src/tests/distributor/updateoperationtest.cpp b/storage/src/tests/distributor/updateoperationtest.cpp index 9ce862f5db8..63e2f4f00c1 100644 --- a/storage/src/tests/distributor/updateoperationtest.cpp +++ b/storage/src/tests/distributor/updateoperationtest.cpp @@ -61,7 +61,7 @@ public: } void replyToMessage(UpdateOperation& callback, MessageSenderStub& sender, uint32_t index, - uint64_t oldTimestamp, api::BucketInfo info = api::BucketInfo(2,4,6)); + uint64_t oldTimestamp, const api::BucketInfo& info = api::BucketInfo(2,4,6)); std::shared_ptr<UpdateOperation> sendUpdate(const std::string& bucketState); @@ -94,7 +94,7 @@ UpdateOperation_Test::sendUpdate(const std::string& bucketState) void UpdateOperation_Test::replyToMessage(UpdateOperation& callback, MessageSenderStub& sender, uint32_t index, - uint64_t oldTimestamp, api::BucketInfo info) + uint64_t oldTimestamp, const api::BucketInfo& info) { std::shared_ptr<api::StorageMessage> msg2 = sender.commands[index]; UpdateCommand* updatec = dynamic_cast<UpdateCommand*>(msg2.get()); @@ -123,6 +123,9 @@ UpdateOperation_Test::testSimple() std::string("UpdateReply(doc:test:test, BucketId(0x0000000000000000), " "timestamp 100, timestamp of updated doc: 90) ReturnCode(NONE)"), sender.getLastReply(true)); + + auto& metrics = getDistributor().getMetrics().updates[documentapi::LoadType::DEFAULT]; + CPPUNIT_ASSERT_EQUAL(0UL, metrics.diverging_timestamp_updates.getValue()); } void @@ -168,6 +171,9 @@ UpdateOperation_Test::testMultiNode() "node(idx=1,crc=0x2,docs=4/4,bytes=6/6,trusted=true,active=false,ready=false), " "node(idx=0,crc=0x2,docs=4/4,bytes=6/6,trusted=true,active=false,ready=false)"), dumpBucket(_bId)); + + auto& metrics = getDistributor().getMetrics().updates[documentapi::LoadType::DEFAULT]; + CPPUNIT_ASSERT_EQUAL(0UL, metrics.diverging_timestamp_updates.getValue()); } void @@ -188,5 +194,8 @@ UpdateOperation_Test::testMultiNodeInconsistentTimestamp() "timestamp 100, timestamp of updated doc: 120 Was inconsistent " "(best node 1)) ReturnCode(NONE)"), sender.getLastReply(true)); + + auto& metrics = getDistributor().getMetrics().updates[documentapi::LoadType::DEFAULT]; + CPPUNIT_ASSERT_EQUAL(1UL, metrics.diverging_timestamp_updates.getValue()); } diff --git a/storage/src/vespa/storage/common/storagecomponent.cpp b/storage/src/vespa/storage/common/storagecomponent.cpp index 1d6b563f6eb..21a4b8eea64 100644 --- a/storage/src/vespa/storage/common/storagecomponent.cpp +++ b/storage/src/vespa/storage/common/storagecomponent.cpp @@ -62,13 +62,6 @@ StorageComponent::setDistribution(DistributionSP distribution) } void -StorageComponent::enableMultipleBucketSpaces(bool value) -{ - std::lock_guard guard(_lock); - _enableMultipleBucketSpaces = value; -} - -void StorageComponent::setNodeStateUpdater(NodeStateUpdater& updater) { std::lock_guard guard(_lock); @@ -91,8 +84,7 @@ StorageComponent::StorageComponent(StorageComponentRegister& compReg, _bucketIdFactory(), _distribution(), _nodeStateUpdater(nullptr), - _lock(), - _enableMultipleBucketSpaces(false) + _lock() { compReg.registerStorageComponent(*this); } @@ -145,11 +137,4 @@ StorageComponent::getDistribution() const return _distribution; } -bool -StorageComponent::enableMultipleBucketSpaces() const -{ - std::lock_guard guard(_lock); - return _enableMultipleBucketSpaces; -} - } // storage diff --git a/storage/src/vespa/storage/common/storagecomponent.h b/storage/src/vespa/storage/common/storagecomponent.h index 168bdbe9aa6..901594ade99 100644 --- a/storage/src/vespa/storage/common/storagecomponent.h +++ b/storage/src/vespa/storage/common/storagecomponent.h @@ -83,7 +83,6 @@ public: void setPriorityConfig(const PriorityConfig&); void setBucketIdFactory(const document::BucketIdFactory&); void setDistribution(DistributionSP); - void enableMultipleBucketSpaces(bool value); StorageComponent(StorageComponentRegister&, vespalib::stringref name); virtual ~StorageComponent(); @@ -102,7 +101,6 @@ public: uint8_t getPriority(const documentapi::LoadType&) const; DistributionSP getDistribution() const; NodeStateUpdater& getStateUpdater() const; - bool enableMultipleBucketSpaces() const; private: vespalib::string _clusterName; @@ -115,7 +113,6 @@ private: DistributionSP _distribution; NodeStateUpdater* _nodeStateUpdater; mutable std::mutex _lock; - bool _enableMultipleBucketSpaces; }; struct StorageComponentRegister : public virtual framework::ComponentRegister diff --git a/storage/src/vespa/storage/distributor/CMakeLists.txt b/storage/src/vespa/storage/distributor/CMakeLists.txt index 2fd51433306..bd82c0e4d6e 100644 --- a/storage/src/vespa/storage/distributor/CMakeLists.txt +++ b/storage/src/vespa/storage/distributor/CMakeLists.txt @@ -34,6 +34,7 @@ vespa_add_library(storage_distributor statecheckers.cpp statusreporterdelegate.cpp throttlingoperationstarter.cpp + update_metric_set.cpp visitormetricsset.cpp $<TARGET_OBJECTS:storage_distributoroperation> $<TARGET_OBJECTS:storage_distributoroperationexternal> diff --git a/storage/src/vespa/storage/distributor/distributormetricsset.cpp b/storage/src/vespa/storage/distributor/distributormetricsset.cpp index b7725559b1d..927dc06182d 100644 --- a/storage/src/vespa/storage/distributor/distributormetricsset.cpp +++ b/storage/src/vespa/storage/distributor/distributormetricsset.cpp @@ -10,7 +10,7 @@ using metrics::MetricSet; DistributorMetricSet::DistributorMetricSet(const metrics::LoadTypeSet& lt) : MetricSet("distributor", {{"distributor"}}, ""), puts(lt, PersistenceOperationMetricSet("puts"), this), - updates(lt, PersistenceOperationMetricSet("updates"), this), + updates(lt, UpdateMetricSet(), this), update_puts(lt, PersistenceOperationMetricSet("update_puts"), this), update_gets(lt, PersistenceOperationMetricSet("update_gets"), this), removes(lt, PersistenceOperationMetricSet("removes"), this), diff --git a/storage/src/vespa/storage/distributor/distributormetricsset.h b/storage/src/vespa/storage/distributor/distributormetricsset.h index b6a429761a0..5a64027f500 100644 --- a/storage/src/vespa/storage/distributor/distributormetricsset.h +++ b/storage/src/vespa/storage/distributor/distributormetricsset.h @@ -2,6 +2,7 @@ #pragma once #include "persistence_operation_metric_set.h" +#include "update_metric_set.h" #include "visitormetricsset.h" #include <vespa/metrics/metrics.h> #include <vespa/documentapi/loadtypes/loadtypeset.h> @@ -12,7 +13,7 @@ class DistributorMetricSet : public metrics::MetricSet { public: metrics::LoadMetric<PersistenceOperationMetricSet> puts; - metrics::LoadMetric<PersistenceOperationMetricSet> updates; + metrics::LoadMetric<UpdateMetricSet> updates; metrics::LoadMetric<PersistenceOperationMetricSet> update_puts; metrics::LoadMetric<PersistenceOperationMetricSet> update_gets; metrics::LoadMetric<PersistenceOperationMetricSet> removes; diff --git a/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.h b/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.h index 6efce913e70..f9787141a19 100644 --- a/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.h +++ b/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.h @@ -18,6 +18,8 @@ class UpdateCommand; class CreateBucketReply; } +class UpdateMetricSet; + namespace distributor { class DistributorBucketSpace; @@ -120,7 +122,7 @@ private: void replyWithTasFailure(DistributorMessageSender& sender, vespalib::stringref message); - PersistenceOperationMetricSet& _updateMetric; + UpdateMetricSet& _updateMetric; PersistenceOperationMetricSet& _putMetric; PersistenceOperationMetricSet& _getMetric; std::shared_ptr<api::UpdateCommand> _updateCmd; diff --git a/storage/src/vespa/storage/distributor/operations/external/updateoperation.cpp b/storage/src/vespa/storage/distributor/operations/external/updateoperation.cpp index f5e67708fe5..c8f28391ec0 100644 --- a/storage/src/vespa/storage/distributor/operations/external/updateoperation.cpp +++ b/storage/src/vespa/storage/distributor/operations/external/updateoperation.cpp @@ -20,14 +20,15 @@ using document::BucketSpace; UpdateOperation::UpdateOperation(DistributorComponent& manager, DistributorBucketSpace &bucketSpace, const std::shared_ptr<api::UpdateCommand> & msg, - PersistenceOperationMetricSet& metric) + UpdateMetricSet& metric) : Operation(), _trackerInstance(metric, std::make_shared<api::UpdateReply>(*msg), manager, msg->getTimestamp()), _tracker(_trackerInstance), _msg(msg), _manager(manager), - _bucketSpace(bucketSpace) + _bucketSpace(bucketSpace), + _metrics(metric) { } @@ -149,6 +150,7 @@ UpdateOperation::onReceive(DistributorMessageSender& sender, reply.getBucket().toString().c_str(), _results[i].oldTs, _results[i].nodeId, _results[goodNode].oldTs, _results[goodNode].nodeId); + _metrics.diverging_timestamp_updates.inc(); replyToSend.setNodeWithNewestTimestamp(_results[goodNode].nodeId); _newestTimestampLocation.first = _results[goodNode].bucketId; diff --git a/storage/src/vespa/storage/distributor/operations/external/updateoperation.h b/storage/src/vespa/storage/distributor/operations/external/updateoperation.h index a23fd2ab876..cf432fa2305 100644 --- a/storage/src/vespa/storage/distributor/operations/external/updateoperation.h +++ b/storage/src/vespa/storage/distributor/operations/external/updateoperation.h @@ -15,6 +15,8 @@ class UpdateCommand; class CreateBucketReply; } +class UpdateMetricSet; + namespace distributor { class DistributorBucketSpace; @@ -25,7 +27,7 @@ public: UpdateOperation(DistributorComponent& manager, DistributorBucketSpace &bucketSpace, const std::shared_ptr<api::UpdateCommand> & msg, - PersistenceOperationMetricSet& metric); + UpdateMetricSet& metric); void onStart(DistributorMessageSender& sender) override; const char* getName() const override { return "update"; }; @@ -59,6 +61,7 @@ private: }; std::vector<OldTimestamp> _results; + UpdateMetricSet& _metrics; }; } diff --git a/storage/src/vespa/storage/distributor/update_metric_set.cpp b/storage/src/vespa/storage/distributor/update_metric_set.cpp new file mode 100644 index 00000000000..2505003d2ba --- /dev/null +++ b/storage/src/vespa/storage/distributor/update_metric_set.cpp @@ -0,0 +1,34 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "update_metric_set.h" +#include <vespa/metrics/loadmetric.hpp> +#include <vespa/metrics/summetric.hpp> + +namespace storage { + +using metrics::MetricSet; + +UpdateMetricSet::UpdateMetricSet(MetricSet* owner) + : PersistenceOperationMetricSet("updates", owner), + diverging_timestamp_updates("diverging_timestamp_updates", {}, + "Number of updates that report they were performed against " + "divergent version timestamps on different replicas", this) +{ +} + +UpdateMetricSet::~UpdateMetricSet() = default; + +MetricSet * +UpdateMetricSet::clone(std::vector<Metric::UP>& ownerList, CopyType copyType, + MetricSet* owner, bool includeUnused) const +{ + if (copyType == INACTIVE) { + return MetricSet::clone(ownerList, INACTIVE, owner, includeUnused); + } + return static_cast<MetricSet*>((new UpdateMetricSet(owner))->assignValues(*this)); +} + +} + +template class metrics::LoadMetric<storage::UpdateMetricSet>; +template class metrics::SumMetric<storage::UpdateMetricSet>; diff --git a/storage/src/vespa/storage/distributor/update_metric_set.h b/storage/src/vespa/storage/distributor/update_metric_set.h new file mode 100644 index 00000000000..de8474c949c --- /dev/null +++ b/storage/src/vespa/storage/distributor/update_metric_set.h @@ -0,0 +1,23 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "persistence_operation_metric_set.h" +#include <vespa/documentapi/loadtypes/loadtypeset.h> +#include <vespa/metrics/metrics.h> + +namespace storage { + +class UpdateMetricSet : public PersistenceOperationMetricSet { +public: + metrics::LongCountMetric diverging_timestamp_updates; + + explicit UpdateMetricSet(MetricSet* owner = nullptr); + ~UpdateMetricSet() override; + + MetricSet* clone(std::vector<Metric::UP>& ownerList, CopyType copyType, + MetricSet* owner, bool includeUnused) const override; +}; + +} // storage + + diff --git a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp index 0a172ccd4e2..2f66027ccab 100644 --- a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp +++ b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp @@ -21,8 +21,7 @@ StorageComponentRegisterImpl::StorageComponentRegisterImpl() _bucketIdFactory(), _distribution(), _nodeStateUpdater(nullptr), - _bucketSpacesConfig(), - _enableMultipleBucketSpaces(false) + _bucketSpacesConfig() { } @@ -43,7 +42,6 @@ StorageComponentRegisterImpl::registerStorageComponent(StorageComponent& smc) smc.setPriorityConfig(_priorityConfig); smc.setBucketIdFactory(_bucketIdFactory); smc.setDistribution(_distribution); - smc.enableMultipleBucketSpaces(_enableMultipleBucketSpaces); } void @@ -133,19 +131,4 @@ StorageComponentRegisterImpl::setBucketSpacesConfig(const BucketspacesConfig& co _bucketSpacesConfig = config; } -void StorageComponentRegisterImpl::setEnableMultipleBucketSpaces(bool enabled) { - vespalib::LockGuard lock(_componentLock); - assert(!_enableMultipleBucketSpaces); // Cannot disable once enabled. - _enableMultipleBucketSpaces = enabled; - for (auto& component : _components) { - component->enableMultipleBucketSpaces(_enableMultipleBucketSpaces); - } -} - -bool StorageComponentRegisterImpl::enableMultipleBucketSpaces() const { - // We allow reading this outside _componentLock, as it should never be written - // again after startup. - return _enableMultipleBucketSpaces; -} - } // storage diff --git a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h index 1621f4e71f3..2f8bfaab87b 100644 --- a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h +++ b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h @@ -36,7 +36,6 @@ class StorageComponentRegisterImpl lib::Distribution::SP _distribution; NodeStateUpdater* _nodeStateUpdater; BucketspacesConfig _bucketSpacesConfig; - bool _enableMultipleBucketSpaces; public: typedef std::unique_ptr<StorageComponentRegisterImpl> UP; @@ -67,10 +66,6 @@ public: virtual void setBucketIdFactory(const document::BucketIdFactory&); virtual void setDistribution(lib::Distribution::SP); virtual void setBucketSpacesConfig(const BucketspacesConfig&); - - virtual void setEnableMultipleBucketSpaces(bool enabled); // To be called during startup configuration phase only. - bool enableMultipleBucketSpaces() const; - }; } // storage diff --git a/storage/src/vespa/storage/storageserver/communicationmanager.cpp b/storage/src/vespa/storage/storageserver/communicationmanager.cpp index eee688b1fb2..5a1c2aed113 100644 --- a/storage/src/vespa/storage/storageserver/communicationmanager.cpp +++ b/storage/src/vespa/storage/storageserver/communicationmanager.cpp @@ -439,8 +439,7 @@ void CommunicationManager::configure(std::unique_ptr<CommunicationManagerConfig> _mbus = std::make_unique<mbus::RPCMessageBus>( mbus::ProtocolSet() .add(std::make_shared<documentapi::DocumentProtocol>(*_component.getLoadTypes(), _component.getTypeRepo())) - .add(std::make_shared<mbusprot::StorageProtocol>(_component.getTypeRepo(), *_component.getLoadTypes(), - _component.enableMultipleBucketSpaces())), + .add(std::make_shared<mbusprot::StorageProtocol>(_component.getTypeRepo(), *_component.getLoadTypes())), params, _configUri); @@ -781,7 +780,7 @@ void CommunicationManager::updateMessagebusProtocol( auto newDocumentProtocol = std::make_shared<documentapi::DocumentProtocol>(*_component.getLoadTypes(), repo); std::lock_guard<std::mutex> guard(_earlierGenerationsLock); _earlierGenerations.push_back(std::make_pair(now, _mbus->getMessageBus().putProtocol(newDocumentProtocol))); - auto newStorageProtocol = std::make_shared<mbusprot::StorageProtocol>(repo, *_component.getLoadTypes(), _component.enableMultipleBucketSpaces()); + auto newStorageProtocol = std::make_shared<mbusprot::StorageProtocol>(repo, *_component.getLoadTypes()); _earlierGenerations.push_back(std::make_pair(now, _mbus->getMessageBus().putProtocol(newStorageProtocol))); } } diff --git a/storage/src/vespa/storage/storageserver/storagenode.cpp b/storage/src/vespa/storage/storageserver/storagenode.cpp index ed33f3846c1..d159b6e5bdd 100644 --- a/storage/src/vespa/storage/storageserver/storagenode.cpp +++ b/storage/src/vespa/storage/storageserver/storagenode.cpp @@ -146,11 +146,6 @@ StorageNode::initialize() // and store them away, while having the config lock. subscribeToConfigs(); - // Multiple bucket spaces can only be enabled on startup and cannot be live reconfigured. - // A process restart is required to either enable or disable after the fact. - // TODO ensure config is tagged as 'restart' as a consequence - _context.getComponentRegister().setEnableMultipleBucketSpaces(_bucketSpacesConfig->enableMultipleBucketSpaces); - updateUpgradeFlag(*_clusterConfig); // First update some basics that doesn't depend on anything else to be diff --git a/storageapi/src/tests/mbusprot/storageprotocoltest.cpp b/storageapi/src/tests/mbusprot/storageprotocoltest.cpp index da7e8cb743e..f634667afd5 100644 --- a/storageapi/src/tests/mbusprot/storageprotocoltest.cpp +++ b/storageapi/src/tests/mbusprot/storageprotocoltest.cpp @@ -56,7 +56,7 @@ struct StorageProtocolTest : public CppUnit::TestFixture { _testDoc(_docMan.createDocument()), _testDocId(_testDoc->getId()), _bucket(makeDocumentBucket(document::BucketId(16, 0x51))), - _protocol(_docMan.getTypeRepoSP(), _loadTypes, true) + _protocol(_docMan.getTypeRepoSP(), _loadTypes) { _loadTypes.addLoadType(34, "foo", documentapi::Priority::PRI_NORMAL_2); } diff --git a/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.cpp b/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.cpp index f83188f7dd8..7e6be0a84f5 100644 --- a/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.cpp +++ b/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.cpp @@ -16,13 +16,11 @@ namespace storage::mbusprot { mbus::string StorageProtocol::NAME = "StorageProtocol"; StorageProtocol::StorageProtocol(const std::shared_ptr<const document::DocumentTypeRepo> repo, - const documentapi::LoadTypeSet& loadTypes, - bool configForcedBucketSpaceSerialization) + const documentapi::LoadTypeSet& loadTypes) : _serializer5_0(repo, loadTypes), _serializer5_1(repo, loadTypes), _serializer5_2(repo, loadTypes), - _serializer6_0(repo, loadTypes), - _configForcedBucketSpaceSerialization(configForcedBucketSpaceSerialization) + _serializer6_0(repo, loadTypes) { } @@ -106,14 +104,10 @@ StorageProtocol::encode(const vespalib::Version& version, } else if (version < version5_2) { return encodeMessage(_serializer5_1, routable, message, version5_1, version); } else { - if (_configForcedBucketSpaceSerialization) { - return encodeMessage(_serializer6_0, routable, message, version6_0, version); + if (version < version6_0) { + return encodeMessage(_serializer5_2, routable, message, version5_2, version); } else { - if (version < version6_0) { - return encodeMessage(_serializer5_2, routable, message, version5_2, version); - } else { - return encodeMessage(_serializer6_0, routable, message, version6_0, version); - } + return encodeMessage(_serializer6_0, routable, message, version6_0, version); } } @@ -184,14 +178,10 @@ StorageProtocol::decode(const vespalib::Version & version, } else if (version < version5_2) { return decodeMessage(_serializer5_1, data, type, version5_1, version); } else { - if (_configForcedBucketSpaceSerialization) { - return decodeMessage(_serializer6_0, data, type, version6_0, version); + if (version < version6_0) { + return decodeMessage(_serializer5_2, data, type, version5_2, version); } else { - if (version < version6_0) { - return decodeMessage(_serializer5_2, data, type, version5_2, version); - } else { - return decodeMessage(_serializer6_0, data, type, version6_0, version); - } + return decodeMessage(_serializer6_0, data, type, version6_0, version); } } } catch (std::exception & e) { diff --git a/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.h b/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.h index 56f271db1d0..1acd7c9675f 100644 --- a/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.h +++ b/storageapi/src/vespa/storageapi/mbusprot/storageprotocol.h @@ -15,8 +15,7 @@ public: static mbus::string NAME; StorageProtocol(const std::shared_ptr<const document::DocumentTypeRepo>, - const documentapi::LoadTypeSet& loadTypes, - bool activateBucketSpaceSerialization = false); + const documentapi::LoadTypeSet& loadTypes); ~StorageProtocol(); const mbus::string& getName() const override { return NAME; } @@ -29,7 +28,6 @@ private: ProtocolSerialization5_1 _serializer5_1; ProtocolSerialization5_2 _serializer5_2; ProtocolSerialization6_0 _serializer6_0; - bool _configForcedBucketSpaceSerialization; }; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 483ccd330e0..1ee22c69c23 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -230,7 +230,7 @@ public interface Tensor { * @return the tensor on the standard string format */ static String toStandardString(Tensor tensor) { - if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Never do that? + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Always do that return tensor.type() + ":" + contentToString(tensor); else return contentToString(tensor); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 000f33696f2..fa32d385004 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -76,32 +76,41 @@ class TensorParser { } private static Tensor fromCellString(Tensor.Builder builder, String s) { - s = s.trim().substring(1).trim(); - while (s.length() > 1) { - int keyOrTensorEnd = s.indexOf('}'); + int index = 1; + index = skipSpace(index, s); + while (index + 1 < s.length()) { + int keyOrTensorEnd = s.indexOf('}', index); TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); - if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAdress is empty - addLabels(s.substring(0, keyOrTensorEnd + 1), addressBuilder); - s = s.substring(keyOrTensorEnd + 1).trim(); - if ( ! s.startsWith(":")) - throw new IllegalArgumentException("Expecting a ':' after " + s + ", got '" + s + "'"); - s = s.substring(1); + if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty + addLabels(s.substring(index, keyOrTensorEnd + 1), addressBuilder); + index = keyOrTensorEnd + 1; + index = skipSpace(index, s); + if ( s.charAt(index) != ':') + throw new IllegalArgumentException("Expecting a ':' after " + s.substring(index) + ", got '" + s + "'"); + index++; } - int valueEnd = s.indexOf(','); + int valueEnd = s.indexOf(',', index); if (valueEnd < 0) { // last value - valueEnd = s.indexOf("}"); + valueEnd = s.indexOf('}', index); if (valueEnd < 0) throw new IllegalArgumentException("A tensor string must end by '}'"); } TensorAddress address = addressBuilder.build(); - Double value = asDouble(address, s.substring(0, valueEnd).trim()); + Double value = asDouble(address, s.substring(index, valueEnd).trim()); builder.cell(address, value); - s = s.substring(valueEnd+1).trim(); + index = valueEnd+1; + index = skipSpace(index, s); } return builder.build(); } + private static int skipSpace(int index, String s) { + while (index < s.length() && s.charAt(index) == ' ') + index++; + return index; + } + /** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */ private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { mapAddressString = mapAddressString.trim(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 38a8329bff1..122b6019884 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -101,7 +101,7 @@ public class TensorTestCase { " {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"), Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:0,y:0,z:0}:1, {x:0,y:1,z:0}:0, {x:1,y:0,z:0}:0, {x:1,y:1,z:0}:0, {x:2,y:0,z:0}:0, {x:2,y:1,z:0}:0, "+ - " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 }"), + " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 } "), Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x")); } diff --git a/yolean/src/main/java/com/yahoo/yolean/Exceptions.java b/yolean/src/main/java/com/yahoo/yolean/Exceptions.java index 82677a14242..fa3eb412016 100644 --- a/yolean/src/main/java/com/yahoo/yolean/Exceptions.java +++ b/yolean/src/main/java/com/yahoo/yolean/Exceptions.java @@ -3,6 +3,9 @@ package com.yahoo.yolean; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.file.NoSuchFileException; +import java.util.Optional; +import java.util.function.Supplier; /** * Helper methods for handling exceptions @@ -70,6 +73,29 @@ public class Exceptions { } } + /** Similar to uncheck(), except an exceptionToIgnore exception is silently ignored. */ + public static <T extends IOException> void uncheckAndIgnore(RunnableThrowingIOException runnable, Class<T> exceptionToIgnore) { + try { + runnable.run(); + } catch (UncheckedIOException e) { + IOException cause = e.getCause(); + if (cause == null) throw e; + try { + cause.getClass().asSubclass(exceptionToIgnore); + } catch (ClassCastException f) { + throw e; + } + // Do nothing - OK + } catch (IOException e) { + try { + e.getClass().asSubclass(exceptionToIgnore); + } catch (ClassCastException f) { + throw new UncheckedIOException(e); + } + // Do nothing - OK + } + } + @FunctionalInterface public interface RunnableThrowingIOException { void run() throws IOException; @@ -98,6 +124,29 @@ public class Exceptions { } } + /** Similar to uncheck(), except null is returned if exceptionToIgnore is thrown. */ + public static <R, T extends IOException> R uncheckAndIgnore(SupplierThrowingIOException<R> supplier, Class<T> exceptionToIgnore) { + try { + return supplier.get(); + } catch (UncheckedIOException e) { + IOException cause = e.getCause(); + if (cause == null) throw e; + try { + cause.getClass().asSubclass(exceptionToIgnore); + } catch (ClassCastException f) { + throw e; + } + return null; + } catch (IOException e) { + try { + e.getClass().asSubclass(exceptionToIgnore); + } catch (ClassCastException f) { + throw new UncheckedIOException(e); + } + return null; + } + } + @FunctionalInterface public interface SupplierThrowingIOException<T> { T get() throws IOException; diff --git a/yolean/src/test/java/com/yahoo/yolean/ExceptionsTestCase.java b/yolean/src/test/java/com/yahoo/yolean/ExceptionsTestCase.java index 31e27fa2675..390018efdf7 100644 --- a/yolean/src/test/java/com/yahoo/yolean/ExceptionsTestCase.java +++ b/yolean/src/test/java/com/yahoo/yolean/ExceptionsTestCase.java @@ -5,8 +5,11 @@ import org.junit.Test; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.file.NoSuchFileException; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; /** * @author bratseth @@ -27,35 +30,38 @@ public class ExceptionsTestCase { @Test public void testUnchecks() { try { - Exceptions.uncheck(this::throwIO); + Exceptions.uncheck(this::throwNoSuchFileException); } catch (UncheckedIOException e) { - assertEquals("root cause", e.getCause().getMessage()); + assertEquals("filename", e.getCause().getMessage()); } try { - Exceptions.uncheck(this::throwIO, "additional %s", "info"); + Exceptions.uncheck(this::throwNoSuchFileException, "additional %s", "info"); } catch (UncheckedIOException e) { assertEquals("additional info", e.getMessage()); } try { - int i = Exceptions.uncheck(this::throwIOWithReturnValue); + int i = Exceptions.uncheck(this::throwNoSuchFileExceptionSupplier); } catch (UncheckedIOException e) { - assertEquals("root cause", e.getCause().getMessage()); + assertEquals("filename", e.getCause().getMessage()); } try { - int i = Exceptions.uncheck(this::throwIOWithReturnValue, "additional %s", "info"); + int i = Exceptions.uncheck(this::throwNoSuchFileExceptionSupplier, "additional %s", "info"); } catch (UncheckedIOException e) { assertEquals("additional info", e.getMessage()); } + + Exceptions.uncheckAndIgnore(this::throwNoSuchFileException, NoSuchFileException.class); + assertNull(Exceptions.uncheckAndIgnore(this::throwNoSuchFileExceptionSupplier, NoSuchFileException.class)); } - private void throwIO() throws IOException { - throw new IOException("root cause"); + private void throwNoSuchFileException() throws IOException { + throw new NoSuchFileException("filename"); } - private int throwIOWithReturnValue() throws IOException { - throw new IOException("root cause"); + private int throwNoSuchFileExceptionSupplier() throws IOException { + throw new NoSuchFileException("filename"); } } |