diff options
45 files changed, 1110 insertions, 461 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java index 9cd7fb24e42..2790f2ddf6e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java @@ -24,16 +24,18 @@ public class TensorFieldProcessor extends Processor { @Override public void process(boolean validate, boolean documentsOnly) { - if ( ! validate) return; - for (var field : search.allConcreteFields()) { if ( field.getDataType() instanceof TensorDataType ) { - validateIndexingScripsForTensorField(field); - validateAttributeSettingForTensorField(field); - processIndexSettingsForTensorField(field); + if (validate) { + validateIndexingScripsForTensorField(field); + validateAttributeSettingForTensorField(field); + } + processIndexSettingsForTensorField(field, validate); } else if (field.getDataType() instanceof CollectionDataType){ - validateDataTypeForCollectionField(field); + if (validate) { + validateDataTypeForCollectionField(field); + } } } } @@ -68,12 +70,12 @@ public class TensorFieldProcessor extends Processor { } } - private void processIndexSettingsForTensorField(SDField field) { + private void processIndexSettingsForTensorField(SDField field, boolean validate) { if (!field.doesIndexing()) { return; } if (isTensorTypeThatSupportsHnswIndex(field)) { - if (!field.doesAttributing()) { + if (validate && !field.doesAttributing()) { fail(search, field, "A tensor that has an index must also be an attribute."); } var index = field.getIndex(field.getName()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java new file mode 100644 index 00000000000..b441f1e1993 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java @@ -0,0 +1,244 @@ +package com.yahoo.vespa.model.admin.metricsproxy; + +import ai.vespa.metricsproxy.core.ConsumersConfig; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.admin.monitoring.Metric; +import com.yahoo.vespa.model.admin.monitoring.MetricSet; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.TestMode.hosted; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.TestMode.self_hosted; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.checkMetric; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.consumersConfigFromModel; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.consumersConfigFromXml; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.getCustomConsumer; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.getModel; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.servicesWithAdminOnly; +import static com.yahoo.vespa.model.admin.monitoring.DefaultPublicConsumer.DEFAULT_PUBLIC_CONSUMER_ID; +import static com.yahoo.vespa.model.admin.monitoring.DefaultPublicMetrics.defaultPublicMetricSet; +import static com.yahoo.vespa.model.admin.monitoring.DefaultVespaMetrics.defaultVespaMetricSet; +import static com.yahoo.vespa.model.admin.monitoring.NetworkMetrics.networkMetricSet; +import static com.yahoo.vespa.model.admin.monitoring.SystemMetrics.systemMetricSet; +import static com.yahoo.vespa.model.admin.monitoring.VespaMetricSet.vespaMetricSet; +import static com.yahoo.vespa.model.admin.monitoring.VespaMetricsConsumer.VESPA_CONSUMER_ID; +import static java.util.Collections.singleton; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link MetricsProxyContainerCluster} related to metrics consumers. + * + * @author gjoranv + */ +public class MetricsConsumersTest { + + private static int numPublicDefaultMetrics = defaultPublicMetricSet.getMetrics().size(); + private static int numDefaultVespaMetrics = defaultVespaMetricSet.getMetrics().size(); + private static int numVespaMetrics = vespaMetricSet.getMetrics().size(); + private static int numSystemMetrics = systemMetricSet.getMetrics().size(); + private static int numNetworkMetrics = networkMetricSet.getMetrics().size(); + private static int numMetricsForVespaConsumer = numVespaMetrics + numSystemMetrics + numNetworkMetrics; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void default_public_consumer_is_set_up_for_self_hosted() { + ConsumersConfig config = consumersConfigFromXml(servicesWithAdminOnly(), self_hosted); + assertEquals(2, config.consumer().size()); + assertEquals(config.consumer(1).name(), DEFAULT_PUBLIC_CONSUMER_ID); + + int numMetricsForPublicDefaultConsumer = defaultPublicMetricSet.getMetrics().size() + numSystemMetrics; + assertEquals(numMetricsForPublicDefaultConsumer, config.consumer(1).metric().size()); + } + + @Test + public void vespa_consumer_and_default_public_consumer_is_set_up_for_hosted() { + ConsumersConfig config = consumersConfigFromXml(servicesWithAdminOnly(), hosted); + assertEquals(2, config.consumer().size()); + assertEquals(config.consumer(0).name(), VESPA_CONSUMER_ID); + assertEquals(config.consumer(1).name(), DEFAULT_PUBLIC_CONSUMER_ID); + } + + @Test + public void vespa_consumer_is_always_present_and_has_all_vespa_metrics_and_all_system_metrics() { + ConsumersConfig config = consumersConfigFromXml(servicesWithAdminOnly(), self_hosted); + assertEquals(config.consumer(0).name(), VESPA_CONSUMER_ID); + assertEquals(numMetricsForVespaConsumer, config.consumer(0).metric().size()); + } + + @Test + public void vespa_consumer_can_be_amended_via_admin_object() { + VespaModel model = getModel(servicesWithAdminOnly(), self_hosted); + var additionalMetric = new Metric("additional-metric"); + model.getAdmin().setAdditionalDefaultMetrics(new MetricSet("amender-metrics", singleton(additionalMetric))); + + ConsumersConfig config = consumersConfigFromModel(model); + assertEquals(numMetricsForVespaConsumer + 1, config.consumer(0).metric().size()); + + ConsumersConfig.Consumer vespaConsumer = config.consumer(0); + assertTrue("Did not contain additional metric", checkMetric(vespaConsumer, additionalMetric)); + } + + @Test + public void vespa_is_a_reserved_consumer_id() { + assertReservedConsumerId("Vespa"); + } + + @Test + public void default_is_a_reserved_consumer_id() { + assertReservedConsumerId("default"); + } + + private void assertReservedConsumerId(String consumerId) { + String services = String.join("\n", + "<services>", + " <admin version='2.0'>", + " <adminserver hostalias='node1'/>", + " <metrics>", + " <consumer id='" + consumerId + "'/>", + " </metrics>", + " </admin>", + "</services>" + ); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("'" + consumerId + "' is not allowed as metrics consumer id"); + consumersConfigFromXml(services, self_hosted); + } + + @Test + public void vespa_consumer_id_is_allowed_for_hosted_infrastructure_applications() { + String services = String.join("\n", + "<services application-type='hosted-infrastructure'>", + " <admin version='4.0'>", + " <adminserver hostalias='node1'/>", + " <metrics>", + " <consumer id='Vespa'>", + " <metric id='custom.metric1'/>", + " </consumer>", + " </metrics>", + " </admin>", + "</services>" + ); + VespaModel hostedModel = getModel(services, hosted); + ConsumersConfig config = consumersConfigFromModel(hostedModel); + assertEquals(2, config.consumer().size()); + + // All default metrics are retained + ConsumersConfig.Consumer vespaConsumer = config.consumer(0); + assertEquals(numMetricsForVespaConsumer + 1, vespaConsumer.metric().size()); + + Metric customMetric1 = new Metric("custom.metric1"); + assertTrue("Did not contain metric: " + customMetric1, checkMetric(vespaConsumer, customMetric1)); + } + + @Test + public void consumer_id_is_case_insensitive() { + String services = String.join("\n", + "<services>", + " <admin version='2.0'>", + " <adminserver hostalias='node1'/>", + " <metrics>", + " <consumer id='A'/>", + " <consumer id='a'/>", + " </metrics>", + " </admin>", + "</services>" + ); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("'a' is used as id for two metrics consumers"); + consumersConfigFromXml(services, self_hosted); + } + + @Test + public void non_existent_metric_set_causes_exception() { + String services = String.join("\n", + "<services>", + " <admin version='2.0'>", + " <adminserver hostalias='node1'/>", + " <metrics>", + " <consumer id='consumer-with-non-existent-default-set'>", + " <metric-set id='non-existent'/>", + " </consumer>", + " </metrics>", + " </admin>", + "</services>" + ); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("No such metric-set: non-existent"); + consumersConfigFromXml(services, self_hosted); + } + + @Test + public void consumer_with_no_metric_set_has_its_own_metrics_plus_system_metrics_plus_default_vespa_metrics() { + String services = String.join("\n", + "<services>", + " <admin version='2.0'>", + " <adminserver hostalias='node1'/>", + " <metrics>", + " <consumer id='consumer-with-metrics-only'>", + " <metric id='custom.metric1'/>", + " <metric id='custom.metric2'/>", + " </consumer>", + " </metrics>", + " </admin>", + "</services>" + ); + ConsumersConfig.Consumer consumer = getCustomConsumer(services); + + assertEquals(numSystemMetrics + numDefaultVespaMetrics + 2, consumer.metric().size()); + + Metric customMetric1 = new Metric("custom.metric1"); + Metric customMetric2 = new Metric("custom.metric2"); + assertTrue("Did not contain metric: " + customMetric1, checkMetric(consumer, customMetric1)); + assertTrue("Did not contain metric: " + customMetric2, checkMetric(consumer, customMetric2)); + } + + @Test + public void consumer_with_default_public_metric_set_has_all_public_metrics_plus_all_system_metrics_plus_its_own() { + String services = String.join("\n", + "<services>", + " <admin version='2.0'>", + " <adminserver hostalias='node1'/>", + " <metrics>", + " <consumer id='consumer-with-public-default-set'>", + " <metric-set id='public'/>", + " <metric id='custom.metric'/>", + " </consumer>", + " </metrics>", + " </admin>", + "</services>" + ); + ConsumersConfig.Consumer consumer = getCustomConsumer(services); + + assertEquals(numPublicDefaultMetrics + numSystemMetrics + 1, consumer.metric().size()); + + Metric customMetric = new Metric("custom.metric"); + assertTrue("Did not contain metric: " + customMetric, checkMetric(consumer, customMetric)); + } + + @Test + public void consumer_with_vespa_metric_set_has_all_vespa_metrics_plus_all_system_metrics_plus_its_own() { + String services = String.join("\n", + "<services>", + " <admin version='2.0'>", + " <adminserver hostalias='node1'/>", + " <metrics>", + " <consumer id='consumer-with-vespa-set'>", + " <metric-set id='vespa'/>", + " <metric id='my.extra.metric'/>", + " </consumer>", + " </metrics>", + " </admin>", + "</services>" + ); + ConsumersConfig.Consumer consumer = getCustomConsumer(services); + assertEquals(numVespaMetrics + numSystemMetrics + 1, consumer.metric().size()); + + Metric customMetric = new Metric("my.extra.metric"); + assertTrue("Did not contain metric: " + customMetric, checkMetric(consumer, customMetric)); + } + +} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerClusterTest.java index e1b42854642..bed77bd5c77 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerClusterTest.java @@ -5,7 +5,6 @@ package com.yahoo.vespa.model.admin.metricsproxy; -import ai.vespa.metricsproxy.core.ConsumersConfig; import ai.vespa.metricsproxy.http.application.ApplicationMetricsHandler; import ai.vespa.metricsproxy.http.application.MetricsNodesConfig; import ai.vespa.metricsproxy.http.metrics.MetricsV1Handler; @@ -21,13 +20,9 @@ import com.yahoo.container.core.ApplicationMetadataConfig; import com.yahoo.search.config.QrStartConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyContainerCluster.AppDimensionNames; -import com.yahoo.vespa.model.admin.monitoring.Metric; -import com.yahoo.vespa.model.admin.monitoring.MetricSet; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.Handler; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import java.util.Collection; @@ -39,22 +34,11 @@ import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.M import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.MY_TENANT; import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.TestMode.hosted; import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.TestMode.self_hosted; -import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.checkMetric; -import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.consumersConfigFromModel; -import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.consumersConfigFromXml; import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.getApplicationDimensionsConfig; -import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.getCustomConsumer; import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.getMetricsNodesConfig; import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.getModel; import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.getQrStartConfig; -import static com.yahoo.vespa.model.admin.monitoring.DefaultPublicConsumer.DEFAULT_PUBLIC_CONSUMER_ID; -import static com.yahoo.vespa.model.admin.monitoring.DefaultPublicMetrics.defaultPublicMetricSet; -import static com.yahoo.vespa.model.admin.monitoring.DefaultVespaMetrics.defaultVespaMetricSet; -import static com.yahoo.vespa.model.admin.monitoring.NetworkMetrics.networkMetricSet; -import static com.yahoo.vespa.model.admin.monitoring.SystemMetrics.systemMetricSet; -import static com.yahoo.vespa.model.admin.monitoring.VespaMetricSet.vespaMetricSet; -import static com.yahoo.vespa.model.admin.monitoring.VespaMetricsConsumer.VESPA_CONSUMER_ID; -import static java.util.Collections.singleton; +import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.servicesWithAdminOnly; import static java.util.stream.Collectors.toList; import static org.hamcrest.CoreMatchers.endsWith; import static org.hamcrest.CoreMatchers.hasItem; @@ -69,16 +53,6 @@ import static org.junit.Assert.assertTrue; */ public class MetricsProxyContainerClusterTest { - private static int numPublicDefaultMetrics = defaultPublicMetricSet.getMetrics().size(); - private static int numDefaultVespaMetrics = defaultVespaMetricSet.getMetrics().size(); - private static int numVespaMetrics = vespaMetricSet.getMetrics().size(); - private static int numSystemMetrics = systemMetricSet.getMetrics().size(); - private static int numNetworkMetrics = networkMetricSet.getMetrics().size(); - private static int numMetricsForVespaConsumer = numVespaMetrics + numSystemMetrics + numNetworkMetrics; - - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Test public void metrics_proxy_bundle_is_included_in_bundles_config() { VespaModel model = getModel(servicesWithAdminOnly(), self_hosted); @@ -134,203 +108,6 @@ public class MetricsProxyContainerClusterTest { } @Test - public void default_public_consumer_is_set_up_for_self_hosted() { - ConsumersConfig config = consumersConfigFromXml(servicesWithAdminOnly(), self_hosted); - assertEquals(2, config.consumer().size()); - assertEquals(config.consumer(1).name(), DEFAULT_PUBLIC_CONSUMER_ID); - - int numMetricsForPublicDefaultConsumer = defaultPublicMetricSet.getMetrics().size() + numSystemMetrics; - assertEquals(numMetricsForPublicDefaultConsumer, config.consumer(1).metric().size()); - } - - @Test - public void vespa_consumer_and_default_public_consumer_is_set_up_for_hosted() { - ConsumersConfig config = consumersConfigFromXml(servicesWithAdminOnly(), hosted); - assertEquals(2, config.consumer().size()); - assertEquals(config.consumer(0).name(), VESPA_CONSUMER_ID); - assertEquals(config.consumer(1).name(), DEFAULT_PUBLIC_CONSUMER_ID); - } - - @Test - public void vespa_consumer_is_always_present_and_has_all_vespa_metrics_and_all_system_metrics() { - ConsumersConfig config = consumersConfigFromXml(servicesWithAdminOnly(), self_hosted); - assertEquals(config.consumer(0).name(), VESPA_CONSUMER_ID); - assertEquals(numMetricsForVespaConsumer, config.consumer(0).metric().size()); - } - - @Test - public void vespa_consumer_can_be_amended_via_admin_object() { - VespaModel model = getModel(servicesWithAdminOnly(), self_hosted); - var additionalMetric = new Metric("additional-metric"); - model.getAdmin().setAdditionalDefaultMetrics(new MetricSet("amender-metrics", singleton(additionalMetric))); - - ConsumersConfig config = consumersConfigFromModel(model); - assertEquals(numMetricsForVespaConsumer + 1, config.consumer(0).metric().size()); - - ConsumersConfig.Consumer vespaConsumer = config.consumer(0); - assertTrue("Did not contain additional metric", checkMetric(vespaConsumer, additionalMetric)); - } - - @Test - public void vespa_is_a_reserved_consumer_id() { - assertReservedConsumerId("Vespa"); - } - - @Test - public void default_is_a_reserved_consumer_id() { - assertReservedConsumerId("default"); - } - - private void assertReservedConsumerId(String consumerId) { - String services = String.join("\n", - "<services>", - " <admin version='2.0'>", - " <adminserver hostalias='node1'/>", - " <metrics>", - " <consumer id='" + consumerId + "'/>", - " </metrics>", - " </admin>", - "</services>" - ); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("'" + consumerId + "' is not allowed as metrics consumer id"); - consumersConfigFromXml(services, self_hosted); - } - - @Test - public void vespa_consumer_id_is_allowed_for_hosted_infrastructure_applications() { - String services = String.join("\n", - "<services application-type='hosted-infrastructure'>", - " <admin version='4.0'>", - " <adminserver hostalias='node1'/>", - " <metrics>", - " <consumer id='Vespa'>", - " <metric id='custom.metric1'/>", - " </consumer>", - " </metrics>", - " </admin>", - "</services>" - ); - VespaModel hostedModel = getModel(services, hosted); - ConsumersConfig config = consumersConfigFromModel(hostedModel); - assertEquals(2, config.consumer().size()); - - // All default metrics are retained - ConsumersConfig.Consumer vespaConsumer = config.consumer(0); - assertEquals(numMetricsForVespaConsumer + 1, vespaConsumer.metric().size()); - - Metric customMetric1 = new Metric("custom.metric1"); - assertTrue("Did not contain metric: " + customMetric1, checkMetric(vespaConsumer, customMetric1)); - } - - @Test - public void consumer_id_is_case_insensitive() { - String services = String.join("\n", - "<services>", - " <admin version='2.0'>", - " <adminserver hostalias='node1'/>", - " <metrics>", - " <consumer id='A'/>", - " <consumer id='a'/>", - " </metrics>", - " </admin>", - "</services>" - ); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("'a' is used as id for two metrics consumers"); - consumersConfigFromXml(services, self_hosted); - } - - @Test - public void non_existent_metric_set_causes_exception() { - String services = String.join("\n", - "<services>", - " <admin version='2.0'>", - " <adminserver hostalias='node1'/>", - " <metrics>", - " <consumer id='consumer-with-non-existent-default-set'>", - " <metric-set id='non-existent'/>", - " </consumer>", - " </metrics>", - " </admin>", - "</services>" - ); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("No such metric-set: non-existent"); - consumersConfigFromXml(services, self_hosted); - } - - @Test - public void consumer_with_no_metric_set_has_its_own_metrics_plus_system_metrics_plus_default_vespa_metrics() { - String services = String.join("\n", - "<services>", - " <admin version='2.0'>", - " <adminserver hostalias='node1'/>", - " <metrics>", - " <consumer id='consumer-with-metrics-only'>", - " <metric id='custom.metric1'/>", - " <metric id='custom.metric2'/>", - " </consumer>", - " </metrics>", - " </admin>", - "</services>" - ); - ConsumersConfig.Consumer consumer = getCustomConsumer(services); - - assertEquals(numSystemMetrics + numDefaultVespaMetrics + 2, consumer.metric().size()); - - Metric customMetric1 = new Metric("custom.metric1"); - Metric customMetric2 = new Metric("custom.metric2"); - assertTrue("Did not contain metric: " + customMetric1, checkMetric(consumer, customMetric1)); - assertTrue("Did not contain metric: " + customMetric2, checkMetric(consumer, customMetric2)); - } - - @Test - public void consumer_with_default_public_metric_set_has_all_public_metrics_plus_all_system_metrics_plus_its_own() { - String services = String.join("\n", - "<services>", - " <admin version='2.0'>", - " <adminserver hostalias='node1'/>", - " <metrics>", - " <consumer id='consumer-with-public-default-set'>", - " <metric-set id='public'/>", - " <metric id='custom.metric'/>", - " </consumer>", - " </metrics>", - " </admin>", - "</services>" - ); - ConsumersConfig.Consumer consumer = getCustomConsumer(services); - - assertEquals(numPublicDefaultMetrics + numSystemMetrics + 1, consumer.metric().size()); - - Metric customMetric = new Metric("custom.metric"); - assertTrue("Did not contain metric: " + customMetric, checkMetric(consumer, customMetric)); - } - - @Test - public void consumer_with_vespa_metric_set_has_all_vespa_metrics_plus_all_system_metrics_plus_its_own() { - String services = String.join("\n", - "<services>", - " <admin version='2.0'>", - " <adminserver hostalias='node1'/>", - " <metrics>", - " <consumer id='consumer-with-vespa-set'>", - " <metric-set id='vespa'/>", - " <metric id='my.extra.metric'/>", - " </consumer>", - " </metrics>", - " </admin>", - "</services>" - ); - ConsumersConfig.Consumer consumer = getCustomConsumer(services); - assertEquals(numVespaMetrics + numSystemMetrics + 1, consumer.metric().size()); - - Metric customMetric = new Metric("my.extra.metric"); - assertTrue("Did not contain metric: " + customMetric, checkMetric(consumer, customMetric)); - } - - @Test public void hosted_application_propagates_application_dimensions() { VespaModel hostedModel = getModel(servicesWithAdminOnly(), hosted); ApplicationDimensionsConfig config = getApplicationDimensionsConfig(hostedModel); @@ -360,16 +137,6 @@ public class MetricsProxyContainerClusterTest { assertEquals(MetricsV1Handler.VALUES_PATH, node.metricsPath()); } - private static String servicesWithAdminOnly() { - return String.join("\n", - "<services>", - " <admin version='4.0'>", - " <adminserver hostalias='node1'/>", - " </admin>", - "</services>" - ); - } - private static String servicesWithTwoNodes() { return String.join("\n", "<services>", diff --git a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyModelTester.java index 7cbc9db5eb2..8ecb13d7ae5 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyModelTester.java @@ -55,6 +55,16 @@ class MetricsProxyModelTester { : CONTAINER_CONFIG_ID; } + static String servicesWithAdminOnly() { + return String.join("\n", + "<services>", + " <admin version='4.0'>", + " <adminserver hostalias='node1'/>", + " </admin>", + "</services>" + ); + } + static boolean checkMetric(ConsumersConfig.Consumer consumer, Metric metric) { for (ConsumersConfig.Consumer.Metric m : consumer.metric()) { if (metric.name.equals(m.name()) && metric.outputName.equals(m.outputname())) @@ -77,32 +87,32 @@ class MetricsProxyModelTester { } static ConsumersConfig consumersConfigFromModel(VespaModel model) { - return new ConsumersConfig((ConsumersConfig.Builder) model.getConfig(new ConsumersConfig.Builder(), CLUSTER_CONFIG_ID)); + return model.getConfig(ConsumersConfig.class, CLUSTER_CONFIG_ID); } static MetricsNodesConfig getMetricsNodesConfig(VespaModel model) { - return new MetricsNodesConfig((MetricsNodesConfig.Builder) model.getConfig(new MetricsNodesConfig.Builder(), CLUSTER_CONFIG_ID)); + return model.getConfig(MetricsNodesConfig.class, CLUSTER_CONFIG_ID); } static ApplicationDimensionsConfig getApplicationDimensionsConfig(VespaModel model) { - return new ApplicationDimensionsConfig((ApplicationDimensionsConfig.Builder) model.getConfig(new ApplicationDimensionsConfig.Builder(), CLUSTER_CONFIG_ID)); + return model.getConfig(ApplicationDimensionsConfig.class, CLUSTER_CONFIG_ID); } static QrStartConfig getQrStartConfig(VespaModel model) { - return new QrStartConfig((QrStartConfig.Builder) model.getConfig(new QrStartConfig.Builder(), CLUSTER_CONFIG_ID)); + return model.getConfig(QrStartConfig.class, CLUSTER_CONFIG_ID); } static NodeDimensionsConfig getNodeDimensionsConfig(VespaModel model, String configId) { - return new NodeDimensionsConfig((NodeDimensionsConfig.Builder) model.getConfig(new NodeDimensionsConfig.Builder(), configId)); + return model.getConfig(NodeDimensionsConfig.class, configId); } static VespaServicesConfig getVespaServicesConfig(String servicesXml) { VespaModel model = getModel(servicesXml, self_hosted); - return new VespaServicesConfig((VespaServicesConfig.Builder) model.getConfig(new VespaServicesConfig.Builder(), CONTAINER_CONFIG_ID)); + return model.getConfig(VespaServicesConfig.class, CONTAINER_CONFIG_ID); } static RpcConnectorConfig getRpcConnectorConfig(VespaModel model) { - return new RpcConnectorConfig((RpcConnectorConfig.Builder) model.getConfig(new RpcConnectorConfig.Builder(), CONTAINER_CONFIG_ID)); + return model.getConfig(RpcConnectorConfig.class, CONTAINER_CONFIG_ID); } } diff --git a/container-disc/src/main/sh/vespa-start-container-daemon.sh b/container-disc/src/main/sh/vespa-start-container-daemon.sh index f097d6d72bc..af429d56a75 100755 --- a/container-disc/src/main/sh/vespa-start-container-daemon.sh +++ b/container-disc/src/main/sh/vespa-start-container-daemon.sh @@ -93,7 +93,7 @@ configure_memory() { memory_options="${memory_options} -XX:MaxDirectMemorySize=${maxDirectMemorySize}m" if ((jvm_compressedClassSpaceSize != 0)); then - memory_options="${memory_options} -XX:CompressedClassSpaceSize=${compressedClassSpaceSize}m" + memory_options="${memory_options} -XX:CompressedClassSpaceSize=${jvm_compressedClassSpaceSize}m" fi if [ "${VESPA_USE_HUGEPAGES}" ]; then diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMetadata.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMetadata.java index 0aa0df8ae2b..171c5caa756 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMetadata.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMetadata.java @@ -18,25 +18,23 @@ public class EndpointCertificateMetadata { private final int version; private final Optional<String> request_id; private final Optional<List<String>> requestedDnsSans; + private final Optional<String> issuer; public EndpointCertificateMetadata(String keyName, String certName, int version) { - this.keyName = keyName; - this.certName = certName; - this.version = version; - this.request_id = Optional.empty(); - this.requestedDnsSans = Optional.empty(); + this(keyName, certName, version, Optional.empty(), Optional.empty(), Optional.empty()); + } + + public EndpointCertificateMetadata(String keyName, String certName, int version, String request_id, List<String> requestedDnsSans) { + this(keyName, certName, version, Optional.of(request_id), Optional.of(requestedDnsSans), Optional.empty()); } - public EndpointCertificateMetadata(String keyName, String certName, int version, Optional<String> request_id, Optional<List<String>> requestedDnsSans) { + public EndpointCertificateMetadata(String keyName, String certName, int version, Optional<String> request_id, Optional<List<String>> requestedDnsSans, Optional<String> issuer) { this.keyName = keyName; this.certName = certName; this.version = version; this.request_id = request_id; this.requestedDnsSans = requestedDnsSans; - } - - public EndpointCertificateMetadata(String keyName, String certName, int version, String request_id, List<String> requestedDnsSans) { - this(keyName, certName, version, Optional.of(request_id), Optional.of(requestedDnsSans)); + this.issuer = issuer; } public String keyName() { @@ -59,6 +57,10 @@ public class EndpointCertificateMetadata { return requestedDnsSans; } + public Optional<String> issuer() { + return issuer; + } + @Override public String toString() { return "EndpointCertificateMetadata{" + @@ -67,6 +69,7 @@ public class EndpointCertificateMetadata { ", version=" + version + ", request_id=" + request_id + ", requestedDnsSans=" + requestedDnsSans + + ", issuer=" + issuer + '}'; } @@ -79,11 +82,12 @@ public class EndpointCertificateMetadata { keyName.equals(that.keyName) && certName.equals(that.certName) && request_id.equals(that.request_id) && - requestedDnsSans.equals(that.requestedDnsSans); + requestedDnsSans.equals(that.requestedDnsSans) && + issuer.equals(that.issuer); } @Override public int hashCode() { - return Objects.hash(keyName, certName, version, request_id, requestedDnsSans); + return Objects.hash(keyName, certName, version, request_id, requestedDnsSans, issuer); } } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMock.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMock.java index 8e81400f3c8..c38ea158507 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMock.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateMock.java @@ -7,6 +7,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.UUID; /** @@ -21,7 +22,7 @@ public class EndpointCertificateMock implements EndpointCertificateProvider { } @Override - public EndpointCertificateMetadata requestCaSignedCertificate(ApplicationId applicationId, List<String> dnsNames) { + public EndpointCertificateMetadata requestCaSignedCertificate(ApplicationId applicationId, List<String> dnsNames, Optional<EndpointCertificateMetadata> currentMetadata) { this.dnsNames.put(applicationId, dnsNames); String endpointCertificatePrefix = String.format("vespa.tls.%s.%s@%s", applicationId.tenant(), applicationId.application(), diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateProvider.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateProvider.java index 97d2bdb3343..9c5c25c1c71 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateProvider.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/certificates/EndpointCertificateProvider.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.hosted.controller.api.integration.certificates; import com.yahoo.config.provision.ApplicationId; import java.util.List; +import java.util.Optional; /** * Generates an endpoint certificate for an application instance. @@ -12,7 +13,7 @@ import java.util.List; */ public interface EndpointCertificateProvider { - EndpointCertificateMetadata requestCaSignedCertificate(ApplicationId applicationId, List<String> dnsNames); + EndpointCertificateMetadata requestCaSignedCertificate(ApplicationId applicationId, List<String> dnsNames, Optional<EndpointCertificateMetadata> currentMetadata); List<EndpointCertificateMetadata> listCertificates(); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/endpointcertificates/EndpointCertificateManager.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/endpointcertificates/EndpointCertificateManager.java index d915da21603..23a3ffb42b6 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/endpointcertificates/EndpointCertificateManager.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/endpointcertificates/EndpointCertificateManager.java @@ -156,7 +156,8 @@ public class EndpointCertificateManager { storedMetaData.certName(), storedMetaData.version(), providerMetadata.request_id(), - providerMetadata.requestedDnsSans()); + providerMetadata.requestedDnsSans(), + Optional.empty()); if (mode == BackfillMode.DRYRUN) { log.log(LogLevel.INFO, "Would update stored metadata " + storedMetaData + " with data from provider: " + backfilledMetadata); @@ -176,7 +177,7 @@ public class EndpointCertificateManager { private EndpointCertificateMetadata provisionEndpointCertificate(Instance instance) { List<ZoneId> zones = zoneRegistry.zones().controllerUpgraded().zones().stream().map(ZoneApi::getId).collect(Collectors.toUnmodifiableList()); EndpointCertificateMetadata provisionedCertificateMetadata = endpointCertificateProvider - .requestCaSignedCertificate(instance.id(), dnsNamesOf(instance.id(), zones)); + .requestCaSignedCertificate(instance.id(), dnsNamesOf(instance.id(), zones), Optional.empty()); curator.writeEndpointCertificateMetadata(instance.id(), provisionedCertificateMetadata); return provisionedCertificateMetadata; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java index ad2835e301f..eb86b1028e2 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java @@ -42,7 +42,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NavigableMap; @@ -521,8 +520,7 @@ public class CuratorDb { } public Optional<EndpointCertificateMetadata> readEndpointCertificateMetadata(ApplicationId applicationId) { - Optional<String> zkData = curator.getData(endpointCertificatePath(applicationId)).map(String::new); - return zkData.map(EndpointCertificateMetadataSerializer::fromJsonOrTlsSecretsKeysString); + return curator.getData(endpointCertificatePath(applicationId)).map(String::new).map(EndpointCertificateMetadataSerializer::fromJsonString); } public Map<ApplicationId, EndpointCertificateMetadata> readAllEndpointCertificateMetadata() { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializer.java index 653f224a02b..501d3a06d42 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializer.java @@ -4,6 +4,7 @@ import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; import com.yahoo.slime.Slime; import com.yahoo.slime.SlimeUtils; +import com.yahoo.slime.Type; import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificateMetadata; import java.util.List; @@ -33,6 +34,7 @@ public class EndpointCertificateMetadataSerializer { private final static String versionField = "version"; private final static String requestIdField = "requestId"; private final static String requestedDnsSansField = "requestedDnsSans"; + private final static String issuerField = "issuer"; public static Slime toSlime(EndpointCertificateMetadata metadata) { Slime slime = new Slime(); @@ -51,46 +53,31 @@ public class EndpointCertificateMetadataSerializer { } public static EndpointCertificateMetadata fromSlime(Inspector inspector) { - switch (inspector.type()) { - case STRING: // TODO: Remove once all are transmitted and stored as JSON - return new EndpointCertificateMetadata( - inspector.asString() + "-key", - inspector.asString() + "-cert", - 0 - ); - case OBJECT: { - Optional<String> request_id = inspector.field(requestIdField).valid() ? - Optional.of(inspector.field(requestIdField).asString()) : - Optional.empty(); + if (inspector.type() != Type.OBJECT) + throw new IllegalArgumentException("Unknown format encountered for endpoint certificate metadata!"); + Optional<String> request_id = inspector.field(requestIdField).valid() ? + Optional.of(inspector.field(requestIdField).asString()) : + Optional.empty(); - Optional<List<String>> requestedDnsSans = inspector.field(requestedDnsSansField).valid() ? - Optional.of(IntStream.range(0, inspector.field(requestedDnsSansField).entries()) - .mapToObj(i -> inspector.field(requestedDnsSansField).entry(i).asString()).collect(Collectors.toList())) : - Optional.empty(); + Optional<List<String>> requestedDnsSans = inspector.field(requestedDnsSansField).valid() ? + Optional.of(IntStream.range(0, inspector.field(requestedDnsSansField).entries()) + .mapToObj(i -> inspector.field(requestedDnsSansField).entry(i).asString()).collect(Collectors.toList())) : + Optional.empty(); - return new EndpointCertificateMetadata( - inspector.field(keyNameField).asString(), - inspector.field(certNameField).asString(), - Math.toIntExact(inspector.field(versionField).asLong()), - request_id, - requestedDnsSans - ); - } + Optional<String> issuer = inspector.field(issuerField).valid() ? + Optional.of(inspector.field(issuerField).asString()) : + Optional.empty(); - default: - throw new IllegalArgumentException("Unknown format encountered for endpoint certificate metadata!"); - } + return new EndpointCertificateMetadata( + inspector.field(keyNameField).asString(), + inspector.field(certNameField).asString(), + Math.toIntExact(inspector.field(versionField).asLong()), + request_id, + requestedDnsSans, + issuer); } - public static EndpointCertificateMetadata fromTlsSecretsKeysString(String tlsSecretsKeys) { - return fromSlime(new Slime().setString(tlsSecretsKeys)); - } - - public static EndpointCertificateMetadata fromJsonOrTlsSecretsKeysString(String zkdata) { - if (zkdata.strip().startsWith("{")) { - return fromSlime(SlimeUtils.jsonToSlime(zkdata).get()); - } else { - return fromTlsSecretsKeysString(zkdata); - } + public static EndpointCertificateMetadata fromJsonString(String zkdata) { + return fromSlime(SlimeUtils.jsonToSlime(zkdata).get()); } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializerTest.java index 7428b9901a2..5f8a3eaa98a 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/EndpointCertificateMetadataSerializerTest.java @@ -29,17 +29,10 @@ public class EndpointCertificateMetadataSerializerTest { } @Test - public void deserializeFromString() { - assertEquals( - new EndpointCertificateMetadata("foo-key", "foo-cert", 0), - EndpointCertificateMetadataSerializer.fromJsonOrTlsSecretsKeysString("foo")); - } - - @Test public void deserializeFromJson() { assertEquals( sample, - EndpointCertificateMetadataSerializer.fromJsonOrTlsSecretsKeysString( + EndpointCertificateMetadataSerializer.fromJsonString( "{\"keyName\":\"keyName\",\"certName\":\"certName\",\"version\":1}")); } @@ -47,7 +40,7 @@ public class EndpointCertificateMetadataSerializerTest { public void deserializeFromJsonWithRequestMetadata() { assertEquals( sampleWithRequestMetadata, - EndpointCertificateMetadataSerializer.fromJsonOrTlsSecretsKeysString( + EndpointCertificateMetadataSerializer.fromJsonString( "{\"keyName\":\"keyName\",\"certName\":\"certName\",\"version\":1,\"requestId\":\"requestId\",\"requestedDnsSans\":[\"SAN1\",\"SAN2\"]}")); } }
\ No newline at end of file diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java index a363bf52155..a140e87713c 100644 --- a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java +++ b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java @@ -21,7 +21,8 @@ public class TlsCryptoEngine implements CryptoEngine { @Override public TlsCryptoSocket createClientCryptoSocket(SocketChannel channel, Spec spec) { - SSLEngine sslEngine = tlsContext.createSslEngine(spec.host(), spec.port()); + String peerHost = spec.host() != null ? spec.host() : "localhost"; // Use localhost for wildcard address + SSLEngine sslEngine = tlsContext.createSslEngine(peerHost, spec.port()); sslEngine.setUseClientMode(true); return new TlsCryptoSocket(channel, sslEngine); } diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/Telegraf.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/Telegraf.java index bf4f0d4c49b..1c8401d003a 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/Telegraf.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/Telegraf.java @@ -11,6 +11,7 @@ import java.io.FileWriter; import java.io.InputStreamReader; import java.io.Reader; import java.io.Writer; +import java.util.logging.Logger; import static com.yahoo.yolean.Exceptions.uncheck; @@ -23,6 +24,8 @@ public class Telegraf extends AbstractComponent { private static final String TELEGRAF_CONFIG_TEMPLATE_PATH = "templates/telegraf.conf.vm"; private final TelegrafRegistry telegrafRegistry; + private static final Logger logger = Logger.getLogger(Telegraf.class.getName()); + @Inject public Telegraf(TelegrafRegistry telegrafRegistry, TelegrafConfig telegrafConfig) { this.telegrafRegistry = telegrafRegistry; @@ -44,10 +47,12 @@ public class Telegraf extends AbstractComponent { } private void restartTelegraf() { + logger.info("Restarting Telegraf"); executeCommand("service telegraf restart"); } private void stopTelegraf() { + logger.info("Stopping Telegraf"); executeCommand("service telegraf stop"); } diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/TelegrafRegistry.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/TelegrafRegistry.java index 429da5bb933..dbc08a96777 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/TelegrafRegistry.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/telegraf/TelegrafRegistry.java @@ -1,9 +1,12 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.metricsproxy.telegraf; +import com.yahoo.log.LogLevel; + import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.logging.Logger; /** * @author olaa @@ -12,11 +15,15 @@ public class TelegrafRegistry { private static final List<Telegraf> telegrafInstances = Collections.synchronizedList(new ArrayList<>()); + private static final Logger logger = Logger.getLogger(TelegrafRegistry.class.getName()); + public void addInstance(Telegraf telegraf) { + logger.log(LogLevel.DEBUG, () -> "Adding Telegraf instance to registry: " + telegraf); telegrafInstances.add(telegraf); } public void removeInstance(Telegraf telegraf) { + logger.log(LogLevel.DEBUG, () -> "Removing Telegraf instance from registry: " + telegraf); telegrafInstances.remove(telegraf); } diff --git a/metrics-proxy/src/main/resources/templates/telegraf.conf.vm b/metrics-proxy/src/main/resources/templates/telegraf.conf.vm index ff04dafe276..c427ee1ce4b 100644 --- a/metrics-proxy/src/main/resources/templates/telegraf.conf.vm +++ b/metrics-proxy/src/main/resources/templates/telegraf.conf.vm @@ -8,6 +8,11 @@ flush_interval = "${intervalSeconds}s" flush_jitter = "0s" precision = "" + logtarget = "file" + logfile = "/var/log/telegraf/telegraf.log" + logfile_rotation_interval = "1d" + logfile_rotation_max_size = "20MB" + logfile_rotation_max_archives = 5 #foreach( $cloudwatch in $cloudwatchPlugins ) # Configuration for AWS CloudWatch output. diff --git a/metrics-proxy/src/test/resources/telegraf-config-with-two-cloudwatch-plugins.txt b/metrics-proxy/src/test/resources/telegraf-config-with-two-cloudwatch-plugins.txt index 0dec2775a05..85656465901 100644 --- a/metrics-proxy/src/test/resources/telegraf-config-with-two-cloudwatch-plugins.txt +++ b/metrics-proxy/src/test/resources/telegraf-config-with-two-cloudwatch-plugins.txt @@ -8,6 +8,11 @@ flush_interval = "300s" flush_jitter = "0s" precision = "" + logtarget = "file" + logfile = "/var/log/telegraf/telegraf.log" + logfile_rotation_interval = "1d" + logfile_rotation_max_size = "20MB" + logfile_rotation_max_archives = 5 # Configuration for AWS CloudWatch output. [[outputs.cloudwatch]] diff --git a/searchcommon/src/vespa/searchcommon/attribute/config.cpp b/searchcommon/src/vespa/searchcommon/attribute/config.cpp index 53e57fd9c66..b4e05875820 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/config.cpp +++ b/searchcommon/src/vespa/searchcommon/attribute/config.cpp @@ -17,7 +17,8 @@ Config::Config() : _growStrategy(), _compactionStrategy(), _predicateParams(), - _tensorType(vespalib::eval::ValueType::error_type()) + _tensorType(vespalib::eval::ValueType::error_type()), + _hnsw_index_params() { } @@ -34,7 +35,8 @@ Config::Config(BasicType bt, CollectionType ct, bool fastSearch_, bool huge_) _growStrategy(), _compactionStrategy(), _predicateParams(), - _tensorType(vespalib::eval::ValueType::error_type()) + _tensorType(vespalib::eval::ValueType::error_type()), + _hnsw_index_params() { } @@ -60,7 +62,8 @@ Config::operator==(const Config &b) const _compactionStrategy == b._compactionStrategy && _predicateParams == b._predicateParams && (_basicType.type() != BasicType::Type::TENSOR || - _tensorType == b._tensorType); + _tensorType == b._tensorType) && + _hnsw_index_params == b._hnsw_index_params; } } diff --git a/searchcommon/src/vespa/searchcommon/attribute/config.h b/searchcommon/src/vespa/searchcommon/attribute/config.h index 2f767061f7a..836fcfed84a 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/config.h +++ b/searchcommon/src/vespa/searchcommon/attribute/config.h @@ -4,15 +4,21 @@ #include "basictype.h" #include "collectiontype.h" +#include "hnsw_index_params.h" #include "predicate_params.h" -#include <vespa/searchcommon/common/growstrategy.h> #include <vespa/searchcommon/common/compaction_strategy.h> +#include <vespa/searchcommon/common/growstrategy.h> #include <vespa/eval/eval/value_type.h> +#include <optional> namespace search::attribute { -class Config -{ +/** + * Configuration for an attribute vector. + * + * Used to determine which implementation to instantiate. + */ +class Config { public: Config(); Config(BasicType bt, CollectionType ct = CollectionType::SINGLE, @@ -29,6 +35,7 @@ public: bool huge() const { return _huge; } const PredicateParams &predicateParams() const { return _predicateParams; } vespalib::eval::ValueType tensorType() const { return _tensorType; } + const std::optional<HnswIndexParams>& hnsw_index_params() const { return _hnsw_index_params; } /** * Check if attribute posting list can consist of a bitvector in @@ -60,6 +67,10 @@ public: _tensorType = tensorType_in; return *this; } + Config& set_hnsw_index_params(const HnswIndexParams& params) { + _hnsw_index_params = params; + return *this; + } /** * Enable attribute posting list to consist of a bitvector in @@ -107,6 +118,7 @@ private: CompactionStrategy _compactionStrategy; PredicateParams _predicateParams; vespalib::eval::ValueType _tensorType; + std::optional<HnswIndexParams> _hnsw_index_params; }; } diff --git a/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h new file mode 100644 index 00000000000..9e98a8c5fb7 --- /dev/null +++ b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h @@ -0,0 +1,32 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace search::attribute { + +/** + * Configuration parameters for a hnsw index used together with a 1-dimensional indexed tensor + * for approximate nearest neighbor search. + */ +class HnswIndexParams { +private: + uint32_t _max_links_per_node; + uint32_t _neighbors_to_explore_at_insert; + +public: + HnswIndexParams(uint32_t max_links_per_node_in, + uint32_t neighbors_to_explore_at_insert_in) + : _max_links_per_node(max_links_per_node_in), + _neighbors_to_explore_at_insert(neighbors_to_explore_at_insert_in) + {} + + uint32_t max_links_per_node() const { return _max_links_per_node; } + uint32_t neighbors_to_explore_at_insert() const { return _neighbors_to_explore_at_insert; } + + bool operator==(const HnswIndexParams& rhs) const { + return _max_links_per_node == rhs._max_links_per_node && + _neighbors_to_explore_at_insert == rhs._neighbors_to_explore_at_insert; + } +}; + +} diff --git a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp index 7d09b2aa0b8..850a967ed3d 100644 --- a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp +++ b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp @@ -278,6 +278,22 @@ AttributeManagerTest::testConfigConvert() AttributeVector::Config out = ConfigConverter::convert(a); EXPECT_EQUAL("tensor(x[5])", out.tensorType().to_spec()); } + { // hnsw index params (enabled) + CACA a; + a.index.hnsw.enabled = true; + a.index.hnsw.maxlinkspernode = 32; + a.index.hnsw.neighborstoexploreatinsert = 300; + auto out = ConfigConverter::convert(a); + EXPECT_TRUE(out.hnsw_index_params().has_value()); + EXPECT_EQUAL(32u, out.hnsw_index_params().value().max_links_per_node()); + EXPECT_EQUAL(300u, out.hnsw_index_params().value().neighbors_to_explore_at_insert()); + } + { // hnsw index params (disabled) + CACA a; + a.index.hnsw.enabled = false; + auto out = ConfigConverter::convert(a); + EXPECT_FALSE(out.hnsw_index_params().has_value()); + } } bool gt_attribute(const attribute::IAttributeVector * a, const attribute::IAttributeVector * b) { diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 7e0fcdc0ccc..5089743a54a 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -1,34 +1,48 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/vespalib/testkit/test_kit.h> #include <vespa/document/base/exceptions.h> -#include <vespa/searchlib/tensor/tensor_attribute.h> -#include <vespa/searchlib/tensor/generic_tensor_attribute.h> -#include <vespa/searchlib/tensor/dense_tensor_attribute.h> -#include <vespa/searchlib/attribute/attributeguard.h> -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/dense/dense_tensor.h> #include <vespa/eval/tensor/default_tensor_engine.h> -#include <vespa/vespalib/io/fileutil.h> -#include <vespa/vespalib/data/fileheader.h> +#include <vespa/eval/tensor/dense/dense_tensor.h> +#include <vespa/eval/tensor/tensor.h> #include <vespa/fastos/file.h> +#include <vespa/searchlib/attribute/attributeguard.h> +#include <vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h> +#include <vespa/searchlib/tensor/dense_tensor_attribute.h> +#include <vespa/searchlib/tensor/doc_vector_access.h> +#include <vespa/searchlib/tensor/generic_tensor_attribute.h> +#include <vespa/searchlib/tensor/hnsw_index.h> +#include <vespa/searchlib/tensor/nearest_neighbor_index.h> +#include <vespa/searchlib/tensor/nearest_neighbor_index_factory.h> +#include <vespa/searchlib/tensor/tensor_attribute.h> +#include <vespa/vespalib/data/fileheader.h> +#include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/vespalib/testkit/test_kit.h> + #include <vespa/log/log.h> LOG_SETUP("tensorattribute_test"); using document::WrongTensorTypeException; -using search::tensor::TensorAttribute; -using search::tensor::DenseTensorAttribute; -using search::tensor::GenericTensorAttribute; using search::AttributeGuard; using search::AttributeVector; -using vespalib::eval::ValueType; +using search::attribute::HnswIndexParams; +using search::tensor::DefaultNearestNeighborIndexFactory; +using search::tensor::DenseTensorAttribute; +using search::tensor::DocVectorAccess; +using search::tensor::GenericTensorAttribute; +using search::tensor::HnswIndex; +using search::tensor::NearestNeighborIndex; +using search::tensor::NearestNeighborIndexFactory; +using search::tensor::TensorAttribute; using vespalib::eval::TensorSpec; -using vespalib::tensor::Tensor; -using vespalib::tensor::DenseTensor; +using vespalib::eval::ValueType; using vespalib::tensor::DefaultTensorEngine; +using vespalib::tensor::DenseTensor; +using vespalib::tensor::Tensor; + +using DoubleVector = std::vector<double>; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { static bool operator==(const Tensor &lhs, const Tensor &rhs) { @@ -36,10 +50,10 @@ static bool operator==(const Tensor &lhs, const Tensor &rhs) } } -} vespalib::string sparseSpec("tensor(x{},y{})"); vespalib::string denseSpec("tensor(x[2],y[3])"); +vespalib::string vec_2d_spec("tensor(x[2])"); Tensor::UP createTensor(const TensorSpec &spec) { auto value = DefaultTensorEngine::ref().from_spec(spec); @@ -52,6 +66,78 @@ Tensor::UP createTensor(const TensorSpec &spec) { return Tensor::UP(tensor); } +TensorSpec +vec_2d(double x0, double x1) +{ + return TensorSpec(vec_2d_spec).add({{"x", 0}}, x0).add({{"x", 1}}, x1); +} + +class MockNearestNeighborIndex : public NearestNeighborIndex { +private: + using Entry = std::pair<uint32_t, DoubleVector>; + using EntryVector = std::vector<Entry>; + + const DocVectorAccess& _vectors; + EntryVector _adds; + EntryVector _removes; + +public: + MockNearestNeighborIndex(const DocVectorAccess& vectors) + : _vectors(vectors), + _adds(), + _removes() + { + } + void clear() { + _adds.clear(); + _removes.clear(); + } + void expect_empty_add() const { + EXPECT_TRUE(_adds.empty()); + } + void expect_add(uint32_t exp_docid, const DoubleVector& exp_vector) const { + EXPECT_EQUAL(1u, _adds.size()); + EXPECT_EQUAL(exp_docid, _adds.back().first); + EXPECT_EQUAL(exp_vector, _adds.back().second); + } + void expect_adds(const EntryVector &exp_adds) const { + EXPECT_EQUAL(exp_adds, _adds); + } + void expect_empty_remove() const { + EXPECT_TRUE(_removes.empty()); + } + void expect_remove(uint32_t exp_docid, const DoubleVector& exp_vector) const { + EXPECT_EQUAL(1u, _removes.size()); + EXPECT_EQUAL(exp_docid, _removes.back().first); + EXPECT_EQUAL(exp_vector, _removes.back().second); + } + void add_document(uint32_t docid) override { + auto vector = _vectors.get_vector(docid).typify<double>(); + _adds.emplace_back(docid, DoubleVector(vector.begin(), vector.end())); + } + void remove_document(uint32_t docid) override { + auto vector = _vectors.get_vector(docid).typify<double>(); + _removes.emplace_back(docid, DoubleVector(vector.begin(), vector.end())); + } + std::vector<Neighbor> find_top_k(uint32_t k, vespalib::tensor::TypedCells vector, uint32_t explore_k) const override { + (void) k; + (void) vector; + (void) explore_k; + return std::vector<Neighbor>(); + } +}; + +class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory { + + std::unique_ptr<NearestNeighborIndex> make(const DocVectorAccess& vectors, + ValueType::CellType cell_type, + const search::attribute::HnswIndexParams& params) const override { + (void) params; + assert(cell_type == ValueType::CellType::DOUBLE); + return std::make_unique<MockNearestNeighborIndex>(vectors); + } +}; + struct Fixture { using BasicType = search::attribute::BasicType; @@ -61,16 +147,20 @@ struct Fixture Config _cfg; vespalib::string _name; vespalib::string _typeSpec; + std::unique_ptr<NearestNeighborIndexFactory> _index_factory; std::shared_ptr<TensorAttribute> _tensorAttr; std::shared_ptr<AttributeVector> _attr; bool _denseTensors; bool _useDenseTensorAttribute; Fixture(const vespalib::string &typeSpec, - bool useDenseTensorAttribute = false) + bool useDenseTensorAttribute = false, + bool enable_hnsw_index = false, + bool use_mock_index = false) : _cfg(BasicType::TENSOR, CollectionType::SINGLE), _name("test"), _typeSpec(typeSpec), + _index_factory(std::make_unique<DefaultNearestNeighborIndexFactory>()), _tensorAttr(), _attr(), _denseTensors(false), @@ -80,20 +170,40 @@ struct Fixture if (_cfg.tensorType().is_dense()) { _denseTensors = true; } + if (enable_hnsw_index) { + _cfg.set_hnsw_index_params(HnswIndexParams(4, 20)); + if (use_mock_index) { + _index_factory = std::make_unique<MockNearestNeighborIndexFactory>(); + } + } _tensorAttr = makeAttr(); _attr = _tensorAttr; _attr->addReservedDoc(); } + ~Fixture() {} std::shared_ptr<TensorAttribute> makeAttr() { if (_useDenseTensorAttribute) { assert(_denseTensors); - return std::make_shared<DenseTensorAttribute>(_name, _cfg); + return std::make_shared<DenseTensorAttribute>(_name, _cfg, *_index_factory); } else { return std::make_shared<GenericTensorAttribute>(_name, _cfg); } } + const DenseTensorAttribute& as_dense_tensor() const { + auto result = dynamic_cast<const DenseTensorAttribute*>(_tensorAttr.get()); + assert(result != nullptr); + return *result; + } + + MockNearestNeighborIndex& mock_index() { + assert(as_dense_tensor().nearest_neighbor_index() != nullptr); + auto mock_index = dynamic_cast<const MockNearestNeighborIndex*>(as_dense_tensor().nearest_neighbor_index()); + assert(mock_index != nullptr); + return *const_cast<MockNearestNeighborIndex*>(mock_index); + } + void ensureSpace(uint32_t docId) { while (_attr->getNumDocs() <= docId) { uint32_t newDocId = 0u; @@ -108,7 +218,15 @@ struct Fixture _attr->commit(); } - void setTensor(uint32_t docId, const Tensor &tensor) { + void set_tensor(uint32_t docid, const TensorSpec &spec) { + set_tensor_internal(docid, *createTensor(spec)); + } + + void set_empty_tensor(uint32_t docid) { + set_tensor_internal(docid, *_tensorAttr->getEmptyTensor()); + } + + void set_tensor_internal(uint32_t docId, const Tensor &tensor) { ensureSpace(docId); _tensorAttr->setTensor(docId, tensor); _attr->commit(); @@ -119,27 +237,18 @@ struct Fixture return _attr->getStatus(); } - void - assertGetNoTensor(uint32_t docId) { + void assertGetNoTensor(uint32_t docId) { AttributeGuard guard(_attr); Tensor::UP actTensor = _tensorAttr->getTensor(docId); EXPECT_FALSE(actTensor); } - void - assertGetTensor(const Tensor &expTensor, uint32_t docId) - { + void assertGetTensor(const TensorSpec &expSpec, uint32_t docId) { + Tensor::UP expTensor = createTensor(expSpec); AttributeGuard guard(_attr); Tensor::UP actTensor = _tensorAttr->getTensor(docId); EXPECT_TRUE(static_cast<bool>(actTensor)); - EXPECT_EQUAL(expTensor, *actTensor); - } - - void - assertGetTensor(const TensorSpec &expSpec, uint32_t docId) - { - Tensor::UP expTensor = createTensor(expSpec); - assertGetTensor(*expTensor, docId); + EXPECT_EQUAL(*expTensor, *actTensor); } void save() { @@ -154,23 +263,20 @@ struct Fixture EXPECT_TRUE(loadok); } - Tensor::UP expDenseTensor3() const - { - return createTensor(TensorSpec(denseSpec) - .add({{"x", 0}, {"y", 1}}, 11) - .add({{"x", 1}, {"y", 2}}, 0)); + TensorSpec expDenseTensor3() const { + return TensorSpec(denseSpec) + .add({{"x", 0}, {"y", 1}}, 11) + .add({{"x", 1}, {"y", 2}}, 0); } - Tensor::UP expDenseFillTensor() const - { - return createTensor(TensorSpec(denseSpec) - .add({{"x", 0}, {"y", 0}}, 5) - .add({{"x", 1}, {"y", 2}}, 0)); + TensorSpec expDenseFillTensor() const { + return TensorSpec(denseSpec) + .add({{"x", 0}, {"y", 0}}, 5) + .add({{"x", 1}, {"y", 2}}, 0); } - Tensor::UP expEmptyDenseTensor() const - { - return createTensor(TensorSpec(denseSpec)); + TensorSpec expEmptyDenseTensor() const { + return TensorSpec(denseSpec); } vespalib::string expEmptyDenseTensorSpec() const { @@ -200,21 +306,21 @@ Fixture::testSetTensorValue() EXPECT_EQUAL(5u, _attr->getNumDocs()); EXPECT_EQUAL(5u, _attr->getCommittedDocIdLimit()); TEST_DO(assertGetNoTensor(4)); - EXPECT_EXCEPTION(setTensor(4, *createTensor(TensorSpec("double"))), + EXPECT_EXCEPTION(set_tensor(4, TensorSpec("double")), WrongTensorTypeException, "but other tensor type is 'double'"); TEST_DO(assertGetNoTensor(4)); - setTensor(4, *_tensorAttr->getEmptyTensor()); + set_empty_tensor(4); if (_denseTensors) { - TEST_DO(assertGetTensor(*expEmptyDenseTensor(), 4)); - setTensor(3, *expDenseTensor3()); - TEST_DO(assertGetTensor(*expDenseTensor3(), 3)); + TEST_DO(assertGetTensor(expEmptyDenseTensor(), 4)); + set_tensor(3, expDenseTensor3()); + TEST_DO(assertGetTensor(expDenseTensor3(), 3)); } else { TEST_DO(assertGetTensor(TensorSpec(sparseSpec), 4)); - setTensor(3, *createTensor(TensorSpec(sparseSpec) - .add({{"x", ""}, {"y", ""}}, 11))); + set_tensor(3, TensorSpec(sparseSpec) + .add({{"x", ""}, {"y", ""}}, 11)); TEST_DO(assertGetTensor(TensorSpec(sparseSpec) - .add({{"x", ""}, {"y", ""}}, 11), 3)); + .add({{"x", ""}, {"y", ""}}, 11), 3)); } TEST_DO(assertGetNoTensor(2)); TEST_DO(clearTensor(3)); @@ -225,23 +331,23 @@ void Fixture::testSaveLoad() { ensureSpace(4); - setTensor(4, *_tensorAttr->getEmptyTensor()); + set_empty_tensor(4); if (_denseTensors) { - setTensor(3, *expDenseTensor3()); + set_tensor(3, expDenseTensor3()); } else { - setTensor(3, *createTensor(TensorSpec(sparseSpec) - .add({{"x", ""}, {"y", "1"}}, 11))); + set_tensor(3, TensorSpec(sparseSpec) + .add({{"x", ""}, {"y", "1"}}, 11)); } TEST_DO(save()); TEST_DO(load()); EXPECT_EQUAL(5u, _attr->getNumDocs()); EXPECT_EQUAL(5u, _attr->getCommittedDocIdLimit()); if (_denseTensors) { - TEST_DO(assertGetTensor(*expDenseTensor3(), 3)); - TEST_DO(assertGetTensor(*expEmptyDenseTensor(), 4)); + TEST_DO(assertGetTensor(expDenseTensor3(), 3)); + TEST_DO(assertGetTensor(expEmptyDenseTensor(), 4)); } else { TEST_DO(assertGetTensor(TensorSpec(sparseSpec) - .add({{"x", ""}, {"y", "1"}}, 11), 3)); + .add({{"x", ""}, {"y", "1"}}, 11), 3)); TEST_DO(assertGetTensor(TensorSpec(sparseSpec), 4)); } TEST_DO(assertGetNoTensor(2)); @@ -256,29 +362,28 @@ Fixture::testCompaction() return; } ensureSpace(4); - Tensor::UP emptytensor = _tensorAttr->getEmptyTensor(); - Tensor::UP emptyxytensor = createTensor(TensorSpec(sparseSpec)); - Tensor::UP simpletensor = createTensor(TensorSpec(sparseSpec) - .add({{"x", ""}, {"y", "1"}}, 11)); - Tensor::UP filltensor = createTensor(TensorSpec(sparseSpec) - .add({{"x", ""}, {"y", ""}}, 5)); + TensorSpec empty_xy_tensor(sparseSpec); + TensorSpec simple_tensor = TensorSpec(sparseSpec) + .add({{"x", ""}, {"y", "1"}}, 11); + TensorSpec fill_tensor = TensorSpec(sparseSpec) + .add({{"x", ""}, {"y", ""}}, 5); if (_denseTensors) { - emptyxytensor = expEmptyDenseTensor(); - simpletensor = expDenseTensor3(); - filltensor = expDenseFillTensor(); + empty_xy_tensor = expEmptyDenseTensor(); + simple_tensor = expDenseTensor3(); + fill_tensor = expDenseFillTensor(); } - setTensor(4, *emptytensor); - setTensor(3, *simpletensor); - setTensor(2, *filltensor); + set_empty_tensor(4); + set_tensor(3, simple_tensor); + set_tensor(2, fill_tensor); clearTensor(2); - setTensor(2, *filltensor); + set_tensor(2, fill_tensor); search::attribute::Status oldStatus = getStatus(); search::attribute::Status newStatus = oldStatus; uint64_t iter = 0; uint64_t iterLimit = 100000; for (; iter < iterLimit; ++iter) { clearTensor(2); - setTensor(2, *filltensor); + set_tensor(2, fill_tensor); newStatus = getStatus(); if (newStatus.getUsed() < oldStatus.getUsed()) { break; @@ -290,9 +395,9 @@ Fixture::testCompaction() "iter = %" PRIu64 ", memory usage %" PRIu64 ", -> %" PRIu64, iter, oldStatus.getUsed(), newStatus.getUsed()); TEST_DO(assertGetNoTensor(1)); - TEST_DO(assertGetTensor(*filltensor, 2)); - TEST_DO(assertGetTensor(*simpletensor, 3)); - TEST_DO(assertGetTensor(*emptyxytensor, 4)); + TEST_DO(assertGetTensor(fill_tensor, 2)); + TEST_DO(assertGetTensor(simple_tensor, 3)); + TEST_DO(assertGetTensor(empty_xy_tensor, 4)); } void @@ -357,4 +462,73 @@ TEST("Test dense tensors with dense tensor attribute") testAll([]() { return std::make_shared<Fixture>(denseSpec, true); }); } +TEST_F("Hnsw index is NOT instantiated in dense tensor attribute by default", + Fixture(vec_2d_spec, true, false)) +{ + const auto& tensor = f.as_dense_tensor(); + EXPECT_TRUE(tensor.nearest_neighbor_index() == nullptr); +} + +TEST_F("Hnsw index is instantiated in dense tensor attribute when specified in config", + Fixture(vec_2d_spec, true, true)) +{ + const auto& tensor = f.as_dense_tensor(); + ASSERT_TRUE(tensor.nearest_neighbor_index() != nullptr); + auto hnsw_index = dynamic_cast<const HnswIndex*>(tensor.nearest_neighbor_index()); + ASSERT_TRUE(hnsw_index != nullptr); + + const auto& cfg = hnsw_index->config(); + EXPECT_EQUAL(8u, cfg.max_links_at_level_0()); + EXPECT_EQUAL(4u, cfg.max_links_at_hierarchic_levels()); + EXPECT_EQUAL(20u, cfg.neighbors_to_explore_at_construction()); + EXPECT_TRUE(cfg.heuristic_select_neighbors()); +} + +class DenseTensorAttributeMockIndex : public Fixture { +public: + DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, true, true, true) {} +}; + +TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockIndex) +{ + auto& index = f.mock_index(); + + f.set_tensor(1, vec_2d(3, 5)); + index.expect_add(1, {3, 5}); + index.expect_empty_remove(); + index.clear(); + + // Replaces previous value. + f.set_tensor(1, vec_2d(7, 9)); + index.expect_remove(1, {3, 5}); + index.expect_add(1, {7, 9}); +} + +TEST_F("clearDoc() updates nearest neighbor index", DenseTensorAttributeMockIndex) +{ + auto& index = f.mock_index(); + + // Nothing to clear. + f.clearTensor(1); + index.expect_empty_remove(); + index.expect_empty_add(); + + // Clears previous value. + f.set_tensor(1, vec_2d(3, 5)); + index.clear(); + f.clearTensor(1); + index.expect_remove(1, {3, 5}); + index.expect_empty_add(); +} + +TEST_F("onLoad() updates nearest neighbor index", DenseTensorAttributeMockIndex) +{ + f.set_tensor(1, vec_2d(3, 5)); + f.set_tensor(2, vec_2d(7, 9)); + f.save(); + f.load(); + auto& index = f.mock_index(); + index.expect_adds({{1, {3, 5}}, {2, {7, 9}}}); +} + TEST_MAIN() { TEST_RUN_ALL(); vespalib::unlink("test.dat"); } diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 7bc582ab442..691e80aeb9f 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -12,6 +12,7 @@ #include <vespa/searchlib/queryeval/simpleresult.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/searchlib/queryeval/nns_index_iterator.h> #include <vespa/log/log.h> LOG_SETUP("nearest_neighbor_test"); @@ -190,4 +191,70 @@ TEST("require that NearestNeighborIterator sets expected rawscore") { TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecDouble)); } +TEST("require that NnsIndexIterator works as expected") { + std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}}; + auto md = MatchData::makeTestInstance(2, 2); + auto &tfmd = *(md->resolveTermField(0)); + auto search = NnsIndexIterator::create(true, tfmd, hits); + uint32_t docid = 1; + search->initFullRange(); + bool match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(2u, search->getDocId()); + docid = 2; + match = search->seek(docid); + EXPECT_TRUE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(docid, search->getDocId()); + search->unpack(docid); + EXPECT_EQUAL(2.0, tfmd.getRawScore()); + + docid = 3; + match = search->seek(docid); + EXPECT_TRUE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(docid, search->getDocId()); + search->unpack(docid); + EXPECT_EQUAL(3.0, tfmd.getRawScore()); + + docid = 4; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(5u, search->getDocId()); + + docid = 6; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(8u, search->getDocId()); + docid = 8; + search->unpack(docid); + EXPECT_EQUAL(4.0, tfmd.getRawScore()); + docid = 9; + match = search->seek(docid); + EXPECT_TRUE(match); + EXPECT_FALSE(search->isAtEnd()); + docid = 10; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_TRUE(search->isAtEnd()); + + docid = 4; + search->initRange(docid, 7); + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(5u, search->getDocId()); + docid = 5; + search->unpack(docid); + EXPECT_EQUAL(1.0, tfmd.getRawScore()); + EXPECT_FALSE(search->isAtEnd()); + docid = 6; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_TRUE(search->isAtEnd()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index c6246bb8434..1204ae1e9bc 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -48,8 +48,7 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>; class HnswIndexTest : public ::testing::Test { public: FloatVectors vectors; - FloatSqEuclideanDistance distance_func; - LevelGenerator level_generator; + LevelGenerator* level_generator; HnswIndexUP index; HnswIndexTest() @@ -62,11 +61,14 @@ public: .set(7, {3, 5}).set(8, {0, 3}).set(9, {4, 5}); } void init(bool heuristic_select_neighbors) { - index = std::make_unique<HnswIndex>(vectors, distance_func, level_generator, + auto generator = std::make_unique<LevelGenerator>(); + level_generator = generator.get(); + index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(), + std::move(generator), HnswIndex::Config(2, 1, 10, heuristic_select_neighbors)); } void add_document(uint32_t docid, uint32_t max_level = 0) { - level_generator.level = max_level; + level_generator->level = max_level; index->add_document(docid); } void remove_document(uint32_t docid) { @@ -100,8 +102,10 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - std::vector<uint32_t> got_by_docid = index->find_top_k(k, qv, k); - EXPECT_EQ(expected_by_docid, got_by_docid); + auto got_by_docid = index->find_top_k(k, qv, k); + for (idx = 0; idx < k; ++idx) { + EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); + } } } }; diff --git a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp index 535e81fc032..10e1a1edb52 100644 --- a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp +++ b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp @@ -73,6 +73,10 @@ ConfigConverter::convert(const AttributesConfig::Attribute & cfg) predicateParams.setBounds(cfg.lowerbound, cfg.upperbound); predicateParams.setDensePostingListThreshold(cfg.densepostinglistthreshold); retval.setPredicateParams(predicateParams); + if (cfg.index.hnsw.enabled) { + retval.set_hnsw_index_params(HnswIndexParams(cfg.index.hnsw.maxlinkspernode, + cfg.index.hnsw.neighborstoexploreatinsert)); + } if (retval.basicType().type() == BasicType::Type::TENSOR) { if (!cfg.tensortype.empty()) { retval.setTensorType(ValueType::from_spec(cfg.tensortype)); diff --git a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt index de2919443ff..0dcb0393473 100644 --- a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt @@ -32,6 +32,7 @@ vespa_add_library(searchlib_queryeval OBJECT nearest_neighbor_blueprint.cpp nearest_neighbor_iterator.cpp nearsearch.cpp + nns_index_iterator.cpp orsearch.cpp predicate_blueprint.cpp predicate_search.cpp diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index 8be6263221a..f9bce4bf7d1 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -3,6 +3,7 @@ #include "emptysearch.h" #include "nearest_neighbor_blueprint.h" #include "nearest_neighbor_iterator.h" +#include "nns_index_iterator.h" #include <vespa/searchlib/fef/termfieldmatchdataarray.h> #include <vespa/eval/tensor/dense/dense_tensor_view.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> @@ -17,20 +18,47 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _attr_tensor(attr_tensor), _query_tensor(std::move(query_tensor)), _target_num_hits(target_num_hits), - _distance_heap(target_num_hits) + _distance_heap(target_num_hits), + _found_hits() { setEstimate(HitEstimate(_attr_tensor.getNumDocs(), false)); } NearestNeighborBlueprint::~NearestNeighborBlueprint() = default; +void +NearestNeighborBlueprint::perform_top_k() +{ + auto nns_index = _attr_tensor.nearest_neighbor_index(); + if (nns_index) { + auto lhs_type = _query_tensor->fast_type(); + auto rhs_type = _attr_tensor.getTensorType(); + // XXX deal with different cell types later + if (lhs_type == rhs_type) { + auto lhs = _query_tensor->cellsRef(); + uint32_t k = _target_num_hits; + uint32_t explore_k = k + 100; // XXX hardcoded for now + _found_hits = nns_index->find_top_k(k, lhs, explore_k); + } + } +} + +void +NearestNeighborBlueprint::fetchPostings(const ExecuteInfo &execInfo) { + if (execInfo.isStrict()) { + perform_top_k(); + } +} + std::unique_ptr<SearchIterator> NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda, bool strict) const { assert(tfmda.size() == 1); fef::TermFieldMatchData &tfmd = *tfmda[0]; // always search in only one field + if (strict && ! _found_hits.empty()) { + return NnsIndexIterator::create(strict, tfmd, _found_hits); + } const vespalib::tensor::DenseTensorView &qT = *_query_tensor; - return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, _distance_heap); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index 019f8e31842..ab4413c487a 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -3,6 +3,7 @@ #include "blueprint.h" #include "nearest_neighbor_distance_heap.h" +#include <vespa/searchlib/tensor/nearest_neighbor_index.h> namespace vespalib::tensor { class DenseTensorView; } namespace search::tensor { class DenseTensorAttribute; } @@ -21,7 +22,9 @@ private: std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor; uint32_t _target_num_hits; mutable NearestNeighborDistanceHeap _distance_heap; + std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits; + void perform_top_k(); public: NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, @@ -38,6 +41,7 @@ public: bool strict) const override; void visitMembers(vespalib::ObjectVisitor& visitor) const override; bool always_needs_unpack() const override; + void fetchPostings(const ExecuteInfo &execInfo) override; }; } diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp new file mode 100644 index 00000000000..7ee985a0ba5 --- /dev/null +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp @@ -0,0 +1,70 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "nns_index_iterator.h" +#include <vespa/searchlib/tensor/nearest_neighbor_index.h> +#include <cmath> + +using Hit = search::tensor::NearestNeighborIndex::Neighbor; + +namespace search::queryeval { + +/** + * Search iterator for K nearest neighbor matching, + * where the actual search is done up front and this class + * just iterates over a vector held by the blueprint. + **/ +class NeighborVectorIterator : public NnsIndexIterator +{ +private: + fef::TermFieldMatchData &_tfmd; + const std::vector<Hit> &_hits; + uint32_t _idx; + double _last_sq_dist; +public: + NeighborVectorIterator(fef::TermFieldMatchData &tfmd, + const std::vector<Hit> &hits) + : _tfmd(tfmd), + _hits(hits), + _idx(0), + _last_sq_dist(0.0) + {} + + void initRange(uint32_t begin_id, uint32_t end_id) override { + SearchIterator::initRange(begin_id, end_id); + _idx = 0; + } + + void doSeek(uint32_t docId) override { + while (_idx < _hits.size()) { + uint32_t hit_id = _hits[_idx].docid; + if (hit_id < docId) { + ++_idx; + } else if (hit_id < getEndId()) { + setDocId(hit_id); + _last_sq_dist = _hits[_idx].distance; + return; + } else { + _idx = _hits.size(); + } + } + setAtEnd(); + } + + void doUnpack(uint32_t docId) override { + _tfmd.setRawScore(docId, sqrt(_last_sq_dist)); + } + + Trinary is_strict() const override { return Trinary::True; } +}; + +std::unique_ptr<NnsIndexIterator> +NnsIndexIterator::create( + bool strict, + fef::TermFieldMatchData &tfmd, + const std::vector<Hit> &hits) +{ + assert(strict); + return std::make_unique<NeighborVectorIterator>(tfmd, hits); +} + +} // namespace diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h new file mode 100644 index 00000000000..62fa49aac46 --- /dev/null +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h @@ -0,0 +1,21 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "searchiterator.h" +#include <vespa/searchlib/fef/termfieldmatchdata.h> +#include <vespa/searchlib/tensor/nearest_neighbor_index.h> + +namespace search::queryeval { + +class NnsIndexIterator : public SearchIterator +{ +public: + using Hit = search::tensor::NearestNeighborIndex::Neighbor; + static std::unique_ptr<NnsIndexIterator> create( + bool strict, + fef::TermFieldMatchData &tfmd, + const std::vector<Hit> &hits); +}; + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index 9175168248c..0bdcd53af77 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -1,16 +1,18 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(searchlib_tensor OBJECT SOURCES + default_nearest_neighbor_index_factory.cpp dense_tensor_attribute.cpp dense_tensor_attribute_saver.cpp dense_tensor_store.cpp generic_tensor_attribute.cpp + generic_tensor_attribute_saver.cpp generic_tensor_store.cpp hnsw_index.cpp imported_tensor_attribute_vector.cpp imported_tensor_attribute_vector_read_guard.cpp + nearest_neighbor_index.cpp tensor_attribute.cpp - generic_tensor_attribute_saver.cpp tensor_store.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp new file mode 100644 index 00000000000..68efe6417c0 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp @@ -0,0 +1,51 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "default_nearest_neighbor_index_factory.h" +#include "distance_functions.h" +#include "hnsw_index.h" +#include "random_level_generator.h" +#include <vespa/searchcommon/attribute/config.h> + +namespace search::tensor { + +using vespalib::eval::ValueType; + +namespace { + +class LevelZeroGenerator : public RandomLevelGenerator { + uint32_t max_level() override { return 0; } +}; + +DistanceFunction::UP +make_distance_function(ValueType::CellType cell_type) +{ + if (cell_type == ValueType::CellType::FLOAT) { + return std::make_unique<SquaredEuclideanDistance<float>>(); + } else { + return std::make_unique<SquaredEuclideanDistance<double>>(); + } +} + +RandomLevelGenerator::UP +make_random_level_generator() +{ + // TODO: Make generator that results in hierarchical graph. + return std::make_unique<LevelZeroGenerator>(); +} + +} + +std::unique_ptr<NearestNeighborIndex> +DefaultNearestNeighborIndexFactory::make(const DocVectorAccess& vectors, + vespalib::eval::ValueType::CellType cell_type, + const search::attribute::HnswIndexParams& params) const +{ + HnswIndex::Config cfg(params.max_links_per_node() * 2, + params.max_links_per_node(), + params.neighbors_to_explore_at_insert(), + true); + return std::make_unique<HnswIndex>(vectors, make_distance_function(cell_type), make_random_level_generator(), cfg); +} + +} + diff --git a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h new file mode 100644 index 00000000000..ea784efdb51 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h @@ -0,0 +1,19 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "nearest_neighbor_index_factory.h" + +namespace search::tensor { + +/** + * Factory that instantiates the production hnsw index. + */ +class DefaultNearestNeighborIndexFactory : public NearestNeighborIndexFactory { +public: + std::unique_ptr<NearestNeighborIndex> make(const DocVectorAccess& vectors, + vespalib::eval::ValueType::CellType cell_type, + const search::attribute::HnswIndexParams& params) const override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp index a2b9f136ed9..171340e07f1 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp @@ -2,6 +2,7 @@ #include "dense_tensor_attribute.h" #include "dense_tensor_attribute_saver.h" +#include "nearest_neighbor_index.h" #include "tensor_attribute.hpp" #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> @@ -55,11 +56,23 @@ TensorReader::is_present() { } -DenseTensorAttribute::DenseTensorAttribute(vespalib::stringref baseFileName, - const Config &cfg) +void +DenseTensorAttribute::consider_remove_from_index(DocId docid) +{ + if (_index && _refVector[docid].valid()) { + _index->remove_document(docid); + } +} + +DenseTensorAttribute::DenseTensorAttribute(vespalib::stringref baseFileName, const Config& cfg, + const NearestNeighborIndexFactory& index_factory) : TensorAttribute(baseFileName, cfg, _denseTensorStore), - _denseTensorStore(cfg.tensorType()) + _denseTensorStore(cfg.tensorType()), + _index() { + if (cfg.hnsw_index_params().has_value()) { + _index = index_factory.make(*this, cfg.tensorType().cell_type(), cfg.hnsw_index_params().value()); + } } @@ -69,12 +82,23 @@ DenseTensorAttribute::~DenseTensorAttribute() _tensorStore.clearHoldLists(); } +uint32_t +DenseTensorAttribute::clearDoc(DocId docId) +{ + consider_remove_from_index(docId); + return TensorAttribute::clearDoc(docId); +} + void DenseTensorAttribute::setTensor(DocId docId, const Tensor &tensor) { checkTensorType(tensor); + consider_remove_from_index(docId); EntryRef ref = _denseTensorStore.setTensor(tensor); setTensorRef(docId, ref); + if (_index) { + _index->add_document(docId); + } } @@ -120,6 +144,11 @@ DenseTensorAttribute::onLoad() auto raw = _denseTensorStore.allocRawBuffer(); tensorReader.readTensor(raw.data, _denseTensorStore.getBufSize()); _refVector.push_back(raw.ref); + if (_index) { + // This ensures that get_vector() (via getTensor()) is able to find the newly added tensor. + setCommittedDocIdLimit(lid + 1); + _index->add_document(lid); + } } else { _refVector.push_back(EntryRef()); } @@ -154,4 +183,12 @@ DenseTensorAttribute::getVersion() const return DENSE_TENSOR_ATTRIBUTE_VERSION; } +vespalib::tensor::TypedCells +DenseTensorAttribute::get_vector(uint32_t docid) const +{ + MutableDenseTensorView tensor_view(_denseTensorStore.type()); + getTensor(docid, tensor_view); + return tensor_view.cellsRef(); +} + } diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h index 593741cef39..f9a8a81b56b 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h @@ -2,35 +2,47 @@ #pragma once -#include "tensor_attribute.h" +#include "default_nearest_neighbor_index_factory.h" #include "dense_tensor_store.h" +#include "doc_vector_access.h" +#include "tensor_attribute.h" +#include <memory> -namespace vespalib { namespace tensor { class MutableDenseTensorView; }} +namespace vespalib::tensor { class MutableDenseTensorView; } -namespace search { +namespace search::tensor { -namespace tensor { +class NearestNeighborIndex; /** * Attribute vector class used to store dense tensors for all * documents in memory. */ -class DenseTensorAttribute : public TensorAttribute -{ +class DenseTensorAttribute : public TensorAttribute, public DocVectorAccess { +private: DenseTensorStore _denseTensorStore; + std::unique_ptr<NearestNeighborIndex> _index; + + void consider_remove_from_index(DocId docid); + public: - DenseTensorAttribute(vespalib::stringref baseFileName, const Config &cfg); + DenseTensorAttribute(vespalib::stringref baseFileName, const Config& cfg, + const NearestNeighborIndexFactory& index_factory = DefaultNearestNeighborIndexFactory()); virtual ~DenseTensorAttribute(); - virtual void setTensor(DocId docId, const Tensor &tensor) override; - virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override; - virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; - virtual bool onLoad() override; - virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override; - virtual void compactWorst() override; - virtual uint32_t getVersion() const override; + // Implements TensorAttribute + uint32_t clearDoc(DocId docId) override; + void setTensor(DocId docId, const Tensor &tensor) override; + std::unique_ptr<Tensor> getTensor(DocId docId) const override; + void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; + bool onLoad() override; + std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override; + void compactWorst() override; + uint32_t getVersion() const override; + + // Implements DocVectorAccess + vespalib::tensor::TypedCells get_vector(uint32_t docid) const override; + + const NearestNeighborIndex* nearest_neighbor_index() const { return _index.get(); } }; - -} // namespace search::tensor - -} // namespace search +} diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function.h b/searchlib/src/vespa/searchlib/tensor/distance_function.h index 8dfb77ddccb..b682824c805 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_function.h @@ -2,6 +2,8 @@ #pragma once +#include <memory> + namespace vespalib::tensor { struct TypedCells; } namespace search::tensor { @@ -14,6 +16,7 @@ namespace search::tensor { */ class DistanceFunction { public: + using UP = std::unique_ptr<DistanceFunction>; virtual ~DistanceFunction() {} virtual double calc(const vespalib::tensor::TypedCells& lhs, const vespalib::tensor::TypedCells& rhs) const = 0; }; diff --git a/searchlib/src/vespa/searchlib/tensor/distance_functions.h b/searchlib/src/vespa/searchlib/tensor/distance_functions.h index 1e8727e92aa..494d1a859b6 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_functions.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_functions.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include <vespa/eval/tensor/dense/typed_cells.h> namespace search::tensor { diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index be53b758841..0d308206761 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -44,7 +44,7 @@ HnswIndex::max_links_for_level(uint32_t level) const uint32_t HnswIndex::make_node_for_document(uint32_t docid) { - uint32_t max_level = _level_generator.max_level(); + uint32_t max_level = _level_generator->max_level(); // TODO: Add capping on num_levels uint32_t num_levels = max_level + 1; // Note: The level array instance lives as long as the document is present in the index. @@ -170,11 +170,11 @@ double HnswIndex::calc_distance(const TypedCells& lhs, uint32_t rhs_docid) const { auto rhs = get_vector(rhs_docid); - return _distance_func.calc(lhs, rhs); + return _distance_func->calc(lhs, rhs); } HnswCandidate -HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) +HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const { HnswCandidate nearest = entry_point; bool keep_searching = true; @@ -192,7 +192,7 @@ HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& e } void -HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& best_neighbors, uint32_t level) +HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& best_neighbors, uint32_t level) const { NearestPriQ candidates; // TODO: Add proper handling of visited set. @@ -227,11 +227,11 @@ HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, Fur } } -HnswIndex::HnswIndex(const DocVectorAccess& vectors, const DistanceFunction& distance_func, - RandomLevelGenerator& level_generator, const Config& cfg) +HnswIndex::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func, + RandomLevelGenerator::UP level_generator, const Config& cfg) : _vectors(vectors), - _distance_func(distance_func), - _level_generator(level_generator), + _distance_func(std::move(distance_func)), + _level_generator(std::move(level_generator)), _cfg(cfg), _node_refs(), _nodes(make_default_node_store_config()), @@ -310,24 +310,32 @@ HnswIndex::remove_document(uint32_t docid) _node_refs[docid].store_release(invalid); } -std::vector<uint32_t> -HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) +struct NeighborsByDocId { + bool operator() (const NearestNeighborIndex::Neighbor &lhs, + const NearestNeighborIndex::Neighbor &rhs) + { + return (lhs.docid < rhs.docid); + } +}; + +std::vector<NearestNeighborIndex::Neighbor> +HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const { - std::vector<uint32_t> result; + std::vector<Neighbor> result; FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k)); while (candidates.size() > k) { candidates.pop(); } result.reserve(candidates.size()); for (const HnswCandidate & hit : candidates.peek()) { - result.emplace_back(hit.docid); + result.emplace_back(hit.docid, hit.distance); } - std::sort(result.begin(), result.end()); + std::sort(result.begin(), result.end(), NeighborsByDocId()); return result; } FurthestPriQ -HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) +HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) const { FurthestPriQ best_neighbors; if (_entry_level < 0) { diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 814148072ca..800b88923b5 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -2,10 +2,12 @@ #pragma once +#include "distance_function.h" #include "doc_vector_access.h" #include "hnsw_index_utils.h" #include "hnsw_node.h" #include "nearest_neighbor_index.h" +#include "random_level_generator.h" #include <vespa/eval/tensor/dense/typed_cells.h> #include <vespa/searchlib/common/bitvector.h> #include <vespa/vespalib/datastore/array_store.h> @@ -15,9 +17,6 @@ namespace search::tensor { -class DistanceFunction; -class RandomLevelGenerator; - /** * Implementation of a hierarchical navigable small world graph (HNSW) * that is used for approximate K-nearest neighbor search. @@ -82,8 +81,8 @@ protected: using TypedCells = vespalib::tensor::TypedCells; const DocVectorAccess& _vectors; - const DistanceFunction& _distance_func; - RandomLevelGenerator& _level_generator; + DistanceFunction::UP _distance_func; + RandomLevelGenerator::UP _level_generator; Config _cfg; NodeRefVector _node_refs; NodeStore _nodes; @@ -124,18 +123,20 @@ protected: /** * Performs a greedy search in the given layer to find the candidate that is nearest the input vector. */ - HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level); - void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, uint32_t level); + HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const; + void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, uint32_t level) const; public: - HnswIndex(const DocVectorAccess& vectors, const DistanceFunction& distance_func, - RandomLevelGenerator& level_generator, const Config& cfg); + HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func, + RandomLevelGenerator::UP level_generator, const Config& cfg); ~HnswIndex() override; + const Config& config() const { return _cfg; } + void add_document(uint32_t docid) override; void remove_document(uint32_t docid) override; - std::vector<uint32_t> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) override; - FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k); + std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const override; + FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k) const; // TODO: Add support for generation handling and cleanup (transfer_hold_lists, trim_hold_lists) diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.cpp b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.cpp new file mode 100644 index 00000000000..f31230af381 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.cpp @@ -0,0 +1,3 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "nearest_neighbor_index.h" diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index 2ae322fe76e..f933af0147e 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -13,10 +13,20 @@ namespace search::tensor { */ class NearestNeighborIndex { public: + struct Neighbor { + uint32_t docid; + double distance; + Neighbor(uint32_t id, double dist) + : docid(id), distance(dist) + {} + Neighbor() : docid(0), distance(0.0) {} + }; virtual ~NearestNeighborIndex() {} virtual void add_document(uint32_t docid) = 0; virtual void remove_document(uint32_t docid) = 0; - virtual std::vector<uint32_t> find_top_k(uint32_t k, vespalib::tensor::TypedCells vector, uint32_t explore_k) = 0; + virtual std::vector<Neighbor> find_top_k(uint32_t k, + vespalib::tensor::TypedCells vector, + uint32_t explore_k) const = 0; }; } diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index_factory.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index_factory.h new file mode 100644 index 00000000000..c09403df5e0 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index_factory.h @@ -0,0 +1,26 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/eval/value_type.h> +#include <memory> + +namespace search::attribute { class HnswIndexParams; } + +namespace search::tensor { + +class DocVectorAccess; +class NearestNeighborIndex; + +/** + * Factory interface used to instantiate an index used for (approximate) nearest neighbor search. + */ +class NearestNeighborIndexFactory { +public: + virtual ~NearestNeighborIndexFactory() {} + virtual std::unique_ptr<NearestNeighborIndex> make(const DocVectorAccess& vectors, + vespalib::eval::ValueType::CellType cell_type, + const search::attribute::HnswIndexParams& params) const = 0; +}; + +} diff --git a/searchlib/src/vespa/searchlib/tensor/random_level_generator.h b/searchlib/src/vespa/searchlib/tensor/random_level_generator.h index 0fcac977d9d..0f4c7c34445 100644 --- a/searchlib/src/vespa/searchlib/tensor/random_level_generator.h +++ b/searchlib/src/vespa/searchlib/tensor/random_level_generator.h @@ -2,6 +2,8 @@ #pragma once +#include <memory> + namespace search::tensor { /** @@ -9,6 +11,7 @@ namespace search::tensor { */ class RandomLevelGenerator { public: + using UP = std::unique_ptr<RandomLevelGenerator>; virtual ~RandomLevelGenerator() {} virtual uint32_t max_level() = 0; }; diff --git a/searchsummary/src/tests/docsummary/matched_elements_filter/matched_elements_filter_test.cpp b/searchsummary/src/tests/docsummary/matched_elements_filter/matched_elements_filter_test.cpp index 9019a212f3f..dff3acc5b89 100644 --- a/searchsummary/src/tests/docsummary/matched_elements_filter/matched_elements_filter_test.cpp +++ b/searchsummary/src/tests/docsummary/matched_elements_filter/matched_elements_filter_test.cpp @@ -246,6 +246,7 @@ TEST_F(MatchedElementsFilterTest, filters_elements_in_array_field_value_when_inp expect_filtered("array_in_doc", {0, 1, 2}, "[{'name':'a','weight':3}," "{'name':'b','weight':5}," "{'name':'c','weight':7}]"); + expect_filtered("array_in_doc", {0, 1, 100}, "[]"); } TEST_F(MatchedElementsFilterTest, struct_field_mapper_is_setup_for_array_field_value) @@ -276,6 +277,7 @@ TEST_F(MatchedElementsFilterTest, filters_elements_in_map_field_value_when_input expect_filtered("map_in_doc", {0, 1, 2}, "[{'key':'a','value':{'name':'a','weight':3}}," "{'key':'b','value':{'name':'b','weight':5}}," "{'key':'c','value':{'name':'c','weight':7}}]"); + expect_filtered("map_in_doc", {0, 1, 100}, "[]"); } TEST_F(MatchedElementsFilterTest, struct_field_mapper_is_setup_for_map_field_value) diff --git a/searchsummary/src/vespa/searchsummary/docsummary/summaryfieldconverter.cpp b/searchsummary/src/vespa/searchsummary/docsummary/summaryfieldconverter.cpp index 7368a199569..ada14bf17f5 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/summaryfieldconverter.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/summaryfieldconverter.cpp @@ -389,9 +389,11 @@ private: MapFieldValueInserter map_inserter(_inserter, _tokenize); if (filter_matching_elements()) { assert(v.has_no_erased_keys()); - for (uint32_t id_to_keep : (*_matching_elems)) { - auto entry = v[id_to_keep]; - map_inserter.insert_entry(*entry.first, *entry.second); + if (!_matching_elems->empty() && _matching_elems->back() < v.size()) { + for (uint32_t id_to_keep : (*_matching_elems)) { + auto entry = v[id_to_keep]; + map_inserter.insert_entry(*entry.first, *entry.second); + } } } else { for (const auto &entry : v) { @@ -406,8 +408,10 @@ private: ArrayInserter ai(a); SlimeFiller conv(ai, _tokenize); if (filter_matching_elements()) { - for (uint32_t id_to_keep : (*_matching_elems)) { - value[id_to_keep].accept(conv); + if (!_matching_elems->empty() && _matching_elems->back() < value.size()) { + for (uint32_t id_to_keep : (*_matching_elems)) { + value[id_to_keep].accept(conv); + } } } else { for (const FieldValue &fv : value) { |