summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java4
-rw-r--r--config/src/tests/failover/failover.cpp1
-rw-r--r--config/src/tests/frt/frt.cpp1
-rw-r--r--config/src/vespa/config/common/trace.cpp1
-rw-r--r--config/src/vespa/config/frt/frtconfigresponsev3.cpp1
-rw-r--r--config/src/vespa/config/frt/slimeconfigrequest.cpp2
-rw-r--r--config/src/vespa/config/print/fileconfigformatter.cpp1
-rw-r--r--configdefinitions/src/vespa/configserver.def4
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java8
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java20
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java21
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java2
-rw-r--r--container-core/abi-spec.json3
-rw-r--r--container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java5
-rw-r--r--container-core/src/main/java/com/yahoo/container/handler/LogHandler.java53
-rw-r--r--container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java6
-rw-r--r--container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java18
-rw-r--r--container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java22
-rw-r--r--default_build_settings.cmake2
-rw-r--r--dist/vespa.spec8
-rw-r--r--eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp16
-rw-r--r--eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp4
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/.gitattributes1
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/dynamic.onnx27
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/dynamic.py39
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp223
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/simple.onnx23
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/simple.py33
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp12
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.cpp9
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.h1
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor.h5
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h5
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp271
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.h92
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp10
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h1
-rw-r--r--eval/src/vespa/eval/tensor/tensor.h1
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp10
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.h1
-rw-r--r--metrics/src/tests/metricmanagertest.cpp1
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java5
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java4
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java5
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java2
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java22
-rw-r--r--searchcore/src/tests/proton/common/pendinglidtracker_test.cpp4
-rw-r--r--searchcore/src/tests/proton/docsummary/docsummary.cpp1
-rw-r--r--searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp1
-rw-r--r--searchcore/src/tests/proton/server/feedstates_test.cpp4
-rw-r--r--searchcore/src/tests/proton/summaryengine/summaryengine.cpp1
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/documentdb.cpp9
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/feedstates.cpp4
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/feedstates.h1
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp44
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h1
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp3
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp14
-rw-r--r--searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp1
-rw-r--r--searchlib/CMakeLists.txt1
-rw-r--r--searchlib/src/tests/attribute/bitvector/bitvector_test.cpp2
-rw-r--r--searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp61
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp15
-rw-r--r--searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp2
-rw-r--r--searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt9
-rw-r--r--searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp89
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp5
-rw-r--r--searchlib/src/tests/transactionlogstress/translogstress.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp17
-rw-r--r--searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h17
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h10
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp23
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h10
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp29
-rw-r--r--searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp17
-rw-r--r--searchlib/src/vespa/searchlib/attribute/reference_attribute.h19
-rw-r--r--searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h3
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp52
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.h7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h26
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp14
-rw-r--r--searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp52
-rw-r--r--searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h25
-rw-r--r--searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp62
-rw-r--r--searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h34
-rw-r--r--searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp20
-rw-r--r--searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp24
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h10
-rw-r--r--searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h2
-rw-r--r--searchlib/src/vespa/searchlib/transactionlog/common.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/transactionlog/domain.cpp36
-rw-r--r--searchlib/src/vespa/searchlib/transactionlog/session.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp17
-rw-r--r--staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp1
-rw-r--r--storage/src/tests/storageserver/statereportertest.cpp2
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java3
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java2
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java3
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java25
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java4
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java13
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java58
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java26
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java13
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java13
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java141
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionFactory.java63
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java39
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java31
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java59
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnectionFactory.java26
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java44
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java13
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnectionFactory.java13
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java290
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java18
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java13
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java13
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java20
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java7
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java3
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/ManualClock.java55
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java23
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java3
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java44
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java22
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java160
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java20
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java3
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java73
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/NewIOThreadTest.java192
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java43
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java31
-rw-r--r--vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java79
-rw-r--r--vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java60
-rw-r--r--vespalib/src/tests/slime/json_slime_benchmark.cpp2
-rw-r--r--vespalib/src/tests/slime/slime_binary_format_test.cpp1
-rw-r--r--vespalib/src/tests/slime/slime_json_format_test.cpp1
-rw-r--r--vespalib/src/tests/slime/slime_test.cpp7
-rw-r--r--vespalib/src/tests/trace/trace_serialization.cpp1
-rw-r--r--vespalib/src/vespa/vespalib/data/memory.h12
-rw-r--r--vespalib/src/vespa/vespalib/data/simple_buffer.cpp2
-rw-r--r--vespalib/src/vespa/vespalib/data/simple_buffer.h3
-rw-r--r--vespalib/src/vespa/vespalib/data/slime/slime.h1
-rw-r--r--vespalib/src/vespa/vespalib/datastore/datastore.h1
-rw-r--r--vespalib/src/vespa/vespalib/datastore/datastore.hpp8
-rw-r--r--vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h4
-rw-r--r--vespalib/src/vespa/vespalib/objects/nbostream.h3
-rw-r--r--vespalib/src/vespa/vespalib/util/arrayref.h18
163 files changed, 2589 insertions, 1097 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java b/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java
index 0624028732f..a4737c9f54c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java
@@ -101,12 +101,12 @@ public class ContainerDocumentApi {
.collect(Collectors.toList());
// We can only use host resource for calculation if all container nodes in the cluster are homogeneous (in terms of vcpu)
if (vcpus.size() != 1 || vcpus.get(0) == 0) return FALLBACK_MAX_POOL_SIZE;
- return (int)Math.ceil(vcpus.get(0));
+ return Math.max(2, (int)Math.ceil(vcpus.get(0)));
}
private static int corePoolSize(int maxPoolSize, Options options) {
if (maxPoolSize == FALLBACK_MAX_POOL_SIZE) return FALLBACK_CORE_POOL_SIZE;
- return (int) Math.ceil(options.feedCoreThreadPoolSizeFactor * maxPoolSize);
+ return Math.max(1, (int)Math.ceil(options.feedCoreThreadPoolSizeFactor * maxPoolSize));
}
public static final class Options {
diff --git a/config/src/tests/failover/failover.cpp b/config/src/tests/failover/failover.cpp
index 17fa264fd32..0ca09b228f3 100644
--- a/config/src/tests/failover/failover.cpp
+++ b/config/src/tests/failover/failover.cpp
@@ -7,6 +7,7 @@
#include <vespa/fnet/frt/frt.h>
#include "config-my.h"
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/log/log.h>
LOG_SETUP("failover");
diff --git a/config/src/tests/frt/frt.cpp b/config/src/tests/frt/frt.cpp
index cf1ff9eca37..85b9789821d 100644
--- a/config/src/tests/frt/frt.cpp
+++ b/config/src/tests/frt/frt.cpp
@@ -10,6 +10,7 @@
#include <vespa/config/frt/frtconfigresponsev3.h>
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/data/slime/json_format.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/fnet/fnet.h>
#include <vespa/fnet/frt/frt.h>
#include <vespa/fnet/frt/error.h>
diff --git a/config/src/vespa/config/common/trace.cpp b/config/src/vespa/config/common/trace.cpp
index 76310d08c7d..4edc9df60c3 100644
--- a/config/src/vespa/config/common/trace.cpp
+++ b/config/src/vespa/config/common/trace.cpp
@@ -3,6 +3,7 @@
#include <vespa/vespalib/trace/slime_trace_serializer.h>
#include <vespa/vespalib/trace/slime_trace_deserializer.h>
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
using namespace vespalib;
using namespace vespalib::slime;
diff --git a/config/src/vespa/config/frt/frtconfigresponsev3.cpp b/config/src/vespa/config/frt/frtconfigresponsev3.cpp
index 405391d99b6..b983c63c6a5 100644
--- a/config/src/vespa/config/frt/frtconfigresponsev3.cpp
+++ b/config/src/vespa/config/frt/frtconfigresponsev3.cpp
@@ -2,6 +2,7 @@
#include "frtconfigresponsev3.h"
#include "compressioninfo.h"
#include <vespa/fnet/frt/frt.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/log/log.h>
LOG_SETUP(".config.frt.frtconfigresponsev3");
diff --git a/config/src/vespa/config/frt/slimeconfigrequest.cpp b/config/src/vespa/config/frt/slimeconfigrequest.cpp
index 07626c1e274..696789f74c1 100644
--- a/config/src/vespa/config/frt/slimeconfigrequest.cpp
+++ b/config/src/vespa/config/frt/slimeconfigrequest.cpp
@@ -7,6 +7,8 @@
#include <vespa/config/common/configdefinition.h>
#include <vespa/config/common/trace.h>
#include <vespa/config/common/vespa_version.h>
+#include <vespa/vespalib/data/simple_buffer.h>
+
using namespace vespalib;
using namespace vespalib::slime;
diff --git a/config/src/vespa/config/print/fileconfigformatter.cpp b/config/src/vespa/config/print/fileconfigformatter.cpp
index 85e938dee8f..628a9daa530 100644
--- a/config/src/vespa/config/print/fileconfigformatter.cpp
+++ b/config/src/vespa/config/print/fileconfigformatter.cpp
@@ -4,6 +4,7 @@
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/util/exceptions.h>
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <cmath>
#include <vector>
diff --git a/configdefinitions/src/vespa/configserver.def b/configdefinitions/src/vespa/configserver.def
index cde539f25f4..7405f5f2d05 100644
--- a/configdefinitions/src/vespa/configserver.def
+++ b/configdefinitions/src/vespa/configserver.def
@@ -51,9 +51,7 @@ ztsUrl string default=""
# Maintainers
maintainerIntervalMinutes int default=30
-# TODO: Default set to a high value (1 year) => maintainer will not run, change when maintainer verified out in prod
-tenantsMaintainerIntervalMinutes int default=525600
-keepUnusedFileReferencesHours int default=4
+keepUnusedFileReferencesHours int default=2
# Bootstrapping
# How long bootstrapping can take before giving up (in seconds)
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java
index 4e44b9cae33..9f2d6d178be 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java
@@ -35,8 +35,8 @@ public class ConfigServerMaintenance extends AbstractComponent {
DefaultTimes defaults = new DefaultTimes(configserverConfig);
tenantsMaintainer = new TenantsMaintainer(applicationRepository, curator, flagSource, defaults.defaultInterval, Clock.systemUTC());
fileDistributionMaintainer = new FileDistributionMaintainer(applicationRepository, curator, defaults.defaultInterval, flagSource);
- sessionsMaintainer = new SessionsMaintainer(applicationRepository, curator, Duration.ofMinutes(1), flagSource);
- applicationPackageMaintainer = new ApplicationPackageMaintainer(applicationRepository, curator, Duration.ofMinutes(1), flagSource);
+ sessionsMaintainer = new SessionsMaintainer(applicationRepository, curator, Duration.ofSeconds(30), flagSource);
+ applicationPackageMaintainer = new ApplicationPackageMaintainer(applicationRepository, curator, Duration.ofSeconds(30), flagSource);
}
@Override
@@ -61,8 +61,8 @@ public class ConfigServerMaintenance extends AbstractComponent {
}
public void runBeforeBootstrap() {
- fileDistributionMaintainer.maintain();
- sessionsMaintainer.maintain();
+ fileDistributionMaintainer.lockAndMaintain();
+ sessionsMaintainer.lockAndMaintain();
}
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java
index 2e6180aeb81..cbfa59b26e4 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java
@@ -171,11 +171,11 @@ public class SessionRepository {
}
public void deleteExpiredSessions(Map<ApplicationId, Long> activeSessions) {
- log.log(Level.FINE, "Purging old sessions for tenant '" + tenantName + "'");
+ log.log(Level.FINE, () -> "Purging old sessions for tenant '" + tenantName + "'");
try {
for (LocalSession candidate : localSessionCache.getSessions()) {
Instant createTime = candidate.getCreateTime();
- log.log(Level.FINE, "Candidate session for deletion: " + candidate.getSessionId() + ", created: " + createTime);
+ log.log(Level.FINE, () -> "Candidate session for deletion: " + candidate.getSessionId() + ", created: " + createTime);
// Sessions with state other than ACTIVATE
if (hasExpired(candidate) && !isActiveSession(candidate)) {
@@ -196,7 +196,7 @@ public class SessionRepository {
} catch (Throwable e) {
log.log(Level.WARNING, "Error when purging old sessions ", e);
}
- log.log(Level.FINE, "Done purging old sessions");
+ log.log(Level.FINE, () -> "Done purging old sessions");
}
private boolean hasExpired(LocalSession candidate) {
@@ -210,7 +210,7 @@ public class SessionRepository {
public void deleteLocalSession(LocalSession session) {
long sessionId = session.getSessionId();
try (Lock lock = lock(sessionId)) {
- log.log(Level.FINE, "Deleting local session " + sessionId);
+ log.log(Level.FINE, () -> "Deleting local session " + sessionId);
SessionStateWatcher watcher = sessionStateWatchers.remove(sessionId);
if (watcher != null) watcher.close();
localSessionCache.removeSession(sessionId);
@@ -274,7 +274,7 @@ public class SessionRepository {
if (session == null) continue; // Internal sessions not in synch with zk, continue
if (session.getStatus() == Session.Status.ACTIVATE) continue;
if (sessionHasExpired(session.getCreateTime(), expiryTime, clock)) {
- log.log(Level.FINE, "Remote session " + sessionId + " for " + tenantName + " has expired, deleting it");
+ log.log(Level.FINE, () -> "Remote session " + sessionId + " for " + tenantName + " has expired, deleting it");
session.delete();
deleted++;
}
@@ -287,7 +287,7 @@ public class SessionRepository {
for (var lock : curator.getChildren(locksPath)) {
Path path = locksPath.append(lock);
if (zooKeeperNodeCreated(path).orElse(clock.instant()).isBefore(clock.instant().minus(expiryTime))) {
- log.log(Level.FINE, "Lock " + path + " has expired, deleting it");
+ log.log(Level.FINE, () -> "Lock " + path + " has expired, deleting it");
curator.delete(path);
deleted++;
}
@@ -485,7 +485,7 @@ public class SessionRepository {
long sessionId,
TimeoutBudget timeoutBudget,
Clock clock) {
- log.log(Level.FINE, TenantRepository.logPre(tenantName) + "Creating session " + sessionId + " in ZooKeeper");
+ log.log(Level.FINE, () -> TenantRepository.logPre(tenantName) + "Creating session " + sessionId + " in ZooKeeper");
SessionZooKeeperClient sessionZKClient = createSessionZooKeeperClient(sessionId);
sessionZKClient.createNewSession(clock.instant());
Curator.CompletionWaiter waiter = sessionZKClient.getUploadWaiter();
@@ -605,13 +605,13 @@ public class SessionRepository {
*/
public Optional<LocalSession> createLocalSessionUsingDistributedApplicationPackage(long sessionId) {
if (applicationRepo.hasLocalSession(sessionId)) {
- log.log(Level.FINE, "Local session for session id " + sessionId + " already exists");
+ log.log(Level.FINE, () -> "Local session for session id " + sessionId + " already exists");
return Optional.of(createSessionFromId(sessionId));
}
SessionZooKeeperClient sessionZKClient = createSessionZooKeeperClient(sessionId);
FileReference fileReference = sessionZKClient.readApplicationPackageReference();
- log.log(Level.FINE, "File reference for session id " + sessionId + ": " + fileReference);
+ log.log(Level.FINE, () -> "File reference for session id " + sessionId + ": " + fileReference);
if (fileReference != null) {
File rootDir = new File(Defaults.getDefaults().underVespaHome(componentRegistry.getConfigserverConfig().fileReferencesDir()));
File sessionDir;
@@ -626,7 +626,7 @@ public class SessionRepository {
}
ApplicationId applicationId = sessionZKClient.readApplicationId()
.orElseThrow(() -> new RuntimeException("Could not find application id for session " + sessionId));
- log.log(Level.INFO, "Creating local session for tenant '" + tenantName + "' with session id " + sessionId);
+ log.log(Level.FINE, () -> "Creating local session for tenant '" + tenantName + "' with session id " + sessionId);
LocalSession localSession = createLocalSession(sessionDir, applicationId, sessionId);
addLocalSession(localSession);
return Optional.of(localSession);
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java
index 6216e8ebfd6..ecbcb513c03 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java
@@ -97,17 +97,6 @@ public class TenantRepository {
*/
@Inject
public TenantRepository(GlobalComponentRegistry componentRegistry) {
- this(componentRegistry, true);
- }
-
- /**
- * Creates a new tenant repository
- *
- * @param componentRegistry a {@link com.yahoo.vespa.config.server.GlobalComponentRegistry}
- * @param useZooKeeperWatchForTenantChanges set to false for tests where you want to control adding and deleting
- * tenants yourself
- */
- public TenantRepository(GlobalComponentRegistry componentRegistry, boolean useZooKeeperWatchForTenantChanges) {
this.componentRegistry = componentRegistry;
ConfigserverConfig configserverConfig = componentRegistry.getConfigserverConfig();
this.bootstrapExecutor = Executors.newFixedThreadPool(configserverConfig.numParallelTenantLoaders());
@@ -124,13 +113,9 @@ public class TenantRepository {
createSystemTenants(configserverConfig);
curator.create(vespaPath);
- if (useZooKeeperWatchForTenantChanges) {
- this.directoryCache = Optional.of(curator.createDirectoryCache(tenantsPath.getAbsolute(), false, false, zkCacheExecutor));
- this.directoryCache.get().addListener(this::childEvent);
- this.directoryCache.get().start();
- } else {
- this.directoryCache = Optional.empty();
- }
+ this.directoryCache = Optional.of(curator.createDirectoryCache(tenantsPath.getAbsolute(), false, false, zkCacheExecutor));
+ this.directoryCache.get().addListener(this::childEvent);
+ this.directoryCache.get().start();
bootstrapTenants();
notifyTenantsLoaded();
checkForRemovedApplicationsService.scheduleWithFixedDelay(this::removeUnusedApplications,
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java
index cde115eec40..a1249838324 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java
@@ -75,7 +75,7 @@ public class TenantApplicationsTest {
.modelFactoryRegistry(createRegistry())
.reloadListener(listener)
.build();
- TenantRepository tenantRepository = new TenantRepository(componentRegistry, false);
+ TenantRepository tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(TenantRepository.HOSTED_VESPA_TENANT);
tenantRepository.addTenant(tenantName);
applications = TenantApplications.create(componentRegistry, tenantName);
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java
index 469168cedd4..67ac0b02133 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java
@@ -50,7 +50,7 @@ public class ApplicationContentHandlerTest extends ContentHandlerTestBase {
@Before
public void setupHandler() {
- TenantRepository tenantRepository = new TenantRepository(componentRegistry, false);
+ TenantRepository tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(tenantName1);
tenantRepository.addTenant(tenantName2);
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
index f43154242fb..3a33d326c48 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
@@ -80,7 +80,7 @@ public class ApplicationHandlerTest {
.provisioner(provisioner)
.modelFactoryRegistry(new ModelFactoryRegistry(modelFactories))
.build();
- tenantRepository = new TenantRepository(componentRegistry, false);
+ tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(mytenantName);
provisioner = new SessionHandlerTest.MockProvisioner();
orchestrator = new OrchestratorMock();
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java
index bef6369beb7..80a0b9edba6 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java
@@ -41,7 +41,7 @@ public class ListApplicationsHandlerTest {
@Before
public void setup() {
- TenantRepository tenantRepository = new TenantRepository(componentRegistry, false);
+ TenantRepository tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(mytenant);
tenantRepository.addTenant(foobar);
applicationRepo = tenantRepository.getTenant(mytenant).getApplicationRepo();
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java
index 135a5ef45c4..c3a7e82dff5 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java
@@ -72,7 +72,7 @@ public class SessionActiveHandlerTest {
.curator(new MockCurator())
.modelFactoryRegistry(new ModelFactoryRegistry(List.of((modelFactory))))
.build();
- TenantRepository tenantRepository = new TenantRepository(componentRegistry, false);
+ TenantRepository tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(tenantName);
applicationRepository = new ApplicationRepository.Builder()
.withTenantRepository(tenantRepository)
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java
index f0362db3b8a..d28404d8d72 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java
@@ -45,7 +45,7 @@ public class SessionContentHandlerTest extends ContentHandlerTestBase {
@Before
public void setupHandler() {
- tenantRepository = new TenantRepository(componentRegistry, false);
+ tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(tenantName);
ApplicationRepository applicationRepository = new ApplicationRepository.Builder()
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java
index 0f1b1543c83..513bf6352e8 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java
@@ -61,7 +61,7 @@ public class SessionCreateHandlerTest extends SessionHandlerTest {
@Before
public void setupRepo() {
- TenantRepository tenantRepository = new TenantRepository(componentRegistry, false);
+ TenantRepository tenantRepository = new TenantRepository(componentRegistry);
applicationRepository = new ApplicationRepository.Builder()
.withTenantRepository(tenantRepository)
.withProvisioner(new SessionHandlerTest.MockProvisioner())
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java
index 297fee94e5b..cc4f39b0789 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java
@@ -65,7 +65,7 @@ public class SessionPrepareHandlerTest extends SessionHandlerTest {
@Before
public void setupRepo() {
- tenantRepository = new TenantRepository(componentRegistry, false);
+ tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(tenant);
applicationRepository = new ApplicationRepository.Builder()
.withTenantRepository(tenantRepository)
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java
index 0f087e3f006..eb06f2f7017 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java
@@ -88,7 +88,7 @@ public class RpcTester implements AutoCloseable {
.configDefinitionRepo(new TestConfigDefinitionRepo())
.configServerConfig(configserverConfig)
.build();
- tenantRepository = new TenantRepository(componentRegistry, false);
+ tenantRepository = new TenantRepository(componentRegistry);
tenantRepository.addTenant(tenantName);
applicationRepository = new ApplicationRepository.Builder()
.withTenantRepository(tenantRepository)
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java
index bb87233f979..b89b63aed46 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java
@@ -65,7 +65,7 @@ public class SessionRepositoryTest {
.build())
.flagSource(flagSource)
.build();
- tenantRepository = new TenantRepository(globalComponentRegistry, false);
+ tenantRepository = new TenantRepository(globalComponentRegistry);
tenantRepository.addTenant(SessionRepositoryTest.tenantName);
applicationRepository = new ApplicationRepository.Builder()
.withTenantRepository(tenantRepository)
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java
index a31b06bbebb..8678c42eab4 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java
@@ -202,7 +202,7 @@ public class TenantRepositoryTest {
private static class FailingDuringBootstrapTenantRepository extends TenantRepository {
public FailingDuringBootstrapTenantRepository(GlobalComponentRegistry globalComponentRegistry) {
- super(globalComponentRegistry, false);
+ super(globalComponentRegistry);
}
@Override
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java
index 74833da6d66..ac596198fe5 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java
@@ -32,7 +32,7 @@ public class TenantTest {
}
private Tenant createTenant(String name) {
- TenantRepository tenantRepository = new TenantRepository(componentRegistry, false);
+ TenantRepository tenantRepository = new TenantRepository(componentRegistry);
TenantName tenantName = TenantName.from(name);
tenantRepository.addTenant(tenantName);
return tenantRepository.getTenant(tenantName);
diff --git a/container-core/abi-spec.json b/container-core/abi-spec.json
index aa2e5ccfa5f..9292a946e82 100644
--- a/container-core/abi-spec.json
+++ b/container-core/abi-spec.json
@@ -114,7 +114,8 @@
],
"methods": [
"public void <init>(java.util.concurrent.Executor, com.yahoo.container.core.LogHandlerConfig)",
- "public com.yahoo.container.jdisc.HttpResponse handle(com.yahoo.container.jdisc.HttpRequest)"
+ "public com.yahoo.container.jdisc.AsyncHttpResponse handle(com.yahoo.container.jdisc.HttpRequest)",
+ "public bridge synthetic com.yahoo.container.jdisc.HttpResponse handle(com.yahoo.container.jdisc.HttpRequest)"
],
"fields": []
},
diff --git a/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java b/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java
index 13d50b9b30f..25299978ecd 100644
--- a/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java
+++ b/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java
@@ -107,9 +107,8 @@ public class HandlersConfigurerDi {
super(osgiFramework);
this.osgiFramework = osgiFramework;
- OsgiImpl osgi = new OsgiImpl(osgiFramework);
- applicationBundleLoader = new ApplicationBundleLoader(osgi, new FileAcquirerBundleInstaller(fileAcquirer));
- platformBundleLoader = new PlatformBundleLoader(osgi);
+ applicationBundleLoader = new ApplicationBundleLoader(this, new FileAcquirerBundleInstaller(fileAcquirer));
+ platformBundleLoader = new PlatformBundleLoader(this);
}
diff --git a/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java b/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java
index b2a156862eb..4b23eafaa9c 100644
--- a/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java
+++ b/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java
@@ -3,13 +3,19 @@ package com.yahoo.container.handler;
import com.google.inject.Inject;
import com.yahoo.container.core.LogHandlerConfig;
+import com.yahoo.container.jdisc.AsyncHttpResponse;
+import com.yahoo.container.jdisc.ContentChannelOutputStream;
import com.yahoo.container.jdisc.HttpRequest;
-import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
+import com.yahoo.jdisc.handler.CompletionHandler;
+import com.yahoo.jdisc.handler.ContentChannel;
+import java.io.IOException;
import java.io.OutputStream;
+import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.Optional;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.logging.Level;
@@ -28,25 +34,58 @@ public class LogHandler extends ThreadedHttpRequestHandler {
}
@Override
- public HttpResponse handle(HttpRequest request) {
-
+ public AsyncHttpResponse handle(HttpRequest request) {
Instant from = Optional.ofNullable(request.getProperty("from"))
.map(Long::valueOf).map(Instant::ofEpochMilli).orElse(Instant.MIN);
Instant to = Optional.ofNullable(request.getProperty("to"))
.map(Long::valueOf).map(Instant::ofEpochMilli).orElse(Instant.MAX);
-
Optional<String> hostname = Optional.ofNullable(request.getProperty("hostname"));
- return new HttpResponse(200) {
+ return new AsyncHttpResponse(200) {
@Override
- public void render(OutputStream outputStream) {
+ public void render(OutputStream output, ContentChannel networkChannel, CompletionHandler handler) {
try {
- logReader.writeLogs(outputStream, from, to, hostname);
+ OutputStream blockingOutput = new BlockingFlushContentChannelOutputStream(networkChannel);
+ logReader.writeLogs(blockingOutput, from, to, hostname);
+ blockingOutput.close();
}
catch (Throwable t) {
log.log(Level.WARNING, "Failed reading logs from " + from + " to " + to, t);
}
+ finally {
+ networkChannel.close(handler);
+ }
}
};
}
+
+
+ private static class BlockingFlushContentChannelOutputStream extends ContentChannelOutputStream {
+
+ private final ContentChannel channel;
+
+ public BlockingFlushContentChannelOutputStream(ContentChannel endpoint) {
+ super(endpoint);
+ this.channel = endpoint;
+ }
+
+ @Override
+ public void flush() throws IOException {
+ super.flush();
+ CountDownLatch latch = new CountDownLatch(1);
+ channel.write(ByteBuffer.allocate(0), // :'(
+ new CompletionHandler() {
+ @Override public void completed() { latch.countDown(); }
+ @Override public void failed(Throwable t) { latch.countDown(); }
+ });
+ try {
+ latch.await();
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException("Interrupted waiting for underlying IO to complete", e);
+ }
+ }
+
+ }
+
}
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java b/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java
index 9c48955bf4c..329889e70c0 100644
--- a/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java
+++ b/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java
@@ -53,11 +53,7 @@ public class ContentChannelOutputStream extends OutputStream implements Writable
public void close() throws IOException {
// the endpoint is closed in a finally{} block inside AbstractHttpRequestHandler
// this class should be possible to close willynilly as it is exposed to plug-ins
- try {
- buffer.flush();
- } catch (RuntimeException e) {
- throw new IOException(Exceptions.toMessageString(e), e);
- }
+ flush();
}
/**
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java b/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java
index e02ae152b3f..edd24fed515 100644
--- a/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java
+++ b/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java
@@ -517,11 +517,9 @@ public class HttpRequest {
}
/**
- * Access an HTTP header in the request. Multi-value headers are not
- * supported.
+ * Access an HTTP header in the request. Multi-value headers are not supported.
*
- * @param name
- * the name of an HTTP header
+ * @param name the name of an HTTP header
* @return the first pertinent value
*/
public String getHeader(String name) {
@@ -530,20 +528,12 @@ public class HttpRequest {
return parentRequest.headers().get(name).get(0);
}
- /**
- * Get the host segment of the URI of this request.
- *
- * @return the host name from the URI
- */
+ /** Get the host segment of the URI of this request. */
public String getHost() {
return getUri().getHost();
}
- /**
- * The port of the URI of this request.
- *
- * @return the port number of the URI
- */
+ /** The port of the URI of this request. */
public int getPort() {
return getUri().getPort();
}
diff --git a/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java b/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java
index 97aa8864eae..afe57579a97 100644
--- a/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java
+++ b/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java
@@ -1,17 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.handler;
+import com.yahoo.container.jdisc.AsyncHttpResponse;
import com.yahoo.container.jdisc.HttpRequest;
-import com.yahoo.container.jdisc.HttpResponse;
+import com.yahoo.jdisc.handler.ReadableContentChannel;
+import com.yahoo.yolean.Exceptions;
import org.junit.Test;
-import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.Executor;
+import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
@@ -24,20 +26,20 @@ public class LogHandlerTest {
{
String uri = "http://myhost.com:1111/logs?from=1000&to=2000";
- HttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET));
- ByteArrayOutputStream bos = new ByteArrayOutputStream();
- response.render(bos);
+ AsyncHttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET));
+ ReadableContentChannel out = new ReadableContentChannel();
+ new Thread(() -> Exceptions.uncheck(() -> response.render(null, out, null))).start();
String expectedResponse = "newer log";
- assertEquals(expectedResponse, bos.toString());
+ assertEquals(expectedResponse, new String(out.toStream().readAllBytes(), UTF_8));
}
{
String uri = "http://myhost.com:1111/logs?from=0&to=1000";
- HttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET));
- ByteArrayOutputStream bos = new ByteArrayOutputStream();
- response.render(bos);
+ AsyncHttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET));
+ ReadableContentChannel out = new ReadableContentChannel();
+ new Thread(() -> Exceptions.uncheck(() -> response.render(null, out, null))).start();
String expectedResponse = "older log";
- assertEquals(expectedResponse, bos.toString());
+ assertEquals(expectedResponse, new String(out.toStream().readAllBytes(), UTF_8));
}
}
diff --git a/default_build_settings.cmake b/default_build_settings.cmake
index 07a70c38d71..75399069619 100644
--- a/default_build_settings.cmake
+++ b/default_build_settings.cmake
@@ -79,7 +79,7 @@ endfunction()
function(setup_vespa_default_build_settings_fedora_33)
message("-- Setting up default build settings for fedora 33")
set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_DEPS}/include" "/usr/include/openblas" PARENT_SCOPE)
- set(DEFAULT_VESPA_LLVM_VERSION "10" PARENT_SCOPE)
+ set(DEFAULT_VESPA_LLVM_VERSION "11" PARENT_SCOPE)
endfunction()
function(setup_vespa_default_build_settings_ubuntu_19_10)
diff --git a/dist/vespa.spec b/dist/vespa.spec
index e49bd74e545..8dae9ee17bc 100644
--- a/dist/vespa.spec
+++ b/dist/vespa.spec
@@ -97,8 +97,8 @@ BuildRequires: gmock-devel
%endif
%if 0%{?fc33}
BuildRequires: protobuf-devel
-BuildRequires: llvm-devel >= 10.0.0
-BuildRequires: boost-devel >= 1.69
+BuildRequires: llvm-devel >= 11.0.0
+BuildRequires: boost-devel >= 1.73
BuildRequires: gtest-devel
BuildRequires: gmock-devel
%endif
@@ -200,8 +200,8 @@ Requires: llvm-libs >= 10.0.0
%endif
%if 0%{?fc33}
Requires: protobuf
-Requires: llvm-libs >= 10.0.0
-%define _vespa_llvm_version 10
+Requires: llvm-libs >= 11.0.0
+%define _vespa_llvm_version 11
%endif
%define _extra_link_directory %{_vespa_deps_prefix}/lib64
%define _extra_include_directory %{_vespa_deps_prefix}/include;/usr/include/openblas
diff --git a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
index 3f4641ed2ee..08333fa30f3 100644
--- a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
@@ -174,4 +174,20 @@ TEST("require that dense tensor cells iterator works for 2d tensor") {
EXPECT_FALSE(itr.valid());
}
+TEST("require that memory used count is reasonable") {
+ Tensor::UP full = build2DTensor();
+ const DenseTensorView &full_view = dynamic_cast<const DenseTensorView &>(*full);
+ DenseTensorView ref_view(full_view.fast_type(), full_view.cellsRef());
+
+ size_t full_sz = full->count_memory_used();
+ size_t view_sz = full_view.count_memory_used();
+ size_t ref_sz = ref_view.count_memory_used();
+
+ EXPECT_EQUAL(ref_sz, sizeof(DenseTensorView));
+ EXPECT_LESS(ref_sz, full_sz);
+ EXPECT_EQUAL(full_sz, view_sz);
+ EXPECT_LESS(full_sz, 10000u);
+ EXPECT_GREATER(full_sz, sizeof(DenseTensor<double>));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp b/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp
index 86b6abedd39..f901b7775fd 100644
--- a/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp
@@ -99,6 +99,10 @@ TEST("Test essential object sizes") {
EXPECT_EQUAL(16u, sizeof(SparseTensorAddressRef));
EXPECT_EQUAL(24u, sizeof(std::pair<SparseTensorAddressRef, double>));
EXPECT_EQUAL(32u, sizeof(vespalib::hash_node<std::pair<SparseTensorAddressRef, double>>));
+ Tensor::UP tensor = buildTensor();
+ size_t used = tensor->count_memory_used();
+ EXPECT_GREATER(used, sizeof(SparseTensor));
+ EXPECT_LESS(used, 10000u);
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/onnx_wrapper/.gitattributes b/eval/src/tests/tensor/onnx_wrapper/.gitattributes
new file mode 100644
index 00000000000..62e8ad1e0a0
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/.gitattributes
@@ -0,0 +1 @@
+/*.onnx binary
diff --git a/eval/src/tests/tensor/onnx_wrapper/dynamic.onnx b/eval/src/tests/tensor/onnx_wrapper/dynamic.onnx
new file mode 100644
index 00000000000..95bbf36885a
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/dynamic.onnx
@@ -0,0 +1,27 @@
+
+dynamic.py:¦
+0
+ query_tensor
+attribute_tensormatmul"MatMul
+-
+ bias_tensorreduce" ReduceSum*
+axes@ 
+
+matmul
+reduceoutput"Adddynamic_scoringZ#
+ query_tensor
+
+batch
+Z"
+attribute_tensor
+ 
+
+Z+
+ bias_tensor
+
+batch
+ ÿÿÿÿÿÿÿÿÿb
+output
+
+batch
+B \ No newline at end of file
diff --git a/eval/src/tests/tensor/onnx_wrapper/dynamic.py b/eval/src/tests/tensor/onnx_wrapper/dynamic.py
new file mode 100755
index 00000000000..d098324fae8
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/dynamic.py
@@ -0,0 +1,39 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.FLOAT, ['batch', 4])
+ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.FLOAT, [4, 1])
+BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.FLOAT, ['batch', -1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ['batch', 1])
+
+nodes = [
+ helper.make_node(
+ 'MatMul',
+ ['query_tensor', 'attribute_tensor'],
+ ['matmul'],
+ ),
+ helper.make_node(
+ 'ReduceSum',
+ ['bias_tensor'],
+ ['reduce'],
+ axes=[1]
+ ),
+ helper.make_node(
+ 'Add',
+ ['matmul', 'reduce'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'dynamic_scoring',
+ [
+ QUERY_TENSOR,
+ ATTRIBUTE_TENSOR,
+ BIAS_TENSOR,
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='dynamic.py')
+onnx.save(model_def, 'dynamic.onnx')
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 28a4a34b2e4..db2415e9969 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -10,83 +10,224 @@ using namespace vespalib::eval;
using namespace vespalib::tensor;
using vespalib::make_string_short::fmt;
+using TensorInfo = Onnx::TensorInfo;
+using DZ = Onnx::DimSize;
std::string get_source_dir() {
const char *dir = getenv("SOURCE_DIRECTORY");
return (dir ? dir : ".");
}
std::string source_dir = get_source_dir();
-std::string vespa_dir = source_dir + "/" + "../../../../..";
-std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx";
+std::string simple_model = source_dir + "/simple.onnx";
+std::string dynamic_model = source_dir + "/dynamic.onnx";
-void dump_info(const char *ctx, const std::vector<OnnxWrapper::TensorInfo> &info) {
+void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
fprintf(stderr, "%s:\n", ctx);
for (size_t i = 0; i < info.size(); ++i) {
fprintf(stderr, " %s[%zu]: '%s' %s\n", ctx, i, info[i].name.c_str(), info[i].type_as_string().c_str());
}
}
-TEST(OnnxWrapperTest, onnx_model_can_be_inspected)
+TEST(WirePlannerTest, element_types_must_match) {
+ Onnx::WirePlanner planner;
+ ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
+ ValueType type2 = ValueType::from_spec("tensor<double>(a[5])");
+ TensorInfo info1 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info2 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::DOUBLE};
+ EXPECT_TRUE(planner.bind_input_type(type1, info1));
+ EXPECT_FALSE(planner.bind_input_type(type2, info1));
+ EXPECT_FALSE(planner.bind_input_type(type1, info2));
+ EXPECT_TRUE(planner.bind_input_type(type2, info2));
+}
+
+TEST(WirePlannerTest, known_dimension_sizes_must_match) {
+ Onnx::WirePlanner planner;
+ ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])");
+ ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[5])");
+ ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])");
+ TensorInfo info = TensorInfo{"info", {DZ(5),DZ(5)}, TensorInfo::ElementType::FLOAT};
+ EXPECT_FALSE(planner.bind_input_type(type1, info));
+ EXPECT_FALSE(planner.bind_input_type(type2, info));
+ EXPECT_TRUE(planner.bind_input_type(type3, info));
+}
+
+TEST(WirePlannerTest, symbolic_dimension_sizes_must_match) {
+ Onnx::WirePlanner planner;
+ ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
+ ValueType type2 = ValueType::from_spec("tensor<float>(a[10])");
+ TensorInfo info = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT};
+ EXPECT_TRUE(planner.bind_input_type(type1, info)); // binds 'dim' to 5
+ EXPECT_FALSE(planner.bind_input_type(type2, info));
+ EXPECT_TRUE(planner.bind_input_type(type1, info));
+}
+
+TEST(WirePlannerTest, unknown_dimension_sizes_match_anything) {
+ Onnx::WirePlanner planner;
+ ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
+ ValueType type2 = ValueType::from_spec("tensor<float>(a[10])");
+ TensorInfo info = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT};
+ EXPECT_TRUE(planner.bind_input_type(type1, info));
+ EXPECT_TRUE(planner.bind_input_type(type2, info));
+}
+
+TEST(WirePlannerTest, all_output_dimensions_must_be_bound) {
+ Onnx::WirePlanner planner;
+ ValueType type = ValueType::from_spec("tensor<float>(a[5],b[10])");
+ TensorInfo info1 = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info2 = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info3 = TensorInfo{"info", {DZ("dim"),DZ()}, TensorInfo::ElementType::FLOAT};
+ EXPECT_TRUE(planner.make_output_type(info1).is_error());
+ EXPECT_TRUE(planner.make_output_type(info2).is_error());
+ EXPECT_TRUE(planner.make_output_type(info3).is_error());
+ EXPECT_TRUE(planner.bind_input_type(type, info3)); // binds 'dim' to 5
+ EXPECT_TRUE(planner.make_output_type(info1).is_error());
+ EXPECT_EQ(planner.make_output_type(info2).to_spec(), "tensor<float>(d0[5])");
+ EXPECT_TRUE(planner.make_output_type(info3).is_error());
+}
+
+TEST(WirePlannerTest, dimensions_resolve_left_to_right) {
+ Onnx::WirePlanner planner;
+ ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])");
+ ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[10])");
+ ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])");
+ TensorInfo info = TensorInfo{"info", {DZ("dim"),DZ("dim")}, TensorInfo::ElementType::FLOAT};
+ EXPECT_FALSE(planner.bind_input_type(type1, info)); // binds 'dim' to 5, then fails (5 != 10)
+ EXPECT_FALSE(planner.bind_input_type(type2, info));
+ EXPECT_TRUE(planner.bind_input_type(type3, info));
+}
+
+TEST(OnnxTest, simple_onnx_model_can_be_inspected)
{
- OnnxWrapper wrapper(simple_model, OnnxWrapper::Optimize::DISABLE);
- dump_info("inputs", wrapper.inputs());
- dump_info("outputs", wrapper.outputs());
- ASSERT_EQ(wrapper.inputs().size(), 3);
- ASSERT_EQ(wrapper.outputs().size(), 1);
+ Onnx model(simple_model, Onnx::Optimize::DISABLE);
+ dump_info("inputs", model.inputs());
+ dump_info("outputs", model.outputs());
+ ASSERT_EQ(model.inputs().size(), 3);
+ ASSERT_EQ(model.outputs().size(), 1);
//-------------------------------------------------------------------------
- EXPECT_EQ(wrapper.inputs()[0].name, "query_tensor");
- EXPECT_EQ(wrapper.inputs()[0].type_as_string(), "float[1][4]");
+ EXPECT_EQ(model.inputs()[0].name, "query_tensor");
+ EXPECT_EQ(model.inputs()[0].type_as_string(), "float[1][4]");
//-------------------------------------------------------------------------
- EXPECT_EQ(wrapper.inputs()[1].name, "attribute_tensor");
- EXPECT_EQ(wrapper.inputs()[1].type_as_string(), "float[4][1]");
+ EXPECT_EQ(model.inputs()[1].name, "attribute_tensor");
+ EXPECT_EQ(model.inputs()[1].type_as_string(), "float[4][1]");
//-------------------------------------------------------------------------
- EXPECT_EQ(wrapper.inputs()[2].name, "bias_tensor");
- EXPECT_EQ(wrapper.inputs()[2].type_as_string(), "float[1][1]");
+ EXPECT_EQ(model.inputs()[2].name, "bias_tensor");
+ EXPECT_EQ(model.inputs()[2].type_as_string(), "float[1][1]");
//-------------------------------------------------------------------------
- EXPECT_EQ(wrapper.outputs()[0].name, "output");
- EXPECT_EQ(wrapper.outputs()[0].type_as_string(), "float[1][1]");
+ EXPECT_EQ(model.outputs()[0].name, "output");
+ EXPECT_EQ(model.outputs()[0].type_as_string(), "float[1][1]");
}
-TEST(OnnxWrapperTest, onnx_model_can_be_evaluated)
+TEST(OnnxTest, dynamic_onnx_model_can_be_inspected)
{
- OnnxWrapper wrapper(simple_model, OnnxWrapper::Optimize::ENABLE);
+ Onnx model(dynamic_model, Onnx::Optimize::DISABLE);
+ dump_info("inputs", model.inputs());
+ dump_info("outputs", model.outputs());
+ ASSERT_EQ(model.inputs().size(), 3);
+ ASSERT_EQ(model.outputs().size(), 1);
+ //-------------------------------------------------------------------------
+ EXPECT_EQ(model.inputs()[0].name, "query_tensor");
+ EXPECT_EQ(model.inputs()[0].type_as_string(), "float[batch][4]");
+ //-------------------------------------------------------------------------
+ EXPECT_EQ(model.inputs()[1].name, "attribute_tensor");
+ EXPECT_EQ(model.inputs()[1].type_as_string(), "float[4][1]");
+ //-------------------------------------------------------------------------
+ EXPECT_EQ(model.inputs()[2].name, "bias_tensor");
+ EXPECT_EQ(model.inputs()[2].type_as_string(), "float[batch][]");
+ //-------------------------------------------------------------------------
+ EXPECT_EQ(model.outputs()[0].name, "output");
+ EXPECT_EQ(model.outputs()[0].type_as_string(), "float[batch][1]");
+}
+
+TEST(OnnxTest, simple_onnx_model_can_be_evaluated)
+{
+ Onnx model(simple_model, Onnx::Optimize::ENABLE);
+ Onnx::WirePlanner planner;
ValueType query_type = ValueType::from_spec("tensor<float>(a[1],b[4])");
std::vector<float> query_values({1.0, 2.0, 3.0, 4.0});
DenseTensorView query(query_type, TypedCells(query_values));
- EXPECT_TRUE(wrapper.inputs()[0].is_compatible(query_type));
- EXPECT_FALSE(wrapper.inputs()[1].is_compatible(query_type));
- EXPECT_FALSE(wrapper.inputs()[2].is_compatible(query_type));
+ EXPECT_TRUE(planner.bind_input_type(query_type, model.inputs()[0]));
ValueType attribute_type = ValueType::from_spec("tensor<float>(a[4],b[1])");
std::vector<float> attribute_values({5.0, 6.0, 7.0, 8.0});
DenseTensorView attribute(attribute_type, TypedCells(attribute_values));
- EXPECT_FALSE(wrapper.inputs()[0].is_compatible(attribute_type));
- EXPECT_TRUE(wrapper.inputs()[1].is_compatible(attribute_type));
- EXPECT_FALSE(wrapper.inputs()[2].is_compatible(attribute_type));
+ EXPECT_TRUE(planner.bind_input_type(attribute_type, model.inputs()[1]));
ValueType bias_type = ValueType::from_spec("tensor<float>(a[1],b[1])");
std::vector<float> bias_values({9.0});
DenseTensorView bias(bias_type, TypedCells(bias_values));
- EXPECT_FALSE(wrapper.inputs()[0].is_compatible(bias_type));
- EXPECT_FALSE(wrapper.inputs()[1].is_compatible(bias_type));
- EXPECT_TRUE(wrapper.inputs()[2].is_compatible(bias_type));
-
- MutableDenseTensorView output(wrapper.outputs()[0].make_compatible_type());
- EXPECT_EQ(output.fast_type().to_spec(), "tensor<float>(d0[1],d1[1])");
-
- OnnxWrapper::Params params;
- params.bind(0, query);
- params.bind(1, attribute);
- params.bind(2, bias);
- auto result = wrapper.eval(params);
-
- EXPECT_EQ(result.num_values(), 1);
- result.get(0, output);
- auto cells = output.cellsRef();
+ EXPECT_TRUE(planner.bind_input_type(bias_type, model.inputs()[2]));
+
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(),
+ "tensor<float>(d0[1],d1[1])");
+
+ Onnx::WireInfo wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &output = ctx.get_result(0);
+ EXPECT_EQ(output.type().to_spec(), "tensor<float>(d0[1],d1[1])");
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, query);
+ ctx.bind_param(1, attribute);
+ ctx.bind_param(2, bias);
+ ctx.eval();
+ auto cells = static_cast<const DenseTensorView&>(output).cellsRef();
EXPECT_EQ(cells.type, ValueType::CellType::FLOAT);
EXPECT_EQ(cells.size, 1);
EXPECT_EQ(cells.get(0), 79.0);
+ //-------------------------------------------------------------------------
+ std::vector<float> new_bias_values({10.0});
+ DenseTensorView new_bias(bias_type, TypedCells(new_bias_values));
+ ctx.bind_param(2, new_bias);
+ ctx.eval();
+ EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0);
+ //-------------------------------------------------------------------------
+}
+
+TEST(OnnxTest, dynamic_onnx_model_can_be_evaluated)
+{
+ Onnx model(dynamic_model, Onnx::Optimize::ENABLE);
+ Onnx::WirePlanner planner;
+
+ ValueType query_type = ValueType::from_spec("tensor<float>(a[1],b[4])");
+ std::vector<float> query_values({1.0, 2.0, 3.0, 4.0});
+ DenseTensorView query(query_type, TypedCells(query_values));
+ EXPECT_TRUE(planner.bind_input_type(query_type, model.inputs()[0]));
+
+ ValueType attribute_type = ValueType::from_spec("tensor<float>(a[4],b[1])");
+ std::vector<float> attribute_values({5.0, 6.0, 7.0, 8.0});
+ DenseTensorView attribute(attribute_type, TypedCells(attribute_values));
+ EXPECT_TRUE(planner.bind_input_type(attribute_type, model.inputs()[1]));
+
+ ValueType bias_type = ValueType::from_spec("tensor<float>(a[1],b[2])");
+ std::vector<float> bias_values({4.0, 5.0});
+ DenseTensorView bias(bias_type, TypedCells(bias_values));
+ EXPECT_TRUE(planner.bind_input_type(bias_type, model.inputs()[2]));
+
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(),
+ "tensor<float>(d0[1],d1[1])");
+
+ Onnx::WireInfo wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &output = ctx.get_result(0);
+ EXPECT_EQ(output.type().to_spec(), "tensor<float>(d0[1],d1[1])");
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, query);
+ ctx.bind_param(1, attribute);
+ ctx.bind_param(2, bias);
+ ctx.eval();
+ auto cells = static_cast<const DenseTensorView&>(output).cellsRef();
+ EXPECT_EQ(cells.type, ValueType::CellType::FLOAT);
+ EXPECT_EQ(cells.size, 1);
+ EXPECT_EQ(cells.get(0), 79.0);
+ //-------------------------------------------------------------------------
+ std::vector<float> new_bias_values({5.0,6.0});
+ DenseTensorView new_bias(bias_type, TypedCells(new_bias_values));
+ ctx.bind_param(2, new_bias);
+ ctx.eval();
+ EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 81.0);
+ //-------------------------------------------------------------------------
}
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/tests/tensor/onnx_wrapper/simple.onnx b/eval/src/tests/tensor/onnx_wrapper/simple.onnx
new file mode 100644
index 00000000000..88ed0ef23f0
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/simple.onnx
@@ -0,0 +1,23 @@
+ simple.py:ã
+0
+ query_tensor
+attribute_tensormatmul"MatMul
+"
+matmul
+ bias_tensoroutput"Addsimple_scoringZ
+ query_tensor
+ 
+
+Z"
+attribute_tensor
+ 
+
+Z
+ bias_tensor
+ 
+
+b
+output
+ 
+
+B \ No newline at end of file
diff --git a/eval/src/tests/tensor/onnx_wrapper/simple.py b/eval/src/tests/tensor/onnx_wrapper/simple.py
new file mode 100755
index 00000000000..a3cd2425d58
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/simple.py
@@ -0,0 +1,33 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.FLOAT, [1, 4])
+ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.FLOAT, [4, 1])
+BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.FLOAT, [1, 1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 1])
+
+nodes = [
+ helper.make_node(
+ 'MatMul',
+ ['query_tensor', 'attribute_tensor'],
+ ['matmul'],
+ ),
+ helper.make_node(
+ 'Add',
+ ['matmul', 'bias_tensor'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'simple_scoring',
+ [
+ QUERY_TENSOR,
+ ATTRIBUTE_TENSOR,
+ BIAS_TENSOR,
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='simple.py')
+onnx.save(model_def, 'simple.onnx')
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 6f9bee025c9..ad182115054 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -200,7 +200,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
}
assert(pass_params == PassParams::LAZY);
assert(params.size() == 2);
- return builder.CreateCall(params[0], {params[1], builder.getInt64(idx)}, "resolve_param");
+ return builder.CreateCall(llvm::cast<llvm::FunctionType>(params[0]->getType()->getPointerElementType()),
+ params[0], {params[1], builder.getInt64(idx)}, "resolve_param");
}
//-------------------------------------------------------------------------
@@ -252,12 +253,14 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), eval_funptr_t, "inject_eval");
llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)forest), builder.getVoidTy()->getPointerTo(), "inject_ctx");
if (pass_params == PassParams::ARRAY) {
- push(builder.CreateCall(eval_fun, {ctx, params[0]}, "call_eval"));
+ push(builder.CreateCall(llvm::cast<llvm::FunctionType>(eval_fun->getType()->getPointerElementType()),
+ eval_fun, {ctx, params[0]}, "call_eval"));
} else {
assert(pass_params == PassParams::LAZY);
llvm::PointerType *proxy_funptr_t = make_eval_forest_proxy_funptr_t();
llvm::Value *proxy_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)vespalib_eval_forest_proxy), proxy_funptr_t, "inject_eval_proxy");
- push(builder.CreateCall(proxy_fun, {eval_fun, ctx, params[0], params[1], builder.getInt64(stats.num_params)}));
+ push(builder.CreateCall(llvm::cast<llvm::FunctionType>(proxy_fun->getType()->getPointerElementType()),
+ proxy_fun, {eval_fun, ctx, params[0], params[1], builder.getInt64(stats.num_params)}));
}
return true;
}
@@ -411,7 +414,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::PointerType *funptr_t = make_check_membership_funptr_t();
llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr");
llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx");
- push(builder.CreateCall(call_fun, {ctx, lhs}, "call_check_membership"));
+ push(builder.CreateCall(llvm::cast<llvm::FunctionType>(call_fun->getType()->getPointerElementType()),
+ call_fun, {ctx, lhs}, "call_check_membership"));
} else {
// build explicit code to check all set members
llvm::Value *found = builder.getFalse();
diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp
index a72a24be211..9ed28d87fee 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor.cpp
@@ -769,5 +769,14 @@ SimpleTensor::decode(nbostream &input)
return builder.build();
}
+size_t
+SimpleTensor::count_memory_used() const {
+ size_t result = sizeof(SimpleTensor);
+ size_t addr_size = sizeof(Label) * _type.dimensions().size();
+ size_t cell_size = sizeof(Cell) + addr_size;
+ result += _cells.size() * cell_size;
+ return result;
+}
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h
index cbf1ac99e05..052d7cb70bd 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.h
+++ b/eval/src/vespa/eval/eval/simple_tensor.h
@@ -93,6 +93,7 @@ public:
static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension);
static void encode(const SimpleTensor &tensor, nbostream &output);
static std::unique_ptr<SimpleTensor> decode(nbostream &input);
+ size_t count_memory_used() const;
};
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.h b/eval/src/vespa/eval/tensor/dense/dense_tensor.h
index d0246fef635..4114661a074 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.h
@@ -21,6 +21,11 @@ public:
// for unit tests
template <typename RCT>
bool operator==(const DenseTensor<RCT> &rhs) const;
+
+ size_t count_memory_used() const override {
+ return sizeof(DenseTensor) + (sizeof(CT) * _cells.size());
+ }
+
private:
eval::ValueType _type;
std::vector<CT> _cells;
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 93dd2dbedeb..a07a3eede77 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -18,6 +18,7 @@ public:
using CellsIterator = DenseTensorCellsIterator;
using Address = std::vector<eval::ValueType::Dimension::size_type>;
+ DenseTensorView(const DenseTensorView &rhs) : DenseTensorView(rhs._typeRef, rhs._cellsRef) {}
DenseTensorView(const eval::ValueType &type_in, TypedCells cells_in)
: _typeRef(type_in),
_cellsRef(cells_in)
@@ -43,6 +44,9 @@ public:
Tensor::UP clone() const override;
eval::TensorSpec toSpec() const override;
void accept(TensorVisitor &visitor) const override;
+ size_t count_memory_used() const override {
+ return sizeof(DenseTensorView);
+ }
template <typename T> static ConstArrayRef<T> typify_cells(const eval::Value &self) {
return static_cast<const DenseTensorView &>(self).cellsRef().typify<T>();
@@ -55,7 +59,6 @@ protected:
: _typeRef(type_in),
_cellsRef()
{}
- DenseTensorView(const DenseTensorView &rhs) : DenseTensorView(rhs._typeRef, rhs._cellsRef) {}
void initCellsRef(TypedCells cells_in) {
assert(_typeRef.cell_type() == cells_in.type);
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
index 125095ff23e..88346213901 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
@@ -18,31 +18,31 @@ namespace vespalib::tensor {
namespace {
-vespalib::string to_str(OnnxWrapper::TensorInfo::ElementType element_type) {
- if (element_type == OnnxWrapper::TensorInfo::ElementType::FLOAT) {
+vespalib::string to_str(Onnx::TensorInfo::ElementType element_type) {
+ if (element_type == Onnx::TensorInfo::ElementType::FLOAT) {
return "float";
}
- if (element_type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) {
+ if (element_type == Onnx::TensorInfo::ElementType::DOUBLE) {
return "double";
}
return "???";
}
-ValueType::CellType as_cell_type(OnnxWrapper::TensorInfo::ElementType type) {
- if (type == OnnxWrapper::TensorInfo::ElementType::FLOAT) {
+ValueType::CellType as_cell_type(Onnx::TensorInfo::ElementType type) {
+ if (type == Onnx::TensorInfo::ElementType::FLOAT) {
return ValueType::CellType::FLOAT;
}
- if (type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) {
+ if (type == Onnx::TensorInfo::ElementType::DOUBLE) {
return ValueType::CellType::DOUBLE;
}
abort();
}
-auto convert_optimize(OnnxWrapper::Optimize optimize) {
- if (optimize == OnnxWrapper::Optimize::ENABLE) {
+auto convert_optimize(Onnx::Optimize optimize) {
+ if (optimize == Onnx::Optimize::ENABLE) {
return ORT_ENABLE_ALL;
} else {
- assert(optimize == OnnxWrapper::Optimize::DISABLE);
+ assert(optimize == Onnx::Optimize::DISABLE);
return ORT_DISABLE_ALL;
}
}
@@ -81,37 +81,77 @@ public:
};
Ort::AllocatorWithDefaultOptions OnnxString::_alloc;
-std::vector<size_t> make_dimensions(const std::vector<int64_t> &shape) {
- std::vector<size_t> result;
- for (int64_t size: shape) {
- result.push_back(std::max(size, 0L));
- }
+std::vector<Onnx::DimSize> make_dimensions(const Ort::TensorTypeAndShapeInfo &tensor_info) {
+ std::vector<const char *> symbolic_sizes(tensor_info.GetDimensionsCount(), nullptr);
+ tensor_info.GetSymbolicDimensions(symbolic_sizes.data(), symbolic_sizes.size());
+ auto shape = tensor_info.GetShape();
+ std::vector<Onnx::DimSize> result;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ if (shape[i] > 0) {
+ result.emplace_back(shape[i]);
+ } else if (symbolic_sizes[i] != nullptr) {
+ result.emplace_back(vespalib::string(symbolic_sizes[i]));
+ } else {
+ result.emplace_back();
+ }
+ }
return result;
}
-OnnxWrapper::TensorInfo::ElementType make_element_type(ONNXTensorElementDataType element_type) {
+Onnx::TensorInfo::ElementType make_element_type(ONNXTensorElementDataType element_type) {
if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
- return OnnxWrapper::TensorInfo::ElementType::FLOAT;
+ return Onnx::TensorInfo::ElementType::FLOAT;
} else if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
- return OnnxWrapper::TensorInfo::ElementType::DOUBLE;
+ return Onnx::TensorInfo::ElementType::DOUBLE;
} else {
- return OnnxWrapper::TensorInfo::ElementType::UNKNOWN;
+ return Onnx::TensorInfo::ElementType::UNKNOWN;
}
}
-OnnxWrapper::TensorInfo make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) {
+Onnx::TensorInfo make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) {
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
- auto shape = tensor_info.GetShape();
auto element_type = tensor_info.GetElementType();
- return OnnxWrapper::TensorInfo{vespalib::string(name.get()), make_dimensions(shape), make_element_type(element_type)};
+ return Onnx::TensorInfo{vespalib::string(name.get()), make_dimensions(tensor_info), make_element_type(element_type)};
}
}
+vespalib::string
+Onnx::DimSize::as_string() const
+{
+ if (is_known()) {
+ return fmt("[%zu]", value);
+ } else if (is_symbolic()) {
+ return fmt("[%s]", name.c_str());
+ } else {
+ return "[]";
+ }
+}
+
+vespalib::string
+Onnx::TensorInfo::type_as_string() const
+{
+ vespalib::string res = to_str(elements);
+ for (const auto &dim: dimensions) {
+ res += dim.as_string();
+ }
+ return res;
+}
+
+Onnx::TensorInfo::~TensorInfo() = default;
+
+//-----------------------------------------------------------------------------
+
+Onnx::WirePlanner::~WirePlanner() = default;
+
bool
-OnnxWrapper::TensorInfo::is_compatible(const eval::ValueType &type) const
+Onnx::WirePlanner::bind_input_type(const eval::ValueType &vespa_in, const TensorInfo &onnx_in)
{
- if ((elements == ElementType::UNKNOWN) || dimensions.empty()) {
+ const auto &type = vespa_in;
+ const auto &name = onnx_in.name;
+ const auto &dimensions = onnx_in.dimensions;
+ const auto &elements = onnx_in.elements;
+ if ((elements == TensorInfo::ElementType::UNKNOWN) || dimensions.empty()) {
return false;
}
if (type.cell_type() != as_cell_type(elements)) {
@@ -121,21 +161,41 @@ OnnxWrapper::TensorInfo::is_compatible(const eval::ValueType &type) const
return false;
}
for (size_t i = 0; i < dimensions.size(); ++i) {
- if (type.dimensions()[i].size != dimensions[i]) {
- return false;
+ if (dimensions[i].is_known()) {
+ if (dimensions[i].value != type.dimensions()[i].size) {
+ return false;
+ }
+ } else if (dimensions[i].is_symbolic()) {
+ auto &bound_size = _symbolic_sizes[dimensions[i].name];
+ if (bound_size == 0) {
+ bound_size = type.dimensions()[i].size;
+ } else if (bound_size != type.dimensions()[i].size) {
+ return false;
+ }
+ } else {
+ _unknown_sizes[std::make_pair(name,i)] = type.dimensions()[i].size;
}
}
return true;
}
eval::ValueType
-OnnxWrapper::TensorInfo::make_compatible_type() const
+Onnx::WirePlanner::make_output_type(const TensorInfo &onnx_out) const
{
- if ((elements == ElementType::UNKNOWN) || dimensions.empty()) {
+ const auto &dimensions = onnx_out.dimensions;
+ const auto &elements = onnx_out.elements;
+ if ((elements == TensorInfo::ElementType::UNKNOWN) || dimensions.empty()) {
return ValueType::error_type();
}
std::vector<ValueType::Dimension> dim_list;
- for (size_t dim_size: dimensions) {
+ for (const auto &dim: dimensions) {
+ size_t dim_size = dim.value;
+ if (dim.is_symbolic()) {
+ auto pos = _symbolic_sizes.find(dim.name);
+ if (pos != _symbolic_sizes.end()) {
+ dim_size = pos->second;
+ }
+ }
if ((dim_size == 0) || (dim_list.size() > 9)) {
return ValueType::error_type();
}
@@ -144,71 +204,131 @@ OnnxWrapper::TensorInfo::make_compatible_type() const
return ValueType::tensor_type(std::move(dim_list), as_cell_type(elements));
}
-vespalib::string
-OnnxWrapper::TensorInfo::type_as_string() const
+Onnx::WireInfo
+Onnx::WirePlanner::get_wire_info(const Onnx &model) const
{
- vespalib::string res = to_str(elements);
- for (size_t dim_size: dimensions) {
- if (dim_size == 0) {
- res += "[]";
- } else {
- res += fmt("[%zu]", dim_size);
+ WireInfo info;
+ for (const auto &input: model.inputs()) {
+ size_t input_idx = 0;
+ std::vector<int64_t> sizes;
+ for (const auto &dim: input.dimensions) {
+ if (dim.is_known()) {
+ sizes.push_back(dim.value);
+ } else if (dim.is_symbolic()) {
+ const auto &pos = _symbolic_sizes.find(dim.name);
+ assert(pos != _symbolic_sizes.end());
+ sizes.push_back(pos->second);
+ } else {
+ const auto &pos = _unknown_sizes.find(std::make_pair(input.name, input_idx));
+ assert(pos != _unknown_sizes.end());
+ sizes.push_back(pos->second);
+ }
+ ++input_idx;
}
+ info.input_sizes.push_back(sizes);
}
- return res;
+ for (const auto &output: model.outputs()) {
+ info.output_types.push_back(make_output_type(output));
+ }
+ return info;
}
-OnnxWrapper::TensorInfo::~TensorInfo() = default;
+//-----------------------------------------------------------------------------
-OnnxWrapper::Shared::Shared()
- : _env(ORT_LOGGING_LEVEL_WARNING, "vespa-onnx-wrapper")
+Ort::AllocatorWithDefaultOptions Onnx::EvalContext::_alloc;
+
+Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
+ : _model(model),
+ _wire_info(wire_info),
+ _cpu_memory(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)),
+ _param_values(),
+ _result_values(),
+ _result_views()
{
+ assert(_wire_info.input_sizes.size() == _model.inputs().size());
+ assert(_wire_info.output_types.size() == _model.outputs().size());
+ for (const auto &input: _wire_info.input_sizes) {
+ (void) input;
+ _param_values.push_back(Ort::Value(nullptr));
+ }
+ std::vector<int64_t> dim_sizes;
+ size_t num_cells;
+ dim_sizes.reserve(16);
+ // NB: output type must be reference inside vector since the view does not copy it
+ for (const auto &output: _wire_info.output_types) {
+ num_cells = 1;
+ dim_sizes.clear();
+ for (const auto &dim: output.dimensions()) {
+ dim_sizes.push_back(dim.size);
+ num_cells *= dim.size;
+ }
+ if (output.cell_type() == ValueType::CellType::FLOAT) {
+ _result_values.push_back(Ort::Value::CreateTensor<float>(_alloc, dim_sizes.data(), dim_sizes.size()));
+ ConstArrayRef<float> cells(_result_values.back().GetTensorMutableData<float>(), num_cells);
+ _result_views.emplace_back(output, TypedCells(cells));
+ } else {
+ assert(output.cell_type() == ValueType::CellType::DOUBLE);
+ _result_values.push_back(Ort::Value::CreateTensor<double>(_alloc, dim_sizes.data(), dim_sizes.size()));
+ ConstArrayRef<double> cells(_result_values.back().GetTensorMutableData<double>(), num_cells);
+ _result_views.emplace_back(output, TypedCells(cells));
+ }
+ }
}
+Onnx::EvalContext::~EvalContext() = default;
+
void
-OnnxWrapper::Params::bind(size_t idx, const DenseTensorView &src)
+Onnx::EvalContext::bind_param(size_t i, const eval::Value &param)
{
- assert(idx == values.size());
- std::vector<int64_t> dim_sizes;
- for (const auto &dim: src.fast_type().dimensions()) {
- dim_sizes.push_back(dim.size);
- }
- auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
- if (src.fast_type().cell_type() == ValueType::CellType::FLOAT) {
+ // NB: dense tensors are always (sub)classes of DenseTensorView
+ const auto &cells_ref = static_cast<const DenseTensorView &>(param).cellsRef();
+ const auto &input_sizes = _wire_info.input_sizes;
+ if (cells_ref.type == ValueType::CellType::FLOAT) {
// NB: create requires non-const input
- auto cells = unconstify(src.cellsRef().typify<float>());
- values.push_back(Ort::Value::CreateTensor<float>(memory_info, cells.begin(), cells.size(), dim_sizes.data(), dim_sizes.size()));
- } else if (src.fast_type().cell_type() == ValueType::CellType::DOUBLE) {
+ auto cells = unconstify(cells_ref.typify<float>());
+ _param_values[i] = Ort::Value::CreateTensor<float>(_cpu_memory, cells.begin(), cells.size(), input_sizes[i].data(), input_sizes[i].size());
+ } else {
+ assert(cells_ref.type == ValueType::CellType::DOUBLE);
// NB: create requires non-const input
- auto cells = unconstify(src.cellsRef().typify<double>());
- values.push_back(Ort::Value::CreateTensor<double>(memory_info, cells.begin(), cells.size(), dim_sizes.data(), dim_sizes.size()));
+ auto cells = unconstify(cells_ref.typify<double>());
+ _param_values[i] = Ort::Value::CreateTensor<double>(_cpu_memory, cells.begin(), cells.size(), input_sizes[i].data(), input_sizes[i].size());
}
}
void
-OnnxWrapper::Result::get(size_t idx, MutableDenseTensorView &dst)
+Onnx::EvalContext::eval()
{
- assert(values[idx].IsTensor());
- auto meta = values[idx].GetTensorTypeAndShapeInfo();
- if (dst.fast_type().cell_type() == ValueType::CellType::FLOAT) {
- assert(meta.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
- ConstArrayRef<float> cells(values[idx].GetTensorMutableData<float>(), meta.GetElementCount());
- dst.setCells(TypedCells(cells));
- } else if (dst.fast_type().cell_type() == ValueType::CellType::DOUBLE) {
- assert(meta.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE);
- ConstArrayRef<double> cells(values[idx].GetTensorMutableData<double>(), meta.GetElementCount());
- dst.setCells(TypedCells(cells));
- }
+ // NB: Run requires non-const session
+ Ort::Session &session = const_cast<Ort::Session&>(_model._session);
+ Ort::RunOptions run_opts(nullptr);
+ session.Run(run_opts,
+ _model._input_name_refs.data(), _param_values.data(), _param_values.size(),
+ _model._output_name_refs.data(), _result_values.data(), _result_values.size());
}
-OnnxWrapper::Shared &
-OnnxWrapper::Shared::get() {
+const eval::Value &
+Onnx::EvalContext::get_result(size_t i) const
+{
+ return _result_views[i];
+}
+
+//-----------------------------------------------------------------------------
+
+Onnx::Shared::Shared()
+ : _env(ORT_LOGGING_LEVEL_WARNING, "vespa-onnx-wrapper")
+{
+}
+
+Onnx::Shared &
+Onnx::Shared::get() {
static Shared shared;
return shared;
}
+//-----------------------------------------------------------------------------
+
void
-OnnxWrapper::extract_meta_data()
+Onnx::extract_meta_data()
{
Ort::AllocatorWithDefaultOptions allocator;
size_t num_inputs = _session.GetInputCount();
@@ -227,7 +347,7 @@ OnnxWrapper::extract_meta_data()
}
}
-OnnxWrapper::OnnxWrapper(const vespalib::string &model_file, Optimize optimize)
+Onnx::Onnx(const vespalib::string &model_file, Optimize optimize)
: _shared(Shared::get()),
_options(),
_session(nullptr),
@@ -243,17 +363,6 @@ OnnxWrapper::OnnxWrapper(const vespalib::string &model_file, Optimize optimize)
extract_meta_data();
}
-OnnxWrapper::~OnnxWrapper() = default;
-
-OnnxWrapper::Result
-OnnxWrapper::eval(const Params &params) const
-{
- assert(params.values.size() == _inputs.size());
- Ort::RunOptions run_opts(nullptr);
- // NB: Run requires non-const session
- Ort::Session &session = const_cast<Ort::Session&>(_session);
- return Result(session.Run(run_opts, _input_name_refs.data(), params.values.data(), _inputs.size(),
- _output_name_refs.data(), _outputs.size()));
-}
+Onnx::~Onnx() = default;
}
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
index abe1da252c7..23ddbcb8885 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
@@ -2,56 +2,101 @@
#pragma once
+#include "dense_tensor_view.h"
#include <onnxruntime/onnxruntime_cxx_api.h>
#include <vespa/vespalib/stllike/string.h>
#include <vespa/eval/eval/value_type.h>
#include <vector>
+#include <map>
-namespace vespalib::tensor {
+namespace vespalib::eval { struct Value; }
-class DenseTensorView;
-class MutableDenseTensorView;
+namespace vespalib::tensor {
/**
* Wrapper around an ONNX model handeled by onnxruntime.
+ *
+ * Create an Onnx object that will load your model and extract
+ * information about inputs and outputs. Use an Onnx::WirePlanner to
+ * bind vespa value types to each of the onnx model inputs. Ask the
+ * wire planner about the vespa value types corresponding to each of
+ * the model outputs for external wiring. Use the wire planner to make
+ * a WireInfo object which is a simple struct indicating the concrete
+ * onnx and vespa types to be used when converting inputs and
+ * outputs. Create an Onnx::EvalContex based on the model and the wire
+ * plan. Bind actual vespa values to the model inputs, invoke eval and
+ * inspect the results. See the unit test (tests/tensor/onnx_wrapper)
+ * for some examples.
**/
-class OnnxWrapper {
+class Onnx {
public:
// model optimization
enum class Optimize { ENABLE, DISABLE };
+ // the size of a dimension
+ struct DimSize {
+ size_t value;
+ vespalib::string name;
+ DimSize() : value(0), name() {}
+ DimSize(size_t size) : value(size), name() {}
+ DimSize(const vespalib::string &symbol) : value(0), name(symbol) {}
+ bool is_known() const { return (value > 0); }
+ bool is_symbolic() const { return !name.empty(); }
+ vespalib::string as_string() const;
+ };
+
// information about a single input or output tensor
struct TensorInfo {
enum class ElementType { FLOAT, DOUBLE, UNKNOWN };
vespalib::string name;
- std::vector<size_t> dimensions;
+ std::vector<DimSize> dimensions;
ElementType elements;
- bool is_compatible(const eval::ValueType &type) const;
- eval::ValueType make_compatible_type() const;
vespalib::string type_as_string() const;
~TensorInfo();
};
- // used to build model parameters
- class Params {
- friend class OnnxWrapper;
+ // how the model should be wired with inputs/outputs
+ struct WireInfo {
+ std::vector<std::vector<int64_t>> input_sizes;
+ std::vector<eval::ValueType> output_types;
+ WireInfo() : input_sizes(), output_types() {}
+ };
+
+ // planning how we should wire the model based on input types
+ class WirePlanner {
private:
- std::vector<Ort::Value> values;
+ std::map<vespalib::string,size_t> _symbolic_sizes;
+ std::map<std::pair<vespalib::string,size_t>,size_t> _unknown_sizes;
public:
- Params() : values() {}
- void bind(size_t idx, const DenseTensorView &src);
+ WirePlanner() : _symbolic_sizes(), _unknown_sizes() {}
+ ~WirePlanner();
+ bool bind_input_type(const eval::ValueType &vespa_in, const TensorInfo &onnx_in);
+ eval::ValueType make_output_type(const TensorInfo &onnx_out) const;
+ WireInfo get_wire_info(const Onnx &model) const;
};
- // used to inspect model results
- class Result {
- friend class OnnxWrapper;
+ // evaluation context; use one per thread and keep model/wire_info alive
+ // all parameter values are expected to be bound per evaluation
+ // output values are pre-allocated and will not change
+ class EvalContext {
private:
- std::vector<Ort::Value> values;
- Result(std::vector<Ort::Value> values_in) : values(std::move(values_in)) {}
+ static Ort::AllocatorWithDefaultOptions _alloc;
+
+ const Onnx &_model;
+ const WireInfo &_wire_info;
+ Ort::MemoryInfo _cpu_memory;
+ std::vector<Ort::Value> _param_values;
+ std::vector<Ort::Value> _result_values;
+ std::vector<DenseTensorView> _result_views;
+
public:
- static Result make_empty() { return Result({}); }
- size_t num_values() const { return values.size(); }
- void get(size_t idx, MutableDenseTensorView &dst);
+ EvalContext(const Onnx &model, const WireInfo &wire_info);
+ ~EvalContext();
+ size_t num_params() const { return _param_values.size(); }
+ size_t num_results() const { return _result_values.size(); }
+ void bind_param(size_t i, const eval::Value &param);
+ void eval();
+ const eval::Value &get_result(size_t i) const;
};
private:
@@ -76,11 +121,10 @@ private:
void extract_meta_data();
public:
- OnnxWrapper(const vespalib::string &model_file, Optimize optimize);
- ~OnnxWrapper();
+ Onnx(const vespalib::string &model_file, Optimize optimize);
+ ~Onnx();
const std::vector<TensorInfo> &inputs() const { return _inputs; }
const std::vector<TensorInfo> &outputs() const { return _outputs; }
- Result eval(const Params &params) const;
};
}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
index d183c33f5cd..db35de6786d 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
@@ -243,6 +243,16 @@ SparseTensor::remove(const CellValues &cellAddresses) const
return remover.build();
}
+size_t
+SparseTensor::count_memory_used() const
+{
+ size_t result = sizeof(SparseTensor) + _cells.getMemoryConsumption();
+ for (const auto &cell : _cells) {
+ result += cell.first.size();
+ }
+ return result;
+}
+
}
VESPALIB_HASH_MAP_INSTANTIATE(vespalib::tensor::SparseTensorAddressRef, double);
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index e5ea639b460..6bd181e1895 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -53,6 +53,7 @@ public:
Tensor::UP clone() const override;
eval::TensorSpec toSpec() const override;
void accept(TensorVisitor &visitor) const override;
+ size_t count_memory_used() const override;
};
}
diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h
index d822c99a6d8..bef7309c609 100644
--- a/eval/src/vespa/eval/tensor/tensor.h
+++ b/eval/src/vespa/eval/tensor/tensor.h
@@ -59,6 +59,7 @@ public:
virtual Tensor::UP clone() const = 0; // want to remove, but needed by document
virtual eval::TensorSpec toSpec() const = 0;
virtual void accept(TensorVisitor &visitor) const = 0;
+ virtual size_t count_memory_used() const = 0;
using TypeList = std::initializer_list<std::reference_wrapper<const eval::ValueType>>;
static bool supported(TypeList types);
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 7c09bc4e4ab..fe73bf92063 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -54,6 +54,16 @@ WrappedSimpleTensor::accept(TensorVisitor &visitor) const
}
}
+size_t
+WrappedSimpleTensor::count_memory_used() const
+{
+ size_t result = sizeof(WrappedSimpleTensor);
+ if (_space) {
+ result += _space->count_memory_used();
+ }
+ return result;
+}
+
Tensor::UP
WrappedSimpleTensor::clone() const
{
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
index 12ee1237d67..6b549718a29 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
@@ -33,6 +33,7 @@ public:
eval::TensorSpec toSpec() const override;
double as_double() const override;
void accept(TensorVisitor &visitor) const override;
+ size_t count_memory_used() const override;
Tensor::UP clone() const override;
// functions below should not be used for this implementation
Tensor::UP apply(const CellFunction &) const override;
diff --git a/metrics/src/tests/metricmanagertest.cpp b/metrics/src/tests/metricmanagertest.cpp
index 6407bb73ecb..47515d1bc4c 100644
--- a/metrics/src/tests/metricmanagertest.cpp
+++ b/metrics/src/tests/metricmanagertest.cpp
@@ -11,6 +11,7 @@
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/util/xmlstream.h>
#include <vespa/vespalib/util/time.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <thread>
#include <vespa/log/log.h>
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java
index cbc5a44ae94..ad85235fc69 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java
@@ -96,6 +96,11 @@ public class NodeList extends AbstractFilteringList<Node, NodeList> {
.orElse(Version.emptyVersion)));
}
+ /** Returns the subset of nodes that are currently on a lower version than the given version */
+ public NodeList osVersionIsBefore(Version version) {
+ return matching(node -> node.status().osVersion().isBefore(version));
+ }
+
/** Returns the subset of nodes that are currently on the given OS version */
public NodeList onOsVersion(Version version) {
return matching(node -> node.status().osVersion().matches(version));
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java
index be1190ccff4..0e3b6715ff1 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java
@@ -26,8 +26,8 @@ public class OsUpgradeActivator extends NodeRepositoryMaintainer {
protected boolean maintain() {
for (var nodeType : NodeType.values()) {
if (!nodeType.isHost()) continue;
- var active = canUpgradeOsOf(nodeType);
- nodeRepository().osVersions().resumeUpgradeOf(nodeType, active);
+ boolean resume = canUpgradeOsOf(nodeType);
+ nodeRepository().osVersions().resumeUpgradeOf(nodeType, resume);
}
return true;
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java
index 1216c060181..0385e2e3df6 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java
@@ -43,6 +43,11 @@ public class OsVersion {
return wanted.isPresent() && !current.equals(wanted);
}
+ /** Returns whether this is before the given version */
+ public boolean isBefore(Version version) {
+ return current.isEmpty() || current.get().isBefore(version);
+ }
+
/** Returns whether current version matches given version */
public boolean matches(Version version) {
return current.isPresent() && current.get().equals(version);
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java
index 03d04a5f6cf..74b288d77c5 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java
@@ -38,7 +38,7 @@ public class DelegatingUpgrader implements Upgrader {
NodeList activeNodes = nodeRepository.list().nodeType(target.nodeType()).state(Node.State.active);
int numberToUpgrade = Math.max(0, maxActiveUpgrades - activeNodes.changingOsVersionTo(target.version()).size());
NodeList nodesToUpgrade = activeNodes.not().changingOsVersionTo(target.version())
- .not().onOsVersion(target.version())
+ .osVersionIsBefore(target.version())
.byIncreasingOsVersion()
.first(numberToUpgrade);
if (nodesToUpgrade.size() == 0) return;
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java
index aebf14ab13f..b4e21b22cd2 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java
@@ -46,7 +46,7 @@ public class RetiringUpgrader implements Upgrader {
Instant retiredAt = target.lastRetiredAt().orElse(Instant.EPOCH);
if (now.isBefore(retiredAt.plus(nodeBudget))) return; // Budget has not been spent yet
- activeNodes.not().onOsVersion(target.version())
+ activeNodes.osVersionIsBefore(target.version())
.not().deprovisioning()
.byIncreasingOsVersion()
.first(1)
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java
index 914008af227..6a41e766ace 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java
@@ -38,7 +38,7 @@ public class OsVersionsTest {
private final ApplicationId infraApplication = ApplicationId.from("hosted-vespa", "infra", "default");
@Test
- public void versions() {
+ public void upgrade() {
var versions = new OsVersions(tester.nodeRepository(), new DelegatingUpgrader(tester.nodeRepository(), Integer.MAX_VALUE));
provisionInfraApplication(10);
Supplier<List<Node>> hostNodes = () -> tester.nodeRepository().getNodes(NodeType.host);
@@ -50,18 +50,28 @@ public class OsVersionsTest {
assertEquals(version1, versions.targetFor(NodeType.host).get());
assertTrue("Per-node wanted OS version remains unset", hostNodes.get().stream().allMatch(node -> node.status().osVersion().wanted().isEmpty()));
+ // One host upgrades to a later version outside the control of orchestration
+ Node hostOnLaterVersion = hostNodes.get().get(0);
+ setCurrentVersion(List.of(hostOnLaterVersion), Version.fromString("8.1"));
+
// Upgrade OS again
var version2 = Version.fromString("7.2");
versions.setTarget(NodeType.host, version2, Optional.empty(), false);
assertEquals(version2, versions.targetFor(NodeType.host).get());
- // Target can be (de)activated
+ // Resume upgrade
versions.resumeUpgradeOf(NodeType.host, true);
- assertTrue("Target version activated", hostNodes.get().stream()
- .allMatch(node -> node.status().osVersion().wanted().isPresent()));
+ List<Node> allHosts = hostNodes.get();
+ assertTrue("Wanted version is set", allHosts.stream()
+ .filter(node -> !node.equals(hostOnLaterVersion))
+ .allMatch(node -> node.status().osVersion().wanted().isPresent()));
+ assertTrue("Wanted version is not set for host on later version",
+ allHosts.get(0).status().osVersion().wanted().isEmpty());
+
+ // Halt upgrade
versions.resumeUpgradeOf(NodeType.host, false);
- assertTrue("Target version deactivated", hostNodes.get().stream()
- .allMatch(node -> node.status().osVersion().wanted().isEmpty()));
+ assertTrue("Wanted version is unset", hostNodes.get().stream()
+ .allMatch(node -> node.status().osVersion().wanted().isEmpty()));
// Downgrading fails
try {
diff --git a/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp b/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp
index 575033ad19a..3b42a399888 100644
--- a/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp
+++ b/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp
@@ -12,6 +12,8 @@ constexpr uint32_t LID_1 = 1u;
const std::vector<uint32_t> LIDV_2_1_3({2u, LID_1, 3u});
const std::vector<uint32_t> LIDV_2_3({2u, 3u});
+namespace proton {
+
std::ostream &
operator << (std::ostream & os, ILidCommitState::State state) {
switch (state) {
@@ -28,6 +30,8 @@ operator << (std::ostream & os, ILidCommitState::State state) {
return os;
}
+}
+
void
verifyPhase1ProduceAndNeedCommit(PendingLidTrackerBase & tracker, ILidCommitState::State expected) {
EXPECT_EQUAL(ILidCommitState::State::COMPLETED, tracker.getState());
diff --git a/searchcore/src/tests/proton/docsummary/docsummary.cpp b/searchcore/src/tests/proton/docsummary/docsummary.cpp
index 7d27c3b21f4..92117e174e9 100644
--- a/searchcore/src/tests/proton/docsummary/docsummary.cpp
+++ b/searchcore/src/tests/proton/docsummary/docsummary.cpp
@@ -31,6 +31,7 @@
#include <vespa/searchlib/transactionlog/nosyncproxy.h>
#include <vespa/searchlib/transactionlog/translogserver.h>
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/vespalib/encoding/base64.h>
#include <vespa/config-bucketspaces.h>
#include <vespa/vespalib/testkit/testapp.h>
diff --git a/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp b/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp
index 4a01d7ae3e8..18b3a5c5d8e 100644
--- a/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp
+++ b/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp
@@ -23,7 +23,6 @@
#include <vespa/searchcore/proton/server/i_feed_handler_owner.h>
#include <vespa/searchcore/proton/server/ireplayconfig.h>
#include <vespa/searchcore/proton/test/dummy_feed_view.h>
-#include <vespa/searchlib/common/idestructorcallback.h>
#include <vespa/searchlib/index/docbuilder.h>
#include <vespa/searchlib/index/dummyfileheadercontext.h>
#include <vespa/searchlib/transactionlog/translogserver.h>
diff --git a/searchcore/src/tests/proton/server/feedstates_test.cpp b/searchcore/src/tests/proton/server/feedstates_test.cpp
index fd1e24c1f17..15083975824 100644
--- a/searchcore/src/tests/proton/server/feedstates_test.cpp
+++ b/searchcore/src/tests/proton/server/feedstates_test.cpp
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
// Unit tests for feedstates.
-
#include <vespa/document/base/documentid.h>
#include <vespa/document/base/testdocrepo.h>
#include <vespa/document/bucket/bucketid.h>
@@ -102,11 +101,10 @@ struct RemoveOperationContext
RemoveOperationContext::RemoveOperationContext(search::SerialNum serial)
: doc_id("id:ns:doctypename::bar"),
op(BucketFactory::getBucketId(doc_id), Timestamp(10), doc_id),
- str(), packet()
+ str(), packet(std::make_unique<Packet>(0xf000))
{
op.serialize(str);
ConstBufferRef buf(str.data(), str.wp());
- packet = std::make_unique<Packet>();
packet->add(Packet::Entry(serial, FeedOperation::REMOVE, buf));
}
RemoveOperationContext::~RemoveOperationContext() = default;
diff --git a/searchcore/src/tests/proton/summaryengine/summaryengine.cpp b/searchcore/src/tests/proton/summaryengine/summaryengine.cpp
index 7cdd8d767c6..7e5e3527b1d 100644
--- a/searchcore/src/tests/proton/summaryengine/summaryengine.cpp
+++ b/searchcore/src/tests/proton/summaryengine/summaryengine.cpp
@@ -6,6 +6,7 @@
#include <vespa/searchlib/util/rawbuf.h>
#include <vespa/searchlib/util/slime_output_raw_buf_adapter.h>
#include <vespa/vespalib/data/databuffer.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/vespalib/util/compressor.h>
#include <vespa/searchsummary/docsummary/docsumwriter.h>
#include <vespa/metrics/metricset.h>
diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp
index 96a93f0ac16..028b5d38ae9 100644
--- a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp
@@ -564,6 +564,9 @@ DocumentDB::close()
// Abort any ongoing maintenance
stopMaintenance();
+ _visibility.commit();
+ _writeService.sync();
+
// The attributes in the ready sub db is also the total set of attributes.
DocumentDBTaggedMetrics &metrics = getMetrics();
_metricsWireService.cleanAttributes(metrics.ready.attributes);
@@ -905,6 +908,11 @@ DocumentDB::syncFeedView()
return;
IFeedView::SP oldFeedView(_feedView.get());
IFeedView::SP newFeedView(_subDBs.getFeedView());
+
+ _writeService.sync();
+ _visibility.commit();
+ _writeService.sync();
+
_feedView.set(newFeedView);
_feedHandler.setActiveFeedView(newFeedView.get());
_subDBs.createRetrievers();
@@ -980,6 +988,7 @@ void
DocumentDB::stopMaintenance()
{
_maintenanceController.stop();
+ _writeService.sync();
}
void
diff --git a/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp b/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp
index c14ed3bb1d9..d01c25d9c1e 100644
--- a/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp
@@ -82,6 +82,8 @@ public:
_commitTimeTracker(100ms)
{ }
+ ~TransactionLogReplayPacketHandler() override = default;
+
void replay(const PutOperation &op) override {
_feed_view_ptr->handlePut(FeedToken(), op);
}
@@ -153,6 +155,8 @@ ReplayTransactionLogState::ReplayTransactionLogState(
_packet_handler(std::make_unique<TransactionLogReplayPacketHandler>(feed_view_ptr, bucketDBHandler, replay_config, config_store))
{ }
+ReplayTransactionLogState::~ReplayTransactionLogState() = default;
+
void
ReplayTransactionLogState::receive(const PacketWrapper::SP &wrap, Executor &executor) {
EntryHandler closure = makeClosure(&startDispatch, _packet_handler.get());
diff --git a/searchcore/src/vespa/searchcore/proton/server/feedstates.h b/searchcore/src/vespa/searchcore/proton/server/feedstates.h
index bf376bb8065..2cf0ee1a4dd 100644
--- a/searchcore/src/vespa/searchcore/proton/server/feedstates.h
+++ b/searchcore/src/vespa/searchcore/proton/server/feedstates.h
@@ -55,6 +55,7 @@ public:
IReplayConfig &replay_config,
FeedConfigStore &config_store);
+ ~ReplayTransactionLogState() override;
void handleOperation(FeedToken, FeedOperationUP op) override {
throwExceptionInHandleOperation(_doc_type_name, *op);
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp
index f29e54ba725..e822b1de33e 100644
--- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp
@@ -6,6 +6,7 @@
#include "i_blockable_maintenance_job.h"
#include <vespa/searchcorespi/index/i_thread_service.h>
#include <vespa/vespalib/util/closuretask.h>
+#include <vespa/vespalib/util/lambdatask.h>
#include <vespa/vespalib/util/scheduledexecutor.h>
#include <vespa/log/log.h>
@@ -15,6 +16,7 @@ using document::BucketId;
using vespalib::Executor;
using vespalib::makeClosure;
using vespalib::makeTask;
+using vespalib::makeLambdaTask;
namespace proton {
@@ -84,8 +86,8 @@ MaintenanceController::registerJob(Executor & executor, IMaintenanceJob::UP job)
void
MaintenanceController::killJobs()
{
- // Called by master write thread during start/reconfig
- // Called by other thread during stop
+ // Called by master write thread
+ assert(_masterThread.isCurrentThread());
LOG(debug, "killJobs(): threadId=%zu", (size_t)FastOS_Thread::GetCurrentThreadId());
_periodicTimer.reset();
// No need to take _jobsLock as modification of _jobs also happens in master write thread.
@@ -94,24 +96,13 @@ MaintenanceController::killJobs()
}
_defaultExecutor.sync();
_defaultExecutor.sync();
- if (_masterThread.isCurrentThread()) {
- JobList tmpJobs = _jobs;
- {
- Guard guard(_jobsLock);
- _jobs.clear();
- }
- // Hold jobs until existing tasks have been drained
- _masterThread.execute(makeTask(makeClosure(this, &MaintenanceController::performHoldJobs, tmpJobs)));
- } else {
- // Wait for all tasks to be finished.
- // NOTE: We must sync 2 times as a task currently being executed can add a new
- // task to the executor as it might not see the new value of the stopped flag.
- _masterThread.sync();
- _masterThread.sync();
- // Clear jobs in master write thread, to avoid races
- _masterThread.execute(makeTask(makeClosure(this, &MaintenanceController::performClearJobs)));
- _masterThread.sync();
+ JobList tmpJobs = _jobs;
+ {
+ Guard guard(_jobsLock);
+ _jobs.clear();
}
+ // Hold jobs until existing tasks have been drained
+ _masterThread.execute(makeTask(makeClosure(this, &MaintenanceController::performHoldJobs, tmpJobs)));
}
void
@@ -123,21 +114,12 @@ MaintenanceController::performHoldJobs(JobList jobs)
}
void
-MaintenanceController::performClearJobs()
-{
- // Called by master write thread
- LOG(debug, "performClearJobs(): threadId=%zu", (size_t)FastOS_Thread::GetCurrentThreadId());
- Guard guard(_jobsLock);
- _jobs.clear();
-}
-
-
-void
MaintenanceController::stop()
{
assert(!_masterThread.isCurrentThread());
- _stopping = true;
- killJobs();
+ _masterThread.execute(makeLambdaTask([this]() { _stopping = true; killJobs(); }));
+ _masterThread.sync(); // Wait for killJobs()
+ _masterThread.sync(); // Wait for already scheduled maintenance jobs and performHoldJobs
}
void
diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h
index 3cfdeba4d34..ece92adebd0 100644
--- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h
+++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h
@@ -90,7 +90,6 @@ private:
void addJobsToPeriodicTimer();
void restart();
void notifyThawedBucket(const document::BucketId &bucket) override;
- void performClearJobs();
void performHoldJobs(JobList jobs);
void registerJob(vespalib::Executor & executor, IMaintenanceJob::UP job);
};
diff --git a/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp b/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp
index bbd02d7efce..baba74c482c 100644
--- a/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp
@@ -15,9 +15,8 @@ void TlcProxy::commit(search::SerialNum serialNum, search::transactionlog::Type
const vespalib::nbostream &buf, DoneCallback onDone)
{
Packet::Entry entry(serialNum, type, vespalib::ConstBufferRef(buf.data(), buf.size()));
- Packet packet;
+ Packet packet(entry.serializedSize());
packet.add(entry);
- packet.close();
_tlsDirectWriter.commit(_domain, packet, std::move(onDone));
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp b/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp
index a3524ae79f3..3a44af517ee 100644
--- a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp
@@ -2,10 +2,9 @@
#include "visibilityhandler.h"
#include <vespa/vespalib/util/isequencedtaskexecutor.h>
-#include <vespa/vespalib/util/closuretask.h>
+#include <vespa/vespalib/util/lambdatask.h>
-using vespalib::makeTask;
-using vespalib::makeClosure;
+using vespalib::makeLambdaTask;
namespace proton {
@@ -81,8 +80,7 @@ VisibilityHandler::startCommit(const std::lock_guard<std::mutex> &unused, bool f
(void) unused;
SerialNum current = _serial.getSerialNum();
if ((current > _lastCommitSerialNum) || force) {
- _writeService.master().execute(makeTask(makeClosure(this,
- &VisibilityHandler::performCommit, force)));
+ _writeService.master().execute(makeLambdaTask([this, force]() { performCommit(force);}));
return true;
}
return false;
@@ -95,8 +93,10 @@ VisibilityHandler::performCommit(bool force)
SerialNum current = _serial.getSerialNum();
if ((current > _lastCommitSerialNum) || force) {
IFeedView::SP feedView(_feedView.get());
- feedView->forceCommit(current);
- _lastCommitSerialNum = current;
+ if (feedView) {
+ feedView->forceCommit(current);
+ _lastCommitSerialNum = current;
+ }
}
}
diff --git a/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp b/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp
index 5bb36f9c828..edc3b86d9d3 100644
--- a/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp
+++ b/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp
@@ -20,7 +20,6 @@ using vespalib::Memory;
using vespalib::slime::Symbol;
using vespalib::slime::BinaryFormat;
using vespalib::slime::ArrayTraverser;
-using vespalib::SimpleBuffer;
using vespalib::DataBuffer;
using vespalib::ConstBufferRef;
using vespalib::compression::CompressionConfig;
diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt
index c8bfc0f1926..4d1a8a82211 100644
--- a/searchlib/CMakeLists.txt
+++ b/searchlib/CMakeLists.txt
@@ -217,6 +217,7 @@ vespa_define_module(
src/tests/sortspec
src/tests/stringenum
src/tests/tensor/dense_tensor_store
+ src/tests/tensor/direct_tensor_store
src/tests/tensor/distance_functions
src/tests/tensor/hnsw_index
src/tests/tensor/hnsw_saver
diff --git a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp
index 24919fb2341..04d2dfe4d52 100644
--- a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp
+++ b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp
@@ -491,7 +491,7 @@ BitVectorTest::test(BasicType bt,
v->asDocumentWeightAttribute();
if (dwa != NULL) {
search::IDocumentWeightAttribute::LookupResult lres =
- dwa->lookup(getSearchStr<VectorType>());
+ dwa->lookup(getSearchStr<VectorType>(), dwa->get_dictionary_snapshot());
typedef search::queryeval::DocumentWeightSearchIterator DWSI;
typedef search::queryeval::SearchIterator SI;
TermFieldMatchData md;
diff --git a/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp b/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp
index cf1506a9118..d8a1d03f1a8 100644
--- a/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp
+++ b/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp
@@ -3,6 +3,7 @@
#include <vespa/searchlib/attribute/attribute.h>
#include <vespa/searchlib/attribute/attributefactory.h>
#include <vespa/searchlib/attribute/attributeguard.h>
+#include <vespa/searchlib/attribute/attribute_read_guard.h>
#include <vespa/searchlib/attribute/attributememorysavetarget.h>
#include <vespa/searchlib/attribute/attributevector.h>
#include <vespa/searchlib/attribute/attributevector.hpp>
@@ -22,6 +23,7 @@
#include <vespa/searchlib/test/searchiteratorverifier.h>
#include <vespa/searchlib/util/randomgenerator.h>
#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/vespalib/test/insertion_operators.h>
#include <vespa/log/log.h>
LOG_SETUP("document_weight_iterator_test");
@@ -124,17 +126,17 @@ void verify_invalid_lookup(IDocumentWeightAttribute::LookupResult result) {
}
TEST_F("require that integer lookup works correctly", LongFixture) {
- verify_valid_lookup(f1.api->lookup("111"));
- verify_invalid_lookup(f1.api->lookup("222"));
+ verify_valid_lookup(f1.api->lookup("111", f1.api->get_dictionary_snapshot()));
+ verify_invalid_lookup(f1.api->lookup("222", f1.api->get_dictionary_snapshot()));
}
TEST_F("require string lookup works correctly", StringFixture) {
- verify_valid_lookup(f1.api->lookup("foo"));
- verify_invalid_lookup(f1.api->lookup("bar"));
+ verify_valid_lookup(f1.api->lookup("foo", f1.api->get_dictionary_snapshot()));
+ verify_invalid_lookup(f1.api->lookup("bar", f1.api->get_dictionary_snapshot()));
}
void verify_posting(const IDocumentWeightAttribute &api, const char *term) {
- auto result = api.lookup(term);
+ auto result = api.lookup(term, api.get_dictionary_snapshot());
ASSERT_TRUE(result.posting_idx.valid());
std::vector<DocumentWeightIterator> itr_store;
api.create(result.posting_idx, itr_store);
@@ -168,6 +170,53 @@ TEST_F("require that string iterators are created correctly", StringFixture) {
verify_posting(*f1.api, "foo");
}
+TEST_F("require that dictionary snapshot works", LongFixture)
+{
+ auto read_guard = f1.attr->makeReadGuard(false);
+ auto dictionary_snapshot = f1.api->get_dictionary_snapshot();
+ auto lookup1 = f1.api->lookup("111", dictionary_snapshot);
+ EXPECT_TRUE(lookup1.enum_idx.valid());
+ f1.attr->clearDoc(1);
+ f1.attr->clearDoc(5);
+ f1.attr->clearDoc(7);
+ f1.attr->commit();
+ auto lookup2 = f1.api->lookup("111", f1.api->get_dictionary_snapshot());
+ EXPECT_FALSE(lookup2.enum_idx.valid());
+ auto lookup3 = f1.api->lookup("111", dictionary_snapshot);
+ EXPECT_TRUE(lookup3.enum_idx.valid());
+ EXPECT_EQUAL(lookup1.enum_idx.ref(), lookup3.enum_idx.ref());
+}
+
+TEST_F("require that collect_folded works for string", StringFixture)
+{
+ StringAttribute *attr = static_cast<StringAttribute *>(f1.attr.get());
+ set_doc(attr, 2, "bar", 30);
+ attr->commit();
+ set_doc(attr, 3, "FOO", 30);
+ attr->commit();
+ auto dictionary_snapshot = f1.api->get_dictionary_snapshot();
+ auto lookup1 = f1.api->lookup("foo", dictionary_snapshot);
+ std::vector<vespalib::string> folded;
+ std::function<void(vespalib::datastore::EntryRef)> save_folded = [&folded,attr](vespalib::datastore::EntryRef enum_idx) { folded.emplace_back(attr->getFromEnum(enum_idx.ref())); };
+ f1.api->collect_folded(lookup1.enum_idx, dictionary_snapshot, save_folded);
+ std::vector<vespalib::string> expected_folded{"FOO", "foo"};
+ EXPECT_EQUAL(expected_folded, folded);
+}
+
+TEST_F("require that collect_folded works for integers", LongFixture)
+{
+ IntegerAttributeTemplate<int64_t> *attr = dynamic_cast<IntegerAttributeTemplate<int64_t> *>(f1.attr.get());
+ set_doc(attr, 2, int64_t(112), 30);
+ attr->commit();
+ auto dictionary_snapshot = f1.api->get_dictionary_snapshot();
+ auto lookup1 = f1.api->lookup("111", dictionary_snapshot);
+ std::vector<int64_t> folded;
+ std::function<void(vespalib::datastore::EntryRef)> save_folded = [&folded,attr](vespalib::datastore::EntryRef enum_idx) { folded.emplace_back(attr->getFromEnum(enum_idx.ref())); };
+ f1.api->collect_folded(lookup1.enum_idx, dictionary_snapshot, save_folded);
+ std::vector<int64_t> expected_folded{int64_t(111)};
+ EXPECT_EQUAL(expected_folded, folded);
+}
+
class Verifier : public search::test::SearchIteratorVerifier {
public:
Verifier();
@@ -176,7 +225,7 @@ public:
(void) strict;
const IDocumentWeightAttribute *api(_attr->asDocumentWeightAttribute());
ASSERT_TRUE(api != nullptr);
- auto dict_entry = api->lookup("123");
+ auto dict_entry = api->lookup("123", api->get_dictionary_snapshot());
ASSERT_TRUE(dict_entry.posting_idx.valid());
return std::make_unique<queryeval::DocumentWeightSearchIterator>(_tfmd, *api, dict_entry);
}
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index cc6b8e0ce29..7a200a46ab2 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -25,7 +25,8 @@ std::string get_source_dir() {
}
std::string source_dir = get_source_dir();
std::string vespa_dir = source_dir + "/" + "../../../../..";
-std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx";
+std::string simple_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/simple.onnx";
+std::string dynamic_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/dynamic.onnx";
uint32_t default_docid = 1;
@@ -97,4 +98,16 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) {
EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
}
+TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
+ add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
+ add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
+ add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]");
+ add_onnx("dynamic", dynamic_model);
+ compile(onnx_feature("dynamic"));
+ EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0));
+ EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp
index 9761b0da2d7..f2c02d02080 100644
--- a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp
+++ b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp
@@ -674,7 +674,7 @@ private:
MatchParams match_params(_dummy_heap, _dummy_heap.getMinScore(), 1.0, 1);
std::vector<IDocumentWeightAttribute::LookupResult> dict_entries;
for (size_t i = 0; i < _num_children; ++i) {
- dict_entries.push_back(_helper.dwa().lookup(vespalib::make_string("%zu", i).c_str()));
+ dict_entries.push_back(_helper.dwa().lookup(vespalib::make_string("%zu", i).c_str(), _helper.dwa().get_dictionary_snapshot()));
}
return create_wand(_use_dwa, _tfmd, match_params, _weights, dict_entries, _helper.dwa(), strict);
}
diff --git a/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt b/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt
new file mode 100644
index 00000000000..14a70f25e3c
--- /dev/null
+++ b/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(searchlib_direct_tensor_store_test_app TEST
+ SOURCES
+ direct_tensor_store_test.cpp
+ DEPENDS
+ searchlib
+ GTest::GTest
+)
+vespa_add_test(NAME searchlib_direct_tensor_store_test_app COMMAND searchlib_direct_tensor_store_test_app)
diff --git a/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
new file mode 100644
index 00000000000..1003e461676
--- /dev/null
+++ b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
@@ -0,0 +1,89 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/searchlib/tensor/direct_tensor_store.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/tensor/tensor.h>
+
+using namespace search::tensor;
+
+using vespalib::datastore::EntryRef;
+using vespalib::eval::TensorSpec;
+using vespalib::tensor::DefaultTensorEngine;
+using vespalib::tensor::Tensor;
+
+vespalib::string tensor_spec("tensor(x{})");
+
+Tensor::UP
+make_tensor(const TensorSpec& spec)
+{
+ auto value = DefaultTensorEngine::ref().from_spec(spec);
+ auto* tensor = dynamic_cast<Tensor*>(value.get());
+ assert(tensor != nullptr);
+ value.release();
+ return Tensor::UP(tensor);
+}
+
+Tensor::UP
+make_tensor(double value)
+{
+ return make_tensor(TensorSpec(tensor_spec).add({{"x", "a"}}, value));
+}
+
+class DirectTensorStoreTest : public ::testing::Test {
+public:
+ DirectTensorStore store;
+
+ DirectTensorStoreTest() : store() {}
+
+ virtual ~DirectTensorStoreTest() {
+ store.clearHoldLists();
+ }
+
+ void expect_tensor(const Tensor* exp, EntryRef ref) {
+ const auto* act = store.get_tensor(ref);
+ ASSERT_TRUE(act);
+ EXPECT_EQ(exp, act);
+ }
+};
+
+TEST_F(DirectTensorStoreTest, can_set_and_get_tensor)
+{
+ auto t = make_tensor(5);
+ auto* exp = t.get();
+ auto ref = store.set_tensor(std::move(t));
+ expect_tensor(exp, ref);
+}
+
+TEST_F(DirectTensorStoreTest, invalid_ref_returns_nullptr)
+{
+ const auto* t = store.get_tensor(EntryRef());
+ EXPECT_FALSE(t);
+}
+
+TEST_F(DirectTensorStoreTest, hold_adds_entry_to_hold_list)
+{
+ auto ref = store.set_tensor(make_tensor(5));
+ auto mem_1 = store.getMemoryUsage();
+ store.holdTensor(ref);
+ auto mem_2 = store.getMemoryUsage();
+ EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold());
+}
+
+TEST_F(DirectTensorStoreTest, move_allocates_new_entry_and_puts_old_entry_on_hold)
+{
+ auto t = make_tensor(5);
+ auto* exp = t.get();
+ auto ref_1 = store.set_tensor(std::move(t));
+ auto mem_1 = store.getMemoryUsage();
+
+ auto ref_2 = store.move(ref_1);
+ auto mem_2 = store.getMemoryUsage();
+ EXPECT_NE(ref_1, ref_2);
+ expect_tensor(exp, ref_1);
+ expect_tensor(exp, ref_2);
+ EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold());
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
+
diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
index 9c896396de3..a5e0e1e2b6a 100644
--- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
@@ -1,9 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <sys/types.h>
-#include <sys/stat.h>
#include <fcntl.h>
-#include <stdio.h>
+#include <cstdio>
#include <unistd.h>
#include <chrono>
#include <cstdlib>
@@ -24,6 +22,7 @@
#include <vespa/vespalib/util/blockingthreadstackexecutor.h>
#include <vespa/vespalib/util/generationhandler.h>
#include <vespa/vespalib/util/lambdatask.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/log/log.h>
LOG_SETUP("stress_hnsw_mt");
diff --git a/searchlib/src/tests/transactionlogstress/translogstress.cpp b/searchlib/src/tests/transactionlogstress/translogstress.cpp
index 81a3006dbff..013ca81dcc9 100644
--- a/searchlib/src/tests/transactionlogstress/translogstress.cpp
+++ b/searchlib/src/tests/transactionlogstress/translogstress.cpp
@@ -8,7 +8,6 @@
#include <vespa/searchlib/index/dummyfileheadercontext.h>
#include <vespa/fastos/app.h>
#include <iostream>
-#include <stdexcept>
#include <sstream>
#include <thread>
@@ -223,7 +222,6 @@ FeederThread::~FeederThread() = default;
void
FeederThread::commitPacket()
{
- _packet.close();
const vespalib::nbostream& stream = _packet.getHandle();
if (!_session->commit(ConstBufferRef(stream.data(), stream.size()))) {
throw std::runtime_error(vespalib::make_string
@@ -238,8 +236,9 @@ FeederThread::commitPacket()
bool
FeederThread::addEntry(const Packet::Entry & e)
{
- //LOG(info, "FeederThread: add %s", EntryPrinter::toStr(e).c_str());
- return _packet.add(e);
+ if (_packet.sizeBytes() > 0xf000) return false;
+ _packet.add(e);
+ return true;
}
void
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
index b9e4bf565ef..4ab80ebce7d 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
@@ -322,6 +322,7 @@ private:
std::vector<int32_t> _weights;
std::vector<IDocumentWeightAttribute::LookupResult> _terms;
const IDocumentWeightAttribute &_attr;
+ vespalib::datastore::EntryRef _dictionary_snapshot;
public:
DirectWeightedSetBlueprint(const FieldSpec &field, const IDocumentWeightAttribute &attr, size_t size_hint)
@@ -329,7 +330,8 @@ public:
_estimate(),
_weights(),
_terms(),
- _attr(attr)
+ _attr(attr),
+ _dictionary_snapshot(_attr.get_dictionary_snapshot())
{
set_allow_termwise_eval(true);
_weights.reserve(size_hint);
@@ -337,7 +339,7 @@ public:
}
void addTerm(const vespalib::string &term, int32_t weight) {
- IDocumentWeightAttribute::LookupResult result = _attr.lookup(term);
+ IDocumentWeightAttribute::LookupResult result = _attr.lookup(term, _dictionary_snapshot);
HitEstimate childEst(result.posting_size, (result.posting_size == 0));
if (!childEst.empty) {
if (_estimate.empty) {
@@ -394,6 +396,7 @@ private:
std::vector<int32_t> _weights;
std::vector<IDocumentWeightAttribute::LookupResult> _terms;
const IDocumentWeightAttribute &_attr;
+ vespalib::datastore::EntryRef _dictionary_snapshot;
public:
DirectWandBlueprint(const FieldSpec &field, const IDocumentWeightAttribute &attr, uint32_t scoresToTrack,
@@ -406,14 +409,16 @@ public:
_scoresAdjustFrequency(queryeval::DEFAULT_PARALLEL_WAND_SCORES_ADJUST_FREQUENCY),
_weights(),
_terms(),
- _attr(attr)
+ _attr(attr),
+ _dictionary_snapshot(_attr.get_dictionary_snapshot())
+
{
_weights.reserve(size_hint);
_terms.reserve(size_hint);
}
void addTerm(const vespalib::string &term, int32_t weight) {
- IDocumentWeightAttribute::LookupResult result = _attr.lookup(term);
+ IDocumentWeightAttribute::LookupResult result = _attr.lookup(term, _dictionary_snapshot);
HitEstimate childEst(result.posting_size, (result.posting_size == 0));
if (!childEst.empty) {
if (_estimate.empty) {
@@ -464,6 +469,7 @@ class DirectAttributeBlueprint : public queryeval::SimpleLeafBlueprint
private:
vespalib::string _attrName;
const IDocumentWeightAttribute &_attr;
+ vespalib::datastore::EntryRef _dictionary_snapshot;
IDocumentWeightAttribute::LookupResult _dict_entry;
public:
@@ -472,7 +478,8 @@ public:
: SimpleLeafBlueprint(field),
_attrName(name),
_attr(attr),
- _dict_entry(_attr.lookup(term))
+ _dictionary_snapshot(_attr.get_dictionary_snapshot()),
+ _dict_entry(_attr.lookup(term, _dictionary_snapshot))
{
setEstimate(HitEstimate(_dict_entry.posting_size, (_dict_entry.posting_size == 0)));
}
diff --git a/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h b/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h
index ed184a7370e..e0cfd446da5 100644
--- a/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h
+++ b/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h
@@ -4,6 +4,8 @@
#include "postinglisttraits.h"
+#include <functional>
+
namespace search {
namespace query { class Node; }
@@ -17,11 +19,18 @@ struct IDocumentWeightAttribute
const uint32_t posting_size;
const int32_t min_weight;
const int32_t max_weight;
- LookupResult() : posting_idx(), posting_size(0), min_weight(0), max_weight(0) {}
- LookupResult(vespalib::datastore::EntryRef posting_idx_in, uint32_t posting_size_in, int32_t min_weight_in, int32_t max_weight_in)
- : posting_idx(posting_idx_in), posting_size(posting_size_in), min_weight(min_weight_in), max_weight(max_weight_in) {}
+ const vespalib::datastore::EntryRef enum_idx;
+ LookupResult() : posting_idx(), posting_size(0), min_weight(0), max_weight(0), enum_idx() {}
+ LookupResult(vespalib::datastore::EntryRef posting_idx_in, uint32_t posting_size_in, int32_t min_weight_in, int32_t max_weight_in, vespalib::datastore::EntryRef enum_idx_in)
+ : posting_idx(posting_idx_in), posting_size(posting_size_in), min_weight(min_weight_in), max_weight(max_weight_in), enum_idx(enum_idx_in) {}
};
- virtual LookupResult lookup(const vespalib::string &term) const = 0;
+ virtual vespalib::datastore::EntryRef get_dictionary_snapshot() const = 0;
+ virtual LookupResult lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const = 0;
+ /*
+ * Collect enum indexes (via callback) where folded
+ * (e.g. lowercased) value equals the folded value for enum_idx.
+ */
+ virtual void collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const = 0;
virtual void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const = 0;
virtual DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const = 0;
virtual ~IDocumentWeightAttribute() {}
diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h
index fa962f8d469..c09366cdaea 100644
--- a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h
+++ b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h
@@ -32,12 +32,14 @@ public:
using EnumStoreBatchUpdater = typename EnumStore::BatchUpdater;
private:
- struct DocumentWeightAttributeAdapter : IDocumentWeightAttribute {
+ struct DocumentWeightAttributeAdapter final : IDocumentWeightAttribute {
const MultiValueNumericPostingAttribute &self;
DocumentWeightAttributeAdapter(const MultiValueNumericPostingAttribute &self_in) : self(self_in) {}
- virtual LookupResult lookup(const vespalib::string &term) const override final;
- virtual void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override final;
- virtual DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override final;
+ vespalib::datastore::EntryRef get_dictionary_snapshot() const override;
+ LookupResult lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const override;
+ void collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const override;
+ void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override;
+ DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override;
};
DocumentWeightAttributeAdapter _document_weight_attribute_adapter;
diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp
index 283c3da00b1..1fd1cd09bea 100644
--- a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp
@@ -83,11 +83,18 @@ MultiValueNumericPostingAttribute<B, M>::getSearch(QueryTermSimpleUP qTerm,
}
template <typename B, typename M>
+vespalib::datastore::EntryRef
+MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::get_dictionary_snapshot() const
+{
+ const Dictionary &dictionary = self._enumStore.get_posting_dictionary();
+ return dictionary.getFrozenView().getRoot();
+}
+
+template <typename B, typename M>
IDocumentWeightAttribute::LookupResult
-MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term) const
+MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const
{
const Dictionary &dictionary = self._enumStore.get_posting_dictionary();
- const FrozenDictionary frozenDictionary(dictionary.getFrozenView());
DictionaryConstIterator dictItr(vespalib::btree::BTreeNode::Ref(), dictionary.getAllocator());
char *end = nullptr;
@@ -95,13 +102,13 @@ MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup(
if (*end == '\0') {
auto comp = self._enumStore.make_comparator(int_term);
- dictItr.lower_bound(frozenDictionary.getRoot(), EnumIndex(), comp);
+ dictItr.lower_bound(dictionary_snapshot, EnumIndex(), comp);
if (dictItr.valid() && !comp(EnumIndex(), dictItr.getKey())) {
vespalib::datastore::EntryRef pidx(dictItr.getData());
if (pidx.valid()) {
const PostingList &plist = self.getPostingList();
auto minmax = plist.getAggregated(pidx);
- return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax());
+ return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax(), dictItr.getKey());
}
}
}
@@ -110,6 +117,14 @@ MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup(
template <typename B, typename M>
void
+MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback)const
+{
+ (void) dictionary_snapshot;
+ callback(enum_idx);
+}
+
+template <typename B, typename M>
+void
MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const
{
assert(idx.valid());
diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h
index c755c5cb649..142879f4578 100644
--- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h
+++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h
@@ -30,12 +30,14 @@ public:
using EnumStoreBatchUpdater = typename EnumStore::BatchUpdater;
private:
- struct DocumentWeightAttributeAdapter : IDocumentWeightAttribute {
+ struct DocumentWeightAttributeAdapter final : IDocumentWeightAttribute {
const MultiValueStringPostingAttributeT &self;
DocumentWeightAttributeAdapter(const MultiValueStringPostingAttributeT &self_in) : self(self_in) {}
- virtual LookupResult lookup(const vespalib::string &term) const override final;
- virtual void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override final;
- virtual DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override final;
+ vespalib::datastore::EntryRef get_dictionary_snapshot() const override;
+ LookupResult lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const override;
+ void collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const override;
+ void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override;
+ DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override;
};
DocumentWeightAttributeAdapter _document_weight_attribute_adapter;
diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp
index 7bc62169b3c..4263eacfa52 100644
--- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp
@@ -99,21 +99,28 @@ MultiValueStringPostingAttributeT<B, T>::getSearch(QueryTermSimpleUP qTerm,
template <typename B, typename T>
+vespalib::datastore::EntryRef
+MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::get_dictionary_snapshot() const
+{
+ const Dictionary &dictionary = self._enumStore.get_posting_dictionary();
+ return dictionary.getFrozenView().getRoot();
+}
+
+template <typename B, typename T>
IDocumentWeightAttribute::LookupResult
-MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term) const
+MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const
{
const Dictionary &dictionary = self._enumStore.get_posting_dictionary();
- const FrozenDictionary frozenDictionary(dictionary.getFrozenView());
DictionaryConstIterator dictItr(vespalib::btree::BTreeNode::Ref(), dictionary.getAllocator());
auto comp = self._enumStore.make_folded_comparator(term.c_str());
- dictItr.lower_bound(frozenDictionary.getRoot(), EnumIndex(), comp);
+ dictItr.lower_bound(dictionary_snapshot, EnumIndex(), comp);
if (dictItr.valid() && !comp(EnumIndex(), dictItr.getKey())) {
vespalib::datastore::EntryRef pidx(dictItr.getData());
if (pidx.valid()) {
const PostingList &plist = self.getPostingList();
auto minmax = plist.getAggregated(pidx);
- return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax());
+ return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax(), dictItr.getKey());
}
}
return LookupResult();
@@ -121,6 +128,20 @@ MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::lookup(
template <typename B, typename T>
void
+MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const
+{
+ const Dictionary &dictionary = self._enumStore.get_posting_dictionary();
+ DictionaryConstIterator dictItr(vespalib::btree::BTreeNode::Ref(), dictionary.getAllocator());
+ auto comp = self._enumStore.make_folded_comparator();
+ dictItr.lower_bound(dictionary_snapshot, enum_idx, comp);
+ while (dictItr.valid() && !comp(enum_idx, dictItr.getKey())) {
+ callback(dictItr.getKey());
+ ++dictItr;
+ }
+}
+
+template <typename B, typename T>
+void
MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const
{
assert(idx.valid());
diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp
index cd53253ad0a..53f88246a0a 100644
--- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp
@@ -112,13 +112,14 @@ ReferenceAttribute::buildReverseMapping(EntryRef newRef, const std::vector<Rever
void
ReferenceAttribute::buildReverseMapping()
{
- std::vector<std::pair<EntryRef, uint32_t>> indices;
+ using EntryPair = std::pair<EntryRef, uint32_t>;
+ std::vector<EntryPair, vespalib::allocator_large<EntryPair>> indices;
uint32_t numDocs = _indices.size();
indices.reserve(numDocs);
for (uint32_t lid = 0; lid < numDocs; ++lid) {
EntryRef ref = _indices[lid];
if (ref.valid()) {
- indices.push_back(std::make_pair(ref, lid));
+ indices.emplace_back(ref, lid);
}
}
std::sort(indices.begin(), indices.end());
@@ -200,8 +201,7 @@ ReferenceAttribute::onUpdateStat()
std::unique_ptr<AttributeSaver>
ReferenceAttribute::onInitSave(vespalib::stringref fileName)
{
- vespalib::GenerationHandler::Guard guard(this->getGenerationHandler().
- takeGuard());
+ vespalib::GenerationHandler::Guard guard(this->getGenerationHandler().takeGuard());
return std::make_unique<ReferenceAttributeSaver>
(std::move(guard),
createAttributeHeader(fileName),
@@ -221,8 +221,7 @@ ReferenceAttribute::onLoad()
assert(attrReader.getEnumerated());
assert(!attrReader.hasIdx());
size_t numDocs(0);
- uint64_t numValues(0);
- numValues = attrReader.getEnumCount();
+ uint64_t numValues = attrReader.getEnumCount();
numDocs = numValues;
auto udatBuffer = attribute::LoadUtils::loadUDAT(*this);
const GenericHeader &header = udatBuffer->getHeader();
@@ -367,13 +366,13 @@ class TargetLidPopulator : public IGidToLidMapperVisitor
{
ReferenceAttribute &_attr;
public:
- TargetLidPopulator(ReferenceAttribute &attr)
+ explicit TargetLidPopulator(ReferenceAttribute &attr)
: IGidToLidMapperVisitor(),
_attr(attr)
{
}
- virtual ~TargetLidPopulator() override { }
- virtual void visit(const document::GlobalId &gid, uint32_t lid) const override {
+ ~TargetLidPopulator() override = default;
+ void visit(const document::GlobalId &gid, uint32_t lid) const override {
_attr.notifyReferencedPutNoCommit(gid, lid);
}
};
diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.h b/searchlib/src/vespa/searchlib/attribute/reference_attribute.h
index 706abc53819..1c138abf989 100644
--- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.h
+++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.h
@@ -7,6 +7,7 @@
#include "reference_mappings.h"
#include <vespa/vespalib/datastore/unique_store.h>
#include <vespa/vespalib/util/rcuvector.h>
+#include <vespa/vespalib/stllike/allocator.h>
namespace search { class IGidToLidMapperFactory; }
@@ -28,7 +29,7 @@ public:
using GlobalId = document::GlobalId;
using ReferenceStore = vespalib::datastore::UniqueStore<Reference>;
using ReferenceStoreIndices = vespalib::RcuVectorBase<EntryRef>;
- using IndicesCopyVector = vespalib::Array<EntryRef>;
+ using IndicesCopyVector = std::vector<EntryRef, vespalib::allocator_large<EntryRef>>;
// Class used to map from target lid to source lids
using ReverseMapping = vespalib::btree::BTreeStore<uint32_t, vespalib::btree::BTreeNoLeafData,
vespalib::btree::NoAggregated,
@@ -45,14 +46,14 @@ private:
std::shared_ptr<IGidToLidMapperFactory> _gidToLidMapperFactory;
ReferenceMappings _referenceMappings;
- virtual void onAddDocs(DocId docIdLimit) override;
- virtual void removeOldGenerations(generation_t firstUsed) override;
- virtual void onGenerationChange(generation_t generation) override;
- virtual void onCommit() override;
- virtual void onUpdateStat() override;
- virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override;
- virtual bool onLoad() override;
- virtual uint64_t getUniqueValueCount() const override;
+ void onAddDocs(DocId docIdLimit) override;
+ void removeOldGenerations(generation_t firstUsed) override;
+ void onGenerationChange(generation_t generation) override;
+ void onCommit() override;
+ void onUpdateStat() override;
+ std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override;
+ bool onLoad() override;
+ uint64_t getUniqueValueCount() const override;
bool considerCompact(const CompactionStrategy &compactionStrategy);
void compactWorst();
diff --git a/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h b/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h
index 00645810d62..e8341901585 100644
--- a/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h
+++ b/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h
@@ -6,6 +6,7 @@
#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/common/tunefileinfo.h>
#include <vespa/vespalib/stllike/string.h>
+#include <vespa/vespalib/stllike/allocator.h>
#include "bitvectoridxfile.h"
namespace search::diskindex {
@@ -49,7 +50,7 @@ public:
class BitVectorCandidate
{
private:
- std::vector<uint32_t> _array;
+ std::vector<uint32_t, vespalib::allocator_large<uint32_t>> _array;
uint64_t _numDocs;
uint32_t _bitVectorLimit;
BitVector::UP _bv;
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index f6d5c37b61d..7433021b9b6 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -3,7 +3,6 @@
#include "onnx_feature.h"
#include <vespa/searchlib/fef/properties.h>
#include <vespa/searchlib/fef/featureexecutor.h>
-#include <vespa/eval/tensor/dense/onnx_wrapper.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -23,7 +22,7 @@ using vespalib::eval::ValueType;
using vespalib::make_string_short::fmt;
using vespalib::tensor::DenseTensorView;
using vespalib::tensor::MutableDenseTensorView;
-using vespalib::tensor::OnnxWrapper;
+using vespalib::tensor::Onnx;
namespace search::features {
@@ -33,37 +32,28 @@ namespace search::features {
class OnnxFeatureExecutor : public FeatureExecutor
{
private:
- const OnnxWrapper &_model;
- OnnxWrapper::Params _params;
- OnnxWrapper::Result _result;
- std::vector<MutableDenseTensorView> _views;
-
+ Onnx::EvalContext _eval_context;
public:
- OnnxFeatureExecutor(const OnnxWrapper &model)
- : _model(model), _params(), _result(OnnxWrapper::Result::make_empty()), _views()
- {
- _views.reserve(_model.outputs().size());
- for (const auto &output: _model.outputs()) {
- _views.emplace_back(output.make_compatible_type());
- }
- }
+ OnnxFeatureExecutor(const Onnx &model, const Onnx::WireInfo &wire_info)
+ : _eval_context(model, wire_info) {}
bool isPure() override { return true; }
- void execute(uint32_t) override {
- _params = OnnxWrapper::Params();
- for (size_t i = 0; i < _model.inputs().size(); ++i) {
- _params.bind(i, static_cast<const DenseTensorView&>(inputs().get_object(i).get()));
+ void handle_bind_outputs(vespalib::ArrayRef<fef::NumberOrObject>) override {
+ for (size_t i = 0; i < _eval_context.num_results(); ++i) {
+ outputs().set_object(i, _eval_context.get_result(i));
}
- _result = _model.eval(_params);
- for (size_t i = 0; i < _model.outputs().size(); ++i) {
- _result.get(i, _views[i]);
- outputs().set_object(i, _views[i]);
+ }
+ void execute(uint32_t) override {
+ for (size_t i = 0; i < _eval_context.num_params(); ++i) {
+ _eval_context.bind_param(i, inputs().get_object(i).get());
}
+ _eval_context.eval();
}
};
OnnxBlueprint::OnnxBlueprint()
: Blueprint("onnxModel"),
- _model(nullptr)
+ _model(nullptr),
+ _wire_info()
{
}
@@ -74,24 +64,25 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
const ParameterList &params)
{
auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
- ? OnnxWrapper::Optimize::DISABLE
- : OnnxWrapper::Optimize::ENABLE;
+ ? Onnx::Optimize::DISABLE
+ : Onnx::Optimize::ENABLE;
// Note: Using the fileref property with the model name as
// fallback to get a file name. This needs to be replaced with an
// actual file reference obtained through config when available.
vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue());
try {
- _model = std::make_unique<OnnxWrapper>(file_name, optimize);
+ _model = std::make_unique<Onnx>(file_name, optimize);
} catch (std::exception &ex) {
return fail("Model setup failed: %s", ex.what());
}
+ Onnx::WirePlanner planner;
for (size_t i = 0; i < _model->inputs().size(); ++i) {
const auto &model_input = _model->inputs()[i];
if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", model_input.name.c_str()), AcceptInput::OBJECT)) {
const FeatureType &feature_input = maybe_input.value();
assert(feature_input.is_object());
- if (!model_input.is_compatible(feature_input.type())) {
+ if (!planner.bind_input_type(feature_input.type(), model_input)) {
return fail("incompatible type for input '%s': %s -> %s", model_input.name.c_str(),
feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str());
}
@@ -99,13 +90,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
}
for (size_t i = 0; i < _model->outputs().size(); ++i) {
const auto &model_output = _model->outputs()[i];
- ValueType output_type = model_output.make_compatible_type();
+ ValueType output_type = planner.make_output_type(model_output);
if (output_type.is_error()) {
return fail("unable to make compatible type for output '%s': %s -> error",
model_output.name.c_str(), model_output.type_as_string().c_str());
}
describeOutput(model_output.name, "output from onnx model", FeatureType::object(output_type));
}
+ _wire_info = planner.get_wire_info(*_model);
return true;
}
@@ -113,7 +105,7 @@ FeatureExecutor &
OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const
{
assert(_model);
- return stash.create<OnnxFeatureExecutor>(*_model);
+ return stash.create<OnnxFeatureExecutor>(*_model, _wire_info);
}
}
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h
index eb6e368ffbd..19c6338d2ee 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.h
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h
@@ -3,8 +3,7 @@
#pragma once
#include <vespa/searchlib/fef/blueprint.h>
-
-namespace vespalib::tensor { class OnnxWrapper; }
+#include <vespa/eval/tensor/dense/onnx_wrapper.h>
namespace search::features {
@@ -13,7 +12,9 @@ namespace search::features {
**/
class OnnxBlueprint : public fef::Blueprint {
private:
- std::unique_ptr<vespalib::tensor::OnnxWrapper> _model;
+ using Onnx = vespalib::tensor::Onnx;
+ std::unique_ptr<Onnx> _model;
+ Onnx::WireInfo _wire_info;
public:
OnnxBlueprint();
~OnnxBlueprint() override;
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 35615b255c0..851400e4806 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -5,6 +5,8 @@ vespa_add_library(searchlib_tensor OBJECT
dense_tensor_attribute.cpp
dense_tensor_attribute_saver.cpp
dense_tensor_store.cpp
+ direct_tensor_attribute.cpp
+ direct_tensor_store.cpp
distance_function_factory.cpp
distance_functions.cpp
generic_tensor_attribute.cpp
@@ -20,6 +22,7 @@ vespa_add_library(searchlib_tensor OBJECT
nearest_neighbor_index.cpp
nearest_neighbor_index_saver.cpp
tensor_attribute.cpp
+ tensor_deserialize.cpp
tensor_store.cpp
DEPENDS
)
diff --git a/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h
new file mode 100644
index 00000000000..7c34b60e93d
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h
@@ -0,0 +1,26 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/fastlib/io/bufferedfile.h>
+#include <vespa/searchlib/attribute/readerbase.h>
+#include <vespa/searchlib/util/fileutil.h>
+
+namespace search::tensor {
+
+/**
+ * Utility for reading an attribute data file where
+ * the format is a sequence of blobs (size, byte[size]).
+ **/
+class BlobSequenceReader : public ReaderBase
+{
+private:
+ FileReader<uint32_t> _sizeReader;
+public:
+ BlobSequenceReader(AttributeVector &attr)
+ : ReaderBase(attr),
+ _sizeReader(*_datFile)
+ { }
+ uint32_t getNextSize() { return _sizeReader.readHostOrder(); }
+ void readBlob(void *buf, size_t len) { _datFile->ReadBuf(buf, len); }
+};
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
index 76533839de7..37a042d4e7f 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
@@ -30,26 +30,26 @@ namespace {
constexpr uint32_t DENSE_TENSOR_ATTRIBUTE_VERSION = 1;
const vespalib::string tensorTypeTag("tensortype");
-class TensorReader : public ReaderBase
+class BlobSequenceReader : public ReaderBase
{
private:
static constexpr uint8_t tensorIsNotPresent = 0;
static constexpr uint8_t tensorIsPresent = 1;
public:
- TensorReader(AttributeVector &attr);
- ~TensorReader();
+ BlobSequenceReader(AttributeVector &attr);
+ ~BlobSequenceReader();
bool is_present();
void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); }
};
-TensorReader::TensorReader(AttributeVector &attr)
+BlobSequenceReader::BlobSequenceReader(AttributeVector &attr)
: ReaderBase(attr)
{
}
-TensorReader::~TensorReader() = default;
+BlobSequenceReader::~BlobSequenceReader() = default;
bool
-TensorReader::is_present() {
+BlobSequenceReader::is_present() {
unsigned char detect;
_datFile->ReadBuf(&detect, sizeof(detect));
if (detect == tensorIsNotPresent) {
@@ -190,7 +190,7 @@ DenseTensorAttribute::getTensor(DocId docId, MutableDenseTensorView &tensor) con
bool
DenseTensorAttribute::onLoad()
{
- TensorReader tensorReader(*this);
+ BlobSequenceReader tensorReader(*this);
if (!tensorReader.hasData()) {
return false;
}
diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp
new file mode 100644
index 00000000000..f53d42442ba
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp
@@ -0,0 +1,52 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "direct_tensor_attribute.h"
+
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/fastlib/io/bufferedfile.h>
+#include <vespa/searchlib/attribute/readerbase.h>
+#include <vespa/searchlib/util/fileutil.h>
+#include <vespa/vespalib/util/array.h>
+
+#include "blob_sequence_reader.h"
+#include "tensor_deserialize.h"
+
+using vespalib::tensor::Tensor;
+
+namespace search::tensor {
+
+constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0;
+
+bool
+DirectTensorAttribute::onLoad()
+{
+ BlobSequenceReader tensorReader(*this);
+ if (!tensorReader.hasData()) {
+ return false;
+ }
+ setCreateSerialNum(tensorReader.getCreateSerialNum());
+ assert(tensorReader.getVersion() == TENSOR_ATTRIBUTE_VERSION);
+ uint32_t numDocs = tensorReader.getDocIdLimit();
+ vespalib::Array<char> buffer(1024);
+ for (uint32_t lid = 0; lid < numDocs; ++lid) {
+ uint32_t tensorSize = tensorReader.getNextSize();
+ if (tensorSize != 0) {
+ if (tensorSize > buffer.size()) {
+ buffer.resize(tensorSize + 1024);
+ }
+ tensorReader.readBlob(&buffer[0], tensorSize);
+ setTensor(lid, deserialize_tensor(&buffer[0], tensorSize));
+ }
+ }
+ setNumDocs(numDocs);
+ setCommittedDocIdLimit(numDocs);
+ return true;
+}
+
+void
+DirectTensorAttribute::setTensor(DocId , std::unique_ptr<Tensor> )
+{
+ // XXX missing
+}
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h
new file mode 100644
index 00000000000..ae3cb222dba
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h
@@ -0,0 +1,25 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "tensor_attribute.h"
+
+namespace search::tensor {
+
+class DirectTensorAttribute : public TensorAttribute
+{
+ // XXX must have some sort of TensorStore here
+public:
+ DirectTensorAttribute(vespalib::stringref baseFileName, const Config &cfg);
+ virtual ~DirectTensorAttribute();
+ virtual void setTensor(DocId docId, const Tensor &tensor) override;
+ virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override;
+ virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override;
+ virtual bool onLoad() override;
+ virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override;
+ virtual void compactWorst() override;
+
+ void setTensor(DocId docId, std::unique_ptr<Tensor> tensor);
+};
+
+} // namespace search::tensor
diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp
new file mode 100644
index 00000000000..4e79315d4b1
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp
@@ -0,0 +1,62 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "direct_tensor_store.h"
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/vespalib/datastore/datastore.hpp>
+
+using vespalib::datastore::EntryRef;
+
+namespace search::tensor {
+
+constexpr size_t MIN_BUFFER_ARRAYS = 8192;
+
+DirectTensorStore::DirectTensorStore()
+ : TensorStore(_concrete_store),
+ _concrete_store(MIN_BUFFER_ARRAYS)
+{
+}
+
+const vespalib::tensor::Tensor*
+DirectTensorStore::get_tensor(EntryRef ref) const
+{
+ if (!ref.valid()) {
+ return nullptr;
+ }
+ auto entry = _concrete_store.getEntry(ref);
+ assert(entry);
+ return entry.get();
+}
+
+EntryRef
+DirectTensorStore::set_tensor(std::unique_ptr<Tensor> tensor)
+{
+ assert(tensor);
+ // TODO: Account for heap allocated memory
+ return _concrete_store.addEntry(TensorSP(tensor.release()));
+}
+
+void
+DirectTensorStore::holdTensor(EntryRef ref)
+{
+ if (!ref.valid()) {
+ return;
+ }
+ // TODO: Account for heap allocated memory
+ _concrete_store.holdElem(ref, 1);
+}
+
+EntryRef
+DirectTensorStore::move(EntryRef ref)
+{
+ if (!ref.valid()) {
+ return EntryRef();
+ }
+ auto old_tensor = _concrete_store.getEntry(ref);
+ assert(old_tensor);
+ // TODO: Account for heap allocated memory (regular + hold)
+ auto new_ref = _concrete_store.addEntry(old_tensor);
+ _concrete_store.holdElem(ref, 1);
+ return new_ref;
+}
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h
new file mode 100644
index 00000000000..1073780a313
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h
@@ -0,0 +1,34 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "tensor_store.h"
+#include <memory>
+
+namespace search::tensor {
+
+/**
+ * Class for storing heap allocated tensors, referenced by EntryRefs.
+ *
+ * Shared pointers to the tensors are stored in an underlying data store.
+ */
+class DirectTensorStore : public TensorStore {
+private:
+ // Note: Must use SP (instead of UP) because of fallbackCopy() and initializeReservedElements() in BufferType,
+ // and implementation of move().
+ using TensorSP = std::shared_ptr<Tensor>;
+ using DataStoreType = vespalib::datastore::DataStore<TensorSP>;
+
+ DataStoreType _concrete_store;
+
+public:
+ DirectTensorStore();
+
+ const Tensor* get_tensor(EntryRef ref) const;
+ EntryRef set_tensor(std::unique_ptr<Tensor> tensor);
+
+ void holdTensor(EntryRef ref) override;
+ EntryRef move(EntryRef ref) override;
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
index aac199ae818..6864fb52120 100644
--- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
@@ -3,6 +3,7 @@
#include "generic_tensor_attribute.h"
#include "generic_tensor_attribute_saver.h"
#include "tensor_attribute.hpp"
+#include "blob_sequence_reader.h"
#include <vespa/eval/tensor/tensor.h>
#include <vespa/fastlib/io/bufferedfile.h>
#include <vespa/searchlib/attribute/readerbase.h>
@@ -18,19 +19,6 @@ namespace {
constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0;
-class TensorReader : public ReaderBase
-{
-private:
- FileReader<uint32_t> _tensorSizeReader;
-public:
- TensorReader(AttributeVector &attr)
- : ReaderBase(attr),
- _tensorSizeReader(*_datFile)
- { }
- uint32_t getNextTensorSize() { return _tensorSizeReader.readHostOrder(); }
- void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); }
-};
-
}
GenericTensorAttribute::GenericTensorAttribute(stringref name, const Config &cfg)
@@ -76,7 +64,7 @@ GenericTensorAttribute::getTensor(DocId, vespalib::tensor::MutableDenseTensorVie
bool
GenericTensorAttribute::onLoad()
{
- TensorReader tensorReader(*this);
+ BlobSequenceReader tensorReader(*this);
if (!tensorReader.hasData()) {
return false;
}
@@ -86,10 +74,10 @@ GenericTensorAttribute::onLoad()
_refVector.reset();
_refVector.unsafe_reserve(numDocs);
for (uint32_t lid = 0; lid < numDocs; ++lid) {
- uint32_t tensorSize = tensorReader.getNextTensorSize();
+ uint32_t tensorSize = tensorReader.getNextSize();
auto raw = _genericTensorStore.allocRawBuffer(tensorSize);
if (tensorSize != 0) {
- tensorReader.readTensor(raw.data, tensorSize);
+ tensorReader.readBlob(raw.data, tensorSize);
}
_refVector.push_back(raw.ref);
}
diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp
index f19bef3ff21..8c695c32719 100644
--- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp
@@ -1,15 +1,14 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "generic_tensor_store.h"
+#include "tensor_deserialize.h"
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/serialization/typed_binary_format.h>
-#include <vespa/document/util/serializableexceptions.h>
#include <vespa/vespalib/datastore/datastore.hpp>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/macro.h>
-using document::DeserializeException;
using vespalib::datastore::Handle;
using vespalib::tensor::Tensor;
using vespalib::tensor::TypedBinaryFormat;
@@ -95,14 +94,7 @@ GenericTensorStore::getTensor(EntryRef ref) const
if (raw.second == 0u) {
return std::unique_ptr<Tensor>();
}
- vespalib::nbostream wrapStream(raw.first, raw.second);
- auto tensor = TypedBinaryFormat::deserialize(wrapStream);
- if (wrapStream.size() != 0) {
- throw DeserializeException("Leftover bytes deserializing "
- "tensor attribute value.",
- VESPA_STRLOC);
- }
- return tensor;
+ return deserialize_tensor(raw.first, raw.second);
}
TensorStore::EntryRef
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp
new file mode 100644
index 00000000000..7998fba5941
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp
@@ -0,0 +1,24 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/document/util/serializableexceptions.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/vespalib/objects/nbostream.h>
+
+using document::DeserializeException;
+using vespalib::tensor::Tensor;
+using vespalib::tensor::TypedBinaryFormat;
+
+namespace search::tensor {
+
+std::unique_ptr<Tensor> deserialize_tensor(const void *data, size_t size)
+{
+ vespalib::nbostream wrapStream(data, size);
+ auto tensor = TypedBinaryFormat::deserialize(wrapStream);
+ if (wrapStream.size() != 0) {
+ throw DeserializeException("Leftover bytes deserializing tensor attribute value.", VESPA_STRLOC);
+ }
+ return tensor;
+}
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h
new file mode 100644
index 00000000000..f1dfa1ca173
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h
@@ -0,0 +1,10 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/tensor/tensor.h>
+
+namespace search::tensor {
+
+extern std::unique_ptr<vespalib::tensor::Tensor>
+deserialize_tensor(const void *data, size_t size);
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h b/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h
index cabb108d2e1..2ef03ba97ef 100644
--- a/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h
+++ b/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h
@@ -63,7 +63,7 @@ public:
(void) strict;
std::vector<DocumentWeightIterator> children;
for (size_t i = 0; i < _num_children; ++i) {
- auto dict_entry = _helper.dwa().lookup(vespalib::make_string("%zu", i).c_str());
+ auto dict_entry = _helper.dwa().lookup(vespalib::make_string("%zu", i).c_str(), _helper.dwa().get_dictionary_snapshot());
_helper.dwa().create(dict_entry.posting_idx, children);
}
return create(std::move(children));
diff --git a/searchlib/src/vespa/searchlib/transactionlog/common.cpp b/searchlib/src/vespa/searchlib/transactionlog/common.cpp
index a5eaa61af12..ee7d265427c 100644
--- a/searchlib/src/vespa/searchlib/transactionlog/common.cpp
+++ b/searchlib/src/vespa/searchlib/transactionlog/common.cpp
@@ -22,7 +22,8 @@ int makeDirectory(const char * dir)
return retval;
}
-int64_t SerialNumRange::cmp(const SerialNumRange & b) const
+int64_t
+SerialNumRange::cmp(const SerialNumRange & b) const
{
int64_t diff(0);
if ( ! (contains(b) || b.contains(*this)) ) {
@@ -71,7 +72,8 @@ nbostream & Packet::Entry::deserialize(nbostream & os)
return os;
}
-nbostream & Packet::Entry::serialize(nbostream & os) const
+nbostream &
+Packet::Entry::serialize(nbostream & os) const
{
os << _unique << _type << static_cast<uint32_t>(_data.size());
os.write(_data.c_str(), _data.size());
diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp
index 5a64d829183..5e7cfc74199 100644
--- a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp
+++ b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp
@@ -131,7 +131,7 @@ Domain::begin(const LockGuard & guard) const
assert(guard.locks(_lock));
SerialNum s(0);
if ( ! _parts.empty() ) {
- s = _parts.begin()->second->range().from();
+ s = _parts.cbegin()->second->range().from();
}
return s;
}
@@ -149,7 +149,7 @@ Domain::end(const LockGuard & guard) const
assert(guard.locks(_lock));
SerialNum s(0);
if ( ! _parts.empty() ) {
- s = _parts.rbegin()->second->range().to();
+ s = _parts.crbegin()->second->range().to();
}
return s;
}
@@ -203,7 +203,8 @@ Domain::triggerSyncNow()
}
}
-DomainPart::SP Domain::findPart(SerialNum s)
+DomainPart::SP
+Domain::findPart(SerialNum s)
{
LockGuard guard(_lock);
DomainPartList::iterator it(_parts.upper_bound(s));
@@ -220,12 +221,14 @@ DomainPart::SP Domain::findPart(SerialNum s)
return DomainPart::SP();
}
-uint64_t Domain::size() const
+uint64_t
+Domain::size() const
{
return size(LockGuard(_lock));
}
-uint64_t Domain::size(const LockGuard & guard) const
+uint64_t
+Domain::size(const LockGuard & guard) const
{
(void) guard;
assert(guard.locks(_lock));
@@ -236,7 +239,8 @@ uint64_t Domain::size(const LockGuard & guard) const
return sz;
}
-SerialNum Domain::findOldestActiveVisit() const
+SerialNum
+Domain::findOldestActiveVisit() const
{
SerialNum oldestActive(std::numeric_limits<SerialNum>::max());
LockGuard guard(_sessionLock);
@@ -249,7 +253,8 @@ SerialNum Domain::findOldestActiveVisit() const
return oldestActive;
}
-void Domain::cleanSessions()
+void
+Domain::cleanSessions()
{
if ( _sessions.empty()) {
return;
@@ -269,7 +274,8 @@ void Domain::cleanSessions()
namespace {
-void waitPendingSync(vespalib::Monitor &syncMonitor, bool &pendingSync)
+void
+waitPendingSync(vespalib::Monitor &syncMonitor, bool &pendingSync)
{
MonitorGuard guard(syncMonitor);
while (pendingSync) {
@@ -302,7 +308,8 @@ void Domain::commit(const Packet & packet)
cleanSessions();
}
-bool Domain::erase(SerialNum to)
+bool
+Domain::erase(SerialNum to)
{
bool retval(true);
/// Do not erase the last element
@@ -321,8 +328,9 @@ bool Domain::erase(SerialNum to)
return retval;
}
-int Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to,
- std::unique_ptr<Session::Destination> dest)
+int
+Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to,
+ std::unique_ptr<Session::Destination> dest)
{
assert(this == domain.get());
cleanSessions();
@@ -334,7 +342,8 @@ int Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to,
return id;
}
-int Domain::startSession(int sessionId)
+int
+Domain::startSession(int sessionId)
{
int retval(-1);
LockGuard guard(_sessionLock);
@@ -350,7 +359,8 @@ int Domain::startSession(int sessionId)
return retval;
}
-int Domain::closeSession(int sessionId)
+int
+Domain::closeSession(int sessionId)
{
_commitExecutor.sync();
int retval(-1);
diff --git a/searchlib/src/vespa/searchlib/transactionlog/session.cpp b/searchlib/src/vespa/searchlib/transactionlog/session.cpp
index e703c32484f..dda840808ce 100644
--- a/searchlib/src/vespa/searchlib/transactionlog/session.cpp
+++ b/searchlib/src/vespa/searchlib/transactionlog/session.cpp
@@ -31,7 +31,7 @@ Session::VisitTask::run()
bool
Session::visit(FastOS_FileInterface & file, DomainPart & dp) {
- Packet packet;
+ Packet packet(size_t(-1));
bool more(false);
if (dp.isClosed()) {
more = dp.visit(file, _range, packet);
diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp
index a3528c4f615..caef792704a 100644
--- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp
+++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp
@@ -18,6 +18,9 @@ using vespalib::make_string;
using vespalib::stringref;
using vespalib::IllegalArgumentException;
using search::common::FileHeaderContext;
+using std::make_shared;
+using std::runtime_error;
+using namespace std::chrono_literals;
namespace search::transactionlog {
@@ -31,10 +34,10 @@ class SyncHandler : public FNET_Task
SerialNum _syncTo;
public:
- SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req,const Domain::SP &domain,
+ SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req, const Domain::SP &domain,
const TransLogServer::Session::SP &session, SerialNum syncTo);
- ~SyncHandler();
+ ~SyncHandler() override;
void PerformTask() override;
};
@@ -157,17 +160,17 @@ bool
TransLogServer::onStop()
{
LOG(info, "Stopping TLS");
- _reqQ.push(NULL);
+ _reqQ.push(nullptr);
return true;
}
void
TransLogServer::run()
{
- FRT_RPCRequest *req(NULL);
+ FRT_RPCRequest *req(nullptr);
bool hasPacket(false);
do {
- for (req = NULL; (hasPacket = _reqQ.pop(req, 60000)) && (req != NULL); req = NULL) {
+ for (req = nullptr; (hasPacket = _reqQ.pop(req, 60000)) && (req != nullptr); req = nullptr) {
bool immediate = true;
if (strcmp(req->GetMethodName(), "domainSessionClose") == 0) {
domainSessionClose(req);
@@ -675,7 +678,7 @@ TransLogServer::finiSession(FRT_RPCRequest *req)
{
FNET_Connection *conn = req->GetConnection();
void *vctx = conn->GetContext()._value.VOIDP;
- conn->GetContextPT()->_value.VOIDP = NULL;
+ conn->GetContextPT()->_value.VOIDP = nullptr;
Session::SP *sessionspp = static_cast<Session::SP *>(vctx);
delete sessionspp;
}
@@ -696,7 +699,7 @@ TransLogServer::domainSync(FRT_RPCRequest *req)
Domain::SP domain(findDomain(domainName));
Session::SP session(getSession(req));
- if (domain.get() == nullptr) {
+ if ( ! domain) {
FRT_Values &rvals = *req->GetReturn();
rvals.AddInt32(0);
rvals.AddInt64(0);
diff --git a/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp b/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp
index a18c032fdff..b8fc432e95b 100644
--- a/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp
+++ b/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp
@@ -2,6 +2,7 @@
#include "generic_state_handler.h"
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
namespace vespalib {
diff --git a/storage/src/tests/storageserver/statereportertest.cpp b/storage/src/tests/storageserver/statereportertest.cpp
index dc8094275d1..a7d18b21516 100644
--- a/storage/src/tests/storageserver/statereportertest.cpp
+++ b/storage/src/tests/storageserver/statereportertest.cpp
@@ -10,8 +10,8 @@
#include <tests/common/dummystoragelink.h>
#include <vespa/config/common/exceptions.h>
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/vespalib/gtest/gtest.h>
-#include <vespa/vespalib/util/time.h>
#include <thread>
#include <vespa/log/log.h>
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java
index 4d50905da7b..62c15fcea27 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java
@@ -5,6 +5,7 @@ package com.yahoo.vespa.http.client;
import com.yahoo.vespa.http.client.config.SessionParams;
import com.yahoo.vespa.http.client.core.api.FeedClientImpl;
+import java.time.Clock;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
@@ -24,7 +25,7 @@ public class FeedClientFactory {
* @return newly created FeedClient API object.
*/
public static FeedClient create(SessionParams sessionParams, FeedClient.ResultCallback resultCallback) {
- return new FeedClientImpl(sessionParams, resultCallback, createTimeoutExecutor());
+ return new FeedClientImpl(sessionParams, resultCallback, createTimeoutExecutor(), Clock.systemUTC());
}
static ScheduledThreadPoolExecutor createTimeoutExecutor() {
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java
index 473b9494ba4..16374ec07cc 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java
@@ -78,7 +78,6 @@ public class Result {
private final Endpoint endpoint;
private final Exception exception;
private final String traceMessage;
- private final long timeStampMillis = System.currentTimeMillis();
public Detail(Endpoint endpoint, ResultType resultType, String traceMessage, Exception e) {
this.endpoint = endpoint;
@@ -133,7 +132,6 @@ public class Result {
b.append(" trace='").append(traceMessage).append("'");
if (endpoint != null)
b.append(" endpoint=").append(endpoint);
- b.append(" resultTimeLocally=").append(timeStampMillis).append("\n");
return b.toString();
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java
index b03a2541cd0..b7423f75c87 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java
@@ -5,6 +5,7 @@ import com.yahoo.vespa.http.client.config.Cluster;
import com.yahoo.vespa.http.client.config.Endpoint;
import com.yahoo.vespa.http.client.config.SessionParams;
+import java.time.Clock;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
@@ -30,7 +31,7 @@ public final class SessionFactory {
@SuppressWarnings("deprecation")
static Session createInternal(SessionParams params) {
- return new com.yahoo.vespa.http.client.core.api.SessionImpl(params, createTimeoutExecutor());
+ return new com.yahoo.vespa.http.client.core.api.SessionImpl(params, createTimeoutExecutor(), Clock.systemUTC());
}
static ScheduledThreadPoolExecutor createTimeoutExecutor() {
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
index 1accbd51ac7..2417a4acf71 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
@@ -42,6 +42,7 @@ public final class ConnectionParams {
private int maxRetries = 100;
private long minTimeBetweenRetriesMs = 700;
private boolean dryRun = false;
+ private boolean runThreads = true;
private int traceLevel = 0;
private int traceEveryXOperation = 0;
private boolean printTraceToStdErr = true;
@@ -191,10 +192,8 @@ public final class ConnectionParams {
}
/**
- * Don't send data to gateway, just pretend that everything is fine.
- *
- * @param dryRun true if enabled.
- * @return pointer to builder.
+ * Set to true to skip making network connections and instead
+ * let requests complete successfully with no effect.
*/
public Builder setDryRun(boolean dryRun) {
this.dryRun = dryRun;
@@ -202,6 +201,15 @@ public final class ConnectionParams {
}
/**
+ * Set to false to skip starting io threads, such that any operation must be driven by a calling thread.
+ * Useful for testing.
+ */
+ public Builder setRunThreads(boolean runThreads) {
+ this.runThreads = runThreads;
+ return this;
+ }
+
+ /**
* Set the min time between retries when temporarily failing against a gateway.
*
* @param minTimeBetweenRetries the min time value
@@ -274,6 +282,7 @@ public final class ConnectionParams {
maxRetries,
minTimeBetweenRetriesMs,
dryRun,
+ runThreads,
traceLevel,
traceEveryXOperation,
printTraceToStdErr,
@@ -293,6 +302,8 @@ public final class ConnectionParams {
return dryRun;
}
+ public boolean runThreads() { return runThreads; }
+
public int getMaxRetries() {
return maxRetries;
}
@@ -330,6 +341,7 @@ public final class ConnectionParams {
public Path getCertificate() { return certificate; }
public Path getCaCertificates() { return caCertificates; }
}
+
private final SSLContext sslContext;
private final Path privateKey;
private final Path certificate;
@@ -344,6 +356,7 @@ public final class ConnectionParams {
private final int maxRetries;
private final long minTimeBetweenRetriesMs;
private final boolean dryRun;
+ private final boolean runThreads;
private final int traceLevel;
private final int traceEveryXOperation;
private final boolean printTraceToStdErr;
@@ -363,6 +376,7 @@ public final class ConnectionParams {
int maxRetries,
long minTimeBetweenRetriesMs,
boolean dryRun,
+ boolean runThreads,
int traceLevel,
int traceEveryXOperation,
boolean printTraceToStdErr,
@@ -384,6 +398,7 @@ public final class ConnectionParams {
this.maxRetries = maxRetries;
this.minTimeBetweenRetriesMs = minTimeBetweenRetriesMs;
this.dryRun = dryRun;
+ this.runThreads = runThreads;
this.traceLevel = traceLevel;
this.traceEveryXOperation = traceEveryXOperation;
this.printTraceToStdErr = printTraceToStdErr;
@@ -435,6 +450,8 @@ public final class ConnectionParams {
return dryRun;
}
+ public boolean runThreads() { return runThreads; }
+
public int getTraceLevel() {
return traceLevel;
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java
index d623db3834c..200bedb90da 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java
@@ -172,6 +172,10 @@ public final class FeedParams {
return this;
}
+ /**
+ * Sets the number of milliseconds until we respond with a timeout for a document operation
+ * if we still have not received a response.
+ */
public Builder setLocalQueueTimeOut(long timeOutMs) {
this.localQueueTimeOut = timeOutMs;
return this;
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java
index 3131206f148..bf07e3ea634 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java
@@ -141,13 +141,12 @@ public final class SessionParams {
private final ErrorReporter errorReport;
private int throttlerMinSize;
- private SessionParams(
- Collection<Cluster> clusters,
- FeedParams feedParams,
- ConnectionParams connectionParams,
- int clientQueueSize,
- ErrorReporter errorReporter,
- int throttlerMinSize) {
+ private SessionParams(Collection<Cluster> clusters,
+ FeedParams feedParams,
+ ConnectionParams connectionParams,
+ int clientQueueSize,
+ ErrorReporter errorReporter,
+ int throttlerMinSize) {
this.clusters = Collections.unmodifiableList(new ArrayList<>(clusters));
this.feedParams = feedParams;
this.connectionParams = connectionParams;
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java
index bc38155d07a..98fd2f9da84 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java
@@ -7,53 +7,53 @@ import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.StandardCharsets;
+import java.time.Instant;
+import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;
/**
+ * A document operation
+ *
* @author Einar M R Rosenvinge
*/
final public class Document {
private final String documentId;
private final ByteBuffer data;
- private final long createTimeMillis = System.currentTimeMillis();
- // This is initialized lazily to reduce work on calling thread (which is the thread calling the API).
+ private final Instant createTime;
+ // This is initialized lazily to reduce work on calling thread (which is the thread calling the API)
private String operationId = null;
private final Object context;
- private long queueInsertTimestampMillis;
+ private Instant queueInsertTime;
- public Document(String documentId, byte[] data, Object context) {
- this.documentId = documentId;
- this.context = context;
- this.data = ByteBuffer.wrap(data);
+ public Document(String documentId, byte[] data, Object context, Instant createTime) {
+ this(documentId, null, ByteBuffer.wrap(data), context, createTime);
}
- public Document(String documentId, String operationId, CharSequence data, Object context) {
+ public Document(String documentId, String operationId, CharSequence data, Object context, Instant createTime) {
+ this(documentId, operationId, encode(data, documentId), context, createTime);
+ }
+
+ private Document(String documentId, String operationId, ByteBuffer data, Object context, Instant createTime) {
this.documentId = documentId;
this.operationId = operationId;
+ this.data = data;
this.context = context;
- try {
- this.data = StandardCharsets.UTF_8.newEncoder().encode(CharBuffer.wrap(data));
- } catch (CharacterCodingException e) {
- throw new RuntimeException("Error encoding document data into UTF8 " + documentId, e);
- }
+ this.createTime = Objects.requireNonNull(createTime, "createTime cannot be null");
+ this.queueInsertTime = createTime;
}
- public void resetQueueTime() {
- queueInsertTimestampMillis = System.currentTimeMillis();
+ public void setQueueInsertTime(Instant queueInsertTime) {
+ this.queueInsertTime = queueInsertTime;
}
- public long timeInQueueMillis() {
- return System.currentTimeMillis() - queueInsertTimestampMillis;
- }
+ public Instant getQueueInsertTime() { return queueInsertTime; }
public CharSequence getDataAsString() {
return StandardCharsets.UTF_8.decode(data.asReadOnlyBuffer());
}
- public Object getContext() {
- return context;
- }
+ public Object getContext() { return context; }
public static class DocumentException extends IOException {
private static final long serialVersionUID = 29832833292L;
@@ -63,9 +63,7 @@ final public class Document {
}
}
- public String getDocumentId() {
- return documentId;
- }
+ public String getDocumentId() { return documentId; }
public ByteBuffer getData() {
return data.asReadOnlyBuffer();
@@ -75,9 +73,7 @@ final public class Document {
return data.remaining();
}
- public long createTimeMillis() {
- return createTimeMillis;
- }
+ public Instant createTime() { return createTime; }
public String getOperationId() {
if (operationId == null) {
@@ -89,4 +85,12 @@ final public class Document {
@Override
public String toString() { return "document '" + documentId + "'"; }
+ private static ByteBuffer encode(CharSequence data, String documentId) {
+ try {
+ return StandardCharsets.UTF_8.newEncoder().encode(CharBuffer.wrap(data));
+ } catch (CharacterCodingException e) {
+ throw new RuntimeException("Error encoding document data into UTF8 " + documentId, e);
+ }
+ }
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java
index 7238a0c4ba7..a950cb545de 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java
@@ -11,6 +11,7 @@ import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
+import java.time.Clock;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.ScheduledThreadPoolExecutor;
@@ -23,25 +24,28 @@ import java.util.concurrent.TimeUnit;
*/
public class FeedClientImpl implements FeedClient {
+ private final Clock clock;
private final OperationProcessor operationProcessor;
private final long closeTimeoutMs;
private final long sleepTimeMs = 500;
public FeedClientImpl(SessionParams sessionParams,
ResultCallback resultCallback,
- ScheduledThreadPoolExecutor timeoutExecutor) {
- this.closeTimeoutMs = (10 + 3 * sessionParams.getConnectionParams().getMaxRetries()) * (
- sessionParams.getFeedParams().getServerTimeout(TimeUnit.MILLISECONDS) +
- sessionParams.getFeedParams().getClientTimeout(TimeUnit.MILLISECONDS));
+ ScheduledThreadPoolExecutor timeoutExecutor,
+ Clock clock) {
+ this.clock = clock;
+ this.closeTimeoutMs = (10 + 3 * sessionParams.getConnectionParams().getMaxRetries()) *
+ (sessionParams.getFeedParams().getServerTimeout(TimeUnit.MILLISECONDS) +
+ sessionParams.getFeedParams().getClientTimeout(TimeUnit.MILLISECONDS));
this.operationProcessor = new OperationProcessor(
- new IncompleteResultsThrottler(
- sessionParams.getThrottlerMinSize(),
- sessionParams.getClientQueueSize(),
- ()->System.currentTimeMillis(),
- new ThrottlePolicy()),
+ new IncompleteResultsThrottler(sessionParams.getThrottlerMinSize(),
+ sessionParams.getClientQueueSize(),
+ clock,
+ new ThrottlePolicy()),
resultCallback,
sessionParams,
- timeoutExecutor);
+ timeoutExecutor,
+ clock);
}
@Override
@@ -50,7 +54,7 @@ public class FeedClientImpl implements FeedClient {
charsetEncoder.onMalformedInput(CodingErrorAction.REPORT);
charsetEncoder.onUnmappableCharacter(CodingErrorAction.REPORT);
- Document document = new Document(documentId, operationId, documentData, context);
+ Document document = new Document(documentId, operationId, documentData, context, clock.instant());
operationProcessor.sendDocument(document);
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java
index bf55a46277d..e09cecf7161 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java
@@ -6,6 +6,7 @@ import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import java.time.Clock;
/**
* Class for wiring up the Session API. It is the return value of stream() in the Session API.
@@ -17,19 +18,21 @@ class MultiClusterSessionOutputStream extends ByteArrayOutputStream {
private final CharSequence documentId;
private final OperationProcessor operationProcessor;
private final Object context;
+ private final Clock clock;
- public MultiClusterSessionOutputStream(
- CharSequence documentId,
- OperationProcessor operationProcessor,
- Object context) {
+ public MultiClusterSessionOutputStream(CharSequence documentId,
+ OperationProcessor operationProcessor,
+ Object context,
+ Clock clock) {
this.documentId = documentId;
this.context = context;
this.operationProcessor = operationProcessor;
+ this.clock = clock;
}
@Override
public void close() throws IOException {
- Document document = new Document(documentId.toString(), toByteArray(), context);
+ Document document = new Document(documentId.toString(), toByteArray(), context, clock.instant());
operationProcessor.sendDocument(document);
super.close();
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java
index a5c97351347..a68d7eb7524 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java
@@ -9,6 +9,7 @@ import com.yahoo.vespa.http.client.core.operationProcessor.IncompleteResultsThro
import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor;
import java.io.OutputStream;
+import java.time.Clock;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledThreadPoolExecutor;
@@ -23,14 +24,15 @@ public class SessionImpl implements com.yahoo.vespa.http.client.Session {
private final OperationProcessor operationProcessor;
private final BlockingQueue<Result> resultQueue = new LinkedBlockingQueue<>();
+ private final Clock clock;
-
- public SessionImpl(SessionParams sessionParams, ScheduledThreadPoolExecutor timeoutExecutor) {
+ public SessionImpl(SessionParams sessionParams, ScheduledThreadPoolExecutor timeoutExecutor, Clock clock) {
+ this.clock = clock;
this.operationProcessor = new OperationProcessor(
new IncompleteResultsThrottler(
sessionParams.getThrottlerMinSize(),
sessionParams.getClientQueueSize(),
- ()->System.currentTimeMillis(),
+ clock,
new ThrottlePolicy()),
new FeedClient.ResultCallback() {
@Override
@@ -39,12 +41,13 @@ public class SessionImpl implements com.yahoo.vespa.http.client.Session {
}
},
sessionParams,
- timeoutExecutor);
+ timeoutExecutor,
+ clock);
}
@Override
public OutputStream stream(CharSequence documentId) {
- return new MultiClusterSessionOutputStream(documentId, operationProcessor, null);
+ return new MultiClusterSessionOutputStream(documentId, operationProcessor, null, clock);
}
@Override
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
index d510ce4b7ea..a46b2e67fe1 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
@@ -20,6 +20,7 @@ import org.apache.http.client.HttpClient;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.InputStreamEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.message.BasicHeader;
@@ -28,10 +29,10 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
-import java.net.InetAddress;
-import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
+import java.time.Clock;
+import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -49,84 +50,89 @@ import java.util.zip.GZIPOutputStream;
*/
class ApacheGatewayConnection implements GatewayConnection {
- private static Logger log = Logger.getLogger(ApacheGatewayConnection.class.getName());
+ private static final Logger log = Logger.getLogger(ApacheGatewayConnection.class.getName());
private static final ObjectMapper mapper = new ObjectMapper();
private static final String PATH = "/reserved-for-internal-use/feedapi?";
- private final List<Integer> SUPPORTED_VERSIONS = new ArrayList<>();
private static final byte[] START_OF_FEED_XML = "<vespafeed>\n".getBytes(StandardCharsets.UTF_8);
private static final byte[] END_OF_FEED_XML = "\n</vespafeed>\n".getBytes(StandardCharsets.UTF_8);
private static final byte[] START_OF_FEED_JSON = "[".getBytes(StandardCharsets.UTF_8);
private static final byte[] END_OF_FEED_JSON = "]".getBytes(StandardCharsets.UTF_8);
+
+ private final List<Integer> supportedVersions = new ArrayList<>();
private final byte[] startOfFeed;
private final byte[] endOfFeed;
private final Endpoint endpoint;
private final FeedParams feedParams;
private final String clusterSpecificRoute;
private final ConnectionParams connectionParams;
- private HttpClient httpClient;
+ private CloseableHttpClient httpClient;
+ private Instant connectionTime = null;
+ private Instant lastPollTime = null;
private String sessionId;
private final String clientId;
private int negotiatedVersion = -1;
private final HttpClientFactory httpClientFactory;
private final String shardingKey = UUID.randomUUID().toString().substring(0, 5);
-
- ApacheGatewayConnection(
- Endpoint endpoint,
- FeedParams feedParams,
- String clusterSpecificRoute,
- ConnectionParams connectionParams,
- HttpClientFactory httpClientFactory,
- String clientId) {
- SUPPORTED_VERSIONS.add(3);
- this.endpoint = validate(endpoint);
+ private final Clock clock;
+
+ ApacheGatewayConnection(Endpoint endpoint,
+ FeedParams feedParams,
+ String clusterSpecificRoute,
+ ConnectionParams connectionParams,
+ HttpClientFactory httpClientFactory,
+ String clientId,
+ Clock clock) {
+ supportedVersions.add(3);
+ this.endpoint = endpoint;
this.feedParams = feedParams;
this.clusterSpecificRoute = clusterSpecificRoute;
this.httpClientFactory = httpClientFactory;
this.connectionParams = connectionParams;
this.httpClient = null;
- boolean isJson = feedParams.getDataFormat() == FeedParams.DataFormat.JSON_UTF8;
- if (isJson) {
+ this.clientId = clientId;
+ this.clock = clock;
+
+ if (feedParams.getDataFormat() == FeedParams.DataFormat.JSON_UTF8) {
startOfFeed = START_OF_FEED_JSON;
endOfFeed = END_OF_FEED_JSON;
} else {
startOfFeed = START_OF_FEED_XML;
endOfFeed = END_OF_FEED_XML;
}
- this.clientId = clientId;
- if (this.clientId == null)
- throw new IllegalArgumentException("Got no client Id.");
}
- private static Endpoint validate(Endpoint endpoint) {
- try {
- InetAddress.getByName(endpoint.getHostname());
- return endpoint;
- }
- catch (UnknownHostException e) {
- throw new IllegalArgumentException("Unknown host: " + endpoint);
- }
+ @Override
+ public InputStream write(List<Document> docs) throws ServerResponseException, IOException {
+ return write(docs, false, connectionParams.getUseCompression());
}
@Override
- public InputStream writeOperations(List<Document> docs) throws ServerResponseException, IOException {
- return write(docs, false, connectionParams.getUseCompression());
+ public InputStream poll() throws ServerResponseException, IOException {
+ lastPollTime = clock.instant();
+ return write(Collections.<Document>emptyList(), false, false);
}
@Override
+ public Instant lastPollTime() { return lastPollTime; }
+
+ @Override
public InputStream drain() throws ServerResponseException, IOException {
- return write(Collections.<Document>emptyList(), true /* drain */, false /* use compression */);
+ return write(Collections.<Document>emptyList(), true, false);
}
@Override
public boolean connect() {
- log.fine("Attempting to connect to " + endpoint);
- if (httpClient != null) {
+ log.fine(() -> "Attempting to connect to " + endpoint);
+ if (httpClient != null)
log.log(Level.WARNING, "Previous httpClient still exists.");
- }
httpClient = httpClientFactory.createClient();
+ connectionTime = clock.instant();
return httpClient != null;
}
+ @Override
+ public Instant connectionTime() { return connectionTime; }
+
// Protected for easier testing only.
protected static InputStreamEntity zipAndCreateEntity(final InputStream inputStream) throws IOException {
byte[] buffer = new byte[4096];
@@ -184,7 +190,7 @@ class ApacheGatewayConnection implements GatewayConnection {
private HttpPost createPost(boolean drain, boolean useCompression, boolean isHandshake) {
HttpPost httpPost = new HttpPost(createUri());
- for (int v : SUPPORTED_VERSIONS) {
+ for (int v : supportedVersions) {
httpPost.addHeader(Headers.VERSION, "" + v);
}
if (sessionId != null) {
@@ -194,11 +200,7 @@ class ApacheGatewayConnection implements GatewayConnection {
httpPost.setHeader(Headers.CLIENT_ID, clientId);
}
httpPost.setHeader(Headers.SHARDING_KEY, shardingKey);
- if (drain) {
- httpPost.setHeader(Headers.DRAIN, "true");
- } else {
- httpPost.setHeader(Headers.DRAIN, "false");
- }
+ httpPost.setHeader(Headers.DRAIN, drain ? "true" : "false");
if (clusterSpecificRoute != null) {
httpPost.setHeader(Headers.ROUTE, feedParams.getRoute());
} else {
@@ -246,13 +248,9 @@ class ApacheGatewayConnection implements GatewayConnection {
private InputStream executePost(HttpPost httpPost) throws ServerResponseException, IOException {
HttpResponse response;
try {
- if (httpClient == null) {
+ if (httpClient == null)
throw new IOException("Trying to executePost while not having a connection/http client");
- }
response = httpClient.execute(httpPost);
- } catch (IOException e) {
- httpPost.abort();
- throw e;
} catch (Exception e) {
httpPost.abort();
throw e;
@@ -270,18 +268,14 @@ class ApacheGatewayConnection implements GatewayConnection {
private void verifyServerResponseCode(HttpResponse response) throws ServerResponseException {
StatusLine statusLine = response.getStatusLine();
+ int statusCode = statusLine.getStatusCode();
+
// We use code 261-299 to report errors related to internal transitive errors that the tenants should not care
// about to avoid masking more serious errors.
- int statusCode = statusLine.getStatusCode();
- if (statusCode > 199 && statusCode < 260) {
- return;
- }
- if (statusCode == 299) {
- throw new ServerResponseException(429, "Too many requests.");
- }
- String message = tryGetDetailedErrorMessage(response)
- .orElseGet(statusLine::getReasonPhrase);
- throw new ServerResponseException(statusLine.getStatusCode(), message);
+ if (statusCode > 199 && statusCode < 260) return;
+ if (statusCode == 299) throw new ServerResponseException(429, "Too many requests.");
+ throw new ServerResponseException(statusCode,
+ tryGetDetailedErrorMessage(response).orElseGet(statusLine::getReasonPhrase));
}
private static Optional<String> tryGetDetailedErrorMessage(HttpResponse response) {
@@ -305,7 +299,7 @@ class ApacheGatewayConnection implements GatewayConnection {
if (negotiatedVersion == 3) {
if (clientId == null || !clientId.equals(serverHeaderVal)) {
String message = "Running using v3. However, server responds with different session " +
- "than client has set; " + serverHeaderVal + " vs client code " + clientId;
+ "than client has set; " + serverHeaderVal + " vs client code " + clientId;
log.severe(message);
throw new ServerResponseException(message);
}
@@ -314,14 +308,12 @@ class ApacheGatewayConnection implements GatewayConnection {
if (sessionId == null) { //this must be the first request
log.finer("Got session ID from server: " + serverHeaderVal);
this.sessionId = serverHeaderVal;
- return;
} else {
if (!sessionId.equals(serverHeaderVal)) {
- log.info("Request has been routed to a server which does not recognize the client session."
- + " Most likely cause is upgrading of cluster, transitive error.");
- throw new ServerResponseException(
- "Session ID received from server ('" + serverHeaderVal
- + "') does not match cached session ID ('" + sessionId + "')");
+ log.info("Request has been routed to a server which does not recognize the client session." +
+ " Most likely cause is upgrading of cluster, transitive error.");
+ throw new ServerResponseException("Session ID received from server ('" + serverHeaderVal +
+ "') does not match cached session ID ('" + sessionId + "')");
}
}
}
@@ -336,9 +328,9 @@ class ApacheGatewayConnection implements GatewayConnection {
} catch (NumberFormatException nfe) {
throw new ServerResponseException("Got bad protocol version from server: " + nfe.getMessage());
}
- if (!SUPPORTED_VERSIONS.contains(serverVersion)) {
+ if (!supportedVersions.contains(serverVersion)) {
throw new ServerResponseException("Unsupported version: " + serverVersion
- + ". Supported versions: " + SUPPORTED_VERSIONS);
+ + ". Supported versions: " + supportedVersions);
}
if (negotiatedVersion == -1) {
if (log.isLoggable(Level.FINE)) {
@@ -387,6 +379,13 @@ class ApacheGatewayConnection implements GatewayConnection {
@Override
public void close() {
+ try {
+ if (httpClient != null)
+ httpClient.close();
+ }
+ catch (IOException e) {
+ log.log(Level.WARNING, "Failed closing HTTP client", e);
+ }
httpClient = null;
}
@@ -403,7 +402,7 @@ class ApacheGatewayConnection implements GatewayConnection {
this.useSsl = useSsl;
}
- public HttpClient createClient() {
+ public CloseableHttpClient createClient() {
HttpClientBuilder clientBuilder;
if (connectionParams.useTlsConfigFromEnvironment()) {
clientBuilder = VespaHttpClientBuilder.create();
@@ -428,12 +427,9 @@ class ApacheGatewayConnection implements GatewayConnection {
}
clientBuilder.setMaxConnPerRoute(1);
clientBuilder.setMaxConnTotal(1);
- clientBuilder.setConnectionTimeToLive(connectionParams.getConnectionTimeToLive().getSeconds(), TimeUnit.SECONDS);
clientBuilder.setUserAgent(String.format("vespa-http-client (%s)", Vtag.currentVersion));
clientBuilder.setDefaultHeaders(Collections.singletonList(new BasicHeader(Headers.CLIENT_VERSION, Vtag.currentVersion)));
clientBuilder.disableContentCompression();
- // Try to disable the disabling to see if system tests become stable again.
- // clientBuilder.disableAutomaticRetries();
RequestConfig.Builder requestConfigBuilder = RequestConfig.custom();
requestConfigBuilder.setSocketTimeout(0);
if (connectionParams.getProxyHost() != null) {
@@ -441,17 +437,16 @@ class ApacheGatewayConnection implements GatewayConnection {
}
clientBuilder.setDefaultRequestConfig(requestConfigBuilder.build());
- log.fine("Creating HttpClient: " + " ConnectionTimeout "
- + " SocketTimeout 0 secs "
- + " proxyhost (can be null) " + connectionParams.getProxyHost()
- + ":" + connectionParams.getProxyPort()
+ log.fine(() -> "Creating HttpClient:" +
+ " ConnectionTimeout " + connectionParams.getConnectionTimeToLive().getSeconds() + " seconds" +
+ " proxyhost (can be null) " + connectionParams.getProxyHost() + ":" + connectionParams.getProxyPort()
+ (useSsl ? " using ssl " : " not using ssl")
);
return clientBuilder.build();
}
}
- // Note: Using deprecated setSslcontext() to allow httpclient 4.4 on classpath (e.g unexpected Maven dependency resolution for test classpath)
+ // Note: Using deprecated setSslContext() to allow httpclient 4.4 on classpath (e.g unexpected Maven dependency resolution for test classpath)
@SuppressWarnings("deprecation")
private static void setSslContext(HttpClientBuilder builder, SSLContext sslContext) {
builder.setSslcontext(sslContext);
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionFactory.java
new file mode 100644
index 00000000000..31ec8aa06a2
--- /dev/null
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionFactory.java
@@ -0,0 +1,63 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.http.client.core.communication;
+
+import com.yahoo.vespa.http.client.config.ConnectionParams;
+import com.yahoo.vespa.http.client.config.Endpoint;
+import com.yahoo.vespa.http.client.config.FeedParams;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.time.Clock;
+import java.util.Objects;
+
+/**
+ * @author bratseth
+ */
+public class ApacheGatewayConnectionFactory implements GatewayConnectionFactory {
+
+ private final Endpoint endpoint;
+ private final FeedParams feedParams;
+ private final String clusterSpecificRoute;
+ private final ConnectionParams connectionParams;
+ private final ApacheGatewayConnection.HttpClientFactory httpClientFactory;
+ private final String clientId;
+ private final Clock clock;
+
+ public ApacheGatewayConnectionFactory(Endpoint endpoint,
+ FeedParams feedParams,
+ String clusterSpecificRoute,
+ ConnectionParams connectionParams,
+ ApacheGatewayConnection.HttpClientFactory httpClientFactory,
+ String clientId,
+ Clock clock) {
+ this.endpoint = validate(endpoint);
+ this.feedParams = feedParams;
+ this.clusterSpecificRoute = clusterSpecificRoute;
+ this.httpClientFactory = httpClientFactory;
+ this.connectionParams = connectionParams;
+ this.clientId = Objects.requireNonNull(clientId, "clientId cannot be null");
+ this.clock = clock;
+ }
+
+ private static Endpoint validate(Endpoint endpoint) {
+ try {
+ InetAddress.getByName(endpoint.getHostname());
+ return endpoint;
+ }
+ catch (UnknownHostException e) {
+ throw new IllegalArgumentException("Unknown host: " + endpoint);
+ }
+ }
+
+ @Override
+ public GatewayConnection newConnection() {
+ return new ApacheGatewayConnection(endpoint,
+ feedParams,
+ clusterSpecificRoute,
+ connectionParams,
+ httpClientFactory,
+ clientId,
+ clock);
+ }
+
+}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java
index d254cd0bab8..8e55e59b3f4 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java
@@ -14,7 +14,10 @@ import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor;
import java.io.IOException;
import java.io.StringWriter;
+import java.time.Clock;
+import java.time.Duration;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@@ -43,7 +46,8 @@ public class ClusterConnection implements AutoCloseable {
Cluster cluster,
int clusterId,
int clientQueueSizePerCluster,
- ScheduledThreadPoolExecutor timeoutExecutor) {
+ ScheduledThreadPoolExecutor timeoutExecutor,
+ Clock clock) {
if (cluster.getEndpoints().isEmpty())
throw new IllegalArgumentException("At least a single endpoint is required in " + cluster);
@@ -53,7 +57,7 @@ public class ClusterConnection implements AutoCloseable {
throw new IllegalArgumentException("At least 1 persistent connection per endpoint is required in " + cluster);
int maxInFlightPerSession = Math.max(1, feedParams.getMaxInFlightRequests() / totalNumberOfEndpointsInThisCluster);
- documentQueue = new DocumentQueue(clientQueueSizePerCluster);
+ documentQueue = new DocumentQueue(clientQueueSizePerCluster, clock);
ioThreadGroup = operationProcessor.getIoThreadGroup();
singleEndpoint = cluster.getEndpoints().size() == 1 ? cluster.getEndpoints().get(0) : null;
Double idlePollFrequency = feedParams.getIdlePollFrequency();
@@ -66,28 +70,33 @@ public class ClusterConnection implements AutoCloseable {
timeoutExecutor,
feedParams.getServerTimeout(TimeUnit.MILLISECONDS) + feedParams.getClientTimeout(TimeUnit.MILLISECONDS));
for (int i = 0; i < connectionParams.getNumPersistentConnectionsPerEndpoint(); i++) {
- GatewayConnection gatewayConnection;
+ GatewayConnectionFactory connectionFactory;
if (connectionParams.isDryRun()) {
- gatewayConnection = new DryRunGatewayConnection(endpoint);
+ connectionFactory = new DryRunGatewayConnectionFactory(endpoint, clock);
} else {
- gatewayConnection = new ApacheGatewayConnection(endpoint,
- feedParams,
- cluster.getRoute(),
- connectionParams,
- new ApacheGatewayConnection.HttpClientFactory(connectionParams, endpoint.isUseSsl()),
- operationProcessor.getClientId()
+ connectionFactory = new ApacheGatewayConnectionFactory(endpoint,
+ feedParams,
+ cluster.getRoute(),
+ connectionParams,
+ new ApacheGatewayConnection.HttpClientFactory(connectionParams, endpoint.isUseSsl()),
+ operationProcessor.getClientId(),
+ clock
);
}
IOThread ioThread = new IOThread(operationProcessor.getIoThreadGroup(),
+ endpoint,
endpointResultQueue,
- gatewayConnection,
+ connectionFactory,
clusterId,
feedParams.getMaxChunkSizeBytes(),
maxInFlightPerSession,
- feedParams.getLocalQueueTimeOut(),
+ Duration.ofMillis(feedParams.getLocalQueueTimeOut()),
documentQueue,
feedParams.getMaxSleepTimeMs(),
- idlePollFrequency);
+ connectionParams.getConnectionTimeToLive(),
+ connectionParams.runThreads(),
+ idlePollFrequency,
+ clock);
ioThreads.add(ioThread);
}
}
@@ -160,6 +169,10 @@ public class ClusterConnection implements AutoCloseable {
return stringWriter.toString();
}
+ public List<IOThread> ioThreads() {
+ return Collections.unmodifiableList(ioThreads);
+ }
+
@Override
public boolean equals(Object o) {
return (this == o) || (o instanceof ClusterConnection && clusterId == ((ClusterConnection) o).clusterId);
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java
index 16bf881963f..3536013e043 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java
@@ -3,6 +3,8 @@ package com.yahoo.vespa.http.client.core.communication;
import com.yahoo.vespa.http.client.core.Document;
+import java.time.Clock;
+import java.time.Duration;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
@@ -11,8 +13,8 @@ import java.util.Optional;
import java.util.concurrent.TimeUnit;
/**
- * Document queue that only gives you document operations on documents for which there are no
- * already in flight operations for.
+ * Shared document queue that gives clients operations on documents which do not have operations already in flight.
+ * This is multithread safe.
*
* @author dybis
*/
@@ -21,10 +23,12 @@ class DocumentQueue {
private final Deque<Document> queue;
private final int maxSize;
private boolean closed = false;
+ private final Clock clock;
- DocumentQueue(int maxSize) {
+ DocumentQueue(int maxSize, Clock clock) {
this.maxSize = maxSize;
this.queue = new ArrayDeque<>(maxSize);
+ this.clock = clock;
}
List<Document> removeAllDocuments() {
@@ -39,7 +43,7 @@ class DocumentQueue {
}
void put(Document document, boolean calledFromIoThreadGroup) throws InterruptedException {
- document.resetQueueTime();
+ document.setQueueInsertTime(clock.instant());
synchronized (queue) {
while (!closed && (queue.size() >= maxSize) && !calledFromIoThreadGroup) {
queue.wait();
@@ -56,9 +60,9 @@ class DocumentQueue {
synchronized (queue) {
long remainingToWait = unit.toMillis(timeout);
while (queue.isEmpty()) {
- long startTime = System.currentTimeMillis();
+ long startTime = clock.millis();
queue.wait(remainingToWait);
- remainingToWait -= (System.currentTimeMillis() - startTime);
+ remainingToWait -= (clock.millis() - startTime);
if (remainingToWait <= 0) {
break;
}
@@ -106,16 +110,15 @@ class DocumentQueue {
return previousState;
}
- Optional<Document> pollDocumentIfTimedoutInQueue(long localQueueTimeOut) {
+ Optional<Document> pollDocumentIfTimedoutInQueue(Duration localQueueTimeOut) {
synchronized (queue) {
- if (queue.isEmpty()) {
- return Optional.empty();
- }
+ if (queue.isEmpty()) return Optional.empty();
+
Document document = queue.peek();
- if (document.timeInQueueMillis() > localQueueTimeOut) {
- return Optional.of(queue.poll());
- }
- return Optional.empty();
+ if (document.getQueueInsertTime().plus(localQueueTimeOut).isBefore(clock.instant()))
+ return Optional.ofNullable(queue.poll());
+ else
+ return Optional.empty();
}
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java
index 23ab5e36e14..129fc000271 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java
@@ -5,13 +5,14 @@ import com.yahoo.vespa.http.client.config.Endpoint;
import com.yahoo.vespa.http.client.core.Document;
import com.yahoo.vespa.http.client.core.ErrorCode;
import com.yahoo.vespa.http.client.core.OperationStatus;
-import com.yahoo.vespa.http.client.core.ServerResponseException;
import java.io.ByteArrayInputStream;
-import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
+import java.time.Clock;
+import java.time.Instant;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
/**
@@ -22,40 +23,78 @@ import java.util.List;
public class DryRunGatewayConnection implements GatewayConnection {
private final Endpoint endpoint;
+ private final Clock clock;
+ private Instant connectionTime = null;
+ private Instant lastPollTime = null;
- public DryRunGatewayConnection(Endpoint endpoint) {
+ /** Set to true to hold off responding with a result to any incoming operations until this is set false */
+ private boolean hold = false;
+ private List<Document> held = new ArrayList<>();
+
+ public DryRunGatewayConnection(Endpoint endpoint, Clock clock) {
this.endpoint = endpoint;
+ this.clock = clock;
}
@Override
- public InputStream writeOperations(List<Document> docs) throws ServerResponseException, IOException {
+ public InputStream write(List<Document> docs) {
StringBuilder result = new StringBuilder();
- for (Document doc : docs) {
- OperationStatus operationStatus = new OperationStatus("ok", doc.getOperationId(), ErrorCode.OK, false, "");
- result.append(operationStatus.render());
+ if (hold) {
+ held.addAll(docs);
+ }
+ else {
+ for (Document doc : held)
+ result.append(okResponse(doc).render());
+ held.clear();
+ for (Document doc : docs)
+ result.append(okResponse(doc).render());
}
return new ByteArrayInputStream(result.toString().getBytes(StandardCharsets.UTF_8));
}
+ public void hold(boolean hold) {
+ this.hold = hold;
+ }
+
+ @Override
+ public InputStream poll() {
+ lastPollTime = clock.instant();
+ return write(new ArrayList<>());
+ }
+
@Override
- public InputStream drain() throws ServerResponseException, IOException {
- return writeOperations(new ArrayList<Document>());
+ public Instant lastPollTime() { return lastPollTime; }
+
+ @Override
+ public InputStream drain() {
+ return write(new ArrayList<>());
}
@Override
public boolean connect() {
+ connectionTime = clock.instant();
return true;
}
@Override
+ public Instant connectionTime() { return connectionTime; }
+
+ @Override
public Endpoint getEndpoint() {
return endpoint;
}
@Override
- public void handshake() throws ServerResponseException, IOException { }
+ public void handshake() { }
@Override
public void close() { }
+ /** Returns the document currently held in this */
+ public List<Document> held() { return Collections.unmodifiableList(held); }
+
+ private OperationStatus okResponse(Document document) {
+ return new OperationStatus("ok", document.getOperationId(), ErrorCode.OK, false, "");
+ }
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnectionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnectionFactory.java
new file mode 100644
index 00000000000..a234dba6b8e
--- /dev/null
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnectionFactory.java
@@ -0,0 +1,26 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.http.client.core.communication;
+
+import com.yahoo.vespa.http.client.config.Endpoint;
+
+import java.time.Clock;
+
+/**
+ * @author bratseth
+ */
+public class DryRunGatewayConnectionFactory implements GatewayConnectionFactory {
+
+ private final Endpoint endpoint;
+ private final Clock clock;
+
+ public DryRunGatewayConnectionFactory(Endpoint endpoint, Clock clock) {
+ this.endpoint = endpoint;
+ this.clock = clock;
+ }
+
+ @Override
+ public GatewayConnection newConnection() {
+ return new DryRunGatewayConnection(endpoint, clock);
+ }
+
+}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java
index cd146cf0e87..1dd8b3bf3ec 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java
@@ -15,24 +15,29 @@ import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
/**
+ * The shared queue of operation results.
+ * This is multithread safe.
+ *
* @author Einar M R Rosenvinge
*/
class EndpointResultQueue {
- private static Logger log = Logger.getLogger(EndpointResultQueue.class.getName());
+ private static final Logger log = Logger.getLogger(EndpointResultQueue.class.getName());
private final OperationProcessor operationProcessor;
+
+ /** The currently in flight operations */
private final Map<String, TimerFuture> futureByOperation = new HashMap<>();
+
private final Endpoint endpoint;
private final int clusterId;
private final ScheduledThreadPoolExecutor timer;
private final long totalTimeoutMs;
- EndpointResultQueue(
- OperationProcessor operationProcessor,
- Endpoint endpoint,
- int clusterId,
- ScheduledThreadPoolExecutor timer,
- long totalTimeoutMs) {
+ EndpointResultQueue(OperationProcessor operationProcessor,
+ Endpoint endpoint,
+ int clusterId,
+ ScheduledThreadPoolExecutor timer,
+ long totalTimeoutMs) {
this.operationProcessor = operationProcessor;
this.endpoint = endpoint;
this.clusterId = clusterId;
@@ -64,25 +69,23 @@ class EndpointResultQueue {
TimerFuture timerFuture = futureByOperation.remove(result.getOperationId());
if (timerFuture == null) {
if (duplicateGivesWarning) {
- log.warning(
- "Result for ID '" + result.getOperationId() + "' received from '" + endpoint
- + "', but we have no record of a sent operation. Either something is wrong on the server side "
- + "(bad VIP usage?), or we have somehow received duplicate results, "
- + "or operation was received _after_ client-side timeout.");
+ log.warning("Result for ID '" + result.getOperationId() + "' received from '" + endpoint +
+ "', but we have no record of a sent operation. Either something is wrong on the server side " +
+ "(bad VIP usage?), or we have somehow received duplicate results, " +
+ "or operation was received _after_ client-side timeout.");
}
return;
}
timerFuture.getFuture().cancel(false);
}
- //Called only from ScheduledThreadPoolExecutor thread in DocumentTimerTask.run(), see below
+ /** Called only from ScheduledThreadPoolExecutor thread in DocumentTimerTask.run(), see below */
private synchronized void timeout(String operationId) {
TimerFuture timerFuture = futureByOperation.remove(operationId);
if (timerFuture == null) {
- log.finer(
- "Timeout of operation '" + operationId + "', but operation "
- + "not found in map. Result was probably received just-in-time from server, while timeout "
- + "task could not be cancelled.");
+ log.finer("Timeout of operation '" + operationId + "', but operation " +
+ "not found in map. Result was probably received just-in-time from server, while timeout " +
+ "task could not be cancelled.");
return;
}
EndpointResult endpointResult = EndPointResultFactory.createTransientError(
@@ -108,6 +111,7 @@ class EndpointResultQueue {
}
private class DocumentTimerTask implements Runnable {
+
private final String operationId;
private DocumentTimerTask(String operationId) {
@@ -118,17 +122,21 @@ class EndpointResultQueue {
public void run() {
timeout(operationId);
}
+
}
- private class TimerFuture {
+ private static class TimerFuture {
+
private final ScheduledFuture<?> future;
public TimerFuture(ScheduledFuture<?> future) {
this.future = future;
}
+
private ScheduledFuture<?> getFuture() {
return future;
}
+
}
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java
index 3e5bdfe3056..ce1edb83fa2 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java
@@ -6,12 +6,23 @@ import com.yahoo.vespa.http.client.core.Document;
import com.yahoo.vespa.http.client.core.ServerResponseException;
import java.io.IOException;
import java.io.InputStream;
+import java.time.Instant;
import java.util.List;
public interface GatewayConnection {
- InputStream writeOperations(List<Document> docs) throws ServerResponseException, IOException;
+ /** Returns the time this connected over the network, or null if not connected yet */
+ Instant connectionTime();
+ /** Returns the last time poll was called on this, or null if never */
+ Instant lastPollTime();
+
+ InputStream write(List<Document> docs) throws ServerResponseException, IOException;
+
+ /** Returns any operation results that are ready now */
+ InputStream poll() throws ServerResponseException, IOException;
+
+ /** Attempt to drain all outstanding operations, even if this leads to blocking */
InputStream drain() throws ServerResponseException, IOException;
boolean connect();
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnectionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnectionFactory.java
new file mode 100644
index 00000000000..d27aa850995
--- /dev/null
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnectionFactory.java
@@ -0,0 +1,13 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.http.client.core.communication;
+
+/**
+ * Creates gateway connections on request
+ *
+ * @author bratseth
+ */
+public interface GatewayConnectionFactory {
+
+ GatewayConnection newConnection();
+
+}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
index 0d916002964..2417208fba3 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
@@ -13,8 +13,13 @@ import com.yahoo.vespa.http.client.core.ServerResponseException;
import java.io.IOException;
import java.io.InputStream;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
@@ -31,23 +36,40 @@ import java.util.logging.Logger;
*/
class IOThread implements Runnable, AutoCloseable {
- private static Logger log = Logger.getLogger(IOThread.class.getName());
+ private static final Logger log = Logger.getLogger(IOThread.class.getName());
+
private final Endpoint endpoint;
- private final GatewayConnection client;
+ private final GatewayConnectionFactory connectionFactory;
private final DocumentQueue documentQueue;
private final EndpointResultQueue resultQueue;
+
+ /** The thread running this, or null if it does not run a thread (meaning tick() must be called from the outside) */
private final Thread thread;
private final int clusterId;
private final CountDownLatch running = new CountDownLatch(1);
private final CountDownLatch stopSignal = new CountDownLatch(1);
private final int maxChunkSizeBytes;
private final int maxInFlightRequests;
- private final long localQueueTimeOut;
+ private final Duration localQueueTimeOut;
+ private final Duration maxOldConnectionPollInterval;
private final GatewayThrottler gatewayThrottler;
+ private final Duration connectionTimeToLive;
private final long pollIntervalUS;
+ private final Clock clock;
private final Random random = new Random();
- private enum ThreadState { DISCONNECTED, CONNECTED, SESSION_SYNCED };
+ private GatewayConnection currentConnection;
+ private ConnectionState connectionState = ConnectionState.DISCONNECTED;
+
+ /**
+ * Previous connections on which we have sent operations and are still waiting for the result
+ * (so all connections in this are in state SESSION_SYNCED).
+ * We need to drain results on the connection where they were sent to make sure we request results on
+ * the node which received the operation also when going through a VIP.
+ */
+ private final List<GatewayConnection> oldConnections = new ArrayList<>();
+
+ private enum ConnectionState { DISCONNECTED, CONNECTED, SESSION_SYNCED };
private final AtomicInteger wrongSessionDetectedCounter = new AtomicInteger(0);
private final AtomicInteger wrongVersionDetectedCounter = new AtomicInteger(0);
private final AtomicInteger problemStatusCodeFromServerCounter = new AtomicInteger(0);
@@ -59,70 +81,49 @@ class IOThread implements Runnable, AutoCloseable {
private final AtomicInteger lastGatewayProcessTimeMillis = new AtomicInteger(0);
IOThread(ThreadGroup ioThreadGroup,
+ Endpoint endpoint,
EndpointResultQueue endpointResultQueue,
- GatewayConnection client,
+ GatewayConnectionFactory connectionFactory,
int clusterId,
int maxChunkSizeBytes,
int maxInFlightRequests,
- long localQueueTimeOut,
+ Duration localQueueTimeOut,
DocumentQueue documentQueue,
long maxSleepTimeMs,
- double idlePollFrequency) {
+ Duration connectionTimeToLive,
+ boolean runThreads,
+ double idlePollFrequency,
+ Clock clock) {
+ this.endpoint = endpoint;
this.documentQueue = documentQueue;
- this.endpoint = client.getEndpoint();
- this.client = client;
+ this.connectionFactory = connectionFactory;
+ this.currentConnection = connectionFactory.newConnection();
this.resultQueue = endpointResultQueue;
this.clusterId = clusterId;
this.maxChunkSizeBytes = maxChunkSizeBytes;
this.maxInFlightRequests = maxInFlightRequests;
+ this.connectionTimeToLive = connectionTimeToLive;
this.gatewayThrottler = new GatewayThrottler(maxSleepTimeMs);
- //Ensure that pollInterval is in the range [1us, 10s]
- this.pollIntervalUS = Math.max(1, (long)(1000000.0/Math.max(0.1, idlePollFrequency)));
- this.thread = new Thread(ioThreadGroup, this, "IOThread " + endpoint);
- thread.setDaemon(true);
+ this.pollIntervalUS = Math.max(1, (long)(1000000.0/Math.max(0.1, idlePollFrequency))); // ensure range [1us, 10s]
+ this.clock = clock;
this.localQueueTimeOut = localQueueTimeOut;
- thread.start();
+ this.maxOldConnectionPollInterval = localQueueTimeOut.dividedBy(10).toMillis() > pollIntervalUS / 1000
+ ? localQueueTimeOut.dividedBy(10)
+ : Duration.ofMillis(pollIntervalUS / 1000);
+ if (runThreads) {
+ this.thread = new Thread(ioThreadGroup, this, "IOThread " + endpoint);
+ thread.setDaemon(true);
+ thread.start();
+ }
+ else {
+ this.thread = null;
+ }
}
public Endpoint getEndpoint() {
return endpoint;
}
- public static class ConnectionStats {
-
- // NOTE: These fields are accessed by reflection in JSON serialization
-
- public final int wrongSessionDetectedCounter;
- public final int wrongVersionDetectedCounter;
- public final int problemStatusCodeFromServerCounter;
- public final int executeProblemsCounter;
- public final int docsReceivedCounter;
- public final int statusReceivedCounter;
- public final int pendingDocumentStatusCount;
- public final int successfullHandshakes;
- public final int lastGatewayProcessTimeMillis;
-
- ConnectionStats(int wrongSessionDetectedCounter,
- int wrongVersionDetectedCounter,
- int problemStatusCodeFromServerCounter,
- int executeProblemsCounter,
- int docsReceivedCounter,
- int statusReceivedCounter,
- int pendingDocumentStatusCount,
- int successfullHandshakes,
- int lastGatewayProcessTimeMillis) {
- this.wrongSessionDetectedCounter = wrongSessionDetectedCounter;
- this.wrongVersionDetectedCounter = wrongVersionDetectedCounter;
- this.problemStatusCodeFromServerCounter = problemStatusCodeFromServerCounter;
- this.executeProblemsCounter = executeProblemsCounter;
- this.docsReceivedCounter = docsReceivedCounter;
- this.statusReceivedCounter = statusReceivedCounter;
- this.pendingDocumentStatusCount = pendingDocumentStatusCount;
- this.successfullHandshakes = successfullHandshakes;
- this.lastGatewayProcessTimeMillis = lastGatewayProcessTimeMillis;
- }
- }
-
/**
* Returns a snapshot of counters. Threadsafe.
*/
@@ -152,18 +153,21 @@ class IOThread implements Runnable, AutoCloseable {
if (size > 0) {
log.info("We have outstanding operations (" + size + ") , trying to fetch responses.");
try {
- processResponse(client.drain());
+ for (GatewayConnection oldConnection : oldConnections)
+ processResponse(oldConnection.drain());
+ processResponse(currentConnection.drain());
} catch (Throwable e) {
log.log(Level.SEVERE, "Some failures while trying to get latest responses from vespa.", e);
}
}
try {
- client.close();
+ for (GatewayConnection oldConnection : oldConnections)
+ oldConnection.close();
+ currentConnection.close();
} finally {
// If there is still documents in the queue, fail them.
- drainDocumentQueueWhenFailingPermanently(new Exception(
- "Closed call, did not manage to process everything so failing this document."));
+ drainDocumentQueueWhenFailingPermanently(new Exception("Closed call, did not manage to process everything so failing this document."));
}
log.fine("Session to " + endpoint + " closed.");
@@ -184,7 +188,7 @@ class IOThread implements Runnable, AutoCloseable {
int chunkSizeBytes = 0;
try {
drainFirstDocumentsInQueueIfOld();
- Document doc = documentQueue.poll(maxWaitUnits, timeUnit);
+ Document doc = thread != null ? documentQueue.poll(maxWaitUnits, timeUnit) : documentQueue.poll();
if (doc != null) {
docsForSendChunk.add(doc);
chunkSizeBytes = doc.size();
@@ -236,12 +240,12 @@ class IOThread implements Runnable, AutoCloseable {
private InputStream sendAndReceive(List<Document> docs) throws IOException, ServerResponseException {
try {
// Post the new docs and get async responses for other posts.
- return client.writeOperations(docs);
+ return currentConnection.write(docs);
} catch (ServerResponseException ser) {
markDocumentAsFailed(docs, ser);
throw ser;
} catch (Exception e) {
- markDocumentAsFailed(docs, new ServerResponseException(e.getMessage()));
+ markDocumentAsFailed(docs, new ServerResponseException(Exceptions.toMessageString(e)));
throw e;
}
}
@@ -274,11 +278,11 @@ class IOThread implements Runnable, AutoCloseable {
private ProcessResponse feedDocumentAndProcessResults(List<Document> docs)
throws ServerResponseException, IOException {
addDocumentsToResultQueue(docs);
- long startTime = System.currentTimeMillis();
+ long startTime = clock.millis();
InputStream serverResponse = sendAndReceive(docs);
ProcessResponse processResponse = processResponse(serverResponse);
- lastGatewayProcessTimeMillis.set((int) (System.currentTimeMillis() - startTime));
+ lastGatewayProcessTimeMillis.set((int) (clock.millis() - startTime));
return processResponse;
}
@@ -309,28 +313,30 @@ class IOThread implements Runnable, AutoCloseable {
return processResponse;
}
- /** Given a current thread state, take the appropriate action and return the resulting new thread state */
- private ThreadState cycle(ThreadState threadState) {
- switch(threadState) {
+ /** Given a current connection state, take the appropriate action and return the resulting new connection state */
+ private ConnectionState cycle(ConnectionState connectionState) {
+ switch(connectionState) {
case DISCONNECTED:
try {
- if (! client.connect()) {
+ if (! currentConnection.connect()) {
log.log(Level.WARNING, "Could not connect to endpoint: '" + endpoint + "'. Will re-try.");
drainFirstDocumentsInQueueIfOld();
- return ThreadState.DISCONNECTED;
+ return ConnectionState.DISCONNECTED;
}
- return ThreadState.CONNECTED;
+ return ConnectionState.CONNECTED;
} catch (Throwable throwable1) {
drainFirstDocumentsInQueueIfOld();
log.log(Level.INFO, "Failed connecting to endpoint: '" + endpoint
+ "'. Will re-try connecting. Failed with '" + Exceptions.toMessageString(throwable1) + "'",throwable1);
executeProblemsCounter.incrementAndGet();
- return ThreadState.DISCONNECTED;
+ return ConnectionState.DISCONNECTED;
}
case CONNECTED:
try {
- client.handshake();
+ if (isStale(currentConnection))
+ return refreshConnection(connectionState);
+ currentConnection.handshake();
successfulHandshakes.getAndIncrement();
} catch (ServerResponseException ser) {
@@ -340,46 +346,49 @@ class IOThread implements Runnable, AutoCloseable {
drainFirstDocumentsInQueueIfOld();
resultQueue.onEndpointError(new FeedProtocolException(ser.getResponseCode(), ser.getResponseString(), ser, endpoint));
- return ThreadState.CONNECTED;
+ return ConnectionState.CONNECTED;
} catch (Throwable throwable) { // This cover IOException as well
executeProblemsCounter.incrementAndGet();
resultQueue.onEndpointError(new FeedConnectException(throwable, endpoint));
log.log(Level.INFO, "Failed talking to endpoint. Handshake with server endpoint '" + endpoint
+ "' failed. Will re-try handshake. Failed with '" + Exceptions.toMessageString(throwable) + "'",throwable);
drainFirstDocumentsInQueueIfOld();
- client.close();
- return ThreadState.DISCONNECTED;
+ currentConnection.close();
+ return ConnectionState.DISCONNECTED;
}
- return ThreadState.SESSION_SYNCED;
+ return ConnectionState.SESSION_SYNCED;
case SESSION_SYNCED:
try {
+ if (isStale(currentConnection))
+ return refreshConnection(connectionState);
ProcessResponse processResponse = pullAndProcessData(pollIntervalUS);
gatewayThrottler.handleCall(processResponse.transitiveErrorCount);
}
catch (ServerResponseException ser) {
- log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint
- + "'. Will re-try. Endpoint responded with an unexpected HTTP response code. '"
- + Exceptions.toMessageString(ser) + "'",ser);
- return ThreadState.CONNECTED;
+ log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint +
+ "'. Will re-try. Endpoint responded with an unexpected HTTP response code. '"
+ + Exceptions.toMessageString(ser) + "'",ser);
+ return ConnectionState.CONNECTED;
}
- catch (Throwable e) { // Covers IOException as well
- log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint
- + "'. Will re-try. Connection level error. Failed with '" + Exceptions.toMessageString(e) + "'", e);
- client.close();
- return ThreadState.DISCONNECTED;
+ catch (Throwable e) {
+ log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint +
+ "'. Will re-try. Connection level error. Failed with '" +
+ Exceptions.toMessageString(e) + "'", e);
+ currentConnection.close();
+ return ConnectionState.DISCONNECTED;
}
- return ThreadState.SESSION_SYNCED;
+ return ConnectionState.SESSION_SYNCED;
default: {
log.severe("Should never get here.");
- client.close();
- return ThreadState.DISCONNECTED;
+ currentConnection.close();
+ return ConnectionState.DISCONNECTED;
}
}
}
- private void sleepIfProblemsGettingSyncedConnection(ThreadState newState, ThreadState oldState) {
- if (newState == ThreadState.SESSION_SYNCED) return;
- if (newState == ThreadState.CONNECTED && oldState == ThreadState.DISCONNECTED) return;
+ private void sleepIfProblemsGettingSyncedConnection(ConnectionState newState, ConnectionState oldState) {
+ if (newState == ConnectionState.SESSION_SYNCED) return;
+ if (newState == ConnectionState.CONNECTED && oldState == ConnectionState.DISCONNECTED) return;
try {
// Take it easy we have problems getting a connection up.
if (stopSignal.getCount() > 0 || !documentQueue.isEmpty()) {
@@ -391,16 +400,19 @@ class IOThread implements Runnable, AutoCloseable {
@Override
public void run() {
- ThreadState threadState = ThreadState.DISCONNECTED;
- while (stopSignal.getCount() > 0 || !documentQueue.isEmpty()) {
- ThreadState oldState = threadState;
- threadState = cycle(threadState);
- sleepIfProblemsGettingSyncedConnection(threadState, oldState);
-
- }
+ while (stopSignal.getCount() > 0 || !documentQueue.isEmpty())
+ tick();
log.finer(toString() + " exiting, documentQueue.size()=" + documentQueue.size());
running.countDown();
+ }
+ /** Do one iteration of work. Should be called from the single worker thread of this. */
+ public void tick() {
+ ConnectionState oldState = connectionState;
+ connectionState = cycle(connectionState);
+ checkOldConnections();
+ if (thread != null)
+ sleepIfProblemsGettingSyncedConnection(connectionState, oldState);
}
private void drainFirstDocumentsInQueueIfOld() {
@@ -410,14 +422,14 @@ class IOThread implements Runnable, AutoCloseable {
EndpointResult endpointResult = EndPointResultFactory.createTransientError(
endpoint, document.get().getOperationId(),
- new Exception("Not sending document operation, timed out in queue after "
- + document.get().timeInQueueMillis() + " ms."));
+ new Exception("Not sending document operation, timed out in queue after " +
+ (clock.millis() - document.get().getQueueInsertTime().toEpochMilli()) + " ms."));
resultQueue.failOperation(endpointResult, clusterId);
}
}
private void drainDocumentQueueWhenFailingPermanently(Exception exception) {
- //first, clear sentOperations:
+ // first, clear sentOperations:
resultQueue.failPending(exception);
for (Document document : documentQueue.removeAllDocuments()) {
@@ -427,4 +439,92 @@ class IOThread implements Runnable, AutoCloseable {
}
}
+ private boolean isStale(GatewayConnection connection) {
+ return connection.connectionTime() != null
+ && connection.connectionTime().plus(connectionTimeToLive).isBefore(clock.instant());
+ }
+
+ private ConnectionState refreshConnection(ConnectionState currentConnectionState) {
+ if (currentConnectionState == ConnectionState.SESSION_SYNCED)
+ oldConnections.add(currentConnection);
+ currentConnection = connectionFactory.newConnection();
+ return ConnectionState.DISCONNECTED;
+ }
+
+ private void checkOldConnections() {
+ for (Iterator<GatewayConnection> i = oldConnections.iterator(); i.hasNext(); ) {
+ GatewayConnection connection = i.next();
+ if (closingTime(connection).isBefore(clock.instant())) {
+ connection.close();
+ i.remove();
+ }
+ else if (timeToPoll(connection)) {
+ try {
+ processResponse(connection.poll());
+ }
+ catch (Exception e) {
+ // Old connection; best effort
+ }
+ }
+ }
+ }
+
+ private Instant closingTime(GatewayConnection connection) {
+ return connection.connectionTime().plus(connectionTimeToLive).plus(localQueueTimeOut);
+ }
+
+ private boolean timeToPoll(GatewayConnection connection) {
+ if (connection.lastPollTime() == null) return true;
+
+ // Poll less the closer the connection comes to closing time
+ double newness = ( closingTime(connection).toEpochMilli() - clock.millis() ) /
+ (double)localQueueTimeOut.toMillis();
+ if (newness < 0) return true; // connection retired prematurely
+ if (newness > 1) return false; // closing time reached
+ Duration pollInterval = Duration.ofMillis(pollIntervalUS / 1000 +
+ (long)((1 - newness) * ( maxOldConnectionPollInterval.toMillis() - pollIntervalUS / 1000)));
+ return connection.lastPollTime().plus(pollInterval).isBefore(clock.instant());
+ }
+
+ public static class ConnectionStats {
+
+ // NOTE: These fields are accessed by reflection in JSON serialization
+
+ public final int wrongSessionDetectedCounter;
+ public final int wrongVersionDetectedCounter;
+ public final int problemStatusCodeFromServerCounter;
+ public final int executeProblemsCounter;
+ public final int docsReceivedCounter;
+ public final int statusReceivedCounter;
+ public final int pendingDocumentStatusCount;
+ public final int successfullHandshakes;
+ public final int lastGatewayProcessTimeMillis;
+
+ ConnectionStats(int wrongSessionDetectedCounter,
+ int wrongVersionDetectedCounter,
+ int problemStatusCodeFromServerCounter,
+ int executeProblemsCounter,
+ int docsReceivedCounter,
+ int statusReceivedCounter,
+ int pendingDocumentStatusCount,
+ int successfullHandshakes,
+ int lastGatewayProcessTimeMillis) {
+ this.wrongSessionDetectedCounter = wrongSessionDetectedCounter;
+ this.wrongVersionDetectedCounter = wrongVersionDetectedCounter;
+ this.problemStatusCodeFromServerCounter = problemStatusCodeFromServerCounter;
+ this.executeProblemsCounter = executeProblemsCounter;
+ this.docsReceivedCounter = docsReceivedCounter;
+ this.statusReceivedCounter = statusReceivedCounter;
+ this.pendingDocumentStatusCount = pendingDocumentStatusCount;
+ this.successfullHandshakes = successfullHandshakes;
+ this.lastGatewayProcessTimeMillis = lastGatewayProcessTimeMillis;
+ }
+ }
+
+ /** For testing. Returns the current connection of this. Not thread safe. */
+ public GatewayConnection currentConnection() { return currentConnection; }
+
+ /** For testing. Returns a snapshot of the old connections of this. Not thread safe. */
+ public List<GatewayConnection> oldConnections() { return new ArrayList<>(oldConnections); }
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java
index 883cea7e6f0..27ad88c123e 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java
@@ -4,6 +4,7 @@ package com.yahoo.vespa.http.client.core.operationProcessor;
import com.yahoo.vespa.http.client.Result;
import com.yahoo.vespa.http.client.core.Document;
+import java.time.Clock;
import java.util.HashMap;
import java.util.Map;
@@ -18,24 +19,25 @@ class DocumentSendInfo {
// This is lazily populated as normal cases does not require retries.
private Map<Integer, Integer> attemptedRetriesByClusterId = null;
private final StringBuilder localTrace;
+ private final Clock clock;
- DocumentSendInfo(Document document, boolean traceThisDoc) {
+ DocumentSendInfo(Document document, boolean traceThisDoc, Clock clock) {
this.document = document;
- localTrace = traceThisDoc
- ? new StringBuilder("\n" + document.createTimeMillis() + " Trace starting " + "\n")
- : null;
+ localTrace = traceThisDoc ? new StringBuilder("\n" + document.createTime() + " Trace starting " + "\n")
+ : null;
+ this.clock = clock;
}
boolean addIfNotAlreadyThere(Result.Detail detail, int clusterId) {
if (detailByClusterId.containsKey(clusterId)) {
if (localTrace != null) {
- localTrace.append(System.currentTimeMillis() + " Got duplicate detail, ignoring this: "
- + detail.toString() + "\n");
+ localTrace.append(clock.millis() + " Got duplicate detail, ignoring this: " +
+ detail.toString() + "\n");
}
return false;
}
if (localTrace != null) {
- localTrace.append(System.currentTimeMillis() + " Got detail: " + detail.toString() + "\n");
+ localTrace.append(clock.millis() + " Got detail: " + detail.toString() + "\n");
}
detailByClusterId.put(clusterId, detail);
return true;
@@ -60,7 +62,7 @@ class DocumentSendInfo {
retries++;
attemptedRetriesByClusterId.put(clusterId, retries);
if (localTrace != null) {
- localTrace.append(System.currentTimeMillis() + " Asked about retrying for cluster ID "
+ localTrace.append(clock.millis() + " Asked about retrying for cluster ID "
+ clusterId + ", number of retries is " + retries + " Detail:\n" + detail.toString());
}
return retries;
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java
index 205153a7a00..3d662eca3e7 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java
@@ -21,12 +21,11 @@ import java.util.logging.Logger;
*/
public final class EndPointResultFactory {
- private static Logger log = Logger.getLogger(EndPointResultFactory.class.getName());
-
+ private static final Logger log = Logger.getLogger(EndPointResultFactory.class.getName());
private static final String EMPTY_MESSAGE = "-";
- public static Collection<EndpointResult> createResult(
- Endpoint endpoint, InputStream inputStream) throws IOException {
+ public static Collection<EndpointResult> createResult(Endpoint endpoint,
+ InputStream inputStream) throws IOException {
List<EndpointResult> results = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(inputStream, StandardCharsets.US_ASCII))) {
@@ -82,9 +81,9 @@ public final class EndPointResultFactory {
return new EndpointResult(
reply.operationId,
new Result.Detail(endpoint,
- replyToResultType(reply),
- reply.traceMessage,
- exception));
+ replyToResultType(reply),
+ reply.traceMessage,
+ exception));
} catch (Throwable t) {
throw new IllegalArgumentException("Bad result line from server: '" + line + "'", t);
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java
index 7cf4e32a880..ebeee802303 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.http.client.core.operationProcessor;
import com.yahoo.vespa.http.client.core.ThrottlePolicy;
+import java.time.Clock;
import java.util.concurrent.ThreadLocalRandom;
/**
@@ -57,6 +58,7 @@ public class IncompleteResultsThrottler {
/**
* Creates the throttler.
+ *
* @param minInFlightValue the throttler will never throttle beyond this limit.
* @param maxInFlightValue the throttler will never throttle above this limit. If zero, no limit.
* @param clock use to calculate window size. Can be null if minWindowSize and maxInFlightValue are equal.
@@ -68,7 +70,7 @@ public class IncompleteResultsThrottler {
this.policy = policy;
this.clock = clock;
if (minInFlightValue != maxInFlightValue) {
- this.sampleStartTimeMs = clock.getTimeMillis();
+ this.sampleStartTimeMs = clock.millis();
}
setNewSemaphoreSize(INITIAL_MAX_IN_FLIGHT_VALUE);
}
@@ -96,10 +98,6 @@ public class IncompleteResultsThrottler {
}
}
- public interface Clock {
- long getTimeMillis();
- }
-
public void resultReady(boolean success) {
blocker.operationDone();
if (!success) {
@@ -147,9 +145,8 @@ public class IncompleteResultsThrottler {
}
private void adjustThrottling() {
- if (clock.getTimeMillis() < sampleStartTimeMs + phaseSizeMs) {
- return;
- }
+ if (clock.millis() < sampleStartTimeMs + phaseSizeMs) return;
+
sampleStartTimeMs += phaseSizeMs;
if (stabilizingPhasesLeft-- == 0) {
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java
index 692d90abe50..90d07104fef 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java
@@ -15,7 +15,9 @@ import com.yahoo.vespa.http.client.core.communication.ClusterConnection;
import java.math.BigInteger;
import java.security.SecureRandom;
+import java.time.Clock;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
@@ -55,16 +57,19 @@ public class OperationProcessor {
private final boolean traceToStderr;
private final ThreadGroup ioThreadGroup;
private final String clientId = new BigInteger(130, random).toString(32);
+ private final Clock clock;
public OperationProcessor(IncompleteResultsThrottler incompleteResultsThrottler,
FeedClient.ResultCallback resultCallback,
SessionParams sessionParams,
- ScheduledThreadPoolExecutor timeoutExecutor) {
+ ScheduledThreadPoolExecutor timeoutExecutor,
+ Clock clock) {
this.numDestinations = sessionParams.getClusters().size();
this.resultCallback = resultCallback;
this.incompleteResultsThrottler = incompleteResultsThrottler;
this.timeoutExecutor = timeoutExecutor;
this.ioThreadGroup = new ThreadGroup("operationprocessor");
+ this.clock = clock;
if (sessionParams.getClusters().isEmpty())
throw new IllegalArgumentException("Cannot feed to 0 clusters.");
@@ -82,7 +87,8 @@ public class OperationProcessor {
cluster,
i,
sessionParams.getClientQueueSize() / sessionParams.getClusters().size(),
- timeoutExecutor));
+ timeoutExecutor,
+ clock));
}
operationStats = new OperationStats(sessionParams, clusters, incompleteResultsThrottler);
maxRetries = sessionParams.getConnectionParams().getMaxRetries();
@@ -181,7 +187,7 @@ public class OperationProcessor {
}
}
if (blockedDocumentToSend != null) {
- sendToClusters(blockedDocumentToSend);
+ sendToClusters(blockedDocumentToSend, clock);
}
return result;
}
@@ -225,13 +231,13 @@ public class OperationProcessor {
inflightDocumentIds.add(document.getDocumentId());
}
- sendToClusters(document);
+ sendToClusters(document, clock);
}
- private void sendToClusters(Document document) {
+ private void sendToClusters(Document document, Clock clock) {
synchronized (monitor) {
boolean traceThisDoc = traceEveryXOperation > 0 && traceCounter++ % traceEveryXOperation == 0;
- docSendInfoByOperationId.put(document.getOperationId(), new DocumentSendInfo(document, traceThisDoc));
+ docSendInfoByOperationId.put(document.getOperationId(), new DocumentSendInfo(document, traceThisDoc, clock));
}
for (ClusterConnection clusterConnection : clusters) {
@@ -250,6 +256,8 @@ public class OperationProcessor {
}
}
+ public List<ClusterConnection> clusters() { return Collections.unmodifiableList(clusters); }
+
public String getStatsAsJson() {
return operationStats.getStatsAsJson();
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java
index 7c034cab75f..926b4cf8c79 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java
@@ -10,6 +10,7 @@ import com.yahoo.vespa.http.client.core.XmlFeedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
+import java.time.Clock;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
@@ -34,11 +35,11 @@ public class Runner {
boolean isJson,
AtomicInteger numSent,
boolean verbose) {
-
+ Clock clock = Clock.systemUTC();
if (verbose)
System.err.println("Now sending data.");
- long sendStartTime = System.currentTimeMillis();
+ long sendStartTime = clock.millis();
if (isJson) {
JsonReader.read(inputStream, feedClient, numSent);
} else {
@@ -49,7 +50,7 @@ public class Runner {
}
}
- long sendTotalTime = System.currentTimeMillis() - sendStartTime;
+ long sendTotalTime = clock.millis() - sendStartTime;
if (verbose)
System.err.println("Waiting for all results, sent " + numSent.get() + " docs.");
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java
index aa47128f436..b70fbaf3096 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java
@@ -11,6 +11,7 @@ import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
+import java.time.Clock;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
@@ -41,7 +42,7 @@ public class FeedClientTest {
resultsReceived.incrementAndGet();
};
- FeedClient feedClient = new FeedClientImpl(sessionParams, resultCallback, FeedClientFactory.createTimeoutExecutor());
+ FeedClient feedClient = new FeedClientImpl(sessionParams, resultCallback, FeedClientFactory.createTimeoutExecutor(), Clock.systemUTC());
@Test
public void testStreamAndClose() {
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/ManualClock.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/ManualClock.java
new file mode 100644
index 00000000000..b32d1eaa859
--- /dev/null
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/ManualClock.java
@@ -0,0 +1,55 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.http.client;
+
+import java.time.Clock;
+import java.time.Instant;
+import java.time.LocalDateTime;
+import java.time.ZoneId;
+import java.time.ZoneOffset;
+import java.time.format.DateTimeFormatter;
+import java.time.temporal.TemporalAmount;
+
+/**
+ * A clock which initially has the time of its creation but can only be advanced by calling advance
+ *
+ * @author bratseth
+ */
+public class ManualClock extends Clock {
+
+ private Instant currentTime = Instant.now();
+
+ public ManualClock() {}
+
+ public ManualClock(String utcIsoTime) {
+ this(at(utcIsoTime));
+ }
+
+ public ManualClock(Instant currentTime) {
+ this.currentTime = currentTime;
+ }
+
+ public void advance(TemporalAmount temporal) {
+ currentTime = currentTime.plus(temporal);
+ }
+
+ public void setInstant(Instant time) {
+ currentTime = time;
+ }
+
+ @Override
+ public Instant instant() { return currentTime; }
+
+ @Override
+ public ZoneId getZone() { return null; }
+
+ @Override
+ public Clock withZone(ZoneId zone) { return null; }
+
+ @Override
+ public long millis() { return currentTime.toEpochMilli(); }
+
+ public static Instant at(String utcIsoTime) {
+ return LocalDateTime.parse(utcIsoTime, DateTimeFormatter.ISO_DATE_TIME).atZone(ZoneOffset.UTC).toInstant();
+ }
+
+}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java
index 1f875e0dd72..0813cb36078 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java
@@ -12,6 +12,7 @@ import org.junit.Test;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import java.time.Clock;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -24,6 +25,7 @@ import static com.yahoo.vespa.http.client.TestUtils.writeDocument;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.core.IsEqual.equalTo;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
@@ -78,7 +80,8 @@ public class QueueBoundsTest {
.build())
.setClientQueueSize(2)
.build(),
- SessionFactory.createTimeoutExecutor())) {
+ SessionFactory.createTimeoutExecutor(),
+ Clock.systemUTC())) {
FeederThread feeder = new FeederThread(session);
try {
feeder.start();
@@ -122,7 +125,8 @@ public class QueueBoundsTest {
.setNumPersistentConnectionsPerEndpoint(1)
.build())
.setClientQueueSize(6) //3 per cluster
- .build(), SessionFactory.createTimeoutExecutor())) {
+ .build(), SessionFactory.createTimeoutExecutor(),
+ Clock.systemUTC())) {
FeederThread feeder = new FeederThread(session);
try {
@@ -210,22 +214,23 @@ public class QueueBoundsTest {
.build())
.setClientQueueSize(1)
.build(),
- SessionFactory.createTimeoutExecutor())) {
+ SessionFactory.createTimeoutExecutor(),
+ Clock.systemUTC())) {
FeederThread feeder = new FeederThread(session);
feeder.start();
try {
{
System.out.println("We start with failed connection, post a document.");
assertFeedNotBlocking(feeder, 0);
- assertThat(session.results().size(), is(0));
+ assertEquals(0, session.results().size());
CountDownLatch lastPostFeed = assertFeedBlocking(feeder, 1);
System.out.println("No result so far.");
- assertThat(session.results().size(), is(0));
+ assertEquals(0, session.results().size());
System.out.println("Make connection ok.");
mockXmlParsingRequestHandler.setScenario(V3MockParsingRequestHandler.Scenario.ALL_OK);
assert(lastPostFeed.await(120, TimeUnit.SECONDS));
- assertThat(lastPostFeed.getCount(), equalTo(0L));
+ assertEquals(0L, lastPostFeed.getCount());
assertResultQueueSize(session, 2, 120, TimeUnit.SECONDS);
}
@@ -235,7 +240,7 @@ public class QueueBoundsTest {
{
assertFeedNotBlocking(feeder, 2);
System.out.println("Fed one document, fit in queue.");
- assertThat(session.results().size(), is(2));
+ assertEquals(2, session.results().size());
System.out.println("Fed one document more, wait for failure.");
assertFeedNotBlocking(feeder, 3);
@@ -249,12 +254,12 @@ public class QueueBoundsTest {
}
int errors = 0;
for (Result result : session.results()) {
- assertThat(result.getDetails().size(), is(1));
+ assertEquals(1, result.getDetails().size());
if (! result.isSuccess()) {
errors++;
}
}
- assertThat(errors, is(1));
+ assertEquals(1, errors);
} finally {
feeder.stop();
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java
index 0821fa55e06..79a91d0b5f3 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java
@@ -5,8 +5,7 @@ import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.handler.AbstractHandler;
/**
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
- * @since 5.1.20
+ * @author Einar M R Rosenvinge
*/
public final class Server implements AutoCloseable {
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java
index 1d70ce953e4..780de3e695c 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java
@@ -18,10 +18,10 @@ import java.util.Map;
import java.util.concurrent.TimeUnit;
import static com.yahoo.vespa.http.client.TestUtils.getResults;
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.CoreMatchers.not;
-import static org.hamcrest.CoreMatchers.nullValue;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
*
@@ -79,34 +79,33 @@ public class V3HttpAPITest {
writeDocument(session);
Map<String, Result> results = getResults(session, 1);
- assertThat(results.size(), is(1));
+ assertEquals(1, results.size());
TestDocument document = documents.get(0);
Result r = results.remove(document.getDocumentId());
- assertThat(r, not(nullValue()));
- if (conditionNotMet) {
- assertThat(r.getDetails().iterator().next().getResultType(), is(Result.ResultType.CONDITION_NOT_MET));
- }
- assertThat(r.getDetails().toString(), r.isSuccess(), is(false));
- assertThat(results.isEmpty(), is(true));
+ assertNotNull(r);
+ if (conditionNotMet)
+ assertEquals(Result.ResultType.CONDITION_NOT_MET, r.getDetails().iterator().next().getResultType());
+ assertFalse(r.getDetails().toString(), r.isSuccess());
+ assertTrue(results.isEmpty());
}
}
@Test
- public void requireThatSingleDestinationWorks() throws Exception {
+ public void testSingleDestination() throws Exception {
try (Server server = new Server(new V3MockParsingRequestHandler(), 0);
- Session session = SessionFactory.create(Endpoint.create("localhost", server.getPort(), false))) {
+ Session session = SessionFactory.create(Endpoint.create("localhost", server.getPort(), false))) {
writeDocuments(session);
Map<String, Result> results = getResults(session, documents.size());
- assertThat(results.size(), is(documents.size()));
+ assertEquals(documents.size(), results.size());
for (TestDocument document : documents) {
Result r = results.remove(document.getDocumentId());
- assertThat(r, not(nullValue()));
- assertThat(r.getDetails().toString(), r.isSuccess(), is(true));
+ assertNotNull(r);
+ assertTrue(r.getDetails().toString(), r.isSuccess());
}
- assertThat(results.isEmpty(), is(true));
+ assertTrue(results.isEmpty());
}
}
@@ -169,15 +168,15 @@ public class V3HttpAPITest {
writeDocuments(session);
Map<String, Result> results = getResults(session, documents.size());
- assertThat(results.size(), is(documents.size()));
+ assertEquals(documents.size(), results.size());
for (TestDocument document : documents) {
Result r = results.remove(document.getDocumentId());
- assertThat(r, not(nullValue()));
- assertThat(r.getDetails().toString(), r.isSuccess(), is(false));
- assertThat(r.getDetails().iterator().next().getResultType(), is(Result.ResultType.TRANSITIVE_ERROR));
+ assertNotNull(r);
+ assertFalse(r.getDetails().toString(), r.isSuccess());
+ assertEquals(Result.ResultType.TRANSITIVE_ERROR, r.getDetails().iterator().next().getResultType());
}
- assertThat(results.isEmpty(), is(true));
+ assertTrue(results.isEmpty());
}
}
@@ -197,4 +196,5 @@ public class V3HttpAPITest {
testServerWithMock(new V3MockParsingRequestHandler(
200, V3MockParsingRequestHandler.Scenario.CONDITON_NOT_MET), false, true);
}
+
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java
index b5c03eade51..ee2f021df6a 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java
@@ -5,9 +5,9 @@ import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ReadOnlyBufferException;
+import java.time.Clock;
-import static org.hamcrest.core.Is.is;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertEquals;
public class DocumentTest {
@@ -15,25 +15,25 @@ public class DocumentTest {
public void simpleCaseOk() {
String docId = "doc id";
String docContent = "foo";
- Document document = new Document(docId, docContent.getBytes(), null);
- assertThat(document.getDocumentId(), is(docId));
- assertThat(document.getData(), is(ByteBuffer.wrap(docContent.getBytes())));
- assertThat(document.getDataAsString().toString(), is(docContent));
+ Document document = new Document(docId, docContent.getBytes(), null, Clock.systemUTC().instant());
+ assertEquals(docId, document.getDocumentId());
+ assertEquals(ByteBuffer.wrap(docContent.getBytes()), document.getData());
+ assertEquals(docContent, document.getDataAsString().toString());
// Make sure that data is not modified on retrieval.
- assertThat(document.getDataAsString().toString(), is(docContent));
- assertThat(document.getData(), is(ByteBuffer.wrap(docContent.getBytes())));
- assertThat(document.getDocumentId(), is(docId));
+ assertEquals(docContent, document.getDataAsString().toString());
+ assertEquals(ByteBuffer.wrap(docContent.getBytes()), document.getData());
+ assertEquals(docId, document.getDocumentId());
}
@Test(expected = ReadOnlyBufferException.class)
public void notMutablePutTest() {
- Document document = new Document("id", null, "data", null /* context */);
+ Document document = new Document("id", null, "data", null, Clock.systemUTC().instant());
document.getData().put("a".getBytes());
}
@Test(expected = ReadOnlyBufferException.class)
public void notMutableCompactTest() {
- Document document = new Document("id", null, "data", null /* context */);
+ Document document = new Document("id", null, "data", null, Clock.systemUTC().instant());
document.getData().compact();
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
index 59a8b613e67..511e40c1c88 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
@@ -14,9 +14,10 @@ import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.ParseException;
import org.apache.http.StatusLine;
-import org.apache.http.client.HttpClient;
+import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.InputStreamEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicHeader;
import org.junit.Rule;
import org.junit.Test;
@@ -27,13 +28,12 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
+import java.time.Clock;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.any;
@@ -42,7 +42,6 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
-
public class ApacheGatewayConnectionTest {
@Rule
@@ -50,20 +49,18 @@ public class ApacheGatewayConnectionTest {
@Test
public void testProtocolV3() throws Exception {
- final Endpoint endpoint = Endpoint.create("localhost", 666, false);
- final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build();
- final String clusterSpecificRoute = "";
- final ConnectionParams connectionParams = new ConnectionParams.Builder()
- .build();
- final List<Document> documents = new ArrayList<>();
+ Endpoint endpoint = Endpoint.create("localhost", 666, false);
+ FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build();
+ String clusterSpecificRoute = "";
+ ConnectionParams connectionParams = new ConnectionParams.Builder().build();
+ List<Document> documents = new ArrayList<>();
- final String vespaDocContent = "Hello, I a JSON doc.";
- final String docId = "42";
+ String vespaDocContent = "Hello, I a JSON doc.";
+ String docId = "42";
- final AtomicInteger requestsReceived = new AtomicInteger(0);
// This is the fake server, takes header client ID and uses this as session Id.
ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
- final Header clientIdHeader = post.getFirstHeader(Headers.CLIENT_ID);
+ Header clientIdHeader = post.getFirstHeader(Headers.CLIENT_ID);
return httpResponse(clientIdHeader.getValue(), "3");
});
@@ -74,21 +71,21 @@ public class ApacheGatewayConnectionTest {
clusterSpecificRoute,
connectionParams,
mockFactory,
- "clientId");
+ "clientId",
+ Clock.systemUTC());
apacheGatewayConnection.connect();
apacheGatewayConnection.handshake();
documents.add(createDoc(docId, vespaDocContent, true));
- apacheGatewayConnection.writeOperations(documents);
+ apacheGatewayConnection.write(documents);
}
@Test(expected=IllegalArgumentException.class)
public void testServerReturnsBadSessionInV3() throws Exception {
- final Endpoint endpoint = Endpoint.create("localhost", 666, false);
- final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build();
- final String clusterSpecificRoute = "";
- final ConnectionParams connectionParams = new ConnectionParams.Builder()
- .build();
+ Endpoint endpoint = Endpoint.create("localhost", 666, false);
+ FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build();
+ String clusterSpecificRoute = "";
+ ConnectionParams connectionParams = new ConnectionParams.Builder().build();
// This is the fake server, returns wrong session Id.
ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> httpResponse("Wrong Id from server", "3"));
@@ -100,57 +97,36 @@ public class ApacheGatewayConnectionTest {
clusterSpecificRoute,
connectionParams,
mockFactory,
- "clientId");
+ "clientId",
+ Clock.systemUTC());
apacheGatewayConnection.connect();
- final List<Document> documents = new ArrayList<>();
- apacheGatewayConnection.writeOperations(documents);
- }
-
- @Test(expected=RuntimeException.class)
- public void testBadConfigParameters() throws Exception {
- final Endpoint endpoint = Endpoint.create("localhost", 666, false);
- final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build();
- final String clusterSpecificRoute = "";
- final ConnectionParams connectionParams = new ConnectionParams.Builder()
- .build();
-
- final ApacheGatewayConnection.HttpClientFactory mockFactory =
- mock(ApacheGatewayConnection.HttpClientFactory.class);
-
- new ApacheGatewayConnection(
- endpoint,
- feedParams,
- clusterSpecificRoute,
- connectionParams,
- mockFactory,
- null);
+ List<Document> documents = new ArrayList<>();
+ apacheGatewayConnection.write(documents);
}
@Test
public void testJsonDocumentHeader() throws Exception {
- final Endpoint endpoint = Endpoint.create("localhost", 666, false);
- final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build();
- final String clusterSpecificRoute = "";
- final ConnectionParams connectionParams = new ConnectionParams.Builder()
- .setUseCompression(true)
- .build();
- final List<Document> documents = new ArrayList<>();
+ Endpoint endpoint = Endpoint.create("localhost", 666, false);
+ FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build();
+ String clusterSpecificRoute = "";
+ ConnectionParams connectionParams = new ConnectionParams.Builder().setUseCompression(true).build();
+ List<Document> documents = new ArrayList<>();
- final String vespaDocContent ="Hello, I a JSON doc.";
- final String docId = "42";
+ String vespaDocContent ="Hello, I a JSON doc.";
+ String docId = "42";
- final AtomicInteger requestsReceived = new AtomicInteger(0);
+ AtomicInteger requestsReceived = new AtomicInteger(0);
// This is the fake server, checks that DATA_FORMAT header is set properly.
ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
- final Header header = post.getFirstHeader(Headers.DATA_FORMAT);
+ Header header = post.getFirstHeader(Headers.DATA_FORMAT);
if (requestsReceived.incrementAndGet() == 1) {
// This is handshake, it is not json.
assert (header == null);
return httpResponse("clientId", "3");
}
assertNotNull(header);
- assertThat(header.getValue(), is(FeedParams.DataFormat.JSON_UTF8.name()));
+ assertEquals(FeedParams.DataFormat.JSON_UTF8.name(), header.getValue());
// Test is done.
return httpResponse("clientId", "3");
});
@@ -162,24 +138,25 @@ public class ApacheGatewayConnectionTest {
clusterSpecificRoute,
connectionParams,
mockFactory,
- "clientId");
+ "clientId",
+ Clock.systemUTC());
apacheGatewayConnection.connect();
apacheGatewayConnection.handshake();
documents.add(createDoc(docId, vespaDocContent, true));
- apacheGatewayConnection.writeOperations(documents);
+ apacheGatewayConnection.write(documents);
}
@Test
public void testZipAndCreateEntity() throws IOException {
- final String testString = "Hello world";
+ String testString = "Hello world";
InputStream stream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8));
// Send in test data to method.
InputStreamEntity inputStreamEntity = ApacheGatewayConnection.zipAndCreateEntity(stream);
// Verify zipped data by comparing unzipped data with test data.
- final String rawContent = TestUtils.zipStreamToString(inputStreamEntity.getContent());
- assert(testString.equals(rawContent));
+ String rawContent = TestUtils.zipStreamToString(inputStreamEntity.getContent());
+ assertEquals(testString, rawContent);
}
/**
@@ -187,32 +164,28 @@ public class ApacheGatewayConnectionTest {
*/
@Test
public void testCompressedWriteOperations() throws Exception {
- final Endpoint endpoint = Endpoint.create("localhost", 666, false);
- final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.XML_UTF8).build();
- final String clusterSpecificRoute = "";
- final ConnectionParams connectionParams = new ConnectionParams.Builder()
- .setUseCompression(true)
- .build();
- final List<Document> documents = new ArrayList<>();
+ Endpoint endpoint = Endpoint.create("localhost", 666, false);
+ FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.XML_UTF8).build();
+ String clusterSpecificRoute = "";
+ ConnectionParams connectionParams = new ConnectionParams.Builder().setUseCompression(true).build();
+ List<Document> documents = new ArrayList<>();
- final String vespaDocContent ="Hello, I am the document data.";
- final String docId = "42";
+ String vespaDocContent ="Hello, I am the document data.";
+ String docId = "42";
- final Document doc = createDoc(docId, vespaDocContent, false);
+ Document doc = createDoc(docId, vespaDocContent, false);
// When sending data on http client, check if it is compressed. If compressed, unzip, check result,
// and count down latch.
ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
- final Header header = post.getFirstHeader("Content-Encoding");
+ Header header = post.getFirstHeader("Content-Encoding");
if (header != null && header.getValue().equals("gzip")) {
final String rawContent = TestUtils.zipStreamToString(post.getEntity().getContent());
final String vespaHeaderText = "<vespafeed>\n";
final String vespaFooterText = "</vespafeed>\n";
- assertThat(rawContent, is(
- doc.getOperationId() + " 38\n" + vespaHeaderText + vespaDocContent + "\n"
- + vespaFooterText));
-
+ assertEquals(doc.getOperationId() + " 38\n" + vespaHeaderText + vespaDocContent + "\n" + vespaFooterText,
+ rawContent);
}
return httpResponse("clientId", "3");
});
@@ -227,13 +200,14 @@ public class ApacheGatewayConnectionTest {
clusterSpecificRoute,
connectionParams,
mockFactory,
- "clientId");
+ "clientId",
+ Clock.systemUTC());
apacheGatewayConnection.connect();
apacheGatewayConnection.handshake();
documents.add(doc);
- apacheGatewayConnection.writeOperations(documents);
+ apacheGatewayConnection.write(documents);
}
@Test
@@ -265,14 +239,15 @@ public class ApacheGatewayConnectionTest {
"",
connectionParams,
mockFactory,
- "clientId");
+ "clientId",
+ Clock.systemUTC());
apacheGatewayConnection.connect();
apacheGatewayConnection.handshake();
List<Document> documents = new ArrayList<>();
documents.add(createDoc("42", "content", true));
- apacheGatewayConnection.writeOperations(documents);
- apacheGatewayConnection.writeOperations(documents);
+ apacheGatewayConnection.write(documents);
+ apacheGatewayConnection.write(documents);
verify(headerProvider, times(3)).getHeaderValue(); // 1x connect(), 2x writeOperations()
}
@@ -293,17 +268,18 @@ public class ApacheGatewayConnectionTest {
"",
new ConnectionParams.Builder().build(),
mockFactory,
- "clientId");
+ "clientId",
+ Clock.systemUTC());
apacheGatewayConnection.connect();
apacheGatewayConnection.handshake();
- apacheGatewayConnection.writeOperations(Collections.singletonList(createDoc("42", "content", true)));
+ apacheGatewayConnection.write(Collections.singletonList(createDoc("42", "content", true)));
}
private static ApacheGatewayConnection.HttpClientFactory mockHttpClientFactory(HttpExecuteMock httpExecuteMock) throws IOException {
ApacheGatewayConnection.HttpClientFactory mockFactory =
mock(ApacheGatewayConnection.HttpClientFactory.class);
- HttpClient httpClientMock = mock(HttpClient.class);
+ CloseableHttpClient httpClientMock = mock(CloseableHttpClient.class);
when(mockFactory.createClient()).thenReturn(httpClientMock);
when(httpClientMock.execute(any())).thenAnswer((Answer) invocation -> {
Object[] args = invocation.getArguments();
@@ -317,16 +293,12 @@ public class ApacheGatewayConnectionTest {
HttpResponse execute(HttpPost httpPost) throws IOException;
}
- private Document createDoc(final String docId, final String content, boolean useJson) throws IOException {
- return new Document(docId, content.getBytes(), null /* context */);
+ private Document createDoc(String docId, String content, boolean useJson) {
+ return new Document(docId, content.getBytes(), null, Clock.systemUTC().instant());
}
- private void addMockedHeader(
- final HttpResponse httpResponseMock,
- final String name,
- final String value,
- HeaderElement[] elements) {
- final Header header = new Header() {
+ private void addMockedHeader(HttpResponse httpResponseMock, String name, String value, HeaderElement[] elements) {
+ Header header = new Header() {
@Override
public String getName() {
return name;
@@ -344,7 +316,7 @@ public class ApacheGatewayConnectionTest {
}
private HttpResponse httpResponse(String sessionIdInResult, String version) throws IOException {
- final HttpResponse httpResponseMock = mock(HttpResponse.class);
+ CloseableHttpResponse httpResponseMock = mock(CloseableHttpResponse.class);
StatusLine statusLineMock = mock(StatusLine.class);
when(httpResponseMock.getStatusLine()).thenReturn(statusLineMock);
@@ -365,7 +337,7 @@ public class ApacheGatewayConnectionTest {
}
private static HttpResponse createErrorHttpResponse(int statusCode, String reasonPhrase, String message) throws IOException {
- HttpResponse response = mock(HttpResponse.class);
+ CloseableHttpResponse response = mock(CloseableHttpResponse.class);
StatusLine statusLine = mock(StatusLine.class);
when(statusLine.getStatusCode()).thenReturn(statusCode);
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java
index 35a06258f86..af354b8feea 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java
@@ -4,21 +4,24 @@ package com.yahoo.vespa.http.client.core.communication;
import com.yahoo.vespa.http.client.core.Document;
import org.junit.Test;
+import java.time.Clock;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class CloseableQTestCase {
+
@Test
public void requestThatPutIsInterruptedOnClose() throws InterruptedException {
- final DocumentQueue q = new DocumentQueue(1);
- q.put(new Document("id", null, "data", null), false);
+ Clock clock = Clock.systemUTC();
+ DocumentQueue q = new DocumentQueue(1, clock);
+ q.put(new Document("id", null, "data", null, clock.instant()), false);
Thread t = new Thread(new Runnable() {
@Override
public void run() {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
-
}
q.close();
q.clear();
@@ -26,7 +29,7 @@ public class CloseableQTestCase {
});
t.start();
try {
- q.put(new Document("id2", null, "data2", null), false);
+ q.put(new Document("id2", null, "data2", null, Clock.systemUTC().instant()), false);
fail("This shouldn't have worked.");
} catch (IllegalStateException ise) {
// ok!
@@ -39,10 +42,11 @@ public class CloseableQTestCase {
@Test
public void requireThatSelfIsUnbounded() throws InterruptedException {
- DocumentQueue q = new DocumentQueue(1);
- q.put(new Document("1", null, "data", null), true);
- q.put(new Document("2", null, "data", null), true);
- q.put(new Document("3", null, "data", null), true);
+ DocumentQueue q = new DocumentQueue(1, Clock.systemUTC());
+ q.put(new Document("1", null, "data", null, Clock.systemUTC().instant()), true);
+ q.put(new Document("2", null, "data", null, Clock.systemUTC().instant()), true);
+ q.put(new Document("3", null, "data", null, Clock.systemUTC().instant()), true);
assertEquals(3, q.size());
}
+
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java
index 0005bddeb73..da82079e992 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java
@@ -20,8 +20,7 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
/**
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
- * @since 5.1.22
+ * @author Einar M R Rosenvinge
*/
public class EndpointResultQueueTest {
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java
index e81638ded1c..59fb968906f 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java
@@ -4,6 +4,7 @@ package com.yahoo.vespa.http.client.core.communication;
import com.yahoo.vespa.http.client.FeedConnectException;
import com.yahoo.vespa.http.client.FeedEndpointException;
import com.yahoo.vespa.http.client.FeedProtocolException;
+import com.yahoo.vespa.http.client.ManualClock;
import com.yahoo.vespa.http.client.Result;
import com.yahoo.vespa.http.client.V3HttpAPITest;
import com.yahoo.vespa.http.client.config.Endpoint;
@@ -16,6 +17,8 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
+import java.time.Clock;
+import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
@@ -35,21 +38,27 @@ import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
+// DO NOT ADD TESTS HERE, add to NewIOThreadTest
public class IOThreadTest {
private static final Endpoint ENDPOINT = Endpoint.create("myhost");
+ final Clock clock = Clock.systemUTC();
final EndpointResultQueue endpointResultQueue = mock(EndpointResultQueue.class);
final ApacheGatewayConnection apacheGatewayConnection = mock(ApacheGatewayConnection.class);
final String exceptionMessage = "SOME EXCEPTION FOO";
CountDownLatch latch = new CountDownLatch(1);
String docId1 = V3HttpAPITest.documents.get(0).getDocumentId();
Document doc1 = new Document(V3HttpAPITest.documents.get(0).getDocumentId(),
- V3HttpAPITest.documents.get(0).getContents(), null /* context */);
+ V3HttpAPITest.documents.get(0).getContents(),
+ null,
+ clock.instant());
String docId2 = V3HttpAPITest.documents.get(1).getDocumentId();
Document doc2 = new Document(V3HttpAPITest.documents.get(1).getDocumentId(),
- V3HttpAPITest.documents.get(1).getContents(), null /* context */);
- DocumentQueue documentQueue = new DocumentQueue(4);
+ V3HttpAPITest.documents.get(1).getContents(),
+ null,
+ clock.instant());
+ DocumentQueue documentQueue = new DocumentQueue(4, clock);
public IOThreadTest() {
when(apacheGatewayConnection.getEndpoint()).thenReturn(ENDPOINT);
@@ -57,20 +66,18 @@ public class IOThreadTest {
/**
* Set up mock so that it can handle both failDocument() and resultReceived().
+ *
* @param expectedDocIdFail on failure, this has to be the doc id, or the mock will fail.
* @param expectedDocIdOk on ok, this has to be the doc id, or the mock will fail.
* @param isTransient checked on failure, if different, the mock will fail.
* @param expectedException checked on failure, if exception toString is different, the mock will fail.
*/
- void setupEndpointResultQueueMock(String expectedDocIdFail, String expectedDocIdOk,boolean isTransient, String expectedException) {
-
+ void setupEndpointResultQueueMock(String expectedDocIdFail, String expectedDocIdOk, boolean isTransient, String expectedException) {
doAnswer(invocation -> {
EndpointResult endpointResult = (EndpointResult) invocation.getArguments()[0];
assertThat(endpointResult.getOperationId(), is(expectedDocIdFail));
- assertThat(endpointResult.getDetail().getException().toString(),
- containsString(expectedException));
- assertThat(endpointResult.getDetail().getResultType(), is(
- isTransient ? Result.ResultType.TRANSITIVE_ERROR : Result.ResultType.FATAL_ERROR));
+ assertThat(endpointResult.getDetail().getException().toString(), containsString(expectedException));
+ assertThat(endpointResult.getDetail().getResultType(), is(isTransient ? Result.ResultType.TRANSITIVE_ERROR : Result.ResultType.FATAL_ERROR));
latch.countDown();
return null;
@@ -86,7 +93,20 @@ public class IOThreadTest {
}
private IOThread createIOThread(int maxInFlightRequests, long localQueueTimeOut) {
- return new IOThread(null, endpointResultQueue, apacheGatewayConnection, 0, 0, maxInFlightRequests, localQueueTimeOut, documentQueue, 0, 10);
+ return new IOThread(null,
+ ENDPOINT,
+ endpointResultQueue,
+ new SingletonGatewayConnectionFactory(apacheGatewayConnection),
+ 0,
+ 0,
+ maxInFlightRequests,
+ Duration.ofMillis(localQueueTimeOut),
+ documentQueue,
+ 0,
+ Duration.ofSeconds(15),
+ true,
+ 10,
+ clock);
}
@Test
@@ -94,7 +114,7 @@ public class IOThreadTest {
when(apacheGatewayConnection.connect()).thenReturn(true);
InputStream serverResponse = new ByteArrayInputStream(
(docId1 + " OK Doc{20}fed").getBytes(StandardCharsets.UTF_8));
- when(apacheGatewayConnection.writeOperations(any())).thenReturn(serverResponse);
+ when(apacheGatewayConnection.write(any())).thenReturn(serverResponse);
setupEndpointResultQueueMock( "nope", docId1, true, exceptionMessage);
try (IOThread ioThread = createIOThread(10000, 10000)) {
ioThread.post(doc1);
@@ -103,9 +123,9 @@ public class IOThreadTest {
}
@Test
- public void requireThatSingleDocumentWriteErrorIsHandledProperly() throws Exception {
+ public void testDocumentWriteError() throws Exception {
when(apacheGatewayConnection.connect()).thenReturn(true);
- when(apacheGatewayConnection.writeOperations(any())).thenThrow(new IOException(exceptionMessage));
+ when(apacheGatewayConnection.write(any())).thenThrow(new IOException(exceptionMessage));
setupEndpointResultQueueMock(doc1.getOperationId(), "nope", true, exceptionMessage);
try (IOThread ioThread = createIOThread(10000, 10000)) {
ioThread.post(doc1);
@@ -114,11 +134,11 @@ public class IOThreadTest {
}
@Test
- public void requireThatTwoDocumentsFirstWriteErrorSecondOkIsHandledProperly() throws Exception {
+ public void testTwoDocumentsFirstWriteErrorSecondOk() throws Exception {
when(apacheGatewayConnection.connect()).thenReturn(true);
InputStream serverResponse = new ByteArrayInputStream(
(docId2 + " OK Doc{20}fed").getBytes(StandardCharsets.UTF_8));
- when(apacheGatewayConnection.writeOperations(any()))
+ when(apacheGatewayConnection.write(any()))
.thenThrow(new IOException(exceptionMessage))
.thenReturn(serverResponse);
latch = new CountDownLatch(2);
@@ -134,10 +154,8 @@ public class IOThreadTest {
@Test
public void testQueueTimeOutNoNoConnectionToServer() throws Exception {
when(apacheGatewayConnection.connect()).thenReturn(false);
- InputStream serverResponse = new ByteArrayInputStream(
- ("").getBytes(StandardCharsets.UTF_8));
- when(apacheGatewayConnection.writeOperations(any()))
- .thenReturn(serverResponse);
+ InputStream serverResponse = new ByteArrayInputStream(("").getBytes(StandardCharsets.UTF_8));
+ when(apacheGatewayConnection.write(any())).thenReturn(serverResponse);
setupEndpointResultQueueMock(doc1.getOperationId(), "nope", true,
"java.lang.Exception: Not sending document operation, timed out in queue after");
try (IOThread ioThread = createIOThread(10, 10)) {
@@ -147,7 +165,7 @@ public class IOThreadTest {
}
@Test
- public void requireThatEndpointProtocolExceptionsArePropagated()
+ public void testEndpointProtocolExceptionPropagation()
throws IOException, ServerResponseException, InterruptedException, TimeoutException, ExecutionException {
when(apacheGatewayConnection.connect()).thenReturn(true);
int errorCode = 403;
@@ -168,7 +186,7 @@ public class IOThreadTest {
}
@Test
- public void requireThatEndpointConnectExceptionsArePropagated()
+ public void testEndpointConnectExceptionsPropagation()
throws IOException, ServerResponseException, InterruptedException, TimeoutException, ExecutionException {
when(apacheGatewayConnection.connect()).thenReturn(true);
String errorMessage = "generic error message";
@@ -198,4 +216,17 @@ public class IOThreadTest {
return futureResult;
}
+ private static final class SingletonGatewayConnectionFactory implements GatewayConnectionFactory {
+
+ private final GatewayConnection singletonConnection;
+
+ SingletonGatewayConnectionFactory(GatewayConnection singletonConnection) {
+ this.singletonConnection = singletonConnection;
+ }
+
+ @Override
+ public GatewayConnection newConnection() { return singletonConnection; }
+
+ }
+
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/NewIOThreadTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/NewIOThreadTest.java
new file mode 100644
index 00000000000..615fa22a6cf
--- /dev/null
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/NewIOThreadTest.java
@@ -0,0 +1,192 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.http.client.core.communication;
+
+import com.yahoo.vespa.http.client.FeedClient;
+import com.yahoo.vespa.http.client.FeedEndpointException;
+import com.yahoo.vespa.http.client.ManualClock;
+import com.yahoo.vespa.http.client.Result;
+import com.yahoo.vespa.http.client.config.Cluster;
+import com.yahoo.vespa.http.client.config.ConnectionParams;
+import com.yahoo.vespa.http.client.config.Endpoint;
+import com.yahoo.vespa.http.client.config.SessionParams;
+import com.yahoo.vespa.http.client.core.Document;
+import com.yahoo.vespa.http.client.core.EndpointResult;
+import com.yahoo.vespa.http.client.core.ThrottlePolicy;
+import com.yahoo.vespa.http.client.core.operationProcessor.IncompleteResultsThrottler;
+import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor;
+import org.junit.Test;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.List;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotSame;
+
+/**
+ * TODO: Migrate IOThreadTests here.
+ *
+ * @author bratseth
+ */
+public class NewIOThreadTest {
+
+ @Test
+ public void testBasics() {
+ OperationProcessorTester tester = new OperationProcessorTester();
+ assertEquals(0, tester.inflight());
+ assertEquals(0, tester.success());
+ assertEquals(0, tester.failures());
+ tester.send("doc1");
+ tester.send("doc2");
+ tester.send("doc3");
+ assertEquals(3, tester.inflight());
+ assertEquals(0, tester.success());
+ assertEquals(0, tester.failures());
+ tester.success("doc1");
+ tester.success("doc2");
+ tester.success("doc3");
+ assertEquals(0, tester.inflight());
+ assertEquals(3, tester.success());
+ assertEquals(0, tester.failures());
+ }
+
+ @Test
+ public void testPollingOldConnections() {
+ OperationProcessorTester tester = new OperationProcessorTester();
+ tester.tick(3);
+
+ assertEquals(1, tester.clusterConnections().size());
+ assertEquals(1, tester.clusterConnections().get(0).ioThreads().size());
+ IOThread ioThread = tester.clusterConnections().get(0).ioThreads().get(0);
+ DryRunGatewayConnection firstConnection = (DryRunGatewayConnection)ioThread.currentConnection();
+ assertEquals(0, ioThread.oldConnections().size());
+
+ firstConnection.hold(true);
+ tester.send("doc1");
+ tester.tick(1);
+
+ tester.clock().advance(Duration.ofSeconds(20)); // Default connection ttl is 15
+ tester.tick(3);
+
+ assertEquals(1, ioThread.oldConnections().size());
+ assertEquals(firstConnection, ioThread.oldConnections().get(0));
+ assertNotSame(firstConnection, ioThread.currentConnection());
+ assertEquals(20, firstConnection.lastPollTime().toEpochMilli() / 1000);
+
+ // Check old connection poll pattern (linear backoff)
+ assertLastPollTimeWhenAdvancing(21, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(22, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(23, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(24, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(24, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(26, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(26, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(28, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(28, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(30, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(30, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(32, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(32, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(34, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(34, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(34, 1, firstConnection, tester);
+ assertLastPollTimeWhenAdvancing(37, 1, firstConnection, tester);
+
+ tester.clock().advance(Duration.ofSeconds(200));
+ tester.tick(1);
+ assertEquals("Old connection is eventually removed", 0, ioThread.oldConnections().size());
+ }
+
+ private void assertLastPollTimeWhenAdvancing(int lastPollTimeSeconds,
+ int advanceSeconds,
+ DryRunGatewayConnection connection,
+ OperationProcessorTester tester) {
+ tester.clock().advance(Duration.ofSeconds(advanceSeconds));
+ tester.tick(1);
+ assertEquals(lastPollTimeSeconds, connection.lastPollTime().toEpochMilli() / 1000);
+ }
+
+ private static class OperationProcessorTester {
+
+ private final Endpoint endpoint;
+ private final int clusterId = 0;
+ private final ManualClock clock;
+ private final TestResultCallback resultCallback;
+ private final OperationProcessor operationProcessor;
+
+ public OperationProcessorTester() {
+ endpoint = Endpoint.create("test-endpoint");
+ SessionParams.Builder params = new SessionParams.Builder();
+ Cluster.Builder clusterParams = new Cluster.Builder();
+ clusterParams.addEndpoint(endpoint);
+ params.addCluster(clusterParams.build());
+ ConnectionParams.Builder connectionParams = new ConnectionParams.Builder();
+ connectionParams.setDryRun(true);
+ connectionParams.setRunThreads(false);
+ params.setConnectionParams(connectionParams.build());
+
+ clock = new ManualClock(Instant.ofEpochMilli(0));
+ resultCallback = new TestResultCallback();
+ operationProcessor = new OperationProcessor(new IncompleteResultsThrottler(1, 100, clock, new ThrottlePolicy()),
+ resultCallback,
+ params.build(),
+ new ScheduledThreadPoolExecutor(1),
+ clock);
+ }
+
+ public ManualClock clock() { return clock; }
+
+ /** Do n iteration of work in all io threads of this */
+ public void tick(int n) {
+ for (int i = 0; i < n; i++)
+ for (ClusterConnection cluster : operationProcessor.clusters())
+ for (IOThread thread : cluster.ioThreads())
+ thread.tick();
+ }
+
+ public void send(String documentId) {
+ operationProcessor.sendDocument(new Document(documentId, documentId, "data of " + documentId, null, clock.instant()));
+ }
+
+ public void success(String documentId) {
+ operationProcessor.resultReceived(new EndpointResult(documentId, new Result.Detail(endpoint)), clusterId);
+ }
+
+ public int inflight() {
+ return operationProcessor.getIncompleteResultQueueSize();
+ }
+
+ public int success() {
+ return resultCallback.successes;
+ }
+
+ public List<ClusterConnection> clusterConnections() {
+ return operationProcessor.clusters();
+ }
+
+ public int failures() {
+ return resultCallback.failures;
+ }
+
+ }
+
+ private static class TestResultCallback implements FeedClient.ResultCallback {
+
+ private int successes = 0;
+ private int failures = 0;
+
+ @Override
+ public void onCompletion(String docId, Result documentResult) {
+ successes++;
+ }
+
+ @Override
+ public void onEndpointException(FeedEndpointException exception) {
+ failures++;
+ }
+
+
+ }
+
+}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java
index baf6e2f2df3..ec929d68efb 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java
@@ -1,9 +1,12 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.http.client.core.operationProcessor;
+import com.yahoo.vespa.http.client.ManualClock;
import com.yahoo.vespa.http.client.core.ThrottlePolicy;
import org.junit.Test;
+import java.time.Duration;
+import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
@@ -12,6 +15,7 @@ import java.util.Random;
import java.util.concurrent.atomic.AtomicLong;
import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyDouble;
@@ -42,14 +46,14 @@ public class IncompleteResultsThrottlerTest {
* @return median queue length.
*/
int getAverageQueue(int clientCount, int breakPoint, int simulationTimeMs) {
- final AtomicLong timeMs = new AtomicLong(0);
+ ManualClock clock = new ManualClock(Instant.ofEpochMilli(0));
ArrayList<IncompleteResultsThrottler> incompleteResultsThrottlers = new ArrayList<>();
MockServer mockServer = new MockServer(breakPoint);
for (int x = 0; x < clientCount; x++) {
IncompleteResultsThrottler incompleteResultsThrottler =
- new IncompleteResultsThrottler(10, 50000, () -> timeMs.get(), new ThrottlePolicy());
+ new IncompleteResultsThrottler(10, 50000, clock, new ThrottlePolicy());
incompleteResultsThrottlers.add(incompleteResultsThrottler);
}
long sum = 0;
@@ -68,8 +72,8 @@ public class IncompleteResultsThrottlerTest {
if (fastForward) {
time = mockServer.nextRequestFinished();
}
- timeMs.set(time);
- mockServer.moveTime(timeMs.get());
+ clock.setInstant(Instant.ofEpochMilli(time));
+ mockServer.moveTime(clock.instant().toEpochMilli());
for (int y = 0; y < clientCount; y++) {
// Fill up, but don't block as that would stop the simulation.
while (incompleteResultsThrottlers.get(y).availableCapacity() > 0) {
@@ -140,45 +144,46 @@ public class IncompleteResultsThrottlerTest {
}
}
- private void moveToNextCycle(final IncompleteResultsThrottler throttler, AtomicLong timeMs)
+ private void moveToNextCycle(final IncompleteResultsThrottler throttler, ManualClock clock)
throws InterruptedException {
waitForThreads();
// Enter an adaption phase, we don't care about this phase.
- timeMs.addAndGet(throttler.phaseSizeMs);
+ clock.advance(Duration.ofMillis(throttler.phaseSizeMs));
throttler.operationStart();
throttler.resultReady(false);
// Now enter the real next phase.
- timeMs.addAndGet(throttler.phaseSizeMs);
+ clock.advance(Duration.ofMillis(throttler.phaseSizeMs));
throttler.operationStart();
throttler.resultReady(false);
}
@Test
public void testInteractionWithPolicyByMockingPolicy() throws InterruptedException {
+ ManualClock clock = new ManualClock(Instant.ofEpochMilli(0));
final int MAX_SIZE = 1000;
final int MORE_THAN_MAX_SIZE = MAX_SIZE + 20;
final int SIZE_AFTER_CYCLE_FIRST = 30;
final int SIZE_AFTER_CYCLE_SECOND = 5000;
ThrottlePolicy policy = mock(ThrottlePolicy.class);
- final AtomicLong timeMs = new AtomicLong(0);
IncompleteResultsThrottler incompleteResultsThrottler =
- new IncompleteResultsThrottler(2, MAX_SIZE, ()->timeMs.get(), policy);
+ new IncompleteResultsThrottler(2, MAX_SIZE, clock, policy);
long bucketSizeMs = incompleteResultsThrottler.phaseSizeMs;
// Cycle 1 - Algorithm has fixed value for max-in-flight: INITIAL_MAX_IN_FLIGHT_VALUE.
// We post a few operations, not all finishing in this cycle. We explicitly do not fill the window
// size to test the argument about any requests blocked.
- assertThat(incompleteResultsThrottler.availableCapacity(),
- is(IncompleteResultsThrottler.INITIAL_MAX_IN_FLIGHT_VALUE));
+ assertEquals(IncompleteResultsThrottler.INITIAL_MAX_IN_FLIGHT_VALUE,
+ incompleteResultsThrottler.availableCapacity());
postOperations(20, incompleteResultsThrottler);
postSuccesses(15, incompleteResultsThrottler);
- moveToNextCycle(incompleteResultsThrottler, timeMs);
+ moveToNextCycle(incompleteResultsThrottler, clock);
// Cycle 2 - Algorithm has fixed value also for second iteration: SECOND_MAX_IN_FLIGHT_VALUE.
// Test verifies that this value is used, and insert a value to be used for next phase SIZE_AFTER_CYCLE_FIRST.
- assertThat(incompleteResultsThrottler.availableCapacity(),
- is(IncompleteResultsThrottler.SECOND_MAX_IN_FLIGHT_VALUE - 5)); // 5 slots already taken earlier
+ assertEquals("5 slots already taken earlier",
+ IncompleteResultsThrottler.SECOND_MAX_IN_FLIGHT_VALUE - 5,
+ incompleteResultsThrottler.availableCapacity());
postSuccesses(5, incompleteResultsThrottler);
when(policy.calcNewMaxInFlight(
anyDouble(), // Max performance change
@@ -188,12 +193,11 @@ public class IncompleteResultsThrottlerTest {
eq(IncompleteResultsThrottler.SECOND_MAX_IN_FLIGHT_VALUE), // current size
eq(false))) // is any request blocked, should be false since we only posted 20 docs.
.thenReturn(SIZE_AFTER_CYCLE_FIRST);
- moveToNextCycle(incompleteResultsThrottler, timeMs);
+ moveToNextCycle(incompleteResultsThrottler, clock);
// Cycle 3 - Test that value set in previous phase is used. Now return a very large number.
// However, this number should be cropped by the system (tested in next cycle).
- assertThat(incompleteResultsThrottler.availableCapacity(),
- is(SIZE_AFTER_CYCLE_FIRST));
+ assertEquals(SIZE_AFTER_CYCLE_FIRST, incompleteResultsThrottler.availableCapacity());
postOperations(MORE_THAN_MAX_SIZE, incompleteResultsThrottler);
postSuccesses(MORE_THAN_MAX_SIZE, incompleteResultsThrottler);
when(policy.calcNewMaxInFlight(
@@ -204,11 +208,10 @@ public class IncompleteResultsThrottlerTest {
eq(SIZE_AFTER_CYCLE_FIRST),// current size
eq(true))) // is any request blocked, should be true since we posted MORE_THAN_MAX_SIZE docs.
.thenReturn(SIZE_AFTER_CYCLE_SECOND);
- moveToNextCycle(incompleteResultsThrottler, timeMs);
+ moveToNextCycle(incompleteResultsThrottler, clock);
// Cycle 4 - Test that the large number from previous cycle is cropped and that max value is used instead.
- assertThat(incompleteResultsThrottler.availableCapacity(),
- is(MAX_SIZE));
+ assertEquals(MAX_SIZE, incompleteResultsThrottler.availableCapacity());
}
private long inversesU(int size, int sweetSpot) {
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java
index 9753a180618..e4ae138054d 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java
@@ -10,6 +10,7 @@ import com.yahoo.vespa.http.client.core.Document;
import com.yahoo.vespa.http.client.core.EndpointResult;
import org.junit.Test;
+import java.time.Clock;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
@@ -32,10 +33,10 @@ import static org.mockito.Mockito.when;
public class OperationProcessorTest {
final Queue<Result> queue = new ArrayDeque<>();
- final Document doc1 = new Document("id:a:type::b", null, "data doc 1", null);
- final Document doc1b = new Document("id:a:type::b", null, "data doc 1b", null);
- final Document doc2 = new Document("id:a:type::b2", null, "data doc 2", null);
- final Document doc3 = new Document("id:a:type::b3", null, "data doc 3", null);
+ final Document doc1 = new Document("id:a:type::b", null, "data doc 1", null, Clock.systemUTC().instant());
+ final Document doc1b = new Document("id:a:type::b", null, "data doc 1b", null, Clock.systemUTC().instant());
+ final Document doc2 = new Document("id:a:type::b2", null, "data doc 2", null, Clock.systemUTC().instant());
+ final Document doc3 = new Document("id:a:type::b3", null, "data doc 3", null, Clock.systemUTC().instant());
@Test
public void testBasic() {
@@ -49,7 +50,7 @@ public class OperationProcessorTest {
OperationProcessor q = new OperationProcessor(
new IncompleteResultsThrottler(1000, 1000, null, null),
(docId, documentResult) -> queue.add(documentResult),
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
q.resultReceived(new EndpointResult("foo", new Result.Detail(null)), 0);
@@ -127,7 +128,7 @@ public class OperationProcessorTest {
OperationProcessor operationProcessor = new OperationProcessor(
new IncompleteResultsThrottler(1000, 1000, null, null),
(docId, documentResult) -> queue.add(documentResult),
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
operationProcessor.sendDocument(doc1);
operationProcessor.sendDocument(doc1b);
@@ -165,7 +166,7 @@ public class OperationProcessorTest {
OperationProcessor operationProcessor = new OperationProcessor(
new IncompleteResultsThrottler(1000, 1000, null, null),
(docId, documentResult) -> queue.add(documentResult),
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
operationProcessor.sendDocument(doc1);
operationProcessor.sendDocument(doc1b);
@@ -198,11 +199,11 @@ public class OperationProcessorTest {
OperationProcessor operationProcessor = new OperationProcessor(
new IncompleteResultsThrottler(1000, 1000, null, null),
(docId, documentResult) -> queue.add(documentResult),
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
Queue<Document> documentQueue = new ArrayDeque<>();
for (int x = 0; x < 100; x++) {
- Document document = new Document("id:a:type::b", null, String.valueOf(x), null);
+ Document document = new Document("id:a:type::b", null, String.valueOf(x), null, Clock.systemUTC().instant());
operationProcessor.sendDocument(document);
documentQueue.add(document);
}
@@ -233,7 +234,7 @@ public class OperationProcessorTest {
OperationProcessor operationProcessor = new OperationProcessor(
new IncompleteResultsThrottler(1000, 1000, null, null),
(docId, documentResult) -> queue.add(documentResult),
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
operationProcessor.sendDocument(doc1);
operationProcessor.sendDocument(doc1b); // Blocked
@@ -273,7 +274,7 @@ public class OperationProcessorTest {
OperationProcessor q = new OperationProcessor(
new IncompleteResultsThrottler(1000, 1000, null, null),
(docId, documentResult) -> queue.add(documentResult),
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
q.sendDocument(doc1);
assertEquals(0, queue.size());
@@ -299,7 +300,7 @@ public class OperationProcessorTest {
OperationProcessor q = new OperationProcessor(
new IncompleteResultsThrottler(1000, 1000, null, null),
(docId, documentResult) -> queue.add(documentResult),
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
q.sendDocument(doc1);
assertEquals(0, queue.size());
@@ -358,7 +359,7 @@ public class OperationProcessorTest {
OperationProcessor operationProcessor = new OperationProcessor(
new IncompleteResultsThrottler(1, 1, null, null),
(docId, documentResult) -> {},
- sessionParams, null);
+ sessionParams, null, Clock.systemUTC());
operationProcessor.sendDocument(doc1);
@@ -397,7 +398,7 @@ public class OperationProcessorTest {
(docId, documentResult) -> {
countDownLatch.countDown();
},
- sessionParams, executor);
+ sessionParams, executor, Clock.systemUTC());
// Will fail due to bogus host name, but will be retried.
operationProcessor.sendDocument(doc1);
@@ -425,7 +426,7 @@ public class OperationProcessorTest {
(docId, documentResult) -> {
countDownLatch.countDown();
},
- sessionParams, executor);
+ sessionParams, executor, Clock.systemUTC());
fail("Expected exception");
}
diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java
index db1c2471752..5548f8fbc1f 100644
--- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java
+++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java
@@ -156,45 +156,38 @@ class ClientFeederV3 {
}
private int getOverloadReturnCode(HttpRequest request) {
- if (request.getHeader(Headers.SILENTUPGRADE) != null ) {
- return 299;
- }
+ if (request.getHeader(Headers.SILENTUPGRADE) != null ) return 299;
return 429;
}
- private Optional<DocumentOperationMessageV3> pullMessageFromRequest(
- FeederSettings settings, InputStream requestInputStream, BlockingQueue<OperationStatus> repliesFromOldMessages) {
+ private Optional<DocumentOperationMessageV3> pullMessageFromRequest(FeederSettings settings,
+ InputStream requestInputStream,
+ BlockingQueue<OperationStatus> repliesFromOldMessages) {
while (true) {
Optional<String> operationId;
try {
operationId = streamReaderV3.getNextOperationId(requestInputStream);
+ if (operationId.isEmpty()) return Optional.empty();
} catch (IOException ioe) {
- if (log.isLoggable(Level.FINE)) {
- log.log(Level.FINE, Exceptions.toMessageString(ioe), ioe);
- }
- return Optional.empty();
- }
- if (! operationId.isPresent()) {
+ log.log(Level.FINE, () -> Exceptions.toMessageString(ioe));
return Optional.empty();
}
- DocumentOperationMessageV3 message;
try {
- message = getNextMessage(operationId.get(), requestInputStream, settings);
+ DocumentOperationMessageV3 message = getNextMessage(operationId.get(), requestInputStream, settings);
+ if (message != null)
+ setRoute(message, settings);
+ return Optional.ofNullable(message);
} catch (Exception e) {
- if (log.isLoggable(Level.WARNING)) {
- log.log(Level.WARNING, Exceptions.toMessageString(e));
- }
+ log.log(Level.WARNING, () -> Exceptions.toMessageString(e));
metric.add(MetricNames.PARSE_ERROR, 1, null);
- repliesFromOldMessages.add(new OperationStatus(
- Exceptions.toMessageString(e), operationId.get(), ErrorCode.ERROR, false, ""));
-
- continue;
+ repliesFromOldMessages.add(new OperationStatus(Exceptions.toMessageString(e),
+ operationId.get(),
+ ErrorCode.ERROR,
+ false,
+ ""));
}
- if (message != null)
- setRoute(message, settings);
- return Optional.ofNullable(message);
}
}
@@ -223,47 +216,45 @@ class ClientFeederV3 {
BlockingQueue<OperationStatus> repliesFromOldMessages,
AtomicInteger threadsAvailableForFeeding) throws InterruptedException {
while (true) {
- Optional<DocumentOperationMessageV3> msg = pullMessageFromRequest(settings, requestInputStream, repliesFromOldMessages);
+ Optional<DocumentOperationMessageV3> message = pullMessageFromRequest(settings,
+ requestInputStream,
+ repliesFromOldMessages);
- if (! msg.isPresent()) {
- break;
- }
- setMessageParameters(msg.get(), settings);
+ if (message.isEmpty()) break;
+ setMessageParameters(message.get(), settings);
Result result;
try {
- result = sendMessage(settings, msg.get(), threadsAvailableForFeeding);
+ result = sendMessage(settings, message.get(), threadsAvailableForFeeding);
} catch (RuntimeException e) {
- repliesFromOldMessages.add(createOperationStatus(msg.get().getOperationId(),
+ repliesFromOldMessages.add(createOperationStatus(message.get().getOperationId(),
Exceptions.toMessageString(e),
ErrorCode.ERROR,
false,
- msg.get().getMessage()));
+ message.get().getMessage()));
continue;
}
if (result.isAccepted()) {
outstandingOperations.incrementAndGet();
updateOpsPerSec();
- log(Level.FINE, "Sent message successfully, document id: ", msg.get().getOperationId());
+ log(Level.FINE, "Sent message successfully, document id: ", message.get().getOperationId());
} else if (!result.getError().isFatal()) {
- repliesFromOldMessages.add(createOperationStatus(msg.get().getOperationId(),
+ repliesFromOldMessages.add(createOperationStatus(message.get().getOperationId(),
result.getError().getMessage(),
ErrorCode.TRANSIENT_ERROR,
false,
- msg.get().getMessage()));
- continue;
+ message.get().getMessage()));
} else {
// should probably not happen, but everybody knows stuff that
// shouldn't happen, happens all the time
boolean isConditionNotMet = result.getError().getCode() == DocumentProtocol.ERROR_TEST_AND_SET_CONDITION_FAILED;
- repliesFromOldMessages.add(createOperationStatus(msg.get().getOperationId(),
+ repliesFromOldMessages.add(createOperationStatus(message.get().getOperationId(),
result.getError().getMessage(),
ErrorCode.ERROR,
isConditionNotMet,
- msg.get().getMessage()));
- continue;
+ message.get().getMessage()));
}
}
}
@@ -326,17 +317,11 @@ class ClientFeederV3 {
}
protected final void log(Level level, Object... msgParts) {
- StringBuilder s;
+ if (!log.isLoggable(level)) return;
- if (!log.isLoggable(level)) {
- return;
- }
-
- s = new StringBuilder();
- for (Object part : msgParts) {
+ StringBuilder s = new StringBuilder();
+ for (Object part : msgParts)
s.append(part.toString());
- }
-
log.log(level, s.toString());
}
diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java
index e6d8a88d10b..909c643a006 100644
--- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java
+++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java
@@ -6,6 +6,8 @@ import com.yahoo.messagebus.routing.Route;
import com.yahoo.vespa.http.client.config.FeedParams.DataFormat;
import com.yahoo.vespa.http.client.core.Headers;
+import java.util.Optional;
+
/**
* Wrapper for the feed feederSettings read from HTTP request.
*
@@ -14,7 +16,7 @@ import com.yahoo.vespa.http.client.core.Headers;
public class FeederSettings {
private static final Route DEFAULT_ROUTE = Route.parse("default");
- public final boolean drain;
+ public final boolean drain; // TODO: Implement drain=true
public final Route route;
public final boolean denyIfBusy;
public final DataFormat dataFormat;
@@ -22,55 +24,13 @@ public class FeederSettings {
public final Integer traceLevel;
public FeederSettings(HttpRequest request) {
- {
- String tmpDrain = request.getHeader(Headers.DRAIN);
- if (tmpDrain != null) {
- drain = Boolean.parseBoolean(tmpDrain);
- } else {
- drain = false;
- }
- }
- {
- String tmpRoute = request.getHeader(Headers.ROUTE);
- if (tmpRoute != null) {
- route = Route.parse(tmpRoute);
- } else {
- route = DEFAULT_ROUTE;
- }
- }
- {
- String tmpDenyIfBusy = request.getHeader(Headers.DENY_IF_BUSY);
- if (tmpDenyIfBusy != null) {
- denyIfBusy = Boolean.parseBoolean(tmpDenyIfBusy);
- } else {
- denyIfBusy = false;
- }
- }
- {
- // TODO: Change default to JSON on Vespa 8
- String tmpDataFormat = request.getHeader(Headers.DATA_FORMAT);
- if (tmpDataFormat != null) {
- dataFormat = DataFormat.valueOf(tmpDataFormat);
- } else {
- dataFormat = DataFormat.XML_UTF8;
- }
- }
- {
- String tmpDataFormat = request.getHeader(Headers.PRIORITY);
- if (tmpDataFormat != null) {
- priority = tmpDataFormat;
- } else {
- priority = null;
- }
- }
- {
- String tmpDataFormat = request.getHeader(Headers.TRACE_LEVEL);
- if (tmpDataFormat != null) {
- traceLevel = Integer.valueOf(tmpDataFormat);
- } else {
- traceLevel = null;
- }
- }
+ this.drain = Optional.ofNullable(request.getHeader(Headers.DRAIN)).map(Boolean::parseBoolean).orElse(false);
+ this.route = Optional.ofNullable(request.getHeader(Headers.ROUTE)).map(Route::parse).orElse(DEFAULT_ROUTE);
+ this.denyIfBusy = Optional.ofNullable(request.getHeader(Headers.DENY_IF_BUSY)).map(Boolean::parseBoolean).orElse(false);
+ // TODO: Change default to JSON on Vespa 8:
+ this.dataFormat = Optional.ofNullable(request.getHeader(Headers.DATA_FORMAT)).map(DataFormat::valueOf).orElse(DataFormat.XML_UTF8);
+ this.priority = request.getHeader(Headers.PRIORITY);
+ this.traceLevel = Optional.ofNullable(request.getHeader(Headers.TRACE_LEVEL)).map(Integer::valueOf).orElse(null);
}
}
diff --git a/vespalib/src/tests/slime/json_slime_benchmark.cpp b/vespalib/src/tests/slime/json_slime_benchmark.cpp
index 3c006bb89f7..36987843492 100644
--- a/vespalib/src/tests/slime/json_slime_benchmark.cpp
+++ b/vespalib/src/tests/slime/json_slime_benchmark.cpp
@@ -1,9 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/vespalib/testkit/test_kit.h>
#include <iostream>
#include <fstream>
-#include <sstream>
using namespace vespalib::slime::convenience;
diff --git a/vespalib/src/tests/slime/slime_binary_format_test.cpp b/vespalib/src/tests/slime/slime_binary_format_test.cpp
index e6661cbf554..37ce6d5dfdf 100644
--- a/vespalib/src/tests/slime/slime_binary_format_test.cpp
+++ b/vespalib/src/tests/slime/slime_binary_format_test.cpp
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include "type_traits.h"
#include <vespa/vespalib/util/stringfmt.h>
diff --git a/vespalib/src/tests/slime/slime_json_format_test.cpp b/vespalib/src/tests/slime/slime_json_format_test.cpp
index d1f77f09af1..df2f8b2e30b 100644
--- a/vespalib/src/tests/slime/slime_json_format_test.cpp
+++ b/vespalib/src/tests/slime/slime_json_format_test.cpp
@@ -3,6 +3,7 @@
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/data/input.h>
#include <vespa/vespalib/data/memory_input.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <iostream>
#include <fstream>
diff --git a/vespalib/src/tests/slime/slime_test.cpp b/vespalib/src/tests/slime/slime_test.cpp
index 7e70dc3538e..e58b1599b8f 100644
--- a/vespalib/src/tests/slime/slime_test.cpp
+++ b/vespalib/src/tests/slime/slime_test.cpp
@@ -1,11 +1,14 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/log/log.h>
-LOG_SETUP("slime_test");
+
#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/data/slime/strfmt.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <type_traits>
+#include <vespa/log/log.h>
+LOG_SETUP("slime_test");
+
using namespace vespalib::slime::convenience;
TEST("print sizes") {
diff --git a/vespalib/src/tests/trace/trace_serialization.cpp b/vespalib/src/tests/trace/trace_serialization.cpp
index 7658fe7f163..3182e46061a 100644
--- a/vespalib/src/tests/trace/trace_serialization.cpp
+++ b/vespalib/src/tests/trace/trace_serialization.cpp
@@ -3,6 +3,7 @@
#include <vespa/vespalib/trace/tracenode.h>
#include <vespa/vespalib/trace/slime_trace_serializer.h>
#include <vespa/vespalib/trace/slime_trace_deserializer.h>
+#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/log/log.h>
LOG_SETUP("trace_test");
diff --git a/vespalib/src/vespa/vespalib/data/memory.h b/vespalib/src/vespa/vespalib/data/memory.h
index 07767180b57..eee0a1a3e4f 100644
--- a/vespalib/src/vespa/vespalib/data/memory.h
+++ b/vespalib/src/vespa/vespalib/data/memory.h
@@ -15,14 +15,14 @@ struct Memory
const char *data;
size_t size;
- Memory() : data(nullptr), size(0) {}
- Memory(const char *d, size_t s) : data(d), size(s) {}
- Memory(const char *str) : data(str), size(strlen(str)) {}
- Memory(const std::string &str)
+ Memory() noexcept : data(nullptr), size(0) {}
+ Memory(const char *d, size_t s) noexcept : data(d), size(s) {}
+ Memory(const char *str) noexcept : data(str), size(strlen(str)) {}
+ Memory(const std::string &str) noexcept
: data(str.data()), size(str.size()) {}
- Memory(const vespalib::string &str)
+ Memory(const vespalib::string &str) noexcept
: data(str.data()), size(str.size()) {}
- Memory(vespalib::stringref str_ref)
+ Memory(vespalib::stringref str_ref) noexcept
: data(str_ref.data()), size(str_ref.size()) {}
vespalib::string make_string() const;
vespalib::stringref make_stringref() const { return stringref(data, size); }
diff --git a/vespalib/src/vespa/vespalib/data/simple_buffer.cpp b/vespalib/src/vespa/vespalib/data/simple_buffer.cpp
index 09ac4a4b830..7e3c5022fc5 100644
--- a/vespalib/src/vespa/vespalib/data/simple_buffer.cpp
+++ b/vespalib/src/vespa/vespalib/data/simple_buffer.cpp
@@ -11,7 +11,7 @@ SimpleBuffer::SimpleBuffer()
{
}
-SimpleBuffer::~SimpleBuffer() { }
+SimpleBuffer::~SimpleBuffer() = default;
Memory
SimpleBuffer::obtain()
diff --git a/vespalib/src/vespa/vespalib/data/simple_buffer.h b/vespalib/src/vespa/vespalib/data/simple_buffer.h
index f7d9543440f..3bcb43a3856 100644
--- a/vespalib/src/vespa/vespalib/data/simple_buffer.h
+++ b/vespalib/src/vespa/vespalib/data/simple_buffer.h
@@ -4,6 +4,7 @@
#include "input.h"
#include "output.h"
+#include <vespa/vespalib/stllike/allocator.h>
#include <iosfwd>
#include <vector>
@@ -20,7 +21,7 @@ class SimpleBuffer : public Input,
public Output
{
private:
- std::vector<char> _data;
+ std::vector<char, allocator_large<char>> _data;
size_t _used;
public:
diff --git a/vespalib/src/vespa/vespalib/data/slime/slime.h b/vespalib/src/vespa/vespalib/data/slime/slime.h
index aa44b38b353..6523cd1dac0 100644
--- a/vespalib/src/vespa/vespalib/data/slime/slime.h
+++ b/vespalib/src/vespa/vespalib/data/slime/slime.h
@@ -31,7 +31,6 @@
#include "external_data_value_factory.h"
#include <vespa/vespalib/data/input_reader.h>
#include <vespa/vespalib/data/output_writer.h>
-#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/vespalib/data/output.h>
namespace vespalib {
diff --git a/vespalib/src/vespa/vespalib/datastore/datastore.h b/vespalib/src/vespa/vespalib/datastore/datastore.h
index 193869a5591..4b908452b32 100644
--- a/vespalib/src/vespa/vespalib/datastore/datastore.h
+++ b/vespalib/src/vespa/vespalib/datastore/datastore.h
@@ -106,6 +106,7 @@ public:
DataStore(const DataStore &rhs) = delete;
DataStore &operator=(const DataStore &rhs) = delete;
DataStore();
+ DataStore(uint32_t min_arrays);
~DataStore();
EntryRef addEntry(const EntryType &e);
diff --git a/vespalib/src/vespa/vespalib/datastore/datastore.hpp b/vespalib/src/vespa/vespalib/datastore/datastore.hpp
index 6549425b022..ad35f3c7383 100644
--- a/vespalib/src/vespa/vespalib/datastore/datastore.hpp
+++ b/vespalib/src/vespa/vespalib/datastore/datastore.hpp
@@ -133,8 +133,14 @@ DataStoreT<RefT>::freeListRawAllocator(uint32_t typeId)
template <typename EntryType, typename RefT>
DataStore<EntryType, RefT>::DataStore()
+ : DataStore(RefType::offsetSize())
+{
+}
+
+template <typename EntryType, typename RefT>
+DataStore<EntryType, RefT>::DataStore(uint32_t min_arrays)
: ParentType(),
- _type(1, RefType::offsetSize(), RefType::offsetSize())
+ _type(1, min_arrays, RefType::offsetSize())
{
addType(&_type);
initActiveBuffers();
diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h
index db545451a30..78597a53dc8 100644
--- a/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h
+++ b/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h
@@ -3,6 +3,7 @@
#pragma once
#include "i_unique_store_dictionary.h"
+#include <vespa/vespalib/stllike/allocator.h>
namespace vespalib::datastore {
@@ -18,9 +19,10 @@ template <typename RefT>
class UniqueStoreEnumerator {
public:
using RefType = RefT;
- using EnumValues = std::vector<std::vector<uint32_t>>;
private:
+ using UInt32Vector = std::vector<uint32_t, vespalib::allocator_large<uint32_t>>;
+ using EnumValues = std::vector<UInt32Vector>;
IUniqueStoreDictionary::ReadSnapshot::UP _dict_snapshot;
const DataStoreBase &_store;
EnumValues _enumValues;
diff --git a/vespalib/src/vespa/vespalib/objects/nbostream.h b/vespalib/src/vespa/vespalib/objects/nbostream.h
index daaea981b5a..c5b26d786b3 100644
--- a/vespalib/src/vespa/vespalib/objects/nbostream.h
+++ b/vespalib/src/vespa/vespalib/objects/nbostream.h
@@ -20,7 +20,7 @@ class nbostream
public:
using Buffer = Array<char>;
using Alloc = alloc::Alloc;
- enum State { ok=0, eof=0x01};
+ enum State { ok=0, eof=0x01, oob=0x02};
nbostream(size_t initialSize=1024);
protected:
nbostream(const void * buf, size_t sz, bool longLivedBuffer);
@@ -145,6 +145,7 @@ public:
const char * peek() const { return &_rbuf[_rp]; }
size_t rp() const { return _rp; }
nbostream & rp(size_t pos) { if (pos > _wp) fail(eof); _rp = pos; return *this; }
+ nbostream & wp(size_t pos) { if (pos > _wbuf.size()) fail(oob); _wp = pos; return *this; }
size_t wp() const { return _wp; }
State state() const { return _state; }
bool good() const { return _state == ok; }
diff --git a/vespalib/src/vespa/vespalib/util/arrayref.h b/vespalib/src/vespa/vespalib/util/arrayref.h
index 749395ff574..03634a7a094 100644
--- a/vespalib/src/vespa/vespalib/util/arrayref.h
+++ b/vespalib/src/vespa/vespalib/util/arrayref.h
@@ -13,11 +13,11 @@ namespace vespalib {
template <typename T>
class ArrayRef {
public:
- ArrayRef() : _v(nullptr), _sz(0) { }
- ArrayRef(T * v, size_t sz) : _v(v), _sz(sz) { }
+ ArrayRef() noexcept : _v(nullptr), _sz(0) { }
+ ArrayRef(T * v, size_t sz) noexcept : _v(v), _sz(sz) { }
template<typename A=std::allocator<T>>
- ArrayRef(std::vector<T, A> & v) : _v(&v[0]), _sz(v.size()) { }
- ArrayRef(Array<T> &v) : _v(&v[0]), _sz(v.size()) { }
+ ArrayRef(std::vector<T, A> & v) noexcept : _v(&v[0]), _sz(v.size()) { }
+ ArrayRef(Array<T> &v) noexcept : _v(&v[0]), _sz(v.size()) { }
T & operator [] (size_t i) { return _v[i]; }
const T & operator [] (size_t i) const { return _v[i]; }
size_t size() const { return _sz; }
@@ -32,12 +32,12 @@ private:
template <typename T>
class ConstArrayRef {
public:
- ConstArrayRef(const T *v, size_t sz) : _v(v), _sz(sz) { }
+ ConstArrayRef(const T *v, size_t sz) noexcept : _v(v), _sz(sz) { }
template<typename A=std::allocator<T>>
- ConstArrayRef(const std::vector<T, A> & v) : _v(&v[0]), _sz(v.size()) { }
- ConstArrayRef(const ArrayRef<T> & v) : _v(&v[0]), _sz(v.size()) { }
- ConstArrayRef(const Array<T> &v) : _v(&v[0]), _sz(v.size()) { }
- ConstArrayRef() : _v(nullptr), _sz(0) {}
+ ConstArrayRef(const std::vector<T, A> & v) noexcept : _v(&v[0]), _sz(v.size()) { }
+ ConstArrayRef(const ArrayRef<T> & v) noexcept : _v(&v[0]), _sz(v.size()) { }
+ ConstArrayRef(const Array<T> &v) noexcept : _v(&v[0]), _sz(v.size()) { }
+ ConstArrayRef() noexcept : _v(nullptr), _sz(0) {}
const T & operator [] (size_t i) const { return _v[i]; }
size_t size() const { return _sz; }
bool empty() const { return _sz == 0; }